Files
keepassgo/internal/mcpserver/server_test.go
Joe Julian 13eeb3fe4a
ci / lint-test (pull_request) Successful in 1m46s
ci / build (pull_request) Successful in 2m40s
Add MCP entry password tool
2026-05-14 09:06:59 -07:00

229 lines
7.0 KiB
Go

package mcpserver
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/modelcontextprotocol/go-sdk/mcp"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
)
func TestServerRegistersKeePassGOTools(t *testing.T) {
t.Parallel()
session, cleanup := newTestSession(t, &fakeVaultClient{})
defer cleanup()
result, err := session.ListTools(context.Background(), nil)
if err != nil {
t.Fatalf("ListTools() error = %v", err)
}
var got []string
for _, tool := range result.Tools {
got = append(got, tool.Name)
}
want := []string{"find_browser_logins", "get_browser_credential", "get_entry_password", "get_session_status", "search_entries"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("ListTools() names mismatch (-want +got):\n%s", diff)
}
}
func TestSearchEntriesReturnsMetadataWithoutSecrets(t *testing.T) {
t.Parallel()
client := &fakeVaultClient{
entries: []*keepassgov1.Entry{{
Id: "casino-blueprint",
Title: "Casino Blueprint",
Username: "linus",
Password: "three-pin-changeup",
Url: "https://vault.example.invalid",
Notes: "meet by the fountains",
Tags: []string{"heist", "vault"},
Path: []string{"Root", "Plans"},
Fields: map[string]string{
"safe-code": "1138",
"alias": "Mr. Caldwell",
},
}},
}
session, cleanup := newTestSession(t, client)
defer cleanup()
result, err := session.CallTool(context.Background(), &mcp.CallToolParams{
Name: "search_entries",
Arguments: map[string]any{
"path": []string{" Root ", "", "Plans"},
"query": " casino ",
},
})
if err != nil {
t.Fatalf("CallTool(search_entries) error = %v", err)
}
got, ok := result.StructuredContent.(map[string]any)
if !ok {
t.Fatalf("CallTool(search_entries).StructuredContent type = %T, want map[string]any", result.StructuredContent)
}
entries, ok := got["entries"].([]any)
if !ok || len(entries) != 1 {
t.Fatalf("CallTool(search_entries).entries = %#v, want one entry", got["entries"])
}
entry, ok := entries[0].(map[string]any)
if !ok {
t.Fatalf("CallTool(search_entries).entries[0] type = %T, want map[string]any", entries[0])
}
if _, ok := entry["password"]; ok {
t.Error("CallTool(search_entries) returned password, want password omitted")
}
if _, ok := entry["notes"]; ok {
t.Error("CallTool(search_entries) returned notes, want notes omitted")
}
if _, ok := entry["fields"]; ok {
t.Error("CallTool(search_entries) returned field values, want field values omitted")
}
if diff := cmp.Diff([]string{"Root", "Plans"}, client.listEntriesPath); diff != "" {
t.Errorf("CallTool(search_entries) ListEntries path mismatch (-want +got):\n%s", diff)
}
if client.listEntriesQuery != "casino" {
t.Errorf("CallTool(search_entries) ListEntries query = %q, want %q", client.listEntriesQuery, "casino")
}
}
func TestGetBrowserCredentialReturnsCredential(t *testing.T) {
t.Parallel()
client := &fakeVaultClient{
credential: &keepassgov1.GetBrowserCredentialResponse{
Id: "benedict-account",
Username: "tess",
Password: "loaded-dice",
Url: "https://casino.example.invalid",
},
}
session, cleanup := newTestSession(t, client)
defer cleanup()
result, err := session.CallTool(context.Background(), &mcp.CallToolParams{
Name: "get_browser_credential",
Arguments: map[string]any{
"entryId": " benedict-account ",
"pageUrl": " https://casino.example.invalid/login ",
},
})
if err != nil {
t.Fatalf("CallTool(get_browser_credential) error = %v", err)
}
got := result.StructuredContent.(map[string]any)
if got["password"] != "loaded-dice" {
t.Errorf("CallTool(get_browser_credential).password = %v, want %q", got["password"], "loaded-dice")
}
if client.credentialID != "benedict-account" {
t.Errorf("CallTool(get_browser_credential) entry id = %q, want %q", client.credentialID, "benedict-account")
}
}
func TestGetEntryPasswordReturnsPasswordForUniqueMetadataMatch(t *testing.T) {
t.Parallel()
client := &fakeVaultClient{
entries: []*keepassgov1.Entry{
{
Id: "wrong-crew",
Title: "Home Assistant",
Username: "rusty",
Password: "wrong-token",
Url: "https://lights.example.invalid",
Path: []string{"Root", "Shared"},
},
{
Id: "codex-token",
Title: "Home Assistant",
Username: "codex",
Password: "right-token",
Url: "https://lights.example.invalid",
Path: []string{"Root", "Codex"},
},
},
}
session, cleanup := newTestSession(t, client)
defer cleanup()
result, err := session.CallTool(context.Background(), &mcp.CallToolParams{
Name: "get_entry_password",
Arguments: map[string]any{
"path": []string{"Root", "Codex"},
"query": "Home Assistant",
"title": "home assistant",
"username": "codex",
"url": "https://lights.example.invalid",
},
})
if err != nil {
t.Fatalf("CallTool(get_entry_password) error = %v", err)
}
got := result.StructuredContent.(map[string]any)
if got["password"] != "right-token" {
t.Errorf("CallTool(get_entry_password).password = %v, want %q", got["password"], "right-token")
}
}
func newTestSession(t *testing.T, vaultClient *fakeVaultClient) (*mcp.ClientSession, func()) {
t.Helper()
ctx := context.Background()
server := New(vaultClient, Config{GRPCAddress: "unix:///tmp/keepassgo-heist.sock", Version: "test"})
client := mcp.NewClient(&mcp.Implementation{Name: "keepassgo-test"}, nil)
serverTransport, clientTransport := mcp.NewInMemoryTransports()
serverSession, err := server.Connect(ctx, serverTransport, nil)
if err != nil {
t.Fatalf("Server.Connect() error = %v", err)
}
clientSession, err := client.Connect(ctx, clientTransport, nil)
if err != nil {
t.Fatalf("Client.Connect() error = %v", err)
}
return clientSession, func() {
clientSession.Close()
serverSession.Close()
}
}
type fakeVaultClient struct {
status *keepassgov1.GetSessionStatusResponse
entries []*keepassgov1.Entry
matches []*keepassgov1.BrowserLoginMatch
credential *keepassgov1.GetBrowserCredentialResponse
listEntriesPath []string
listEntriesQuery string
credentialID string
credentialURL string
}
func (c *fakeVaultClient) Status(context.Context) (*keepassgov1.GetSessionStatusResponse, error) {
if c.status != nil {
return c.status, nil
}
return &keepassgov1.GetSessionStatusResponse{}, nil
}
func (c *fakeVaultClient) FindBrowserLogins(_ context.Context, pageURL string) ([]*keepassgov1.BrowserLoginMatch, error) {
return c.matches, nil
}
func (c *fakeVaultClient) ListEntries(_ context.Context, path []string, query string) ([]*keepassgov1.Entry, error) {
c.listEntriesPath = append([]string(nil), path...)
c.listEntriesQuery = query
return c.entries, nil
}
func (c *fakeVaultClient) GetBrowserCredential(_ context.Context, entryID, pageURL string) (*keepassgov1.GetBrowserCredentialResponse, error) {
c.credentialID = entryID
c.credentialURL = pageURL
if c.credential != nil {
return c.credential, nil
}
return &keepassgov1.GetBrowserCredentialResponse{}, nil
}