229 lines
7.0 KiB
Go
229 lines
7.0 KiB
Go
package mcpserver
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
|
|
)
|
|
|
|
func TestServerRegistersKeePassGOTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
session, cleanup := newTestSession(t, &fakeVaultClient{})
|
|
defer cleanup()
|
|
|
|
result, err := session.ListTools(context.Background(), nil)
|
|
if err != nil {
|
|
t.Fatalf("ListTools() error = %v", err)
|
|
}
|
|
var got []string
|
|
for _, tool := range result.Tools {
|
|
got = append(got, tool.Name)
|
|
}
|
|
want := []string{"find_browser_logins", "get_browser_credential", "get_entry_password", "get_session_status", "search_entries"}
|
|
if diff := cmp.Diff(want, got); diff != "" {
|
|
t.Errorf("ListTools() names mismatch (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
|
|
func TestSearchEntriesReturnsMetadataWithoutSecrets(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
client := &fakeVaultClient{
|
|
entries: []*keepassgov1.Entry{{
|
|
Id: "casino-blueprint",
|
|
Title: "Casino Blueprint",
|
|
Username: "linus",
|
|
Password: "three-pin-changeup",
|
|
Url: "https://vault.example.invalid",
|
|
Notes: "meet by the fountains",
|
|
Tags: []string{"heist", "vault"},
|
|
Path: []string{"Root", "Plans"},
|
|
Fields: map[string]string{
|
|
"safe-code": "1138",
|
|
"alias": "Mr. Caldwell",
|
|
},
|
|
}},
|
|
}
|
|
session, cleanup := newTestSession(t, client)
|
|
defer cleanup()
|
|
|
|
result, err := session.CallTool(context.Background(), &mcp.CallToolParams{
|
|
Name: "search_entries",
|
|
Arguments: map[string]any{
|
|
"path": []string{" Root ", "", "Plans"},
|
|
"query": " casino ",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("CallTool(search_entries) error = %v", err)
|
|
}
|
|
got, ok := result.StructuredContent.(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("CallTool(search_entries).StructuredContent type = %T, want map[string]any", result.StructuredContent)
|
|
}
|
|
entries, ok := got["entries"].([]any)
|
|
if !ok || len(entries) != 1 {
|
|
t.Fatalf("CallTool(search_entries).entries = %#v, want one entry", got["entries"])
|
|
}
|
|
entry, ok := entries[0].(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("CallTool(search_entries).entries[0] type = %T, want map[string]any", entries[0])
|
|
}
|
|
if _, ok := entry["password"]; ok {
|
|
t.Error("CallTool(search_entries) returned password, want password omitted")
|
|
}
|
|
if _, ok := entry["notes"]; ok {
|
|
t.Error("CallTool(search_entries) returned notes, want notes omitted")
|
|
}
|
|
if _, ok := entry["fields"]; ok {
|
|
t.Error("CallTool(search_entries) returned field values, want field values omitted")
|
|
}
|
|
if diff := cmp.Diff([]string{"Root", "Plans"}, client.listEntriesPath); diff != "" {
|
|
t.Errorf("CallTool(search_entries) ListEntries path mismatch (-want +got):\n%s", diff)
|
|
}
|
|
if client.listEntriesQuery != "casino" {
|
|
t.Errorf("CallTool(search_entries) ListEntries query = %q, want %q", client.listEntriesQuery, "casino")
|
|
}
|
|
}
|
|
|
|
func TestGetBrowserCredentialReturnsCredential(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
client := &fakeVaultClient{
|
|
credential: &keepassgov1.GetBrowserCredentialResponse{
|
|
Id: "benedict-account",
|
|
Username: "tess",
|
|
Password: "loaded-dice",
|
|
Url: "https://casino.example.invalid",
|
|
},
|
|
}
|
|
session, cleanup := newTestSession(t, client)
|
|
defer cleanup()
|
|
|
|
result, err := session.CallTool(context.Background(), &mcp.CallToolParams{
|
|
Name: "get_browser_credential",
|
|
Arguments: map[string]any{
|
|
"entryId": " benedict-account ",
|
|
"pageUrl": " https://casino.example.invalid/login ",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("CallTool(get_browser_credential) error = %v", err)
|
|
}
|
|
got := result.StructuredContent.(map[string]any)
|
|
if got["password"] != "loaded-dice" {
|
|
t.Errorf("CallTool(get_browser_credential).password = %v, want %q", got["password"], "loaded-dice")
|
|
}
|
|
if client.credentialID != "benedict-account" {
|
|
t.Errorf("CallTool(get_browser_credential) entry id = %q, want %q", client.credentialID, "benedict-account")
|
|
}
|
|
}
|
|
|
|
func TestGetEntryPasswordReturnsPasswordForUniqueMetadataMatch(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
client := &fakeVaultClient{
|
|
entries: []*keepassgov1.Entry{
|
|
{
|
|
Id: "wrong-crew",
|
|
Title: "Home Assistant",
|
|
Username: "rusty",
|
|
Password: "wrong-token",
|
|
Url: "https://lights.example.invalid",
|
|
Path: []string{"Root", "Shared"},
|
|
},
|
|
{
|
|
Id: "codex-token",
|
|
Title: "Home Assistant",
|
|
Username: "codex",
|
|
Password: "right-token",
|
|
Url: "https://lights.example.invalid",
|
|
Path: []string{"Root", "Codex"},
|
|
},
|
|
},
|
|
}
|
|
session, cleanup := newTestSession(t, client)
|
|
defer cleanup()
|
|
|
|
result, err := session.CallTool(context.Background(), &mcp.CallToolParams{
|
|
Name: "get_entry_password",
|
|
Arguments: map[string]any{
|
|
"path": []string{"Root", "Codex"},
|
|
"query": "Home Assistant",
|
|
"title": "home assistant",
|
|
"username": "codex",
|
|
"url": "https://lights.example.invalid",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("CallTool(get_entry_password) error = %v", err)
|
|
}
|
|
got := result.StructuredContent.(map[string]any)
|
|
if got["password"] != "right-token" {
|
|
t.Errorf("CallTool(get_entry_password).password = %v, want %q", got["password"], "right-token")
|
|
}
|
|
}
|
|
|
|
func newTestSession(t *testing.T, vaultClient *fakeVaultClient) (*mcp.ClientSession, func()) {
|
|
t.Helper()
|
|
|
|
ctx := context.Background()
|
|
server := New(vaultClient, Config{GRPCAddress: "unix:///tmp/keepassgo-heist.sock", Version: "test"})
|
|
client := mcp.NewClient(&mcp.Implementation{Name: "keepassgo-test"}, nil)
|
|
serverTransport, clientTransport := mcp.NewInMemoryTransports()
|
|
serverSession, err := server.Connect(ctx, serverTransport, nil)
|
|
if err != nil {
|
|
t.Fatalf("Server.Connect() error = %v", err)
|
|
}
|
|
clientSession, err := client.Connect(ctx, clientTransport, nil)
|
|
if err != nil {
|
|
t.Fatalf("Client.Connect() error = %v", err)
|
|
}
|
|
return clientSession, func() {
|
|
clientSession.Close()
|
|
serverSession.Close()
|
|
}
|
|
}
|
|
|
|
type fakeVaultClient struct {
|
|
status *keepassgov1.GetSessionStatusResponse
|
|
entries []*keepassgov1.Entry
|
|
matches []*keepassgov1.BrowserLoginMatch
|
|
credential *keepassgov1.GetBrowserCredentialResponse
|
|
listEntriesPath []string
|
|
listEntriesQuery string
|
|
credentialID string
|
|
credentialURL string
|
|
}
|
|
|
|
func (c *fakeVaultClient) Status(context.Context) (*keepassgov1.GetSessionStatusResponse, error) {
|
|
if c.status != nil {
|
|
return c.status, nil
|
|
}
|
|
return &keepassgov1.GetSessionStatusResponse{}, nil
|
|
}
|
|
|
|
func (c *fakeVaultClient) FindBrowserLogins(_ context.Context, pageURL string) ([]*keepassgov1.BrowserLoginMatch, error) {
|
|
return c.matches, nil
|
|
}
|
|
|
|
func (c *fakeVaultClient) ListEntries(_ context.Context, path []string, query string) ([]*keepassgov1.Entry, error) {
|
|
c.listEntriesPath = append([]string(nil), path...)
|
|
c.listEntriesQuery = query
|
|
return c.entries, nil
|
|
}
|
|
|
|
func (c *fakeVaultClient) GetBrowserCredential(_ context.Context, entryID, pageURL string) (*keepassgov1.GetBrowserCredentialResponse, error) {
|
|
c.credentialID = entryID
|
|
c.credentialURL = pageURL
|
|
if c.credential != nil {
|
|
return c.credential, nil
|
|
}
|
|
return &keepassgov1.GetBrowserCredentialResponse{}, nil
|
|
}
|