Add API approval broker for gRPC authorization prompts

This commit is contained in:
Joe Julian
2026-03-29 23:09:36 -07:00
parent 6a7594e128
commit f77a185e46
4 changed files with 684 additions and 82 deletions
+162 -80
View File
@@ -10,6 +10,7 @@ import (
"sync"
"time"
"git.julianfamily.org/keepassgo/apiapproval"
"git.julianfamily.org/keepassgo/apitokens"
"git.julianfamily.org/keepassgo/clipboard"
"git.julianfamily.org/keepassgo/passwords"
@@ -33,6 +34,7 @@ type Server struct {
lifecycle lifecycleBackend
profiles map[string]passwords.Profile
clipboard clipboard.Writer
approvals *apiapproval.Broker
}
type lifecycleBackend interface {
@@ -44,11 +46,17 @@ type lifecycleBackend interface {
Unlock(vault.MasterKey) error
}
type modelReplaceableLifecycle interface {
lifecycleBackend
Replace(vault.Model)
}
func NewServer(model vault.Model, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer) *Server {
return &Server{
model: model,
profiles: profiles,
clipboard: clipboardWriter,
approvals: apiapproval.NewBroker(30 * time.Second),
}
}
@@ -58,6 +66,15 @@ func NewServerWithLifecycle(model vault.Model, profiles map[string]passwords.Pro
return server
}
func (s *Server) ApprovalBroker() *apiapproval.Broker {
return s.approvals
}
func (s *Server) ResolveApproval(id string, outcome apiapproval.Outcome) (apiapproval.Request, error) {
request, _, err := s.approvals.Resolve(id, outcome)
return request, err
}
func (s *Server) GetSessionStatus(_ context.Context, _ *keepassgov1.GetSessionStatusRequest) (*keepassgov1.GetSessionStatusResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -193,17 +210,15 @@ func mapLifecycleError(operation string, err error) error {
}
func (s *Server) ListEntries(ctx context.Context, req *keepassgov1.ListEntriesRequest) (*keepassgov1.ListEntriesResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if _, err := s.authorizePathRequest(ctx, apitokens.OperationListEntries, req.GetPath()); err != nil {
return nil, err
}
model := s.visibleModel()
model = visibleModel(model)
var entries []vault.Entry
if strings.TrimSpace(req.GetQuery()) != "" {
results := model.Search(req.GetQuery())
@@ -226,10 +241,8 @@ func (s *Server) ListEntries(ctx context.Context, req *keepassgov1.ListEntriesRe
}
func (s *Server) ListGroups(ctx context.Context, req *keepassgov1.ListGroupsRequest) (*keepassgov1.ListGroupsResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if _, err := s.authorizePathRequest(ctx, apitokens.OperationListGroups, req.GetPath()); err != nil {
@@ -237,18 +250,17 @@ func (s *Server) ListGroups(ctx context.Context, req *keepassgov1.ListGroupsRequ
}
return &keepassgov1.ListGroupsResponse{
Names: s.visibleModel().ChildGroups(req.GetPath()),
Names: visibleModel(model).ChildGroups(req.GetPath()),
}, nil
}
func (s *Server) CreateGroup(ctx context.Context, req *keepassgov1.CreateGroupRequest) (*keepassgov1.CreateGroupResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetParentPath()); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
@@ -259,13 +271,12 @@ func (s *Server) CreateGroup(ctx context.Context, req *keepassgov1.CreateGroupRe
}
func (s *Server) RenameGroup(ctx context.Context, req *keepassgov1.RenameGroupRequest) (*keepassgov1.RenameGroupResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetPath()); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
@@ -282,13 +293,12 @@ func (s *Server) RenameGroup(ctx context.Context, req *keepassgov1.RenameGroupRe
}
func (s *Server) DeleteGroup(ctx context.Context, req *keepassgov1.DeleteGroupRequest) (*keepassgov1.DeleteGroupResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetPath()); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
@@ -314,16 +324,15 @@ func (s *Server) UpsertEntry(ctx context.Context, req *keepassgov1.UpsertEntryRe
}
entry := entryFromProto(req.GetEntry())
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
return nil, err
}
s.mu.Lock()
if s.locked {
s.mu.Unlock()
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
s.mu.Unlock()
return nil, err
}
s.model.UpsertEntry(entry)
s.dirty = true
s.mu.Unlock()
@@ -332,18 +341,19 @@ func (s *Server) UpsertEntry(ctx context.Context, req *keepassgov1.UpsertEntryRe
}
func (s *Server) DeleteEntry(ctx context.Context, req *keepassgov1.DeleteEntryRequest) (*keepassgov1.DeleteEntryResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if entry, err := findEntryByID(s.model, req.GetId()); err == nil {
if entry, err := findEntryByID(model, req.GetId()); err == nil {
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
return nil, err
}
}
s.mu.Lock()
defer s.mu.Unlock()
if err := s.model.DeleteEntry(req.GetId()); err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
@@ -356,15 +366,13 @@ func (s *Server) DeleteEntry(ctx context.Context, req *keepassgov1.DeleteEntryRe
}
func (s *Server) RestoreEntry(ctx context.Context, req *keepassgov1.RestoreEntryRequest) (*keepassgov1.RestoreEntryResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
var restored vault.Entry
for _, entry := range s.model.RecycleBin {
for _, entry := range model.RecycleBin {
if entry.ID == req.GetId() {
restored = entry
break
@@ -376,6 +384,9 @@ func (s *Server) RestoreEntry(ctx context.Context, req *keepassgov1.RestoreEntry
}
}
s.mu.Lock()
defer s.mu.Unlock()
if err := s.model.RestoreEntry(req.GetId()); err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
@@ -388,14 +399,12 @@ func (s *Server) RestoreEntry(ctx context.Context, req *keepassgov1.RestoreEntry
}
func (s *Server) ListEntryHistory(ctx context.Context, req *keepassgov1.ListEntryHistoryRequest) (*keepassgov1.ListEntryHistoryResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(s.model, req.GetId())
entry, err := findEntryByID(model, req.GetId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
@@ -413,13 +422,11 @@ func (s *Server) ListEntryHistory(ctx context.Context, req *keepassgov1.ListEntr
}
func (s *Server) RestoreEntryHistory(ctx context.Context, req *keepassgov1.RestoreEntryHistoryRequest) (*keepassgov1.RestoreEntryHistoryResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(s.model, req.GetId())
entry, err := findEntryByID(model, req.GetId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
@@ -427,6 +434,8 @@ func (s *Server) RestoreEntryHistory(ctx context.Context, req *keepassgov1.Resto
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if err := s.model.RestoreEntryVersion(req.GetId(), int(req.GetHistoryIndex())); err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
@@ -523,14 +532,12 @@ func (s *Server) InstantiateTemplate(_ context.Context, req *keepassgov1.Instant
}
func (s *Server) ListAttachments(ctx context.Context, req *keepassgov1.ListAttachmentsRequest) (*keepassgov1.ListAttachmentsResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(s.model, req.GetEntryId())
entry, err := findEntryByID(model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
@@ -548,14 +555,11 @@ func (s *Server) ListAttachments(ctx context.Context, req *keepassgov1.ListAttac
}
func (s *Server) UploadAttachment(ctx context.Context, req *keepassgov1.UploadAttachmentRequest) (*keepassgov1.UploadAttachmentResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId())
entry, err := findEntryByID(model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
@@ -563,6 +567,13 @@ func (s *Server) UploadAttachment(ctx context.Context, req *keepassgov1.UploadAt
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if entry.Attachments == nil {
entry.Attachments = map[string][]byte{}
}
@@ -574,14 +585,12 @@ func (s *Server) UploadAttachment(ctx context.Context, req *keepassgov1.UploadAt
}
func (s *Server) DownloadAttachment(ctx context.Context, req *keepassgov1.DownloadAttachmentRequest) (*keepassgov1.DownloadAttachmentResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(s.model, req.GetEntryId())
entry, err := findEntryByID(model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
@@ -600,14 +609,11 @@ func (s *Server) DownloadAttachment(ctx context.Context, req *keepassgov1.Downlo
}
func (s *Server) DeleteAttachment(ctx context.Context, req *keepassgov1.DeleteAttachmentRequest) (*keepassgov1.DeleteAttachmentResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId())
entry, err := findEntryByID(model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
@@ -615,6 +621,13 @@ func (s *Server) DeleteAttachment(ctx context.Context, req *keepassgov1.DeleteAt
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, ok := entry.Attachments[req.GetName()]; !ok {
return nil, status.Error(codes.NotFound, "attachment not found")
}
@@ -630,11 +643,7 @@ func (s *Server) DeleteAttachment(ctx context.Context, req *keepassgov1.DeleteAt
}
func (s *Server) CopyEntryField(ctx context.Context, req *keepassgov1.CopyEntryFieldRequest) (*keepassgov1.CopyEntryFieldResponse, error) {
s.mu.RLock()
model := s.model
locked := s.locked
s.mu.RUnlock()
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
@@ -733,10 +742,10 @@ func findMutableEntryByID(model *vault.Model, id string) (vault.Entry, int, erro
return vault.Entry{}, -1, vault.ErrEntryNotFound
}
func (s *Server) visibleModel() vault.Model {
out := s.model
func visibleModel(model vault.Model) vault.Model {
out := model
out.Entries = nil
for _, entry := range s.model.Entries {
for _, entry := range model.Entries {
token, ok, err := apitokens.TokenFromEntry(entry)
if err == nil && ok && token.ID != "" {
continue
@@ -744,7 +753,7 @@ func (s *Server) visibleModel() vault.Model {
out.Entries = append(out.Entries, entry)
}
out.Groups = nil
for _, path := range s.model.Groups {
for _, path := range model.Groups {
if len(path) >= 2 && path[0] == "Root" && path[1] == "API Tokens" {
continue
}
@@ -753,6 +762,12 @@ func (s *Server) visibleModel() vault.Model {
return out
}
func (s *Server) snapshotModel() (vault.Model, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.model, s.locked
}
var timeNow = func() time.Time { return time.Now().UTC() }
func (s *Server) authenticateRequest(ctx context.Context) (apitokens.Token, error) {
@@ -768,7 +783,9 @@ func (s *Server) authenticateRequest(ctx context.Context) (apitokens.Token, erro
if !strings.HasPrefix(values[0], prefix) {
return apitokens.Token{}, status.Error(codes.Unauthenticated, "invalid bearer token")
}
s.mu.RLock()
tokens, err := apitokens.Entries(s.model)
s.mu.RUnlock()
if err != nil {
return apitokens.Token{}, status.Errorf(codes.Internal, "load api tokens: %v", err)
}
@@ -789,10 +806,7 @@ func (s *Server) authorizePathRequest(ctx context.Context, op apitokens.Operatio
if err != nil {
return apitokens.Token{}, err
}
if apitokens.Evaluate(token, op, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: path}) != apitokens.DecisionAllow {
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token")
}
return token, nil
return s.authorizeResourceRequest(ctx, token, op, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: path})
}
func (s *Server) authorizeEntryRequest(ctx context.Context, op apitokens.Operation, entry vault.Entry) (apitokens.Token, error) {
@@ -800,10 +814,78 @@ func (s *Server) authorizeEntryRequest(ctx context.Context, op apitokens.Operati
if err != nil {
return apitokens.Token{}, err
}
if apitokens.Evaluate(token, op, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: entry.ID, Path: entry.Path}) != apitokens.DecisionAllow {
return s.authorizeResourceRequest(ctx, token, op, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: entry.ID, Path: entry.Path})
}
func (s *Server) authorizeResourceRequest(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (apitokens.Token, error) {
switch apitokens.Evaluate(token, op, resource) {
case apitokens.DecisionAllow:
return token, nil
case apitokens.DecisionDeny:
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token")
case apitokens.DecisionPrompt:
result, err := s.approvals.Request(ctx, token, op, resource)
if result.Rule != nil {
if persistErr := s.persistApprovalRule(token.ID, *result.Rule); persistErr != nil {
return apitokens.Token{}, status.Errorf(codes.Internal, "persist approval decision: %v", persistErr)
}
}
switch {
case err == nil:
return token, nil
case errors.Is(err, apiapproval.ErrRequestDenied):
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access denied by user approval")
case errors.Is(err, apiapproval.ErrRequestCanceled):
return apitokens.Token{}, status.Error(codes.Unauthenticated, "authorization request canceled")
case errors.Is(err, apiapproval.ErrRequestTimedOut):
return apitokens.Token{}, status.Error(codes.DeadlineExceeded, "authorization request timed out")
case errors.Is(err, context.Canceled):
return apitokens.Token{}, status.Error(codes.Canceled, "authorization request canceled")
case errors.Is(err, context.DeadlineExceeded):
return apitokens.Token{}, status.Error(codes.DeadlineExceeded, "authorization request timed out")
default:
return apitokens.Token{}, status.Errorf(codes.Internal, "await authorization request: %v", err)
}
default:
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token")
}
return token, nil
}
func (s *Server) persistApprovalRule(tokenID string, rule apitokens.PolicyRule) error {
s.mu.Lock()
defer s.mu.Unlock()
for i, entry := range s.model.Entries {
token, ok, err := apitokens.TokenFromEntry(entry)
if err != nil || !ok || token.ID != tokenID {
continue
}
if !hasPolicyRule(token.Policies, rule) {
token.Policies = append(token.Policies, rule)
}
s.model.Entries[i] = token.Entry(entry.Path)
s.dirty = true
if lifecycle, ok := s.lifecycle.(modelReplaceableLifecycle); ok {
lifecycle.Replace(s.model)
}
return nil
}
return status.Error(codes.NotFound, "api token entry not found")
}
func hasPolicyRule(rules []apitokens.PolicyRule, target apitokens.PolicyRule) bool {
for _, rule := range rules {
if rule.Effect != target.Effect || rule.Operation != target.Operation {
continue
}
if rule.Resource.Kind != target.Resource.Kind || rule.Resource.EntryID != target.Resource.EntryID {
continue
}
if slices.Equal(rule.Resource.Path, target.Resource.Path) {
return true
}
}
return false
}
func copyOperation(target string) apitokens.Operation {
+173 -2
View File
@@ -10,6 +10,7 @@ import (
"testing"
"time"
"git.julianfamily.org/keepassgo/apiapproval"
"git.julianfamily.org/keepassgo/apitokens"
"git.julianfamily.org/keepassgo/passwords"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
@@ -85,6 +86,7 @@ func TestVaultServiceRejectsUnauthorizedEntryAccess(t *testing.T) {
},
testAPITokenEntry(t,
apitokens.PolicyRule{Effect: apitokens.EffectAllow, Operation: apitokens.OperationListEntries, Resource: apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}}},
apitokens.PolicyRule{Effect: apitokens.EffectDeny, Operation: apitokens.OperationCopyPassword, Resource: apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: "vault-console", Path: []string{"Root", "Internet"}}},
),
},
})
@@ -96,6 +98,146 @@ func TestVaultServiceRejectsUnauthorizedEntryAccess(t *testing.T) {
}
}
func TestVaultServicePromptsAndResumesWhenApproved(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "token-1",
URL: "https://vault.crew.example.invalid",
Path: []string{"Root", "Internet"},
},
testAPITokenEntry(t),
},
}
client, _, service, cleanup := newTestHarnessForModel(t, model)
defer cleanup()
service.approvals = apiapproval.NewBroker(time.Minute)
respCh := make(chan *keepassgov1.ListEntriesResponse, 1)
errCh := make(chan error, 1)
go func() {
resp, err := client.ListEntries(tokenContext(defaultTestTokenSecret), &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}})
respCh <- resp
errCh <- err
}()
pending := waitForServerPendingApproval(t, service, 1)[0]
if pending.Operation != apitokens.OperationListEntries {
t.Fatalf("pending.Operation = %q, want %q", pending.Operation, apitokens.OperationListEntries)
}
if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeAllowOnce); err != nil {
t.Fatalf("ResolveApproval(allow) error = %v", err)
}
resp := <-respCh
if err := <-errCh; err != nil {
t.Fatalf("ListEntries() error = %v", err)
}
if len(resp.Entries) != 1 || resp.Entries[0].Id != "vault-console" {
t.Fatalf("ListEntries().Entries = %#v, want vault-console", resp.Entries)
}
}
func TestVaultServicePersistsPermanentDenyApproval(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "token-1",
URL: "https://vault.crew.example.invalid",
Path: []string{"Root", "Internet"},
},
testAPITokenEntry(t),
},
}
client, _, service, cleanup := newTestHarnessForModel(t, model)
defer cleanup()
service.approvals = apiapproval.NewBroker(time.Minute)
errCh := make(chan error, 1)
go func() {
_, err := client.ListEntries(tokenContext(defaultTestTokenSecret), &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}})
errCh <- err
}()
pending := waitForServerPendingApproval(t, service, 1)[0]
if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeDenyPermanent); err != nil {
t.Fatalf("ResolveApproval(deny permanent) error = %v", err)
}
if err := <-errCh; status.Code(err) != codes.PermissionDenied {
t.Fatalf("ListEntries() code = %v, want %v", status.Code(err), codes.PermissionDenied)
}
service.mu.RLock()
tokens, err := apitokens.Entries(service.model)
service.mu.RUnlock()
if err != nil {
t.Fatalf("Entries() error = %v", err)
}
decision := apitokens.Evaluate(tokens[0], apitokens.OperationListEntries, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}})
if decision != apitokens.DecisionDeny {
t.Fatalf("Evaluate() after permanent deny = %q, want %q", decision, apitokens.DecisionDeny)
}
}
func TestVaultServiceReturnsCanceledForCanceledApproval(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
{ID: "vault-console", Title: "Vault Console", Path: []string{"Root", "Internet"}},
testAPITokenEntry(t),
},
}
client, _, service, cleanup := newTestHarnessForModel(t, model)
defer cleanup()
service.approvals = apiapproval.NewBroker(time.Minute)
errCh := make(chan error, 1)
go func() {
_, err := client.ListEntries(tokenContext(defaultTestTokenSecret), &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}})
errCh <- err
}()
pending := waitForServerPendingApproval(t, service, 1)[0]
if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeCancel); err != nil {
t.Fatalf("ResolveApproval(cancel) error = %v", err)
}
if err := <-errCh; status.Code(err) != codes.Unauthenticated {
t.Fatalf("ListEntries() code = %v, want %v", status.Code(err), codes.Unauthenticated)
}
}
func TestVaultServiceTimesOutPendingApproval(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
{ID: "vault-console", Title: "Vault Console", Path: []string{"Root", "Internet"}},
testAPITokenEntry(t),
},
}
client, _, service, cleanup := newTestHarnessForModel(t, model)
defer cleanup()
service.approvals = apiapproval.NewBroker(20 * time.Millisecond)
_, err := client.ListEntries(tokenContext(defaultTestTokenSecret), &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}})
if status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("ListEntries() code = %v, want %v", status.Code(err), codes.DeadlineExceeded)
}
}
func TestVaultServiceReportsSessionStatusAndSupportsLockUnlock(t *testing.T) {
t.Parallel()
@@ -936,6 +1078,11 @@ func newTestClient(t *testing.T) (keepassgov1.VaultServiceClient, *memoryClipboa
}
func newTestClientForModel(t *testing.T, model vault.Model) (keepassgov1.VaultServiceClient, *memoryClipboardWriter, func()) {
client, clipboardWriter, _, cleanup := newTestHarnessForModel(t, model)
return client, clipboardWriter, cleanup
}
func newTestHarnessForModel(t *testing.T, model vault.Model) (keepassgov1.VaultServiceClient, *memoryClipboardWriter, *Server, func()) {
t.Helper()
listener := bufconn.Listen(1024 * 1024)
@@ -963,10 +1110,15 @@ func newTestClientForModel(t *testing.T, model vault.Model) (keepassgov1.VaultSe
server.Stop()
}
return keepassgov1.NewVaultServiceClient(conn), clipboardWriter, cleanup
return keepassgov1.NewVaultServiceClient(conn), clipboardWriter, service, cleanup
}
func newTestClientWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepassgov1.VaultServiceClient, *memoryClipboardWriter, func()) {
client, clipboardWriter, _, cleanup := newTestHarnessWithLifecycle(t, lifecycle)
return client, clipboardWriter, cleanup
}
func newTestHarnessWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepassgov1.VaultServiceClient, *memoryClipboardWriter, *Server, func()) {
t.Helper()
listener := bufconn.Listen(1024 * 1024)
@@ -1001,7 +1153,7 @@ func newTestClientWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepass
server.Stop()
}
return keepassgov1.NewVaultServiceClient(conn), clipboardWriter, cleanup
return keepassgov1.NewVaultServiceClient(conn), clipboardWriter, service, cleanup
}
type memoryClipboardWriter struct {
@@ -1013,6 +1165,21 @@ func (w *memoryClipboardWriter) WriteText(text string) error {
return nil
}
func waitForServerPendingApproval(t *testing.T, server *Server, want int) []apiapproval.Request {
t.Helper()
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
pending := server.ApprovalBroker().Pending()
if len(pending) == want {
return pending
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("server pending approvals never reached len %d", want)
return nil
}
type stubLifecycle struct {
model vault.Model
openPath string
@@ -1087,3 +1254,7 @@ func (s *stubLifecycle) Unlock(key vault.MasterKey) error {
s.locked = false
return nil
}
func (s *stubLifecycle) Replace(model vault.Model) {
s.model = model
}
+215
View File
@@ -0,0 +1,215 @@
package apiapproval
import (
"context"
"errors"
"fmt"
"slices"
"strconv"
"sync"
"time"
"git.julianfamily.org/keepassgo/apitokens"
)
var (
ErrRequestDenied = errors.New("authorization request denied")
ErrRequestCanceled = errors.New("authorization request canceled")
ErrRequestTimedOut = errors.New("authorization request timed out")
ErrRequestNotFound = errors.New("authorization request not found")
)
type Outcome string
const (
OutcomeAllowOnce Outcome = "allow-once"
OutcomeDenyOnce Outcome = "deny-once"
OutcomeAllowPermanent Outcome = "allow-permanent"
OutcomeDenyPermanent Outcome = "deny-permanent"
OutcomeCancel Outcome = "cancel"
)
type Request struct {
ID string
TokenID string
TokenName string
ClientName string
Operation apitokens.Operation
Resource apitokens.Resource
RequestedAt time.Time
}
type Result struct {
Outcome Outcome
Rule *apitokens.PolicyRule
}
type Broker struct {
mu sync.Mutex
pending map[string]*pendingRequest
timeout time.Duration
now func() time.Time
nextID func() string
}
type pendingRequest struct {
request Request
done chan Outcome
}
type idGenerator struct {
mu sync.Mutex
counter int
}
func NewBroker(timeout time.Duration) *Broker {
gen := &idGenerator{}
return &Broker{
pending: map[string]*pendingRequest{},
timeout: timeout,
now: func() time.Time {
return time.Now().UTC()
},
nextID: func() string {
return gen.Next()
},
}
}
func (g *idGenerator) Next() string {
g.mu.Lock()
defer g.mu.Unlock()
g.counter++
return "approval-" + strconv.Itoa(g.counter)
}
func (b *Broker) Pending() []Request {
b.mu.Lock()
defer b.mu.Unlock()
requests := make([]Request, 0, len(b.pending))
for _, pending := range b.pending {
requests = append(requests, pending.request)
}
slices.SortFunc(requests, func(a, c Request) int {
switch {
case a.RequestedAt.Before(c.RequestedAt):
return -1
case a.RequestedAt.After(c.RequestedAt):
return 1
case a.ID < c.ID:
return -1
case a.ID > c.ID:
return 1
default:
return 0
}
})
return requests
}
func (b *Broker) Request(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (Result, error) {
if b == nil {
return Result{}, ErrRequestTimedOut
}
pending := &pendingRequest{
request: Request{
ID: b.nextID(),
TokenID: token.ID,
TokenName: token.Name,
ClientName: token.ClientName,
Operation: op,
Resource: resource,
RequestedAt: b.now(),
},
done: make(chan Outcome, 1),
}
b.mu.Lock()
b.pending[pending.request.ID] = pending
b.mu.Unlock()
defer func() {
b.mu.Lock()
delete(b.pending, pending.request.ID)
b.mu.Unlock()
}()
timer := time.NewTimer(b.timeout)
defer timer.Stop()
select {
case outcome := <-pending.done:
result, err := resultForOutcome(pending.request, outcome)
if err != nil {
return result, err
}
return result, nil
case <-timer.C:
return Result{}, ErrRequestTimedOut
case <-ctx.Done():
return Result{}, ctx.Err()
}
}
func (b *Broker) Resolve(id string, outcome Outcome) (Request, *apitokens.PolicyRule, error) {
b.mu.Lock()
pending, ok := b.pending[id]
b.mu.Unlock()
if !ok {
return Request{}, nil, ErrRequestNotFound
}
rule, err := RuleFromDecision(pending.request, outcome)
if err != nil {
return Request{}, nil, err
}
select {
case pending.done <- outcome:
default:
}
return pending.request, rule, nil
}
func resultForOutcome(request Request, outcome Outcome) (Result, error) {
rule, err := RuleFromDecision(request, outcome)
if err != nil {
return Result{}, err
}
result := Result{Outcome: outcome, Rule: rule}
switch outcome {
case OutcomeAllowOnce, OutcomeAllowPermanent:
return result, nil
case OutcomeDenyOnce, OutcomeDenyPermanent:
return result, ErrRequestDenied
case OutcomeCancel:
return result, ErrRequestCanceled
default:
return Result{}, fmt.Errorf("unsupported approval outcome %q", outcome)
}
}
func RuleFromDecision(request Request, outcome Outcome) (*apitokens.PolicyRule, error) {
switch outcome {
case OutcomeAllowPermanent:
rule := apitokens.PolicyRule{
Effect: apitokens.EffectAllow,
Operation: request.Operation,
Resource: request.Resource,
}
return &rule, nil
case OutcomeDenyPermanent:
rule := apitokens.PolicyRule{
Effect: apitokens.EffectDeny,
Operation: request.Operation,
Resource: request.Resource,
}
return &rule, nil
case OutcomeAllowOnce, OutcomeDenyOnce, OutcomeCancel:
return nil, nil
default:
return nil, fmt.Errorf("unsupported approval outcome %q", outcome)
}
}
+134
View File
@@ -0,0 +1,134 @@
package apiapproval
import (
"context"
"errors"
"testing"
"time"
"git.julianfamily.org/keepassgo/apitokens"
)
func TestBrokerCreatesPendingRequestAndAllowsOnce(t *testing.T) {
t.Parallel()
broker := NewBroker(time.Minute)
broker.now = func() time.Time { return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) }
resultCh := make(chan Result, 1)
errCh := make(chan error, 1)
go func() {
result, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI", ClientName: "grpc-cli"}, apitokens.OperationListEntries, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}})
resultCh <- result
errCh <- err
}()
waitForPending(t, broker, 1)
pending := broker.Pending()
if len(pending) != 1 {
t.Fatalf("Pending() len = %d, want 1", len(pending))
}
if pending[0].TokenID != "token-1" {
t.Fatalf("Pending()[0].TokenID = %q, want token-1", pending[0].TokenID)
}
if _, _, err := broker.Resolve(pending[0].ID, OutcomeAllowOnce); err != nil {
t.Fatalf("Resolve(allow once) error = %v", err)
}
result := <-resultCh
if err := <-errCh; err != nil {
t.Fatalf("Request() error = %v, want nil", err)
}
if result.Outcome != OutcomeAllowOnce {
t.Fatalf("Request() outcome = %q, want %q", result.Outcome, OutcomeAllowOnce)
}
if result.Rule != nil {
t.Fatalf("Request() rule = %#v, want nil for allow-once", result.Rule)
}
if got := broker.Pending(); len(got) != 0 {
t.Fatalf("Pending() after allow len = %d, want 0", len(got))
}
}
func TestBrokerReturnsPermanentRuleForDeny(t *testing.T) {
t.Parallel()
broker := NewBroker(time.Minute)
reqDone := make(chan struct{})
var result Result
var err error
go func() {
result, err = broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationReadEntry, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: "entry-1", Path: []string{"Root", "Internet"}})
close(reqDone)
}()
waitForPending(t, broker, 1)
pending := broker.Pending()[0]
request, rule, resolveErr := broker.Resolve(pending.ID, OutcomeDenyPermanent)
if resolveErr != nil {
t.Fatalf("Resolve(deny permanent) error = %v", resolveErr)
}
if request.ID != pending.ID {
t.Fatalf("Resolve().ID = %q, want %q", request.ID, pending.ID)
}
if rule == nil || rule.Effect != apitokens.EffectDeny || rule.Operation != apitokens.OperationReadEntry {
t.Fatalf("Resolve() rule = %#v, want deny read-entry rule", rule)
}
<-reqDone
if !errors.Is(err, ErrRequestDenied) {
t.Fatalf("Request() error = %v, want ErrRequestDenied", err)
}
if result.Rule == nil || result.Rule.Effect != apitokens.EffectDeny {
t.Fatalf("Request() rule = %#v, want deny rule", result.Rule)
}
if result.Outcome != OutcomeDenyPermanent {
t.Fatalf("Request() outcome = %q, want %q", result.Outcome, OutcomeDenyPermanent)
}
}
func TestBrokerSupportsCancellation(t *testing.T) {
t.Parallel()
broker := NewBroker(time.Minute)
errCh := make(chan error, 1)
go func() {
_, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}})
errCh <- err
}()
waitForPending(t, broker, 1)
if _, _, err := broker.Resolve(broker.Pending()[0].ID, OutcomeCancel); err != nil {
t.Fatalf("Resolve(cancel) error = %v", err)
}
if err := <-errCh; !errors.Is(err, ErrRequestCanceled) {
t.Fatalf("Request() error = %v, want ErrRequestCanceled", err)
}
}
func TestBrokerTimesOutPendingRequests(t *testing.T) {
t.Parallel()
broker := NewBroker(10 * time.Millisecond)
_, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}})
if !errors.Is(err, ErrRequestTimedOut) {
t.Fatalf("Request() error = %v, want ErrRequestTimedOut", err)
}
if got := broker.Pending(); len(got) != 0 {
t.Fatalf("Pending() len after timeout = %d, want 0", len(got))
}
}
func waitForPending(t *testing.T, broker *Broker, want int) {
t.Helper()
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if got := len(broker.Pending()); got == want {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("Pending() never reached len %d", want)
}