Add gRPC vault lifecycle backend flow
This commit is contained in:
+280
-9
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.julianfamily.org/keepassgo/passwords"
|
||||
@@ -34,7 +35,21 @@ func TestVaultServiceRejectsRequestsWithoutBearerToken(t *testing.T) {
|
||||
func TestVaultServiceReportsSessionStatusAndSupportsLockUnlock(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, _, cleanup := newTestClient(t)
|
||||
lifecycle := &stubLifecycle{
|
||||
model: vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{
|
||||
ID: "vault-console",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "token-1",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
client, _, cleanup := newTestClientWithLifecycle(t, lifecycle)
|
||||
defer cleanup()
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token")
|
||||
@@ -78,6 +93,82 @@ func TestVaultServiceReportsSessionStatusAndSupportsLockUnlock(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultServiceLockAndUnlockUseLifecycleBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lifecycle := &stubLifecycle{
|
||||
model: vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{ID: "entry-1", Title: "Remote Git", Path: []string{"Root", "Internet"}},
|
||||
},
|
||||
},
|
||||
unlockPassword: "correct horse battery staple",
|
||||
unlockKeyFile: []byte("key-material"),
|
||||
}
|
||||
client, _, cleanup := newTestClientWithLifecycle(t, lifecycle)
|
||||
defer cleanup()
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token")
|
||||
if _, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{
|
||||
Path: "/tmp/test.kdbx",
|
||||
Password: lifecycle.unlockPassword,
|
||||
KeyFileData: lifecycle.unlockKeyFile,
|
||||
}); err != nil {
|
||||
t.Fatalf("OpenVault() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{}); err != nil {
|
||||
t.Fatalf("LockVault() error = %v", err)
|
||||
}
|
||||
if !lifecycle.locked {
|
||||
t.Fatal("LockVault() did not lock lifecycle backend")
|
||||
}
|
||||
|
||||
statusResp, err := client.GetSessionStatus(ctx, &keepassgov1.GetSessionStatusRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionStatus() after lock error = %v", err)
|
||||
}
|
||||
if !statusResp.Locked {
|
||||
t.Fatal("GetSessionStatus().Locked = false, want true after lock")
|
||||
}
|
||||
|
||||
if _, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{
|
||||
Password: "wrong password",
|
||||
KeyFileData: lifecycle.unlockKeyFile,
|
||||
}); status.Code(err) != codes.InvalidArgument {
|
||||
t.Fatalf("UnlockVault(wrong password) code = %v, want %v", status.Code(err), codes.InvalidArgument)
|
||||
}
|
||||
|
||||
statusResp, err = client.GetSessionStatus(ctx, &keepassgov1.GetSessionStatusRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionStatus() after failed unlock error = %v", err)
|
||||
}
|
||||
if !statusResp.Locked {
|
||||
t.Fatal("GetSessionStatus().Locked = false, want true after failed unlock")
|
||||
}
|
||||
|
||||
if _, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{
|
||||
Password: lifecycle.unlockPassword,
|
||||
KeyFileData: lifecycle.unlockKeyFile,
|
||||
}); err != nil {
|
||||
t.Fatalf("UnlockVault() error = %v", err)
|
||||
}
|
||||
if lifecycle.lastUnlockKey.Password != lifecycle.unlockPassword {
|
||||
t.Fatalf("UnlockVault() password = %q, want %q", lifecycle.lastUnlockKey.Password, lifecycle.unlockPassword)
|
||||
}
|
||||
if !bytes.Equal(lifecycle.lastUnlockKey.KeyFileData, lifecycle.unlockKeyFile) {
|
||||
t.Fatalf("UnlockVault() key data = %q, want %q", lifecycle.lastUnlockKey.KeyFileData, lifecycle.unlockKeyFile)
|
||||
}
|
||||
|
||||
listed, err := client.ListEntries(ctx, &keepassgov1.ListEntriesRequest{Path: []string{"Root", "Internet"}})
|
||||
if err != nil {
|
||||
t.Fatalf("ListEntries() after unlock error = %v", err)
|
||||
}
|
||||
if len(listed.Entries) != 1 || listed.Entries[0].Title != "Remote Git" {
|
||||
t.Fatalf("ListEntries().Entries = %#v, want Remote Git after unlock", listed.Entries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultServiceOpensAndSavesVaultThroughLifecycleBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -131,6 +222,143 @@ func TestVaultServiceOpensAndSavesVaultThroughLifecycleBackend(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultServiceLifecycleMethodsRequireLifecycleBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, _, cleanup := newTestClient(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
call func() error
|
||||
}{
|
||||
{
|
||||
name: "open",
|
||||
call: func() error {
|
||||
_, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/test.kdbx"})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "open_remote",
|
||||
call: func() error {
|
||||
_, err := client.OpenRemoteVault(ctx, &keepassgov1.OpenRemoteVaultRequest{
|
||||
BaseUrl: "https://dav.example.com",
|
||||
Path: "vaults/main.kdbx",
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "save",
|
||||
call: func() error {
|
||||
_, err := client.SaveVault(ctx, &keepassgov1.SaveVaultRequest{})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "lock",
|
||||
call: func() error {
|
||||
_, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unlock",
|
||||
call: func() error {
|
||||
_, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.call()
|
||||
if status.Code(err) != codes.FailedPrecondition {
|
||||
t.Fatalf("%s code = %v, want %v", tt.name, status.Code(err), codes.FailedPrecondition)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultServiceLifecycleMethodsMapBackendErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
call func(keepassgov1.VaultServiceClient, context.Context) error
|
||||
err error
|
||||
want codes.Code
|
||||
}{
|
||||
{
|
||||
name: "open not found",
|
||||
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
|
||||
_, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/missing.kdbx"})
|
||||
return err
|
||||
},
|
||||
err: os.ErrNotExist,
|
||||
want: codes.NotFound,
|
||||
},
|
||||
{
|
||||
name: "open invalid master key",
|
||||
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
|
||||
_, err := client.OpenVault(ctx, &keepassgov1.OpenVaultRequest{Path: "/tmp/test.kdbx"})
|
||||
return err
|
||||
},
|
||||
err: vault.ErrInvalidMasterKey,
|
||||
want: codes.InvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "save no path",
|
||||
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
|
||||
_, err := client.SaveVault(ctx, &keepassgov1.SaveVaultRequest{})
|
||||
return err
|
||||
},
|
||||
err: session.ErrNoPath,
|
||||
want: codes.FailedPrecondition,
|
||||
},
|
||||
{
|
||||
name: "lock already locked",
|
||||
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
|
||||
_, err := client.LockVault(ctx, &keepassgov1.LockVaultRequest{})
|
||||
return err
|
||||
},
|
||||
err: session.ErrLocked,
|
||||
want: codes.FailedPrecondition,
|
||||
},
|
||||
{
|
||||
name: "unlock invalid master key",
|
||||
call: func(client keepassgov1.VaultServiceClient, ctx context.Context) error {
|
||||
_, err := client.UnlockVault(ctx, &keepassgov1.UnlockVaultRequest{Password: "wrong"})
|
||||
return err
|
||||
},
|
||||
err: vault.ErrInvalidMasterKey,
|
||||
want: codes.InvalidArgument,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lifecycle := &stubLifecycle{err: tt.err}
|
||||
client, _, cleanup := newTestClientWithLifecycle(t, lifecycle)
|
||||
defer cleanup()
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", "Bearer test-token")
|
||||
err := tt.call(client, ctx)
|
||||
if status.Code(err) != tt.want {
|
||||
t.Fatalf("%s code = %v, want %v", tt.name, status.Code(err), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultServiceListsEntriesForAuthorizedClients(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -503,7 +731,7 @@ func newTestClient(t *testing.T) (keepassgov1.VaultServiceClient, *memoryClipboa
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
},
|
||||
Path: []string{"Root", "Internet"},
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
{
|
||||
ID: "surveillance-console",
|
||||
@@ -560,7 +788,7 @@ func newTestClientWithLifecycle(t *testing.T, lifecycle *stubLifecycle) (keepass
|
||||
server := grpc.NewServer(grpc.UnaryInterceptor(BearerTokenInterceptor("test-token")))
|
||||
clipboardWriter := &memoryClipboardWriter{}
|
||||
keepassgov1.RegisterVaultServiceServer(server, NewServerWithLifecycle(
|
||||
vault.Model{},
|
||||
lifecycle.model,
|
||||
passwords.DefaultProfiles(),
|
||||
clipboardWriter,
|
||||
lifecycle,
|
||||
@@ -598,12 +826,16 @@ func (w *memoryClipboardWriter) WriteText(text string) error {
|
||||
}
|
||||
|
||||
type stubLifecycle struct {
|
||||
model vault.Model
|
||||
openPath string
|
||||
remoteBaseURL string
|
||||
remotePath string
|
||||
saved bool
|
||||
locked bool
|
||||
model vault.Model
|
||||
openPath string
|
||||
remoteBaseURL string
|
||||
remotePath string
|
||||
saved bool
|
||||
locked bool
|
||||
err error
|
||||
unlockPassword string
|
||||
unlockKeyFile []byte
|
||||
lastUnlockKey vault.MasterKey
|
||||
}
|
||||
|
||||
func (s *stubLifecycle) Current() (vault.Model, error) {
|
||||
@@ -614,17 +846,56 @@ func (s *stubLifecycle) Current() (vault.Model, error) {
|
||||
}
|
||||
|
||||
func (s *stubLifecycle) Open(path string, _ vault.MasterKey) error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
s.openPath = path
|
||||
s.locked = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubLifecycle) OpenRemote(client webdav.Client, path string, _ vault.MasterKey) error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
s.remoteBaseURL = client.BaseURL
|
||||
s.remotePath = path
|
||||
s.locked = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubLifecycle) Save() error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
s.saved = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubLifecycle) Lock() error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
|
||||
s.locked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubLifecycle) Unlock(key vault.MasterKey) error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
if s.unlockPassword != "" && key.Password != s.unlockPassword {
|
||||
return vault.ErrInvalidMasterKey
|
||||
}
|
||||
if s.unlockKeyFile != nil && !bytes.Equal(key.KeyFileData, s.unlockKeyFile) {
|
||||
return vault.ErrInvalidMasterKey
|
||||
}
|
||||
|
||||
s.lastUnlockKey = vault.MasterKey{
|
||||
Password: key.Password,
|
||||
KeyFileData: append([]byte(nil), key.KeyFileData...),
|
||||
}
|
||||
s.locked = false
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user