package mcpserver import ( "context" "testing" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/mcp" keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1" ) func TestServerRegistersKeePassGOTools(t *testing.T) { t.Parallel() session, cleanup := newTestSession(t, &fakeVaultClient{}) defer cleanup() result, err := session.ListTools(context.Background(), nil) if err != nil { t.Fatalf("ListTools() error = %v", err) } var got []string for _, tool := range result.Tools { got = append(got, tool.Name) } want := []string{"find_browser_logins", "get_browser_credential", "get_session_status", "search_entries"} if diff := cmp.Diff(want, got); diff != "" { t.Errorf("ListTools() names mismatch (-want +got):\n%s", diff) } } func TestSearchEntriesReturnsMetadataWithoutSecrets(t *testing.T) { t.Parallel() client := &fakeVaultClient{ entries: []*keepassgov1.Entry{{ Id: "casino-blueprint", Title: "Casino Blueprint", Username: "linus", Password: "three-pin-changeup", Url: "https://vault.example.invalid", Notes: "meet by the fountains", Tags: []string{"heist", "vault"}, Path: []string{"Root", "Plans"}, Fields: map[string]string{ "safe-code": "1138", "alias": "Mr. Caldwell", }, }}, } session, cleanup := newTestSession(t, client) defer cleanup() result, err := session.CallTool(context.Background(), &mcp.CallToolParams{ Name: "search_entries", Arguments: map[string]any{ "path": []string{" Root ", "", "Plans"}, "query": " casino ", }, }) if err != nil { t.Fatalf("CallTool(search_entries) error = %v", err) } got, ok := result.StructuredContent.(map[string]any) if !ok { t.Fatalf("CallTool(search_entries).StructuredContent type = %T, want map[string]any", result.StructuredContent) } entries, ok := got["entries"].([]any) if !ok || len(entries) != 1 { t.Fatalf("CallTool(search_entries).entries = %#v, want one entry", got["entries"]) } entry, ok := entries[0].(map[string]any) if !ok { t.Fatalf("CallTool(search_entries).entries[0] type = %T, want map[string]any", entries[0]) } if _, ok := entry["password"]; ok { t.Error("CallTool(search_entries) returned password, want password omitted") } if _, ok := entry["notes"]; ok { t.Error("CallTool(search_entries) returned notes, want notes omitted") } if _, ok := entry["fields"]; ok { t.Error("CallTool(search_entries) returned field values, want field values omitted") } if diff := cmp.Diff([]string{"Root", "Plans"}, client.listEntriesPath); diff != "" { t.Errorf("CallTool(search_entries) ListEntries path mismatch (-want +got):\n%s", diff) } if client.listEntriesQuery != "casino" { t.Errorf("CallTool(search_entries) ListEntries query = %q, want %q", client.listEntriesQuery, "casino") } } func TestGetBrowserCredentialReturnsCredential(t *testing.T) { t.Parallel() client := &fakeVaultClient{ credential: &keepassgov1.GetBrowserCredentialResponse{ Id: "benedict-account", Username: "tess", Password: "loaded-dice", Url: "https://casino.example.invalid", }, } session, cleanup := newTestSession(t, client) defer cleanup() result, err := session.CallTool(context.Background(), &mcp.CallToolParams{ Name: "get_browser_credential", Arguments: map[string]any{ "entryId": " benedict-account ", "pageUrl": " https://casino.example.invalid/login ", }, }) if err != nil { t.Fatalf("CallTool(get_browser_credential) error = %v", err) } got := result.StructuredContent.(map[string]any) if got["password"] != "loaded-dice" { t.Errorf("CallTool(get_browser_credential).password = %v, want %q", got["password"], "loaded-dice") } if client.credentialID != "benedict-account" { t.Errorf("CallTool(get_browser_credential) entry id = %q, want %q", client.credentialID, "benedict-account") } } func newTestSession(t *testing.T, vaultClient *fakeVaultClient) (*mcp.ClientSession, func()) { t.Helper() ctx := context.Background() server := New(vaultClient, Config{GRPCAddress: "unix:///tmp/keepassgo-heist.sock", Version: "test"}) client := mcp.NewClient(&mcp.Implementation{Name: "keepassgo-test"}, nil) serverTransport, clientTransport := mcp.NewInMemoryTransports() serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { t.Fatalf("Server.Connect() error = %v", err) } clientSession, err := client.Connect(ctx, clientTransport, nil) if err != nil { t.Fatalf("Client.Connect() error = %v", err) } return clientSession, func() { clientSession.Close() serverSession.Close() } } type fakeVaultClient struct { status *keepassgov1.GetSessionStatusResponse entries []*keepassgov1.Entry matches []*keepassgov1.BrowserLoginMatch credential *keepassgov1.GetBrowserCredentialResponse listEntriesPath []string listEntriesQuery string credentialID string credentialURL string } func (c *fakeVaultClient) Status(context.Context) (*keepassgov1.GetSessionStatusResponse, error) { if c.status != nil { return c.status, nil } return &keepassgov1.GetSessionStatusResponse{}, nil } func (c *fakeVaultClient) FindBrowserLogins(_ context.Context, pageURL string) ([]*keepassgov1.BrowserLoginMatch, error) { return c.matches, nil } func (c *fakeVaultClient) ListEntries(_ context.Context, path []string, query string) ([]*keepassgov1.Entry, error) { c.listEntriesPath = append([]string(nil), path...) c.listEntriesQuery = query return c.entries, nil } func (c *fakeVaultClient) GetBrowserCredential(_ context.Context, entryID, pageURL string) (*keepassgov1.GetBrowserCredentialResponse, error) { c.credentialID = entryID c.credentialURL = pageURL if c.credential != nil { return c.credential, nil } return &keepassgov1.GetBrowserCredentialResponse{}, nil }