Add gRPC vault lifecycle backend flow

This commit is contained in:
Joe Julian
2026-03-29 11:21:52 -07:00
parent 4b4696ce30
commit 6c5e9b42d3
4 changed files with 358 additions and 26 deletions
+54 -14
View File
@@ -4,15 +4,17 @@ import (
"context"
"errors"
"maps"
"os"
"slices"
"sync"
"strings"
"sync"
"git.julianfamily.org/keepassgo/clipboard"
"git.julianfamily.org/keepassgo/passwords"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
"git.julianfamily.org/keepassgo/webdav"
"git.julianfamily.org/keepassgo/session"
"git.julianfamily.org/keepassgo/vault"
"git.julianfamily.org/keepassgo/webdav"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
@@ -36,6 +38,8 @@ type lifecycleBackend interface {
Open(string, vault.MasterKey) error
OpenRemote(webdav.Client, string, vault.MasterKey) error
Save() error
Lock() error
Unlock(vault.MasterKey) error
}
func NewServer(model vault.Model, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer) *Server {
@@ -57,8 +61,8 @@ func (s *Server) GetSessionStatus(_ context.Context, _ *keepassgov1.GetSessionSt
defer s.mu.RUnlock()
return &keepassgov1.GetSessionStatusResponse{
Locked: s.locked,
Dirty: s.dirty,
Locked: s.locked,
Dirty: s.dirty,
EntryCount: uint32(len(s.model.Entries)),
}, nil
}
@@ -70,12 +74,12 @@ func (s *Server) OpenVault(_ context.Context, req *keepassgov1.OpenVaultRequest)
key := vault.MasterKey{Password: req.GetPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
if err := s.lifecycle.Open(req.GetPath(), key); err != nil {
return nil, status.Errorf(codes.Internal, "open vault: %v", err)
return nil, mapLifecycleError("open vault", err)
}
model, err := s.lifecycle.Current()
if err != nil {
return nil, status.Errorf(codes.Internal, "load opened vault: %v", err)
return nil, mapLifecycleError("load opened vault", err)
}
s.mu.Lock()
@@ -99,12 +103,12 @@ func (s *Server) OpenRemoteVault(_ context.Context, req *keepassgov1.OpenRemoteV
}
key := vault.MasterKey{Password: req.GetMasterPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
if err := s.lifecycle.OpenRemote(client, req.GetPath(), key); err != nil {
return nil, status.Errorf(codes.Internal, "open remote vault: %v", err)
return nil, mapLifecycleError("open remote vault", err)
}
model, err := s.lifecycle.Current()
if err != nil {
return nil, status.Errorf(codes.Internal, "load opened remote vault: %v", err)
return nil, mapLifecycleError("load opened remote vault", err)
}
s.mu.Lock()
@@ -122,7 +126,7 @@ func (s *Server) SaveVault(_ context.Context, _ *keepassgov1.SaveVaultRequest) (
}
if err := s.lifecycle.Save(); err != nil {
return nil, status.Errorf(codes.Internal, "save vault: %v", err)
return nil, mapLifecycleError("save vault", err)
}
s.mu.Lock()
@@ -133,23 +137,59 @@ func (s *Server) SaveVault(_ context.Context, _ *keepassgov1.SaveVaultRequest) (
}
func (s *Server) LockVault(_ context.Context, _ *keepassgov1.LockVaultRequest) (*keepassgov1.LockVaultResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.lifecycle == nil {
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
}
if err := s.lifecycle.Lock(); err != nil {
return nil, mapLifecycleError("lock vault", err)
}
s.mu.Lock()
s.locked = true
s.mu.Unlock()
return &keepassgov1.LockVaultResponse{}, nil
}
func (s *Server) UnlockVault(_ context.Context, _ *keepassgov1.UnlockVaultRequest) (*keepassgov1.UnlockVaultResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
func (s *Server) UnlockVault(_ context.Context, req *keepassgov1.UnlockVaultRequest) (*keepassgov1.UnlockVaultResponse, error) {
if s.lifecycle == nil {
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
}
key := vault.MasterKey{Password: req.GetPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
if err := s.lifecycle.Unlock(key); err != nil {
return nil, mapLifecycleError("unlock vault", err)
}
model, err := s.lifecycle.Current()
if err != nil {
return nil, mapLifecycleError("load unlocked vault", err)
}
s.mu.Lock()
s.model = model
s.locked = false
s.mu.Unlock()
return &keepassgov1.UnlockVaultResponse{}, nil
}
func mapLifecycleError(operation string, err error) error {
switch {
case errors.Is(err, os.ErrNotExist):
return status.Errorf(codes.NotFound, "%s: %v", operation, err)
case errors.Is(err, vault.ErrInvalidMasterKey):
return status.Errorf(codes.InvalidArgument, "%s: %v", operation, err)
case errors.Is(err, session.ErrLocked), errors.Is(err, session.ErrNoPath):
return status.Errorf(codes.FailedPrecondition, "%s: %v", operation, err)
case errors.Is(err, webdav.ErrConflict):
return status.Errorf(codes.Aborted, "%s: %v", operation, err)
default:
return status.Errorf(codes.Internal, "%s: %v", operation, err)
}
}
func (s *Server) ListEntries(_ context.Context, req *keepassgov1.ListEntriesRequest) (*keepassgov1.ListEntriesResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
+280 -9
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"net"
"os"
"testing"
"git.julianfamily.org/keepassgo/passwords"
@@ -34,7 +35,21 @@ func TestVaultServiceRejectsRequestsWithoutBearerToken(t *testing.T) {
func TestVaultServiceReportsSessionStatusAndSupportsLockUnlock(t *testing.T) {
t.Parallel()
client, _, cleanup := newTestClient(t)
lifecycle := &stubLifecycle{
model: vault.Model{
Entries: []vault.Entry{
{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "token-1",
URL: "https://vault.crew.example.invalid",
Path: []string{"Root", "Internet"},
},
},
},
}
client, _, cleanup := newTestClientWithLifecycle(t, lifecycle)
defer cleanup()
ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token")
@@ -78,6 +93,82 @@ func TestVaultServiceReportsSessionStatusAndSupportsLockUnlock(t *testing.T) {
}
}
func TestVaultServiceLockAndUnlockUseLifecycleBackend(t *testing.T) {
t.Parallel()
lifecycle := &stubLifecycle{
model: vault.Model{
Entries: []vault.Entry{
{ID: "entry-1", Title: "Remote Git", Path: []string{"Root", "Internet"}},
},
},
unlockPassword: "correct horse battery staple",
unlockKeyFile: []byte("key-material"),
}
client, _, cleanup := newTestClientWithLifecycle(t, lifecycle)
defer cleanup()
ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token")
if _, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{
Path: "/tmp/test.kdbx",
Password: lifecycle.unlockPassword,
KeyFileData: lifecycle.unlockKeyFile,
}); err != nil {
t.Fatalf("OpenVault() error = %v", err)
}
if _, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{}); err != nil {
t.Fatalf("LockVault() error = %v", err)
}
if !lifecycle.locked {
t.Fatal("LockVault() did not lock lifecycle backend")
}
statusResp, err := client.GetSessionStatus(ctx, &keepassgov1.GetSessionStatusRequest{})
if err != nil {
t.Fatalf("GetSessionStatus() after lock error = %v", err)
}
if !statusResp.Locked {
t.Fatal("GetSessionStatus().Locked = false, want true after lock")
}
if _, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{
Password: "wrong password",
KeyFileData: lifecycle.unlockKeyFile,
}); status.Code(err) != codes.InvalidArgument {
t.Fatalf("UnlockVault(wrong password) code = %v, want %v", status.Code(err), codes.InvalidArgument)
}
statusResp, err = client.GetSessionStatus(ctx, &keepassgov1.GetSessionStatusRequest{})
if err != nil {
t.Fatalf("GetSessionStatus() after failed unlock error = %v", err)
}
if !statusResp.Locked {
t.Fatal("GetSessionStatus().Locked = false, want true after failed unlock")
}
if _, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{
Password: lifecycle.unlockPassword,
KeyFileData: lifecycle.unlockKeyFile,
}); err != nil {
t.Fatalf("UnlockVault() error = %v", err)
}
if lifecycle.lastUnlockKey.Password != lifecycle.unlockPassword {
t.Fatalf("UnlockVault() password = %q, want %q", lifecycle.lastUnlockKey.Password, lifecycle.unlockPassword)
}
if !bytes.Equal(lifecycle.lastUnlockKey.KeyFileData, lifecycle.unlockKeyFile) {
t.Fatalf("UnlockVault() key data = %q, want %q", lifecycle.lastUnlockKey.KeyFileData, lifecycle.unlockKeyFile)
}
listed, err := client.ListEntries(ctx, &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}})
if err != nil {
t.Fatalf("ListEntries() after unlock error = %v", err)
}
if len(listed.Entries) != 1 || listed.Entries[0].Title != "Remote Git" {
t.Fatalf("ListEntries().Entries = %#v, want Remote Git after unlock", listed.Entries)
}
}
func TestVaultServiceOpensAndSavesVaultThroughLifecycleBackend(t *testing.T) {
t.Parallel()
@@ -131,6 +222,143 @@ func TestVaultServiceOpensAndSavesVaultThroughLifecycleBackend(t *testing.T) {
}
}
func TestVaultServiceLifecycleMethodsRequireLifecycleBackend(t *testing.T) {
t.Parallel()
client, _, cleanup := newTestClient(t)
defer cleanup()
ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token")
testCases := []struct {
name string
call func() error
}{
{
name: "open",
call: func() error {
_, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/test.kdbx"})
return err
},
},
{
name: "open_remote",
call: func() error {
_, err := client.OpenRemoteVault(ctx, &keepassgov1.OpenRemoteVaultRequest{
BaseUrl: "https://dav.example.com",
Path: "vaults/main.kdbx",
})
return err
},
},
{
name: "save",
call: func() error {
_, err := client.SaveVault(ctx, &keepassgov1.SaveVaultRequest{})
return err
},
},
{
name: "lock",
call: func() error {
_, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{})
return err
},
},
{
name: "unlock",
call: func() error {
_, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{})
return err
},
},
}
for _, tt := range testCases {
tt := tt
t.Run(tt.name, func(t *testing.T) {
err := tt.call()
if status.Code(err) != codes.FailedPrecondition {
t.Fatalf("%s code = %v, want %v", tt.name, status.Code(err), codes.FailedPrecondition)
}
})
}
}
func TestVaultServiceLifecycleMethodsMapBackendErrors(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
call func(keepassgov1.VaultServiceClient, context.Context) error
err error
want codes.Code
}{
{
name: "open not found",
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
_, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/missing.kdbx"})
return err
},
err: os.ErrNotExist,
want: codes.NotFound,
},
{
name: "open invalid master key",
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
_, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/test.kdbx"})
return err
},
err: vault.ErrInvalidMasterKey,
want: codes.InvalidArgument,
},
{
name: "save no path",
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
_, err := client.SaveVault(ctx, &keepassgov1.SaveVaultRequest{})
return err
},
err: session.ErrNoPath,
want: codes.FailedPrecondition,
},
{
name: "lock already locked",
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
_, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{})
return err
},
err: session.ErrLocked,
want: codes.FailedPrecondition,
},
{
name: "unlock invalid master key",
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
_, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{Password: "wrong"})
return err
},
err: vault.ErrInvalidMasterKey,
want: codes.InvalidArgument,
},
}
for _, tt := range testCases {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
lifecycle := &stubLifecycle{err: tt.err}
client, _, cleanup := newTestClientWithLifecycle(t, lifecycle)
defer cleanup()
ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token")
err := tt.call(client, ctx)
if status.Code(err) != tt.want {
t.Fatalf("%s code = %v, want %v", tt.name, status.Code(err), tt.want)
}
})
}
}
func TestVaultServiceListsEntriesForAuthorizedClients(t *testing.T) {
t.Parallel()
@@ -503,7 +731,7 @@ func newTestClient(t *testing.T) (keepassgov1.VaultServiceClient, *memoryClipboa
Path: []string{"Root", "Internet"},
},
},
Path: []string{"Root", "Internet"},
Path: []string{"Root", "Internet"},
},
{
ID: "surveillance-console",
@@ -560,7 +788,7 @@ func newTestClientWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepass
server := grpc.NewServer(grpc.UnaryInterceptor(BearerTokenInterceptor("test-token")))
clipboardWriter := &memoryClipboardWriter{}
keepassgov1.RegisterVaultServiceServer(server, NewServerWithLifecycle(
vault.Model{},
lifecycle.model,
passwords.DefaultProfiles(),
clipboardWriter,
lifecycle,
@@ -598,12 +826,16 @@ func (w *memoryClipboardWriter) WriteText(text string) error {
}
type stubLifecycle struct {
model vault.Model
openPath string
remoteBaseURL string
remotePath string
saved bool
locked bool
model vault.Model
openPath string
remoteBaseURL string
remotePath string
saved bool
locked bool
err error
unlockPassword string
unlockKeyFile []byte
lastUnlockKey vault.MasterKey
}
func (s *stubLifecycle) Current() (vault.Model, error) {
@@ -614,17 +846,56 @@ func (s *stubLifecycle) Current() (vault.Model, error) {
}
func (s *stubLifecycle) Open(path string, _ vault.MasterKey) error {
if s.err != nil {
return s.err
}
s.openPath = path
s.locked = false
return nil
}
func (s *stubLifecycle) OpenRemote(client webdav.Client, path string, _ vault.MasterKey) error {
if s.err != nil {
return s.err
}
s.remoteBaseURL = client.BaseURL
s.remotePath = path
s.locked = false
return nil
}
func (s *stubLifecycle) Save() error {
if s.err != nil {
return s.err
}
s.saved = true
return nil
}
func (s *stubLifecycle) Lock() error {
if s.err != nil {
return s.err
}
s.locked = true
return nil
}
func (s *stubLifecycle) Unlock(key vault.MasterKey) error {
if s.err != nil {
return s.err
}
if s.unlockPassword != "" && key.Password != s.unlockPassword {
return vault.ErrInvalidMasterKey
}
if s.unlockKeyFile != nil && !bytes.Equal(key.KeyFileData, s.unlockKeyFile) {
return vault.ErrInvalidMasterKey
}
s.lastUnlockKey = vault.MasterKey{
Password: key.Password,
KeyFileData: append([]byte(nil), key.KeyFileData...),
}
s.locked = false
return nil
}
+20 -2
View File
@@ -479,6 +479,8 @@ func (*LockVaultResponse) Descriptor() ([]byte, []int) {
type UnlockVaultRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"`
KeyFileData []byte `protobuf:"bytes,2,opt,name=key_file_data,json=keyFileData,proto3" json:"key_file_data,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -513,6 +515,20 @@ func (*UnlockVaultRequest) Descriptor() ([]byte, []int) {
return file_proto_keepassgo_v1_keepassgo_proto_rawDescGZIP(), []int{10}
}
func (x *UnlockVaultRequest) GetPassword() string {
if x != nil {
return x.Password
}
return ""
}
func (x *UnlockVaultRequest) GetKeyFileData() []byte {
if x != nil {
return x.KeyFileData
}
return nil
}
type UnlockVaultResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -2364,8 +2380,10 @@ const file_proto_keepassgo_v1_keepassgo_proto_rawDesc = "" +
"\x10SaveVaultRequest\"\x13\n" +
"\x11SaveVaultResponse\"\x12\n" +
"\x10LockVaultRequest\"\x13\n" +
"\x11LockVaultResponse\"\x14\n" +
"\x12UnlockVaultRequest\"\x15\n" +
"\x11LockVaultResponse\"T\n" +
"\x12UnlockVaultRequest\x12\x1a\n" +
"\bpassword\x18\x01 \x01(\tR\bpassword\x12\"\n" +
"\rkey_file_data\x18\x02 \x01(\fR\vkeyFileData\"\x15\n" +
"\x13UnlockVaultResponse\">\n" +
"\x12ListEntriesRequest\x12\x12\n" +
"\x04path\x18\x01 \x03(\tR\x04path\x12\x14\n" +
+4 -1
View File
@@ -67,7 +67,10 @@ message LockVaultRequest {}
message LockVaultResponse {}
message UnlockVaultRequest {}
message UnlockVaultRequest {
string password = 1;
bytes key_file_data = 2;
}
message UnlockVaultResponse {}