Move app packages under internal

This commit is contained in:
Joe Julian
2026-04-09 06:42:21 -07:00
parent 7751b5472a
commit fe921b8790
55 changed files with 162 additions and 162 deletions
+122
View File
@@ -0,0 +1,122 @@
package api
import (
"errors"
"fmt"
"net"
"strings"
"sync"
"git.julianfamily.org/keepassgo/internal/clipboard"
"git.julianfamily.org/keepassgo/internal/passwords"
"git.julianfamily.org/keepassgo/internal/session"
"git.julianfamily.org/keepassgo/internal/vault"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
"google.golang.org/grpc"
)
type DirtyProvider func() bool
type Host struct {
server *Server
grpcServer *grpc.Server
listener net.Listener
lifecycle lifecycleBackend
dirty DirtyProvider
mu sync.Mutex
lastModel vault.Model
started bool
listenAddr string
}
func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, dirty DirtyProvider) (*Host, error) {
addr = strings.TrimSpace(addr)
if addr == "" || strings.EqualFold(addr, "off") {
return nil, nil
}
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("listen gRPC host %s: %w", addr, err)
}
service := NewServerWithLifecycle(vault.Model{}, profiles, clipboardWriter, lifecycle)
server := grpc.NewServer(grpc.UnaryInterceptor(AuthInterceptor(service)))
keepassgov1.RegisterVaultServiceServer(server, service)
host := &Host{
server: service,
grpcServer: server,
listener: listener,
lifecycle: lifecycle,
dirty: dirty,
listenAddr: listener.Addr().String(),
started: true,
}
if err := host.SyncFromLifecycle(); err != nil && !errors.Is(err, session.ErrLocked) {
_ = listener.Close()
server.Stop()
return nil, err
}
go func() {
_ = server.Serve(listener)
}()
return host, nil
}
func (h *Host) Address() string {
if h == nil {
return ""
}
return h.listenAddr
}
func (h *Host) Server() *Server {
if h == nil {
return nil
}
return h.server
}
func (h *Host) Stop() error {
if h == nil {
return nil
}
h.mu.Lock()
defer h.mu.Unlock()
if !h.started {
return nil
}
h.started = false
h.grpcServer.Stop()
return h.listener.Close()
}
func (h *Host) SyncFromLifecycle() error {
if h == nil || h.lifecycle == nil || h.server == nil {
return nil
}
h.mu.Lock()
defer h.mu.Unlock()
model, err := h.lifecycle.Current()
locked := false
switch {
case err == nil:
h.lastModel = model
case errors.Is(err, session.ErrLocked):
locked = true
default:
return err
}
dirty := false
if h.dirty != nil {
dirty = h.dirty()
}
h.server.SetSessionState(h.lastModel, locked, dirty)
return nil
}
+72
View File
@@ -0,0 +1,72 @@
package api
import (
"context"
"net"
"testing"
"git.julianfamily.org/keepassgo/internal/passwords"
"git.julianfamily.org/keepassgo/internal/session"
"git.julianfamily.org/keepassgo/internal/vault"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func TestStartHostServesVaultLifecycleAndSyncsSessionState(t *testing.T) {
t.Parallel()
lifecycle := &session.Manager{}
if err := lifecycle.Create(vault.Model{
Entries: []vault.Entry{
testAPITokenEntry(t),
{ID: "entry-1", Title: "Vault Console", Path: []string{"Root", "Internet"}},
},
}, vault.MasterKey{Password: "correct horse battery staple"}); err != nil {
t.Fatalf("Create() error = %v", err)
}
host, err := StartHost("127.0.0.1:0", lifecycle, passwords.DefaultProfiles(), nil, func() bool { return true })
if err != nil {
t.Fatalf("StartHost() error = %v", err)
}
defer func() { _ = host.Stop() }()
conn, err := grpc.NewClient("passthrough:///"+host.Address(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return net.Dial("tcp", host.Address())
}),
)
if err != nil {
t.Fatalf("grpc.NewClient() error = %v", err)
}
defer func() { _ = conn.Close() }()
client := keepassgov1.NewVaultServiceClient(conn)
statusResp, err := client.GetSessionStatus(tokenContext(defaultTestTokenSecret), &keepassgov1.GetSessionStatusRequest{})
if err != nil {
t.Fatalf("GetSessionStatus() error = %v", err)
}
if statusResp.Locked {
t.Fatal("GetSessionStatus().Locked = true, want false")
}
if !statusResp.Dirty {
t.Fatal("GetSessionStatus().Dirty = false, want true from dirty provider")
}
if err := lifecycle.Lock(); err != nil {
t.Fatalf("Lock() error = %v", err)
}
if err := host.SyncFromLifecycle(); err != nil {
t.Fatalf("SyncFromLifecycle() after lock error = %v", err)
}
statusResp, err = client.GetSessionStatus(tokenContext(defaultTestTokenSecret), &keepassgov1.GetSessionStatusRequest{})
if err != nil {
t.Fatalf("GetSessionStatus() after lock error = %v", err)
}
if !statusResp.Locked {
t.Fatal("GetSessionStatus().Locked = false, want true after lifecycle lock")
}
}
+979
View File
@@ -0,0 +1,979 @@
package api
import (
"context"
"errors"
"maps"
"os"
"slices"
"strings"
"sync"
"time"
"git.julianfamily.org/keepassgo/internal/apiapproval"
"git.julianfamily.org/keepassgo/internal/apiaudit"
"git.julianfamily.org/keepassgo/internal/apitokens"
"git.julianfamily.org/keepassgo/internal/clipboard"
"git.julianfamily.org/keepassgo/internal/passwords"
"git.julianfamily.org/keepassgo/internal/session"
"git.julianfamily.org/keepassgo/internal/vault"
"git.julianfamily.org/keepassgo/internal/webdav"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type Server struct {
keepassgov1.UnimplementedVaultServiceServer
mu sync.RWMutex
model vault.Model
locked bool
dirty bool
lifecycle lifecycleBackend
profiles map[string]passwords.Profile
clipboard clipboard.Writer
approvals *apiapproval.Broker
audit *apiaudit.Log
}
type lifecycleBackend interface {
Current() (vault.Model, error)
Open(string, vault.MasterKey) error
OpenRemote(webdav.Client, string, vault.MasterKey) error
Save() error
Lock() error
Unlock(vault.MasterKey) error
}
type modelReplaceableLifecycle interface {
lifecycleBackend
Replace(vault.Model)
}
func NewServer(model vault.Model, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer) *Server {
return &Server{
model: model,
profiles: profiles,
clipboard: clipboardWriter,
approvals: apiapproval.NewBroker(30 * time.Second),
audit: apiaudit.New(200),
}
}
func NewServerWithLifecycle(model vault.Model, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, lifecycle lifecycleBackend) *Server {
server := NewServer(model, profiles, clipboardWriter)
server.lifecycle = lifecycle
return server
}
func (s *Server) ApprovalBroker() *apiapproval.Broker {
return s.approvals
}
func (s *Server) AuditLog() *apiaudit.Log {
return s.audit
}
func (s *Server) ResolveApproval(id string, outcome apiapproval.Outcome) (apiapproval.Request, *apitokens.PolicyRule, error) {
return s.approvals.Resolve(id, outcome)
}
func (s *Server) SetSessionState(model vault.Model, locked, dirty bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.model = model
s.locked = locked
s.dirty = dirty
}
func (s *Server) GetSessionStatus(_ context.Context, _ *keepassgov1.GetSessionStatusRequest) (*keepassgov1.GetSessionStatusResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
return &keepassgov1.GetSessionStatusResponse{
Locked: s.locked,
Dirty: s.dirty,
EntryCount: uint32(len(s.model.Entries)),
}, nil
}
func (s *Server) OpenVault(_ context.Context, req *keepassgov1.OpenVaultRequest) (*keepassgov1.OpenVaultResponse, error) {
if s.lifecycle == nil {
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
}
key := vault.MasterKey{Password: req.GetPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
if err := s.lifecycle.Open(req.GetPath(), key); err != nil {
return nil, mapLifecycleError("open vault", err)
}
model, err := s.lifecycle.Current()
if err != nil {
return nil, mapLifecycleError("load opened vault", err)
}
s.mu.Lock()
s.model = model
s.locked = false
s.dirty = false
s.mu.Unlock()
return &keepassgov1.OpenVaultResponse{}, nil
}
func (s *Server) OpenRemoteVault(_ context.Context, req *keepassgov1.OpenRemoteVaultRequest) (*keepassgov1.OpenRemoteVaultResponse, error) {
if s.lifecycle == nil {
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
}
client := webdav.Client{
BaseURL: req.GetBaseUrl(),
Username: req.GetUsername(),
Password: req.GetPassword(),
}
key := vault.MasterKey{Password: req.GetMasterPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
if err := s.lifecycle.OpenRemote(client, req.GetPath(), key); err != nil {
return nil, mapLifecycleError("open remote vault", err)
}
model, err := s.lifecycle.Current()
if err != nil {
return nil, mapLifecycleError("load opened remote vault", err)
}
s.mu.Lock()
s.model = model
s.locked = false
s.dirty = false
s.mu.Unlock()
return &keepassgov1.OpenRemoteVaultResponse{}, nil
}
func (s *Server) SaveVault(_ context.Context, _ *keepassgov1.SaveVaultRequest) (*keepassgov1.SaveVaultResponse, error) {
if s.lifecycle == nil {
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
}
if err := s.lifecycle.Save(); err != nil {
return nil, mapLifecycleError("save vault", err)
}
s.mu.Lock()
s.dirty = false
s.mu.Unlock()
return &keepassgov1.SaveVaultResponse{}, nil
}
func (s *Server) LockVault(_ context.Context, _ *keepassgov1.LockVaultRequest) (*keepassgov1.LockVaultResponse, error) {
if s.lifecycle == nil {
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
}
if err := s.lifecycle.Lock(); err != nil {
return nil, mapLifecycleError("lock vault", err)
}
s.mu.Lock()
s.locked = true
s.mu.Unlock()
return &keepassgov1.LockVaultResponse{}, nil
}
func (s *Server) UnlockVault(_ context.Context, req *keepassgov1.UnlockVaultRequest) (*keepassgov1.UnlockVaultResponse, error) {
if s.lifecycle == nil {
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
}
key := vault.MasterKey{Password: req.GetPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
if err := s.lifecycle.Unlock(key); err != nil {
return nil, mapLifecycleError("unlock vault", err)
}
model, err := s.lifecycle.Current()
if err != nil {
return nil, mapLifecycleError("load unlocked vault", err)
}
s.mu.Lock()
s.model = model
s.locked = false
s.mu.Unlock()
return &keepassgov1.UnlockVaultResponse{}, nil
}
func mapLifecycleError(operation string, err error) error {
switch {
case errors.Is(err, os.ErrNotExist):
return status.Errorf(codes.NotFound, "%s: %v", operation, err)
case errors.Is(err, vault.ErrInvalidMasterKey):
return status.Errorf(codes.InvalidArgument, "%s: %v", operation, err)
case errors.Is(err, session.ErrLocked), errors.Is(err, session.ErrNoPath):
return status.Errorf(codes.FailedPrecondition, "%s: %v", operation, err)
case errors.Is(err, webdav.ErrConflict):
return status.Errorf(codes.Aborted, "%s: %v", operation, err)
default:
return status.Errorf(codes.Internal, "%s: %v", operation, err)
}
}
func (s *Server) ListEntries(ctx context.Context, req *keepassgov1.ListEntriesRequest) (*keepassgov1.ListEntriesResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if _, err := s.authorizePathRequest(ctx, apitokens.OperationListEntries, req.GetPath()); err != nil {
return nil, err
}
model = visibleModel(model)
var entries []vault.Entry
if strings.TrimSpace(req.GetQuery()) != "" {
results := model.Search(req.GetQuery())
entries = make([]vault.Entry, 0, len(results))
for _, result := range results {
entries = append(entries, result.Entry)
}
} else {
entries = model.EntriesInPath(req.GetPath())
}
resp := &keepassgov1.ListEntriesResponse{
Entries: make([]*keepassgov1.Entry, 0, len(entries)),
}
for _, entry := range entries {
resp.Entries = append(resp.Entries, entryToProto(entry))
}
return resp, nil
}
func (s *Server) ListGroups(ctx context.Context, req *keepassgov1.ListGroupsRequest) (*keepassgov1.ListGroupsResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if _, err := s.authorizePathRequest(ctx, apitokens.OperationListGroups, req.GetPath()); err != nil {
return nil, err
}
return &keepassgov1.ListGroupsResponse{
Names: visibleModel(model).ChildGroups(req.GetPath()),
}, nil
}
func (s *Server) CreateGroup(ctx context.Context, req *keepassgov1.CreateGroupRequest) (*keepassgov1.CreateGroupResponse, error) {
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetParentPath()); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
s.model.CreateGroup(req.GetParentPath(), req.GetName())
s.dirty = true
return &keepassgov1.CreateGroupResponse{}, nil
}
func (s *Server) RenameGroup(ctx context.Context, req *keepassgov1.RenameGroupRequest) (*keepassgov1.RenameGroupResponse, error) {
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetPath()); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if err := s.model.RenameGroup(req.GetPath(), req.GetNewName()); err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
}
return nil, status.Errorf(codes.Internal, "rename group: %v", err)
}
s.dirty = true
return &keepassgov1.RenameGroupResponse{}, nil
}
func (s *Server) DeleteGroup(ctx context.Context, req *keepassgov1.DeleteGroupRequest) (*keepassgov1.DeleteGroupResponse, error) {
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetPath()); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if err := s.model.DeleteGroup(req.GetPath()); err != nil {
switch {
case errors.Is(err, vault.ErrEntryNotFound):
return nil, status.Error(codes.NotFound, err.Error())
case errors.Is(err, vault.ErrGroupNotEmpty):
return nil, status.Error(codes.FailedPrecondition, err.Error())
default:
return nil, status.Errorf(codes.Internal, "delete group: %v", err)
}
}
s.dirty = true
return &keepassgov1.DeleteGroupResponse{}, nil
}
func (s *Server) UpsertEntry(ctx context.Context, req *keepassgov1.UpsertEntryRequest) (*keepassgov1.UpsertEntryResponse, error) {
if req.GetEntry() == nil {
return nil, status.Error(codes.InvalidArgument, "missing entry")
}
entry := entryFromProto(req.GetEntry())
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
return nil, err
}
s.mu.Lock()
if s.locked {
s.mu.Unlock()
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
s.model.UpsertEntry(entry)
s.dirty = true
s.mu.Unlock()
return &keepassgov1.UpsertEntryResponse{Entry: entryToProto(entry)}, nil
}
func (s *Server) DeleteEntry(ctx context.Context, req *keepassgov1.DeleteEntryRequest) (*keepassgov1.DeleteEntryResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if entry, err := findEntryByID(model, req.GetId()); err == nil {
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
return nil, err
}
}
s.mu.Lock()
defer s.mu.Unlock()
if err := s.model.DeleteEntry(req.GetId()); err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
}
return nil, status.Errorf(codes.Internal, "delete entry: %v", err)
}
s.dirty = true
return &keepassgov1.DeleteEntryResponse{}, nil
}
func (s *Server) RestoreEntry(ctx context.Context, req *keepassgov1.RestoreEntryRequest) (*keepassgov1.RestoreEntryResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
var restored vault.Entry
for _, entry := range model.RecycleBin {
if entry.ID == req.GetId() {
restored = entry
break
}
}
if restored.ID != "" {
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, restored); err != nil {
return nil, err
}
}
s.mu.Lock()
defer s.mu.Unlock()
if err := s.model.RestoreEntry(req.GetId()); err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
}
return nil, status.Errorf(codes.Internal, "restore entry: %v", err)
}
s.dirty = true
return &keepassgov1.RestoreEntryResponse{Entry: entryToProto(restored)}, nil
}
func (s *Server) ListEntryHistory(ctx context.Context, req *keepassgov1.ListEntryHistoryRequest) (*keepassgov1.ListEntryHistoryResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(model, req.GetId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationReadEntry, entry); err != nil {
return nil, err
}
resp := &keepassgov1.ListEntryHistoryResponse{
Entries: make([]*keepassgov1.Entry, 0, len(entry.History)),
}
for _, historical := range entry.History {
resp.Entries = append(resp.Entries, entryToProto(historical))
}
return resp, nil
}
func (s *Server) RestoreEntryHistory(ctx context.Context, req *keepassgov1.RestoreEntryHistoryRequest) (*keepassgov1.RestoreEntryHistoryResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(model, req.GetId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if err := s.model.RestoreEntryVersion(req.GetId(), int(req.GetHistoryIndex())); err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
}
return nil, status.Errorf(codes.Internal, "restore entry history: %v", err)
}
entry, err = findEntryByID(s.model, req.GetId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
s.dirty = true
return &keepassgov1.RestoreEntryHistoryResponse{Entry: entryToProto(entry)}, nil
}
func (s *Server) ListTemplates(_ context.Context, _ *keepassgov1.ListTemplatesRequest) (*keepassgov1.ListTemplatesResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
resp := &keepassgov1.ListTemplatesResponse{
Templates: make([]*keepassgov1.Entry, 0, len(s.model.Templates)),
}
for _, template := range s.model.Templates {
resp.Templates = append(resp.Templates, entryToProto(template))
}
return resp, nil
}
func (s *Server) UpsertTemplate(_ context.Context, req *keepassgov1.UpsertTemplateRequest) (*keepassgov1.UpsertTemplateResponse, error) {
if req.GetTemplate() == nil {
return nil, status.Error(codes.InvalidArgument, "missing template")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry := entryFromProto(req.GetTemplate())
s.model.UpsertTemplate(entry)
s.dirty = true
return &keepassgov1.UpsertTemplateResponse{Template: entryToProto(entry)}, nil
}
func (s *Server) DeleteTemplate(_ context.Context, req *keepassgov1.DeleteTemplateRequest) (*keepassgov1.DeleteTemplateResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
if err := s.model.DeleteTemplate(req.GetId()); err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
}
return nil, status.Errorf(codes.Internal, "delete template: %v", err)
}
s.dirty = true
return &keepassgov1.DeleteTemplateResponse{}, nil
}
func (s *Server) InstantiateTemplate(_ context.Context, req *keepassgov1.InstantiateTemplateRequest) (*keepassgov1.InstantiateTemplateResponse, error) {
if req.GetOverrides() == nil {
return nil, status.Error(codes.InvalidArgument, "missing overrides")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := s.model.InstantiateTemplate(req.GetTemplateId(), entryFromProto(req.GetOverrides()))
if err != nil {
if errors.Is(err, vault.ErrEntryNotFound) {
return nil, status.Error(codes.NotFound, err.Error())
}
return nil, status.Errorf(codes.Internal, "instantiate template: %v", err)
}
s.dirty = true
return &keepassgov1.InstantiateTemplateResponse{Entry: entryToProto(entry)}, nil
}
func (s *Server) ListAttachments(ctx context.Context, req *keepassgov1.ListAttachmentsRequest) (*keepassgov1.ListAttachmentsResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationReadEntry, entry); err != nil {
return nil, err
}
names := make([]string, 0, len(entry.Attachments))
for name := range entry.Attachments {
names = append(names, name)
}
slices.Sort(names)
return &keepassgov1.ListAttachmentsResponse{Names: names}, nil
}
func (s *Server) UploadAttachment(ctx context.Context, req *keepassgov1.UploadAttachmentRequest) (*keepassgov1.UploadAttachmentResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if entry.Attachments == nil {
entry.Attachments = map[string][]byte{}
}
entry.Attachments[req.GetName()] = append([]byte(nil), req.GetContent()...)
s.model.Entries[index] = entry
s.dirty = true
return &keepassgov1.UploadAttachmentResponse{}, nil
}
func (s *Server) DownloadAttachment(ctx context.Context, req *keepassgov1.DownloadAttachmentRequest) (*keepassgov1.DownloadAttachmentResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationReadEntry, entry); err != nil {
return nil, err
}
content, ok := entry.Attachments[req.GetName()]
if !ok {
return nil, status.Error(codes.NotFound, "attachment not found")
}
return &keepassgov1.DownloadAttachmentResponse{
Content: append([]byte(nil), content...),
}, nil
}
func (s *Server) DeleteAttachment(ctx context.Context, req *keepassgov1.DeleteAttachmentRequest) (*keepassgov1.DeleteAttachmentResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, ok := entry.Attachments[req.GetName()]; !ok {
return nil, status.Error(codes.NotFound, "attachment not found")
}
delete(entry.Attachments, req.GetName())
if len(entry.Attachments) == 0 {
entry.Attachments = nil
}
s.model.Entries[index] = entry
s.dirty = true
return &keepassgov1.DeleteAttachmentResponse{}, nil
}
func (s *Server) CopyEntryField(ctx context.Context, req *keepassgov1.CopyEntryFieldRequest) (*keepassgov1.CopyEntryFieldResponse, error) {
model, locked := s.snapshotModel()
if locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
entry, err := findEntryByID(model, req.GetId())
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
if _, err := s.authorizeEntryRequest(ctx, copyOperation(req.GetTarget()), entry); err != nil {
return nil, err
}
service := clipboard.Service{Writer: s.clipboard}
if err := service.Copy(model, req.GetId(), clipboard.Target(req.GetTarget())); err != nil {
switch {
case errors.Is(err, vault.ErrEntryNotFound):
return nil, status.Error(codes.NotFound, err.Error())
case errors.Is(err, clipboard.ErrUnsupportedTarget):
return nil, status.Error(codes.InvalidArgument, err.Error())
default:
return nil, status.Errorf(codes.Internal, "copy entry field: %v", err)
}
}
return &keepassgov1.CopyEntryFieldResponse{}, nil
}
func (s *Server) GeneratePassword(ctx context.Context, req *keepassgov1.GeneratePasswordRequest) (*keepassgov1.GeneratePasswordResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if _, err := s.authenticateRequest(ctx); err != nil {
return nil, err
}
if s.locked {
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
}
profile, err := passwords.LookupProfile(req.GetProfile(), s.profiles)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
password, err := passwords.Generate(profile)
if err != nil {
return nil, status.Errorf(codes.Internal, "generate password: %v", err)
}
return &keepassgov1.GeneratePasswordResponse{Password: password}, nil
}
func entryToProto(entry vault.Entry) *keepassgov1.Entry {
return &keepassgov1.Entry{
Id: entry.ID,
Title: entry.Title,
Username: entry.Username,
Password: entry.Password,
Url: entry.URL,
Notes: entry.Notes,
Tags: append([]string(nil), entry.Tags...),
Path: append([]string(nil), entry.Path...),
Fields: maps.Clone(entry.Fields),
}
}
func entryFromProto(entry *keepassgov1.Entry) vault.Entry {
return vault.Entry{
ID: entry.GetId(),
Title: entry.GetTitle(),
Username: entry.GetUsername(),
Password: entry.GetPassword(),
URL: entry.GetUrl(),
Notes: entry.GetNotes(),
Tags: append([]string(nil), entry.GetTags()...),
Path: append([]string(nil), entry.GetPath()...),
Fields: maps.Clone(entry.GetFields()),
}
}
func findEntryByID(model vault.Model, id string) (vault.Entry, error) {
for _, entry := range model.Entries {
if entry.ID == id {
return entry, nil
}
}
return vault.Entry{}, vault.ErrEntryNotFound
}
func findMutableEntryByID(model *vault.Model, id string) (vault.Entry, int, error) {
for i, entry := range model.Entries {
if entry.ID == id {
entry.Attachments = maps.Clone(entry.Attachments)
return entry, i, nil
}
}
return vault.Entry{}, -1, vault.ErrEntryNotFound
}
func visibleModel(model vault.Model) vault.Model {
out := model
out.Entries = nil
for _, entry := range model.Entries {
token, ok, err := apitokens.TokenFromEntry(entry)
if err == nil && ok && token.ID != "" {
continue
}
out.Entries = append(out.Entries, entry)
}
out.Groups = nil
for _, path := range model.Groups {
if len(path) >= 2 && path[0] == "Root" && path[1] == "API Tokens" {
continue
}
out.Groups = append(out.Groups, path)
}
return out
}
func (s *Server) snapshotModel() (vault.Model, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.model, s.locked
}
var timeNow = func() time.Time { return time.Now().UTC() }
func (s *Server) authenticateRequest(ctx context.Context) (apitokens.Token, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return apitokens.Token{}, status.Error(codes.Unauthenticated, "missing metadata")
}
values := md.Get("authorization")
if len(values) == 0 {
s.audit.Record(apiaudit.Event{Type: apiaudit.EventAuthRejected, Message: "missing authorization"})
return apitokens.Token{}, status.Error(codes.Unauthenticated, "missing authorization")
}
const prefix = "Bearer "
if !strings.HasPrefix(values[0], prefix) {
s.audit.Record(apiaudit.Event{Type: apiaudit.EventAuthRejected, Message: "invalid bearer token"})
return apitokens.Token{}, status.Error(codes.Unauthenticated, "invalid bearer token")
}
s.mu.RLock()
tokens, err := apitokens.Entries(s.model)
s.mu.RUnlock()
if err != nil {
return apitokens.Token{}, status.Errorf(codes.Internal, "load api tokens: %v", err)
}
token, err := apitokens.Authenticate(tokens, strings.TrimSpace(strings.TrimPrefix(values[0], prefix)), timeNow())
if err != nil {
switch err {
case apitokens.ErrInvalidToken, apitokens.ErrExpiredToken, apitokens.ErrDisabledToken:
s.audit.Record(apiaudit.Event{Type: apiaudit.EventAuthRejected, Message: err.Error()})
return apitokens.Token{}, status.Error(codes.Unauthenticated, err.Error())
default:
return apitokens.Token{}, status.Errorf(codes.Internal, "authenticate api token: %v", err)
}
}
return token, nil
}
func (s *Server) authorizePathRequest(ctx context.Context, op apitokens.Operation, path []string) (apitokens.Token, error) {
token, err := s.authenticateRequest(ctx)
if err != nil {
return apitokens.Token{}, err
}
return s.authorizeResourceRequest(ctx, token, op, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: path})
}
func (s *Server) authorizeEntryRequest(ctx context.Context, op apitokens.Operation, entry vault.Entry) (apitokens.Token, error) {
token, err := s.authenticateRequest(ctx)
if err != nil {
return apitokens.Token{}, err
}
return s.authorizeResourceRequest(ctx, token, op, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: entry.ID, Path: entry.Path})
}
func (s *Server) authorizeResourceRequest(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (apitokens.Token, error) {
switch apitokens.Evaluate(token, op, resource) {
case apitokens.DecisionAllow:
return token, nil
case apitokens.DecisionDeny:
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token")
case apitokens.DecisionPrompt:
s.audit.Record(apiaudit.Event{
Type: apiaudit.EventApprovalRequested,
TokenID: token.ID,
TokenName: token.Name,
ClientName: token.ClientName,
Operation: op,
Resource: resource,
})
result, err := s.approvals.Request(ctx, token, op, resource)
if result.Rule != nil {
if persistErr := s.persistApprovalRule(token.ID, *result.Rule); persistErr != nil {
return apitokens.Token{}, status.Errorf(codes.Internal, "persist approval decision: %v", persistErr)
}
}
switch {
case err == nil:
s.audit.Record(apiaudit.Event{
Type: apiaudit.EventApprovalAllowed,
TokenID: token.ID,
TokenName: token.Name,
ClientName: token.ClientName,
Operation: op,
Resource: resource,
})
return token, nil
case errors.Is(err, apiapproval.ErrRequestDenied):
s.audit.Record(apiaudit.Event{
Type: apiaudit.EventApprovalDenied,
TokenID: token.ID,
TokenName: token.Name,
ClientName: token.ClientName,
Operation: op,
Resource: resource,
})
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access denied by user approval")
case errors.Is(err, apiapproval.ErrRequestCanceled):
s.audit.Record(apiaudit.Event{
Type: apiaudit.EventApprovalCanceled,
TokenID: token.ID,
TokenName: token.Name,
ClientName: token.ClientName,
Operation: op,
Resource: resource,
})
return apitokens.Token{}, status.Error(codes.Unauthenticated, "authorization request canceled")
case errors.Is(err, apiapproval.ErrRequestTimedOut):
s.audit.Record(apiaudit.Event{
Type: apiaudit.EventApprovalTimedOut,
TokenID: token.ID,
TokenName: token.Name,
ClientName: token.ClientName,
Operation: op,
Resource: resource,
})
return apitokens.Token{}, status.Error(codes.DeadlineExceeded, "authorization request timed out")
case errors.Is(err, context.Canceled):
return apitokens.Token{}, status.Error(codes.Canceled, "authorization request canceled")
case errors.Is(err, context.DeadlineExceeded):
return apitokens.Token{}, status.Error(codes.DeadlineExceeded, "authorization request timed out")
default:
return apitokens.Token{}, status.Errorf(codes.Internal, "await authorization request: %v", err)
}
default:
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token")
}
}
func (s *Server) persistApprovalRule(tokenID string, rule apitokens.PolicyRule) error {
s.mu.Lock()
defer s.mu.Unlock()
for i, entry := range s.model.Entries {
token, ok, err := apitokens.TokenFromEntry(entry)
if err != nil || !ok || token.ID != tokenID {
continue
}
if !hasPolicyRule(token.Policies, rule) {
token.Policies = append(token.Policies, rule)
}
s.model.Entries[i] = token.Entry(entry.Path)
s.dirty = true
if lifecycle, ok := s.lifecycle.(modelReplaceableLifecycle); ok {
lifecycle.Replace(s.model)
}
return nil
}
return status.Error(codes.NotFound, "api token entry not found")
}
func hasPolicyRule(rules []apitokens.PolicyRule, target apitokens.PolicyRule) bool {
for _, rule := range rules {
if rule.Effect != target.Effect || rule.Operation != target.Operation {
continue
}
if rule.Resource.Kind != target.Resource.Kind || rule.Resource.EntryID != target.Resource.EntryID {
continue
}
if slices.Equal(rule.Resource.Path, target.Resource.Path) {
return true
}
}
return false
}
func copyOperation(target string) apitokens.Operation {
switch clipboard.Target(target) {
case clipboard.TargetUsername:
return apitokens.OperationCopyUsername
case clipboard.TargetURL:
return apitokens.OperationCopyURL
default:
return apitokens.OperationCopyPassword
}
}
func AuthInterceptor(server *Server) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
switch info.FullMethod {
case "/keepassgo.v1.VaultService/GetSessionStatus",
"/keepassgo.v1.VaultService/OpenVault",
"/keepassgo.v1.VaultService/OpenRemoteVault",
"/keepassgo.v1.VaultService/SaveVault",
"/keepassgo.v1.VaultService/LockVault",
"/keepassgo.v1.VaultService/UnlockVault":
if _, err := server.authenticateRequest(ctx); err != nil {
return nil, err
}
}
return handler(ctx, req)
}
}
File diff suppressed because it is too large Load Diff
+215
View File
@@ -0,0 +1,215 @@
package apiapproval
import (
"context"
"errors"
"fmt"
"slices"
"strconv"
"sync"
"time"
"git.julianfamily.org/keepassgo/internal/apitokens"
)
var (
ErrRequestDenied = errors.New("authorization request denied")
ErrRequestCanceled = errors.New("authorization request canceled")
ErrRequestTimedOut = errors.New("authorization request timed out")
ErrRequestNotFound = errors.New("authorization request not found")
)
type Outcome string
const (
OutcomeAllowOnce Outcome = "allow-once"
OutcomeDenyOnce Outcome = "deny-once"
OutcomeAllowPermanent Outcome = "allow-permanent"
OutcomeDenyPermanent Outcome = "deny-permanent"
OutcomeCancel Outcome = "cancel"
)
type Request struct {
ID string
TokenID string
TokenName string
ClientName string
Operation apitokens.Operation
Resource apitokens.Resource
RequestedAt time.Time
}
type Result struct {
Outcome Outcome
Rule *apitokens.PolicyRule
}
type Broker struct {
mu sync.Mutex
pending map[string]*pendingRequest
timeout time.Duration
now func() time.Time
nextID func() string
}
type pendingRequest struct {
request Request
done chan Outcome
}
type idGenerator struct {
mu sync.Mutex
counter int
}
func NewBroker(timeout time.Duration) *Broker {
gen := &idGenerator{}
return &Broker{
pending: map[string]*pendingRequest{},
timeout: timeout,
now: func() time.Time {
return time.Now().UTC()
},
nextID: func() string {
return gen.Next()
},
}
}
func (g *idGenerator) Next() string {
g.mu.Lock()
defer g.mu.Unlock()
g.counter++
return "approval-" + strconv.Itoa(g.counter)
}
func (b *Broker) Pending() []Request {
b.mu.Lock()
defer b.mu.Unlock()
requests := make([]Request, 0, len(b.pending))
for _, pending := range b.pending {
requests = append(requests, pending.request)
}
slices.SortFunc(requests, func(a, c Request) int {
switch {
case a.RequestedAt.Before(c.RequestedAt):
return -1
case a.RequestedAt.After(c.RequestedAt):
return 1
case a.ID < c.ID:
return -1
case a.ID > c.ID:
return 1
default:
return 0
}
})
return requests
}
func (b *Broker) Request(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (Result, error) {
if b == nil {
return Result{}, ErrRequestTimedOut
}
pending := &pendingRequest{
request: Request{
ID: b.nextID(),
TokenID: token.ID,
TokenName: token.Name,
ClientName: token.ClientName,
Operation: op,
Resource: resource,
RequestedAt: b.now(),
},
done: make(chan Outcome, 1),
}
b.mu.Lock()
b.pending[pending.request.ID] = pending
b.mu.Unlock()
defer func() {
b.mu.Lock()
delete(b.pending, pending.request.ID)
b.mu.Unlock()
}()
timer := time.NewTimer(b.timeout)
defer timer.Stop()
select {
case outcome := <-pending.done:
result, err := resultForOutcome(pending.request, outcome)
if err != nil {
return result, err
}
return result, nil
case <-timer.C:
return Result{}, ErrRequestTimedOut
case <-ctx.Done():
return Result{}, ctx.Err()
}
}
func (b *Broker) Resolve(id string, outcome Outcome) (Request, *apitokens.PolicyRule, error) {
b.mu.Lock()
pending, ok := b.pending[id]
b.mu.Unlock()
if !ok {
return Request{}, nil, ErrRequestNotFound
}
rule, err := RuleFromDecision(pending.request, outcome)
if err != nil {
return Request{}, nil, err
}
select {
case pending.done <- outcome:
default:
}
return pending.request, rule, nil
}
func resultForOutcome(request Request, outcome Outcome) (Result, error) {
rule, err := RuleFromDecision(request, outcome)
if err != nil {
return Result{}, err
}
result := Result{Outcome: outcome, Rule: rule}
switch outcome {
case OutcomeAllowOnce, OutcomeAllowPermanent:
return result, nil
case OutcomeDenyOnce, OutcomeDenyPermanent:
return result, ErrRequestDenied
case OutcomeCancel:
return result, ErrRequestCanceled
default:
return Result{}, fmt.Errorf("unsupported approval outcome %q", outcome)
}
}
func RuleFromDecision(request Request, outcome Outcome) (*apitokens.PolicyRule, error) {
switch outcome {
case OutcomeAllowPermanent:
rule := apitokens.PolicyRule{
Effect: apitokens.EffectAllow,
Operation: request.Operation,
Resource: request.Resource,
}
return &rule, nil
case OutcomeDenyPermanent:
rule := apitokens.PolicyRule{
Effect: apitokens.EffectDeny,
Operation: request.Operation,
Resource: request.Resource,
}
return &rule, nil
case OutcomeAllowOnce, OutcomeDenyOnce, OutcomeCancel:
return nil, nil
default:
return nil, fmt.Errorf("unsupported approval outcome %q", outcome)
}
}
+134
View File
@@ -0,0 +1,134 @@
package apiapproval
import (
"context"
"errors"
"testing"
"time"
"git.julianfamily.org/keepassgo/internal/apitokens"
)
func TestBrokerCreatesPendingRequestAndAllowsOnce(t *testing.T) {
t.Parallel()
broker := NewBroker(time.Minute)
broker.now = func() time.Time { return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) }
resultCh := make(chan Result, 1)
errCh := make(chan error, 1)
go func() {
result, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI", ClientName: "grpc-cli"}, apitokens.OperationListEntries, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}})
resultCh <- result
errCh <- err
}()
waitForPending(t, broker, 1)
pending := broker.Pending()
if len(pending) != 1 {
t.Fatalf("Pending() len = %d, want 1", len(pending))
}
if pending[0].TokenID != "token-1" {
t.Fatalf("Pending()[0].TokenID = %q, want token-1", pending[0].TokenID)
}
if _, _, err := broker.Resolve(pending[0].ID, OutcomeAllowOnce); err != nil {
t.Fatalf("Resolve(allow once) error = %v", err)
}
result := <-resultCh
if err := <-errCh; err != nil {
t.Fatalf("Request() error = %v, want nil", err)
}
if result.Outcome != OutcomeAllowOnce {
t.Fatalf("Request() outcome = %q, want %q", result.Outcome, OutcomeAllowOnce)
}
if result.Rule != nil {
t.Fatalf("Request() rule = %#v, want nil for allow-once", result.Rule)
}
if got := broker.Pending(); len(got) != 0 {
t.Fatalf("Pending() after allow len = %d, want 0", len(got))
}
}
func TestBrokerReturnsPermanentRuleForDeny(t *testing.T) {
t.Parallel()
broker := NewBroker(time.Minute)
reqDone := make(chan struct{})
var result Result
var err error
go func() {
result, err = broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationReadEntry, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: "entry-1", Path: []string{"Root", "Internet"}})
close(reqDone)
}()
waitForPending(t, broker, 1)
pending := broker.Pending()[0]
request, rule, resolveErr := broker.Resolve(pending.ID, OutcomeDenyPermanent)
if resolveErr != nil {
t.Fatalf("Resolve(deny permanent) error = %v", resolveErr)
}
if request.ID != pending.ID {
t.Fatalf("Resolve().ID = %q, want %q", request.ID, pending.ID)
}
if rule == nil || rule.Effect != apitokens.EffectDeny || rule.Operation != apitokens.OperationReadEntry {
t.Fatalf("Resolve() rule = %#v, want deny read-entry rule", rule)
}
<-reqDone
if !errors.Is(err, ErrRequestDenied) {
t.Fatalf("Request() error = %v, want ErrRequestDenied", err)
}
if result.Rule == nil || result.Rule.Effect != apitokens.EffectDeny {
t.Fatalf("Request() rule = %#v, want deny rule", result.Rule)
}
if result.Outcome != OutcomeDenyPermanent {
t.Fatalf("Request() outcome = %q, want %q", result.Outcome, OutcomeDenyPermanent)
}
}
func TestBrokerSupportsCancellation(t *testing.T) {
t.Parallel()
broker := NewBroker(time.Minute)
errCh := make(chan error, 1)
go func() {
_, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}})
errCh <- err
}()
waitForPending(t, broker, 1)
if _, _, err := broker.Resolve(broker.Pending()[0].ID, OutcomeCancel); err != nil {
t.Fatalf("Resolve(cancel) error = %v", err)
}
if err := <-errCh; !errors.Is(err, ErrRequestCanceled) {
t.Fatalf("Request() error = %v, want ErrRequestCanceled", err)
}
}
func TestBrokerTimesOutPendingRequests(t *testing.T) {
t.Parallel()
broker := NewBroker(10 * time.Millisecond)
_, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}})
if !errors.Is(err, ErrRequestTimedOut) {
t.Fatalf("Request() error = %v, want ErrRequestTimedOut", err)
}
if got := broker.Pending(); len(got) != 0 {
t.Fatalf("Pending() len after timeout = %d, want 0", len(got))
}
}
func waitForPending(t *testing.T, broker *Broker, want int) {
t.Helper()
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if got := len(broker.Pending()); got == want {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("Pending() never reached len %d", want)
}
+80
View File
@@ -0,0 +1,80 @@
package apiaudit
import (
"slices"
"sync"
"time"
"git.julianfamily.org/keepassgo/internal/apitokens"
)
type EventType string
const (
EventApprovalRequested EventType = "approval_requested"
EventApprovalAllowed EventType = "approval_allowed"
EventApprovalDenied EventType = "approval_denied"
EventApprovalCanceled EventType = "approval_canceled"
EventApprovalTimedOut EventType = "approval_timed_out"
EventAutofillFound EventType = "autofill_found"
EventAutofillAmbiguous EventType = "autofill_ambiguous"
EventAutofillBlocked EventType = "autofill_blocked"
EventAuthRejected EventType = "auth_rejected"
)
type Event struct {
Type EventType
At time.Time
TokenID string
TokenName string
ClientName string
Operation apitokens.Operation
Resource apitokens.Resource
Message string
}
type Log struct {
mu sync.Mutex
max int
now func() time.Time
events []Event
}
func New(max int) *Log {
if max < 1 {
max = 1
}
return &Log{
max: max,
now: func() time.Time {
return time.Now().UTC()
},
}
}
func (l *Log) Record(event Event) {
if l == nil {
return
}
l.mu.Lock()
defer l.mu.Unlock()
if event.At.IsZero() {
event.At = l.now()
}
l.events = append([]Event{event}, l.events...)
if len(l.events) > l.max {
l.events = l.events[:l.max]
}
}
func (l *Log) Events() []Event {
if l == nil {
return nil
}
l.mu.Lock()
defer l.mu.Unlock()
return slices.Clone(l.events)
}
+68
View File
@@ -0,0 +1,68 @@
package apiaudit
import (
"testing"
"time"
"git.julianfamily.org/keepassgo/internal/apitokens"
)
func TestLogKeepsNewestEventsWithinBound(t *testing.T) {
t.Parallel()
log := New(2)
log.now = func() time.Time { return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) }
log.Record(Event{Type: EventApprovalRequested, TokenID: "token-1"})
log.Record(Event{Type: EventApprovalAllowed, TokenID: "token-2"})
log.Record(Event{Type: EventApprovalDenied, TokenID: "token-3"})
events := log.Events()
if len(events) != 2 {
t.Fatalf("len(Events()) = %d, want 2", len(events))
}
if events[0].TokenID != "token-3" || events[1].TokenID != "token-2" {
t.Fatalf("Events() = %#v, want newest-first bounded list", events)
}
}
func TestLogPreservesRecordedMetadata(t *testing.T) {
t.Parallel()
log := New(5)
log.Record(Event{
Type: EventApprovalRequested,
TokenID: "token-1",
TokenName: "CLI",
ClientName: "grpc-cli",
Operation: apitokens.OperationListEntries,
Resource: apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}},
Message: "prompted for access",
})
events := log.Events()
if len(events) != 1 {
t.Fatalf("len(Events()) = %d, want 1", len(events))
}
if events[0].Operation != apitokens.OperationListEntries || events[0].Message != "prompted for access" {
t.Fatalf("Events()[0] = %#v, want preserved metadata", events[0])
}
}
func TestLogStoresAutofillEventTypes(t *testing.T) {
t.Parallel()
log := New(5)
log.Record(Event{
Type: EventAutofillAmbiguous,
TokenName: "Browser Extension",
Message: "multiple matches for example.com",
})
events := log.Events()
if len(events) != 1 {
t.Fatalf("len(Events()) = %d, want 1", len(events))
}
if events[0].Type != EventAutofillAmbiguous {
t.Fatalf("Events()[0].Type = %q, want %q", events[0].Type, EventAutofillAmbiguous)
}
}
+335
View File
@@ -0,0 +1,335 @@
package apitokens
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
"time"
"git.julianfamily.org/keepassgo/internal/vault"
)
const (
EntryTypeAPIToken = "api-token"
FieldType = "KeePassGO-Type"
FieldTokenID = "KeePassGO-API-Token-ID"
FieldClientName = "KeePassGO-API-Client-Name"
FieldCreatedAt = "KeePassGO-API-Created-At"
FieldExpiresAt = "KeePassGO-API-Expires-At"
FieldDisabled = "KeePassGO-API-Disabled"
FieldRevokedAt = "KeePassGO-API-Revoked-At"
FieldSecretHash = "KeePassGO-API-Secret-Hash"
FieldPolicies = "KeePassGO-API-Policies"
)
var (
ErrNotAToken = errors.New("entry is not an api token")
ErrInvalidToken = errors.New("invalid api token")
ErrExpiredToken = errors.New("expired api token")
ErrDisabledToken = errors.New("disabled api token")
ErrTokenNotFound = errors.New("api token not found")
)
var EntryPath = []string{"Root", "API Tokens"}
type Effect string
type Operation string
type ResourceKind string
type Decision string
const (
EffectAllow Effect = "allow"
EffectDeny Effect = "deny"
ResourceGroup ResourceKind = "group"
ResourceEntry ResourceKind = "entry"
DecisionAllow Decision = "allow"
DecisionDeny Decision = "deny"
DecisionPrompt Decision = "prompt"
OperationListEntries Operation = "list_entries"
OperationListGroups Operation = "list_groups"
OperationReadEntry Operation = "read_entry"
OperationCopyPassword Operation = "copy_password"
OperationCopyUsername Operation = "copy_username"
OperationCopyURL Operation = "copy_url"
OperationMutateEntry Operation = "mutate_entry"
OperationMutateGroup Operation = "mutate_group"
OperationManageVault Operation = "manage_vault"
)
type Resource struct {
Kind ResourceKind `json:"kind"`
Path []string `json:"path,omitempty"`
EntryID string `json:"entry_id,omitempty"`
}
type PolicyRule struct {
Effect Effect `json:"effect"`
Operation Operation `json:"operation"`
Resource Resource `json:"resource"`
}
type Token struct {
ID string
Name string
ClientName string
SecretHash string
CreatedAt time.Time
ExpiresAt *time.Time
RevokedAt *time.Time
Disabled bool
Policies []PolicyRule
}
func Issue(name, clientName string, expiresAt *time.Time, now time.Time) (Token, string, error) {
clear, hashed, err := newSecret()
if err != nil {
return Token{}, "", err
}
id, _, err := newSecret()
if err != nil {
return Token{}, "", err
}
return Token{
ID: id,
Name: strings.TrimSpace(name),
ClientName: strings.TrimSpace(clientName),
SecretHash: hashed,
CreatedAt: now.UTC(),
ExpiresAt: cloneTime(expiresAt),
}, clear, nil
}
func Rotate(token Token, now time.Time) (Token, string, error) {
clear, hashed, err := newSecret()
if err != nil {
return Token{}, "", err
}
token.SecretHash = hashed
token.Disabled = false
token.RevokedAt = nil
if token.CreatedAt.IsZero() {
token.CreatedAt = now.UTC()
}
return token, clear, nil
}
func Disable(token Token) Token {
token.Disabled = true
return token
}
func Revoke(token Token, when time.Time) Token {
token.Disabled = true
t := when.UTC()
token.RevokedAt = &t
return token
}
func Authenticate(tokens []Token, presentedSecret string, now time.Time) (Token, error) {
hashed := hashSecret(presentedSecret)
for _, token := range tokens {
if token.SecretHash != hashed {
continue
}
if token.Disabled || token.RevokedAt != nil {
return Token{}, ErrDisabledToken
}
if token.ExpiresAt != nil && !token.ExpiresAt.After(now.UTC()) {
return Token{}, ErrExpiredToken
}
return token, nil
}
return Token{}, ErrInvalidToken
}
func Evaluate(token Token, operation Operation, resource Resource) Decision {
decision := DecisionPrompt
for _, rule := range token.Policies {
if rule.Operation != operation {
continue
}
if !matches(rule.Resource, resource) {
continue
}
if rule.Effect == EffectDeny {
return DecisionDeny
}
if rule.Effect == EffectAllow {
decision = DecisionAllow
}
}
return decision
}
func Entries(model vault.Model) ([]Token, error) {
var out []Token
for _, entry := range model.Entries {
token, ok, err := TokenFromEntry(entry)
if err != nil {
return nil, err
}
if ok {
out = append(out, token)
}
}
slices.SortFunc(out, func(a, b Token) int {
switch {
case a.Name < b.Name:
return -1
case a.Name > b.Name:
return 1
default:
return strings.Compare(a.ID, b.ID)
}
})
return out, nil
}
func Find(model vault.Model, id string) (Token, error) {
tokens, err := Entries(model)
if err != nil {
return Token{}, err
}
for _, token := range tokens {
if token.ID == id {
return token, nil
}
}
return Token{}, ErrTokenNotFound
}
func Upsert(model *vault.Model, token Token) {
model.UpsertEntry(token.Entry(EntryPath))
model.CreateGroup([]string{"Root"}, "API Tokens")
}
func Delete(model *vault.Model, id string) error {
for i, entry := range model.Entries {
token, ok, err := TokenFromEntry(entry)
if err != nil {
return err
}
if ok && token.ID == id {
model.Entries = append(model.Entries[:i], model.Entries[i+1:]...)
return nil
}
}
return ErrTokenNotFound
}
func TokenFromEntry(entry vault.Entry) (Token, bool, error) {
if entry.Fields[FieldType] != EntryTypeAPIToken {
return Token{}, false, nil
}
createdAt, err := time.Parse(time.RFC3339, entry.Fields[FieldCreatedAt])
if err != nil {
return Token{}, true, fmt.Errorf("parse created at: %w", err)
}
var expiresAt *time.Time
if raw := strings.TrimSpace(entry.Fields[FieldExpiresAt]); raw != "" {
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return Token{}, true, fmt.Errorf("parse expires at: %w", err)
}
expiresAt = &t
}
var revokedAt *time.Time
if raw := strings.TrimSpace(entry.Fields[FieldRevokedAt]); raw != "" {
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return Token{}, true, fmt.Errorf("parse revoked at: %w", err)
}
revokedAt = &t
}
policies := []PolicyRule{}
if raw := strings.TrimSpace(entry.Fields[FieldPolicies]); raw != "" {
if err := json.Unmarshal([]byte(raw), &policies); err != nil {
return Token{}, true, fmt.Errorf("parse policies: %w", err)
}
}
return Token{
ID: entry.Fields[FieldTokenID],
Name: entry.Title,
ClientName: entry.Fields[FieldClientName],
SecretHash: entry.Fields[FieldSecretHash],
CreatedAt: createdAt,
ExpiresAt: expiresAt,
RevokedAt: revokedAt,
Disabled: strings.EqualFold(entry.Fields[FieldDisabled], "true"),
Policies: policies,
}, true, nil
}
func (t Token) Entry(path []string) vault.Entry {
fields := map[string]string{
FieldType: EntryTypeAPIToken,
FieldTokenID: t.ID,
FieldClientName: t.ClientName,
FieldCreatedAt: t.CreatedAt.UTC().Format(time.RFC3339),
FieldDisabled: fmt.Sprintf("%t", t.Disabled),
FieldSecretHash: t.SecretHash,
}
if t.ExpiresAt != nil {
fields[FieldExpiresAt] = t.ExpiresAt.UTC().Format(time.RFC3339)
}
if t.RevokedAt != nil {
fields[FieldRevokedAt] = t.RevokedAt.UTC().Format(time.RFC3339)
}
if len(t.Policies) > 0 {
data, _ := json.Marshal(t.Policies)
fields[FieldPolicies] = string(data)
}
return vault.Entry{
ID: t.ID,
Title: t.Name,
Username: t.ClientName,
Path: slices.Clone(path),
Fields: fields,
}
}
func hashSecret(secret string) string {
sum := sha256.Sum256([]byte(secret))
return hex.EncodeToString(sum[:])
}
func newSecret() (string, string, error) {
buf := make([]byte, 24)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate secret: %w", err)
}
clear := base64.RawURLEncoding.EncodeToString(buf)
return clear, hashSecret(clear), nil
}
func cloneTime(in *time.Time) *time.Time {
if in == nil {
return nil
}
t := in.UTC()
return &t
}
func matches(rule, resource Resource) bool {
switch rule.Kind {
case ResourceEntry:
return rule.EntryID != "" && rule.EntryID == resource.EntryID
case ResourceGroup:
if len(rule.Path) > len(resource.Path) {
return false
}
return slices.Equal(rule.Path, resource.Path[:len(rule.Path)])
default:
return false
}
}
+282
View File
@@ -0,0 +1,282 @@
package apitokens
import (
"slices"
"testing"
"time"
"git.julianfamily.org/keepassgo/internal/vault"
)
func TestTokenEntryRoundTripsThroughVaultEntry(t *testing.T) {
t.Parallel()
expiresAt := time.Date(2026, 4, 1, 12, 0, 0, 0, time.UTC)
token := Token{
ID: "token-1",
Name: "Browser Connector",
ClientName: "browser-extension",
SecretHash: "deadbeef",
CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC),
ExpiresAt: &expiresAt,
Disabled: true,
Policies: []PolicyRule{
{Effect: EffectAllow, Operation: OperationListEntries, Resource: Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}}},
{Effect: EffectDeny, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceEntry, EntryID: "bank-token"}},
},
}
entry := token.Entry([]string{"Root", "API Tokens"})
if entry.Fields[FieldType] != EntryTypeAPIToken {
t.Fatalf("FieldType = %q, want %q", entry.Fields[FieldType], EntryTypeAPIToken)
}
got, ok, err := TokenFromEntry(entry)
if err != nil {
t.Fatalf("TokenFromEntry() error = %v", err)
}
if !ok {
t.Fatal("TokenFromEntry() ok = false, want true")
}
if got.ID != token.ID || got.Name != token.Name || got.ClientName != token.ClientName || got.SecretHash != token.SecretHash || !got.Disabled {
t.Fatalf("TokenFromEntry() = %#v, want %#v", got, token)
}
if got.ExpiresAt == nil || !got.ExpiresAt.Equal(expiresAt) {
t.Fatalf("ExpiresAt = %#v, want %v", got.ExpiresAt, expiresAt)
}
if len(got.Policies) != 2 {
t.Fatalf("len(Policies) = %d, want 2", len(got.Policies))
}
}
func TestEntriesFiltersOnlyAPITokens(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
{ID: "entry-1", Title: "Ordinary Entry"},
Token{ID: "token-1", Name: "CLI", SecretHash: "hash-1", CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)}.Entry([]string{"Root", "API Tokens"}),
Token{ID: "token-2", Name: "Browser", SecretHash: "hash-2", CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)}.Entry([]string{"Root", "API Tokens"}),
},
}
got, err := Entries(model)
if err != nil {
t.Fatalf("Entries() error = %v", err)
}
if len(got) != 2 {
t.Fatalf("len(Entries()) = %d, want 2", len(got))
}
if got[0].Name != "Browser" || got[1].Name != "CLI" {
t.Fatalf("Entries() = %#v, want Browser then CLI by name sort", got)
}
}
func TestUpsertCreatesAndUpdatesVaultTokenEntry(t *testing.T) {
t.Parallel()
model := vault.Model{}
token := Token{
ID: "token-1",
Name: "CLI",
ClientName: "grpc-cli",
SecretHash: "hash-1",
CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC),
}
Upsert(&model, token)
if len(model.Entries) != 1 {
t.Fatalf("len(model.Entries) = %d, want 1", len(model.Entries))
}
if got := model.ChildGroups([]string{"Root"}); !slices.Equal(got, []string{"API Tokens"}) {
t.Fatalf("ChildGroups(Root) = %v, want [API Tokens]", got)
}
token.ClientName = "rotated-client"
token.Disabled = true
Upsert(&model, token)
got, err := Find(model, "token-1")
if err != nil {
t.Fatalf("Find() error = %v", err)
}
if got.ClientName != "rotated-client" || !got.Disabled {
t.Fatalf("Find() = %#v, want updated client name and disabled state", got)
}
}
func TestDeleteRemovesTokenEntryByTokenID(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
Token{ID: "token-1", Name: "CLI", SecretHash: "hash-1", CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)}.Entry(EntryPath),
{ID: "entry-1", Title: "Regular Entry"},
},
}
if err := Delete(&model, "token-1"); err != nil {
t.Fatalf("Delete() error = %v", err)
}
if len(model.Entries) != 1 || model.Entries[0].ID != "entry-1" {
t.Fatalf("model.Entries after Delete = %#v, want only regular entry", model.Entries)
}
if err := Delete(&model, "missing"); err != ErrTokenNotFound {
t.Fatalf("Delete(missing) error = %v, want %v", err, ErrTokenNotFound)
}
}
func TestIssueRotateDisableAndRevokeToken(t *testing.T) {
t.Parallel()
now := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
token, secret, err := Issue("Browser Connector", "browser-extension", nil, now)
if err != nil {
t.Fatalf("Issue() error = %v", err)
}
if token.ID == "" || secret == "" {
t.Fatalf("Issue() returned empty token or secret: %#v %q", token, secret)
}
if token.SecretHash == secret {
t.Fatal("SecretHash should not equal cleartext secret")
}
if token.Disabled {
t.Fatal("Disabled = true, want false after Issue")
}
rotated, newSecret, err := Rotate(token, now.Add(time.Hour))
if err != nil {
t.Fatalf("Rotate() error = %v", err)
}
if rotated.ID != token.ID {
t.Fatalf("Rotate() changed ID from %q to %q", token.ID, rotated.ID)
}
if newSecret == secret {
t.Fatal("Rotate() returned the same cleartext secret, want a new one")
}
if rotated.SecretHash == token.SecretHash {
t.Fatal("Rotate() left SecretHash unchanged")
}
disabled := Disable(rotated)
if !disabled.Disabled {
t.Fatal("Disable() did not set Disabled")
}
revoked := Revoke(disabled, now.Add(2*time.Hour))
if !revoked.Disabled {
t.Fatal("Revoke() should leave token disabled")
}
if revoked.RevokedAt == nil || !revoked.RevokedAt.Equal(now.Add(2*time.Hour)) {
t.Fatalf("RevokedAt = %#v, want %v", revoked.RevokedAt, now.Add(2*time.Hour))
}
}
func TestAuthenticateRejectsDisabledExpiredAndWrongSecret(t *testing.T) {
t.Parallel()
now := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
valid, secret, err := Issue("CLI", "cli-tool", nil, now)
if err != nil {
t.Fatalf("Issue() error = %v", err)
}
expired, expiredSecret, err := Issue("Expired", "cli-tool", nil, now)
if err != nil {
t.Fatalf("Issue(expired) error = %v", err)
}
expiredAt := now.Add(-time.Minute)
expired.ExpiresAt = &expiredAt
disabled, disabledSecret, err := Issue("Disabled", "cli-tool", nil, now)
if err != nil {
t.Fatalf("Issue(disabled) error = %v", err)
}
disabled = Disable(disabled)
tokens := []Token{expired, disabled, valid}
if _, err := Authenticate(tokens, "wrong-secret", now); err != ErrInvalidToken {
t.Fatalf("Authenticate(wrong-secret) error = %v, want %v", err, ErrInvalidToken)
}
if _, err := Authenticate([]Token{expired}, expiredSecret, now); err != ErrExpiredToken {
t.Fatalf("Authenticate(expired) error = %v, want %v", err, ErrExpiredToken)
}
if _, err := Authenticate([]Token{disabled}, disabledSecret, now); err != ErrDisabledToken {
t.Fatalf("Authenticate(disabled) error = %v, want %v", err, ErrDisabledToken)
}
got, err := Authenticate(tokens, secret, now)
if err != nil {
t.Fatalf("Authenticate(valid) error = %v", err)
}
if got.ID != valid.ID {
t.Fatalf("Authenticate(valid).ID = %q, want %q", got.ID, valid.ID)
}
}
func TestEvaluatePolicyDistinguishesAllowDenyAndPrompt(t *testing.T) {
t.Parallel()
token := Token{
ID: "token-1",
Policies: []PolicyRule{
{Effect: EffectAllow, Operation: OperationListEntries, Resource: Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}}},
{Effect: EffectDeny, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceEntry, EntryID: "amazon"}},
},
}
decision := Evaluate(token, OperationListEntries, Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}})
if decision != DecisionAllow {
t.Fatalf("Evaluate(allow) = %q, want %q", decision, DecisionAllow)
}
decision = Evaluate(token, OperationCopyPassword, Resource{Kind: ResourceEntry, EntryID: "amazon", Path: []string{"Root", "Internet"}})
if decision != DecisionDeny {
t.Fatalf("Evaluate(deny) = %q, want %q", decision, DecisionDeny)
}
decision = Evaluate(token, OperationCopyPassword, Resource{Kind: ResourceEntry, EntryID: "github", Path: []string{"Root", "Internet"}})
if decision != DecisionPrompt {
t.Fatalf("Evaluate(prompt) = %q, want %q", decision, DecisionPrompt)
}
}
func TestEntryScopedDenyOverridesGroupAllow(t *testing.T) {
t.Parallel()
token := Token{
ID: "token-1",
Policies: []PolicyRule{
{Effect: EffectAllow, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}}},
{Effect: EffectDeny, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceEntry, EntryID: "bank", Path: []string{"Root", "Internet"}}},
},
}
allowed := Evaluate(token, OperationCopyPassword, Resource{Kind: ResourceEntry, EntryID: "forum", Path: []string{"Root", "Internet"}})
denied := Evaluate(token, OperationCopyPassword, Resource{Kind: ResourceEntry, EntryID: "bank", Path: []string{"Root", "Internet"}})
if allowed != DecisionAllow || denied != DecisionDeny {
t.Fatalf("Evaluate() allow/deny = %q/%q, want %q/%q", allowed, denied, DecisionAllow, DecisionDeny)
}
}
func TestTokenEntryKeepsPoliciesStableAcrossRoundTripOrdering(t *testing.T) {
t.Parallel()
token := Token{
ID: "token-1",
Name: "CLI",
ClientName: "cli",
SecretHash: "hash",
CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC),
Policies: []PolicyRule{
{Effect: EffectDeny, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceEntry, EntryID: "bank"}},
{Effect: EffectAllow, Operation: OperationListEntries, Resource: Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}}},
},
}
got, ok, err := TokenFromEntry(token.Entry([]string{"Root", "API Tokens"}))
if err != nil || !ok {
t.Fatalf("TokenFromEntry() error = %v, ok=%v", err, ok)
}
if !slices.EqualFunc(got.Policies, token.Policies, func(a, b PolicyRule) bool {
return a.Effect == b.Effect && a.Operation == b.Operation && a.Resource.Kind == b.Resource.Kind && a.Resource.EntryID == b.Resource.EntryID && slices.Equal(a.Resource.Path, b.Resource.Path)
}) {
t.Fatalf("Policies after round trip = %#v, want %#v", got.Policies, token.Policies)
}
}
+141
View File
@@ -0,0 +1,141 @@
package appstate
import (
"fmt"
"strings"
"git.julianfamily.org/keepassgo/internal/vault"
)
type SyncMode string
const (
SyncModeManual SyncMode = "manual"
SyncModeAutomaticOnOpenSave SyncMode = "automatic_on_open_save"
)
type RemoteBinding struct {
LocalVaultPath string `json:"localVaultPath"`
RemoteProfileID string `json:"remoteProfileId"`
CredentialEntryID string `json:"credentialEntryId"`
SyncMode SyncMode `json:"syncMode,omitempty"`
}
type ResolvedRemoteBinding struct {
Profile vault.RemoteProfile
Credentials vault.Entry
}
type RemoteBindingInput struct {
LocalVaultPath string
RemoteProfileID string
RemoteProfileName string
BaseURL string
RemotePath string
CredentialEntryID string
CredentialTitle string
Username string
Password string
CredentialPath []string
SyncMode SyncMode
}
func (b RemoteBinding) Resolve(model vault.Model) (ResolvedRemoteBinding, error) {
profile, err := model.RemoteProfileByID(b.RemoteProfileID)
if err != nil {
return ResolvedRemoteBinding{}, fmt.Errorf("resolve remote profile: %w", err)
}
credentials, err := model.EntryByID(b.CredentialEntryID)
if err != nil {
return ResolvedRemoteBinding{}, fmt.Errorf("resolve remote credentials: %w", err)
}
return ResolvedRemoteBinding{
Profile: profile,
Credentials: credentials,
}, nil
}
func ConfigureRemoteBinding(model *vault.Model, input RemoteBindingInput) (RemoteBinding, error) {
if model == nil {
return RemoteBinding{}, fmt.Errorf("model is required")
}
input.LocalVaultPath = strings.TrimSpace(input.LocalVaultPath)
input.RemoteProfileID = strings.TrimSpace(input.RemoteProfileID)
input.RemoteProfileName = strings.TrimSpace(input.RemoteProfileName)
input.BaseURL = strings.TrimSpace(input.BaseURL)
input.RemotePath = strings.TrimSpace(input.RemotePath)
input.CredentialEntryID = strings.TrimSpace(input.CredentialEntryID)
input.CredentialTitle = strings.TrimSpace(input.CredentialTitle)
input.Username = strings.TrimSpace(input.Username)
switch {
case input.LocalVaultPath == "":
return RemoteBinding{}, fmt.Errorf("local vault path is required")
case input.RemoteProfileID == "":
return RemoteBinding{}, fmt.Errorf("remote profile id is required")
case input.BaseURL == "":
return RemoteBinding{}, fmt.Errorf("remote base URL is required")
case input.RemotePath == "":
return RemoteBinding{}, fmt.Errorf("remote path is required")
case input.CredentialEntryID == "":
return RemoteBinding{}, fmt.Errorf("credential entry id is required")
case input.Password == "":
return RemoteBinding{}, fmt.Errorf("credential password is required")
}
if input.RemoteProfileName == "" {
input.RemoteProfileName = input.RemoteProfileID
}
if input.CredentialTitle == "" {
input.CredentialTitle = "Remote Sign-In"
}
model.UpsertRemoteProfile(vault.RemoteProfile{
ID: input.RemoteProfileID,
Name: input.RemoteProfileName,
Backend: vault.RemoteBackendWebDAV,
BaseURL: input.BaseURL,
Path: input.RemotePath,
})
model.UpsertEntry(vault.Entry{
ID: input.CredentialEntryID,
Title: input.CredentialTitle,
Username: input.Username,
Password: input.Password,
URL: input.BaseURL,
Path: append([]string(nil), input.CredentialPath...),
})
return RemoteBinding{
LocalVaultPath: input.LocalVaultPath,
RemoteProfileID: input.RemoteProfileID,
CredentialEntryID: input.CredentialEntryID,
SyncMode: normalizeSyncMode(input.SyncMode),
}, nil
}
func RemoveRemoteBinding(model *vault.Model, binding RemoteBinding) error {
if model == nil {
return fmt.Errorf("model is required")
}
if strings.TrimSpace(binding.RemoteProfileID) == "" {
return fmt.Errorf("remote profile id is required")
}
if strings.TrimSpace(binding.CredentialEntryID) == "" {
return fmt.Errorf("credential entry id is required")
}
model.RemoveRemoteProfileByID(binding.RemoteProfileID)
model.RemoveEntryByID(binding.CredentialEntryID)
return nil
}
func normalizeSyncMode(mode SyncMode) SyncMode {
switch mode {
case SyncModeAutomaticOnOpenSave:
return SyncModeAutomaticOnOpenSave
default:
return SyncModeManual
}
}
+250
View File
@@ -0,0 +1,250 @@
package appstate
import (
"encoding/json"
"errors"
"strings"
"testing"
"git.julianfamily.org/keepassgo/internal/vault"
)
func TestRemoteBindingResolveUsesVaultProfileAndCredentialEntry(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
{
ID: "linuscaldwell-webdav",
Title: "Bellagio WebDAV Sign-In",
Username: "linuscaldwell",
Password: "bellagio-pass-1",
Path: []string{"Crew", "Internet"},
},
},
RemoteProfiles: []vault.RemoteProfile{
{
ID: "bellagio-webdav",
Name: "Bellagio Vault",
Backend: vault.RemoteBackendWebDAV,
BaseURL: "https://dav.example.invalid/remote.php/dav",
Path: "files/bellagio/keepass.kdbx",
},
},
}
binding := RemoteBinding{
LocalVaultPath: "/tmp/bellagio.kdbx",
RemoteProfileID: "bellagio-webdav",
CredentialEntryID: "linuscaldwell-webdav",
SyncMode: SyncModeAutomaticOnOpenSave,
}
resolved, err := binding.Resolve(model)
if err != nil {
t.Fatalf("Resolve() error = %v", err)
}
if got := resolved.Profile.BaseURL; got != "https://dav.example.invalid/remote.php/dav" {
t.Fatalf("resolved profile base URL = %q, want remote.php/dav URL", got)
}
if got := resolved.Profile.Path; got != "files/bellagio/keepass.kdbx" {
t.Fatalf("resolved profile path = %q, want files/bellagio/keepass.kdbx", got)
}
if got := resolved.Credentials.Username; got != "linuscaldwell" {
t.Fatalf("resolved credentials username = %q, want linuscaldwell", got)
}
if got := resolved.Credentials.Password; got != "bellagio-pass-1" {
t.Fatalf("resolved credentials password = %q, want bellagio-pass-1", got)
}
}
func TestRemoteBindingResolveFailsWhenVaultReferenceIsMissing(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
{ID: "linuscaldwell-webdav", Title: "Bellagio WebDAV Sign-In"},
},
}
_, err := (RemoteBinding{
LocalVaultPath: "/tmp/bellagio.kdbx",
RemoteProfileID: "bellagio-webdav",
CredentialEntryID: "missing-creds",
}).Resolve(model)
if !errors.Is(err, vault.ErrRemoteProfileNotFound) {
t.Fatalf("Resolve() error = %v, want ErrRemoteProfileNotFound first", err)
}
model.RemoteProfiles = []vault.RemoteProfile{{
ID: "bellagio-webdav",
Name: "Bellagio Vault",
Backend: vault.RemoteBackendWebDAV,
BaseURL: "https://dav.example.invalid/remote.php/dav",
Path: "files/bellagio/keepass.kdbx",
}}
_, err = (RemoteBinding{
LocalVaultPath: "/tmp/bellagio.kdbx",
RemoteProfileID: "bellagio-webdav",
CredentialEntryID: "missing-creds",
}).Resolve(model)
if !errors.Is(err, vault.ErrEntryNotFound) {
t.Fatalf("Resolve() error = %v, want ErrEntryNotFound", err)
}
}
func TestRemoteBindingJSONStoresOnlyNonSecretReferences(t *testing.T) {
t.Parallel()
content, err := json.Marshal(RemoteBinding{
LocalVaultPath: "/tmp/bellagio.kdbx",
RemoteProfileID: "bellagio-webdav",
CredentialEntryID: "remote-creds-1",
SyncMode: SyncModeAutomaticOnOpenSave,
})
if err != nil {
t.Fatalf("json.Marshal(RemoteBinding) error = %v", err)
}
text := string(content)
for _, disallowed := range []string{"bellagio-pass-1", "password", "username", "baseUrl"} {
if strings.Contains(text, disallowed) {
t.Fatalf("binding JSON %q unexpectedly contains %q", text, disallowed)
}
}
}
func TestConfigureRemoteBindingStoresProfileAndCredentialsInVault(t *testing.T) {
t.Parallel()
var model vault.Model
binding, err := ConfigureRemoteBinding(&model, RemoteBindingInput{
LocalVaultPath: "/tmp/bellagio.kdbx",
RemoteProfileID: "bellagio-webdav",
RemoteProfileName: "Bellagio Vault",
BaseURL: "https://dav.example.invalid/remote.php/dav",
RemotePath: "files/bellagio/keepass.kdbx",
CredentialEntryID: "remote-creds-1",
CredentialTitle: "Bellagio WebDAV Sign-In",
Username: "linuscaldwell",
Password: "bellagio-pass-1",
CredentialPath: []string{"Crew", "Internet"},
SyncMode: SyncModeAutomaticOnOpenSave,
})
if err != nil {
t.Fatalf("ConfigureRemoteBinding() error = %v", err)
}
if len(model.RemoteProfiles) != 1 {
t.Fatalf("len(RemoteProfiles) = %d, want 1", len(model.RemoteProfiles))
}
if got := model.RemoteProfiles[0].BaseURL; got != "https://dav.example.invalid/remote.php/dav" {
t.Fatalf("stored remote profile base URL = %q, want remote.php/dav URL", got)
}
credentials, err := model.EntryByID("remote-creds-1")
if err != nil {
t.Fatalf("EntryByID(remote-creds-1) error = %v", err)
}
if credentials.Username != "linuscaldwell" || credentials.Password != "bellagio-pass-1" {
t.Fatalf("stored credential entry = %#v, want linuscaldwell/bellagio-pass-1", credentials)
}
if credentials.URL != "https://dav.example.invalid/remote.php/dav" {
t.Fatalf("stored credential entry URL = %q, want remote.php/dav URL", credentials.URL)
}
if binding.LocalVaultPath != "/tmp/bellagio.kdbx" {
t.Fatalf("binding LocalVaultPath = %q, want /tmp/bellagio.kdbx", binding.LocalVaultPath)
}
if binding.RemoteProfileID != "bellagio-webdav" || binding.CredentialEntryID != "remote-creds-1" {
t.Fatalf("binding = %#v, want only vault references", binding)
}
}
func TestConfigureRemoteBindingRejectsIncompleteInput(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
input RemoteBindingInput
}{
{
name: "missing_local_vault_path",
input: RemoteBindingInput{
RemoteProfileID: "bellagio-webdav",
BaseURL: "https://dav.example.invalid/remote.php/dav",
RemotePath: "files/bellagio/keepass.kdbx",
CredentialEntryID: "remote-creds-1",
Password: "bellagio-pass-1",
},
},
{
name: "missing_remote_base_url",
input: RemoteBindingInput{
LocalVaultPath: "/tmp/bellagio.kdbx",
RemoteProfileID: "bellagio-webdav",
RemotePath: "files/bellagio/keepass.kdbx",
CredentialEntryID: "remote-creds-1",
Password: "bellagio-pass-1",
},
},
{
name: "missing_credential_entry_id",
input: RemoteBindingInput{
LocalVaultPath: "/tmp/bellagio.kdbx",
RemoteProfileID: "bellagio-webdav",
BaseURL: "https://dav.example.invalid/remote.php/dav",
RemotePath: "files/bellagio/keepass.kdbx",
Password: "bellagio-pass-1",
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var model vault.Model
if _, err := ConfigureRemoteBinding(&model, tc.input); err == nil {
t.Fatalf("ConfigureRemoteBinding(%#v) error = nil, want validation error", tc.input)
}
})
}
}
func TestRemoveRemoteBindingRemovesProfileAndCredentialsFromVault(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{{
ID: "remote-creds-1",
Title: "Bellagio WebDAV Sign-In",
Username: "linuscaldwell",
Password: "bellagio-pass-1",
}},
RemoteProfiles: []vault.RemoteProfile{{
ID: "bellagio-webdav",
Name: "Bellagio Vault",
Backend: vault.RemoteBackendWebDAV,
BaseURL: "https://dav.example.invalid/remote.php/dav",
Path: "files/bellagio/keepass.kdbx",
}},
}
err := RemoveRemoteBinding(&model, RemoteBinding{
LocalVaultPath: "/tmp/bellagio.kdbx",
RemoteProfileID: "bellagio-webdav",
CredentialEntryID: "remote-creds-1",
})
if err != nil {
t.Fatalf("RemoveRemoteBinding() error = %v", err)
}
if got := len(model.RemoteProfiles); got != 0 {
t.Fatalf("len(RemoteProfiles) = %d, want 0", got)
}
if _, err := model.EntryByID("remote-creds-1"); !errors.Is(err, vault.ErrEntryNotFound) {
t.Fatalf("EntryByID(remote-creds-1) error = %v, want ErrEntryNotFound", err)
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+12 -12
View File
@@ -31,18 +31,18 @@ import (
"gioui.org/widget"
"gioui.org/widget/material"
"gioui.org/x/explorer"
"git.julianfamily.org/keepassgo/api"
"git.julianfamily.org/keepassgo/apiapproval"
"git.julianfamily.org/keepassgo/apiaudit"
"git.julianfamily.org/keepassgo/apitokens"
"git.julianfamily.org/keepassgo/appstate"
keepassassets "git.julianfamily.org/keepassgo/assets"
"git.julianfamily.org/keepassgo/autofillcache"
"git.julianfamily.org/keepassgo/clipboard"
"git.julianfamily.org/keepassgo/passwords"
"git.julianfamily.org/keepassgo/session"
"git.julianfamily.org/keepassgo/vault"
"git.julianfamily.org/keepassgo/webdav"
"git.julianfamily.org/keepassgo/internal/api"
"git.julianfamily.org/keepassgo/internal/apiapproval"
"git.julianfamily.org/keepassgo/internal/apiaudit"
"git.julianfamily.org/keepassgo/internal/apitokens"
"git.julianfamily.org/keepassgo/internal/appstate"
keepassassets "git.julianfamily.org/keepassgo/internal/assets"
"git.julianfamily.org/keepassgo/internal/autofillcache"
"git.julianfamily.org/keepassgo/internal/clipboard"
"git.julianfamily.org/keepassgo/internal/passwords"
"git.julianfamily.org/keepassgo/internal/session"
"git.julianfamily.org/keepassgo/internal/vault"
"git.julianfamily.org/keepassgo/internal/webdav"
"golang.org/x/exp/shiny/materialdesign/icons"
)
+1 -1
View File
@@ -8,7 +8,7 @@ import (
gioclipboard "gioui.org/io/clipboard"
"gioui.org/layout"
appclipboard "git.julianfamily.org/keepassgo/clipboard"
appclipboard "git.julianfamily.org/keepassgo/internal/clipboard"
)
type clipboardCommandWriter struct {
+9 -9
View File
@@ -21,15 +21,15 @@ import (
"gioui.org/unit"
"gioui.org/widget"
"git.julianfamily.org/keepassgo/apiapproval"
"git.julianfamily.org/keepassgo/apiaudit"
"git.julianfamily.org/keepassgo/apitokens"
"git.julianfamily.org/keepassgo/appstate"
"git.julianfamily.org/keepassgo/clipboard"
"git.julianfamily.org/keepassgo/passwords"
"git.julianfamily.org/keepassgo/session"
"git.julianfamily.org/keepassgo/vault"
"git.julianfamily.org/keepassgo/webdav"
"git.julianfamily.org/keepassgo/internal/apiapproval"
"git.julianfamily.org/keepassgo/internal/apiaudit"
"git.julianfamily.org/keepassgo/internal/apitokens"
"git.julianfamily.org/keepassgo/internal/appstate"
"git.julianfamily.org/keepassgo/internal/clipboard"
"git.julianfamily.org/keepassgo/internal/passwords"
"git.julianfamily.org/keepassgo/internal/session"
"git.julianfamily.org/keepassgo/internal/vault"
"git.julianfamily.org/keepassgo/internal/webdav"
)
func TestMain(m *testing.M) {
+2 -2
View File
@@ -10,8 +10,8 @@ import (
"gioui.org/unit"
"gioui.org/widget"
"gioui.org/widget/material"
"git.julianfamily.org/keepassgo/apiaudit"
"git.julianfamily.org/keepassgo/apitokens"
"git.julianfamily.org/keepassgo/internal/apiaudit"
"git.julianfamily.org/keepassgo/internal/apitokens"
)
func apiOperations() []apitokens.Operation {
+3 -3
View File
@@ -8,9 +8,9 @@ import (
"strings"
"gioui.org/widget"
"git.julianfamily.org/keepassgo/clipboard"
"git.julianfamily.org/keepassgo/passwords"
"git.julianfamily.org/keepassgo/vault"
"git.julianfamily.org/keepassgo/internal/clipboard"
"git.julianfamily.org/keepassgo/internal/passwords"
"git.julianfamily.org/keepassgo/internal/vault"
)
func (u *ui) attachmentInput() (string, []byte, error) {
+1 -1
View File
@@ -15,7 +15,7 @@ import (
"gioui.org/unit"
"gioui.org/widget"
"gioui.org/widget/material"
"git.julianfamily.org/keepassgo/appstate"
"git.julianfamily.org/keepassgo/internal/appstate"
)
func (u *ui) lifecycleControls(gtx layout.Context) layout.Dimensions {
+1 -1
View File
@@ -6,7 +6,7 @@ import (
"strings"
"gioui.org/io/key"
"git.julianfamily.org/keepassgo/appstate"
"git.julianfamily.org/keepassgo/internal/appstate"
)
type focusID string
+1 -1
View File
@@ -12,7 +12,7 @@ import (
"gioui.org/unit"
"gioui.org/widget"
"gioui.org/widget/material"
"git.julianfamily.org/keepassgo/vault"
"git.julianfamily.org/keepassgo/internal/vault"
)
const (
+1 -1
View File
@@ -7,7 +7,7 @@ import (
"gioui.org/io/key"
"gioui.org/layout"
"git.julianfamily.org/keepassgo/clipboard"
"git.julianfamily.org/keepassgo/internal/clipboard"
)
const (
+1 -1
View File
@@ -4,7 +4,7 @@ import (
"runtime"
"strings"
"git.julianfamily.org/keepassgo/appstate"
"git.julianfamily.org/keepassgo/internal/appstate"
)
type syncMenuModel struct {
+23
View File
@@ -0,0 +1,23 @@
package assets
import (
"bytes"
"embed"
"image"
"image/png"
)
//go:embed *.png *.svg
var files embed.FS
func MustPNG(name string) image.Image {
data, err := files.ReadFile(name)
if err != nil {
panic(err)
}
img, err := png.Decode(bytes.NewReader(data))
if err != nil {
panic(err)
}
return img
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

+18
View File
@@ -0,0 +1,18 @@
<svg xmlns="http://www.w3.org/2000/svg" width="256" height="256" viewBox="0 0 256 256" fill="none">
<title>KeePassGO icon</title>
<desc>Vault-shaped mark with lock, layered record lines, and navigation accent.</desc>
<polygon points="128,18 214,38 214,176 128,218 42,176 42,38" fill="#3F4E63"/>
<polygon points="128,34 198,50 198,166 128,200 58,166 58,50" fill="#55657A" opacity="0.16"/>
<polygon points="128,178 128,218 42,176 42,160" fill="#6D89A8"/>
<rect x="44" y="78" width="62" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
<rect x="44" y="98" width="52" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
<path d="M100 86C100 70.536 112.536 58 128 58C143.464 58 156 70.536 156 86V106H142V86C142 78.268 135.732 72 128 72C120.268 72 114 78.268 114 86V106H100V86Z" fill="#FFFFFF"/>
<rect x="76" y="102" width="104" height="64" rx="10" fill="#FFFFFF"/>
<rect x="80" y="106" width="96" height="56" rx="8" fill="#E9EDF0"/>
<path d="M128 118C139.046 118 148 126.954 148 138C148 145.669 143.68 152.329 137.338 155.7V173H118.662V155.7C112.32 152.329 108 145.669 108 138C108 126.954 116.954 118 128 118Z" fill="#3F4E63"/>
<circle cx="128" cy="138" r="8" fill="#FFFFFF" opacity="0.2"/>
<path d="M178 130H214V150H166V144L178 130Z" fill="#E79A17"/>
<path d="M178 130H214V140H174L166 144V144L178 130Z" fill="#F0B13D" opacity="0.35"/>
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

@@ -0,0 +1,22 @@
<svg xmlns="http://www.w3.org/2000/svg" width="920" height="260" viewBox="0 0 920 260" fill="none">
<title>KeePassGO horizontal logo</title>
<desc>KeePassGO symbol with wordmark for light-theme desktop UI.</desc>
<defs>
<symbol id="kpg-icon" viewBox="0 0 256 256">
<polygon points="128,18 214,38 214,176 128,218 42,176 42,38" fill="#3F4E63"/>
<polygon points="128,34 198,50 198,166 128,200 58,166 58,50" fill="#55657A" opacity="0.16"/>
<polygon points="128,178 128,218 42,176 42,160" fill="#6D89A8"/>
<rect x="44" y="78" width="62" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
<rect x="44" y="98" width="52" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
<path d="M100 86C100 70.536 112.536 58 128 58C143.464 58 156 70.536 156 86V106H142V86C142 78.268 135.732 72 128 72C120.268 72 114 78.268 114 86V106H100V86Z" fill="#FFFFFF"/>
<rect x="76" y="102" width="104" height="64" rx="10" fill="#FFFFFF"/>
<rect x="80" y="106" width="96" height="56" rx="8" fill="#E9EDF0"/>
<path d="M128 118C139.046 118 148 126.954 148 138C148 145.669 143.68 152.329 137.338 155.7V173H118.662V155.7C112.32 152.329 108 145.669 108 138C108 126.954 116.954 118 128 118Z" fill="#3F4E63"/>
<path d="M178 130H214V150H166V144L178 130Z" fill="#E79A17"/>
<path d="M178 130H214V140H174L166 144V144L178 130Z" fill="#F0B13D" opacity="0.35"/>
</symbol>
</defs>
<rect width="920" height="260" fill="transparent"/>
<use href="#kpg-icon" x="24" y="20" width="176" height="176"/>
<text x="220" y="132" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="84" font-weight="650" fill="#3F4E63">KeePassGO</text>
</svg>

After

Width:  |  Height:  |  Size: 1.7 KiB

@@ -0,0 +1,75 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1280" height="800" viewBox="0 0 1280 800" fill="none">
<title>KeePassGO splash screen</title>
<desc>Light-theme desktop splash screen with structured panels and the KeePassGO logo.</desc>
<defs>
<symbol id="kpg-icon" viewBox="0 0 256 256">
<polygon points="128,18 214,38 214,176 128,218 42,176 42,38" fill="#3F4E63"/>
<polygon points="128,34 198,50 198,166 128,200 58,166 58,50" fill="#55657A" opacity="0.16"/>
<polygon points="128,178 128,218 42,176 42,160" fill="#6D89A8"/>
<rect x="44" y="78" width="62" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
<rect x="44" y="98" width="52" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
<path d="M100 86C100 70.536 112.536 58 128 58C143.464 58 156 70.536 156 86V106H142V86C142 78.268 135.732 72 128 72C120.268 72 114 78.268 114 86V106H100V86Z" fill="#FFFFFF"/>
<rect x="76" y="102" width="104" height="64" rx="10" fill="#FFFFFF"/>
<rect x="80" y="106" width="96" height="56" rx="8" fill="#E9EDF0"/>
<path d="M128 118C139.046 118 148 126.954 148 138C148 145.669 143.68 152.329 137.338 155.7V173H118.662V155.7C112.32 152.329 108 145.669 108 138C108 126.954 116.954 118 128 118Z" fill="#3F4E63"/>
<path d="M178 130H214V150H166V144L178 130Z" fill="#E79A17"/>
<path d="M178 130H214V140H174L166 144V144L178 130Z" fill="#F0B13D" opacity="0.35"/>
</symbol>
<linearGradient id="fade" x1="0" y1="0" x2="0" y2="1">
<stop offset="0" stop-color="#D9E0E6" stop-opacity="0.9"/>
<stop offset="1" stop-color="#D9E0E6" stop-opacity="0.2"/>
</linearGradient>
</defs>
<rect width="1280" height="800" fill="#F7F7F5"/>
<g opacity="0.9">
<rect x="48" y="48" width="238" height="188" fill="#EEF2F4"/>
<rect x="286" y="48" width="152" height="104" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="438" y="48" width="214" height="104" fill="#E7EDF1"/>
<rect x="652" y="48" width="146" height="188" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="798" y="48" width="224" height="118" fill="#EEF2F4"/>
<rect x="1022" y="48" width="210" height="188" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="48" y="236" width="128" height="164" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="176" y="236" width="262" height="164" fill="#E7EDF1"/>
<rect x="438" y="236" width="160" height="84" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="598" y="236" width="200" height="164" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="798" y="236" width="434" height="164" fill="#EEF2F4"/>
<rect x="48" y="400" width="224" height="128" fill="#EEF2F4"/>
<rect x="272" y="400" width="166" height="128" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="438" y="400" width="360" height="128" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="798" y="400" width="152" height="128" fill="#E7EDF1"/>
<rect x="950" y="400" width="282" height="128" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="48" y="528" width="114" height="224" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="162" y="528" width="276" height="224" fill="#E7EDF1"/>
<rect x="438" y="528" width="212" height="224" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="650" y="528" width="318" height="224" fill="#EEF2F4"/>
<rect x="968" y="528" width="264" height="224" fill="#F7F7F5" stroke="#C7D0D8"/>
</g>
<g opacity="0.45">
<path d="M48 152H1232" stroke="#C7D0D8"/>
<path d="M48 320H1232" stroke="#C7D0D8"/>
<path d="M48 528H1232" stroke="#C7D0D8"/>
<path d="M176 48V752" stroke="#C7D0D8"/>
<path d="M438 48V752" stroke="#C7D0D8"/>
<path d="M798 48V752" stroke="#C7D0D8"/>
<path d="M1022 48V752" stroke="#C7D0D8"/>
</g>
<rect x="290" y="190" width="700" height="420" rx="18" fill="#F7F7F5" opacity="0.92"/>
<rect x="290" y="190" width="700" height="420" rx="18" stroke="#C7D0D8"/>
<rect x="290" y="190" width="700" height="120" rx="18" fill="url(#fade)"/>
<use href="#kpg-icon" x="530" y="232" width="220" height="220"/>
<text x="640" y="530" text-anchor="middle" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="88" font-weight="650" fill="#3F4E63">KeePassGO</text>
<text x="640" y="582" text-anchor="middle" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="28" font-weight="500" fill="#55657A">KeePass-compatible password manager</text>
<g opacity="0.9">
<rect x="516" y="634" width="248" height="10" rx="5" fill="#D9E0E6"/>
<rect x="516" y="634" width="132" height="10" rx="5" fill="#6D89A8"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

@@ -0,0 +1,41 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1024" height="1024" viewBox="0 0 1024 1024" fill="none">
<title>KeePassGO square splash preview</title>
<desc>Square crop of the KeePassGO splash system.</desc>
<defs>
<symbol id="kpg-icon" viewBox="0 0 256 256">
<polygon points="128,18 214,38 214,176 128,218 42,176 42,38" fill="#3F4E63"/>
<polygon points="128,34 198,50 198,166 128,200 58,166 58,50" fill="#55657A" opacity="0.16"/>
<polygon points="128,178 128,218 42,176 42,160" fill="#6D89A8"/>
<rect x="44" y="78" width="62" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
<rect x="44" y="98" width="52" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
<path d="M100 86C100 70.536 112.536 58 128 58C143.464 58 156 70.536 156 86V106H142V86C142 78.268 135.732 72 128 72C120.268 72 114 78.268 114 86V106H100V86Z" fill="#FFFFFF"/>
<rect x="76" y="102" width="104" height="64" rx="10" fill="#FFFFFF"/>
<rect x="80" y="106" width="96" height="56" rx="8" fill="#E9EDF0"/>
<path d="M128 118C139.046 118 148 126.954 148 138C148 145.669 143.68 152.329 137.338 155.7V173H118.662V155.7C112.32 152.329 108 145.669 108 138C108 126.954 116.954 118 128 118Z" fill="#3F4E63"/>
<path d="M178 130H214V150H166V144L178 130Z" fill="#E79A17"/>
<path d="M178 130H214V140H174L166 144V144L178 130Z" fill="#F0B13D" opacity="0.35"/>
</symbol>
</defs>
<rect width="1024" height="1024" fill="#F7F7F5"/>
<g opacity="0.9">
<rect x="32" y="32" width="180" height="184" fill="#EEF2F4"/>
<rect x="212" y="32" width="264" height="140" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="476" y="32" width="180" height="184" fill="#E7EDF1"/>
<rect x="656" y="32" width="336" height="140" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="656" y="172" width="336" height="180" fill="#EEF2F4"/>
<rect x="32" y="216" width="236" height="210" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="268" y="216" width="388" height="210" fill="#EEF2F4"/>
<rect x="32" y="426" width="180" height="268" fill="#E7EDF1"/>
<rect x="212" y="426" width="312" height="268" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="524" y="426" width="468" height="268" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="32" y="694" width="280" height="298" fill="#F7F7F5" stroke="#C7D0D8"/>
<rect x="312" y="694" width="268" height="298" fill="#EEF2F4"/>
<rect x="580" y="694" width="412" height="298" fill="#F7F7F5" stroke="#C7D0D8"/>
</g>
<rect x="162" y="190" width="700" height="644" rx="24" fill="#F7F7F5" opacity="0.95"/>
<rect x="162" y="190" width="700" height="644" rx="24" stroke="#C7D0D8"/>
<use href="#kpg-icon" x="332" y="252" width="360" height="360"/>
<text x="512" y="680" text-anchor="middle" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="88" font-weight="650" fill="#3F4E63">KeePassGO</text>
<text x="512" y="740" text-anchor="middle" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="28" font-weight="500" fill="#55657A">KeePass-compatible password manager</text>
</svg>

After

Width:  |  Height:  |  Size: 3.0 KiB

+296
View File
@@ -0,0 +1,296 @@
package autofillcache
import (
"encoding/json"
"net/url"
"os"
"path/filepath"
"sort"
"strings"
"time"
"git.julianfamily.org/keepassgo/internal/vault"
)
type Entry struct {
ID string `json:"id"`
Title string `json:"title"`
Username string `json:"username"`
Password string `json:"password"`
URL string `json:"url"`
Host string `json:"host"`
Targets []string `json:"targets,omitempty"`
Path []string `json:"path,omitempty"`
}
type File struct {
UpdatedAt string `json:"updatedAt"`
Entries []Entry `json:"entries"`
}
type MatchStatus string
const (
MatchStatusNone MatchStatus = ""
MatchStatusFound MatchStatus = "found"
MatchStatusAmbiguous MatchStatus = "ambiguous"
MatchStatusMissing MatchStatus = "missing"
)
type MatchResult struct {
Status MatchStatus `json:"status"`
Entry Entry `json:"entry,omitempty"`
}
func Match(cache File, webURL string) (Entry, bool) {
result := Resolve(cache, webURL)
return result.Entry, result.Status == MatchStatusFound
}
func Resolve(cache File, webURL string) MatchResult {
target := normalizeURL(webURL)
if target.host == "" {
return MatchResult{Status: MatchStatusMissing}
}
exactHost := make([]Entry, 0)
parentHost := make([]Entry, 0)
for _, entry := range cache.Entries {
if entryMatchesHost(entry, target.host) {
exactHost = append(exactHost, entry)
continue
}
if entryMatchesParentHost(entry, target.host) {
parentHost = append(parentHost, entry)
}
}
if result := chooseEntry(target, exactHost); result.Status != MatchStatusMissing {
return result
}
return chooseEntry(target, parentHost)
}
func Build(model vault.Model, now time.Time) File {
entries := make([]Entry, 0, len(model.Entries))
for _, item := range model.Entries {
targets := collectTargets(item)
host := normalizeHost(item.URL)
if host == "" {
for _, target := range targets {
host = normalizeHost(target)
if host != "" {
break
}
}
}
if host == "" {
continue
}
if strings.TrimSpace(item.Username) == "" || strings.TrimSpace(item.Password) == "" {
continue
}
entries = append(entries, Entry{
ID: item.ID,
Title: item.Title,
Username: item.Username,
Password: item.Password,
URL: item.URL,
Host: host,
Targets: targets,
Path: append([]string(nil), item.Path...),
})
}
return File{
UpdatedAt: now.UTC().Format(time.RFC3339),
Entries: entries,
}
}
func Write(path string, model vault.Model, now time.Time) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(Build(model, now), "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
func Clear(path string) error {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func normalizeHost(raw string) string {
return normalizeURL(raw).host
}
type normalizedTarget struct {
host string
path string
url string
}
func normalizeURL(raw string) normalizedTarget {
value := strings.TrimSpace(raw)
if value == "" {
return normalizedTarget{}
}
if !strings.Contains(value, "://") {
value = "https://" + value
}
parsed, err := url.Parse(value)
if err != nil {
return normalizedTarget{}
}
host := strings.TrimSpace(parsed.Hostname())
path := cleanPath(parsed.EscapedPath())
return normalizedTarget{
host: strings.ToLower(host),
path: path,
url: strings.ToLower(host) + path,
}
}
func cleanPath(path string) string {
path = strings.TrimSpace(path)
if path == "" || path == "/" {
return "/"
}
path = strings.TrimRight(path, "/")
if path == "" {
return "/"
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return path
}
func chooseEntry(target normalizedTarget, entries []Entry) MatchResult {
switch len(entries) {
case 0:
return MatchResult{Status: MatchStatusMissing}
case 1:
return MatchResult{Status: MatchStatusFound, Entry: entries[0]}
}
exact := make([]Entry, 0)
bestPrefixLen := -1
bestPrefix := make([]Entry, 0)
for _, entry := range entries {
exactMatch, prefixLen := bestTargetMatch(entry, target)
if exactMatch {
exact = append(exact, entry)
continue
}
if prefixLen <= 0 {
continue
}
switch {
case prefixLen > bestPrefixLen:
bestPrefixLen = prefixLen
bestPrefix = []Entry{entry}
case prefixLen == bestPrefixLen:
bestPrefix = append(bestPrefix, entry)
}
}
if len(exact) == 1 {
return MatchResult{Status: MatchStatusFound, Entry: exact[0]}
}
if len(exact) > 1 {
return MatchResult{Status: MatchStatusAmbiguous}
}
if len(bestPrefix) == 1 {
return MatchResult{Status: MatchStatusFound, Entry: bestPrefix[0]}
}
if len(bestPrefix) == 0 {
return MatchResult{Status: MatchStatusMissing}
}
return MatchResult{Status: MatchStatusAmbiguous}
}
func collectTargets(item vault.Entry) []string {
seen := make(map[string]struct{})
targets := make([]string, 0, 1+len(item.Fields))
appendTarget := func(raw string) {
value := strings.TrimSpace(raw)
if value == "" {
return
}
if _, ok := seen[value]; ok {
return
}
seen[value] = struct{}{}
targets = append(targets, value)
}
appendTarget(item.URL)
keys := make([]string, 0, len(item.Fields))
for key := range item.Fields {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
upper := strings.ToUpper(strings.TrimSpace(key))
if strings.HasPrefix(upper, "ANDROIDAPP") || strings.HasPrefix(upper, "KP2A_URL") {
appendTarget(item.Fields[key])
}
}
return targets
}
func entryTargets(entry Entry) []normalizedTarget {
values := entry.Targets
if len(values) == 0 {
values = []string{entry.URL}
}
targets := make([]normalizedTarget, 0, len(values))
for _, value := range values {
target := normalizeURL(value)
if target.host == "" {
continue
}
targets = append(targets, target)
}
return targets
}
func entryMatchesHost(entry Entry, host string) bool {
for _, target := range entryTargets(entry) {
if target.host == host {
return true
}
}
return false
}
func entryMatchesParentHost(entry Entry, host string) bool {
for _, target := range entryTargets(entry) {
if target.host != "" && strings.HasSuffix(host, "."+target.host) {
return true
}
}
return false
}
func bestTargetMatch(entry Entry, target normalizedTarget) (bool, int) {
bestPrefixLen := -1
for _, candidate := range entryTargets(entry) {
if candidate.url == target.url {
return true, 0
}
if candidate.path != "/" && strings.HasPrefix(target.path, candidate.path) {
if pathLen := len(candidate.path); pathLen > bestPrefixLen {
bestPrefixLen = pathLen
}
}
}
return false, bestPrefixLen
}
+339
View File
@@ -0,0 +1,339 @@
package autofillcache
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"git.julianfamily.org/keepassgo/internal/vault"
)
func TestBuildFiltersAndNormalizesEntries(t *testing.T) {
t.Parallel()
now := time.Date(2026, time.March, 31, 12, 0, 0, 0, time.UTC)
got := Build(vault.Model{
Entries: []vault.Entry{
{
ID: "one",
Title: "Chrome Test",
Username: "joe",
Password: "secret",
URL: "https://10.0.2.2:8443/login",
Path: []string{"Crew", "Internet"},
},
{
ID: "two",
Title: "No Password",
Username: "joe",
URL: "https://example.com",
},
{
ID: "three",
Title: "Bare Host",
Username: "user",
Password: "pass",
URL: "surveillance.crew.example.invalid",
Fields: map[string]string{
"AndroidApp1": "androidapp://com.lights.mobile",
"KP2A_URL_1": "https://surveillance.crew.example.invalid/account",
},
},
},
}, now)
if len(got.Entries) != 2 {
t.Fatalf("entry count = %d, want 2", len(got.Entries))
}
if got.Entries[0].Host != "10.0.2.2" {
t.Fatalf("first host = %q, want 10.0.2.2", got.Entries[0].Host)
}
if got.Entries[1].Host != "surveillance.crew.example.invalid" {
t.Fatalf("second host = %q, want surveillance.crew.example.invalid", got.Entries[1].Host)
}
if len(got.Entries[1].Targets) != 3 {
t.Fatalf("len(second targets) = %d, want 3", len(got.Entries[1].Targets))
}
if got.UpdatedAt != "2026-03-31T12:00:00Z" {
t.Fatalf("updatedAt = %q", got.UpdatedAt)
}
}
func TestWriteAndClear(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, "autofill-cache.json")
model := vault.Model{
Entries: []vault.Entry{
{ID: "one", Title: "Chrome Test", Username: "joe", Password: "secret", URL: "https://10.0.2.2:8443/login"},
},
}
if err := Write(path, model, time.Date(2026, time.March, 31, 12, 0, 0, 0, time.UTC)); err != nil {
t.Fatalf("Write() error = %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
var got File
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(got.Entries) != 1 || got.Entries[0].Host != "10.0.2.2" {
t.Fatalf("cache entries = %#v", got.Entries)
}
if err := Clear(path); err != nil {
t.Fatalf("Clear() error = %v", err)
}
if _, err := os.Stat(path); !os.IsNotExist(err) {
t.Fatalf("cache path still exists, stat err = %v", err)
}
}
func TestMatchChoosesExactURLWhenHostsRepeat(t *testing.T) {
t.Parallel()
cache := File{
Entries: []Entry{
{
ID: "one",
Title: "Primary Login",
Username: "first",
Password: "secret1",
URL: "https://10.0.2.2:8443/login/",
Host: "10.0.2.2",
},
{
ID: "two",
Title: "Alt Login",
Username: "second",
Password: "secret2",
URL: "https://10.0.2.2:8443/alt/",
Host: "10.0.2.2",
},
},
}
got, ok := Match(cache, "https://10.0.2.2:8443/alt/")
if !ok {
t.Fatalf("Match() found no entry")
}
if got.ID != "two" {
t.Fatalf("Match() entry = %q, want two", got.ID)
}
}
func TestMatchRejectsAmbiguousSharedHost(t *testing.T) {
t.Parallel()
cache := File{
Entries: []Entry{
{
ID: "one",
Title: "Host A",
Username: "first",
Password: "secret1",
URL: "https://surveillance.crew.example.invalid/",
Host: "surveillance.crew.example.invalid",
},
{
ID: "two",
Title: "Host B",
Username: "second",
Password: "secret2",
URL: "https://surveillance.crew.example.invalid/",
Host: "surveillance.crew.example.invalid",
},
},
}
if _, ok := Match(cache, "https://surveillance.crew.example.invalid/"); ok {
t.Fatalf("Match() unexpectedly resolved ambiguous shared host")
}
}
func TestResolveReportsFoundAmbiguousAndMissingStatuses(t *testing.T) {
t.Parallel()
cache := File{
Entries: []Entry{
{
ID: "one",
Title: "Admin Login",
Username: "admin",
Password: "secret1",
URL: "https://example.com/admin",
Host: "example.com",
},
{
ID: "two",
Title: "Shared Login A",
Username: "shared-a",
Password: "secret2",
URL: "https://shared.example.com",
Host: "shared.example.com",
},
{
ID: "three",
Title: "Shared Login B",
Username: "shared-b",
Password: "secret3",
URL: "https://shared.example.com",
Host: "shared.example.com",
},
},
}
if got := Resolve(cache, "https://example.com/admin/login"); got.Status != MatchStatusFound || got.Entry.ID != "one" {
t.Fatalf("Resolve(found) = %#v, want found entry one", got)
}
if got := Resolve(cache, "https://shared.example.com"); got.Status != MatchStatusAmbiguous {
t.Fatalf("Resolve(ambiguous) = %#v, want ambiguous", got)
}
if got := Resolve(cache, "https://nowhere.invalid"); got.Status != MatchStatusMissing {
t.Fatalf("Resolve(missing) = %#v, want missing", got)
}
}
func TestMatchChoosesLongestPathPrefix(t *testing.T) {
t.Parallel()
cache := File{
Entries: []Entry{
{
ID: "one",
Title: "Generic Login",
Username: "generic",
Password: "secret1",
URL: "https://example.com/",
Host: "example.com",
},
{
ID: "two",
Title: "Admin Login",
Username: "admin",
Password: "secret2",
URL: "https://example.com/admin",
Host: "example.com",
},
},
}
got, ok := Match(cache, "https://example.com/admin/login")
if !ok {
t.Fatalf("Match() found no entry")
}
if got.ID != "two" {
t.Fatalf("Match() entry = %q, want two", got.ID)
}
}
func TestMatchSupportsAndroidAppPackageTargets(t *testing.T) {
t.Parallel()
cache := File{
Entries: []Entry{
{
ID: "one",
Title: "Thunderbird",
Username: "mail-user",
Password: "secret1",
URL: "androidapp://org.mozilla.thunderbird/login",
Host: "org.mozilla.thunderbird",
},
},
}
got, ok := Match(cache, "androidapp://org.mozilla.thunderbird")
if !ok {
t.Fatalf("Match() found no entry")
}
if got.ID != "one" {
t.Fatalf("Match() entry = %q, want one", got.ID)
}
}
func TestMatchRejectsAmbiguousAndroidAppPackageTargets(t *testing.T) {
t.Parallel()
cache := File{
Entries: []Entry{
{
ID: "one",
Title: "Thunderbird Primary",
Username: "mail-user",
Password: "secret1",
URL: "androidapp://org.mozilla.thunderbird",
Host: "org.mozilla.thunderbird",
},
{
ID: "two",
Title: "Thunderbird Secondary",
Username: "other-user",
Password: "secret2",
URL: "androidapp://org.mozilla.thunderbird",
Host: "org.mozilla.thunderbird",
},
},
}
if _, ok := Match(cache, "androidapp://org.mozilla.thunderbird"); ok {
t.Fatalf("Match() unexpectedly resolved ambiguous android app package target")
}
}
func TestMatchUsesAndroidAppCustomFieldTarget(t *testing.T) {
t.Parallel()
cache := File{
Entries: []Entry{
{
ID: "one",
Title: "Blink",
Username: "blink-user",
Password: "secret1",
URL: "https://account.blinknetwork.com",
Host: "account.blinknetwork.com",
Targets: []string{"https://account.blinknetwork.com", "androidapp://com.blinknetwork.mobile2"},
},
},
}
got, ok := Match(cache, "androidapp://com.blinknetwork.mobile2")
if !ok {
t.Fatalf("Match() found no entry")
}
if got.ID != "one" {
t.Fatalf("Match() entry = %q, want one", got.ID)
}
}
func TestMatchUsesKP2AURLCustomFieldTarget(t *testing.T) {
t.Parallel()
cache := File{
Entries: []Entry{
{
ID: "one",
Title: "Blink",
Username: "blink-user",
Password: "secret1",
URL: "https://blinknetwork.com",
Host: "blinknetwork.com",
Targets: []string{"https://blinknetwork.com", "https://account.blinknetwork.com"},
},
},
}
got, ok := Match(cache, "https://account.blinknetwork.com")
if !ok {
t.Fatalf("Match() found no entry")
}
if got.ID != "one" {
t.Fatalf("Match() entry = %q, want one", got.ID)
}
}
+101
View File
@@ -0,0 +1,101 @@
package clipboard
import (
"errors"
systemclipboard "github.com/atotto/clipboard"
"git.julianfamily.org/keepassgo/internal/vault"
)
var ErrUnsupportedTarget = errors.New("unsupported clipboard target")
var ErrWriteFailed = errors.New("clipboard write failed")
type Target string
const (
TargetUsername Target = "username"
TargetPassword Target = "password"
TargetURL Target = "url"
)
type Writer interface {
WriteText(text string) error
}
type Service struct {
Writer Writer
}
func (s Service) Copy(model vault.Model, entryID string, target Target) error {
entry, err := findEntry(model, entryID)
if err != nil {
return err
}
content, err := contentForTarget(entry, target)
if err != nil {
return err
}
if err := s.writer().WriteText(content); err != nil {
return writeError{err: err}
}
return nil
}
func (s Service) writer() Writer {
if s.Writer != nil {
return s.Writer
}
return systemWriter{}
}
func WriteText(text string) error {
return systemWriter{}.WriteText(text)
}
func findEntry(model vault.Model, entryID string) (vault.Entry, error) {
for _, entry := range model.Entries {
if entry.ID == entryID {
return entry, nil
}
}
return vault.Entry{}, vault.ErrEntryNotFound
}
func contentForTarget(entry vault.Entry, target Target) (string, error) {
switch target {
case TargetUsername:
return entry.Username, nil
case TargetPassword:
return entry.Password, nil
case TargetURL:
return entry.URL, nil
default:
return "", ErrUnsupportedTarget
}
}
type systemWriter struct{}
func (systemWriter) WriteText(text string) error {
return systemclipboard.WriteAll(text)
}
type writeError struct {
err error
}
func (e writeError) Error() string {
return ErrWriteFailed.Error()
}
func (e writeError) Unwrap() error {
return e.err
}
func (e writeError) Is(target error) bool {
return target == ErrWriteFailed
}
+105
View File
@@ -0,0 +1,105 @@
package clipboard
import (
"errors"
"testing"
"git.julianfamily.org/keepassgo/internal/vault"
)
func TestServiceCopiesUsernamePasswordAndURL(t *testing.T) {
t.Parallel()
model := vault.Model{
Entries: []vault.Entry{
{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "token-1",
URL: "https://vault.crew.example.invalid",
},
},
}
tests := []struct {
name string
target Target
want string
}{
{name: "username", target: TargetUsername, want: "dannyocean"},
{name: "password", target: TargetPassword, want: "token-1"},
{name: "url", target: TargetURL, want: "https://vault.crew.example.invalid"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var writer memoryWriter
service := Service{Writer: &writer}
if err := service.Copy(model, "vault-console", tt.target); err != nil {
t.Fatalf("Copy() error = %v", err)
}
if writer.content != tt.want {
t.Fatalf("clipboard content = %q, want %q", writer.content, tt.want)
}
})
}
}
func TestServiceRejectsUnknownEntryAndUnsupportedTarget(t *testing.T) {
t.Parallel()
var writer memoryWriter
service := Service{Writer: &writer}
err := service.Copy(vault.Model{}, "missing", TargetPassword)
if !errors.Is(err, vault.ErrEntryNotFound) {
t.Fatalf("Copy() missing entry error = %v, want ErrEntryNotFound", err)
}
model := vault.Model{
Entries: []vault.Entry{{ID: "vault-console", Username: "dannyocean"}},
}
err = service.Copy(model, "vault-console", Target("unsupported"))
if !errors.Is(err, ErrUnsupportedTarget) {
t.Fatalf("Copy() unsupported target error = %v, want ErrUnsupportedTarget", err)
}
}
func TestServiceSanitizesClipboardWriteErrors(t *testing.T) {
t.Parallel()
service := Service{Writer: failingWriter{err: errors.New("backend refused token-1")}}
model := vault.Model{
Entries: []vault.Entry{
{ID: "vault-console", Password: "token-1"},
},
}
err := service.Copy(model, "vault-console", TargetPassword)
if !errors.Is(err, ErrWriteFailed) {
t.Fatalf("Copy() write error = %v, want ErrWriteFailed", err)
}
if err.Error() != ErrWriteFailed.Error() {
t.Fatalf("Copy() write error string = %q, want %q", err.Error(), ErrWriteFailed.Error())
}
}
type memoryWriter struct {
content string
}
func (w *memoryWriter) WriteText(text string) error {
w.content = text
return nil
}
type failingWriter struct {
err error
}
func (w failingWriter) WriteText(string) error {
return w.err
}
+195
View File
@@ -0,0 +1,195 @@
package passwords
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"slices"
"strings"
)
const (
lowercaseChars = "abcdefghijkmnopqrstuvwxyz"
uppercaseChars = "ABCDEFGHJKLMNPQRSTUVWXYZ"
digitChars = "23456789"
symbolChars = "!@#$%^&*()-_=+[]{}<>?/."
)
var ErrImpossibleProfile = errors.New("impossible password profile")
var ErrUnknownProfile = errors.New("unknown password profile")
type Profile struct {
Name string
Length int
Lowercase bool
Uppercase bool
Digits bool
Symbols bool
MinLowercase int
MinUppercase int
MinDigits int
MinSymbols int
ExcludeSimilar bool
}
func DefaultProfiles() map[string]Profile {
return map[string]Profile{
"strong": {
Name: "strong",
Length: 24,
Lowercase: true,
Uppercase: true,
Digits: true,
Symbols: true,
MinLowercase: 2,
MinUppercase: 2,
MinDigits: 2,
MinSymbols: 2,
ExcludeSimilar: true,
},
"memorable": {
Name: "memorable",
Length: 20,
Lowercase: true,
Uppercase: true,
Digits: true,
Symbols: false,
MinLowercase: 4,
MinUppercase: 2,
MinDigits: 2,
ExcludeSimilar: true,
},
}
}
func DefaultProfileNames() []string {
return ProfileNames(DefaultProfiles())
}
func LookupProfile(name string, profiles map[string]Profile) (Profile, error) {
profile, ok := profiles[strings.TrimSpace(name)]
if !ok {
return Profile{}, fmt.Errorf("%w %q", ErrUnknownProfile, strings.TrimSpace(name))
}
return profile, nil
}
func LookupDefaultProfile(name string) (Profile, error) {
return LookupProfile(name, DefaultProfiles())
}
func ProfileNames(profiles map[string]Profile) []string {
names := make([]string, 0, len(profiles))
for name := range profiles {
names = append(names, name)
}
slices.Sort(names)
return names
}
func Generate(profile Profile) (string, error) {
if err := validateProfile(profile); err != nil {
return "", err
}
var chars []byte
var pool strings.Builder
if profile.Lowercase {
pool.WriteString(lowercaseChars)
chars = append(chars, mustRandomChars(lowercaseChars, profile.MinLowercase)...)
}
if profile.Uppercase {
pool.WriteString(uppercaseChars)
chars = append(chars, mustRandomChars(uppercaseChars, profile.MinUppercase)...)
}
if profile.Digits {
pool.WriteString(digitChars)
chars = append(chars, mustRandomChars(digitChars, profile.MinDigits)...)
}
if profile.Symbols {
pool.WriteString(symbolChars)
chars = append(chars, mustRandomChars(symbolChars, profile.MinSymbols)...)
}
allChars := pool.String()
for len(chars) < profile.Length {
ch, err := randomChar(allChars)
if err != nil {
return "", err
}
chars = append(chars, ch)
}
if err := shuffle(chars); err != nil {
return "", err
}
return string(chars), nil
}
func validateProfile(profile Profile) error {
if profile.Length <= 0 {
return fmt.Errorf("%w: length must be positive", ErrImpossibleProfile)
}
required := profile.MinLowercase + profile.MinUppercase + profile.MinDigits + profile.MinSymbols
if required > profile.Length {
return fmt.Errorf("%w: minimum character counts exceed length", ErrImpossibleProfile)
}
if profile.MinLowercase > 0 && !profile.Lowercase {
return fmt.Errorf("%w: lowercase disabled with lowercase minimum", ErrImpossibleProfile)
}
if profile.MinUppercase > 0 && !profile.Uppercase {
return fmt.Errorf("%w: uppercase disabled with uppercase minimum", ErrImpossibleProfile)
}
if profile.MinDigits > 0 && !profile.Digits {
return fmt.Errorf("%w: digits disabled with digit minimum", ErrImpossibleProfile)
}
if profile.MinSymbols > 0 && !profile.Symbols {
return fmt.Errorf("%w: symbols disabled with symbol minimum", ErrImpossibleProfile)
}
if !profile.Lowercase && !profile.Uppercase && !profile.Digits && !profile.Symbols {
return fmt.Errorf("%w: no character sets enabled", ErrImpossibleProfile)
}
return nil
}
func mustRandomChars(chars string, count int) []byte {
if count <= 0 {
return nil
}
out := make([]byte, 0, count)
for i := 0; i < count; i++ {
ch, err := randomChar(chars)
if err != nil {
panic(err)
}
out = append(out, ch)
}
return out
}
func randomChar(chars string) (byte, error) {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
if err != nil {
return 0, fmt.Errorf("random index: %w", err)
}
return chars[n.Int64()], nil
}
func shuffle(chars []byte) error {
for i := len(chars) - 1; i > 0; i-- {
n, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
if err != nil {
return fmt.Errorf("shuffle password: %w", err)
}
j := int(n.Int64())
chars[i], chars[j] = chars[j], chars[i]
}
return nil
}
+126
View File
@@ -0,0 +1,126 @@
package passwords
import (
"errors"
"slices"
"strings"
"testing"
)
func TestGenerateRespectsProfileRequirements(t *testing.T) {
t.Parallel()
profile := Profile{
Name: "strong",
Length: 24,
Lowercase: true,
Uppercase: true,
Digits: true,
Symbols: true,
MinLowercase: 2,
MinUppercase: 2,
MinDigits: 2,
MinSymbols: 2,
ExcludeSimilar: true,
}
password, err := Generate(profile)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
if len(password) != 24 {
t.Fatalf("len(password) = %d, want 24", len(password))
}
if countFromSet(password, lowercaseChars) < 2 {
t.Fatalf("lowercase count in %q is too small", password)
}
if countFromSet(password, uppercaseChars) < 2 {
t.Fatalf("uppercase count in %q is too small", password)
}
if countFromSet(password, digitChars) < 2 {
t.Fatalf("digit count in %q is too small", password)
}
if countFromSet(password, symbolChars) < 2 {
t.Fatalf("symbol count in %q is too small", password)
}
if strings.ContainsAny(password, "O0Il1") {
t.Fatalf("password %q contains excluded similar characters", password)
}
}
func TestGenerateRejectsImpossibleProfiles(t *testing.T) {
t.Parallel()
_, err := Generate(Profile{
Name: "bad",
Length: 6,
Lowercase: true,
Uppercase: true,
Digits: true,
Symbols: true,
MinLowercase: 2,
MinUppercase: 2,
MinDigits: 2,
MinSymbols: 2,
})
if err == nil {
t.Fatal("Generate() error = nil, want impossible profile error")
}
}
func TestProfileSetReturnsNamedProfiles(t *testing.T) {
t.Parallel()
set := DefaultProfiles()
profile, ok := set["strong"]
if !ok {
t.Fatalf("DefaultProfiles()[\"strong\"] missing")
}
if profile.Length < 20 || !profile.Symbols {
t.Fatalf("strong profile = %#v, want a strong reusable profile", profile)
}
}
func TestDefaultProfileNamesReturnsSortedNames(t *testing.T) {
t.Parallel()
got := DefaultProfileNames()
want := []string{"memorable", "strong"}
if !slices.Equal(got, want) {
t.Fatalf("DefaultProfileNames() = %v, want %v", got, want)
}
}
func TestLookupDefaultProfileResolvesKnownProfilesAndRejectsUnknownNames(t *testing.T) {
t.Parallel()
profile, err := LookupDefaultProfile(" strong ")
if err != nil {
t.Fatalf("LookupDefaultProfile(\" strong \") error = %v", err)
}
if profile.Name != "strong" {
t.Fatalf("LookupDefaultProfile(\" strong \").Name = %q, want %q", profile.Name, "strong")
}
_, err = LookupDefaultProfile("invalid")
if !errors.Is(err, ErrUnknownProfile) {
t.Fatalf("LookupDefaultProfile(\"invalid\") error = %v, want ErrUnknownProfile", err)
}
}
func countFromSet(password, chars string) int {
count := 0
for _, r := range password {
if strings.ContainsRune(chars, r) {
count++
}
}
return count
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+160
View File
@@ -0,0 +1,160 @@
package vault
import "testing"
func TestUpsertEntryPreservesPreviousVersionInHistory(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "old-token",
URL: "https://vault.crew.example.invalid",
Notes: "Original note",
Path: []string{"Root", "Internet"},
},
},
}
model.UpsertEntry(Entry{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "new-token",
URL: "https://vault.crew.example.invalid",
Notes: "Updated note",
Path: []string{"Root", "Internet"},
})
got := model.EntriesInPath([]string{"Root", "Internet"})
if len(got) != 1 {
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
}
if got[0].Password != "new-token" {
t.Fatalf("Entry.Password = %q, want %q", got[0].Password, "new-token")
}
if len(got[0].History) != 1 {
t.Fatalf("len(Entry.History) = %d, want 1", len(got[0].History))
}
if got[0].History[0].Password != "old-token" || got[0].History[0].Notes != "Original note" {
t.Fatalf("Entry.History[0] = %#v, want prior entry version", got[0].History[0])
}
}
func TestDeleteEntryMovesItToRecycleBin(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{
ID: "surveillance-console",
Title: "Surveillance Console",
Username: "codex",
Password: "token-2",
URL: "https://surveillance.crew.example.invalid",
Path: []string{"Root", "Home Assistant"},
},
},
}
if err := model.DeleteEntry("surveillance-console"); err != nil {
t.Fatalf("DeleteEntry() error = %v", err)
}
if got := model.EntriesInPath([]string{"Root", "Home Assistant"}); len(got) != 0 {
t.Fatalf("EntriesInPath() = %#v, want empty after delete", got)
}
if len(model.RecycleBin) != 1 {
t.Fatalf("len(RecycleBin) = %d, want 1", len(model.RecycleBin))
}
if model.RecycleBin[0].Title != "Surveillance Console" {
t.Fatalf("RecycleBin[0].Title = %q, want %q", model.RecycleBin[0].Title, "Surveillance Console")
}
}
func TestRestoreEntryMovesItBackFromRecycleBin(t *testing.T) {
t.Parallel()
model := Model{
RecycleBin: []Entry{
{
ID: "bellagio",
Title: "Bellagio",
Username: "rustyryan",
Password: "token-3",
URL: "https://bellagio.example.invalid",
Path: []string{"Root", "Internet"},
},
},
}
if err := model.RestoreEntry("bellagio"); err != nil {
t.Fatalf("RestoreEntry() error = %v", err)
}
got := model.EntriesInPath([]string{"Root", "Internet"})
if len(got) != 1 {
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
}
if got[0].Title != "Bellagio" {
t.Fatalf("EntriesInPath()[0].Title = %q, want %q", got[0].Title, "Bellagio")
}
if len(model.RecycleBin) != 0 {
t.Fatalf("len(RecycleBin) = %d, want 0", len(model.RecycleBin))
}
}
func TestRestoreEntryVersionPromotesHistoricalVersionAndRetainsCurrentInHistory(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "new-token",
Notes: "Current note",
Path: []string{"Root", "Internet"},
History: []Entry{
{
ID: "vault-console-history-1",
Title: "Vault Console",
Username: "dannyocean",
Password: "old-token",
Notes: "Previous note",
Path: []string{"Root", "Internet"},
},
},
},
},
}
if err := model.RestoreEntryVersion("vault-console", 0); err != nil {
t.Fatalf("RestoreEntryVersion() error = %v", err)
}
got := model.EntriesInPath([]string{"Root", "Internet"})
if len(got) != 1 {
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
}
if got[0].Password != "old-token" || got[0].Notes != "Previous note" {
t.Fatalf("restored entry = %#v, want old-token/Previous note current version", got[0])
}
if len(got[0].History) != 1 {
t.Fatalf("len(History) = %d, want 1", len(got[0].History))
}
if got[0].History[0].Password != "new-token" || got[0].History[0].Notes != "Current note" {
t.Fatalf("History[0] = %#v, want prior current version retained", got[0].History[0])
}
}
+628
View File
@@ -0,0 +1,628 @@
package vault
import (
"crypto/rand"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"slices"
"strings"
"time"
"github.com/tobischo/gokeepasslib/v3"
w "github.com/tobischo/gokeepasslib/v3/wrappers"
)
type KDBXConfig struct {
Header *gokeepasslib.DBHeader
InnerHeader *gokeepasslib.InnerHeader
}
var ErrInvalidMasterKey = errors.New("invalid master key")
const (
templatesRoot = "Templates"
recycleBinRoot = "Recycle Bin"
keepassGOIDField = "KeePassGO-ID"
remoteProfilesKey = "keepassgo.remoteProfiles"
)
func LoadKDBX(r io.Reader, password string) (Model, error) {
return LoadKDBXWithKey(r, MasterKey{Password: password})
}
func SaveKDBX(wr io.Writer, model Model, password string) error {
return SaveKDBXWithKey(wr, model, MasterKey{Password: password})
}
func SaveKDBXWithKey(wr io.Writer, model Model, key MasterKey) error {
return SaveKDBXWithConfigAndKey(wr, model, key, nil)
}
func SaveKDBXWithConfigAndKey(wr io.Writer, model Model, key MasterKey, config *KDBXConfig) error {
credentials, err := newCredentials(key)
if err != nil {
return err
}
db := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4())
db.Credentials = credentials
db.Content.Meta = gokeepasslib.NewMetaData()
db.Content.Meta.CustomData = customDataForModel(model)
db.Content.Root = &gokeepasslib.RootData{}
if config != nil && config.Header != nil {
db.Header = cloneHeader(config.Header)
db.Hashes = gokeepasslib.NewHashes(db.Header)
}
if db.Header.IsKdbx4() {
if config != nil && config.InnerHeader != nil {
db.Content.InnerHeader = cloneInnerHeader(config.InnerHeader)
db.Content.InnerHeader.Binaries = nil
} else if db.Content.InnerHeader == nil {
db.Content.InnerHeader = &gokeepasslib.InnerHeader{
InnerRandomStreamID: gokeepasslib.ChaChaStreamID,
InnerRandomStreamKey: randomBytes(64),
}
}
} else {
db.Content.InnerHeader = nil
}
db.Content.Root.Groups = buildGroupTree(db, model)
db.Content.Root.DeletedObjects = nil
if err := db.LockProtectedEntries(); err != nil {
return fmt.Errorf("lock protected entries: %w", err)
}
if err := gokeepasslib.NewEncoder(wr).Encode(db); err != nil {
return fmt.Errorf("encode kdbx: %w", err)
}
return nil
}
func appendGroupEntries(model *Model, db *gokeepasslib.Database, group gokeepasslib.Group, path []string) {
path = append(clonePath(path), group.Name)
model.CreateGroup(path[:len(path)-1], group.Name)
for _, entry := range group.Entries {
appendModelEntry(model, Entry{
ID: extractEntryID(entry),
Title: entry.GetTitle(),
Username: entry.GetContent("UserName"),
Password: entry.GetPassword(),
URL: entry.GetContent("URL"),
Notes: entry.GetContent("Notes"),
Tags: splitTags(entry.Tags),
Fields: extractCustomFields(entry),
Attachments: extractAttachments(db, entry),
History: extractHistory(db, entry, path),
Path: clonePath(path),
})
}
for _, child := range group.Groups {
appendGroupEntries(model, db, child, path)
}
}
func appendModelEntry(model *Model, entry Entry) {
if len(entry.Path) == 0 {
model.Entries = append(model.Entries, entry)
return
}
switch entry.Path[0] {
case templatesRoot:
model.Templates = append(model.Templates, entry)
return
case recycleBinRoot:
entry.Path = slices.Clone(entry.Path[1:])
model.RecycleBin = append(model.RecycleBin, entry)
return
}
model.Entries = append(model.Entries, entry)
}
func entriesForPersistence(model Model) []Entry {
entries := append(slices.Clone(model.Entries), model.Templates...)
for _, entry := range model.RecycleBin {
recycleEntry := cloneEntry(entry)
recycleEntry.Path = append([]string{recycleBinRoot}, recycleEntry.Path...)
entries = append(entries, recycleEntry)
}
return entries
}
func marshalUUID(id gokeepasslib.UUID) string {
text, err := id.MarshalText()
if err != nil {
return ""
}
return string(text)
}
func clonePath(path []string) []string {
if len(path) == 0 {
return nil
}
out := make([]string, len(path))
copy(out, path)
return out
}
func splitTags(tags string) []string {
if strings.TrimSpace(tags) == "" {
return nil
}
fields := strings.Split(tags, ";")
var out []string
for _, field := range fields {
field = strings.TrimSpace(field)
if field == "" {
continue
}
out = append(out, field)
}
return out
}
func extractCustomFields(entry gokeepasslib.Entry) map[string]string {
fields := map[string]string{}
for _, value := range entry.Values {
switch value.Key {
case "Title", "UserName", "Password", "URL", "Notes", keepassGOIDField:
continue
default:
fields[value.Key] = value.Value.Content
}
}
if len(fields) == 0 {
return nil
}
return fields
}
func extractEntryID(entry gokeepasslib.Entry) string {
if id := entry.GetContent(keepassGOIDField); id != "" {
return id
}
return marshalUUID(entry.UUID)
}
func extractHistory(db *gokeepasslib.Database, entry gokeepasslib.Entry, path []string) []Entry {
if len(entry.Histories) == 0 {
return nil
}
var history []Entry
for _, item := range entry.Histories {
for _, historical := range item.Entries {
history = append(history, Entry{
ID: extractEntryID(historical),
Title: historical.GetTitle(),
Username: historical.GetContent("UserName"),
Password: historical.GetPassword(),
URL: historical.GetContent("URL"),
Notes: historical.GetContent("Notes"),
Tags: splitTags(historical.Tags),
Fields: extractCustomFields(historical),
Attachments: extractAttachments(db, historical),
Path: clonePath(path),
})
}
}
return history
}
type groupNode struct {
name string
children map[string]*groupNode
entries []Entry
}
type MasterKey struct {
Password string
KeyFileData []byte
}
func buildGroupTree(db *gokeepasslib.Database, model Model) []gokeepasslib.Group {
entries := entriesForPersistence(model)
root := &groupNode{children: map[string]*groupNode{}}
for _, entry := range entries {
node := root
for _, segment := range entry.Path {
if node.children[segment] == nil {
node.children[segment] = &groupNode{
name: segment,
children: map[string]*groupNode{},
}
}
node = node.children[segment]
}
node.entries = append(node.entries, entry)
}
for _, path := range groupPathsForPersistence(model, entries) {
node := root
for _, segment := range path {
if node.children[segment] == nil {
node.children[segment] = &groupNode{
name: segment,
children: map[string]*groupNode{},
}
}
node = node.children[segment]
}
}
groups := marshalGroups(db, root)
if len(groups) > 0 {
return groups
}
group := gokeepasslib.NewGroup()
group.Name = "Root"
return []gokeepasslib.Group{group}
}
func groupPathsForPersistence(model Model, entries []Entry) [][]string {
seen := map[string]bool{}
var groups [][]string
appendPath := func(path []string) {
key := strings.Join(path, "\x00")
if seen[key] {
return
}
seen[key] = true
groups = append(groups, slices.Clone(path))
}
for _, entry := range entries {
for i := 1; i <= len(entry.Path); i++ {
appendPath(entry.Path[:i])
}
}
for _, path := range model.Groups {
for i := 1; i <= len(path); i++ {
appendPath(path[:i])
}
}
return groups
}
func LoadKDBXWithKey(r io.Reader, key MasterKey) (Model, error) {
model, _, err := LoadKDBXWithConfig(r, key)
return model, err
}
func LoadKDBXWithConfig(r io.Reader, key MasterKey) (Model, *KDBXConfig, error) {
credentials, err := newCredentials(key)
if err != nil {
return Model{}, nil, err
}
db := gokeepasslib.NewDatabase()
db.Credentials = credentials
if err := gokeepasslib.NewDecoder(r).Decode(db); err != nil {
if isInvalidCredentialError(err) {
return Model{}, nil, ErrInvalidMasterKey
}
return Model{}, nil, fmt.Errorf("decode kdbx: %w", err)
}
if err := db.UnlockProtectedEntries(); err != nil {
return Model{}, nil, fmt.Errorf("unlock protected entries: %w", err)
}
var model Model
for _, group := range db.Content.Root.Groups {
appendGroupEntries(&model, db, group, nil)
}
model.RemoteProfiles = remoteProfilesFromMeta(db.Content.Meta)
return model, &KDBXConfig{
Header: cloneHeader(db.Header),
InnerHeader: cloneInnerHeader(db.Content.InnerHeader),
}, nil
}
func customDataForModel(model Model) []gokeepasslib.CustomData {
if len(model.RemoteProfiles) == 0 {
return nil
}
content, err := json.Marshal(model.RemoteProfiles)
if err != nil {
return nil
}
return []gokeepasslib.CustomData{{
Key: remoteProfilesKey,
Value: string(content),
}}
}
func remoteProfilesFromMeta(meta *gokeepasslib.MetaData) []RemoteProfile {
if meta == nil {
return nil
}
for _, item := range meta.CustomData {
if item.Key != remoteProfilesKey {
continue
}
var profiles []RemoteProfile
if err := json.Unmarshal([]byte(item.Value), &profiles); err != nil {
return nil
}
return profiles
}
return nil
}
func newCredentials(key MasterKey) (*gokeepasslib.DBCredentials, error) {
switch {
case key.Password != "" && len(key.KeyFileData) > 0:
credentials, err := gokeepasslib.NewPasswordAndKeyDataCredentials(key.Password, key.KeyFileData)
if err != nil {
return nil, fmt.Errorf("build password+key credentials: %w", err)
}
return credentials, nil
case len(key.KeyFileData) > 0:
credentials, err := gokeepasslib.NewKeyDataCredentials(key.KeyFileData)
if err != nil {
return nil, fmt.Errorf("build key credentials: %w", err)
}
return credentials, nil
default:
return gokeepasslib.NewPasswordCredentials(key.Password), nil
}
}
func cloneHeader(header *gokeepasslib.DBHeader) *gokeepasslib.DBHeader {
if header == nil {
return nil
}
out := *header
out.RawData = nil
if header.Signature != nil {
signature := *header.Signature
out.Signature = &signature
}
if header.FileHeaders != nil {
fileHeaders := *header.FileHeaders
fileHeaders.Comment = slices.Clone(header.FileHeaders.Comment)
fileHeaders.CipherID = slices.Clone(header.FileHeaders.CipherID)
fileHeaders.MasterSeed = slices.Clone(header.FileHeaders.MasterSeed)
fileHeaders.TransformSeed = slices.Clone(header.FileHeaders.TransformSeed)
fileHeaders.EncryptionIV = slices.Clone(header.FileHeaders.EncryptionIV)
fileHeaders.ProtectedStreamKey = slices.Clone(header.FileHeaders.ProtectedStreamKey)
fileHeaders.StreamStartBytes = slices.Clone(header.FileHeaders.StreamStartBytes)
if header.FileHeaders.KdfParameters != nil {
kdf := *header.FileHeaders.KdfParameters
kdf.UUID = slices.Clone(header.FileHeaders.KdfParameters.UUID)
kdf.SecretKey = slices.Clone(header.FileHeaders.KdfParameters.SecretKey)
kdf.AssocData = slices.Clone(header.FileHeaders.KdfParameters.AssocData)
if header.FileHeaders.KdfParameters.RawData != nil {
kdf.RawData = cloneVariantDictionary(header.FileHeaders.KdfParameters.RawData)
}
fileHeaders.KdfParameters = &kdf
}
if header.FileHeaders.PublicCustomData != nil {
fileHeaders.PublicCustomData = cloneVariantDictionary(header.FileHeaders.PublicCustomData)
}
out.FileHeaders = &fileHeaders
}
return &out
}
func cloneVariantDictionary(dict *gokeepasslib.VariantDictionary) *gokeepasslib.VariantDictionary {
if dict == nil {
return nil
}
out := &gokeepasslib.VariantDictionary{Version: dict.Version}
out.Items = make([]*gokeepasslib.VariantDictionaryItem, 0, len(dict.Items))
for _, item := range dict.Items {
cloned := *item
cloned.Name = slices.Clone(item.Name)
cloned.Value = slices.Clone(item.Value)
out.Items = append(out.Items, &cloned)
}
return out
}
func cloneInnerHeader(header *gokeepasslib.InnerHeader) *gokeepasslib.InnerHeader {
if header == nil {
return nil
}
out := &gokeepasslib.InnerHeader{
InnerRandomStreamID: header.InnerRandomStreamID,
InnerRandomStreamKey: slices.Clone(header.InnerRandomStreamKey),
}
for _, binary := range header.Binaries {
out.Binaries = append(out.Binaries, gokeepasslib.Binary{
ID: binary.ID,
Compressed: binary.Compressed,
MemoryProtection: binary.MemoryProtection,
Content: slices.Clone(binary.Content),
})
}
return out
}
func randomBytes(length int) []byte {
buf := make([]byte, length)
_, _ = io.ReadFull(rand.Reader, buf)
return buf
}
func isInvalidCredentialError(err error) bool {
if errors.Is(err, gokeepasslib.ErrInvalidDatabaseOrCredentials) {
return true
}
return strings.Contains(err.Error(), "Wrong password?")
}
func marshalGroups(db *gokeepasslib.Database, node *groupNode) []gokeepasslib.Group {
names := slices.Collect(maps.Keys(node.children))
slices.SortFunc(names, compareGroupNames)
var groups []gokeepasslib.Group
for _, name := range names {
child := node.children[name]
group := gokeepasslib.NewGroup()
group.Name = child.name
group.Entries = marshalEntries(db, child.entries)
group.Groups = marshalGroups(db, child)
groups = append(groups, group)
}
return groups
}
func compareGroupNames(a, b string) int {
switch {
case a == b:
return 0
case a == "Root":
return -1
case b == "Root":
return 1
case a == templatesRoot:
return -1
case b == templatesRoot:
return 1
case a == recycleBinRoot:
return 1
case b == recycleBinRoot:
return -1
case a < b:
return -1
default:
return 1
}
}
func marshalEntries(db *gokeepasslib.Database, entries []Entry) []gokeepasslib.Entry {
slices.SortFunc(entries, func(a, b Entry) int {
switch {
case a.Title < b.Title:
return -1
case a.Title > b.Title:
return 1
default:
return 0
}
})
var out []gokeepasslib.Entry
for _, entry := range entries {
out = append(out, marshalEntry(db, entry))
}
return out
}
func marshalEntry(db *gokeepasslib.Database, entry Entry) gokeepasslib.Entry {
item := gokeepasslib.NewEntry()
item.UUID = uuidForEntryID(entry.ID)
item.Tags = strings.Join(entry.Tags, "; ")
item.Values = append(item.Values,
value("Title", entry.Title),
value("UserName", entry.Username),
protectedValue("Password", entry.Password),
value("URL", entry.URL),
value("Notes", entry.Notes),
value(keepassGOIDField, entry.ID),
)
keys := slices.Collect(maps.Keys(entry.Fields))
slices.Sort(keys)
for _, key := range keys {
item.Values = append(item.Values, value(key, entry.Fields[key]))
}
attachmentNames := slices.Collect(maps.Keys(entry.Attachments))
slices.Sort(attachmentNames)
for _, name := range attachmentNames {
binary := db.AddBinary(entry.Attachments[name])
item.Binaries = append(item.Binaries, binary.CreateReference(name))
}
for _, historical := range entry.History {
item.Histories = append(item.Histories, gokeepasslib.History{
Entries: []gokeepasslib.Entry{marshalEntry(db, historical)},
})
}
return item
}
func uuidForEntryID(id string) gokeepasslib.UUID {
if id != "" {
var uuid gokeepasslib.UUID
if err := uuid.UnmarshalText([]byte(id)); err == nil {
return uuid
}
}
sum := sha256.Sum256([]byte(id))
var uuid gokeepasslib.UUID
copy(uuid[:], sum[:len(uuid)])
if id == "" {
copy(uuid[:], time.Now().UTC().AppendFormat(nil, time.RFC3339Nano))
}
return uuid
}
func value(key, content string) gokeepasslib.ValueData {
return gokeepasslib.ValueData{Key: key, Value: gokeepasslib.V{Content: content}}
}
func protectedValue(key, content string) gokeepasslib.ValueData {
return gokeepasslib.ValueData{
Key: key,
Value: gokeepasslib.V{Content: content, Protected: w.NewBoolWrapper(true)},
}
}
func extractAttachments(db *gokeepasslib.Database, entry gokeepasslib.Entry) map[string][]byte {
if len(entry.Binaries) == 0 {
return nil
}
attachments := map[string][]byte{}
for _, ref := range entry.Binaries {
binary := db.FindBinary(ref.Value.ID)
if binary == nil {
continue
}
content, err := binary.GetContentBytes()
if err != nil {
continue
}
attachments[ref.Name] = slices.Clone(content)
}
if len(attachments) == 0 {
return nil
}
return attachments
}
+794
View File
@@ -0,0 +1,794 @@
package vault
import (
"bytes"
"errors"
"slices"
"testing"
"github.com/tobischo/gokeepasslib/v3"
w "github.com/tobischo/gokeepasslib/v3/wrappers"
)
func TestLoadKDBXBuildsModelFromNestedGroups(t *testing.T) {
t.Parallel()
db := &gokeepasslib.Database{
Header: gokeepasslib.NewHeader(),
Credentials: gokeepasslib.NewPasswordCredentials("correct horse battery staple"),
Content: &gokeepasslib.DBContent{
Meta: gokeepasslib.NewMetaData(),
Root: &gokeepasslib.RootData{
Groups: []gokeepasslib.Group{
mustGroup("Root",
mustGroup("Internet",
mustEntry("Bellagio", "rustyryan", "https://bellagio.example.invalid", "hunter2"),
mustEntry("Vault Console", "dannyocean", "https://vault.crew.example.invalid", "bellagio-pass-1"),
),
mustGroup("Security Office",
mustEntry("Surveillance Console", "bashertarr", "https://surveillance.crew.example.invalid", "bellagio-pass-2"),
),
),
},
},
},
}
if err := db.LockProtectedEntries(); err != nil {
t.Fatalf("LockProtectedEntries failed: %v", err)
}
var encoded bytes.Buffer
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
t.Fatalf("Encode failed: %v", err)
}
model, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX failed: %v", err)
}
got := model.EntriesInPath([]string{"Root", "Internet"})
if len(got) != 2 {
t.Fatalf("len(EntriesInPath()) = %d, want 2", len(got))
}
if got[0].Title != "Bellagio" || got[0].Username != "rustyryan" || got[0].URL != "https://bellagio.example.invalid" {
t.Fatalf("unexpected first entry: %#v", got[0])
}
if got[1].Title != "Vault Console" || got[1].Username != "dannyocean" || got[1].URL != "https://vault.crew.example.invalid" {
t.Fatalf("unexpected second entry: %#v", got[1])
}
groups := model.ChildGroups([]string{"Root"})
if len(groups) != 2 || groups[0] != "Internet" || groups[1] != "Security Office" {
t.Fatalf("ChildGroups() = %v, want [Internet Security Office]", groups)
}
}
func TestLoadKDBXPreservesEntryDetails(t *testing.T) {
t.Parallel()
entry := mustEntry("Surveillance Console", "bashertarr", "https://surveillance.crew.example.invalid", "bellagio-pass-2")
entry.Tags = "automation; home"
entry.Values = append(entry.Values,
mkValue("Notes", "Long-lived token used by Codex for home automation tasks."),
mkValue("X-Role", "automation"),
)
db := &gokeepasslib.Database{
Header: gokeepasslib.NewHeader(),
Credentials: gokeepasslib.NewPasswordCredentials("correct horse battery staple"),
Content: &gokeepasslib.DBContent{
Meta: gokeepasslib.NewMetaData(),
Root: &gokeepasslib.RootData{
Groups: []gokeepasslib.Group{
mustGroup("Root", mustGroup("Security Office", entry)),
},
},
},
}
if err := db.LockProtectedEntries(); err != nil {
t.Fatalf("LockProtectedEntries failed: %v", err)
}
var encoded bytes.Buffer
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
t.Fatalf("Encode failed: %v", err)
}
model, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX failed: %v", err)
}
got := model.EntriesInPath([]string{"Root", "Security Office"})
if len(got) != 1 {
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
}
if got[0].Password != "bellagio-pass-2" {
t.Fatalf("Entry.Password = %q, want %q", got[0].Password, "bellagio-pass-2")
}
if got[0].Notes != "Long-lived token used by Codex for home automation tasks." {
t.Fatalf("Entry.Notes = %q, want %q", got[0].Notes, "Long-lived token used by Codex for home automation tasks.")
}
if len(got[0].Tags) != 2 || got[0].Tags[0] != "automation" || got[0].Tags[1] != "home" {
t.Fatalf("Entry.Tags = %v, want [automation home]", got[0].Tags)
}
if got[0].Fields["X-Role"] != "automation" {
t.Fatalf("Entry.Fields[\"X-Role\"] = %q, want %q", got[0].Fields["X-Role"], "automation")
}
}
func TestSaveKDBXRoundTripsModel(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{
ID: "entry-1",
Title: "Vault Console",
Username: "dannyocean",
Password: "bellagio-pass-1",
URL: "https://vault.crew.example.invalid",
Notes: "Personal git server token entry used for automation and CLI auth.",
Tags: []string{"git", "infra"},
Fields: map[string]string{
"X-Role": "automation",
},
Path: []string{"Root", "Internet"},
},
{
ID: "entry-2",
Title: "Surveillance Console",
Username: "bashertarr",
Password: "bellagio-pass-2",
URL: "https://surveillance.crew.example.invalid",
Notes: "Long-lived token used by Codex for home automation tasks.",
Tags: []string{"automation", "home"},
Path: []string{"Root", "Security Office"},
},
},
}
var encoded bytes.Buffer
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
t.Fatalf("SaveKDBX() error = %v", err)
}
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX() error = %v", err)
}
got := loaded.Search("vault")
if len(got) != 1 {
t.Fatalf("len(Search(\"git\")) = %d, want 1", len(got))
}
if got[0].Entry.Notes != "Personal git server token entry used for automation and CLI auth." {
t.Fatalf("Search(\"git\") notes = %q, want %q", got[0].Entry.Notes, "Personal git server token entry used for automation and CLI auth.")
}
if got[0].Entry.Fields["X-Role"] != "automation" {
t.Fatalf("Search(\"git\") X-Role = %q, want %q", got[0].Entry.Fields["X-Role"], "automation")
}
homeAssistant := loaded.EntriesInPath([]string{"Root", "Security Office"})
if len(homeAssistant) != 1 {
t.Fatalf("len(EntriesInPath(Security Office)) = %d, want 1", len(homeAssistant))
}
if homeAssistant[0].Password != "bellagio-pass-2" {
t.Fatalf("Security Office password = %q, want %q", homeAssistant[0].Password, "bellagio-pass-2")
}
}
func TestSaveKDBXRoundTripsTemplates(t *testing.T) {
t.Parallel()
model := Model{
Templates: []Entry{
{
ID: "tpl-1",
Title: "Website Login",
Username: "template-user",
Password: "template-password",
URL: "https://example.com",
Notes: "Reusable template for website accounts.",
Tags: []string{"template", "web"},
Fields: map[string]string{
"Environment": "prod",
},
Path: []string{"Templates"},
},
},
}
var encoded bytes.Buffer
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
t.Fatalf("SaveKDBX() error = %v", err)
}
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX() error = %v", err)
}
if len(loaded.Templates) != 1 {
t.Fatalf("len(Templates) = %d, want 1", len(loaded.Templates))
}
if loaded.Templates[0].Title != "Website Login" {
t.Fatalf("Templates[0].Title = %q, want %q", loaded.Templates[0].Title, "Website Login")
}
if loaded.Templates[0].Fields["Environment"] != "prod" {
t.Fatalf("Templates[0].Fields[Environment] = %q, want %q", loaded.Templates[0].Fields["Environment"], "prod")
}
if len(loaded.Entries) != 0 {
t.Fatalf("len(Entries) = %d, want 0", len(loaded.Entries))
}
}
func TestSaveKDBXRoundTripsRemoteProfiles(t *testing.T) {
t.Parallel()
model := Model{
RemoteProfiles: []RemoteProfile{
{
ID: "bellagio-webdav",
Name: "Bellagio Vault",
Backend: RemoteBackendWebDAV,
BaseURL: "https://dav.example.invalid/remote.php/dav",
Path: "files/bellagio/keepass.kdbx",
},
},
}
var encoded bytes.Buffer
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
t.Fatalf("SaveKDBX() error = %v", err)
}
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX() error = %v", err)
}
if len(loaded.RemoteProfiles) != 1 {
t.Fatalf("len(RemoteProfiles) = %d, want 1", len(loaded.RemoteProfiles))
}
got := loaded.RemoteProfiles[0]
if got.ID != "bellagio-webdav" || got.Name != "Bellagio Vault" {
t.Fatalf("loaded remote profile = %#v, want bellagio-webdav Bellagio Vault", got)
}
if got.Backend != RemoteBackendWebDAV {
t.Fatalf("remote backend = %q, want %q", got.Backend, RemoteBackendWebDAV)
}
if got.BaseURL != "https://dav.example.invalid/remote.php/dav" {
t.Fatalf("remote base URL = %q, want remote.php/dav URL", got.BaseURL)
}
if got.Path != "files/bellagio/keepass.kdbx" {
t.Fatalf("remote path = %q, want files/bellagio/keepass.kdbx", got.Path)
}
}
func TestSaveKDBXRoundTripsEntryHistory(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{
ID: "entry-1",
Title: "Vault Console",
Username: "dannyocean",
Password: "new-token",
URL: "https://vault.crew.example.invalid",
Path: []string{"Root", "Internet"},
History: []Entry{
{
ID: "entry-1-old",
Title: "Vault Console",
Username: "dannyocean",
Password: "old-token",
URL: "https://vault.crew.example.invalid",
Path: []string{"Root", "Internet"},
Notes: "Original version",
},
},
},
},
}
var encoded bytes.Buffer
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
t.Fatalf("SaveKDBX() error = %v", err)
}
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX() error = %v", err)
}
got := loaded.EntriesInPath([]string{"Root", "Internet"})
if len(got) != 1 {
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
}
if len(got[0].History) != 1 {
t.Fatalf("len(History) = %d, want 1", len(got[0].History))
}
if got[0].History[0].Password != "old-token" || got[0].History[0].Notes != "Original version" {
t.Fatalf("History[0] = %#v, want preserved prior version", got[0].History[0])
}
}
func TestSaveKDBXRoundTripsRecycleBinEntries(t *testing.T) {
t.Parallel()
model := Model{
RecycleBin: []Entry{
{
ID: "entry-1",
Title: "Surveillance Console",
Username: "bashertarr",
Password: "bellagio-pass-2",
URL: "https://surveillance.crew.example.invalid",
Path: []string{"Root", "Security Office"},
},
},
}
var encoded bytes.Buffer
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
t.Fatalf("SaveKDBX() error = %v", err)
}
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX() error = %v", err)
}
if len(loaded.RecycleBin) != 1 {
t.Fatalf("len(RecycleBin) = %d, want 1", len(loaded.RecycleBin))
}
if loaded.RecycleBin[0].Title != "Surveillance Console" {
t.Fatalf("RecycleBin[0].Title = %q, want %q", loaded.RecycleBin[0].Title, "Surveillance Console")
}
if len(loaded.RecycleBin[0].Path) != 2 || loaded.RecycleBin[0].Path[0] != "Root" || loaded.RecycleBin[0].Path[1] != "Security Office" {
t.Fatalf("RecycleBin[0].Path = %v, want [Root Security Office]", loaded.RecycleBin[0].Path)
}
if len(loaded.Entries) != 0 {
t.Fatalf("len(Entries) = %d, want 0", len(loaded.Entries))
}
}
func TestLoadKDBXWithKeyFileCredentials(t *testing.T) {
t.Parallel()
keyData := []byte(`<?xml version="1.0" encoding="utf-8"?>
<KeyFile>
<Meta>
<Version>1.0</Version>
</Meta>
<Key>
<Data>PbLBYmgEXFhLWf2gxoBMARXgDZGE7f34tr+anCw52LI=</Data>
</Key>
</KeyFile>
`)
credentials, err := newCredentials(MasterKey{KeyFileData: keyData})
if err != nil {
t.Fatalf("newCredentials() error = %v", err)
}
db := &gokeepasslib.Database{
Header: gokeepasslib.NewHeader(),
Credentials: credentials,
Content: &gokeepasslib.DBContent{
Meta: gokeepasslib.NewMetaData(),
Root: &gokeepasslib.RootData{
Groups: []gokeepasslib.Group{
mustGroup("Root", mustGroup("Internet", mustEntry("Vault Console", "dannyocean", "https://vault.crew.example.invalid", "bellagio-pass-1"))),
},
},
},
}
if err := db.LockProtectedEntries(); err != nil {
t.Fatalf("LockProtectedEntries failed: %v", err)
}
var encoded bytes.Buffer
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
t.Fatalf("Encode failed: %v", err)
}
model, err := LoadKDBXWithKey(bytes.NewReader(encoded.Bytes()), MasterKey{KeyFileData: keyData})
if err != nil {
t.Fatalf("LoadKDBXWithKey() error = %v", err)
}
got := model.Search("vault")
if len(got) != 1 || got[0].Entry.Password != "bellagio-pass-1" {
t.Fatalf("LoadKDBXWithKey() = %#v, want password-preserving vault entry", got)
}
}
func TestLoadKDBXWithCompositeCredentials(t *testing.T) {
t.Parallel()
keyData := []byte(`<?xml version="1.0" encoding="utf-8"?>
<KeyFile>
<Meta>
<Version>1.0</Version>
</Meta>
<Key>
<Data>PbLBYmgEXFhLWf2gxoBMARXgDZGE7f34tr+anCw52LI=</Data>
</Key>
</KeyFile>
`)
credentials, err := newCredentials(MasterKey{
Password: "correct horse battery staple",
KeyFileData: keyData,
})
if err != nil {
t.Fatalf("newCredentials() error = %v", err)
}
db := &gokeepasslib.Database{
Header: gokeepasslib.NewHeader(),
Credentials: credentials,
Content: &gokeepasslib.DBContent{
Meta: gokeepasslib.NewMetaData(),
Root: &gokeepasslib.RootData{
Groups: []gokeepasslib.Group{
mustGroup("Root", mustGroup("Security Office", mustEntry("Surveillance Console", "bashertarr", "https://surveillance.crew.example.invalid", "bellagio-pass-2"))),
},
},
},
}
if err := db.LockProtectedEntries(); err != nil {
t.Fatalf("LockProtectedEntries failed: %v", err)
}
var encoded bytes.Buffer
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
t.Fatalf("Encode failed: %v", err)
}
model, err := LoadKDBXWithKey(bytes.NewReader(encoded.Bytes()), MasterKey{
Password: "correct horse battery staple",
KeyFileData: keyData,
})
if err != nil {
t.Fatalf("LoadKDBXWithKey() error = %v", err)
}
got := model.EntriesInPath([]string{"Root", "Security Office"})
if len(got) != 1 || got[0].Password != "bellagio-pass-2" {
t.Fatalf("LoadKDBXWithKey() = %#v, want Security Office entry with password", got)
}
}
func TestLoadKDBXReturnsInvalidCredentialsError(t *testing.T) {
t.Parallel()
db := &gokeepasslib.Database{
Header: gokeepasslib.NewHeader(),
Credentials: gokeepasslib.NewPasswordCredentials("correct horse battery staple"),
Content: &gokeepasslib.DBContent{
Meta: gokeepasslib.NewMetaData(),
Root: &gokeepasslib.RootData{
Groups: []gokeepasslib.Group{
mustGroup("Root", mustGroup("Internet", mustEntry("Vault Console", "dannyocean", "https://vault.crew.example.invalid", "bellagio-pass-1"))),
},
},
},
}
if err := db.LockProtectedEntries(); err != nil {
t.Fatalf("LockProtectedEntries failed: %v", err)
}
var encoded bytes.Buffer
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
t.Fatalf("Encode failed: %v", err)
}
_, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "definitely wrong password")
if !errors.Is(err, ErrInvalidMasterKey) {
t.Fatalf("LoadKDBX() error = %v, want %v", err, ErrInvalidMasterKey)
}
}
func TestSaveKDBXWithKeyRoundTripsModel(t *testing.T) {
t.Parallel()
keyData := []byte(`<?xml version="1.0" encoding="utf-8"?>
<KeyFile>
<Meta>
<Version>1.0</Version>
</Meta>
<Key>
<Data>PbLBYmgEXFhLWf2gxoBMARXgDZGE7f34tr+anCw52LI=</Data>
</Key>
</KeyFile>
`)
model := Model{
Entries: []Entry{
{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "bellagio-pass-1",
URL: "https://vault.crew.example.invalid",
Path: []string{"Root", "Internet"},
},
},
}
var encoded bytes.Buffer
if err := SaveKDBXWithKey(&encoded, model, MasterKey{KeyFileData: keyData}); err != nil {
t.Fatalf("SaveKDBXWithKey() error = %v", err)
}
loaded, err := LoadKDBXWithKey(bytes.NewReader(encoded.Bytes()), MasterKey{KeyFileData: keyData})
if err != nil {
t.Fatalf("LoadKDBXWithKey() error = %v", err)
}
got := loaded.Search("vault")
if len(got) != 1 || got[0].Entry.Password != "bellagio-pass-1" {
t.Fatalf("round-trip with key file = %#v, want vault entry with password", got)
}
}
func TestSaveKDBXWithCompositeKeyRoundTripsModel(t *testing.T) {
t.Parallel()
keyData := []byte(`<?xml version="1.0" encoding="utf-8"?>
<KeyFile>
<Meta>
<Version>1.0</Version>
</Meta>
<Key>
<Data>PbLBYmgEXFhLWf2gxoBMARXgDZGE7f34tr+anCw52LI=</Data>
</Key>
</KeyFile>
`)
model := Model{
Entries: []Entry{
{
ID: "surveillance-console",
Title: "Surveillance Console",
Username: "bashertarr",
Password: "bellagio-pass-2",
URL: "https://surveillance.crew.example.invalid",
Path: []string{"Root", "Security Office"},
},
},
}
key := MasterKey{
Password: "correct horse battery staple",
KeyFileData: keyData,
}
var encoded bytes.Buffer
if err := SaveKDBXWithKey(&encoded, model, key); err != nil {
t.Fatalf("SaveKDBXWithKey() error = %v", err)
}
loaded, err := LoadKDBXWithKey(bytes.NewReader(encoded.Bytes()), key)
if err != nil {
t.Fatalf("LoadKDBXWithKey() error = %v", err)
}
got := loaded.EntriesInPath([]string{"Root", "Security Office"})
if len(got) != 1 || got[0].Password != "bellagio-pass-2" {
t.Fatalf("composite key round-trip = %#v, want Security Office entry with password", got)
}
}
func TestKDBXRoundTripsEntryAttachments(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{
ID: "vault-console",
Title: "Vault Console",
Username: "dannyocean",
Password: "bellagio-pass-1",
URL: "https://vault.crew.example.invalid",
Path: []string{"Root", "Internet"},
Attachments: map[string][]byte{
"token.txt": []byte("secret attachment contents"),
},
},
},
}
var encoded bytes.Buffer
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
t.Fatalf("SaveKDBX() error = %v", err)
}
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX() error = %v", err)
}
got := loaded.EntriesInPath([]string{"Root", "Internet"})
if len(got) != 1 {
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
}
if string(got[0].Attachments["token.txt"]) != "secret attachment contents" {
t.Fatalf("attachment contents = %q, want %q", string(got[0].Attachments["token.txt"]), "secret attachment contents")
}
}
func TestKDBXReopenCyclesPreserveStableIDsAndCrossFeatureState(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{
ID: "entry-1",
Title: "Vault Console",
Username: "dannyocean",
Password: "bellagio-pass-2",
URL: "https://vault.crew.example.invalid",
Notes: "Current credential",
Path: []string{"Root", "Internet"},
Attachments: map[string][]byte{
"token.txt": []byte("secret attachment contents"),
},
History: []Entry{
{
ID: "entry-1-history-1",
Title: "Vault Console",
Username: "dannyocean",
Password: "bellagio-pass-1",
URL: "https://vault.crew.example.invalid",
Notes: "Original credential",
Path: []string{"Root", "Internet"},
},
},
},
},
Templates: []Entry{
{
ID: "tpl-1",
Title: "Website Login",
Username: "template-user",
Password: "template-password",
Path: []string{"Templates", "Web"},
},
},
RecycleBin: []Entry{
{
ID: "deleted-1",
Title: "Retired Entry",
Username: "archived-user",
Password: "retired-token",
Path: []string{"Root", "Archive"},
},
},
Groups: [][]string{
{"Root", "Archive"},
{"Root", "Empty Group"},
{"Templates", "Web"},
},
}
var encoded bytes.Buffer
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
t.Fatalf("SaveKDBX(first cycle) error = %v", err)
}
reopened, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX(first cycle) error = %v", err)
}
encoded.Reset()
if err := SaveKDBX(&encoded, reopened, "correct horse battery staple"); err != nil {
t.Fatalf("SaveKDBX(second cycle) error = %v", err)
}
reopened, err = LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
if err != nil {
t.Fatalf("LoadKDBX(second cycle) error = %v", err)
}
got := reopened.EntriesInPath([]string{"Root", "Internet"})
if len(got) != 1 {
t.Fatalf("len(EntriesInPath(Root/Internet)) = %d, want 1", len(got))
}
if got[0].ID != "entry-1" {
t.Fatalf("entry ID after reopen cycles = %q, want %q", got[0].ID, "entry-1")
}
if len(got[0].History) != 1 {
t.Fatalf("len(History) after reopen cycles = %d, want 1", len(got[0].History))
}
if got[0].History[0].ID != "entry-1-history-1" {
t.Fatalf("history ID after reopen cycles = %q, want %q", got[0].History[0].ID, "entry-1-history-1")
}
if string(got[0].Attachments["token.txt"]) != "secret attachment contents" {
t.Fatalf("attachment after reopen cycles = %q, want %q", string(got[0].Attachments["token.txt"]), "secret attachment contents")
}
if len(reopened.Templates) != 1 || reopened.Templates[0].Path[1] != "Web" {
t.Fatalf("Templates after reopen cycles = %#v, want Website Login in Templates/Web", reopened.Templates)
}
if len(reopened.RecycleBin) != 1 || reopened.RecycleBin[0].Path[1] != "Archive" {
t.Fatalf("RecycleBin after reopen cycles = %#v, want recycled entry in Root/Archive", reopened.RecycleBin)
}
rootGroups := reopened.ChildGroups([]string{"Root"})
if !slices.Equal(rootGroups, []string{"Archive", "Empty Group", "Internet"}) {
t.Fatalf("ChildGroups(Root) after reopen cycles = %v, want [Archive Empty Group Internet]", rootGroups)
}
templateGroups := reopened.ChildGroups([]string{"Templates"})
if !slices.Equal(templateGroups, []string{"Web"}) {
t.Fatalf("ChildGroups(Templates) after reopen cycles = %v, want [Web]", templateGroups)
}
}
func mustGroup(name string, children ...any) gokeepasslib.Group {
group := gokeepasslib.NewGroup()
group.Name = name
for _, child := range children {
switch value := child.(type) {
case gokeepasslib.Group:
group.Groups = append(group.Groups, value)
case gokeepasslib.Entry:
group.Entries = append(group.Entries, value)
default:
panic("unsupported child type")
}
}
return group
}
func mustEntry(title, username, url, password string) gokeepasslib.Entry {
entry := gokeepasslib.NewEntry()
entry.Values = append(entry.Values,
mkValue("Title", title),
mkValue("UserName", username),
mkValue("URL", url),
mkProtectedValue("Password", password),
)
return entry
}
func mkValue(key, value string) gokeepasslib.ValueData {
return gokeepasslib.ValueData{Key: key, Value: gokeepasslib.V{Content: value}}
}
func mkProtectedValue(key, value string) gokeepasslib.ValueData {
return gokeepasslib.ValueData{
Key: key,
Value: gokeepasslib.V{Content: value, Protected: w.NewBoolWrapper(true)},
}
}
+13
View File
@@ -0,0 +1,13 @@
package vault
// MasterKeyMode identifies which key material the user intends to provide.
type MasterKeyMode string
const (
// MasterKeyModePasswordOnly requires a master password and no key file.
MasterKeyModePasswordOnly MasterKeyMode = "password-only"
// MasterKeyModeKeyFileOnly requires a key file and no password.
MasterKeyModeKeyFileOnly MasterKeyMode = "key-file-only"
// MasterKeyModePasswordAndKeyFile requires both password and key file.
MasterKeyModePasswordAndKeyFile MasterKeyMode = "password-and-key-file"
)
+601
View File
@@ -0,0 +1,601 @@
package vault
import (
"errors"
"slices"
"strings"
)
var ErrEntryNotFound = errors.New("entry not found")
var ErrGroupNotEmpty = errors.New("group is not empty")
var ErrRemoteProfileNotFound = errors.New("remote profile not found")
type RemoteBackend string
const (
RemoteBackendWebDAV RemoteBackend = "webdav"
)
type RemoteProfile struct {
ID string
Name string
Backend RemoteBackend
BaseURL string
Path string
}
type Entry struct {
ID string
Title string
Username string
Password string
URL string
Notes string
Tags []string
Fields map[string]string
Attachments map[string][]byte
History []Entry
Path []string
}
type SearchResult struct {
Entry Entry
Path string
}
type Model struct {
Entries []Entry
Templates []Entry
RecycleBin []Entry
Groups [][]string
RemoteProfiles []RemoteProfile
}
func (m Model) ChildGroups(path []string) []string {
seen := map[string]bool{}
var groups []string
for _, entry := range m.Entries {
if len(path) > len(entry.Path) {
continue
}
if !slices.Equal(entry.Path[:len(path)], path) {
continue
}
if len(entry.Path) == len(path) {
continue
}
group := entry.Path[len(path)]
if seen[group] {
continue
}
seen[group] = true
groups = append(groups, group)
}
for _, groupPath := range m.Groups {
if len(path) > len(groupPath) {
continue
}
if !slices.Equal(groupPath[:len(path)], path) {
continue
}
if len(groupPath) == len(path) {
continue
}
group := groupPath[len(path)]
if seen[group] {
continue
}
seen[group] = true
groups = append(groups, group)
}
slices.Sort(groups)
return groups
}
func (m Model) EntriesInPath(path []string) []Entry {
var entries []Entry
for _, entry := range m.Entries {
if slices.Equal(entry.Path, path) {
entries = append(entries, entry)
}
}
slices.SortFunc(entries, func(a, b Entry) int {
switch {
case a.Title < b.Title:
return -1
case a.Title > b.Title:
return 1
default:
return 0
}
})
return entries
}
func (m Model) EntriesUnderPath(path []string) []Entry {
var entries []Entry
for _, entry := range m.Entries {
if !hasPathPrefix(entry.Path, path) {
continue
}
entries = append(entries, entry)
}
slices.SortFunc(entries, func(a, b Entry) int {
switch {
case a.Title < b.Title:
return -1
case a.Title > b.Title:
return 1
default:
return 0
}
})
return entries
}
func (m Model) Search(query string) []SearchResult {
query = strings.TrimSpace(strings.ToLower(query))
if query == "" {
return nil
}
var results []SearchResult
for _, entry := range m.Entries {
haystack := strings.ToLower(
entry.Title + " " +
entry.Username + " " +
entry.URL + " " +
strings.Join(entry.Path, " "),
)
if !strings.Contains(haystack, query) {
continue
}
results = append(results, SearchResult{
Entry: entry,
Path: strings.Join(entry.Path, " / "),
})
}
slices.SortFunc(results, func(a, b SearchResult) int {
switch {
case a.Entry.Title < b.Entry.Title:
return -1
case a.Entry.Title > b.Entry.Title:
return 1
default:
return 0
}
})
return results
}
func (m *Model) UpsertEntry(entry Entry) {
for i := range m.Entries {
if m.Entries[i].ID != entry.ID {
continue
}
previous := cloneEntry(m.Entries[i])
entry.History = append([]Entry{previous}, cloneHistory(m.Entries[i].History)...)
m.Entries[i] = cloneEntry(entry)
return
}
m.Entries = append(m.Entries, cloneEntry(entry))
}
func (m *Model) RemoveEntryByID(id string) bool {
for i := range m.Entries {
if m.Entries[i].ID != id {
continue
}
m.Entries = append(m.Entries[:i], m.Entries[i+1:]...)
return true
}
return false
}
func (m *Model) EntryByID(id string) (Entry, error) {
for _, entry := range m.Entries {
if entry.ID == id {
return cloneEntry(entry), nil
}
}
return Entry{}, ErrEntryNotFound
}
func (m *Model) UpsertRemoteProfile(profile RemoteProfile) {
for i := range m.RemoteProfiles {
if m.RemoteProfiles[i].ID != profile.ID {
continue
}
m.RemoteProfiles[i] = profile
return
}
m.RemoteProfiles = append(m.RemoteProfiles, profile)
}
func (m *Model) RemoveRemoteProfileByID(id string) bool {
for i := range m.RemoteProfiles {
if m.RemoteProfiles[i].ID != id {
continue
}
m.RemoteProfiles = append(m.RemoteProfiles[:i], m.RemoteProfiles[i+1:]...)
return true
}
return false
}
func (m Model) RemoteProfileByID(id string) (RemoteProfile, error) {
for _, profile := range m.RemoteProfiles {
if profile.ID == id {
return profile, nil
}
}
return RemoteProfile{}, ErrRemoteProfileNotFound
}
func (m *Model) UpsertTemplate(entry Entry) {
for i := range m.Templates {
if m.Templates[i].ID != entry.ID {
continue
}
m.Templates[i] = cloneEntry(entry)
return
}
m.Templates = append(m.Templates, cloneEntry(entry))
}
func (m *Model) DeleteTemplate(id string) error {
for i := range m.Templates {
if m.Templates[i].ID != id {
continue
}
m.Templates = append(m.Templates[:i], m.Templates[i+1:]...)
return nil
}
return ErrEntryNotFound
}
func (m *Model) DeleteEntry(id string) error {
for i := range m.Entries {
if m.Entries[i].ID != id {
continue
}
m.RecycleBin = append(m.RecycleBin, cloneEntry(m.Entries[i]))
m.Entries = append(m.Entries[:i], m.Entries[i+1:]...)
return nil
}
return ErrEntryNotFound
}
func (m *Model) RestoreEntry(id string) error {
for i := range m.RecycleBin {
if m.RecycleBin[i].ID != id {
continue
}
m.Entries = append(m.Entries, cloneEntry(m.RecycleBin[i]))
m.RecycleBin = append(m.RecycleBin[:i], m.RecycleBin[i+1:]...)
return nil
}
return ErrEntryNotFound
}
func (m *Model) InstantiateTemplate(templateID string, overrides Entry) (Entry, error) {
for i := range m.Templates {
if m.Templates[i].ID != templateID {
continue
}
entry := mergeEntryTemplate(m.Templates[i], overrides)
m.UpsertEntry(entry)
return cloneEntry(entry), nil
}
return Entry{}, ErrEntryNotFound
}
func (m *Model) DuplicateEntry(id, duplicateID string) (Entry, error) {
for i := range m.Entries {
if m.Entries[i].ID != id {
continue
}
duplicate := cloneEntry(m.Entries[i])
duplicate.ID = duplicateID
duplicate.Title = duplicate.Title + " (Copy)"
duplicate.History = nil
m.Entries = append(m.Entries, duplicate)
return cloneEntry(duplicate), nil
}
return Entry{}, ErrEntryNotFound
}
func (m *Model) RestoreEntryVersion(id string, historyIndex int) error {
for i := range m.Entries {
if m.Entries[i].ID != id {
continue
}
if historyIndex < 0 || historyIndex >= len(m.Entries[i].History) {
return ErrEntryNotFound
}
current := cloneEntry(m.Entries[i])
restored := cloneEntry(m.Entries[i].History[historyIndex])
restored.ID = current.ID
restored.History = append([]Entry{current}, append(
cloneHistory(m.Entries[i].History[:historyIndex]),
cloneHistory(m.Entries[i].History[historyIndex+1:])...,
)...)
m.Entries[i] = restored
return nil
}
return ErrEntryNotFound
}
func (m *Model) CreateGroup(parent []string, name string) {
groupPath := append([]string(nil), parent...)
for _, part := range splitGroupPath(name) {
groupPath = append(groupPath, part)
if groupPathExists(m.Groups, groupPath) {
continue
}
m.Groups = append(m.Groups, append([]string(nil), groupPath...))
}
}
func (m *Model) RenameGroup(path []string, newName string) error {
if len(path) == 0 {
return ErrEntryNotFound
}
renamed := false
newPath := append(append([]string(nil), path[:len(path)-1]...), newName)
for i := range m.Entries {
if !hasPathPrefix(m.Entries[i].Path, path) {
continue
}
m.Entries[i].Path = append(append([]string(nil), newPath...), m.Entries[i].Path[len(path):]...)
renamed = true
}
for i := range m.Templates {
if !hasPathPrefix(m.Templates[i].Path, path) {
continue
}
m.Templates[i].Path = append(append([]string(nil), newPath...), m.Templates[i].Path[len(path):]...)
renamed = true
}
for i := range m.Groups {
if !hasPathPrefix(m.Groups[i], path) {
continue
}
m.Groups[i] = append(append([]string(nil), newPath...), m.Groups[i][len(path):]...)
renamed = true
}
if !renamed {
return ErrEntryNotFound
}
return nil
}
func (m *Model) MoveEntry(id string, path []string) error {
for i := range m.Entries {
if m.Entries[i].ID != id {
continue
}
m.Entries[i].Path = append([]string(nil), path...)
return nil
}
return ErrEntryNotFound
}
func (m *Model) MoveGroup(path, parent []string) error {
if len(path) == 0 {
return ErrEntryNotFound
}
if hasPathPrefix(parent, path) {
return ErrEntryNotFound
}
groupName := path[len(path)-1]
newPath := append(append([]string(nil), parent...), groupName)
moved := false
for i := range m.Entries {
if !hasPathPrefix(m.Entries[i].Path, path) {
continue
}
m.Entries[i].Path = append(append([]string(nil), newPath...), m.Entries[i].Path[len(path):]...)
moved = true
}
for i := range m.Templates {
if !hasPathPrefix(m.Templates[i].Path, path) {
continue
}
m.Templates[i].Path = append(append([]string(nil), newPath...), m.Templates[i].Path[len(path):]...)
moved = true
}
for i := range m.Groups {
if !hasPathPrefix(m.Groups[i], path) {
continue
}
m.Groups[i] = append(append([]string(nil), newPath...), m.Groups[i][len(path):]...)
moved = true
}
if !moved {
return ErrEntryNotFound
}
if !groupPathExists(m.Groups, newPath) {
m.Groups = append(m.Groups, append([]string(nil), newPath...))
}
return nil
}
func (m *Model) MoveTemplate(id string, path []string) error {
for i := range m.Templates {
if m.Templates[i].ID != id {
continue
}
m.Templates[i].Path = append([]string(nil), path...)
return nil
}
return ErrEntryNotFound
}
func (m *Model) DeleteGroup(path []string) error {
for _, entry := range m.Entries {
if slices.Equal(entry.Path, path) || hasPathPrefix(entry.Path, path) {
return ErrGroupNotEmpty
}
}
for _, entry := range m.Templates {
if slices.Equal(entry.Path, path) || hasPathPrefix(entry.Path, path) {
return ErrGroupNotEmpty
}
}
for i := range m.Groups {
if slices.Equal(m.Groups[i], path) {
m.Groups = append(m.Groups[:i], m.Groups[i+1:]...)
return nil
}
}
return ErrEntryNotFound
}
func hasPathPrefix(path, prefix []string) bool {
if len(prefix) > len(path) {
return false
}
return slices.Equal(path[:len(prefix)], prefix)
}
func splitGroupPath(name string) []string {
var parts []string
for _, part := range strings.Split(name, "/") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
parts = append(parts, part)
}
return parts
}
func groupPathExists(groups [][]string, path []string) bool {
for _, existing := range groups {
if slices.Equal(existing, path) {
return true
}
}
return false
}
func mergeEntryTemplate(template, overrides Entry) Entry {
entry := cloneEntry(template)
if overrides.ID != "" {
entry.ID = overrides.ID
}
if overrides.Title != "" {
entry.Title = overrides.Title
}
if overrides.Username != "" {
entry.Username = overrides.Username
}
if overrides.Password != "" {
entry.Password = overrides.Password
}
if overrides.URL != "" {
entry.URL = overrides.URL
}
if overrides.Notes != "" {
entry.Notes = overrides.Notes
}
if len(overrides.Tags) > 0 {
entry.Tags = slices.Clone(overrides.Tags)
}
if len(overrides.Path) > 0 {
entry.Path = slices.Clone(overrides.Path)
}
entry.Fields = mergeStringMaps(template.Fields, overrides.Fields)
entry.Attachments = mergeBinaryMaps(template.Attachments, overrides.Attachments)
entry.History = nil
return entry
}
func mergeStringMaps(base, overrides map[string]string) map[string]string {
if len(base) == 0 && len(overrides) == 0 {
return nil
}
out := make(map[string]string, len(base)+len(overrides))
for key, value := range base {
out[key] = value
}
for key, value := range overrides {
out[key] = value
}
return out
}
func mergeBinaryMaps(base, overrides map[string][]byte) map[string][]byte {
if len(base) == 0 && len(overrides) == 0 {
return nil
}
out := make(map[string][]byte, len(base)+len(overrides))
for key, value := range base {
out[key] = slices.Clone(value)
}
for key, value := range overrides {
out[key] = slices.Clone(value)
}
return out
}
func cloneEntry(entry Entry) Entry {
entry.Tags = slices.Clone(entry.Tags)
entry.Path = slices.Clone(entry.Path)
entry.History = cloneHistory(entry.History)
if entry.Fields != nil {
fields := make(map[string]string, len(entry.Fields))
for key, value := range entry.Fields {
fields[key] = value
}
entry.Fields = fields
}
if entry.Attachments != nil {
attachments := make(map[string][]byte, len(entry.Attachments))
for key, value := range entry.Attachments {
attachments[key] = slices.Clone(value)
}
entry.Attachments = attachments
}
return entry
}
func cloneHistory(history []Entry) []Entry {
if len(history) == 0 {
return nil
}
out := make([]Entry, len(history))
for i := range history {
out[i] = cloneEntry(history[i])
}
return out
}
+343
View File
@@ -0,0 +1,343 @@
package vault
import (
"errors"
"slices"
"testing"
)
func testModel() Model {
return Model{
Entries: []Entry{
{ID: "1", Title: "Bellagio", Username: "rustyryan", URL: "https://bellagio.example.invalid", Path: []string{"Crew", "Internet"}},
{ID: "2", Title: "Vault Console", Username: "dannyocean", URL: "https://vault.crew.example.invalid", Path: []string{"Crew", "Internet"}},
{ID: "3", Title: "Surveillance Console", Username: "codex", URL: "https://surveillance.crew.example.invalid", Path: []string{"Crew", "Home Assistant"}},
{ID: "4", Title: "Alma (WA Prep)", Username: "christina.julian", URL: "https://waprep.getalma.com", Path: []string{"Tricia", "School"}},
},
}
}
func TestChildGroupsReturnsImmediateGroupsOnly(t *testing.T) {
model := testModel()
got := model.ChildGroups([]string{"Crew"})
want := []string{"Home Assistant", "Internet"}
if !slices.Equal(got, want) {
t.Fatalf("ChildGroups() = %v, want %v", got, want)
}
}
func TestEntriesInPathReturnsOnlyDirectEntries(t *testing.T) {
model := testModel()
got := model.EntriesInPath([]string{"Crew", "Internet"})
if len(got) != 2 {
t.Fatalf("len(EntriesInPath()) = %d, want 2", len(got))
}
if got[0].Title != "Bellagio" || got[1].Title != "Vault Console" {
t.Fatalf("EntriesInPath() titles = %q, %q", got[0].Title, got[1].Title)
}
}
func TestEntriesUnderPathReturnsDescendantEntries(t *testing.T) {
t.Parallel()
model := testModel()
got := model.EntriesUnderPath([]string{"Crew"})
if len(got) != 3 {
t.Fatalf("len(EntriesUnderPath(Crew)) = %d, want 3", len(got))
}
if got[0].Title != "Bellagio" || got[1].Title != "Surveillance Console" || got[2].Title != "Vault Console" {
t.Fatalf("EntriesUnderPath(Crew) titles = %q, %q, %q", got[0].Title, got[1].Title, got[2].Title)
}
}
func TestSearchReturnsMatchesWithFullPathContext(t *testing.T) {
model := testModel()
got := model.Search("vault")
if len(got) != 1 {
t.Fatalf("len(Search()) = %d, want 1", len(got))
}
if got[0].Entry.Title != "Vault Console" {
t.Fatalf("Search() title = %q, want %q", got[0].Entry.Title, "Vault Console")
}
if got[0].Path != "Crew / Internet" {
t.Fatalf("Search() path = %q, want %q", got[0].Path, "Crew / Internet")
}
}
func TestTemplateEntriesAreStoredSeparatelyFromNormalEntries(t *testing.T) {
model := testModel()
model.UpsertTemplate(Entry{
ID: "tpl-1",
Title: "Website Login",
Username: "template-user",
Password: "template-password",
URL: "https://example.com",
Notes: "Reusable template for website accounts.",
Tags: []string{"template", "web"},
Path: []string{"Templates"},
})
if len(model.Entries) != 4 {
t.Fatalf("len(Entries) = %d, want 4", len(model.Entries))
}
if len(model.Templates) != 1 {
t.Fatalf("len(Templates) = %d, want 1", len(model.Templates))
}
if got := model.Templates[0].Title; got != "Website Login" {
t.Fatalf("Templates[0].Title = %q, want %q", got, "Website Login")
}
}
func TestInstantiateTemplateCreatesNormalEntryWithOverrides(t *testing.T) {
model := Model{
Templates: []Entry{
{
ID: "tpl-1",
Title: "Website Login",
Username: "template-user",
Password: "template-password",
URL: "https://example.com",
Notes: "Reusable template for website accounts.",
Tags: []string{"template", "web"},
Fields: map[string]string{
"Environment": "prod",
},
Path: []string{"Templates"},
},
},
}
entry, err := model.InstantiateTemplate("tpl-1", Entry{
ID: "entry-1",
Title: "Bellagio",
Username: "rustyryan",
Password: "hunter2",
URL: "https://bellagio.example.invalid",
Path: []string{"Crew", "Internet"},
Tags: []string{"dns"},
})
if err != nil {
t.Fatalf("InstantiateTemplate() error = %v", err)
}
if entry.ID != "entry-1" {
t.Fatalf("entry.ID = %q, want %q", entry.ID, "entry-1")
}
if entry.Title != "Bellagio" {
t.Fatalf("entry.Title = %q, want %q", entry.Title, "Bellagio")
}
if entry.Username != "rustyryan" || entry.Password != "hunter2" || entry.URL != "https://bellagio.example.invalid" {
t.Fatalf("entry credentials = %#v, want override values", entry)
}
if entry.Notes != "Reusable template for website accounts." {
t.Fatalf("entry.Notes = %q, want %q", entry.Notes, "Reusable template for website accounts.")
}
if !slices.Equal(entry.Tags, []string{"dns"}) {
t.Fatalf("entry.Tags = %v, want [dns]", entry.Tags)
}
if entry.Fields["Environment"] != "prod" {
t.Fatalf("entry.Fields[Environment] = %q, want %q", entry.Fields["Environment"], "prod")
}
got := model.EntriesInPath([]string{"Crew", "Internet"})
if len(got) != 1 || got[0].Title != "Bellagio" {
t.Fatalf("EntriesInPath() = %#v, want instantiated Bellagio entry", got)
}
}
func TestInstantiateTemplateFailsForUnknownTemplate(t *testing.T) {
model := Model{}
_, err := model.InstantiateTemplate("missing-template", Entry{ID: "entry-1"})
if err == nil {
t.Fatal("InstantiateTemplate() error = nil, want ErrEntryNotFound")
}
if !errors.Is(err, ErrEntryNotFound) {
t.Fatalf("InstantiateTemplate() error = %v, want ErrEntryNotFound", err)
}
}
func TestDeleteTemplateRemovesTemplateWithoutTouchingEntries(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{ID: "entry-1", Title: "Vault Console", Path: []string{"Root", "Internet"}},
},
Templates: []Entry{
{ID: "tpl-1", Title: "Website Login", Path: []string{"Templates"}},
},
}
if err := model.DeleteTemplate("tpl-1"); err != nil {
t.Fatalf("DeleteTemplate() error = %v", err)
}
if len(model.Templates) != 0 {
t.Fatalf("len(Templates) = %d, want 0", len(model.Templates))
}
if len(model.Entries) != 1 || model.Entries[0].ID != "entry-1" {
t.Fatalf("Entries = %#v, want unchanged normal entry", model.Entries)
}
}
func TestMoveTemplateChangesItsPath(t *testing.T) {
t.Parallel()
model := Model{
Templates: []Entry{
{ID: "tpl-1", Title: "Website Login", Path: []string{"Templates", "Web"}},
},
}
if err := model.MoveTemplate("tpl-1", []string{"Templates", "Infra"}); err != nil {
t.Fatalf("MoveTemplate() error = %v", err)
}
if got := model.Templates[0].Path; !slices.Equal(got, []string{"Templates", "Infra"}) {
t.Fatalf("Templates[0].Path = %v, want [Templates Infra]", got)
}
}
func TestDuplicateEntryCopiesEntryWithNewIDAndTitle(t *testing.T) {
t.Parallel()
model := Model{
Entries: []Entry{
{
ID: "entry-1",
Title: "Vault Console",
Username: "dannyocean",
Password: "token-1",
Path: []string{"Root", "Internet"},
},
},
}
duplicate, err := model.DuplicateEntry("entry-1", "entry-2")
if err != nil {
t.Fatalf("DuplicateEntry() error = %v", err)
}
if duplicate.ID != "entry-2" {
t.Fatalf("duplicate.ID = %q, want %q", duplicate.ID, "entry-2")
}
if duplicate.Title != "Vault Console (Copy)" {
t.Fatalf("duplicate.Title = %q, want %q", duplicate.Title, "Vault Console (Copy)")
}
got := model.EntriesInPath([]string{"Root", "Internet"})
if len(got) != 2 {
t.Fatalf("len(EntriesInPath()) = %d, want 2", len(got))
}
}
func TestCreateGroupMakesItVisibleAsChildGroup(t *testing.T) {
model := testModel()
model.CreateGroup([]string{"Crew"}, "Finance")
got := model.ChildGroups([]string{"Crew"})
want := []string{"Finance", "Home Assistant", "Internet"}
if !slices.Equal(got, want) {
t.Fatalf("ChildGroups() = %v, want %v", got, want)
}
}
func TestCreateGroupSupportsNestedRelativePath(t *testing.T) {
t.Parallel()
model := testModel()
model.CreateGroup([]string{"Crew"}, "Infrastructure / Prod")
got := model.ChildGroups([]string{"Crew"})
if !slices.Equal(got, []string{"Home Assistant", "Infrastructure", "Internet"}) {
t.Fatalf("ChildGroups(Crew) = %v, want [Home Assistant Infrastructure Internet]", got)
}
got = model.ChildGroups([]string{"Crew", "Infrastructure"})
if !slices.Equal(got, []string{"Prod"}) {
t.Fatalf("ChildGroups(Crew/Infrastructure) = %v, want [Prod]", got)
}
}
func TestRenameGroupMovesEntriesAndKeepsHierarchy(t *testing.T) {
model := testModel()
if err := model.RenameGroup([]string{"Crew", "Internet"}, "Infra"); err != nil {
t.Fatalf("RenameGroup() error = %v", err)
}
got := model.EntriesInPath([]string{"Crew", "Infra"})
if len(got) != 2 {
t.Fatalf("len(EntriesInPath(Crew/Infra)) = %d, want 2", len(got))
}
if len(model.EntriesInPath([]string{"Crew", "Internet"})) != 0 {
t.Fatal("EntriesInPath(Crew/Internet) should be empty after rename")
}
}
func TestMoveEntryChangesItsPath(t *testing.T) {
model := testModel()
if err := model.MoveEntry("1", []string{"Tricia", "School"}); err != nil {
t.Fatalf("MoveEntry() error = %v", err)
}
got := model.EntriesInPath([]string{"Tricia", "School"})
if len(got) != 2 {
t.Fatalf("len(EntriesInPath(Tricia/School)) = %d, want 2", len(got))
}
}
func TestMoveGroupMovesEntriesAndNestedGroups(t *testing.T) {
t.Parallel()
model := testModel()
model.CreateGroup([]string{"Crew", "Internet"}, "Infrastructure")
if err := model.MoveGroup([]string{"Crew", "Internet"}, []string{"Tricia"}); err != nil {
t.Fatalf("MoveGroup() error = %v", err)
}
got := model.EntriesInPath([]string{"Tricia", "Internet"})
if len(got) != 2 {
t.Fatalf("len(EntriesInPath(Tricia/Internet)) = %d, want 2", len(got))
}
if len(model.EntriesInPath([]string{"Crew", "Internet"})) != 0 {
t.Fatal("EntriesInPath(Crew/Internet) should be empty after move")
}
gotGroups := model.ChildGroups([]string{"Tricia", "Internet"})
if !slices.Equal(gotGroups, []string{"Infrastructure"}) {
t.Fatalf("ChildGroups(Tricia/Internet) = %v, want [Infrastructure]", gotGroups)
}
}
func TestDeleteEmptyGroupRemovesItFromNavigation(t *testing.T) {
model := testModel()
model.CreateGroup([]string{"Crew"}, "Finance")
if err := model.DeleteGroup([]string{"Crew", "Finance"}); err != nil {
t.Fatalf("DeleteGroup() error = %v", err)
}
got := model.ChildGroups([]string{"Crew"})
want := []string{"Home Assistant", "Internet"}
if !slices.Equal(got, want) {
t.Fatalf("ChildGroups() = %v, want %v", got, want)
}
}
+100
View File
@@ -0,0 +1,100 @@
package vault
import (
"fmt"
"slices"
"github.com/tobischo/gokeepasslib/v3"
)
type SecuritySettings struct {
Cipher string
KDF string
}
const (
CipherAES256 = "aes256"
CipherChaCha20 = "chacha20"
KDFAES = "aes-kdf"
KDFArgon2 = "argon2"
)
func SupportedSecuritySettings() (ciphers []string, kdfs []string) {
return []string{CipherAES256, CipherChaCha20}, []string{KDFAES, KDFArgon2}
}
func DetectSecuritySettings(config *KDBXConfig) SecuritySettings {
settings := SecuritySettings{
Cipher: CipherChaCha20,
KDF: KDFArgon2,
}
if config == nil || config.Header == nil || config.Header.FileHeaders == nil {
return settings
}
if slices.Equal(config.Header.FileHeaders.CipherID, gokeepasslib.CipherAES) {
settings.Cipher = CipherAES256
}
if config.Header.FileHeaders.KdfParameters != nil && slices.Equal(config.Header.FileHeaders.KdfParameters.UUID, gokeepasslib.KdfAES4) {
settings.KDF = KDFAES
}
return settings
}
func NewSecurityConfig(settings SecuritySettings) (*KDBXConfig, error) {
db := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4())
config := &KDBXConfig{
Header: cloneHeader(db.Header),
InnerHeader: cloneInnerHeader(db.Content.InnerHeader),
}
return ApplySecuritySettings(config, settings)
}
func ApplySecuritySettings(config *KDBXConfig, settings SecuritySettings) (*KDBXConfig, error) {
if config == nil || config.Header == nil || config.Header.FileHeaders == nil {
return NewSecurityConfig(settings)
}
out := &KDBXConfig{
Header: cloneHeader(config.Header),
InnerHeader: cloneInnerHeader(config.InnerHeader),
}
if out.Header.FileHeaders.KdfParameters == nil {
defaults := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4())
out.Header.FileHeaders.KdfParameters = cloneHeader(defaults.Header).FileHeaders.KdfParameters
}
switch settings.Cipher {
case "", CipherChaCha20:
out.Header.FileHeaders.CipherID = slices.Clone(gokeepasslib.CipherChaCha20)
out.Header.FileHeaders.EncryptionIV = randomBytes(12)
case CipherAES256:
out.Header.FileHeaders.CipherID = slices.Clone(gokeepasslib.CipherAES)
out.Header.FileHeaders.EncryptionIV = randomBytes(16)
default:
return nil, fmt.Errorf("unsupported cipher %q", settings.Cipher)
}
var salt [32]byte
copy(salt[:], randomBytes(32))
switch settings.KDF {
case "", KDFArgon2:
defaults := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4())
kdf := defaults.Header.FileHeaders.KdfParameters
out.Header.FileHeaders.KdfParameters = &gokeepasslib.KdfParameters{
UUID: slices.Clone(gokeepasslib.KdfArgon2),
Rounds: kdf.Rounds,
Salt: salt,
Parallelism: kdf.Parallelism,
Memory: kdf.Memory,
Iterations: kdf.Iterations,
Version: kdf.Version,
}
case KDFAES:
out.Header.FileHeaders.KdfParameters = &gokeepasslib.KdfParameters{
UUID: slices.Clone(gokeepasslib.KdfAES4),
Rounds: 6000,
Salt: salt,
}
default:
return nil, fmt.Errorf("unsupported KDF %q", settings.KDF)
}
return out, nil
}
+50
View File
@@ -0,0 +1,50 @@
package vault
import (
"bytes"
"slices"
"testing"
"github.com/tobischo/gokeepasslib/v3"
)
func TestNewSecurityConfigCreatesRequestedCipherAndKDF(t *testing.T) {
t.Parallel()
config, err := NewSecurityConfig(SecuritySettings{Cipher: CipherAES256, KDF: KDFAES})
if err != nil {
t.Fatalf("NewSecurityConfig() error = %v", err)
}
if !slices.Equal(config.Header.FileHeaders.CipherID, gokeepasslib.CipherAES) {
t.Fatalf("CipherID = %x, want %x", config.Header.FileHeaders.CipherID, gokeepasslib.CipherAES)
}
if !slices.Equal(config.Header.FileHeaders.KdfParameters.UUID, gokeepasslib.KdfAES4) {
t.Fatalf("KDF UUID = %x, want %x", config.Header.FileHeaders.KdfParameters.UUID, gokeepasslib.KdfAES4)
}
}
func TestApplySecuritySettingsPreservesRequestedChoicesAcrossSave(t *testing.T) {
t.Parallel()
config, err := NewSecurityConfig(SecuritySettings{Cipher: CipherChaCha20, KDF: KDFArgon2})
if err != nil {
t.Fatalf("NewSecurityConfig() error = %v", err)
}
config, err = ApplySecuritySettings(config, SecuritySettings{Cipher: CipherAES256, KDF: KDFAES})
if err != nil {
t.Fatalf("ApplySecuritySettings() error = %v", err)
}
var encoded bytes.Buffer
if err := SaveKDBXWithConfigAndKey(&encoded, Model{}, MasterKey{Password: "correct horse battery staple"}, config); err != nil {
t.Fatalf("SaveKDBXWithConfigAndKey() error = %v", err)
}
_, reloadedConfig, err := LoadKDBXWithConfig(bytes.NewReader(encoded.Bytes()), MasterKey{Password: "correct horse battery staple"})
if err != nil {
t.Fatalf("LoadKDBXWithConfig() error = %v", err)
}
got := DetectSecuritySettings(reloadedConfig)
if got.Cipher != CipherAES256 || got.KDF != KDFAES {
t.Fatalf("DetectSecuritySettings() = %#v, want aes256/aes-kdf", got)
}
}
+93
View File
@@ -0,0 +1,93 @@
package webdav
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
)
var ErrConflict = errors.New("webdav conflict")
type Version struct {
ETag string
}
type Client struct {
HTTPClient *http.Client
BaseURL string
Username string
Password string
}
func (c Client) Open(path string) ([]byte, Version, error) {
req, err := http.NewRequest(http.MethodGet, c.url(path), nil)
if err != nil {
return nil, Version{}, fmt.Errorf("build GET request: %w", err)
}
c.applyAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, Version{}, fmt.Errorf("GET %s: %w", path, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, Version{}, fmt.Errorf("GET %s: unexpected status %d", path, resp.StatusCode)
}
content, err := io.ReadAll(resp.Body)
if err != nil {
return nil, Version{}, fmt.Errorf("read %s: %w", path, err)
}
return content, Version{ETag: resp.Header.Get("ETag")}, nil
}
func (c Client) Save(path string, content io.Reader, version Version) (Version, error) {
req, err := http.NewRequest(http.MethodPut, c.url(path), content)
if err != nil {
return Version{}, fmt.Errorf("build PUT request: %w", err)
}
c.applyAuth(req)
if version.ETag != "" {
req.Header.Set("If-Match", version.ETag)
}
resp, err := c.httpClient().Do(req)
if err != nil {
return Version{}, fmt.Errorf("PUT %s: %w", path, err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusPreconditionFailed {
return Version{}, ErrConflict
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusNoContent {
return Version{}, fmt.Errorf("PUT %s: unexpected status %d", path, resp.StatusCode)
}
return Version{ETag: resp.Header.Get("ETag")}, nil
}
func (c Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
func (c Client) applyAuth(req *http.Request) {
if c.Username == "" && c.Password == "" {
return
}
req.SetBasicAuth(c.Username, c.Password)
}
func (c Client) url(path string) string {
base := strings.TrimRight(c.BaseURL, "/")
path = strings.TrimLeft(path, "/")
return base + "/" + path
}
+105
View File
@@ -0,0 +1,105 @@
package webdav
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestClientOpenDownloadsRemoteVault(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Fatalf("method = %s, want GET", r.Method)
}
if user, pass, ok := r.BasicAuth(); !ok || user != "rustyryan" || pass != "secret" {
t.Fatalf("basic auth = %q/%q ok=%v, want rustyryan/secret true", user, pass, ok)
}
w.Header().Set("ETag", `"etag-1"`)
_, _ = io.WriteString(w, "vault-bytes")
}))
defer server.Close()
client := Client{
HTTPClient: server.Client(),
BaseURL: server.URL,
Username: "rustyryan",
Password: "secret",
}
content, version, err := client.Open("keepass.kdbx")
if err != nil {
t.Fatalf("Open() error = %v", err)
}
if string(content) != "vault-bytes" {
t.Fatalf("Open() content = %q, want %q", string(content), "vault-bytes")
}
if version.ETag != `"etag-1"` {
t.Fatalf("Open() ETag = %q, want %q", version.ETag, `"etag-1"`)
}
}
func TestClientSaveUploadsVaultWithIfMatch(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
t.Fatalf("method = %s, want PUT", r.Method)
}
if got := r.Header.Get("If-Match"); got != `"etag-1"` {
t.Fatalf("If-Match = %q, want %q", got, `"etag-1"`)
}
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("ReadAll() error = %v", err)
}
if string(body) != "updated-vault" {
t.Fatalf("PUT body = %q, want %q", string(body), "updated-vault")
}
w.Header().Set("ETag", `"etag-2"`)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := Client{
HTTPClient: server.Client(),
BaseURL: server.URL,
Username: "rustyryan",
Password: "secret",
}
version, err := client.Save("keepass.kdbx", bytes.NewBufferString("updated-vault"), Version{ETag: `"etag-1"`})
if err != nil {
t.Fatalf("Save() error = %v", err)
}
if version.ETag != `"etag-2"` {
t.Fatalf("Save() ETag = %q, want %q", version.ETag, `"etag-2"`)
}
}
func TestClientSaveReturnsConflictOnVersionMismatch(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusPreconditionFailed)
}))
defer server.Close()
client := Client{
HTTPClient: server.Client(),
BaseURL: server.URL,
Username: "rustyryan",
Password: "secret",
}
_, err := client.Save("keepass.kdbx", bytes.NewBufferString("updated-vault"), Version{ETag: `"etag-1"`})
if err != ErrConflict {
t.Fatalf("Save() error = %v, want %v", err, ErrConflict)
}
}