From f77a185e465cad689624d350c0702df0ec1c236f Mon Sep 17 00:00:00 2001 From: Joe Julian Date: Sun, 29 Mar 2026 23:09:36 -0700 Subject: [PATCH] Add API approval broker for gRPC authorization prompts --- api/server.go | 242 +++++++++++++++++++++++------------ api/server_test.go | 175 ++++++++++++++++++++++++- apiapproval/approval.go | 215 +++++++++++++++++++++++++++++++ apiapproval/approval_test.go | 134 +++++++++++++++++++ 4 files changed, 684 insertions(+), 82 deletions(-) create mode 100644 apiapproval/approval.go create mode 100644 apiapproval/approval_test.go diff --git a/api/server.go b/api/server.go index 58a34a0..8d7f872 100644 --- a/api/server.go +++ b/api/server.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "git.julianfamily.org/keepassgo/apiapproval" "git.julianfamily.org/keepassgo/apitokens" "git.julianfamily.org/keepassgo/clipboard" "git.julianfamily.org/keepassgo/passwords" @@ -33,6 +34,7 @@ type Server struct { lifecycle lifecycleBackend profiles map[string]passwords.Profile clipboard clipboard.Writer + approvals *apiapproval.Broker } type lifecycleBackend interface { @@ -44,11 +46,17 @@ type lifecycleBackend interface { Unlock(vault.MasterKey) error } +type modelReplaceableLifecycle interface { + lifecycleBackend + Replace(vault.Model) +} + func NewServer(model vault.Model, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer) *Server { return &Server{ model: model, profiles: profiles, clipboard: clipboardWriter, + approvals: apiapproval.NewBroker(30 * time.Second), } } @@ -58,6 +66,15 @@ func NewServerWithLifecycle(model vault.Model, profiles map[string]passwords.Pro return server } +func (s *Server) ApprovalBroker() *apiapproval.Broker { + return s.approvals +} + +func (s *Server) ResolveApproval(id string, outcome apiapproval.Outcome) (apiapproval.Request, error) { + request, _, err := s.approvals.Resolve(id, outcome) + return request, err +} + func (s *Server) GetSessionStatus(_ context.Context, _ *keepassgov1.GetSessionStatusRequest) (*keepassgov1.GetSessionStatusResponse, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -193,17 +210,15 @@ func mapLifecycleError(operation string, err error) error { } func (s *Server) ListEntries(ctx context.Context, req *keepassgov1.ListEntriesRequest) (*keepassgov1.ListEntriesResponse, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } if _, err := s.authorizePathRequest(ctx, apitokens.OperationListEntries, req.GetPath()); err != nil { return nil, err } - model := s.visibleModel() + model = visibleModel(model) var entries []vault.Entry if strings.TrimSpace(req.GetQuery()) != "" { results := model.Search(req.GetQuery()) @@ -226,10 +241,8 @@ func (s *Server) ListEntries(ctx context.Context, req *keepassgov1.ListEntriesRe } func (s *Server) ListGroups(ctx context.Context, req *keepassgov1.ListGroupsRequest) (*keepassgov1.ListGroupsResponse, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } if _, err := s.authorizePathRequest(ctx, apitokens.OperationListGroups, req.GetPath()); err != nil { @@ -237,18 +250,17 @@ func (s *Server) ListGroups(ctx context.Context, req *keepassgov1.ListGroupsRequ } return &keepassgov1.ListGroupsResponse{ - Names: s.visibleModel().ChildGroups(req.GetPath()), + Names: visibleModel(model).ChildGroups(req.GetPath()), }, nil } func (s *Server) CreateGroup(ctx context.Context, req *keepassgov1.CreateGroupRequest) (*keepassgov1.CreateGroupResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetParentPath()); err != nil { return nil, err } + s.mu.Lock() + defer s.mu.Unlock() if s.locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } @@ -259,13 +271,12 @@ func (s *Server) CreateGroup(ctx context.Context, req *keepassgov1.CreateGroupRe } func (s *Server) RenameGroup(ctx context.Context, req *keepassgov1.RenameGroupRequest) (*keepassgov1.RenameGroupResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetPath()); err != nil { return nil, err } + s.mu.Lock() + defer s.mu.Unlock() if s.locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } @@ -282,13 +293,12 @@ func (s *Server) RenameGroup(ctx context.Context, req *keepassgov1.RenameGroupRe } func (s *Server) DeleteGroup(ctx context.Context, req *keepassgov1.DeleteGroupRequest) (*keepassgov1.DeleteGroupResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetPath()); err != nil { return nil, err } + s.mu.Lock() + defer s.mu.Unlock() if s.locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } @@ -314,16 +324,15 @@ func (s *Server) UpsertEntry(ctx context.Context, req *keepassgov1.UpsertEntryRe } entry := entryFromProto(req.GetEntry()) + if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil { + return nil, err + } s.mu.Lock() if s.locked { s.mu.Unlock() return nil, status.Error(codes.FailedPrecondition, "vault is locked") } - if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil { - s.mu.Unlock() - return nil, err - } s.model.UpsertEntry(entry) s.dirty = true s.mu.Unlock() @@ -332,18 +341,19 @@ func (s *Server) UpsertEntry(ctx context.Context, req *keepassgov1.UpsertEntryRe } func (s *Server) DeleteEntry(ctx context.Context, req *keepassgov1.DeleteEntryRequest) (*keepassgov1.DeleteEntryResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } - if entry, err := findEntryByID(s.model, req.GetId()); err == nil { + if entry, err := findEntryByID(model, req.GetId()); err == nil { if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil { return nil, err } } + s.mu.Lock() + defer s.mu.Unlock() + if err := s.model.DeleteEntry(req.GetId()); err != nil { if errors.Is(err, vault.ErrEntryNotFound) { return nil, status.Error(codes.NotFound, err.Error()) @@ -356,15 +366,13 @@ func (s *Server) DeleteEntry(ctx context.Context, req *keepassgov1.DeleteEntryRe } func (s *Server) RestoreEntry(ctx context.Context, req *keepassgov1.RestoreEntryRequest) (*keepassgov1.RestoreEntryResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } var restored vault.Entry - for _, entry := range s.model.RecycleBin { + for _, entry := range model.RecycleBin { if entry.ID == req.GetId() { restored = entry break @@ -376,6 +384,9 @@ func (s *Server) RestoreEntry(ctx context.Context, req *keepassgov1.RestoreEntry } } + s.mu.Lock() + defer s.mu.Unlock() + if err := s.model.RestoreEntry(req.GetId()); err != nil { if errors.Is(err, vault.ErrEntryNotFound) { return nil, status.Error(codes.NotFound, err.Error()) @@ -388,14 +399,12 @@ func (s *Server) RestoreEntry(ctx context.Context, req *keepassgov1.RestoreEntry } func (s *Server) ListEntryHistory(ctx context.Context, req *keepassgov1.ListEntryHistoryRequest) (*keepassgov1.ListEntryHistoryResponse, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } - entry, err := findEntryByID(s.model, req.GetId()) + entry, err := findEntryByID(model, req.GetId()) if err != nil { return nil, status.Error(codes.NotFound, err.Error()) } @@ -413,13 +422,11 @@ func (s *Server) ListEntryHistory(ctx context.Context, req *keepassgov1.ListEntr } func (s *Server) RestoreEntryHistory(ctx context.Context, req *keepassgov1.RestoreEntryHistoryRequest) (*keepassgov1.RestoreEntryHistoryResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } - entry, err := findEntryByID(s.model, req.GetId()) + entry, err := findEntryByID(model, req.GetId()) if err != nil { return nil, status.Error(codes.NotFound, err.Error()) } @@ -427,6 +434,8 @@ func (s *Server) RestoreEntryHistory(ctx context.Context, req *keepassgov1.Resto return nil, err } + s.mu.Lock() + defer s.mu.Unlock() if err := s.model.RestoreEntryVersion(req.GetId(), int(req.GetHistoryIndex())); err != nil { if errors.Is(err, vault.ErrEntryNotFound) { return nil, status.Error(codes.NotFound, err.Error()) @@ -523,14 +532,12 @@ func (s *Server) InstantiateTemplate(_ context.Context, req *keepassgov1.Instant } func (s *Server) ListAttachments(ctx context.Context, req *keepassgov1.ListAttachmentsRequest) (*keepassgov1.ListAttachmentsResponse, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } - entry, err := findEntryByID(s.model, req.GetEntryId()) + entry, err := findEntryByID(model, req.GetEntryId()) if err != nil { return nil, status.Error(codes.NotFound, err.Error()) } @@ -548,14 +555,11 @@ func (s *Server) ListAttachments(ctx context.Context, req *keepassgov1.ListAttac } func (s *Server) UploadAttachment(ctx context.Context, req *keepassgov1.UploadAttachmentRequest) (*keepassgov1.UploadAttachmentResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } - - entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId()) + entry, err := findEntryByID(model, req.GetEntryId()) if err != nil { return nil, status.Error(codes.NotFound, err.Error()) } @@ -563,6 +567,13 @@ func (s *Server) UploadAttachment(ctx context.Context, req *keepassgov1.UploadAt return nil, err } + s.mu.Lock() + defer s.mu.Unlock() + entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId()) + if err != nil { + return nil, status.Error(codes.NotFound, err.Error()) + } + if entry.Attachments == nil { entry.Attachments = map[string][]byte{} } @@ -574,14 +585,12 @@ func (s *Server) UploadAttachment(ctx context.Context, req *keepassgov1.UploadAt } func (s *Server) DownloadAttachment(ctx context.Context, req *keepassgov1.DownloadAttachmentRequest) (*keepassgov1.DownloadAttachmentResponse, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } - entry, err := findEntryByID(s.model, req.GetEntryId()) + entry, err := findEntryByID(model, req.GetEntryId()) if err != nil { return nil, status.Error(codes.NotFound, err.Error()) } @@ -600,14 +609,11 @@ func (s *Server) DownloadAttachment(ctx context.Context, req *keepassgov1.Downlo } func (s *Server) DeleteAttachment(ctx context.Context, req *keepassgov1.DeleteAttachmentRequest) (*keepassgov1.DeleteAttachmentResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if s.locked { + model, locked := s.snapshotModel() + if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } - - entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId()) + entry, err := findEntryByID(model, req.GetEntryId()) if err != nil { return nil, status.Error(codes.NotFound, err.Error()) } @@ -615,6 +621,13 @@ func (s *Server) DeleteAttachment(ctx context.Context, req *keepassgov1.DeleteAt return nil, err } + s.mu.Lock() + defer s.mu.Unlock() + entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId()) + if err != nil { + return nil, status.Error(codes.NotFound, err.Error()) + } + if _, ok := entry.Attachments[req.GetName()]; !ok { return nil, status.Error(codes.NotFound, "attachment not found") } @@ -630,11 +643,7 @@ func (s *Server) DeleteAttachment(ctx context.Context, req *keepassgov1.DeleteAt } func (s *Server) CopyEntryField(ctx context.Context, req *keepassgov1.CopyEntryFieldRequest) (*keepassgov1.CopyEntryFieldResponse, error) { - s.mu.RLock() - model := s.model - locked := s.locked - s.mu.RUnlock() - + model, locked := s.snapshotModel() if locked { return nil, status.Error(codes.FailedPrecondition, "vault is locked") } @@ -733,10 +742,10 @@ func findMutableEntryByID(model *vault.Model, id string) (vault.Entry, int, erro return vault.Entry{}, -1, vault.ErrEntryNotFound } -func (s *Server) visibleModel() vault.Model { - out := s.model +func visibleModel(model vault.Model) vault.Model { + out := model out.Entries = nil - for _, entry := range s.model.Entries { + for _, entry := range model.Entries { token, ok, err := apitokens.TokenFromEntry(entry) if err == nil && ok && token.ID != "" { continue @@ -744,7 +753,7 @@ func (s *Server) visibleModel() vault.Model { out.Entries = append(out.Entries, entry) } out.Groups = nil - for _, path := range s.model.Groups { + for _, path := range model.Groups { if len(path) >= 2 && path[0] == "Root" && path[1] == "API Tokens" { continue } @@ -753,6 +762,12 @@ func (s *Server) visibleModel() vault.Model { return out } +func (s *Server) snapshotModel() (vault.Model, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + return s.model, s.locked +} + var timeNow = func() time.Time { return time.Now().UTC() } func (s *Server) authenticateRequest(ctx context.Context) (apitokens.Token, error) { @@ -768,7 +783,9 @@ func (s *Server) authenticateRequest(ctx context.Context) (apitokens.Token, erro if !strings.HasPrefix(values[0], prefix) { return apitokens.Token{}, status.Error(codes.Unauthenticated, "invalid bearer token") } + s.mu.RLock() tokens, err := apitokens.Entries(s.model) + s.mu.RUnlock() if err != nil { return apitokens.Token{}, status.Errorf(codes.Internal, "load api tokens: %v", err) } @@ -789,10 +806,7 @@ func (s *Server) authorizePathRequest(ctx context.Context, op apitokens.Operatio if err != nil { return apitokens.Token{}, err } - if apitokens.Evaluate(token, op, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: path}) != apitokens.DecisionAllow { - return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token") - } - return token, nil + return s.authorizeResourceRequest(ctx, token, op, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: path}) } func (s *Server) authorizeEntryRequest(ctx context.Context, op apitokens.Operation, entry vault.Entry) (apitokens.Token, error) { @@ -800,10 +814,78 @@ func (s *Server) authorizeEntryRequest(ctx context.Context, op apitokens.Operati if err != nil { return apitokens.Token{}, err } - if apitokens.Evaluate(token, op, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: entry.ID, Path: entry.Path}) != apitokens.DecisionAllow { + return s.authorizeResourceRequest(ctx, token, op, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: entry.ID, Path: entry.Path}) +} + +func (s *Server) authorizeResourceRequest(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (apitokens.Token, error) { + switch apitokens.Evaluate(token, op, resource) { + case apitokens.DecisionAllow: + return token, nil + case apitokens.DecisionDeny: + return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token") + case apitokens.DecisionPrompt: + result, err := s.approvals.Request(ctx, token, op, resource) + if result.Rule != nil { + if persistErr := s.persistApprovalRule(token.ID, *result.Rule); persistErr != nil { + return apitokens.Token{}, status.Errorf(codes.Internal, "persist approval decision: %v", persistErr) + } + } + switch { + case err == nil: + return token, nil + case errors.Is(err, apiapproval.ErrRequestDenied): + return apitokens.Token{}, status.Error(codes.PermissionDenied, "access denied by user approval") + case errors.Is(err, apiapproval.ErrRequestCanceled): + return apitokens.Token{}, status.Error(codes.Unauthenticated, "authorization request canceled") + case errors.Is(err, apiapproval.ErrRequestTimedOut): + return apitokens.Token{}, status.Error(codes.DeadlineExceeded, "authorization request timed out") + case errors.Is(err, context.Canceled): + return apitokens.Token{}, status.Error(codes.Canceled, "authorization request canceled") + case errors.Is(err, context.DeadlineExceeded): + return apitokens.Token{}, status.Error(codes.DeadlineExceeded, "authorization request timed out") + default: + return apitokens.Token{}, status.Errorf(codes.Internal, "await authorization request: %v", err) + } + default: return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token") } - return token, nil +} + +func (s *Server) persistApprovalRule(tokenID string, rule apitokens.PolicyRule) error { + s.mu.Lock() + defer s.mu.Unlock() + + for i, entry := range s.model.Entries { + token, ok, err := apitokens.TokenFromEntry(entry) + if err != nil || !ok || token.ID != tokenID { + continue + } + if !hasPolicyRule(token.Policies, rule) { + token.Policies = append(token.Policies, rule) + } + s.model.Entries[i] = token.Entry(entry.Path) + s.dirty = true + if lifecycle, ok := s.lifecycle.(modelReplaceableLifecycle); ok { + lifecycle.Replace(s.model) + } + return nil + } + return status.Error(codes.NotFound, "api token entry not found") +} + +func hasPolicyRule(rules []apitokens.PolicyRule, target apitokens.PolicyRule) bool { + for _, rule := range rules { + if rule.Effect != target.Effect || rule.Operation != target.Operation { + continue + } + if rule.Resource.Kind != target.Resource.Kind || rule.Resource.EntryID != target.Resource.EntryID { + continue + } + if slices.Equal(rule.Resource.Path, target.Resource.Path) { + return true + } + } + return false } func copyOperation(target string) apitokens.Operation { diff --git a/api/server_test.go b/api/server_test.go index e3481a7..4653ee9 100644 --- a/api/server_test.go +++ b/api/server_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "git.julianfamily.org/keepassgo/apiapproval" "git.julianfamily.org/keepassgo/apitokens" "git.julianfamily.org/keepassgo/passwords" keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1" @@ -85,6 +86,7 @@ func TestVaultServiceRejectsUnauthorizedEntryAccess(t *testing.T) { }, testAPITokenEntry(t, apitokens.PolicyRule{Effect: apitokens.EffectAllow, Operation: apitokens.OperationListEntries, Resource: apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}}}, + apitokens.PolicyRule{Effect: apitokens.EffectDeny, Operation: apitokens.OperationCopyPassword, Resource: apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: "vault-console", Path: []string{"Root", "Internet"}}}, ), }, }) @@ -96,6 +98,146 @@ func TestVaultServiceRejectsUnauthorizedEntryAccess(t *testing.T) { } } +func TestVaultServicePromptsAndResumesWhenApproved(t *testing.T) { + t.Parallel() + + model := vault.Model{ + Entries: []vault.Entry{ + { + ID: "vault-console", + Title: "Vault Console", + Username: "dannyocean", + Password: "token-1", + URL: "https://vault.crew.example.invalid", + Path: []string{"Root", "Internet"}, + }, + testAPITokenEntry(t), + }, + } + client, _, service, cleanup := newTestHarnessForModel(t, model) + defer cleanup() + service.approvals = apiapproval.NewBroker(time.Minute) + + respCh := make(chan *keepassgov1.ListEntriesResponse, 1) + errCh := make(chan error, 1) + go func() { + resp, err := client.ListEntries(tokenContext(defaultTestTokenSecret), &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}}) + respCh <- resp + errCh <- err + }() + + pending := waitForServerPendingApproval(t, service, 1)[0] + if pending.Operation != apitokens.OperationListEntries { + t.Fatalf("pending.Operation = %q, want %q", pending.Operation, apitokens.OperationListEntries) + } + if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeAllowOnce); err != nil { + t.Fatalf("ResolveApproval(allow) error = %v", err) + } + + resp := <-respCh + if err := <-errCh; err != nil { + t.Fatalf("ListEntries() error = %v", err) + } + if len(resp.Entries) != 1 || resp.Entries[0].Id != "vault-console" { + t.Fatalf("ListEntries().Entries = %#v, want vault-console", resp.Entries) + } +} + +func TestVaultServicePersistsPermanentDenyApproval(t *testing.T) { + t.Parallel() + + model := vault.Model{ + Entries: []vault.Entry{ + { + ID: "vault-console", + Title: "Vault Console", + Username: "dannyocean", + Password: "token-1", + URL: "https://vault.crew.example.invalid", + Path: []string{"Root", "Internet"}, + }, + testAPITokenEntry(t), + }, + } + client, _, service, cleanup := newTestHarnessForModel(t, model) + defer cleanup() + service.approvals = apiapproval.NewBroker(time.Minute) + + errCh := make(chan error, 1) + go func() { + _, err := client.ListEntries(tokenContext(defaultTestTokenSecret), &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}}) + errCh <- err + }() + + pending := waitForServerPendingApproval(t, service, 1)[0] + if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeDenyPermanent); err != nil { + t.Fatalf("ResolveApproval(deny permanent) error = %v", err) + } + + if err := <-errCh; status.Code(err) != codes.PermissionDenied { + t.Fatalf("ListEntries() code = %v, want %v", status.Code(err), codes.PermissionDenied) + } + + service.mu.RLock() + tokens, err := apitokens.Entries(service.model) + service.mu.RUnlock() + if err != nil { + t.Fatalf("Entries() error = %v", err) + } + decision := apitokens.Evaluate(tokens[0], apitokens.OperationListEntries, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}}) + if decision != apitokens.DecisionDeny { + t.Fatalf("Evaluate() after permanent deny = %q, want %q", decision, apitokens.DecisionDeny) + } +} + +func TestVaultServiceReturnsCanceledForCanceledApproval(t *testing.T) { + t.Parallel() + + model := vault.Model{ + Entries: []vault.Entry{ + {ID: "vault-console", Title: "Vault Console", Path: []string{"Root", "Internet"}}, + testAPITokenEntry(t), + }, + } + client, _, service, cleanup := newTestHarnessForModel(t, model) + defer cleanup() + service.approvals = apiapproval.NewBroker(time.Minute) + + errCh := make(chan error, 1) + go func() { + _, err := client.ListEntries(tokenContext(defaultTestTokenSecret), &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}}) + errCh <- err + }() + + pending := waitForServerPendingApproval(t, service, 1)[0] + if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeCancel); err != nil { + t.Fatalf("ResolveApproval(cancel) error = %v", err) + } + + if err := <-errCh; status.Code(err) != codes.Unauthenticated { + t.Fatalf("ListEntries() code = %v, want %v", status.Code(err), codes.Unauthenticated) + } +} + +func TestVaultServiceTimesOutPendingApproval(t *testing.T) { + t.Parallel() + + model := vault.Model{ + Entries: []vault.Entry{ + {ID: "vault-console", Title: "Vault Console", Path: []string{"Root", "Internet"}}, + testAPITokenEntry(t), + }, + } + client, _, service, cleanup := newTestHarnessForModel(t, model) + defer cleanup() + service.approvals = apiapproval.NewBroker(20 * time.Millisecond) + + _, err := client.ListEntries(tokenContext(defaultTestTokenSecret), &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}}) + if status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("ListEntries() code = %v, want %v", status.Code(err), codes.DeadlineExceeded) + } +} + func TestVaultServiceReportsSessionStatusAndSupportsLockUnlock(t *testing.T) { t.Parallel() @@ -936,6 +1078,11 @@ func newTestClient(t *testing.T) (keepassgov1.VaultServiceClient, *memoryClipboa } func newTestClientForModel(t *testing.T, model vault.Model) (keepassgov1.VaultServiceClient, *memoryClipboardWriter, func()) { + client, clipboardWriter, _, cleanup := newTestHarnessForModel(t, model) + return client, clipboardWriter, cleanup +} + +func newTestHarnessForModel(t *testing.T, model vault.Model) (keepassgov1.VaultServiceClient, *memoryClipboardWriter, *Server, func()) { t.Helper() listener := bufconn.Listen(1024 * 1024) @@ -963,10 +1110,15 @@ func newTestClientForModel(t *testing.T, model vault.Model) (keepassgov1.VaultSe server.Stop() } - return keepassgov1.NewVaultServiceClient(conn), clipboardWriter, cleanup + return keepassgov1.NewVaultServiceClient(conn), clipboardWriter, service, cleanup } func newTestClientWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepassgov1.VaultServiceClient, *memoryClipboardWriter, func()) { + client, clipboardWriter, _, cleanup := newTestHarnessWithLifecycle(t, lifecycle) + return client, clipboardWriter, cleanup +} + +func newTestHarnessWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepassgov1.VaultServiceClient, *memoryClipboardWriter, *Server, func()) { t.Helper() listener := bufconn.Listen(1024 * 1024) @@ -1001,7 +1153,7 @@ func newTestClientWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepass server.Stop() } - return keepassgov1.NewVaultServiceClient(conn), clipboardWriter, cleanup + return keepassgov1.NewVaultServiceClient(conn), clipboardWriter, service, cleanup } type memoryClipboardWriter struct { @@ -1013,6 +1165,21 @@ func (w *memoryClipboardWriter) WriteText(text string) error { return nil } +func waitForServerPendingApproval(t *testing.T, server *Server, want int) []apiapproval.Request { + t.Helper() + + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + pending := server.ApprovalBroker().Pending() + if len(pending) == want { + return pending + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("server pending approvals never reached len %d", want) + return nil +} + type stubLifecycle struct { model vault.Model openPath string @@ -1087,3 +1254,7 @@ func (s *stubLifecycle) Unlock(key vault.MasterKey) error { s.locked = false return nil } + +func (s *stubLifecycle) Replace(model vault.Model) { + s.model = model +} diff --git a/apiapproval/approval.go b/apiapproval/approval.go new file mode 100644 index 0000000..b569ae9 --- /dev/null +++ b/apiapproval/approval.go @@ -0,0 +1,215 @@ +package apiapproval + +import ( + "context" + "errors" + "fmt" + "slices" + "strconv" + "sync" + "time" + + "git.julianfamily.org/keepassgo/apitokens" +) + +var ( + ErrRequestDenied = errors.New("authorization request denied") + ErrRequestCanceled = errors.New("authorization request canceled") + ErrRequestTimedOut = errors.New("authorization request timed out") + ErrRequestNotFound = errors.New("authorization request not found") +) + +type Outcome string + +const ( + OutcomeAllowOnce Outcome = "allow-once" + OutcomeDenyOnce Outcome = "deny-once" + OutcomeAllowPermanent Outcome = "allow-permanent" + OutcomeDenyPermanent Outcome = "deny-permanent" + OutcomeCancel Outcome = "cancel" +) + +type Request struct { + ID string + TokenID string + TokenName string + ClientName string + Operation apitokens.Operation + Resource apitokens.Resource + RequestedAt time.Time +} + +type Result struct { + Outcome Outcome + Rule *apitokens.PolicyRule +} + +type Broker struct { + mu sync.Mutex + pending map[string]*pendingRequest + timeout time.Duration + now func() time.Time + nextID func() string +} + +type pendingRequest struct { + request Request + done chan Outcome +} + +type idGenerator struct { + mu sync.Mutex + counter int +} + +func NewBroker(timeout time.Duration) *Broker { + gen := &idGenerator{} + return &Broker{ + pending: map[string]*pendingRequest{}, + timeout: timeout, + now: func() time.Time { + return time.Now().UTC() + }, + nextID: func() string { + return gen.Next() + }, + } +} + +func (g *idGenerator) Next() string { + g.mu.Lock() + defer g.mu.Unlock() + g.counter++ + return "approval-" + strconv.Itoa(g.counter) +} + +func (b *Broker) Pending() []Request { + b.mu.Lock() + defer b.mu.Unlock() + + requests := make([]Request, 0, len(b.pending)) + for _, pending := range b.pending { + requests = append(requests, pending.request) + } + slices.SortFunc(requests, func(a, c Request) int { + switch { + case a.RequestedAt.Before(c.RequestedAt): + return -1 + case a.RequestedAt.After(c.RequestedAt): + return 1 + case a.ID < c.ID: + return -1 + case a.ID > c.ID: + return 1 + default: + return 0 + } + }) + return requests +} + +func (b *Broker) Request(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (Result, error) { + if b == nil { + return Result{}, ErrRequestTimedOut + } + + pending := &pendingRequest{ + request: Request{ + ID: b.nextID(), + TokenID: token.ID, + TokenName: token.Name, + ClientName: token.ClientName, + Operation: op, + Resource: resource, + RequestedAt: b.now(), + }, + done: make(chan Outcome, 1), + } + + b.mu.Lock() + b.pending[pending.request.ID] = pending + b.mu.Unlock() + + defer func() { + b.mu.Lock() + delete(b.pending, pending.request.ID) + b.mu.Unlock() + }() + + timer := time.NewTimer(b.timeout) + defer timer.Stop() + + select { + case outcome := <-pending.done: + result, err := resultForOutcome(pending.request, outcome) + if err != nil { + return result, err + } + return result, nil + case <-timer.C: + return Result{}, ErrRequestTimedOut + case <-ctx.Done(): + return Result{}, ctx.Err() + } +} + +func (b *Broker) Resolve(id string, outcome Outcome) (Request, *apitokens.PolicyRule, error) { + b.mu.Lock() + pending, ok := b.pending[id] + b.mu.Unlock() + if !ok { + return Request{}, nil, ErrRequestNotFound + } + + rule, err := RuleFromDecision(pending.request, outcome) + if err != nil { + return Request{}, nil, err + } + + select { + case pending.done <- outcome: + default: + } + return pending.request, rule, nil +} + +func resultForOutcome(request Request, outcome Outcome) (Result, error) { + rule, err := RuleFromDecision(request, outcome) + if err != nil { + return Result{}, err + } + result := Result{Outcome: outcome, Rule: rule} + switch outcome { + case OutcomeAllowOnce, OutcomeAllowPermanent: + return result, nil + case OutcomeDenyOnce, OutcomeDenyPermanent: + return result, ErrRequestDenied + case OutcomeCancel: + return result, ErrRequestCanceled + default: + return Result{}, fmt.Errorf("unsupported approval outcome %q", outcome) + } +} + +func RuleFromDecision(request Request, outcome Outcome) (*apitokens.PolicyRule, error) { + switch outcome { + case OutcomeAllowPermanent: + rule := apitokens.PolicyRule{ + Effect: apitokens.EffectAllow, + Operation: request.Operation, + Resource: request.Resource, + } + return &rule, nil + case OutcomeDenyPermanent: + rule := apitokens.PolicyRule{ + Effect: apitokens.EffectDeny, + Operation: request.Operation, + Resource: request.Resource, + } + return &rule, nil + case OutcomeAllowOnce, OutcomeDenyOnce, OutcomeCancel: + return nil, nil + default: + return nil, fmt.Errorf("unsupported approval outcome %q", outcome) + } +} diff --git a/apiapproval/approval_test.go b/apiapproval/approval_test.go new file mode 100644 index 0000000..89bbadd --- /dev/null +++ b/apiapproval/approval_test.go @@ -0,0 +1,134 @@ +package apiapproval + +import ( + "context" + "errors" + "testing" + "time" + + "git.julianfamily.org/keepassgo/apitokens" +) + +func TestBrokerCreatesPendingRequestAndAllowsOnce(t *testing.T) { + t.Parallel() + + broker := NewBroker(time.Minute) + broker.now = func() time.Time { return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) } + + resultCh := make(chan Result, 1) + errCh := make(chan error, 1) + go func() { + result, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI", ClientName: "grpc-cli"}, apitokens.OperationListEntries, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}}) + resultCh <- result + errCh <- err + }() + + waitForPending(t, broker, 1) + pending := broker.Pending() + if len(pending) != 1 { + t.Fatalf("Pending() len = %d, want 1", len(pending)) + } + if pending[0].TokenID != "token-1" { + t.Fatalf("Pending()[0].TokenID = %q, want token-1", pending[0].TokenID) + } + + if _, _, err := broker.Resolve(pending[0].ID, OutcomeAllowOnce); err != nil { + t.Fatalf("Resolve(allow once) error = %v", err) + } + + result := <-resultCh + if err := <-errCh; err != nil { + t.Fatalf("Request() error = %v, want nil", err) + } + if result.Outcome != OutcomeAllowOnce { + t.Fatalf("Request() outcome = %q, want %q", result.Outcome, OutcomeAllowOnce) + } + if result.Rule != nil { + t.Fatalf("Request() rule = %#v, want nil for allow-once", result.Rule) + } + if got := broker.Pending(); len(got) != 0 { + t.Fatalf("Pending() after allow len = %d, want 0", len(got)) + } +} + +func TestBrokerReturnsPermanentRuleForDeny(t *testing.T) { + t.Parallel() + + broker := NewBroker(time.Minute) + reqDone := make(chan struct{}) + var result Result + var err error + go func() { + result, err = broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationReadEntry, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: "entry-1", Path: []string{"Root", "Internet"}}) + close(reqDone) + }() + + waitForPending(t, broker, 1) + pending := broker.Pending()[0] + request, rule, resolveErr := broker.Resolve(pending.ID, OutcomeDenyPermanent) + if resolveErr != nil { + t.Fatalf("Resolve(deny permanent) error = %v", resolveErr) + } + if request.ID != pending.ID { + t.Fatalf("Resolve().ID = %q, want %q", request.ID, pending.ID) + } + if rule == nil || rule.Effect != apitokens.EffectDeny || rule.Operation != apitokens.OperationReadEntry { + t.Fatalf("Resolve() rule = %#v, want deny read-entry rule", rule) + } + + <-reqDone + if !errors.Is(err, ErrRequestDenied) { + t.Fatalf("Request() error = %v, want ErrRequestDenied", err) + } + if result.Rule == nil || result.Rule.Effect != apitokens.EffectDeny { + t.Fatalf("Request() rule = %#v, want deny rule", result.Rule) + } + if result.Outcome != OutcomeDenyPermanent { + t.Fatalf("Request() outcome = %q, want %q", result.Outcome, OutcomeDenyPermanent) + } +} + +func TestBrokerSupportsCancellation(t *testing.T) { + t.Parallel() + + broker := NewBroker(time.Minute) + errCh := make(chan error, 1) + go func() { + _, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}}) + errCh <- err + }() + + waitForPending(t, broker, 1) + if _, _, err := broker.Resolve(broker.Pending()[0].ID, OutcomeCancel); err != nil { + t.Fatalf("Resolve(cancel) error = %v", err) + } + if err := <-errCh; !errors.Is(err, ErrRequestCanceled) { + t.Fatalf("Request() error = %v, want ErrRequestCanceled", err) + } +} + +func TestBrokerTimesOutPendingRequests(t *testing.T) { + t.Parallel() + + broker := NewBroker(10 * time.Millisecond) + _, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}}) + if !errors.Is(err, ErrRequestTimedOut) { + t.Fatalf("Request() error = %v, want ErrRequestTimedOut", err) + } + if got := broker.Pending(); len(got) != 0 { + t.Fatalf("Pending() len after timeout = %d, want 0", len(got)) + } +} + +func waitForPending(t *testing.T, broker *Broker, want int) { + t.Helper() + + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if got := len(broker.Pending()); got == want { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("Pending() never reached len %d", want) +}