diff --git a/cmd/keepassgo-browser-bridge/main.go b/cmd/keepassgo-browser-bridge/main.go index d350a5b..ab78aed 100644 --- a/cmd/keepassgo-browser-bridge/main.go +++ b/cmd/keepassgo-browser-bridge/main.go @@ -11,6 +11,7 @@ import ( "git.julianfamily.org/keepassgo/internal/browserbridge" "git.julianfamily.org/keepassgo/internal/grpcaddr" + "google.golang.org/grpc" ) func main() { @@ -91,11 +92,7 @@ func runStatus(args []string) error { GRPCAddress: strings.TrimSpace(*grpcAddr), BearerToken: strings.TrimSpace(*token), } - connCfg, err := req.Connection() - if err != nil { - return err - } - conn, client, ctx, err := browserbridge.Dial(context.Background(), connCfg) + conn, client, ctx, err := dialBridge(context.Background(), req) if err != nil { return err } @@ -111,11 +108,7 @@ func runNativeMessage() error { if err != nil { return err } - connCfg, err := req.Connection() - if err != nil { - return browserbridge.WriteResponse(os.Stdout, browserbridge.Response{Success: false, Error: err.Error()}) - } - conn, client, ctx, err := browserbridge.Dial(context.Background(), connCfg) + conn, client, ctx, err := dialBridge(context.Background(), req) if err != nil { return browserbridge.WriteResponse(os.Stdout, browserbridge.Response{Success: false, Error: err.Error()}) } @@ -123,6 +116,14 @@ func runNativeMessage() error { return browserbridge.WriteResponse(os.Stdout, browserbridge.HandleRequest(ctx, req, client)) } +func dialBridge(ctx context.Context, req browserbridge.Request) (*grpc.ClientConn, *browserbridge.GRPCClient, context.Context, error) { + connCfg, err := req.Connection() + if err != nil { + return nil, nil, nil, err + } + return browserbridge.Dial(ctx, connCfg) +} + func defaultBinaryPath() (string, error) { return browserbridge.ResolveBridgeBinaryPath("") } diff --git a/internal/api/server.go b/internal/api/server.go index e05d87d..2d5fc7f 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -113,9 +113,6 @@ func (s *Server) GetSessionStatus(ctx context.Context, _ *keepassgov1.GetSession if err != nil { return nil, err } - s.mu.RLock() - defer s.mu.RUnlock() - pendingApprovals := s.approvals.Pending() var tokenPending uint32 for _, pending := range pendingApprovals { @@ -123,11 +120,16 @@ func (s *Server) GetSessionStatus(ctx context.Context, _ *keepassgov1.GetSession tokenPending++ } } + s.mu.RLock() + locked := s.locked + dirty := s.dirty + entryCount := uint32(len(s.model.Entries)) + s.mu.RUnlock() return &keepassgov1.GetSessionStatusResponse{ - Locked: s.locked, - Dirty: s.dirty, - EntryCount: uint32(len(s.model.Entries)), + Locked: locked, + Dirty: dirty, + EntryCount: entryCount, PendingApprovalCount: uint32(len(pendingApprovals)), TokenPendingApprovalCount: tokenPending, }, nil @@ -486,76 +488,75 @@ func (s *Server) ListGroups(ctx context.Context, req *keepassgov1.ListGroupsRequ }, nil } -func (s *Server) CreateGroup(ctx context.Context, req *keepassgov1.CreateGroupRequest) (*keepassgov1.CreateGroupResponse, error) { +func (s *Server) mutateAuthorizedVisiblePath(ctx context.Context, clientPath []string, op apitokens.Operation, mutate func(*vault.Model, []string) error) error { model, locked := s.snapshotModel() if locked { - return nil, status.Error(codes.FailedPrecondition, "vault is locked") + return status.Error(codes.FailedPrecondition, "vault is locked") } - parentPath := expandClientPath(visibleModel(model), req.GetParentPath()) - if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, parentPath); err != nil { - return nil, err + internalPath := expandClientPath(visibleModel(model), clientPath) + if _, err := s.authorizePathRequest(ctx, op, internalPath); err != nil { + return err } + return s.mutateAuthorizedModel(func() error { return nil }, func(model *vault.Model) error { + return mutate(model, internalPath) + }) +} +func (s *Server) mutateAuthorizedModel(authorize func() error, mutate func(*vault.Model) error) error { + if err := authorize(); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() - - s.model.CreateGroup(parentPath, req.GetName()) + if err := mutate(&s.model); err != nil { + return err + } s.dirty = true s.syncMutationLocked() + return nil +} + +func (s *Server) CreateGroup(ctx context.Context, req *keepassgov1.CreateGroupRequest) (*keepassgov1.CreateGroupResponse, error) { + if err := s.mutateAuthorizedVisiblePath(ctx, req.GetParentPath(), apitokens.OperationMutateGroup, func(model *vault.Model, parentPath []string) error { + model.CreateGroup(parentPath, req.GetName()) + return nil + }); err != nil { + return nil, err + } return &keepassgov1.CreateGroupResponse{}, nil } func (s *Server) RenameGroup(ctx context.Context, req *keepassgov1.RenameGroupRequest) (*keepassgov1.RenameGroupResponse, error) { - model, locked := s.snapshotModel() - if locked { - return nil, status.Error(codes.FailedPrecondition, "vault is locked") - } - groupPath := expandClientPath(visibleModel(model), req.GetPath()) - if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, groupPath); err != nil { + if err := s.mutateAuthorizedVisiblePath(ctx, req.GetPath(), apitokens.OperationMutateGroup, func(model *vault.Model, groupPath []string) error { + if err := model.RenameGroup(groupPath, req.GetNewName()); err != nil { + if errors.Is(err, vault.ErrEntryNotFound) { + return status.Error(codes.NotFound, err.Error()) + } + return status.Errorf(codes.Internal, "rename group: %v", err) + } + return nil + }); err != nil { return nil, err } - - s.mu.Lock() - defer s.mu.Unlock() - - if err := s.model.RenameGroup(groupPath, req.GetNewName()); err != nil { - if errors.Is(err, vault.ErrEntryNotFound) { - return nil, status.Error(codes.NotFound, err.Error()) - } - return nil, status.Errorf(codes.Internal, "rename group: %v", err) - } - - s.dirty = true - s.syncMutationLocked() return &keepassgov1.RenameGroupResponse{}, nil } func (s *Server) DeleteGroup(ctx context.Context, req *keepassgov1.DeleteGroupRequest) (*keepassgov1.DeleteGroupResponse, error) { - model, locked := s.snapshotModel() - if locked { - return nil, status.Error(codes.FailedPrecondition, "vault is locked") - } - groupPath := expandClientPath(visibleModel(model), req.GetPath()) - if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, groupPath); err != nil { + if err := s.mutateAuthorizedVisiblePath(ctx, req.GetPath(), apitokens.OperationMutateGroup, func(model *vault.Model, groupPath []string) error { + if err := model.DeleteGroup(groupPath); err != nil { + switch { + case errors.Is(err, vault.ErrEntryNotFound): + return status.Error(codes.NotFound, err.Error()) + case errors.Is(err, vault.ErrGroupNotEmpty): + return status.Error(codes.FailedPrecondition, err.Error()) + default: + return status.Errorf(codes.Internal, "delete group: %v", err) + } + } + return nil + }); err != nil { return nil, err } - - s.mu.Lock() - defer s.mu.Unlock() - - if err := s.model.DeleteGroup(groupPath); err != nil { - switch { - case errors.Is(err, vault.ErrEntryNotFound): - return nil, status.Error(codes.NotFound, err.Error()) - case errors.Is(err, vault.ErrGroupNotEmpty): - return nil, status.Error(codes.FailedPrecondition, err.Error()) - default: - return nil, status.Errorf(codes.Internal, "delete group: %v", err) - } - } - - s.dirty = true - s.syncMutationLocked() return &keepassgov1.DeleteGroupResponse{}, nil } @@ -572,12 +573,12 @@ func (s *Server) UpsertEntry(ctx context.Context, req *keepassgov1.UpsertEntryRe if _, err := s.authorizeUpsertEntryRequest(ctx, entry); err != nil { return nil, err } - - s.mu.Lock() - s.model.UpsertEntry(entry) - s.dirty = true - s.syncMutationLocked() - s.mu.Unlock() + if err := s.mutateAuthorizedModel(func() error { return nil }, func(model *vault.Model) error { + model.UpsertEntry(entry) + return nil + }); err != nil { + return nil, err + } return &keepassgov1.UpsertEntryResponse{Entry: entryToProtoWithModel(visibleModel(model), entry)}, nil } diff --git a/internal/apiapproval/approval.go b/internal/apiapproval/approval.go index ae2cd94..0968402 100644 --- a/internal/apiapproval/approval.go +++ b/internal/apiapproval/approval.go @@ -13,10 +13,11 @@ import ( ) var ( - ErrRequestDenied = errors.New("authorization request denied") - ErrRequestCanceled = errors.New("authorization request canceled") - ErrRequestTimedOut = errors.New("authorization request timed out") - ErrRequestNotFound = errors.New("authorization request not found") + ErrRequestDenied = errors.New("authorization request denied") + ErrRequestCanceled = errors.New("authorization request canceled") + ErrRequestTimedOut = errors.New("authorization request timed out") + ErrRequestNotFound = errors.New("authorization request not found") + ErrBrokerNotConfigured = errors.New("authorization broker is not configured") ) type Outcome string @@ -120,7 +121,7 @@ func (b *Broker) SetChangeNotifier(notify func()) { func (b *Broker) Request(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (Result, error) { if b == nil { - return Result{}, ErrRequestTimedOut + return Result{}, ErrBrokerNotConfigured } pending := &pendingRequest{ diff --git a/internal/apiapproval/approval_test.go b/internal/apiapproval/approval_test.go index ab1a719..5725ade 100644 --- a/internal/apiapproval/approval_test.go +++ b/internal/apiapproval/approval_test.go @@ -121,6 +121,16 @@ func TestBrokerTimesOutPendingRequests(t *testing.T) { } } +func TestNilBrokerReturnsConfigurationError(t *testing.T) { + t.Parallel() + + var broker *Broker + _, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}}) + if !errors.Is(err, ErrBrokerNotConfigured) { + t.Fatalf("Request(nil broker) error = %v, want %v", err, ErrBrokerNotConfigured) + } +} + func TestBrokerNotifiesWhenPendingRequestsChange(t *testing.T) { t.Parallel() diff --git a/internal/appstate/state.go b/internal/appstate/state.go index cf982e3..1b7d52f 100644 --- a/internal/appstate/state.go +++ b/internal/appstate/state.go @@ -192,139 +192,111 @@ func (s *State) RemoteCredentialEntries() ([]vault.Entry, error) { } func (s *State) IssueAPIToken(name, clientName string, expiresAt *time.Time, now time.Time) (apitokens.Token, string, error) { - session, ok := s.Session.(MutableSession) - if !ok { - return apitokens.Token{}, "", fmt.Errorf("session is not mutable") - } - model, err := session.Current() + result, err := s.mutateAPITokens(apiaudit.EventTokenIssued, "issued API token", func(model *vault.Model) (tokenMutationResult, error) { + token, secret, err := apitokens.Issue(name, clientName, expiresAt, now) + if err != nil { + return tokenMutationResult{}, err + } + apitokens.Upsert(model, token) + return tokenMutationResult{token: token, secret: secret}, nil + }) if err != nil { return apitokens.Token{}, "", err } - token, secret, err := apitokens.Issue(name, clientName, expiresAt, now) - if err != nil { - return apitokens.Token{}, "", err - } - apitokens.Upsert(&model, token) - session.Replace(model) - if err := s.markDirtyAndAutoSave(); err != nil { - return apitokens.Token{}, "", err - } - s.recordTokenAudit(apiaudit.EventTokenIssued, token, "issued API token") - return token, secret, nil + return result.token, result.secret, nil } func (s *State) RotateAPIToken(id string, now time.Time) (apitokens.Token, string, error) { - session, ok := s.Session.(MutableSession) - if !ok { - return apitokens.Token{}, "", fmt.Errorf("session is not mutable") - } - model, err := session.Current() + result, err := s.mutateAPITokens(apiaudit.EventTokenRotated, "rotated API token", func(model *vault.Model) (tokenMutationResult, error) { + token, err := apitokens.Find(*model, id) + if err != nil { + return tokenMutationResult{}, err + } + token, secret, err := apitokens.Rotate(token, now) + if err != nil { + return tokenMutationResult{}, err + } + apitokens.Upsert(model, token) + return tokenMutationResult{token: token, secret: secret}, nil + }) if err != nil { return apitokens.Token{}, "", err } - token, err := apitokens.Find(model, id) - if err != nil { - return apitokens.Token{}, "", err - } - token, secret, err := apitokens.Rotate(token, now) - if err != nil { - return apitokens.Token{}, "", err - } - apitokens.Upsert(&model, token) - session.Replace(model) - if err := s.markDirtyAndAutoSave(); err != nil { - return apitokens.Token{}, "", err - } - s.recordTokenAudit(apiaudit.EventTokenRotated, token, "rotated API token") - return token, secret, nil + return result.token, result.secret, nil } func (s *State) UpsertAPIToken(token apitokens.Token) error { - session, ok := s.Session.(MutableSession) - if !ok { - return fmt.Errorf("session is not mutable") - } - model, err := session.Current() - if err != nil { - return err - } - apitokens.Upsert(&model, token) - session.Replace(model) - if err := s.markDirtyAndAutoSave(); err != nil { - return err - } - s.recordTokenAudit(apiaudit.EventTokenUpdated, token, "updated API token") - return nil + _, err := s.mutateAPITokens(apiaudit.EventTokenUpdated, "updated API token", func(model *vault.Model) (tokenMutationResult, error) { + apitokens.Upsert(model, token) + return tokenMutationResult{token: token}, nil + }) + return err } func (s *State) DisableAPIToken(id string) error { - session, ok := s.Session.(MutableSession) - if !ok { - return fmt.Errorf("session is not mutable") - } - model, err := session.Current() - if err != nil { - return err - } - token, err := apitokens.Find(model, id) - if err != nil { - return err - } - token = apitokens.Disable(token) - apitokens.Upsert(&model, token) - session.Replace(model) - if err := s.markDirtyAndAutoSave(); err != nil { - return err - } - s.recordTokenAudit(apiaudit.EventTokenDisabled, token, "disabled API token") - return nil + _, err := s.mutateAPITokens(apiaudit.EventTokenDisabled, "disabled API token", func(model *vault.Model) (tokenMutationResult, error) { + token, err := apitokens.Find(*model, id) + if err != nil { + return tokenMutationResult{}, err + } + token = apitokens.Disable(token) + apitokens.Upsert(model, token) + return tokenMutationResult{token: token}, nil + }) + return err } func (s *State) RevokeAPIToken(id string, when time.Time) error { - session, ok := s.Session.(MutableSession) - if !ok { - return fmt.Errorf("session is not mutable") - } - model, err := session.Current() - if err != nil { - return err - } - token, err := apitokens.Find(model, id) - if err != nil { - return err - } - token = apitokens.Revoke(token, when) - apitokens.Upsert(&model, token) - session.Replace(model) - if err := s.markDirtyAndAutoSave(); err != nil { - return err - } - s.recordTokenAudit(apiaudit.EventTokenRevoked, token, "revoked API token") - return nil + _, err := s.mutateAPITokens(apiaudit.EventTokenRevoked, "revoked API token", func(model *vault.Model) (tokenMutationResult, error) { + token, err := apitokens.Find(*model, id) + if err != nil { + return tokenMutationResult{}, err + } + token = apitokens.Revoke(token, when) + apitokens.Upsert(model, token) + return tokenMutationResult{token: token}, nil + }) + return err } func (s *State) DeleteAPIToken(id string) error { + _, err := s.mutateAPITokens(apiaudit.EventTokenDeleted, "deleted API token", func(model *vault.Model) (tokenMutationResult, error) { + token, err := apitokens.Find(*model, id) + if err != nil { + return tokenMutationResult{}, err + } + if err := apitokens.Delete(model, id); err != nil { + return tokenMutationResult{}, err + } + return tokenMutationResult{token: token}, nil + }) + return err +} + +type tokenMutationResult struct { + token apitokens.Token + secret string +} + +func (s *State) mutateAPITokens(eventType apiaudit.EventType, message string, mutate func(*vault.Model) (tokenMutationResult, error)) (tokenMutationResult, error) { session, ok := s.Session.(MutableSession) if !ok { - return fmt.Errorf("session is not mutable") + return tokenMutationResult{}, fmt.Errorf("session is not mutable") } model, err := session.Current() if err != nil { - return err + return tokenMutationResult{}, err } - token, err := apitokens.Find(model, id) + result, err := mutate(&model) if err != nil { - return err - } - if err := apitokens.Delete(&model, id); err != nil { - return err + return tokenMutationResult{}, err } session.Replace(model) if err := s.markDirtyAndAutoSave(); err != nil { - return err + return tokenMutationResult{}, err } - s.recordTokenAudit(apiaudit.EventTokenDeleted, token, "deleted API token") - return nil + s.recordTokenAudit(eventType, result.token, message) + return result, nil } func (s *State) recordTokenAudit(eventType apiaudit.EventType, token apitokens.Token, message string) { diff --git a/internal/appui/runtime.go b/internal/appui/runtime.go index 9f6b474..b9b38b8 100644 --- a/internal/appui/runtime.go +++ b/internal/appui/runtime.go @@ -109,7 +109,9 @@ func ensureBrowserNativeHosts() { if err != nil { return } - _ = browserbridge.EnsureNativeHostManifests(appBinaryPath) + if err := browserbridge.EnsureNativeHostManifests(appBinaryPath); err != nil { + platform.LogInfo("KeePassGO", fmt.Sprintf("keepassgo browser native host registration failed: %v", err)) + } } type uiApprovalManager struct { diff --git a/internal/browserbridge/bridge.go b/internal/browserbridge/bridge.go index 0bc7e8e..d223c90 100644 --- a/internal/browserbridge/bridge.go +++ b/internal/browserbridge/bridge.go @@ -15,6 +15,8 @@ import ( "git.julianfamily.org/keepassgo/internal/grpcaddr" keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1" + gcodes "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" ) const ( @@ -22,6 +24,7 @@ const ( defaultFirefoxID = "browser@keepassgo.com" maxNativeMessageSize = 1024 * 1024 chromiumIDBytes = 16 + responseVersion = "1" ) type Request struct { @@ -162,33 +165,25 @@ func HandleRequest(ctx context.Context, req Request, client Client) Response { if err != nil { return Response{Success: false, Error: err.Error(), Status: disconnectedStatus(conn.GRPCAddress)} } - return Response{Success: true, Status: status, Version: "1"} + return Response{Success: true, Status: status, Version: responseVersion} case "find-logins": - status, err := statusResponse(ctx, client, conn.GRPCAddress) - if err != nil { - return Response{Success: false, Error: err.Error(), Status: disconnectedStatus(conn.GRPCAddress)} - } - if status.Locked { - return Response{Success: true, Status: status, Matches: nil, Version: "1"} - } matches, err := findMatches(ctx, client, req.URL) if err != nil { - return Response{Success: false, Error: err.Error(), Status: status} - } - return Response{Success: true, Status: status, Matches: matches, Version: "1"} - case "get-login": - status, err := statusResponse(ctx, client, conn.GRPCAddress) - if err != nil { + if status := inferredActionStatus(conn.GRPCAddress, err); status != nil { + return Response{Success: true, Status: status, Matches: nil, Version: responseVersion} + } return Response{Success: false, Error: err.Error(), Status: disconnectedStatus(conn.GRPCAddress)} } - if status.Locked { - return Response{Success: false, Error: "vault is locked", Status: status} - } + return Response{Success: true, Status: availableStatus(conn.GRPCAddress), Matches: matches, Version: responseVersion} + case "get-login": credential, err := loadCredential(ctx, client, req.EntryID, req.URL) if err != nil { - return Response{Success: false, Error: err.Error(), Status: status} + if status := inferredActionStatus(conn.GRPCAddress, err); status != nil { + return Response{Success: false, Error: err.Error(), Status: status} + } + return Response{Success: false, Error: err.Error(), Status: disconnectedStatus(conn.GRPCAddress)} } - return Response{Success: true, Status: status, Credential: credential, Version: "1"} + return Response{Success: true, Status: availableStatus(conn.GRPCAddress), Credential: credential, Version: responseVersion} default: return Response{Success: false, Error: fmt.Sprintf("unsupported action %q", action)} } @@ -198,6 +193,21 @@ func disconnectedStatus(addr string) *Status { return &Status{Connected: false, GRPCAddress: strings.TrimSpace(addr)} } +func availableStatus(addr string) *Status { + return &Status{Connected: true, Locked: false, GRPCAddress: strings.TrimSpace(addr)} +} + +func inferredActionStatus(addr string, err error) *Status { + switch gstatus.Code(err) { + case gcodes.FailedPrecondition: + return &Status{Connected: true, Locked: true, GRPCAddress: strings.TrimSpace(addr)} + case gcodes.OK: + return availableStatus(addr) + default: + return nil + } +} + func statusResponse(ctx context.Context, client Client, addr string) (*Status, error) { resp, err := client.Status(ctx) if err != nil { diff --git a/internal/browserbridge/bridge_test.go b/internal/browserbridge/bridge_test.go index 15c79e8..ca05759 100644 --- a/internal/browserbridge/bridge_test.go +++ b/internal/browserbridge/bridge_test.go @@ -13,6 +13,8 @@ import ( "testing" keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1" + gcodes "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" ) func TestReadRequestAndWriteResponse(t *testing.T) { @@ -70,8 +72,7 @@ func TestReadRequestAndWriteResponse(t *testing.T) { func TestHandleRequestFindLogins(t *testing.T) { t.Parallel() - client := fakeClient{ - status: &keepassgov1.GetSessionStatusResponse{Locked: false, EntryCount: 2}, + client := &fakeClient{ matches: []*keepassgov1.BrowserLoginMatch{ {Id: "vault-console", Title: "Vault Console", Username: "dannyocean", Url: "https://vault.example.invalid", Quality: "exact-host"}, }, @@ -87,12 +88,15 @@ func TestHandleRequestFindLogins(t *testing.T) { if len(resp.Matches) != 1 || resp.Matches[0].ID != "vault-console" { t.Fatalf("HandleRequest().Matches = %#v, want vault-console", resp.Matches) } + if client.statusCalls != 0 { + t.Fatalf("HandleRequest(find-logins) statusCalls = %d, want 0", client.statusCalls) + } } func TestHandleRequestStatusIncludesPendingApprovalCounts(t *testing.T) { t.Parallel() - client := fakeClient{ + client := &fakeClient{ status: &keepassgov1.GetSessionStatusResponse{ Locked: false, EntryCount: 2, @@ -121,8 +125,7 @@ func TestHandleRequestStatusIncludesPendingApprovalCounts(t *testing.T) { func TestHandleRequestGetLogin(t *testing.T) { t.Parallel() - client := fakeClient{ - status: &keepassgov1.GetSessionStatusResponse{Locked: false, EntryCount: 1}, + client := &fakeClient{ credential: &keepassgov1.GetBrowserCredentialResponse{ Id: "vault-console", Username: "dannyocean", @@ -142,12 +145,35 @@ func TestHandleRequestGetLogin(t *testing.T) { if resp.Credential == nil || resp.Credential.ID != "vault-console" { t.Fatalf("HandleRequest().Credential = %#v, want vault-console", resp.Credential) } + if client.statusCalls != 0 { + t.Fatalf("HandleRequest(get-login) statusCalls = %d, want 0", client.statusCalls) + } +} + +func TestHandleRequestFindLoginsInfersLockedStatusFromRPC(t *testing.T) { + t.Parallel() + + client := &fakeClient{matchesErr: gstatus.Error(gcodes.FailedPrecondition, "vault is locked")} + resp := HandleRequest(context.Background(), Request{ + Action: "find-logins", + BearerToken: "secret", + URL: "https://vault.example.invalid/login", + }, client) + if !resp.Success { + t.Fatalf("HandleRequest(find-logins locked) success = false, error = %q", resp.Error) + } + if resp.Status == nil || !resp.Status.Locked { + t.Fatalf("HandleRequest(find-logins locked).Status = %#v, want locked status", resp.Status) + } + if client.statusCalls != 0 { + t.Fatalf("HandleRequest(find-logins locked) statusCalls = %d, want 0", client.statusCalls) + } } func TestHandleRequestRequiresBearerToken(t *testing.T) { t.Parallel() - resp := HandleRequest(context.Background(), Request{Action: "status"}, fakeClient{}) + resp := HandleRequest(context.Background(), Request{Action: "status"}, &fakeClient{}) if resp.Success { t.Fatal("HandleRequest().Success = true, want false without token") } @@ -282,10 +308,13 @@ func TestEnsureNativeHostManifestsInstallsFirefoxAndDiscoveredChromium(t *testin } type fakeClient struct { - status *keepassgov1.GetSessionStatusResponse - matches []*keepassgov1.BrowserLoginMatch - credential *keepassgov1.GetBrowserCredentialResponse - err error + status *keepassgov1.GetSessionStatusResponse + matches []*keepassgov1.BrowserLoginMatch + credential *keepassgov1.GetBrowserCredentialResponse + err error + matchesErr error + credentialErr error + statusCalls int } func writeExtensionManifest(t *testing.T, path, name string) { @@ -333,7 +362,8 @@ func assertManifestContainsExtension(t *testing.T, path, field, want string) { } } -func (f fakeClient) Status(context.Context) (*keepassgov1.GetSessionStatusResponse, error) { +func (f *fakeClient) Status(context.Context) (*keepassgov1.GetSessionStatusResponse, error) { + f.statusCalls++ if f.err != nil { return nil, f.err } @@ -343,14 +373,20 @@ func (f fakeClient) Status(context.Context) (*keepassgov1.GetSessionStatusRespon return f.status, nil } -func (f fakeClient) FindBrowserLogins(context.Context, string) ([]*keepassgov1.BrowserLoginMatch, error) { +func (f *fakeClient) FindBrowserLogins(context.Context, string) ([]*keepassgov1.BrowserLoginMatch, error) { + if f.matchesErr != nil { + return nil, f.matchesErr + } if f.err != nil { return nil, f.err } return f.matches, nil } -func (f fakeClient) GetBrowserCredential(context.Context, string, string) (*keepassgov1.GetBrowserCredentialResponse, error) { +func (f *fakeClient) GetBrowserCredential(context.Context, string, string) (*keepassgov1.GetBrowserCredentialResponse, error) { + if f.credentialErr != nil { + return nil, f.credentialErr + } if f.err != nil { return nil, f.err }