diff --git a/api/server.go b/api/server.go index 157cec5..6aa6b34 100644 --- a/api/server.go +++ b/api/server.go @@ -4,15 +4,17 @@ import ( "context" "errors" "maps" + "os" "slices" - "sync" "strings" + "sync" "git.julianfamily.org/keepassgo/clipboard" "git.julianfamily.org/keepassgo/passwords" keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1" - "git.julianfamily.org/keepassgo/webdav" + "git.julianfamily.org/keepassgo/session" "git.julianfamily.org/keepassgo/vault" + "git.julianfamily.org/keepassgo/webdav" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -36,6 +38,8 @@ type lifecycleBackend interface { Open(string, vault.MasterKey) error OpenRemote(webdav.Client, string, vault.MasterKey) error Save() error + Lock() error + Unlock(vault.MasterKey) error } func NewServer(model vault.Model, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer) *Server { @@ -57,8 +61,8 @@ func (s *Server) GetSessionStatus(_ context.Context, _ *keepassgov1.GetSessionSt defer s.mu.RUnlock() return &keepassgov1.GetSessionStatusResponse{ - Locked: s.locked, - Dirty: s.dirty, + Locked: s.locked, + Dirty: s.dirty, EntryCount: uint32(len(s.model.Entries)), }, nil } @@ -70,12 +74,12 @@ func (s *Server) OpenVault(_ context.Context, req *keepassgov1.OpenVaultRequest) key := vault.MasterKey{Password: req.GetPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)} if err := s.lifecycle.Open(req.GetPath(), key); err != nil { - return nil, status.Errorf(codes.Internal, "open vault: %v", err) + return nil, mapLifecycleError("open vault", err) } model, err := s.lifecycle.Current() if err != nil { - return nil, status.Errorf(codes.Internal, "load opened vault: %v", err) + return nil, mapLifecycleError("load opened vault", err) } s.mu.Lock() @@ -99,12 +103,12 @@ func (s *Server) OpenRemoteVault(_ context.Context, req *keepassgov1.OpenRemoteV } key := vault.MasterKey{Password: req.GetMasterPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)} if err := s.lifecycle.OpenRemote(client, req.GetPath(), key); err != nil { - return nil, status.Errorf(codes.Internal, "open remote vault: %v", err) + return nil, mapLifecycleError("open remote vault", err) } model, err := s.lifecycle.Current() if err != nil { - return nil, status.Errorf(codes.Internal, "load opened remote vault: %v", err) + return nil, mapLifecycleError("load opened remote vault", err) } s.mu.Lock() @@ -122,7 +126,7 @@ func (s *Server) SaveVault(_ context.Context, _ *keepassgov1.SaveVaultRequest) ( } if err := s.lifecycle.Save(); err != nil { - return nil, status.Errorf(codes.Internal, "save vault: %v", err) + return nil, mapLifecycleError("save vault", err) } s.mu.Lock() @@ -133,23 +137,59 @@ func (s *Server) SaveVault(_ context.Context, _ *keepassgov1.SaveVaultRequest) ( } func (s *Server) LockVault(_ context.Context, _ *keepassgov1.LockVaultRequest) (*keepassgov1.LockVaultResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() + if s.lifecycle == nil { + return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured") + } + if err := s.lifecycle.Lock(); err != nil { + return nil, mapLifecycleError("lock vault", err) + } + + s.mu.Lock() s.locked = true + s.mu.Unlock() return &keepassgov1.LockVaultResponse{}, nil } -func (s *Server) UnlockVault(_ context.Context, _ *keepassgov1.UnlockVaultRequest) (*keepassgov1.UnlockVaultResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *Server) UnlockVault(_ context.Context, req *keepassgov1.UnlockVaultRequest) (*keepassgov1.UnlockVaultResponse, error) { + if s.lifecycle == nil { + return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured") + } + key := vault.MasterKey{Password: req.GetPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)} + if err := s.lifecycle.Unlock(key); err != nil { + return nil, mapLifecycleError("unlock vault", err) + } + + model, err := s.lifecycle.Current() + if err != nil { + return nil, mapLifecycleError("load unlocked vault", err) + } + + s.mu.Lock() + s.model = model s.locked = false + s.mu.Unlock() return &keepassgov1.UnlockVaultResponse{}, nil } +func mapLifecycleError(operation string, err error) error { + switch { + case errors.Is(err, os.ErrNotExist): + return status.Errorf(codes.NotFound, "%s: %v", operation, err) + case errors.Is(err, vault.ErrInvalidMasterKey): + return status.Errorf(codes.InvalidArgument, "%s: %v", operation, err) + case errors.Is(err, session.ErrLocked), errors.Is(err, session.ErrNoPath): + return status.Errorf(codes.FailedPrecondition, "%s: %v", operation, err) + case errors.Is(err, webdav.ErrConflict): + return status.Errorf(codes.Aborted, "%s: %v", operation, err) + default: + return status.Errorf(codes.Internal, "%s: %v", operation, err) + } +} + func (s *Server) ListEntries(_ context.Context, req *keepassgov1.ListEntriesRequest) (*keepassgov1.ListEntriesResponse, error) { s.mu.RLock() defer s.mu.RUnlock() diff --git a/api/server_test.go b/api/server_test.go index d12494c..02500fb 100644 --- a/api/server_test.go +++ b/api/server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "net" + "os" "testing" "git.julianfamily.org/keepassgo/passwords" @@ -34,7 +35,21 @@ func TestVaultServiceRejectsRequestsWithoutBearerToken(t *testing.T) { func TestVaultServiceReportsSessionStatusAndSupportsLockUnlock(t *testing.T) { t.Parallel() - client, _, cleanup := newTestClient(t) + lifecycle := &stubLifecycle{ + model: vault.Model{ + Entries: []vault.Entry{ + { + ID: "vault-console", + Title: "Vault Console", + Username: "dannyocean", + Password: "token-1", + URL: "https://vault.crew.example.invalid", + Path: []string{"Root", "Internet"}, + }, + }, + }, + } + client, _, cleanup := newTestClientWithLifecycle(t, lifecycle) defer cleanup() ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token") @@ -78,6 +93,82 @@ func TestVaultServiceReportsSessionStatusAndSupportsLockUnlock(t *testing.T) { } } +func TestVaultServiceLockAndUnlockUseLifecycleBackend(t *testing.T) { + t.Parallel() + + lifecycle := &stubLifecycle{ + model: vault.Model{ + Entries: []vault.Entry{ + {ID: "entry-1", Title: "Remote Git", Path: []string{"Root", "Internet"}}, + }, + }, + unlockPassword: "correct horse battery staple", + unlockKeyFile: []byte("key-material"), + } + client, _, cleanup := newTestClientWithLifecycle(t, lifecycle) + defer cleanup() + + ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token") + if _, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{ + Path: "/tmp/test.kdbx", + Password: lifecycle.unlockPassword, + KeyFileData: lifecycle.unlockKeyFile, + }); err != nil { + t.Fatalf("OpenVault() error = %v", err) + } + + if _, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{}); err != nil { + t.Fatalf("LockVault() error = %v", err) + } + if !lifecycle.locked { + t.Fatal("LockVault() did not lock lifecycle backend") + } + + statusResp, err := client.GetSessionStatus(ctx, &keepassgov1.GetSessionStatusRequest{}) + if err != nil { + t.Fatalf("GetSessionStatus() after lock error = %v", err) + } + if !statusResp.Locked { + t.Fatal("GetSessionStatus().Locked = false, want true after lock") + } + + if _, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{ + Password: "wrong password", + KeyFileData: lifecycle.unlockKeyFile, + }); status.Code(err) != codes.InvalidArgument { + t.Fatalf("UnlockVault(wrong password) code = %v, want %v", status.Code(err), codes.InvalidArgument) + } + + statusResp, err = client.GetSessionStatus(ctx, &keepassgov1.GetSessionStatusRequest{}) + if err != nil { + t.Fatalf("GetSessionStatus() after failed unlock error = %v", err) + } + if !statusResp.Locked { + t.Fatal("GetSessionStatus().Locked = false, want true after failed unlock") + } + + if _, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{ + Password: lifecycle.unlockPassword, + KeyFileData: lifecycle.unlockKeyFile, + }); err != nil { + t.Fatalf("UnlockVault() error = %v", err) + } + if lifecycle.lastUnlockKey.Password != lifecycle.unlockPassword { + t.Fatalf("UnlockVault() password = %q, want %q", lifecycle.lastUnlockKey.Password, lifecycle.unlockPassword) + } + if !bytes.Equal(lifecycle.lastUnlockKey.KeyFileData, lifecycle.unlockKeyFile) { + t.Fatalf("UnlockVault() key data = %q, want %q", lifecycle.lastUnlockKey.KeyFileData, lifecycle.unlockKeyFile) + } + + listed, err := client.ListEntries(ctx, &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}}) + if err != nil { + t.Fatalf("ListEntries() after unlock error = %v", err) + } + if len(listed.Entries) != 1 || listed.Entries[0].Title != "Remote Git" { + t.Fatalf("ListEntries().Entries = %#v, want Remote Git after unlock", listed.Entries) + } +} + func TestVaultServiceOpensAndSavesVaultThroughLifecycleBackend(t *testing.T) { t.Parallel() @@ -131,6 +222,143 @@ func TestVaultServiceOpensAndSavesVaultThroughLifecycleBackend(t *testing.T) { } } +func TestVaultServiceLifecycleMethodsRequireLifecycleBackend(t *testing.T) { + t.Parallel() + + client, _, cleanup := newTestClient(t) + defer cleanup() + + ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token") + + testCases := []struct { + name string + call func() error + }{ + { + name: "open", + call: func() error { + _, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/test.kdbx"}) + return err + }, + }, + { + name: "open_remote", + call: func() error { + _, err := client.OpenRemoteVault(ctx, &keepassgov1.OpenRemoteVaultRequest{ + BaseUrl: "https://dav.example.com", + Path: "vaults/main.kdbx", + }) + return err + }, + }, + { + name: "save", + call: func() error { + _, err := client.SaveVault(ctx, &keepassgov1.SaveVaultRequest{}) + return err + }, + }, + { + name: "lock", + call: func() error { + _, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{}) + return err + }, + }, + { + name: "unlock", + call: func() error { + _, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{}) + return err + }, + }, + } + + for _, tt := range testCases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + err := tt.call() + if status.Code(err) != codes.FailedPrecondition { + t.Fatalf("%s code = %v, want %v", tt.name, status.Code(err), codes.FailedPrecondition) + } + }) + } +} + +func TestVaultServiceLifecycleMethodsMapBackendErrors(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + call func(keepassgov1.VaultServiceClient, context.Context) error + err error + want codes.Code + }{ + { + name: "open not found", + call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error { + _, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/missing.kdbx"}) + return err + }, + err: os.ErrNotExist, + want: codes.NotFound, + }, + { + name: "open invalid master key", + call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error { + _, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/test.kdbx"}) + return err + }, + err: vault.ErrInvalidMasterKey, + want: codes.InvalidArgument, + }, + { + name: "save no path", + call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error { + _, err := client.SaveVault(ctx, &keepassgov1.SaveVaultRequest{}) + return err + }, + err: session.ErrNoPath, + want: codes.FailedPrecondition, + }, + { + name: "lock already locked", + call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error { + _, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{}) + return err + }, + err: session.ErrLocked, + want: codes.FailedPrecondition, + }, + { + name: "unlock invalid master key", + call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error { + _, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{Password: "wrong"}) + return err + }, + err: vault.ErrInvalidMasterKey, + want: codes.InvalidArgument, + }, + } + + for _, tt := range testCases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + lifecycle := &stubLifecycle{err: tt.err} + client, _, cleanup := newTestClientWithLifecycle(t, lifecycle) + defer cleanup() + + ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token") + err := tt.call(client, ctx) + if status.Code(err) != tt.want { + t.Fatalf("%s code = %v, want %v", tt.name, status.Code(err), tt.want) + } + }) + } +} + func TestVaultServiceListsEntriesForAuthorizedClients(t *testing.T) { t.Parallel() @@ -503,7 +731,7 @@ func newTestClient(t *testing.T) (keepassgov1.VaultServiceClient, *memoryClipboa Path: []string{"Root", "Internet"}, }, }, - Path: []string{"Root", "Internet"}, + Path: []string{"Root", "Internet"}, }, { ID: "surveillance-console", @@ -560,7 +788,7 @@ func newTestClientWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepass server := grpc.NewServer(grpc.UnaryInterceptor(BearerTokenInterceptor("test-token"))) clipboardWriter := &memoryClipboardWriter{} keepassgov1.RegisterVaultServiceServer(server, NewServerWithLifecycle( - vault.Model{}, + lifecycle.model, passwords.DefaultProfiles(), clipboardWriter, lifecycle, @@ -598,12 +826,16 @@ func (w *memoryClipboardWriter) WriteText(text string) error { } type stubLifecycle struct { - model vault.Model - openPath string - remoteBaseURL string - remotePath string - saved bool - locked bool + model vault.Model + openPath string + remoteBaseURL string + remotePath string + saved bool + locked bool + err error + unlockPassword string + unlockKeyFile []byte + lastUnlockKey vault.MasterKey } func (s *stubLifecycle) Current() (vault.Model, error) { @@ -614,17 +846,56 @@ func (s *stubLifecycle) Current() (vault.Model, error) { } func (s *stubLifecycle) Open(path string, _ vault.MasterKey) error { + if s.err != nil { + return s.err + } s.openPath = path + s.locked = false return nil } func (s *stubLifecycle) OpenRemote(client webdav.Client, path string, _ vault.MasterKey) error { + if s.err != nil { + return s.err + } s.remoteBaseURL = client.BaseURL s.remotePath = path + s.locked = false return nil } func (s *stubLifecycle) Save() error { + if s.err != nil { + return s.err + } s.saved = true return nil } + +func (s *stubLifecycle) Lock() error { + if s.err != nil { + return s.err + } + + s.locked = true + return nil +} + +func (s *stubLifecycle) Unlock(key vault.MasterKey) error { + if s.err != nil { + return s.err + } + if s.unlockPassword != "" && key.Password != s.unlockPassword { + return vault.ErrInvalidMasterKey + } + if s.unlockKeyFile != nil && !bytes.Equal(key.KeyFileData, s.unlockKeyFile) { + return vault.ErrInvalidMasterKey + } + + s.lastUnlockKey = vault.MasterKey{ + Password: key.Password, + KeyFileData: append([]byte(nil), key.KeyFileData...), + } + s.locked = false + return nil +} diff --git a/proto/keepassgo/v1/keepassgo.pb.go b/proto/keepassgo/v1/keepassgo.pb.go index 7136e87..8708e19 100644 --- a/proto/keepassgo/v1/keepassgo.pb.go +++ b/proto/keepassgo/v1/keepassgo.pb.go @@ -479,6 +479,8 @@ func (*LockVaultResponse) Descriptor() ([]byte, []int) { type UnlockVaultRequest struct { state protoimpl.MessageState `protogen:"open.v1"` + Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` + KeyFileData []byte `protobuf:"bytes,2,opt,name=key_file_data,json=keyFileData,proto3" json:"key_file_data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -513,6 +515,20 @@ func (*UnlockVaultRequest) Descriptor() ([]byte, []int) { return file_proto_keepassgo_v1_keepassgo_proto_rawDescGZIP(), []int{10} } +func (x *UnlockVaultRequest) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +func (x *UnlockVaultRequest) GetKeyFileData() []byte { + if x != nil { + return x.KeyFileData + } + return nil +} + type UnlockVaultResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -2364,8 +2380,10 @@ const file_proto_keepassgo_v1_keepassgo_proto_rawDesc = "" + "\x10SaveVaultRequest\"\x13\n" + "\x11SaveVaultResponse\"\x12\n" + "\x10LockVaultRequest\"\x13\n" + - "\x11LockVaultResponse\"\x14\n" + - "\x12UnlockVaultRequest\"\x15\n" + + "\x11LockVaultResponse\"T\n" + + "\x12UnlockVaultRequest\x12\x1a\n" + + "\bpassword\x18\x01 \x01(\tR\bpassword\x12\"\n" + + "\rkey_file_data\x18\x02 \x01(\fR\vkeyFileData\"\x15\n" + "\x13UnlockVaultResponse\">\n" + "\x12ListEntriesRequest\x12\x12\n" + "\x04path\x18\x01 \x03(\tR\x04path\x12\x14\n" + diff --git a/proto/keepassgo/v1/keepassgo.proto b/proto/keepassgo/v1/keepassgo.proto index 4005184..f8d1583 100644 --- a/proto/keepassgo/v1/keepassgo.proto +++ b/proto/keepassgo/v1/keepassgo.proto @@ -67,7 +67,10 @@ message LockVaultRequest {} message LockVaultResponse {} -message UnlockVaultRequest {} +message UnlockVaultRequest { + string password = 1; + bytes key_file_data = 2; +} message UnlockVaultResponse {}