Files
keepassgo/internal/browserbridge/bridge_test.go
T
2026-04-23 21:00:29 -07:00

523 lines
17 KiB
Go

package browserbridge
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
gcodes "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
)
func TestReadRequestAndWriteResponse(t *testing.T) {
t.Parallel()
var input bytes.Buffer
body, err := json.Marshal(Request{
Action: "find-logins",
BearerToken: "secret",
URL: "https://example.invalid/login",
})
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}
if err := binary.Write(&input, binary.LittleEndian, uint32(len(body))); err != nil {
t.Fatalf("binary.Write() error = %v", err)
}
if _, err := input.Write(body); err != nil {
t.Fatalf("Write() error = %v", err)
}
req, err := ReadRequest(&input)
if err != nil {
t.Fatalf("ReadRequest() error = %v", err)
}
if req.Action != "find-logins" || req.BearerToken != "secret" {
t.Fatalf("ReadRequest() = %#v, want action and token preserved", req)
}
if conn, err := req.Connection("127.0.0.1:47777"); err != nil || conn.GRPCAddress != "127.0.0.1:47777" {
t.Fatalf("req.Connection(127.0.0.1:47777) = (%#v, %v), want explicit tcp address preserved", conn, err)
}
var output bytes.Buffer
if err := WriteResponse(&output, Response{Success: true, Version: "1"}); err != nil {
t.Fatalf("WriteResponse() error = %v", err)
}
var size uint32
if err := binary.Read(&output, binary.LittleEndian, &size); err != nil {
t.Fatalf("binary.Read() error = %v", err)
}
payload := make([]byte, size)
if _, err := output.Read(payload); err != nil {
t.Fatalf("Read() payload error = %v", err)
}
var resp Response
if err := json.Unmarshal(payload, &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if !resp.Success || resp.Version != "1" {
t.Fatalf("response = %#v, want success version 1", resp)
}
}
func TestHandleRequestFindLogins(t *testing.T) {
t.Parallel()
client := &fakeClient{
matches: []*keepassgov1.BrowserLoginMatch{
{Id: "vault-console", Title: "Vault Console", Username: "dannyocean", Url: "https://vault.example.invalid", Quality: "exact-host"},
},
}
resp := HandleRequest(context.Background(), Request{
Action: "find-logins",
BearerToken: "secret",
URL: "https://vault.example.invalid/login",
}, "", client)
if !resp.Success {
t.Fatalf("HandleRequest() success = false, error = %q", resp.Error)
}
if len(resp.Matches) != 1 || resp.Matches[0].ID != "vault-console" {
t.Fatalf("HandleRequest().Matches = %#v, want vault-console", resp.Matches)
}
if client.statusCalls != 0 {
t.Fatalf("HandleRequest(find-logins) statusCalls = %d, want 0", client.statusCalls)
}
}
func TestHandleRequestStatusIncludesPendingApprovalCounts(t *testing.T) {
t.Parallel()
client := &fakeClient{
status: &keepassgov1.GetSessionStatusResponse{
Locked: false,
EntryCount: 2,
PendingApprovalCount: 3,
TokenPendingApprovalCount: 1,
},
}
resp := HandleRequest(context.Background(), Request{
Action: "status",
BearerToken: "secret",
}, "", client)
if !resp.Success {
t.Fatalf("HandleRequest(status) success = false, error = %q", resp.Error)
}
if resp.Status == nil {
t.Fatal("HandleRequest(status).Status = nil, want status")
}
if got := resp.Status.PendingApprovalCount; got != 3 {
t.Fatalf("HandleRequest(status).PendingApprovalCount = %d, want 3", got)
}
if got := resp.Status.TokenPendingApprovalCount; got != 1 {
t.Fatalf("HandleRequest(status).TokenPendingApprovalCount = %d, want 1", got)
}
}
func TestHandleRequestGetLogin(t *testing.T) {
t.Parallel()
client := &fakeClient{
credential: &keepassgov1.GetBrowserCredentialResponse{
Id: "vault-console",
Username: "dannyocean",
Password: "token-1",
Url: "https://vault.example.invalid",
},
}
resp := HandleRequest(context.Background(), Request{
Action: "get-login",
BearerToken: "secret",
EntryID: "vault-console",
URL: "https://vault.example.invalid/login",
}, "", client)
if !resp.Success {
t.Fatalf("HandleRequest() success = false, error = %q", resp.Error)
}
if resp.Credential == nil || resp.Credential.ID != "vault-console" {
t.Fatalf("HandleRequest().Credential = %#v, want vault-console", resp.Credential)
}
if client.statusCalls != 0 {
t.Fatalf("HandleRequest(get-login) statusCalls = %d, want 0", client.statusCalls)
}
}
func TestHandleRequestSearchLogins(t *testing.T) {
t.Parallel()
client := &fakeClient{
entries: []*keepassgov1.Entry{
{Id: "rusty-gitlab", Title: "Rusty GitLab", Username: "rustyryan", Url: "gitlab.com", Path: []string{"Joe", "Internet"}},
},
}
resp := HandleRequest(context.Background(), Request{
Action: "search-logins",
BearerToken: "secret",
Query: "GitLab",
}, "", client)
if !resp.Success {
t.Fatalf("HandleRequest(search-logins) success = false, error = %q", resp.Error)
}
if len(resp.SearchResults) != 1 || resp.SearchResults[0].ID != "rusty-gitlab" {
t.Fatalf("HandleRequest(search-logins).SearchResults = %#v, want rusty-gitlab", resp.SearchResults)
}
}
func TestHandleRequestSaveLoginUpdatesExistingEntry(t *testing.T) {
t.Parallel()
client := &fakeClient{
entries: []*keepassgov1.Entry{
{
Id: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "old-password",
Url: "https://vault.example.invalid/login",
Path: []string{"Crew", "Internet"},
Fields: map[string]string{
"URL1": "vault.example.invalid",
"X-Role": "inside-man",
},
Tags: []string{"vault"},
Notes: "Original notes stay intact.",
},
},
}
resp := HandleRequest(context.Background(), Request{
Action: "save-login",
BearerToken: "secret",
EntryID: "vault-console",
Username: "dannyocean",
Password: "new-password",
URL: "https://vault.example.invalid/login",
}, "", client)
if !resp.Success {
t.Fatalf("HandleRequest(save-login update) success = false, error = %q", resp.Error)
}
if client.upserted == nil {
t.Fatal("HandleRequest(save-login update) did not upsert an entry")
}
if got := client.upserted.Id; got != "vault-console" {
t.Fatalf("upserted.Id = %q, want vault-console", got)
}
if got := client.upserted.Password; got != "new-password" {
t.Fatalf("upserted.Password = %q, want new-password", got)
}
if got := client.upserted.Fields["X-Role"]; got != "inside-man" {
t.Fatalf("upserted.Fields[X-Role] = %q, want inside-man", got)
}
if got := client.upserted.Notes; got != "Original notes stay intact." {
t.Fatalf("upserted.Notes = %q, want original notes", got)
}
}
func TestHandleRequestSaveLoginCreatesNewEntryInChosenPath(t *testing.T) {
t.Parallel()
client := &fakeClient{}
resp := HandleRequest(context.Background(), Request{
Action: "save-login",
BearerToken: "secret",
Title: "Bellagio Login",
Username: "linuscaldwell",
Password: "yellow-chip",
URL: "https://bellagio.example.invalid/login",
Path: []string{"Crew", "Internet"},
}, "", client)
if !resp.Success {
t.Fatalf("HandleRequest(save-login create) success = false, error = %q", resp.Error)
}
if client.upserted == nil {
t.Fatal("HandleRequest(save-login create) did not upsert an entry")
}
if got := client.upserted.Title; got != "Bellagio Login" {
t.Fatalf("upserted.Title = %q, want Bellagio Login", got)
}
if got := client.upserted.Username; got != "linuscaldwell" {
t.Fatalf("upserted.Username = %q, want linuscaldwell", got)
}
if got := client.upserted.Path; !slices.Equal(got, []string{"Crew", "Internet"}) {
t.Fatalf("upserted.Path = %v, want [Crew Internet]", got)
}
if got := client.upserted.Id; got == "" {
t.Fatal("upserted.Id = empty, want generated id")
}
}
func TestHandleRequestFindLoginsInfersLockedStatusFromRPC(t *testing.T) {
t.Parallel()
client := &fakeClient{matchesErr: gstatus.Error(gcodes.FailedPrecondition, "vault is locked")}
resp := HandleRequest(context.Background(), Request{
Action: "find-logins",
BearerToken: "secret",
URL: "https://vault.example.invalid/login",
}, "", client)
if !resp.Success {
t.Fatalf("HandleRequest(find-logins locked) success = false, error = %q", resp.Error)
}
if resp.Status == nil || !resp.Status.Locked {
t.Fatalf("HandleRequest(find-logins locked).Status = %#v, want locked status", resp.Status)
}
if client.statusCalls != 0 {
t.Fatalf("HandleRequest(find-logins locked) statusCalls = %d, want 0", client.statusCalls)
}
}
func TestHandleRequestRequiresBearerToken(t *testing.T) {
t.Parallel()
resp := HandleRequest(context.Background(), Request{Action: "status"}, "", &fakeClient{})
if resp.Success {
t.Fatal("HandleRequest().Success = true, want false without token")
}
}
func TestRequestConnectionDefaultsAddress(t *testing.T) {
t.Parallel()
req := Request{Action: "status", BearerToken: "secret"}
conn, err := req.Connection("")
if err != nil {
t.Fatalf("Connection(\"\") error = %v", err)
}
if conn.GRPCAddress == "" {
t.Fatal("Connection().GRPCAddress = empty, want default address")
}
if runtime.GOOS != "windows" && !strings.HasPrefix(conn.GRPCAddress, "unix://") && conn.GRPCAddress != "off" {
t.Fatalf("Connection().GRPCAddress = %q, want unix socket default on this platform", conn.GRPCAddress)
}
}
func TestInstallManifest(t *testing.T) {
t.Parallel()
tmp := t.TempDir()
binaryPath := filepath.Join(tmp, "keepassgo-browser-bridge")
if err := os.WriteFile(binaryPath, []byte("#!/bin/sh\n"), 0o755); err != nil {
t.Fatalf("WriteFile(binary) error = %v", err)
}
path, err := InstallManifest(BrowserFirefox, binaryPath, "", filepath.Join(tmp, "firefox-host.json"))
if err != nil {
t.Fatalf("InstallManifest() error = %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
var manifest NativeHostManifest
if err := json.Unmarshal(data, &manifest); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if manifest.Path != binaryPath {
t.Fatalf("manifest.Path = %q, want %q", manifest.Path, binaryPath)
}
if len(manifest.AllowedExtensions) != 1 || manifest.AllowedExtensions[0] != DefaultFirefoxExtensionID() {
t.Fatalf("manifest.AllowedExtensions = %#v, want default firefox extension id", manifest.AllowedExtensions)
}
}
func TestChromiumExtensionIDFromManifestKey(t *testing.T) {
t.Parallel()
const publicKey = "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAMfW0u1k4K5A0uN2s0aH7uQKpM3x5Hf8mZfY1xVh0m7E2mJ7M8GiV4m0g0I2w9U9D1yqGQ6w8jzH5v8t7qB2RjMCAwEAAQ=="
got, err := ChromiumExtensionIDFromManifestKey(publicKey)
if err != nil {
t.Fatalf("ChromiumExtensionIDFromManifestKey() error = %v", err)
}
if got != "okcdfigpojphpoecpglkkmkjmiaefmpd" {
t.Fatalf("ChromiumExtensionIDFromManifestKey() = %q, want %q", got, "okcdfigpojphpoecpglkkmkjmiaefmpd")
}
}
func TestManifestSetChromiumIncludesAllOrigins(t *testing.T) {
t.Parallel()
manifest, err := ManifestSet(BrowserChromium, "/tmp/keepassgo-browser-bridge", []string{
"mjlnpdomnblnbblhacolncflebbgafhj",
"ddfbfpcgdjkffmjnialjpookcoedahcn",
"mjlnpdomnblnbblhacolncflebbgafhj",
})
if err != nil {
t.Fatalf("ManifestSet() error = %v", err)
}
want := []string{
"chrome-extension://ddfbfpcgdjkffmjnialjpookcoedahcn/",
"chrome-extension://mjlnpdomnblnbblhacolncflebbgafhj/",
}
if !slices.Equal(manifest.AllowedOrigins, want) {
t.Fatalf("ManifestSet().AllowedOrigins = %#v, want %#v", manifest.AllowedOrigins, want)
}
}
func TestDiscoverInstalledExtensionIDsInRoot(t *testing.T) {
t.Parallel()
root := t.TempDir()
writeExtensionManifest(t, filepath.Join(root, "Default", "Extensions", "mjlnpdomnblnbblhacolncflebbgafhj", "1.0.0", "manifest.json"), browserExtensionName)
writeExtensionManifest(t, filepath.Join(root, "Profile 1", "Extensions", "ddfbfpcgdjkffmjnialjpookcoedahcn", "1.2.0", "manifest.json"), browserExtensionName)
writeExtensionManifest(t, filepath.Join(root, "Profile 2", "Extensions", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "3.4.5", "manifest.json"), "Bellagio Notes")
writeExtensionManifest(t, filepath.Join(root, "Profile 3", "Extensions", "mjlnpdomnblnbblhacolncflebbgafhj", "1.1.0", "manifest.json"), browserExtensionName)
got, err := DiscoverInstalledExtensionIDsInRoot(root)
if err != nil {
t.Fatalf("DiscoverInstalledExtensionIDsInRoot() error = %v", err)
}
want := []string{
"ddfbfpcgdjkffmjnialjpookcoedahcn",
"mjlnpdomnblnbblhacolncflebbgafhj",
}
if !slices.Equal(got, want) {
t.Fatalf("DiscoverInstalledExtensionIDsInRoot() = %#v, want %#v", got, want)
}
}
func TestEnsureNativeHostManifestsInstallsFirefoxAndDiscoveredChromium(t *testing.T) {
tmp := t.TempDir()
t.Setenv("HOME", filepath.Join(tmp, "home"))
appDir := filepath.Join(tmp, "app")
if err := os.MkdirAll(appDir, 0o755); err != nil {
t.Fatalf("MkdirAll(appDir) error = %v", err)
}
appBinaryPath := filepath.Join(appDir, "keepassgo")
if err := os.WriteFile(appBinaryPath, []byte("#!/bin/sh\n"), 0o755); err != nil {
t.Fatalf("WriteFile(appBinaryPath) error = %v", err)
}
bridgeBinaryPath := filepath.Join(appDir, "keepassgo-browser-bridge")
if err := os.WriteFile(bridgeBinaryPath, []byte("#!/bin/sh\n"), 0o755); err != nil {
t.Fatalf("WriteFile(bridgeBinaryPath) error = %v", err)
}
home := filepath.Join(tmp, "home")
writeExtensionManifest(t, filepath.Join(home, ".config", "chromium", "Default", "Extensions", "mjlnpdomnblnbblhacolncflebbgafhj", "1.0.0", "manifest.json"), browserExtensionName)
writeExtensionManifest(t, filepath.Join(home, ".config", "google-chrome", "Profile 7", "Extensions", "ddfbfpcgdjkffmjnialjpookcoedahcn", "1.0.0", "manifest.json"), browserExtensionName)
if err := EnsureNativeHostManifests(appBinaryPath); err != nil {
t.Fatalf("EnsureNativeHostManifests() error = %v", err)
}
assertManifestContainsExtension(t, filepath.Join(home, ".mozilla", "native-messaging-hosts", NativeHostName+".json"), "allowed_extensions", DefaultFirefoxExtensionID())
assertManifestContainsExtension(t, filepath.Join(home, ".config", "chromium", "NativeMessagingHosts", NativeHostName+".json"), "allowed_origins", "chrome-extension://mjlnpdomnblnbblhacolncflebbgafhj/")
assertManifestContainsExtension(t, filepath.Join(home, ".config", "google-chrome", "NativeMessagingHosts", NativeHostName+".json"), "allowed_origins", "chrome-extension://ddfbfpcgdjkffmjnialjpookcoedahcn/")
}
type fakeClient struct {
status *keepassgov1.GetSessionStatusResponse
matches []*keepassgov1.BrowserLoginMatch
entries []*keepassgov1.Entry
credential *keepassgov1.GetBrowserCredentialResponse
upserted *keepassgov1.Entry
err error
matchesErr error
entriesErr error
credentialErr error
upsertErr error
statusCalls int
}
func writeExtensionManifest(t *testing.T, path, name string) {
t.Helper()
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatalf("MkdirAll(%q) error = %v", filepath.Dir(path), err)
}
data, err := json.Marshal(map[string]string{"name": name})
if err != nil {
t.Fatalf("Marshal(manifest %q) error = %v", path, err)
}
if err := os.WriteFile(path, append(data, '\n'), 0o644); err != nil {
t.Fatalf("WriteFile(%q) error = %v", path, err)
}
}
func assertManifestContainsExtension(t *testing.T, path, field, want string) {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile(%q) error = %v", path, err)
}
var manifest map[string]any
if err := json.Unmarshal(data, &manifest); err != nil {
t.Fatalf("Unmarshal(%q) error = %v", path, err)
}
valuesAny, ok := manifest[field]
if !ok {
t.Fatalf("manifest %q missing field %q", path, field)
}
valuesRaw, ok := valuesAny.([]any)
if !ok {
t.Fatalf("manifest %q field %q = %#v, want []any", path, field, valuesAny)
}
values := make([]string, 0, len(valuesRaw))
for _, raw := range valuesRaw {
text, ok := raw.(string)
if !ok {
t.Fatalf("manifest %q field %q value = %#v, want string", path, field, raw)
}
values = append(values, text)
}
if !slices.Contains(values, want) {
t.Fatalf("manifest %q field %q = %#v, want to contain %q", path, field, values, want)
}
}
func (f *fakeClient) Status(context.Context) (*keepassgov1.GetSessionStatusResponse, error) {
f.statusCalls++
if f.err != nil {
return nil, f.err
}
if f.status == nil {
return &keepassgov1.GetSessionStatusResponse{}, nil
}
return f.status, nil
}
func (f *fakeClient) FindBrowserLogins(context.Context, string) ([]*keepassgov1.BrowserLoginMatch, error) {
if f.matchesErr != nil {
return nil, f.matchesErr
}
if f.err != nil {
return nil, f.err
}
return f.matches, nil
}
func (f *fakeClient) ListEntries(context.Context, []string, string) ([]*keepassgov1.Entry, error) {
if f.entriesErr != nil {
return nil, f.entriesErr
}
if f.err != nil {
return nil, f.err
}
return f.entries, nil
}
func (f *fakeClient) GetBrowserCredential(context.Context, string, string) (*keepassgov1.GetBrowserCredentialResponse, error) {
if f.credentialErr != nil {
return nil, f.credentialErr
}
if f.err != nil {
return nil, f.err
}
if f.credential == nil {
return &keepassgov1.GetBrowserCredentialResponse{}, nil
}
return f.credential, nil
}
func (f *fakeClient) UpsertEntry(_ context.Context, entry *keepassgov1.Entry) (*keepassgov1.Entry, error) {
if f.upsertErr != nil {
return nil, f.upsertErr
}
f.upserted = entry
return entry, nil
}