Move app packages under internal
@@ -0,0 +1,122 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type DirtyProvider func() bool
|
||||
|
||||
type Host struct {
|
||||
server *Server
|
||||
grpcServer *grpc.Server
|
||||
listener net.Listener
|
||||
lifecycle lifecycleBackend
|
||||
dirty DirtyProvider
|
||||
mu sync.Mutex
|
||||
lastModel vault.Model
|
||||
started bool
|
||||
listenAddr string
|
||||
}
|
||||
|
||||
func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, dirty DirtyProvider) (*Host, error) {
|
||||
addr = strings.TrimSpace(addr)
|
||||
if addr == "" || strings.EqualFold(addr, "off") {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen gRPC host %s: %w", addr, err)
|
||||
}
|
||||
|
||||
service := NewServerWithLifecycle(vault.Model{}, profiles, clipboardWriter, lifecycle)
|
||||
server := grpc.NewServer(grpc.UnaryInterceptor(AuthInterceptor(service)))
|
||||
keepassgov1.RegisterVaultServiceServer(server, service)
|
||||
|
||||
host := &Host{
|
||||
server: service,
|
||||
grpcServer: server,
|
||||
listener: listener,
|
||||
lifecycle: lifecycle,
|
||||
dirty: dirty,
|
||||
listenAddr: listener.Addr().String(),
|
||||
started: true,
|
||||
}
|
||||
if err := host.SyncFromLifecycle(); err != nil && !errors.Is(err, session.ErrLocked) {
|
||||
_ = listener.Close()
|
||||
server.Stop()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = server.Serve(listener)
|
||||
}()
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (h *Host) Address() string {
|
||||
if h == nil {
|
||||
return ""
|
||||
}
|
||||
return h.listenAddr
|
||||
}
|
||||
|
||||
func (h *Host) Server() *Server {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
return h.server
|
||||
}
|
||||
|
||||
func (h *Host) Stop() error {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if !h.started {
|
||||
return nil
|
||||
}
|
||||
h.started = false
|
||||
h.grpcServer.Stop()
|
||||
return h.listener.Close()
|
||||
}
|
||||
|
||||
func (h *Host) SyncFromLifecycle() error {
|
||||
if h == nil || h.lifecycle == nil || h.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
model, err := h.lifecycle.Current()
|
||||
locked := false
|
||||
switch {
|
||||
case err == nil:
|
||||
h.lastModel = model
|
||||
case errors.Is(err, session.ErrLocked):
|
||||
locked = true
|
||||
default:
|
||||
return err
|
||||
}
|
||||
|
||||
dirty := false
|
||||
if h.dirty != nil {
|
||||
dirty = h.dirty()
|
||||
}
|
||||
h.server.SetSessionState(h.lastModel, locked, dirty)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
func TestStartHostServesVaultLifecycleAndSyncsSessionState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lifecycle := &session.Manager{}
|
||||
if err := lifecycle.Create(vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
testAPITokenEntry(t),
|
||||
{ID: "entry-1", Title: "Vault Console", Path: []string{"Root", "Internet"}},
|
||||
},
|
||||
}, vault.MasterKey{Password: "correct horse battery staple"}); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
host, err := StartHost("127.0.0.1:0", lifecycle, passwords.DefaultProfiles(), nil, func() bool { return true })
|
||||
if err != nil {
|
||||
t.Fatalf("StartHost() error = %v", err)
|
||||
}
|
||||
defer func() { _ = host.Stop() }()
|
||||
|
||||
conn, err := grpc.NewClient("passthrough:///"+host.Address(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||
return net.Dial("tcp", host.Address())
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.NewClient() error = %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client := keepassgov1.NewVaultServiceClient(conn)
|
||||
statusResp, err := client.GetSessionStatus(tokenContext(defaultTestTokenSecret), &keepassgov1.GetSessionStatusRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionStatus() error = %v", err)
|
||||
}
|
||||
if statusResp.Locked {
|
||||
t.Fatal("GetSessionStatus().Locked = true, want false")
|
||||
}
|
||||
if !statusResp.Dirty {
|
||||
t.Fatal("GetSessionStatus().Dirty = false, want true from dirty provider")
|
||||
}
|
||||
|
||||
if err := lifecycle.Lock(); err != nil {
|
||||
t.Fatalf("Lock() error = %v", err)
|
||||
}
|
||||
if err := host.SyncFromLifecycle(); err != nil {
|
||||
t.Fatalf("SyncFromLifecycle() after lock error = %v", err)
|
||||
}
|
||||
|
||||
statusResp, err = client.GetSessionStatus(tokenContext(defaultTestTokenSecret), &keepassgov1.GetSessionStatusRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionStatus() after lock error = %v", err)
|
||||
}
|
||||
if !statusResp.Locked {
|
||||
t.Fatal("GetSessionStatus().Locked = false, want true after lifecycle lock")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,979 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/apiapproval"
|
||||
"git.julianfamily.org/keepassgo/internal/apiaudit"
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
"git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
"git.julianfamily.org/keepassgo/internal/webdav"
|
||||
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
keepassgov1.UnimplementedVaultServiceServer
|
||||
|
||||
mu sync.RWMutex
|
||||
model vault.Model
|
||||
locked bool
|
||||
dirty bool
|
||||
lifecycle lifecycleBackend
|
||||
profiles map[string]passwords.Profile
|
||||
clipboard clipboard.Writer
|
||||
approvals *apiapproval.Broker
|
||||
audit *apiaudit.Log
|
||||
}
|
||||
|
||||
type lifecycleBackend interface {
|
||||
Current() (vault.Model, error)
|
||||
Open(string, vault.MasterKey) error
|
||||
OpenRemote(webdav.Client, string, vault.MasterKey) error
|
||||
Save() error
|
||||
Lock() error
|
||||
Unlock(vault.MasterKey) error
|
||||
}
|
||||
|
||||
type modelReplaceableLifecycle interface {
|
||||
lifecycleBackend
|
||||
Replace(vault.Model)
|
||||
}
|
||||
|
||||
func NewServer(model vault.Model, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer) *Server {
|
||||
return &Server{
|
||||
model: model,
|
||||
profiles: profiles,
|
||||
clipboard: clipboardWriter,
|
||||
approvals: apiapproval.NewBroker(30 * time.Second),
|
||||
audit: apiaudit.New(200),
|
||||
}
|
||||
}
|
||||
|
||||
func NewServerWithLifecycle(model vault.Model, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, lifecycle lifecycleBackend) *Server {
|
||||
server := NewServer(model, profiles, clipboardWriter)
|
||||
server.lifecycle = lifecycle
|
||||
return server
|
||||
}
|
||||
|
||||
func (s *Server) ApprovalBroker() *apiapproval.Broker {
|
||||
return s.approvals
|
||||
}
|
||||
|
||||
func (s *Server) AuditLog() *apiaudit.Log {
|
||||
return s.audit
|
||||
}
|
||||
|
||||
func (s *Server) ResolveApproval(id string, outcome apiapproval.Outcome) (apiapproval.Request, *apitokens.PolicyRule, error) {
|
||||
return s.approvals.Resolve(id, outcome)
|
||||
}
|
||||
|
||||
func (s *Server) SetSessionState(model vault.Model, locked, dirty bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.model = model
|
||||
s.locked = locked
|
||||
s.dirty = dirty
|
||||
}
|
||||
|
||||
func (s *Server) GetSessionStatus(_ context.Context, _ *keepassgov1.GetSessionStatusRequest) (*keepassgov1.GetSessionStatusResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return &keepassgov1.GetSessionStatusResponse{
|
||||
Locked: s.locked,
|
||||
Dirty: s.dirty,
|
||||
EntryCount: uint32(len(s.model.Entries)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) OpenVault(_ context.Context, req *keepassgov1.OpenVaultRequest) (*keepassgov1.OpenVaultResponse, error) {
|
||||
if s.lifecycle == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
|
||||
}
|
||||
|
||||
key := vault.MasterKey{Password: req.GetPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
|
||||
if err := s.lifecycle.Open(req.GetPath(), key); err != nil {
|
||||
return nil, mapLifecycleError("open vault", err)
|
||||
}
|
||||
|
||||
model, err := s.lifecycle.Current()
|
||||
if err != nil {
|
||||
return nil, mapLifecycleError("load opened vault", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.model = model
|
||||
s.locked = false
|
||||
s.dirty = false
|
||||
s.mu.Unlock()
|
||||
|
||||
return &keepassgov1.OpenVaultResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) OpenRemoteVault(_ context.Context, req *keepassgov1.OpenRemoteVaultRequest) (*keepassgov1.OpenRemoteVaultResponse, error) {
|
||||
if s.lifecycle == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
|
||||
}
|
||||
|
||||
client := webdav.Client{
|
||||
BaseURL: req.GetBaseUrl(),
|
||||
Username: req.GetUsername(),
|
||||
Password: req.GetPassword(),
|
||||
}
|
||||
key := vault.MasterKey{Password: req.GetMasterPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
|
||||
if err := s.lifecycle.OpenRemote(client, req.GetPath(), key); err != nil {
|
||||
return nil, mapLifecycleError("open remote vault", err)
|
||||
}
|
||||
|
||||
model, err := s.lifecycle.Current()
|
||||
if err != nil {
|
||||
return nil, mapLifecycleError("load opened remote vault", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.model = model
|
||||
s.locked = false
|
||||
s.dirty = false
|
||||
s.mu.Unlock()
|
||||
|
||||
return &keepassgov1.OpenRemoteVaultResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) SaveVault(_ context.Context, _ *keepassgov1.SaveVaultRequest) (*keepassgov1.SaveVaultResponse, error) {
|
||||
if s.lifecycle == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
|
||||
}
|
||||
|
||||
if err := s.lifecycle.Save(); err != nil {
|
||||
return nil, mapLifecycleError("save vault", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.dirty = false
|
||||
s.mu.Unlock()
|
||||
|
||||
return &keepassgov1.SaveVaultResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) LockVault(_ context.Context, _ *keepassgov1.LockVaultRequest) (*keepassgov1.LockVaultResponse, error) {
|
||||
if s.lifecycle == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
|
||||
}
|
||||
|
||||
if err := s.lifecycle.Lock(); err != nil {
|
||||
return nil, mapLifecycleError("lock vault", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.locked = true
|
||||
s.mu.Unlock()
|
||||
|
||||
return &keepassgov1.LockVaultResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) UnlockVault(_ context.Context, req *keepassgov1.UnlockVaultRequest) (*keepassgov1.UnlockVaultResponse, error) {
|
||||
if s.lifecycle == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault lifecycle backend is not configured")
|
||||
}
|
||||
|
||||
key := vault.MasterKey{Password: req.GetPassword(), KeyFileData: append([]byte(nil), req.GetKeyFileData()...)}
|
||||
if err := s.lifecycle.Unlock(key); err != nil {
|
||||
return nil, mapLifecycleError("unlock vault", err)
|
||||
}
|
||||
|
||||
model, err := s.lifecycle.Current()
|
||||
if err != nil {
|
||||
return nil, mapLifecycleError("load unlocked vault", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.model = model
|
||||
s.locked = false
|
||||
s.mu.Unlock()
|
||||
|
||||
return &keepassgov1.UnlockVaultResponse{}, nil
|
||||
}
|
||||
|
||||
func mapLifecycleError(operation string, err error) error {
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
return status.Errorf(codes.NotFound, "%s: %v", operation, err)
|
||||
case errors.Is(err, vault.ErrInvalidMasterKey):
|
||||
return status.Errorf(codes.InvalidArgument, "%s: %v", operation, err)
|
||||
case errors.Is(err, session.ErrLocked), errors.Is(err, session.ErrNoPath):
|
||||
return status.Errorf(codes.FailedPrecondition, "%s: %v", operation, err)
|
||||
case errors.Is(err, webdav.ErrConflict):
|
||||
return status.Errorf(codes.Aborted, "%s: %v", operation, err)
|
||||
default:
|
||||
return status.Errorf(codes.Internal, "%s: %v", operation, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ListEntries(ctx context.Context, req *keepassgov1.ListEntriesRequest) (*keepassgov1.ListEntriesResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
if _, err := s.authorizePathRequest(ctx, apitokens.OperationListEntries, req.GetPath()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model = visibleModel(model)
|
||||
var entries []vault.Entry
|
||||
if strings.TrimSpace(req.GetQuery()) != "" {
|
||||
results := model.Search(req.GetQuery())
|
||||
entries = make([]vault.Entry, 0, len(results))
|
||||
for _, result := range results {
|
||||
entries = append(entries, result.Entry)
|
||||
}
|
||||
} else {
|
||||
entries = model.EntriesInPath(req.GetPath())
|
||||
}
|
||||
|
||||
resp := &keepassgov1.ListEntriesResponse{
|
||||
Entries: make([]*keepassgov1.Entry, 0, len(entries)),
|
||||
}
|
||||
for _, entry := range entries {
|
||||
resp.Entries = append(resp.Entries, entryToProto(entry))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Server) ListGroups(ctx context.Context, req *keepassgov1.ListGroupsRequest) (*keepassgov1.ListGroupsResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
if _, err := s.authorizePathRequest(ctx, apitokens.OperationListGroups, req.GetPath()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &keepassgov1.ListGroupsResponse{
|
||||
Names: visibleModel(model).ChildGroups(req.GetPath()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) CreateGroup(ctx context.Context, req *keepassgov1.CreateGroupRequest) (*keepassgov1.CreateGroupResponse, error) {
|
||||
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetParentPath()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
s.model.CreateGroup(req.GetParentPath(), req.GetName())
|
||||
s.dirty = true
|
||||
return &keepassgov1.CreateGroupResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) RenameGroup(ctx context.Context, req *keepassgov1.RenameGroupRequest) (*keepassgov1.RenameGroupResponse, error) {
|
||||
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetPath()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
if err := s.model.RenameGroup(req.GetPath(), req.GetNewName()); err != nil {
|
||||
if errors.Is(err, vault.ErrEntryNotFound) {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "rename group: %v", err)
|
||||
}
|
||||
|
||||
s.dirty = true
|
||||
return &keepassgov1.RenameGroupResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) DeleteGroup(ctx context.Context, req *keepassgov1.DeleteGroupRequest) (*keepassgov1.DeleteGroupResponse, error) {
|
||||
if _, err := s.authorizePathRequest(ctx, apitokens.OperationMutateGroup, req.GetPath()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
if err := s.model.DeleteGroup(req.GetPath()); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, vault.ErrEntryNotFound):
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
case errors.Is(err, vault.ErrGroupNotEmpty):
|
||||
return nil, status.Error(codes.FailedPrecondition, err.Error())
|
||||
default:
|
||||
return nil, status.Errorf(codes.Internal, "delete group: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.dirty = true
|
||||
return &keepassgov1.DeleteGroupResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) UpsertEntry(ctx context.Context, req *keepassgov1.UpsertEntryRequest) (*keepassgov1.UpsertEntryResponse, error) {
|
||||
if req.GetEntry() == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "missing entry")
|
||||
}
|
||||
|
||||
entry := entryFromProto(req.GetEntry())
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if s.locked {
|
||||
s.mu.Unlock()
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
s.model.UpsertEntry(entry)
|
||||
s.dirty = true
|
||||
s.mu.Unlock()
|
||||
|
||||
return &keepassgov1.UpsertEntryResponse{Entry: entryToProto(entry)}, nil
|
||||
}
|
||||
|
||||
func (s *Server) DeleteEntry(ctx context.Context, req *keepassgov1.DeleteEntryRequest) (*keepassgov1.DeleteEntryResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
if entry, err := findEntryByID(model, req.GetId()); err == nil {
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.model.DeleteEntry(req.GetId()); err != nil {
|
||||
if errors.Is(err, vault.ErrEntryNotFound) {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "delete entry: %v", err)
|
||||
}
|
||||
|
||||
s.dirty = true
|
||||
return &keepassgov1.DeleteEntryResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) RestoreEntry(ctx context.Context, req *keepassgov1.RestoreEntryRequest) (*keepassgov1.RestoreEntryResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
var restored vault.Entry
|
||||
for _, entry := range model.RecycleBin {
|
||||
if entry.ID == req.GetId() {
|
||||
restored = entry
|
||||
break
|
||||
}
|
||||
}
|
||||
if restored.ID != "" {
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, restored); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.model.RestoreEntry(req.GetId()); err != nil {
|
||||
if errors.Is(err, vault.ErrEntryNotFound) {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "restore entry: %v", err)
|
||||
}
|
||||
|
||||
s.dirty = true
|
||||
return &keepassgov1.RestoreEntryResponse{Entry: entryToProto(restored)}, nil
|
||||
}
|
||||
|
||||
func (s *Server) ListEntryHistory(ctx context.Context, req *keepassgov1.ListEntryHistoryRequest) (*keepassgov1.ListEntryHistoryResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
entry, err := findEntryByID(model, req.GetId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationReadEntry, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &keepassgov1.ListEntryHistoryResponse{
|
||||
Entries: make([]*keepassgov1.Entry, 0, len(entry.History)),
|
||||
}
|
||||
for _, historical := range entry.History {
|
||||
resp.Entries = append(resp.Entries, entryToProto(historical))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Server) RestoreEntryHistory(ctx context.Context, req *keepassgov1.RestoreEntryHistoryRequest) (*keepassgov1.RestoreEntryHistoryResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
entry, err := findEntryByID(model, req.GetId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if err := s.model.RestoreEntryVersion(req.GetId(), int(req.GetHistoryIndex())); err != nil {
|
||||
if errors.Is(err, vault.ErrEntryNotFound) {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "restore entry history: %v", err)
|
||||
}
|
||||
|
||||
entry, err = findEntryByID(s.model, req.GetId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
s.dirty = true
|
||||
return &keepassgov1.RestoreEntryHistoryResponse{Entry: entryToProto(entry)}, nil
|
||||
}
|
||||
|
||||
func (s *Server) ListTemplates(_ context.Context, _ *keepassgov1.ListTemplatesRequest) (*keepassgov1.ListTemplatesResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
resp := &keepassgov1.ListTemplatesResponse{
|
||||
Templates: make([]*keepassgov1.Entry, 0, len(s.model.Templates)),
|
||||
}
|
||||
for _, template := range s.model.Templates {
|
||||
resp.Templates = append(resp.Templates, entryToProto(template))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Server) UpsertTemplate(_ context.Context, req *keepassgov1.UpsertTemplateRequest) (*keepassgov1.UpsertTemplateResponse, error) {
|
||||
if req.GetTemplate() == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "missing template")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
entry := entryFromProto(req.GetTemplate())
|
||||
s.model.UpsertTemplate(entry)
|
||||
s.dirty = true
|
||||
|
||||
return &keepassgov1.UpsertTemplateResponse{Template: entryToProto(entry)}, nil
|
||||
}
|
||||
|
||||
func (s *Server) DeleteTemplate(_ context.Context, req *keepassgov1.DeleteTemplateRequest) (*keepassgov1.DeleteTemplateResponse, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
if err := s.model.DeleteTemplate(req.GetId()); err != nil {
|
||||
if errors.Is(err, vault.ErrEntryNotFound) {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "delete template: %v", err)
|
||||
}
|
||||
s.dirty = true
|
||||
|
||||
return &keepassgov1.DeleteTemplateResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) InstantiateTemplate(_ context.Context, req *keepassgov1.InstantiateTemplateRequest) (*keepassgov1.InstantiateTemplateResponse, error) {
|
||||
if req.GetOverrides() == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "missing overrides")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
entry, err := s.model.InstantiateTemplate(req.GetTemplateId(), entryFromProto(req.GetOverrides()))
|
||||
if err != nil {
|
||||
if errors.Is(err, vault.ErrEntryNotFound) {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "instantiate template: %v", err)
|
||||
}
|
||||
|
||||
s.dirty = true
|
||||
return &keepassgov1.InstantiateTemplateResponse{Entry: entryToProto(entry)}, nil
|
||||
}
|
||||
|
||||
func (s *Server) ListAttachments(ctx context.Context, req *keepassgov1.ListAttachmentsRequest) (*keepassgov1.ListAttachmentsResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
entry, err := findEntryByID(model, req.GetEntryId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationReadEntry, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(entry.Attachments))
|
||||
for name := range entry.Attachments {
|
||||
names = append(names, name)
|
||||
}
|
||||
slices.Sort(names)
|
||||
|
||||
return &keepassgov1.ListAttachmentsResponse{Names: names}, nil
|
||||
}
|
||||
|
||||
func (s *Server) UploadAttachment(ctx context.Context, req *keepassgov1.UploadAttachmentRequest) (*keepassgov1.UploadAttachmentResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
entry, err := findEntryByID(model, req.GetEntryId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
|
||||
if entry.Attachments == nil {
|
||||
entry.Attachments = map[string][]byte{}
|
||||
}
|
||||
entry.Attachments[req.GetName()] = append([]byte(nil), req.GetContent()...)
|
||||
s.model.Entries[index] = entry
|
||||
s.dirty = true
|
||||
|
||||
return &keepassgov1.UploadAttachmentResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) DownloadAttachment(ctx context.Context, req *keepassgov1.DownloadAttachmentRequest) (*keepassgov1.DownloadAttachmentResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
entry, err := findEntryByID(model, req.GetEntryId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationReadEntry, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content, ok := entry.Attachments[req.GetName()]
|
||||
if !ok {
|
||||
return nil, status.Error(codes.NotFound, "attachment not found")
|
||||
}
|
||||
|
||||
return &keepassgov1.DownloadAttachmentResponse{
|
||||
Content: append([]byte(nil), content...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) DeleteAttachment(ctx context.Context, req *keepassgov1.DeleteAttachmentRequest) (*keepassgov1.DeleteAttachmentResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
entry, err := findEntryByID(model, req.GetEntryId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
if _, err := s.authorizeEntryRequest(ctx, apitokens.OperationMutateEntry, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
entry, index, err := findMutableEntryByID(&s.model, req.GetEntryId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
|
||||
if _, ok := entry.Attachments[req.GetName()]; !ok {
|
||||
return nil, status.Error(codes.NotFound, "attachment not found")
|
||||
}
|
||||
|
||||
delete(entry.Attachments, req.GetName())
|
||||
if len(entry.Attachments) == 0 {
|
||||
entry.Attachments = nil
|
||||
}
|
||||
s.model.Entries[index] = entry
|
||||
s.dirty = true
|
||||
|
||||
return &keepassgov1.DeleteAttachmentResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) CopyEntryField(ctx context.Context, req *keepassgov1.CopyEntryFieldRequest) (*keepassgov1.CopyEntryFieldResponse, error) {
|
||||
model, locked := s.snapshotModel()
|
||||
if locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
entry, err := findEntryByID(model, req.GetId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
if _, err := s.authorizeEntryRequest(ctx, copyOperation(req.GetTarget()), entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service := clipboard.Service{Writer: s.clipboard}
|
||||
if err := service.Copy(model, req.GetId(), clipboard.Target(req.GetTarget())); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, vault.ErrEntryNotFound):
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
case errors.Is(err, clipboard.ErrUnsupportedTarget):
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
default:
|
||||
return nil, status.Errorf(codes.Internal, "copy entry field: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &keepassgov1.CopyEntryFieldResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) GeneratePassword(ctx context.Context, req *keepassgov1.GeneratePasswordRequest) (*keepassgov1.GeneratePasswordResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if _, err := s.authenticateRequest(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.locked {
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is locked")
|
||||
}
|
||||
|
||||
profile, err := passwords.LookupProfile(req.GetProfile(), s.profiles)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
password, err := passwords.Generate(profile)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "generate password: %v", err)
|
||||
}
|
||||
|
||||
return &keepassgov1.GeneratePasswordResponse{Password: password}, nil
|
||||
}
|
||||
|
||||
func entryToProto(entry vault.Entry) *keepassgov1.Entry {
|
||||
return &keepassgov1.Entry{
|
||||
Id: entry.ID,
|
||||
Title: entry.Title,
|
||||
Username: entry.Username,
|
||||
Password: entry.Password,
|
||||
Url: entry.URL,
|
||||
Notes: entry.Notes,
|
||||
Tags: append([]string(nil), entry.Tags...),
|
||||
Path: append([]string(nil), entry.Path...),
|
||||
Fields: maps.Clone(entry.Fields),
|
||||
}
|
||||
}
|
||||
|
||||
func entryFromProto(entry *keepassgov1.Entry) vault.Entry {
|
||||
return vault.Entry{
|
||||
ID: entry.GetId(),
|
||||
Title: entry.GetTitle(),
|
||||
Username: entry.GetUsername(),
|
||||
Password: entry.GetPassword(),
|
||||
URL: entry.GetUrl(),
|
||||
Notes: entry.GetNotes(),
|
||||
Tags: append([]string(nil), entry.GetTags()...),
|
||||
Path: append([]string(nil), entry.GetPath()...),
|
||||
Fields: maps.Clone(entry.GetFields()),
|
||||
}
|
||||
}
|
||||
|
||||
func findEntryByID(model vault.Model, id string) (vault.Entry, error) {
|
||||
for _, entry := range model.Entries {
|
||||
if entry.ID == id {
|
||||
return entry, nil
|
||||
}
|
||||
}
|
||||
return vault.Entry{}, vault.ErrEntryNotFound
|
||||
}
|
||||
|
||||
func findMutableEntryByID(model *vault.Model, id string) (vault.Entry, int, error) {
|
||||
for i, entry := range model.Entries {
|
||||
if entry.ID == id {
|
||||
entry.Attachments = maps.Clone(entry.Attachments)
|
||||
return entry, i, nil
|
||||
}
|
||||
}
|
||||
return vault.Entry{}, -1, vault.ErrEntryNotFound
|
||||
}
|
||||
|
||||
func visibleModel(model vault.Model) vault.Model {
|
||||
out := model
|
||||
out.Entries = nil
|
||||
for _, entry := range model.Entries {
|
||||
token, ok, err := apitokens.TokenFromEntry(entry)
|
||||
if err == nil && ok && token.ID != "" {
|
||||
continue
|
||||
}
|
||||
out.Entries = append(out.Entries, entry)
|
||||
}
|
||||
out.Groups = nil
|
||||
for _, path := range model.Groups {
|
||||
if len(path) >= 2 && path[0] == "Root" && path[1] == "API Tokens" {
|
||||
continue
|
||||
}
|
||||
out.Groups = append(out.Groups, path)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Server) snapshotModel() (vault.Model, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.model, s.locked
|
||||
}
|
||||
|
||||
var timeNow = func() time.Time { return time.Now().UTC() }
|
||||
|
||||
func (s *Server) authenticateRequest(ctx context.Context) (apitokens.Token, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return apitokens.Token{}, status.Error(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
values := md.Get("authorization")
|
||||
if len(values) == 0 {
|
||||
s.audit.Record(apiaudit.Event{Type: apiaudit.EventAuthRejected, Message: "missing authorization"})
|
||||
return apitokens.Token{}, status.Error(codes.Unauthenticated, "missing authorization")
|
||||
}
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(values[0], prefix) {
|
||||
s.audit.Record(apiaudit.Event{Type: apiaudit.EventAuthRejected, Message: "invalid bearer token"})
|
||||
return apitokens.Token{}, status.Error(codes.Unauthenticated, "invalid bearer token")
|
||||
}
|
||||
s.mu.RLock()
|
||||
tokens, err := apitokens.Entries(s.model)
|
||||
s.mu.RUnlock()
|
||||
if err != nil {
|
||||
return apitokens.Token{}, status.Errorf(codes.Internal, "load api tokens: %v", err)
|
||||
}
|
||||
token, err := apitokens.Authenticate(tokens, strings.TrimSpace(strings.TrimPrefix(values[0], prefix)), timeNow())
|
||||
if err != nil {
|
||||
switch err {
|
||||
case apitokens.ErrInvalidToken, apitokens.ErrExpiredToken, apitokens.ErrDisabledToken:
|
||||
s.audit.Record(apiaudit.Event{Type: apiaudit.EventAuthRejected, Message: err.Error()})
|
||||
return apitokens.Token{}, status.Error(codes.Unauthenticated, err.Error())
|
||||
default:
|
||||
return apitokens.Token{}, status.Errorf(codes.Internal, "authenticate api token: %v", err)
|
||||
}
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *Server) authorizePathRequest(ctx context.Context, op apitokens.Operation, path []string) (apitokens.Token, error) {
|
||||
token, err := s.authenticateRequest(ctx)
|
||||
if err != nil {
|
||||
return apitokens.Token{}, err
|
||||
}
|
||||
return s.authorizeResourceRequest(ctx, token, op, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: path})
|
||||
}
|
||||
|
||||
func (s *Server) authorizeEntryRequest(ctx context.Context, op apitokens.Operation, entry vault.Entry) (apitokens.Token, error) {
|
||||
token, err := s.authenticateRequest(ctx)
|
||||
if err != nil {
|
||||
return apitokens.Token{}, err
|
||||
}
|
||||
return s.authorizeResourceRequest(ctx, token, op, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: entry.ID, Path: entry.Path})
|
||||
}
|
||||
|
||||
func (s *Server) authorizeResourceRequest(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (apitokens.Token, error) {
|
||||
switch apitokens.Evaluate(token, op, resource) {
|
||||
case apitokens.DecisionAllow:
|
||||
return token, nil
|
||||
case apitokens.DecisionDeny:
|
||||
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token")
|
||||
case apitokens.DecisionPrompt:
|
||||
s.audit.Record(apiaudit.Event{
|
||||
Type: apiaudit.EventApprovalRequested,
|
||||
TokenID: token.ID,
|
||||
TokenName: token.Name,
|
||||
ClientName: token.ClientName,
|
||||
Operation: op,
|
||||
Resource: resource,
|
||||
})
|
||||
result, err := s.approvals.Request(ctx, token, op, resource)
|
||||
if result.Rule != nil {
|
||||
if persistErr := s.persistApprovalRule(token.ID, *result.Rule); persistErr != nil {
|
||||
return apitokens.Token{}, status.Errorf(codes.Internal, "persist approval decision: %v", persistErr)
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case err == nil:
|
||||
s.audit.Record(apiaudit.Event{
|
||||
Type: apiaudit.EventApprovalAllowed,
|
||||
TokenID: token.ID,
|
||||
TokenName: token.Name,
|
||||
ClientName: token.ClientName,
|
||||
Operation: op,
|
||||
Resource: resource,
|
||||
})
|
||||
return token, nil
|
||||
case errors.Is(err, apiapproval.ErrRequestDenied):
|
||||
s.audit.Record(apiaudit.Event{
|
||||
Type: apiaudit.EventApprovalDenied,
|
||||
TokenID: token.ID,
|
||||
TokenName: token.Name,
|
||||
ClientName: token.ClientName,
|
||||
Operation: op,
|
||||
Resource: resource,
|
||||
})
|
||||
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access denied by user approval")
|
||||
case errors.Is(err, apiapproval.ErrRequestCanceled):
|
||||
s.audit.Record(apiaudit.Event{
|
||||
Type: apiaudit.EventApprovalCanceled,
|
||||
TokenID: token.ID,
|
||||
TokenName: token.Name,
|
||||
ClientName: token.ClientName,
|
||||
Operation: op,
|
||||
Resource: resource,
|
||||
})
|
||||
return apitokens.Token{}, status.Error(codes.Unauthenticated, "authorization request canceled")
|
||||
case errors.Is(err, apiapproval.ErrRequestTimedOut):
|
||||
s.audit.Record(apiaudit.Event{
|
||||
Type: apiaudit.EventApprovalTimedOut,
|
||||
TokenID: token.ID,
|
||||
TokenName: token.Name,
|
||||
ClientName: token.ClientName,
|
||||
Operation: op,
|
||||
Resource: resource,
|
||||
})
|
||||
return apitokens.Token{}, status.Error(codes.DeadlineExceeded, "authorization request timed out")
|
||||
case errors.Is(err, context.Canceled):
|
||||
return apitokens.Token{}, status.Error(codes.Canceled, "authorization request canceled")
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return apitokens.Token{}, status.Error(codes.DeadlineExceeded, "authorization request timed out")
|
||||
default:
|
||||
return apitokens.Token{}, status.Errorf(codes.Internal, "await authorization request: %v", err)
|
||||
}
|
||||
default:
|
||||
return apitokens.Token{}, status.Error(codes.PermissionDenied, "access is not allowed for this token")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) persistApprovalRule(tokenID string, rule apitokens.PolicyRule) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for i, entry := range s.model.Entries {
|
||||
token, ok, err := apitokens.TokenFromEntry(entry)
|
||||
if err != nil || !ok || token.ID != tokenID {
|
||||
continue
|
||||
}
|
||||
if !hasPolicyRule(token.Policies, rule) {
|
||||
token.Policies = append(token.Policies, rule)
|
||||
}
|
||||
s.model.Entries[i] = token.Entry(entry.Path)
|
||||
s.dirty = true
|
||||
if lifecycle, ok := s.lifecycle.(modelReplaceableLifecycle); ok {
|
||||
lifecycle.Replace(s.model)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return status.Error(codes.NotFound, "api token entry not found")
|
||||
}
|
||||
|
||||
func hasPolicyRule(rules []apitokens.PolicyRule, target apitokens.PolicyRule) bool {
|
||||
for _, rule := range rules {
|
||||
if rule.Effect != target.Effect || rule.Operation != target.Operation {
|
||||
continue
|
||||
}
|
||||
if rule.Resource.Kind != target.Resource.Kind || rule.Resource.EntryID != target.Resource.EntryID {
|
||||
continue
|
||||
}
|
||||
if slices.Equal(rule.Resource.Path, target.Resource.Path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func copyOperation(target string) apitokens.Operation {
|
||||
switch clipboard.Target(target) {
|
||||
case clipboard.TargetUsername:
|
||||
return apitokens.OperationCopyUsername
|
||||
case clipboard.TargetURL:
|
||||
return apitokens.OperationCopyURL
|
||||
default:
|
||||
return apitokens.OperationCopyPassword
|
||||
}
|
||||
}
|
||||
|
||||
func AuthInterceptor(server *Server) grpc.UnaryServerInterceptor {
|
||||
return func(
|
||||
ctx context.Context,
|
||||
req any,
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (any, error) {
|
||||
switch info.FullMethod {
|
||||
case "/keepassgo.v1.VaultService/GetSessionStatus",
|
||||
"/keepassgo.v1.VaultService/OpenVault",
|
||||
"/keepassgo.v1.VaultService/OpenRemoteVault",
|
||||
"/keepassgo.v1.VaultService/SaveVault",
|
||||
"/keepassgo.v1.VaultService/LockVault",
|
||||
"/keepassgo.v1.VaultService/UnlockVault":
|
||||
if _, err := server.authenticateRequest(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
package apiapproval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRequestDenied = errors.New("authorization request denied")
|
||||
ErrRequestCanceled = errors.New("authorization request canceled")
|
||||
ErrRequestTimedOut = errors.New("authorization request timed out")
|
||||
ErrRequestNotFound = errors.New("authorization request not found")
|
||||
)
|
||||
|
||||
type Outcome string
|
||||
|
||||
const (
|
||||
OutcomeAllowOnce Outcome = "allow-once"
|
||||
OutcomeDenyOnce Outcome = "deny-once"
|
||||
OutcomeAllowPermanent Outcome = "allow-permanent"
|
||||
OutcomeDenyPermanent Outcome = "deny-permanent"
|
||||
OutcomeCancel Outcome = "cancel"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
ID string
|
||||
TokenID string
|
||||
TokenName string
|
||||
ClientName string
|
||||
Operation apitokens.Operation
|
||||
Resource apitokens.Resource
|
||||
RequestedAt time.Time
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Outcome Outcome
|
||||
Rule *apitokens.PolicyRule
|
||||
}
|
||||
|
||||
type Broker struct {
|
||||
mu sync.Mutex
|
||||
pending map[string]*pendingRequest
|
||||
timeout time.Duration
|
||||
now func() time.Time
|
||||
nextID func() string
|
||||
}
|
||||
|
||||
type pendingRequest struct {
|
||||
request Request
|
||||
done chan Outcome
|
||||
}
|
||||
|
||||
type idGenerator struct {
|
||||
mu sync.Mutex
|
||||
counter int
|
||||
}
|
||||
|
||||
func NewBroker(timeout time.Duration) *Broker {
|
||||
gen := &idGenerator{}
|
||||
return &Broker{
|
||||
pending: map[string]*pendingRequest{},
|
||||
timeout: timeout,
|
||||
now: func() time.Time {
|
||||
return time.Now().UTC()
|
||||
},
|
||||
nextID: func() string {
|
||||
return gen.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *idGenerator) Next() string {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.counter++
|
||||
return "approval-" + strconv.Itoa(g.counter)
|
||||
}
|
||||
|
||||
func (b *Broker) Pending() []Request {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
requests := make([]Request, 0, len(b.pending))
|
||||
for _, pending := range b.pending {
|
||||
requests = append(requests, pending.request)
|
||||
}
|
||||
slices.SortFunc(requests, func(a, c Request) int {
|
||||
switch {
|
||||
case a.RequestedAt.Before(c.RequestedAt):
|
||||
return -1
|
||||
case a.RequestedAt.After(c.RequestedAt):
|
||||
return 1
|
||||
case a.ID < c.ID:
|
||||
return -1
|
||||
case a.ID > c.ID:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return requests
|
||||
}
|
||||
|
||||
func (b *Broker) Request(ctx context.Context, token apitokens.Token, op apitokens.Operation, resource apitokens.Resource) (Result, error) {
|
||||
if b == nil {
|
||||
return Result{}, ErrRequestTimedOut
|
||||
}
|
||||
|
||||
pending := &pendingRequest{
|
||||
request: Request{
|
||||
ID: b.nextID(),
|
||||
TokenID: token.ID,
|
||||
TokenName: token.Name,
|
||||
ClientName: token.ClientName,
|
||||
Operation: op,
|
||||
Resource: resource,
|
||||
RequestedAt: b.now(),
|
||||
},
|
||||
done: make(chan Outcome, 1),
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
b.pending[pending.request.ID] = pending
|
||||
b.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
b.mu.Lock()
|
||||
delete(b.pending, pending.request.ID)
|
||||
b.mu.Unlock()
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(b.timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case outcome := <-pending.done:
|
||||
result, err := resultForOutcome(pending.request, outcome)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
return result, nil
|
||||
case <-timer.C:
|
||||
return Result{}, ErrRequestTimedOut
|
||||
case <-ctx.Done():
|
||||
return Result{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) Resolve(id string, outcome Outcome) (Request, *apitokens.PolicyRule, error) {
|
||||
b.mu.Lock()
|
||||
pending, ok := b.pending[id]
|
||||
b.mu.Unlock()
|
||||
if !ok {
|
||||
return Request{}, nil, ErrRequestNotFound
|
||||
}
|
||||
|
||||
rule, err := RuleFromDecision(pending.request, outcome)
|
||||
if err != nil {
|
||||
return Request{}, nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case pending.done <- outcome:
|
||||
default:
|
||||
}
|
||||
return pending.request, rule, nil
|
||||
}
|
||||
|
||||
func resultForOutcome(request Request, outcome Outcome) (Result, error) {
|
||||
rule, err := RuleFromDecision(request, outcome)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
result := Result{Outcome: outcome, Rule: rule}
|
||||
switch outcome {
|
||||
case OutcomeAllowOnce, OutcomeAllowPermanent:
|
||||
return result, nil
|
||||
case OutcomeDenyOnce, OutcomeDenyPermanent:
|
||||
return result, ErrRequestDenied
|
||||
case OutcomeCancel:
|
||||
return result, ErrRequestCanceled
|
||||
default:
|
||||
return Result{}, fmt.Errorf("unsupported approval outcome %q", outcome)
|
||||
}
|
||||
}
|
||||
|
||||
func RuleFromDecision(request Request, outcome Outcome) (*apitokens.PolicyRule, error) {
|
||||
switch outcome {
|
||||
case OutcomeAllowPermanent:
|
||||
rule := apitokens.PolicyRule{
|
||||
Effect: apitokens.EffectAllow,
|
||||
Operation: request.Operation,
|
||||
Resource: request.Resource,
|
||||
}
|
||||
return &rule, nil
|
||||
case OutcomeDenyPermanent:
|
||||
rule := apitokens.PolicyRule{
|
||||
Effect: apitokens.EffectDeny,
|
||||
Operation: request.Operation,
|
||||
Resource: request.Resource,
|
||||
}
|
||||
return &rule, nil
|
||||
case OutcomeAllowOnce, OutcomeDenyOnce, OutcomeCancel:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported approval outcome %q", outcome)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package apiapproval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
)
|
||||
|
||||
func TestBrokerCreatesPendingRequestAndAllowsOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
broker := NewBroker(time.Minute)
|
||||
broker.now = func() time.Time { return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) }
|
||||
|
||||
resultCh := make(chan Result, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
result, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI", ClientName: "grpc-cli"}, apitokens.OperationListEntries, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}})
|
||||
resultCh <- result
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
waitForPending(t, broker, 1)
|
||||
pending := broker.Pending()
|
||||
if len(pending) != 1 {
|
||||
t.Fatalf("Pending() len = %d, want 1", len(pending))
|
||||
}
|
||||
if pending[0].TokenID != "token-1" {
|
||||
t.Fatalf("Pending()[0].TokenID = %q, want token-1", pending[0].TokenID)
|
||||
}
|
||||
|
||||
if _, _, err := broker.Resolve(pending[0].ID, OutcomeAllowOnce); err != nil {
|
||||
t.Fatalf("Resolve(allow once) error = %v", err)
|
||||
}
|
||||
|
||||
result := <-resultCh
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("Request() error = %v, want nil", err)
|
||||
}
|
||||
if result.Outcome != OutcomeAllowOnce {
|
||||
t.Fatalf("Request() outcome = %q, want %q", result.Outcome, OutcomeAllowOnce)
|
||||
}
|
||||
if result.Rule != nil {
|
||||
t.Fatalf("Request() rule = %#v, want nil for allow-once", result.Rule)
|
||||
}
|
||||
if got := broker.Pending(); len(got) != 0 {
|
||||
t.Fatalf("Pending() after allow len = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerReturnsPermanentRuleForDeny(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
broker := NewBroker(time.Minute)
|
||||
reqDone := make(chan struct{})
|
||||
var result Result
|
||||
var err error
|
||||
go func() {
|
||||
result, err = broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationReadEntry, apitokens.Resource{Kind: apitokens.ResourceEntry, EntryID: "entry-1", Path: []string{"Root", "Internet"}})
|
||||
close(reqDone)
|
||||
}()
|
||||
|
||||
waitForPending(t, broker, 1)
|
||||
pending := broker.Pending()[0]
|
||||
request, rule, resolveErr := broker.Resolve(pending.ID, OutcomeDenyPermanent)
|
||||
if resolveErr != nil {
|
||||
t.Fatalf("Resolve(deny permanent) error = %v", resolveErr)
|
||||
}
|
||||
if request.ID != pending.ID {
|
||||
t.Fatalf("Resolve().ID = %q, want %q", request.ID, pending.ID)
|
||||
}
|
||||
if rule == nil || rule.Effect != apitokens.EffectDeny || rule.Operation != apitokens.OperationReadEntry {
|
||||
t.Fatalf("Resolve() rule = %#v, want deny read-entry rule", rule)
|
||||
}
|
||||
|
||||
<-reqDone
|
||||
if !errors.Is(err, ErrRequestDenied) {
|
||||
t.Fatalf("Request() error = %v, want ErrRequestDenied", err)
|
||||
}
|
||||
if result.Rule == nil || result.Rule.Effect != apitokens.EffectDeny {
|
||||
t.Fatalf("Request() rule = %#v, want deny rule", result.Rule)
|
||||
}
|
||||
if result.Outcome != OutcomeDenyPermanent {
|
||||
t.Fatalf("Request() outcome = %q, want %q", result.Outcome, OutcomeDenyPermanent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerSupportsCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
broker := NewBroker(time.Minute)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
waitForPending(t, broker, 1)
|
||||
if _, _, err := broker.Resolve(broker.Pending()[0].ID, OutcomeCancel); err != nil {
|
||||
t.Fatalf("Resolve(cancel) error = %v", err)
|
||||
}
|
||||
if err := <-errCh; !errors.Is(err, ErrRequestCanceled) {
|
||||
t.Fatalf("Request() error = %v, want ErrRequestCanceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerTimesOutPendingRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
broker := NewBroker(10 * time.Millisecond)
|
||||
_, err := broker.Request(context.Background(), apitokens.Token{ID: "token-1", Name: "CLI"}, apitokens.OperationListGroups, apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}})
|
||||
if !errors.Is(err, ErrRequestTimedOut) {
|
||||
t.Fatalf("Request() error = %v, want ErrRequestTimedOut", err)
|
||||
}
|
||||
if got := broker.Pending(); len(got) != 0 {
|
||||
t.Fatalf("Pending() len after timeout = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func waitForPending(t *testing.T, broker *Broker, want int) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if got := len(broker.Pending()); got == want {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("Pending() never reached len %d", want)
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package apiaudit
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
)
|
||||
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventApprovalRequested EventType = "approval_requested"
|
||||
EventApprovalAllowed EventType = "approval_allowed"
|
||||
EventApprovalDenied EventType = "approval_denied"
|
||||
EventApprovalCanceled EventType = "approval_canceled"
|
||||
EventApprovalTimedOut EventType = "approval_timed_out"
|
||||
EventAutofillFound EventType = "autofill_found"
|
||||
EventAutofillAmbiguous EventType = "autofill_ambiguous"
|
||||
EventAutofillBlocked EventType = "autofill_blocked"
|
||||
EventAuthRejected EventType = "auth_rejected"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Type EventType
|
||||
At time.Time
|
||||
TokenID string
|
||||
TokenName string
|
||||
ClientName string
|
||||
Operation apitokens.Operation
|
||||
Resource apitokens.Resource
|
||||
Message string
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
mu sync.Mutex
|
||||
max int
|
||||
now func() time.Time
|
||||
events []Event
|
||||
}
|
||||
|
||||
func New(max int) *Log {
|
||||
if max < 1 {
|
||||
max = 1
|
||||
}
|
||||
return &Log{
|
||||
max: max,
|
||||
now: func() time.Time {
|
||||
return time.Now().UTC()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Log) Record(event Event) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if event.At.IsZero() {
|
||||
event.At = l.now()
|
||||
}
|
||||
l.events = append([]Event{event}, l.events...)
|
||||
if len(l.events) > l.max {
|
||||
l.events = l.events[:l.max]
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Log) Events() []Event {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return slices.Clone(l.events)
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package apiaudit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
)
|
||||
|
||||
func TestLogKeepsNewestEventsWithinBound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := New(2)
|
||||
log.now = func() time.Time { return time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC) }
|
||||
log.Record(Event{Type: EventApprovalRequested, TokenID: "token-1"})
|
||||
log.Record(Event{Type: EventApprovalAllowed, TokenID: "token-2"})
|
||||
log.Record(Event{Type: EventApprovalDenied, TokenID: "token-3"})
|
||||
|
||||
events := log.Events()
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("len(Events()) = %d, want 2", len(events))
|
||||
}
|
||||
if events[0].TokenID != "token-3" || events[1].TokenID != "token-2" {
|
||||
t.Fatalf("Events() = %#v, want newest-first bounded list", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogPreservesRecordedMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := New(5)
|
||||
log.Record(Event{
|
||||
Type: EventApprovalRequested,
|
||||
TokenID: "token-1",
|
||||
TokenName: "CLI",
|
||||
ClientName: "grpc-cli",
|
||||
Operation: apitokens.OperationListEntries,
|
||||
Resource: apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root", "Internet"}},
|
||||
Message: "prompted for access",
|
||||
})
|
||||
|
||||
events := log.Events()
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("len(Events()) = %d, want 1", len(events))
|
||||
}
|
||||
if events[0].Operation != apitokens.OperationListEntries || events[0].Message != "prompted for access" {
|
||||
t.Fatalf("Events()[0] = %#v, want preserved metadata", events[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogStoresAutofillEventTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := New(5)
|
||||
log.Record(Event{
|
||||
Type: EventAutofillAmbiguous,
|
||||
TokenName: "Browser Extension",
|
||||
Message: "multiple matches for example.com",
|
||||
})
|
||||
|
||||
events := log.Events()
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("len(Events()) = %d, want 1", len(events))
|
||||
}
|
||||
if events[0].Type != EventAutofillAmbiguous {
|
||||
t.Fatalf("Events()[0].Type = %q, want %q", events[0].Type, EventAutofillAmbiguous)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,335 @@
|
||||
package apitokens
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
const (
|
||||
EntryTypeAPIToken = "api-token"
|
||||
|
||||
FieldType = "KeePassGO-Type"
|
||||
FieldTokenID = "KeePassGO-API-Token-ID"
|
||||
FieldClientName = "KeePassGO-API-Client-Name"
|
||||
FieldCreatedAt = "KeePassGO-API-Created-At"
|
||||
FieldExpiresAt = "KeePassGO-API-Expires-At"
|
||||
FieldDisabled = "KeePassGO-API-Disabled"
|
||||
FieldRevokedAt = "KeePassGO-API-Revoked-At"
|
||||
FieldSecretHash = "KeePassGO-API-Secret-Hash"
|
||||
FieldPolicies = "KeePassGO-API-Policies"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotAToken = errors.New("entry is not an api token")
|
||||
ErrInvalidToken = errors.New("invalid api token")
|
||||
ErrExpiredToken = errors.New("expired api token")
|
||||
ErrDisabledToken = errors.New("disabled api token")
|
||||
ErrTokenNotFound = errors.New("api token not found")
|
||||
)
|
||||
|
||||
var EntryPath = []string{"Root", "API Tokens"}
|
||||
|
||||
type Effect string
|
||||
type Operation string
|
||||
type ResourceKind string
|
||||
type Decision string
|
||||
|
||||
const (
|
||||
EffectAllow Effect = "allow"
|
||||
EffectDeny Effect = "deny"
|
||||
|
||||
ResourceGroup ResourceKind = "group"
|
||||
ResourceEntry ResourceKind = "entry"
|
||||
|
||||
DecisionAllow Decision = "allow"
|
||||
DecisionDeny Decision = "deny"
|
||||
DecisionPrompt Decision = "prompt"
|
||||
|
||||
OperationListEntries Operation = "list_entries"
|
||||
OperationListGroups Operation = "list_groups"
|
||||
OperationReadEntry Operation = "read_entry"
|
||||
OperationCopyPassword Operation = "copy_password"
|
||||
OperationCopyUsername Operation = "copy_username"
|
||||
OperationCopyURL Operation = "copy_url"
|
||||
OperationMutateEntry Operation = "mutate_entry"
|
||||
OperationMutateGroup Operation = "mutate_group"
|
||||
OperationManageVault Operation = "manage_vault"
|
||||
)
|
||||
|
||||
type Resource struct {
|
||||
Kind ResourceKind `json:"kind"`
|
||||
Path []string `json:"path,omitempty"`
|
||||
EntryID string `json:"entry_id,omitempty"`
|
||||
}
|
||||
|
||||
type PolicyRule struct {
|
||||
Effect Effect `json:"effect"`
|
||||
Operation Operation `json:"operation"`
|
||||
Resource Resource `json:"resource"`
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
ID string
|
||||
Name string
|
||||
ClientName string
|
||||
SecretHash string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt *time.Time
|
||||
RevokedAt *time.Time
|
||||
Disabled bool
|
||||
Policies []PolicyRule
|
||||
}
|
||||
|
||||
func Issue(name, clientName string, expiresAt *time.Time, now time.Time) (Token, string, error) {
|
||||
clear, hashed, err := newSecret()
|
||||
if err != nil {
|
||||
return Token{}, "", err
|
||||
}
|
||||
id, _, err := newSecret()
|
||||
if err != nil {
|
||||
return Token{}, "", err
|
||||
}
|
||||
return Token{
|
||||
ID: id,
|
||||
Name: strings.TrimSpace(name),
|
||||
ClientName: strings.TrimSpace(clientName),
|
||||
SecretHash: hashed,
|
||||
CreatedAt: now.UTC(),
|
||||
ExpiresAt: cloneTime(expiresAt),
|
||||
}, clear, nil
|
||||
}
|
||||
|
||||
func Rotate(token Token, now time.Time) (Token, string, error) {
|
||||
clear, hashed, err := newSecret()
|
||||
if err != nil {
|
||||
return Token{}, "", err
|
||||
}
|
||||
token.SecretHash = hashed
|
||||
token.Disabled = false
|
||||
token.RevokedAt = nil
|
||||
if token.CreatedAt.IsZero() {
|
||||
token.CreatedAt = now.UTC()
|
||||
}
|
||||
return token, clear, nil
|
||||
}
|
||||
|
||||
func Disable(token Token) Token {
|
||||
token.Disabled = true
|
||||
return token
|
||||
}
|
||||
|
||||
func Revoke(token Token, when time.Time) Token {
|
||||
token.Disabled = true
|
||||
t := when.UTC()
|
||||
token.RevokedAt = &t
|
||||
return token
|
||||
}
|
||||
|
||||
func Authenticate(tokens []Token, presentedSecret string, now time.Time) (Token, error) {
|
||||
hashed := hashSecret(presentedSecret)
|
||||
for _, token := range tokens {
|
||||
if token.SecretHash != hashed {
|
||||
continue
|
||||
}
|
||||
if token.Disabled || token.RevokedAt != nil {
|
||||
return Token{}, ErrDisabledToken
|
||||
}
|
||||
if token.ExpiresAt != nil && !token.ExpiresAt.After(now.UTC()) {
|
||||
return Token{}, ErrExpiredToken
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
return Token{}, ErrInvalidToken
|
||||
}
|
||||
|
||||
func Evaluate(token Token, operation Operation, resource Resource) Decision {
|
||||
decision := DecisionPrompt
|
||||
for _, rule := range token.Policies {
|
||||
if rule.Operation != operation {
|
||||
continue
|
||||
}
|
||||
if !matches(rule.Resource, resource) {
|
||||
continue
|
||||
}
|
||||
if rule.Effect == EffectDeny {
|
||||
return DecisionDeny
|
||||
}
|
||||
if rule.Effect == EffectAllow {
|
||||
decision = DecisionAllow
|
||||
}
|
||||
}
|
||||
return decision
|
||||
}
|
||||
|
||||
func Entries(model vault.Model) ([]Token, error) {
|
||||
var out []Token
|
||||
for _, entry := range model.Entries {
|
||||
token, ok, err := TokenFromEntry(entry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
out = append(out, token)
|
||||
}
|
||||
}
|
||||
slices.SortFunc(out, func(a, b Token) int {
|
||||
switch {
|
||||
case a.Name < b.Name:
|
||||
return -1
|
||||
case a.Name > b.Name:
|
||||
return 1
|
||||
default:
|
||||
return strings.Compare(a.ID, b.ID)
|
||||
}
|
||||
})
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func Find(model vault.Model, id string) (Token, error) {
|
||||
tokens, err := Entries(model)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
for _, token := range tokens {
|
||||
if token.ID == id {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
return Token{}, ErrTokenNotFound
|
||||
}
|
||||
|
||||
func Upsert(model *vault.Model, token Token) {
|
||||
model.UpsertEntry(token.Entry(EntryPath))
|
||||
model.CreateGroup([]string{"Root"}, "API Tokens")
|
||||
}
|
||||
|
||||
func Delete(model *vault.Model, id string) error {
|
||||
for i, entry := range model.Entries {
|
||||
token, ok, err := TokenFromEntry(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok && token.ID == id {
|
||||
model.Entries = append(model.Entries[:i], model.Entries[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
|
||||
func TokenFromEntry(entry vault.Entry) (Token, bool, error) {
|
||||
if entry.Fields[FieldType] != EntryTypeAPIToken {
|
||||
return Token{}, false, nil
|
||||
}
|
||||
createdAt, err := time.Parse(time.RFC3339, entry.Fields[FieldCreatedAt])
|
||||
if err != nil {
|
||||
return Token{}, true, fmt.Errorf("parse created at: %w", err)
|
||||
}
|
||||
var expiresAt *time.Time
|
||||
if raw := strings.TrimSpace(entry.Fields[FieldExpiresAt]); raw != "" {
|
||||
t, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return Token{}, true, fmt.Errorf("parse expires at: %w", err)
|
||||
}
|
||||
expiresAt = &t
|
||||
}
|
||||
var revokedAt *time.Time
|
||||
if raw := strings.TrimSpace(entry.Fields[FieldRevokedAt]); raw != "" {
|
||||
t, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return Token{}, true, fmt.Errorf("parse revoked at: %w", err)
|
||||
}
|
||||
revokedAt = &t
|
||||
}
|
||||
policies := []PolicyRule{}
|
||||
if raw := strings.TrimSpace(entry.Fields[FieldPolicies]); raw != "" {
|
||||
if err := json.Unmarshal([]byte(raw), &policies); err != nil {
|
||||
return Token{}, true, fmt.Errorf("parse policies: %w", err)
|
||||
}
|
||||
}
|
||||
return Token{
|
||||
ID: entry.Fields[FieldTokenID],
|
||||
Name: entry.Title,
|
||||
ClientName: entry.Fields[FieldClientName],
|
||||
SecretHash: entry.Fields[FieldSecretHash],
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: expiresAt,
|
||||
RevokedAt: revokedAt,
|
||||
Disabled: strings.EqualFold(entry.Fields[FieldDisabled], "true"),
|
||||
Policies: policies,
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func (t Token) Entry(path []string) vault.Entry {
|
||||
fields := map[string]string{
|
||||
FieldType: EntryTypeAPIToken,
|
||||
FieldTokenID: t.ID,
|
||||
FieldClientName: t.ClientName,
|
||||
FieldCreatedAt: t.CreatedAt.UTC().Format(time.RFC3339),
|
||||
FieldDisabled: fmt.Sprintf("%t", t.Disabled),
|
||||
FieldSecretHash: t.SecretHash,
|
||||
}
|
||||
if t.ExpiresAt != nil {
|
||||
fields[FieldExpiresAt] = t.ExpiresAt.UTC().Format(time.RFC3339)
|
||||
}
|
||||
if t.RevokedAt != nil {
|
||||
fields[FieldRevokedAt] = t.RevokedAt.UTC().Format(time.RFC3339)
|
||||
}
|
||||
if len(t.Policies) > 0 {
|
||||
data, _ := json.Marshal(t.Policies)
|
||||
fields[FieldPolicies] = string(data)
|
||||
}
|
||||
return vault.Entry{
|
||||
ID: t.ID,
|
||||
Title: t.Name,
|
||||
Username: t.ClientName,
|
||||
Path: slices.Clone(path),
|
||||
Fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
func hashSecret(secret string) string {
|
||||
sum := sha256.Sum256([]byte(secret))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func newSecret() (string, string, error) {
|
||||
buf := make([]byte, 24)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", "", fmt.Errorf("generate secret: %w", err)
|
||||
}
|
||||
clear := base64.RawURLEncoding.EncodeToString(buf)
|
||||
return clear, hashSecret(clear), nil
|
||||
}
|
||||
|
||||
func cloneTime(in *time.Time) *time.Time {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
t := in.UTC()
|
||||
return &t
|
||||
}
|
||||
|
||||
func matches(rule, resource Resource) bool {
|
||||
switch rule.Kind {
|
||||
case ResourceEntry:
|
||||
return rule.EntryID != "" && rule.EntryID == resource.EntryID
|
||||
case ResourceGroup:
|
||||
if len(rule.Path) > len(resource.Path) {
|
||||
return false
|
||||
}
|
||||
return slices.Equal(rule.Path, resource.Path[:len(rule.Path)])
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
package apitokens
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
func TestTokenEntryRoundTripsThroughVaultEntry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expiresAt := time.Date(2026, 4, 1, 12, 0, 0, 0, time.UTC)
|
||||
token := Token{
|
||||
ID: "token-1",
|
||||
Name: "Browser Connector",
|
||||
ClientName: "browser-extension",
|
||||
SecretHash: "deadbeef",
|
||||
CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC),
|
||||
ExpiresAt: &expiresAt,
|
||||
Disabled: true,
|
||||
Policies: []PolicyRule{
|
||||
{Effect: EffectAllow, Operation: OperationListEntries, Resource: Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}}},
|
||||
{Effect: EffectDeny, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceEntry, EntryID: "bank-token"}},
|
||||
},
|
||||
}
|
||||
|
||||
entry := token.Entry([]string{"Root", "API Tokens"})
|
||||
if entry.Fields[FieldType] != EntryTypeAPIToken {
|
||||
t.Fatalf("FieldType = %q, want %q", entry.Fields[FieldType], EntryTypeAPIToken)
|
||||
}
|
||||
|
||||
got, ok, err := TokenFromEntry(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("TokenFromEntry() error = %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("TokenFromEntry() ok = false, want true")
|
||||
}
|
||||
if got.ID != token.ID || got.Name != token.Name || got.ClientName != token.ClientName || got.SecretHash != token.SecretHash || !got.Disabled {
|
||||
t.Fatalf("TokenFromEntry() = %#v, want %#v", got, token)
|
||||
}
|
||||
if got.ExpiresAt == nil || !got.ExpiresAt.Equal(expiresAt) {
|
||||
t.Fatalf("ExpiresAt = %#v, want %v", got.ExpiresAt, expiresAt)
|
||||
}
|
||||
if len(got.Policies) != 2 {
|
||||
t.Fatalf("len(Policies) = %d, want 2", len(got.Policies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntriesFiltersOnlyAPITokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{ID: "entry-1", Title: "Ordinary Entry"},
|
||||
Token{ID: "token-1", Name: "CLI", SecretHash: "hash-1", CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)}.Entry([]string{"Root", "API Tokens"}),
|
||||
Token{ID: "token-2", Name: "Browser", SecretHash: "hash-2", CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)}.Entry([]string{"Root", "API Tokens"}),
|
||||
},
|
||||
}
|
||||
|
||||
got, err := Entries(model)
|
||||
if err != nil {
|
||||
t.Fatalf("Entries() error = %v", err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("len(Entries()) = %d, want 2", len(got))
|
||||
}
|
||||
if got[0].Name != "Browser" || got[1].Name != "CLI" {
|
||||
t.Fatalf("Entries() = %#v, want Browser then CLI by name sort", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertCreatesAndUpdatesVaultTokenEntry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := vault.Model{}
|
||||
token := Token{
|
||||
ID: "token-1",
|
||||
Name: "CLI",
|
||||
ClientName: "grpc-cli",
|
||||
SecretHash: "hash-1",
|
||||
CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
Upsert(&model, token)
|
||||
if len(model.Entries) != 1 {
|
||||
t.Fatalf("len(model.Entries) = %d, want 1", len(model.Entries))
|
||||
}
|
||||
if got := model.ChildGroups([]string{"Root"}); !slices.Equal(got, []string{"API Tokens"}) {
|
||||
t.Fatalf("ChildGroups(Root) = %v, want [API Tokens]", got)
|
||||
}
|
||||
|
||||
token.ClientName = "rotated-client"
|
||||
token.Disabled = true
|
||||
Upsert(&model, token)
|
||||
|
||||
got, err := Find(model, "token-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Find() error = %v", err)
|
||||
}
|
||||
if got.ClientName != "rotated-client" || !got.Disabled {
|
||||
t.Fatalf("Find() = %#v, want updated client name and disabled state", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRemovesTokenEntryByTokenID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
Token{ID: "token-1", Name: "CLI", SecretHash: "hash-1", CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)}.Entry(EntryPath),
|
||||
{ID: "entry-1", Title: "Regular Entry"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := Delete(&model, "token-1"); err != nil {
|
||||
t.Fatalf("Delete() error = %v", err)
|
||||
}
|
||||
if len(model.Entries) != 1 || model.Entries[0].ID != "entry-1" {
|
||||
t.Fatalf("model.Entries after Delete = %#v, want only regular entry", model.Entries)
|
||||
}
|
||||
if err := Delete(&model, "missing"); err != ErrTokenNotFound {
|
||||
t.Fatalf("Delete(missing) error = %v, want %v", err, ErrTokenNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssueRotateDisableAndRevokeToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
|
||||
token, secret, err := Issue("Browser Connector", "browser-extension", nil, now)
|
||||
if err != nil {
|
||||
t.Fatalf("Issue() error = %v", err)
|
||||
}
|
||||
if token.ID == "" || secret == "" {
|
||||
t.Fatalf("Issue() returned empty token or secret: %#v %q", token, secret)
|
||||
}
|
||||
if token.SecretHash == secret {
|
||||
t.Fatal("SecretHash should not equal cleartext secret")
|
||||
}
|
||||
if token.Disabled {
|
||||
t.Fatal("Disabled = true, want false after Issue")
|
||||
}
|
||||
|
||||
rotated, newSecret, err := Rotate(token, now.Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("Rotate() error = %v", err)
|
||||
}
|
||||
if rotated.ID != token.ID {
|
||||
t.Fatalf("Rotate() changed ID from %q to %q", token.ID, rotated.ID)
|
||||
}
|
||||
if newSecret == secret {
|
||||
t.Fatal("Rotate() returned the same cleartext secret, want a new one")
|
||||
}
|
||||
if rotated.SecretHash == token.SecretHash {
|
||||
t.Fatal("Rotate() left SecretHash unchanged")
|
||||
}
|
||||
|
||||
disabled := Disable(rotated)
|
||||
if !disabled.Disabled {
|
||||
t.Fatal("Disable() did not set Disabled")
|
||||
}
|
||||
revoked := Revoke(disabled, now.Add(2*time.Hour))
|
||||
if !revoked.Disabled {
|
||||
t.Fatal("Revoke() should leave token disabled")
|
||||
}
|
||||
if revoked.RevokedAt == nil || !revoked.RevokedAt.Equal(now.Add(2*time.Hour)) {
|
||||
t.Fatalf("RevokedAt = %#v, want %v", revoked.RevokedAt, now.Add(2*time.Hour))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticateRejectsDisabledExpiredAndWrongSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
|
||||
valid, secret, err := Issue("CLI", "cli-tool", nil, now)
|
||||
if err != nil {
|
||||
t.Fatalf("Issue() error = %v", err)
|
||||
}
|
||||
expired, expiredSecret, err := Issue("Expired", "cli-tool", nil, now)
|
||||
if err != nil {
|
||||
t.Fatalf("Issue(expired) error = %v", err)
|
||||
}
|
||||
expiredAt := now.Add(-time.Minute)
|
||||
expired.ExpiresAt = &expiredAt
|
||||
disabled, disabledSecret, err := Issue("Disabled", "cli-tool", nil, now)
|
||||
if err != nil {
|
||||
t.Fatalf("Issue(disabled) error = %v", err)
|
||||
}
|
||||
disabled = Disable(disabled)
|
||||
|
||||
tokens := []Token{expired, disabled, valid}
|
||||
if _, err := Authenticate(tokens, "wrong-secret", now); err != ErrInvalidToken {
|
||||
t.Fatalf("Authenticate(wrong-secret) error = %v, want %v", err, ErrInvalidToken)
|
||||
}
|
||||
if _, err := Authenticate([]Token{expired}, expiredSecret, now); err != ErrExpiredToken {
|
||||
t.Fatalf("Authenticate(expired) error = %v, want %v", err, ErrExpiredToken)
|
||||
}
|
||||
if _, err := Authenticate([]Token{disabled}, disabledSecret, now); err != ErrDisabledToken {
|
||||
t.Fatalf("Authenticate(disabled) error = %v, want %v", err, ErrDisabledToken)
|
||||
}
|
||||
got, err := Authenticate(tokens, secret, now)
|
||||
if err != nil {
|
||||
t.Fatalf("Authenticate(valid) error = %v", err)
|
||||
}
|
||||
if got.ID != valid.ID {
|
||||
t.Fatalf("Authenticate(valid).ID = %q, want %q", got.ID, valid.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvaluatePolicyDistinguishesAllowDenyAndPrompt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token := Token{
|
||||
ID: "token-1",
|
||||
Policies: []PolicyRule{
|
||||
{Effect: EffectAllow, Operation: OperationListEntries, Resource: Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}}},
|
||||
{Effect: EffectDeny, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceEntry, EntryID: "amazon"}},
|
||||
},
|
||||
}
|
||||
|
||||
decision := Evaluate(token, OperationListEntries, Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}})
|
||||
if decision != DecisionAllow {
|
||||
t.Fatalf("Evaluate(allow) = %q, want %q", decision, DecisionAllow)
|
||||
}
|
||||
|
||||
decision = Evaluate(token, OperationCopyPassword, Resource{Kind: ResourceEntry, EntryID: "amazon", Path: []string{"Root", "Internet"}})
|
||||
if decision != DecisionDeny {
|
||||
t.Fatalf("Evaluate(deny) = %q, want %q", decision, DecisionDeny)
|
||||
}
|
||||
|
||||
decision = Evaluate(token, OperationCopyPassword, Resource{Kind: ResourceEntry, EntryID: "github", Path: []string{"Root", "Internet"}})
|
||||
if decision != DecisionPrompt {
|
||||
t.Fatalf("Evaluate(prompt) = %q, want %q", decision, DecisionPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryScopedDenyOverridesGroupAllow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token := Token{
|
||||
ID: "token-1",
|
||||
Policies: []PolicyRule{
|
||||
{Effect: EffectAllow, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}}},
|
||||
{Effect: EffectDeny, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceEntry, EntryID: "bank", Path: []string{"Root", "Internet"}}},
|
||||
},
|
||||
}
|
||||
|
||||
allowed := Evaluate(token, OperationCopyPassword, Resource{Kind: ResourceEntry, EntryID: "forum", Path: []string{"Root", "Internet"}})
|
||||
denied := Evaluate(token, OperationCopyPassword, Resource{Kind: ResourceEntry, EntryID: "bank", Path: []string{"Root", "Internet"}})
|
||||
if allowed != DecisionAllow || denied != DecisionDeny {
|
||||
t.Fatalf("Evaluate() allow/deny = %q/%q, want %q/%q", allowed, denied, DecisionAllow, DecisionDeny)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenEntryKeepsPoliciesStableAcrossRoundTripOrdering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token := Token{
|
||||
ID: "token-1",
|
||||
Name: "CLI",
|
||||
ClientName: "cli",
|
||||
SecretHash: "hash",
|
||||
CreatedAt: time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC),
|
||||
Policies: []PolicyRule{
|
||||
{Effect: EffectDeny, Operation: OperationCopyPassword, Resource: Resource{Kind: ResourceEntry, EntryID: "bank"}},
|
||||
{Effect: EffectAllow, Operation: OperationListEntries, Resource: Resource{Kind: ResourceGroup, Path: []string{"Root", "Internet"}}},
|
||||
},
|
||||
}
|
||||
|
||||
got, ok, err := TokenFromEntry(token.Entry([]string{"Root", "API Tokens"}))
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("TokenFromEntry() error = %v, ok=%v", err, ok)
|
||||
}
|
||||
if !slices.EqualFunc(got.Policies, token.Policies, func(a, b PolicyRule) bool {
|
||||
return a.Effect == b.Effect && a.Operation == b.Operation && a.Resource.Kind == b.Resource.Kind && a.Resource.EntryID == b.Resource.EntryID && slices.Equal(a.Resource.Path, b.Resource.Path)
|
||||
}) {
|
||||
t.Fatalf("Policies after round trip = %#v, want %#v", got.Policies, token.Policies)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package appstate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
type SyncMode string
|
||||
|
||||
const (
|
||||
SyncModeManual SyncMode = "manual"
|
||||
SyncModeAutomaticOnOpenSave SyncMode = "automatic_on_open_save"
|
||||
)
|
||||
|
||||
type RemoteBinding struct {
|
||||
LocalVaultPath string `json:"localVaultPath"`
|
||||
RemoteProfileID string `json:"remoteProfileId"`
|
||||
CredentialEntryID string `json:"credentialEntryId"`
|
||||
SyncMode SyncMode `json:"syncMode,omitempty"`
|
||||
}
|
||||
|
||||
type ResolvedRemoteBinding struct {
|
||||
Profile vault.RemoteProfile
|
||||
Credentials vault.Entry
|
||||
}
|
||||
|
||||
type RemoteBindingInput struct {
|
||||
LocalVaultPath string
|
||||
RemoteProfileID string
|
||||
RemoteProfileName string
|
||||
BaseURL string
|
||||
RemotePath string
|
||||
CredentialEntryID string
|
||||
CredentialTitle string
|
||||
Username string
|
||||
Password string
|
||||
CredentialPath []string
|
||||
SyncMode SyncMode
|
||||
}
|
||||
|
||||
func (b RemoteBinding) Resolve(model vault.Model) (ResolvedRemoteBinding, error) {
|
||||
profile, err := model.RemoteProfileByID(b.RemoteProfileID)
|
||||
if err != nil {
|
||||
return ResolvedRemoteBinding{}, fmt.Errorf("resolve remote profile: %w", err)
|
||||
}
|
||||
credentials, err := model.EntryByID(b.CredentialEntryID)
|
||||
if err != nil {
|
||||
return ResolvedRemoteBinding{}, fmt.Errorf("resolve remote credentials: %w", err)
|
||||
}
|
||||
return ResolvedRemoteBinding{
|
||||
Profile: profile,
|
||||
Credentials: credentials,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ConfigureRemoteBinding(model *vault.Model, input RemoteBindingInput) (RemoteBinding, error) {
|
||||
if model == nil {
|
||||
return RemoteBinding{}, fmt.Errorf("model is required")
|
||||
}
|
||||
|
||||
input.LocalVaultPath = strings.TrimSpace(input.LocalVaultPath)
|
||||
input.RemoteProfileID = strings.TrimSpace(input.RemoteProfileID)
|
||||
input.RemoteProfileName = strings.TrimSpace(input.RemoteProfileName)
|
||||
input.BaseURL = strings.TrimSpace(input.BaseURL)
|
||||
input.RemotePath = strings.TrimSpace(input.RemotePath)
|
||||
input.CredentialEntryID = strings.TrimSpace(input.CredentialEntryID)
|
||||
input.CredentialTitle = strings.TrimSpace(input.CredentialTitle)
|
||||
input.Username = strings.TrimSpace(input.Username)
|
||||
|
||||
switch {
|
||||
case input.LocalVaultPath == "":
|
||||
return RemoteBinding{}, fmt.Errorf("local vault path is required")
|
||||
case input.RemoteProfileID == "":
|
||||
return RemoteBinding{}, fmt.Errorf("remote profile id is required")
|
||||
case input.BaseURL == "":
|
||||
return RemoteBinding{}, fmt.Errorf("remote base URL is required")
|
||||
case input.RemotePath == "":
|
||||
return RemoteBinding{}, fmt.Errorf("remote path is required")
|
||||
case input.CredentialEntryID == "":
|
||||
return RemoteBinding{}, fmt.Errorf("credential entry id is required")
|
||||
case input.Password == "":
|
||||
return RemoteBinding{}, fmt.Errorf("credential password is required")
|
||||
}
|
||||
|
||||
if input.RemoteProfileName == "" {
|
||||
input.RemoteProfileName = input.RemoteProfileID
|
||||
}
|
||||
if input.CredentialTitle == "" {
|
||||
input.CredentialTitle = "Remote Sign-In"
|
||||
}
|
||||
|
||||
model.UpsertRemoteProfile(vault.RemoteProfile{
|
||||
ID: input.RemoteProfileID,
|
||||
Name: input.RemoteProfileName,
|
||||
Backend: vault.RemoteBackendWebDAV,
|
||||
BaseURL: input.BaseURL,
|
||||
Path: input.RemotePath,
|
||||
})
|
||||
model.UpsertEntry(vault.Entry{
|
||||
ID: input.CredentialEntryID,
|
||||
Title: input.CredentialTitle,
|
||||
Username: input.Username,
|
||||
Password: input.Password,
|
||||
URL: input.BaseURL,
|
||||
Path: append([]string(nil), input.CredentialPath...),
|
||||
})
|
||||
|
||||
return RemoteBinding{
|
||||
LocalVaultPath: input.LocalVaultPath,
|
||||
RemoteProfileID: input.RemoteProfileID,
|
||||
CredentialEntryID: input.CredentialEntryID,
|
||||
SyncMode: normalizeSyncMode(input.SyncMode),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RemoveRemoteBinding(model *vault.Model, binding RemoteBinding) error {
|
||||
if model == nil {
|
||||
return fmt.Errorf("model is required")
|
||||
}
|
||||
if strings.TrimSpace(binding.RemoteProfileID) == "" {
|
||||
return fmt.Errorf("remote profile id is required")
|
||||
}
|
||||
if strings.TrimSpace(binding.CredentialEntryID) == "" {
|
||||
return fmt.Errorf("credential entry id is required")
|
||||
}
|
||||
|
||||
model.RemoveRemoteProfileByID(binding.RemoteProfileID)
|
||||
model.RemoveEntryByID(binding.CredentialEntryID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeSyncMode(mode SyncMode) SyncMode {
|
||||
switch mode {
|
||||
case SyncModeAutomaticOnOpenSave:
|
||||
return SyncModeAutomaticOnOpenSave
|
||||
default:
|
||||
return SyncModeManual
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
package appstate
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
func TestRemoteBindingResolveUsesVaultProfileAndCredentialEntry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{
|
||||
ID: "linuscaldwell-webdav",
|
||||
Title: "Bellagio WebDAV Sign-In",
|
||||
Username: "linuscaldwell",
|
||||
Password: "bellagio-pass-1",
|
||||
Path: []string{"Crew", "Internet"},
|
||||
},
|
||||
},
|
||||
RemoteProfiles: []vault.RemoteProfile{
|
||||
{
|
||||
ID: "bellagio-webdav",
|
||||
Name: "Bellagio Vault",
|
||||
Backend: vault.RemoteBackendWebDAV,
|
||||
BaseURL: "https://dav.example.invalid/remote.php/dav",
|
||||
Path: "files/bellagio/keepass.kdbx",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
binding := RemoteBinding{
|
||||
LocalVaultPath: "/tmp/bellagio.kdbx",
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
CredentialEntryID: "linuscaldwell-webdav",
|
||||
SyncMode: SyncModeAutomaticOnOpenSave,
|
||||
}
|
||||
|
||||
resolved, err := binding.Resolve(model)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve() error = %v", err)
|
||||
}
|
||||
if got := resolved.Profile.BaseURL; got != "https://dav.example.invalid/remote.php/dav" {
|
||||
t.Fatalf("resolved profile base URL = %q, want remote.php/dav URL", got)
|
||||
}
|
||||
if got := resolved.Profile.Path; got != "files/bellagio/keepass.kdbx" {
|
||||
t.Fatalf("resolved profile path = %q, want files/bellagio/keepass.kdbx", got)
|
||||
}
|
||||
if got := resolved.Credentials.Username; got != "linuscaldwell" {
|
||||
t.Fatalf("resolved credentials username = %q, want linuscaldwell", got)
|
||||
}
|
||||
if got := resolved.Credentials.Password; got != "bellagio-pass-1" {
|
||||
t.Fatalf("resolved credentials password = %q, want bellagio-pass-1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteBindingResolveFailsWhenVaultReferenceIsMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{ID: "linuscaldwell-webdav", Title: "Bellagio WebDAV Sign-In"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := (RemoteBinding{
|
||||
LocalVaultPath: "/tmp/bellagio.kdbx",
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
CredentialEntryID: "missing-creds",
|
||||
}).Resolve(model)
|
||||
if !errors.Is(err, vault.ErrRemoteProfileNotFound) {
|
||||
t.Fatalf("Resolve() error = %v, want ErrRemoteProfileNotFound first", err)
|
||||
}
|
||||
|
||||
model.RemoteProfiles = []vault.RemoteProfile{{
|
||||
ID: "bellagio-webdav",
|
||||
Name: "Bellagio Vault",
|
||||
Backend: vault.RemoteBackendWebDAV,
|
||||
BaseURL: "https://dav.example.invalid/remote.php/dav",
|
||||
Path: "files/bellagio/keepass.kdbx",
|
||||
}}
|
||||
|
||||
_, err = (RemoteBinding{
|
||||
LocalVaultPath: "/tmp/bellagio.kdbx",
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
CredentialEntryID: "missing-creds",
|
||||
}).Resolve(model)
|
||||
if !errors.Is(err, vault.ErrEntryNotFound) {
|
||||
t.Fatalf("Resolve() error = %v, want ErrEntryNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteBindingJSONStoresOnlyNonSecretReferences(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal(RemoteBinding{
|
||||
LocalVaultPath: "/tmp/bellagio.kdbx",
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
CredentialEntryID: "remote-creds-1",
|
||||
SyncMode: SyncModeAutomaticOnOpenSave,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal(RemoteBinding) error = %v", err)
|
||||
}
|
||||
|
||||
text := string(content)
|
||||
for _, disallowed := range []string{"bellagio-pass-1", "password", "username", "baseUrl"} {
|
||||
if strings.Contains(text, disallowed) {
|
||||
t.Fatalf("binding JSON %q unexpectedly contains %q", text, disallowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureRemoteBindingStoresProfileAndCredentialsInVault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var model vault.Model
|
||||
|
||||
binding, err := ConfigureRemoteBinding(&model, RemoteBindingInput{
|
||||
LocalVaultPath: "/tmp/bellagio.kdbx",
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
RemoteProfileName: "Bellagio Vault",
|
||||
BaseURL: "https://dav.example.invalid/remote.php/dav",
|
||||
RemotePath: "files/bellagio/keepass.kdbx",
|
||||
CredentialEntryID: "remote-creds-1",
|
||||
CredentialTitle: "Bellagio WebDAV Sign-In",
|
||||
Username: "linuscaldwell",
|
||||
Password: "bellagio-pass-1",
|
||||
CredentialPath: []string{"Crew", "Internet"},
|
||||
SyncMode: SyncModeAutomaticOnOpenSave,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConfigureRemoteBinding() error = %v", err)
|
||||
}
|
||||
|
||||
if len(model.RemoteProfiles) != 1 {
|
||||
t.Fatalf("len(RemoteProfiles) = %d, want 1", len(model.RemoteProfiles))
|
||||
}
|
||||
if got := model.RemoteProfiles[0].BaseURL; got != "https://dav.example.invalid/remote.php/dav" {
|
||||
t.Fatalf("stored remote profile base URL = %q, want remote.php/dav URL", got)
|
||||
}
|
||||
|
||||
credentials, err := model.EntryByID("remote-creds-1")
|
||||
if err != nil {
|
||||
t.Fatalf("EntryByID(remote-creds-1) error = %v", err)
|
||||
}
|
||||
if credentials.Username != "linuscaldwell" || credentials.Password != "bellagio-pass-1" {
|
||||
t.Fatalf("stored credential entry = %#v, want linuscaldwell/bellagio-pass-1", credentials)
|
||||
}
|
||||
if credentials.URL != "https://dav.example.invalid/remote.php/dav" {
|
||||
t.Fatalf("stored credential entry URL = %q, want remote.php/dav URL", credentials.URL)
|
||||
}
|
||||
|
||||
if binding.LocalVaultPath != "/tmp/bellagio.kdbx" {
|
||||
t.Fatalf("binding LocalVaultPath = %q, want /tmp/bellagio.kdbx", binding.LocalVaultPath)
|
||||
}
|
||||
if binding.RemoteProfileID != "bellagio-webdav" || binding.CredentialEntryID != "remote-creds-1" {
|
||||
t.Fatalf("binding = %#v, want only vault references", binding)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureRemoteBindingRejectsIncompleteInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
input RemoteBindingInput
|
||||
}{
|
||||
{
|
||||
name: "missing_local_vault_path",
|
||||
input: RemoteBindingInput{
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
BaseURL: "https://dav.example.invalid/remote.php/dav",
|
||||
RemotePath: "files/bellagio/keepass.kdbx",
|
||||
CredentialEntryID: "remote-creds-1",
|
||||
Password: "bellagio-pass-1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_remote_base_url",
|
||||
input: RemoteBindingInput{
|
||||
LocalVaultPath: "/tmp/bellagio.kdbx",
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
RemotePath: "files/bellagio/keepass.kdbx",
|
||||
CredentialEntryID: "remote-creds-1",
|
||||
Password: "bellagio-pass-1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_credential_entry_id",
|
||||
input: RemoteBindingInput{
|
||||
LocalVaultPath: "/tmp/bellagio.kdbx",
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
BaseURL: "https://dav.example.invalid/remote.php/dav",
|
||||
RemotePath: "files/bellagio/keepass.kdbx",
|
||||
Password: "bellagio-pass-1",
|
||||
},
|
||||
},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var model vault.Model
|
||||
if _, err := ConfigureRemoteBinding(&model, tc.input); err == nil {
|
||||
t.Fatalf("ConfigureRemoteBinding(%#v) error = nil, want validation error", tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveRemoteBindingRemovesProfileAndCredentialsFromVault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{{
|
||||
ID: "remote-creds-1",
|
||||
Title: "Bellagio WebDAV Sign-In",
|
||||
Username: "linuscaldwell",
|
||||
Password: "bellagio-pass-1",
|
||||
}},
|
||||
RemoteProfiles: []vault.RemoteProfile{{
|
||||
ID: "bellagio-webdav",
|
||||
Name: "Bellagio Vault",
|
||||
Backend: vault.RemoteBackendWebDAV,
|
||||
BaseURL: "https://dav.example.invalid/remote.php/dav",
|
||||
Path: "files/bellagio/keepass.kdbx",
|
||||
}},
|
||||
}
|
||||
|
||||
err := RemoveRemoteBinding(&model, RemoteBinding{
|
||||
LocalVaultPath: "/tmp/bellagio.kdbx",
|
||||
RemoteProfileID: "bellagio-webdav",
|
||||
CredentialEntryID: "remote-creds-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveRemoteBinding() error = %v", err)
|
||||
}
|
||||
|
||||
if got := len(model.RemoteProfiles); got != 0 {
|
||||
t.Fatalf("len(RemoteProfiles) = %d, want 0", got)
|
||||
}
|
||||
if _, err := model.EntryByID("remote-creds-1"); !errors.Is(err, vault.ErrEntryNotFound) {
|
||||
t.Fatalf("EntryByID(remote-creds-1) error = %v, want ErrEntryNotFound", err)
|
||||
}
|
||||
}
|
||||
@@ -31,18 +31,18 @@ import (
|
||||
"gioui.org/widget"
|
||||
"gioui.org/widget/material"
|
||||
"gioui.org/x/explorer"
|
||||
"git.julianfamily.org/keepassgo/api"
|
||||
"git.julianfamily.org/keepassgo/apiapproval"
|
||||
"git.julianfamily.org/keepassgo/apiaudit"
|
||||
"git.julianfamily.org/keepassgo/apitokens"
|
||||
"git.julianfamily.org/keepassgo/appstate"
|
||||
keepassassets "git.julianfamily.org/keepassgo/assets"
|
||||
"git.julianfamily.org/keepassgo/autofillcache"
|
||||
"git.julianfamily.org/keepassgo/clipboard"
|
||||
"git.julianfamily.org/keepassgo/passwords"
|
||||
"git.julianfamily.org/keepassgo/session"
|
||||
"git.julianfamily.org/keepassgo/vault"
|
||||
"git.julianfamily.org/keepassgo/webdav"
|
||||
"git.julianfamily.org/keepassgo/internal/api"
|
||||
"git.julianfamily.org/keepassgo/internal/apiapproval"
|
||||
"git.julianfamily.org/keepassgo/internal/apiaudit"
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
"git.julianfamily.org/keepassgo/internal/appstate"
|
||||
keepassassets "git.julianfamily.org/keepassgo/internal/assets"
|
||||
"git.julianfamily.org/keepassgo/internal/autofillcache"
|
||||
"git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
"git.julianfamily.org/keepassgo/internal/webdav"
|
||||
"golang.org/x/exp/shiny/materialdesign/icons"
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
gioclipboard "gioui.org/io/clipboard"
|
||||
"gioui.org/layout"
|
||||
|
||||
appclipboard "git.julianfamily.org/keepassgo/clipboard"
|
||||
appclipboard "git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
)
|
||||
|
||||
type clipboardCommandWriter struct {
|
||||
|
||||
@@ -21,15 +21,15 @@ import (
|
||||
"gioui.org/unit"
|
||||
"gioui.org/widget"
|
||||
|
||||
"git.julianfamily.org/keepassgo/apiapproval"
|
||||
"git.julianfamily.org/keepassgo/apiaudit"
|
||||
"git.julianfamily.org/keepassgo/apitokens"
|
||||
"git.julianfamily.org/keepassgo/appstate"
|
||||
"git.julianfamily.org/keepassgo/clipboard"
|
||||
"git.julianfamily.org/keepassgo/passwords"
|
||||
"git.julianfamily.org/keepassgo/session"
|
||||
"git.julianfamily.org/keepassgo/vault"
|
||||
"git.julianfamily.org/keepassgo/webdav"
|
||||
"git.julianfamily.org/keepassgo/internal/apiapproval"
|
||||
"git.julianfamily.org/keepassgo/internal/apiaudit"
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
"git.julianfamily.org/keepassgo/internal/appstate"
|
||||
"git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
"git.julianfamily.org/keepassgo/internal/webdav"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"gioui.org/unit"
|
||||
"gioui.org/widget"
|
||||
"gioui.org/widget/material"
|
||||
"git.julianfamily.org/keepassgo/apiaudit"
|
||||
"git.julianfamily.org/keepassgo/apitokens"
|
||||
"git.julianfamily.org/keepassgo/internal/apiaudit"
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
)
|
||||
|
||||
func apiOperations() []apitokens.Operation {
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"gioui.org/widget"
|
||||
"git.julianfamily.org/keepassgo/clipboard"
|
||||
"git.julianfamily.org/keepassgo/passwords"
|
||||
"git.julianfamily.org/keepassgo/vault"
|
||||
"git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
func (u *ui) attachmentInput() (string, []byte, error) {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"gioui.org/unit"
|
||||
"gioui.org/widget"
|
||||
"gioui.org/widget/material"
|
||||
"git.julianfamily.org/keepassgo/appstate"
|
||||
"git.julianfamily.org/keepassgo/internal/appstate"
|
||||
)
|
||||
|
||||
func (u *ui) lifecycleControls(gtx layout.Context) layout.Dimensions {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"gioui.org/io/key"
|
||||
"git.julianfamily.org/keepassgo/appstate"
|
||||
"git.julianfamily.org/keepassgo/internal/appstate"
|
||||
)
|
||||
|
||||
type focusID string
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"gioui.org/unit"
|
||||
"gioui.org/widget"
|
||||
"gioui.org/widget/material"
|
||||
"git.julianfamily.org/keepassgo/vault"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"gioui.org/io/key"
|
||||
"gioui.org/layout"
|
||||
|
||||
"git.julianfamily.org/keepassgo/clipboard"
|
||||
"git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"git.julianfamily.org/keepassgo/appstate"
|
||||
"git.julianfamily.org/keepassgo/internal/appstate"
|
||||
)
|
||||
|
||||
type syncMenuModel struct {
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package assets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"image"
|
||||
"image/png"
|
||||
)
|
||||
|
||||
//go:embed *.png *.svg
|
||||
var files embed.FS
|
||||
|
||||
func MustPNG(name string) image.Image {
|
||||
data, err := files.ReadFile(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
img, err := png.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
After Width: | Height: | Size: 14 KiB |
@@ -0,0 +1,18 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="256" height="256" viewBox="0 0 256 256" fill="none">
|
||||
<title>KeePassGO icon</title>
|
||||
<desc>Vault-shaped mark with lock, layered record lines, and navigation accent.</desc>
|
||||
<polygon points="128,18 214,38 214,176 128,218 42,176 42,38" fill="#3F4E63"/>
|
||||
<polygon points="128,34 198,50 198,166 128,200 58,166 58,50" fill="#55657A" opacity="0.16"/>
|
||||
<polygon points="128,178 128,218 42,176 42,160" fill="#6D89A8"/>
|
||||
<rect x="44" y="78" width="62" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
|
||||
<rect x="44" y="98" width="52" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
|
||||
|
||||
<path d="M100 86C100 70.536 112.536 58 128 58C143.464 58 156 70.536 156 86V106H142V86C142 78.268 135.732 72 128 72C120.268 72 114 78.268 114 86V106H100V86Z" fill="#FFFFFF"/>
|
||||
<rect x="76" y="102" width="104" height="64" rx="10" fill="#FFFFFF"/>
|
||||
<rect x="80" y="106" width="96" height="56" rx="8" fill="#E9EDF0"/>
|
||||
<path d="M128 118C139.046 118 148 126.954 148 138C148 145.669 143.68 152.329 137.338 155.7V173H118.662V155.7C112.32 152.329 108 145.669 108 138C108 126.954 116.954 118 128 118Z" fill="#3F4E63"/>
|
||||
<circle cx="128" cy="138" r="8" fill="#FFFFFF" opacity="0.2"/>
|
||||
|
||||
<path d="M178 130H214V150H166V144L178 130Z" fill="#E79A17"/>
|
||||
<path d="M178 130H214V140H174L166 144V144L178 130Z" fill="#F0B13D" opacity="0.35"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
|
After Width: | Height: | Size: 40 KiB |
@@ -0,0 +1,22 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="920" height="260" viewBox="0 0 920 260" fill="none">
|
||||
<title>KeePassGO horizontal logo</title>
|
||||
<desc>KeePassGO symbol with wordmark for light-theme desktop UI.</desc>
|
||||
<defs>
|
||||
<symbol id="kpg-icon" viewBox="0 0 256 256">
|
||||
<polygon points="128,18 214,38 214,176 128,218 42,176 42,38" fill="#3F4E63"/>
|
||||
<polygon points="128,34 198,50 198,166 128,200 58,166 58,50" fill="#55657A" opacity="0.16"/>
|
||||
<polygon points="128,178 128,218 42,176 42,160" fill="#6D89A8"/>
|
||||
<rect x="44" y="78" width="62" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
|
||||
<rect x="44" y="98" width="52" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
|
||||
<path d="M100 86C100 70.536 112.536 58 128 58C143.464 58 156 70.536 156 86V106H142V86C142 78.268 135.732 72 128 72C120.268 72 114 78.268 114 86V106H100V86Z" fill="#FFFFFF"/>
|
||||
<rect x="76" y="102" width="104" height="64" rx="10" fill="#FFFFFF"/>
|
||||
<rect x="80" y="106" width="96" height="56" rx="8" fill="#E9EDF0"/>
|
||||
<path d="M128 118C139.046 118 148 126.954 148 138C148 145.669 143.68 152.329 137.338 155.7V173H118.662V155.7C112.32 152.329 108 145.669 108 138C108 126.954 116.954 118 128 118Z" fill="#3F4E63"/>
|
||||
<path d="M178 130H214V150H166V144L178 130Z" fill="#E79A17"/>
|
||||
<path d="M178 130H214V140H174L166 144V144L178 130Z" fill="#F0B13D" opacity="0.35"/>
|
||||
</symbol>
|
||||
</defs>
|
||||
<rect width="920" height="260" fill="transparent"/>
|
||||
<use href="#kpg-icon" x="24" y="20" width="176" height="176"/>
|
||||
<text x="220" y="132" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="84" font-weight="650" fill="#3F4E63">KeePassGO</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.7 KiB |
@@ -0,0 +1,75 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1280" height="800" viewBox="0 0 1280 800" fill="none">
|
||||
<title>KeePassGO splash screen</title>
|
||||
<desc>Light-theme desktop splash screen with structured panels and the KeePassGO logo.</desc>
|
||||
<defs>
|
||||
<symbol id="kpg-icon" viewBox="0 0 256 256">
|
||||
<polygon points="128,18 214,38 214,176 128,218 42,176 42,38" fill="#3F4E63"/>
|
||||
<polygon points="128,34 198,50 198,166 128,200 58,166 58,50" fill="#55657A" opacity="0.16"/>
|
||||
<polygon points="128,178 128,218 42,176 42,160" fill="#6D89A8"/>
|
||||
<rect x="44" y="78" width="62" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
|
||||
<rect x="44" y="98" width="52" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
|
||||
<path d="M100 86C100 70.536 112.536 58 128 58C143.464 58 156 70.536 156 86V106H142V86C142 78.268 135.732 72 128 72C120.268 72 114 78.268 114 86V106H100V86Z" fill="#FFFFFF"/>
|
||||
<rect x="76" y="102" width="104" height="64" rx="10" fill="#FFFFFF"/>
|
||||
<rect x="80" y="106" width="96" height="56" rx="8" fill="#E9EDF0"/>
|
||||
<path d="M128 118C139.046 118 148 126.954 148 138C148 145.669 143.68 152.329 137.338 155.7V173H118.662V155.7C112.32 152.329 108 145.669 108 138C108 126.954 116.954 118 128 118Z" fill="#3F4E63"/>
|
||||
<path d="M178 130H214V150H166V144L178 130Z" fill="#E79A17"/>
|
||||
<path d="M178 130H214V140H174L166 144V144L178 130Z" fill="#F0B13D" opacity="0.35"/>
|
||||
</symbol>
|
||||
<linearGradient id="fade" x1="0" y1="0" x2="0" y2="1">
|
||||
<stop offset="0" stop-color="#D9E0E6" stop-opacity="0.9"/>
|
||||
<stop offset="1" stop-color="#D9E0E6" stop-opacity="0.2"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
|
||||
<rect width="1280" height="800" fill="#F7F7F5"/>
|
||||
|
||||
<g opacity="0.9">
|
||||
<rect x="48" y="48" width="238" height="188" fill="#EEF2F4"/>
|
||||
<rect x="286" y="48" width="152" height="104" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="438" y="48" width="214" height="104" fill="#E7EDF1"/>
|
||||
<rect x="652" y="48" width="146" height="188" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="798" y="48" width="224" height="118" fill="#EEF2F4"/>
|
||||
<rect x="1022" y="48" width="210" height="188" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
|
||||
<rect x="48" y="236" width="128" height="164" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="176" y="236" width="262" height="164" fill="#E7EDF1"/>
|
||||
<rect x="438" y="236" width="160" height="84" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="598" y="236" width="200" height="164" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="798" y="236" width="434" height="164" fill="#EEF2F4"/>
|
||||
|
||||
<rect x="48" y="400" width="224" height="128" fill="#EEF2F4"/>
|
||||
<rect x="272" y="400" width="166" height="128" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="438" y="400" width="360" height="128" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="798" y="400" width="152" height="128" fill="#E7EDF1"/>
|
||||
<rect x="950" y="400" width="282" height="128" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
|
||||
<rect x="48" y="528" width="114" height="224" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="162" y="528" width="276" height="224" fill="#E7EDF1"/>
|
||||
<rect x="438" y="528" width="212" height="224" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="650" y="528" width="318" height="224" fill="#EEF2F4"/>
|
||||
<rect x="968" y="528" width="264" height="224" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
</g>
|
||||
|
||||
<g opacity="0.45">
|
||||
<path d="M48 152H1232" stroke="#C7D0D8"/>
|
||||
<path d="M48 320H1232" stroke="#C7D0D8"/>
|
||||
<path d="M48 528H1232" stroke="#C7D0D8"/>
|
||||
<path d="M176 48V752" stroke="#C7D0D8"/>
|
||||
<path d="M438 48V752" stroke="#C7D0D8"/>
|
||||
<path d="M798 48V752" stroke="#C7D0D8"/>
|
||||
<path d="M1022 48V752" stroke="#C7D0D8"/>
|
||||
</g>
|
||||
|
||||
<rect x="290" y="190" width="700" height="420" rx="18" fill="#F7F7F5" opacity="0.92"/>
|
||||
<rect x="290" y="190" width="700" height="420" rx="18" stroke="#C7D0D8"/>
|
||||
<rect x="290" y="190" width="700" height="120" rx="18" fill="url(#fade)"/>
|
||||
|
||||
<use href="#kpg-icon" x="530" y="232" width="220" height="220"/>
|
||||
<text x="640" y="530" text-anchor="middle" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="88" font-weight="650" fill="#3F4E63">KeePassGO</text>
|
||||
<text x="640" y="582" text-anchor="middle" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="28" font-weight="500" fill="#55657A">KeePass-compatible password manager</text>
|
||||
|
||||
<g opacity="0.9">
|
||||
<rect x="516" y="634" width="248" height="10" rx="5" fill="#D9E0E6"/>
|
||||
<rect x="516" y="634" width="132" height="10" rx="5" fill="#6D89A8"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.5 KiB |
|
After Width: | Height: | Size: 47 KiB |
@@ -0,0 +1,41 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1024" height="1024" viewBox="0 0 1024 1024" fill="none">
|
||||
<title>KeePassGO square splash preview</title>
|
||||
<desc>Square crop of the KeePassGO splash system.</desc>
|
||||
<defs>
|
||||
<symbol id="kpg-icon" viewBox="0 0 256 256">
|
||||
<polygon points="128,18 214,38 214,176 128,218 42,176 42,38" fill="#3F4E63"/>
|
||||
<polygon points="128,34 198,50 198,166 128,200 58,166 58,50" fill="#55657A" opacity="0.16"/>
|
||||
<polygon points="128,178 128,218 42,176 42,160" fill="#6D89A8"/>
|
||||
<rect x="44" y="78" width="62" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
|
||||
<rect x="44" y="98" width="52" height="10" rx="2" fill="#FFFFFF" opacity="0.92"/>
|
||||
<path d="M100 86C100 70.536 112.536 58 128 58C143.464 58 156 70.536 156 86V106H142V86C142 78.268 135.732 72 128 72C120.268 72 114 78.268 114 86V106H100V86Z" fill="#FFFFFF"/>
|
||||
<rect x="76" y="102" width="104" height="64" rx="10" fill="#FFFFFF"/>
|
||||
<rect x="80" y="106" width="96" height="56" rx="8" fill="#E9EDF0"/>
|
||||
<path d="M128 118C139.046 118 148 126.954 148 138C148 145.669 143.68 152.329 137.338 155.7V173H118.662V155.7C112.32 152.329 108 145.669 108 138C108 126.954 116.954 118 128 118Z" fill="#3F4E63"/>
|
||||
<path d="M178 130H214V150H166V144L178 130Z" fill="#E79A17"/>
|
||||
<path d="M178 130H214V140H174L166 144V144L178 130Z" fill="#F0B13D" opacity="0.35"/>
|
||||
</symbol>
|
||||
</defs>
|
||||
|
||||
<rect width="1024" height="1024" fill="#F7F7F5"/>
|
||||
<g opacity="0.9">
|
||||
<rect x="32" y="32" width="180" height="184" fill="#EEF2F4"/>
|
||||
<rect x="212" y="32" width="264" height="140" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="476" y="32" width="180" height="184" fill="#E7EDF1"/>
|
||||
<rect x="656" y="32" width="336" height="140" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="656" y="172" width="336" height="180" fill="#EEF2F4"/>
|
||||
<rect x="32" y="216" width="236" height="210" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="268" y="216" width="388" height="210" fill="#EEF2F4"/>
|
||||
<rect x="32" y="426" width="180" height="268" fill="#E7EDF1"/>
|
||||
<rect x="212" y="426" width="312" height="268" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="524" y="426" width="468" height="268" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="32" y="694" width="280" height="298" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
<rect x="312" y="694" width="268" height="298" fill="#EEF2F4"/>
|
||||
<rect x="580" y="694" width="412" height="298" fill="#F7F7F5" stroke="#C7D0D8"/>
|
||||
</g>
|
||||
<rect x="162" y="190" width="700" height="644" rx="24" fill="#F7F7F5" opacity="0.95"/>
|
||||
<rect x="162" y="190" width="700" height="644" rx="24" stroke="#C7D0D8"/>
|
||||
<use href="#kpg-icon" x="332" y="252" width="360" height="360"/>
|
||||
<text x="512" y="680" text-anchor="middle" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="88" font-weight="650" fill="#3F4E63">KeePassGO</text>
|
||||
<text x="512" y="740" text-anchor="middle" font-family="Inter, Noto Sans, Segoe UI, Arial, sans-serif" font-size="28" font-weight="500" fill="#55657A">KeePass-compatible password manager</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.0 KiB |
@@ -0,0 +1,296 @@
|
||||
package autofillcache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
URL string `json:"url"`
|
||||
Host string `json:"host"`
|
||||
Targets []string `json:"targets,omitempty"`
|
||||
Path []string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
Entries []Entry `json:"entries"`
|
||||
}
|
||||
|
||||
type MatchStatus string
|
||||
|
||||
const (
|
||||
MatchStatusNone MatchStatus = ""
|
||||
MatchStatusFound MatchStatus = "found"
|
||||
MatchStatusAmbiguous MatchStatus = "ambiguous"
|
||||
MatchStatusMissing MatchStatus = "missing"
|
||||
)
|
||||
|
||||
type MatchResult struct {
|
||||
Status MatchStatus `json:"status"`
|
||||
Entry Entry `json:"entry,omitempty"`
|
||||
}
|
||||
|
||||
func Match(cache File, webURL string) (Entry, bool) {
|
||||
result := Resolve(cache, webURL)
|
||||
return result.Entry, result.Status == MatchStatusFound
|
||||
}
|
||||
|
||||
func Resolve(cache File, webURL string) MatchResult {
|
||||
target := normalizeURL(webURL)
|
||||
if target.host == "" {
|
||||
return MatchResult{Status: MatchStatusMissing}
|
||||
}
|
||||
|
||||
exactHost := make([]Entry, 0)
|
||||
parentHost := make([]Entry, 0)
|
||||
for _, entry := range cache.Entries {
|
||||
if entryMatchesHost(entry, target.host) {
|
||||
exactHost = append(exactHost, entry)
|
||||
continue
|
||||
}
|
||||
if entryMatchesParentHost(entry, target.host) {
|
||||
parentHost = append(parentHost, entry)
|
||||
}
|
||||
}
|
||||
|
||||
if result := chooseEntry(target, exactHost); result.Status != MatchStatusMissing {
|
||||
return result
|
||||
}
|
||||
return chooseEntry(target, parentHost)
|
||||
}
|
||||
|
||||
func Build(model vault.Model, now time.Time) File {
|
||||
entries := make([]Entry, 0, len(model.Entries))
|
||||
for _, item := range model.Entries {
|
||||
targets := collectTargets(item)
|
||||
host := normalizeHost(item.URL)
|
||||
if host == "" {
|
||||
for _, target := range targets {
|
||||
host = normalizeHost(target)
|
||||
if host != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(item.Username) == "" || strings.TrimSpace(item.Password) == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, Entry{
|
||||
ID: item.ID,
|
||||
Title: item.Title,
|
||||
Username: item.Username,
|
||||
Password: item.Password,
|
||||
URL: item.URL,
|
||||
Host: host,
|
||||
Targets: targets,
|
||||
Path: append([]string(nil), item.Path...),
|
||||
})
|
||||
}
|
||||
return File{
|
||||
UpdatedAt: now.UTC().Format(time.RFC3339),
|
||||
Entries: entries,
|
||||
}
|
||||
}
|
||||
|
||||
func Write(path string, model vault.Model, now time.Time) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(Build(model, now), "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
func Clear(path string) error {
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeHost(raw string) string {
|
||||
return normalizeURL(raw).host
|
||||
}
|
||||
|
||||
type normalizedTarget struct {
|
||||
host string
|
||||
path string
|
||||
url string
|
||||
}
|
||||
|
||||
func normalizeURL(raw string) normalizedTarget {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return normalizedTarget{}
|
||||
}
|
||||
if !strings.Contains(value, "://") {
|
||||
value = "https://" + value
|
||||
}
|
||||
parsed, err := url.Parse(value)
|
||||
if err != nil {
|
||||
return normalizedTarget{}
|
||||
}
|
||||
host := strings.TrimSpace(parsed.Hostname())
|
||||
path := cleanPath(parsed.EscapedPath())
|
||||
return normalizedTarget{
|
||||
host: strings.ToLower(host),
|
||||
path: path,
|
||||
url: strings.ToLower(host) + path,
|
||||
}
|
||||
}
|
||||
|
||||
func cleanPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" || path == "/" {
|
||||
return "/"
|
||||
}
|
||||
path = strings.TrimRight(path, "/")
|
||||
if path == "" {
|
||||
return "/"
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func chooseEntry(target normalizedTarget, entries []Entry) MatchResult {
|
||||
switch len(entries) {
|
||||
case 0:
|
||||
return MatchResult{Status: MatchStatusMissing}
|
||||
case 1:
|
||||
return MatchResult{Status: MatchStatusFound, Entry: entries[0]}
|
||||
}
|
||||
|
||||
exact := make([]Entry, 0)
|
||||
bestPrefixLen := -1
|
||||
bestPrefix := make([]Entry, 0)
|
||||
for _, entry := range entries {
|
||||
exactMatch, prefixLen := bestTargetMatch(entry, target)
|
||||
if exactMatch {
|
||||
exact = append(exact, entry)
|
||||
continue
|
||||
}
|
||||
if prefixLen <= 0 {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case prefixLen > bestPrefixLen:
|
||||
bestPrefixLen = prefixLen
|
||||
bestPrefix = []Entry{entry}
|
||||
case prefixLen == bestPrefixLen:
|
||||
bestPrefix = append(bestPrefix, entry)
|
||||
}
|
||||
}
|
||||
if len(exact) == 1 {
|
||||
return MatchResult{Status: MatchStatusFound, Entry: exact[0]}
|
||||
}
|
||||
if len(exact) > 1 {
|
||||
return MatchResult{Status: MatchStatusAmbiguous}
|
||||
}
|
||||
if len(bestPrefix) == 1 {
|
||||
return MatchResult{Status: MatchStatusFound, Entry: bestPrefix[0]}
|
||||
}
|
||||
if len(bestPrefix) == 0 {
|
||||
return MatchResult{Status: MatchStatusMissing}
|
||||
}
|
||||
return MatchResult{Status: MatchStatusAmbiguous}
|
||||
}
|
||||
|
||||
func collectTargets(item vault.Entry) []string {
|
||||
seen := make(map[string]struct{})
|
||||
targets := make([]string, 0, 1+len(item.Fields))
|
||||
appendTarget := func(raw string) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
return
|
||||
}
|
||||
seen[value] = struct{}{}
|
||||
targets = append(targets, value)
|
||||
}
|
||||
|
||||
appendTarget(item.URL)
|
||||
|
||||
keys := make([]string, 0, len(item.Fields))
|
||||
for key := range item.Fields {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
upper := strings.ToUpper(strings.TrimSpace(key))
|
||||
if strings.HasPrefix(upper, "ANDROIDAPP") || strings.HasPrefix(upper, "KP2A_URL") {
|
||||
appendTarget(item.Fields[key])
|
||||
}
|
||||
}
|
||||
|
||||
return targets
|
||||
}
|
||||
|
||||
func entryTargets(entry Entry) []normalizedTarget {
|
||||
values := entry.Targets
|
||||
if len(values) == 0 {
|
||||
values = []string{entry.URL}
|
||||
}
|
||||
targets := make([]normalizedTarget, 0, len(values))
|
||||
for _, value := range values {
|
||||
target := normalizeURL(value)
|
||||
if target.host == "" {
|
||||
continue
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
return targets
|
||||
}
|
||||
|
||||
func entryMatchesHost(entry Entry, host string) bool {
|
||||
for _, target := range entryTargets(entry) {
|
||||
if target.host == host {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func entryMatchesParentHost(entry Entry, host string) bool {
|
||||
for _, target := range entryTargets(entry) {
|
||||
if target.host != "" && strings.HasSuffix(host, "."+target.host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func bestTargetMatch(entry Entry, target normalizedTarget) (bool, int) {
|
||||
bestPrefixLen := -1
|
||||
for _, candidate := range entryTargets(entry) {
|
||||
if candidate.url == target.url {
|
||||
return true, 0
|
||||
}
|
||||
if candidate.path != "/" && strings.HasPrefix(target.path, candidate.path) {
|
||||
if pathLen := len(candidate.path); pathLen > bestPrefixLen {
|
||||
bestPrefixLen = pathLen
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, bestPrefixLen
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
package autofillcache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
func TestBuildFiltersAndNormalizesEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Date(2026, time.March, 31, 12, 0, 0, 0, time.UTC)
|
||||
got := Build(vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Chrome Test",
|
||||
Username: "joe",
|
||||
Password: "secret",
|
||||
URL: "https://10.0.2.2:8443/login",
|
||||
Path: []string{"Crew", "Internet"},
|
||||
},
|
||||
{
|
||||
ID: "two",
|
||||
Title: "No Password",
|
||||
Username: "joe",
|
||||
URL: "https://example.com",
|
||||
},
|
||||
{
|
||||
ID: "three",
|
||||
Title: "Bare Host",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
URL: "surveillance.crew.example.invalid",
|
||||
Fields: map[string]string{
|
||||
"AndroidApp1": "androidapp://com.lights.mobile",
|
||||
"KP2A_URL_1": "https://surveillance.crew.example.invalid/account",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, now)
|
||||
|
||||
if len(got.Entries) != 2 {
|
||||
t.Fatalf("entry count = %d, want 2", len(got.Entries))
|
||||
}
|
||||
if got.Entries[0].Host != "10.0.2.2" {
|
||||
t.Fatalf("first host = %q, want 10.0.2.2", got.Entries[0].Host)
|
||||
}
|
||||
if got.Entries[1].Host != "surveillance.crew.example.invalid" {
|
||||
t.Fatalf("second host = %q, want surveillance.crew.example.invalid", got.Entries[1].Host)
|
||||
}
|
||||
if len(got.Entries[1].Targets) != 3 {
|
||||
t.Fatalf("len(second targets) = %d, want 3", len(got.Entries[1].Targets))
|
||||
}
|
||||
if got.UpdatedAt != "2026-03-31T12:00:00Z" {
|
||||
t.Fatalf("updatedAt = %q", got.UpdatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAndClear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "autofill-cache.json")
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{ID: "one", Title: "Chrome Test", Username: "joe", Password: "secret", URL: "https://10.0.2.2:8443/login"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := Write(path, model, time.Date(2026, time.March, 31, 12, 0, 0, 0, time.UTC)); err != nil {
|
||||
t.Fatalf("Write() error = %v", err)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
var got File
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(got.Entries) != 1 || got.Entries[0].Host != "10.0.2.2" {
|
||||
t.Fatalf("cache entries = %#v", got.Entries)
|
||||
}
|
||||
if err := Clear(path); err != nil {
|
||||
t.Fatalf("Clear() error = %v", err)
|
||||
}
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Fatalf("cache path still exists, stat err = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchChoosesExactURLWhenHostsRepeat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := File{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Primary Login",
|
||||
Username: "first",
|
||||
Password: "secret1",
|
||||
URL: "https://10.0.2.2:8443/login/",
|
||||
Host: "10.0.2.2",
|
||||
},
|
||||
{
|
||||
ID: "two",
|
||||
Title: "Alt Login",
|
||||
Username: "second",
|
||||
Password: "secret2",
|
||||
URL: "https://10.0.2.2:8443/alt/",
|
||||
Host: "10.0.2.2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, ok := Match(cache, "https://10.0.2.2:8443/alt/")
|
||||
if !ok {
|
||||
t.Fatalf("Match() found no entry")
|
||||
}
|
||||
if got.ID != "two" {
|
||||
t.Fatalf("Match() entry = %q, want two", got.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchRejectsAmbiguousSharedHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := File{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Host A",
|
||||
Username: "first",
|
||||
Password: "secret1",
|
||||
URL: "https://surveillance.crew.example.invalid/",
|
||||
Host: "surveillance.crew.example.invalid",
|
||||
},
|
||||
{
|
||||
ID: "two",
|
||||
Title: "Host B",
|
||||
Username: "second",
|
||||
Password: "secret2",
|
||||
URL: "https://surveillance.crew.example.invalid/",
|
||||
Host: "surveillance.crew.example.invalid",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if _, ok := Match(cache, "https://surveillance.crew.example.invalid/"); ok {
|
||||
t.Fatalf("Match() unexpectedly resolved ambiguous shared host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveReportsFoundAmbiguousAndMissingStatuses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := File{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Admin Login",
|
||||
Username: "admin",
|
||||
Password: "secret1",
|
||||
URL: "https://example.com/admin",
|
||||
Host: "example.com",
|
||||
},
|
||||
{
|
||||
ID: "two",
|
||||
Title: "Shared Login A",
|
||||
Username: "shared-a",
|
||||
Password: "secret2",
|
||||
URL: "https://shared.example.com",
|
||||
Host: "shared.example.com",
|
||||
},
|
||||
{
|
||||
ID: "three",
|
||||
Title: "Shared Login B",
|
||||
Username: "shared-b",
|
||||
Password: "secret3",
|
||||
URL: "https://shared.example.com",
|
||||
Host: "shared.example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if got := Resolve(cache, "https://example.com/admin/login"); got.Status != MatchStatusFound || got.Entry.ID != "one" {
|
||||
t.Fatalf("Resolve(found) = %#v, want found entry one", got)
|
||||
}
|
||||
if got := Resolve(cache, "https://shared.example.com"); got.Status != MatchStatusAmbiguous {
|
||||
t.Fatalf("Resolve(ambiguous) = %#v, want ambiguous", got)
|
||||
}
|
||||
if got := Resolve(cache, "https://nowhere.invalid"); got.Status != MatchStatusMissing {
|
||||
t.Fatalf("Resolve(missing) = %#v, want missing", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchChoosesLongestPathPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := File{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Generic Login",
|
||||
Username: "generic",
|
||||
Password: "secret1",
|
||||
URL: "https://example.com/",
|
||||
Host: "example.com",
|
||||
},
|
||||
{
|
||||
ID: "two",
|
||||
Title: "Admin Login",
|
||||
Username: "admin",
|
||||
Password: "secret2",
|
||||
URL: "https://example.com/admin",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, ok := Match(cache, "https://example.com/admin/login")
|
||||
if !ok {
|
||||
t.Fatalf("Match() found no entry")
|
||||
}
|
||||
if got.ID != "two" {
|
||||
t.Fatalf("Match() entry = %q, want two", got.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchSupportsAndroidAppPackageTargets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := File{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Thunderbird",
|
||||
Username: "mail-user",
|
||||
Password: "secret1",
|
||||
URL: "androidapp://org.mozilla.thunderbird/login",
|
||||
Host: "org.mozilla.thunderbird",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, ok := Match(cache, "androidapp://org.mozilla.thunderbird")
|
||||
if !ok {
|
||||
t.Fatalf("Match() found no entry")
|
||||
}
|
||||
if got.ID != "one" {
|
||||
t.Fatalf("Match() entry = %q, want one", got.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchRejectsAmbiguousAndroidAppPackageTargets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := File{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Thunderbird Primary",
|
||||
Username: "mail-user",
|
||||
Password: "secret1",
|
||||
URL: "androidapp://org.mozilla.thunderbird",
|
||||
Host: "org.mozilla.thunderbird",
|
||||
},
|
||||
{
|
||||
ID: "two",
|
||||
Title: "Thunderbird Secondary",
|
||||
Username: "other-user",
|
||||
Password: "secret2",
|
||||
URL: "androidapp://org.mozilla.thunderbird",
|
||||
Host: "org.mozilla.thunderbird",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if _, ok := Match(cache, "androidapp://org.mozilla.thunderbird"); ok {
|
||||
t.Fatalf("Match() unexpectedly resolved ambiguous android app package target")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchUsesAndroidAppCustomFieldTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := File{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Blink",
|
||||
Username: "blink-user",
|
||||
Password: "secret1",
|
||||
URL: "https://account.blinknetwork.com",
|
||||
Host: "account.blinknetwork.com",
|
||||
Targets: []string{"https://account.blinknetwork.com", "androidapp://com.blinknetwork.mobile2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, ok := Match(cache, "androidapp://com.blinknetwork.mobile2")
|
||||
if !ok {
|
||||
t.Fatalf("Match() found no entry")
|
||||
}
|
||||
if got.ID != "one" {
|
||||
t.Fatalf("Match() entry = %q, want one", got.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchUsesKP2AURLCustomFieldTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := File{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "one",
|
||||
Title: "Blink",
|
||||
Username: "blink-user",
|
||||
Password: "secret1",
|
||||
URL: "https://blinknetwork.com",
|
||||
Host: "blinknetwork.com",
|
||||
Targets: []string{"https://blinknetwork.com", "https://account.blinknetwork.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, ok := Match(cache, "https://account.blinknetwork.com")
|
||||
if !ok {
|
||||
t.Fatalf("Match() found no entry")
|
||||
}
|
||||
if got.ID != "one" {
|
||||
t.Fatalf("Match() entry = %q, want one", got.ID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package clipboard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
systemclipboard "github.com/atotto/clipboard"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
var ErrUnsupportedTarget = errors.New("unsupported clipboard target")
|
||||
var ErrWriteFailed = errors.New("clipboard write failed")
|
||||
|
||||
type Target string
|
||||
|
||||
const (
|
||||
TargetUsername Target = "username"
|
||||
TargetPassword Target = "password"
|
||||
TargetURL Target = "url"
|
||||
)
|
||||
|
||||
type Writer interface {
|
||||
WriteText(text string) error
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
Writer Writer
|
||||
}
|
||||
|
||||
func (s Service) Copy(model vault.Model, entryID string, target Target) error {
|
||||
entry, err := findEntry(model, entryID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content, err := contentForTarget(entry, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.writer().WriteText(content); err != nil {
|
||||
return writeError{err: err}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Service) writer() Writer {
|
||||
if s.Writer != nil {
|
||||
return s.Writer
|
||||
}
|
||||
return systemWriter{}
|
||||
}
|
||||
|
||||
func WriteText(text string) error {
|
||||
return systemWriter{}.WriteText(text)
|
||||
}
|
||||
|
||||
func findEntry(model vault.Model, entryID string) (vault.Entry, error) {
|
||||
for _, entry := range model.Entries {
|
||||
if entry.ID == entryID {
|
||||
return entry, nil
|
||||
}
|
||||
}
|
||||
return vault.Entry{}, vault.ErrEntryNotFound
|
||||
}
|
||||
|
||||
func contentForTarget(entry vault.Entry, target Target) (string, error) {
|
||||
switch target {
|
||||
case TargetUsername:
|
||||
return entry.Username, nil
|
||||
case TargetPassword:
|
||||
return entry.Password, nil
|
||||
case TargetURL:
|
||||
return entry.URL, nil
|
||||
default:
|
||||
return "", ErrUnsupportedTarget
|
||||
}
|
||||
}
|
||||
|
||||
type systemWriter struct{}
|
||||
|
||||
func (systemWriter) WriteText(text string) error {
|
||||
return systemclipboard.WriteAll(text)
|
||||
}
|
||||
|
||||
type writeError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e writeError) Error() string {
|
||||
return ErrWriteFailed.Error()
|
||||
}
|
||||
|
||||
func (e writeError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func (e writeError) Is(target error) bool {
|
||||
return target == ErrWriteFailed
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package clipboard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
)
|
||||
|
||||
func TestServiceCopiesUsernamePasswordAndURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{
|
||||
ID: "vault-console",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "token-1",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
target Target
|
||||
want string
|
||||
}{
|
||||
{name: "username", target: TargetUsername, want: "dannyocean"},
|
||||
{name: "password", target: TargetPassword, want: "token-1"},
|
||||
{name: "url", target: TargetURL, want: "https://vault.crew.example.invalid"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var writer memoryWriter
|
||||
service := Service{Writer: &writer}
|
||||
|
||||
if err := service.Copy(model, "vault-console", tt.target); err != nil {
|
||||
t.Fatalf("Copy() error = %v", err)
|
||||
}
|
||||
|
||||
if writer.content != tt.want {
|
||||
t.Fatalf("clipboard content = %q, want %q", writer.content, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceRejectsUnknownEntryAndUnsupportedTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var writer memoryWriter
|
||||
service := Service{Writer: &writer}
|
||||
|
||||
err := service.Copy(vault.Model{}, "missing", TargetPassword)
|
||||
if !errors.Is(err, vault.ErrEntryNotFound) {
|
||||
t.Fatalf("Copy() missing entry error = %v, want ErrEntryNotFound", err)
|
||||
}
|
||||
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{{ID: "vault-console", Username: "dannyocean"}},
|
||||
}
|
||||
err = service.Copy(model, "vault-console", Target("unsupported"))
|
||||
if !errors.Is(err, ErrUnsupportedTarget) {
|
||||
t.Fatalf("Copy() unsupported target error = %v, want ErrUnsupportedTarget", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceSanitizesClipboardWriteErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := Service{Writer: failingWriter{err: errors.New("backend refused token-1")}}
|
||||
model := vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
{ID: "vault-console", Password: "token-1"},
|
||||
},
|
||||
}
|
||||
|
||||
err := service.Copy(model, "vault-console", TargetPassword)
|
||||
if !errors.Is(err, ErrWriteFailed) {
|
||||
t.Fatalf("Copy() write error = %v, want ErrWriteFailed", err)
|
||||
}
|
||||
if err.Error() != ErrWriteFailed.Error() {
|
||||
t.Fatalf("Copy() write error string = %q, want %q", err.Error(), ErrWriteFailed.Error())
|
||||
}
|
||||
}
|
||||
|
||||
type memoryWriter struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (w *memoryWriter) WriteText(text string) error {
|
||||
w.content = text
|
||||
return nil
|
||||
}
|
||||
|
||||
type failingWriter struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (w failingWriter) WriteText(string) error {
|
||||
return w.err
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
package passwords
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
lowercaseChars = "abcdefghijkmnopqrstuvwxyz"
|
||||
uppercaseChars = "ABCDEFGHJKLMNPQRSTUVWXYZ"
|
||||
digitChars = "23456789"
|
||||
symbolChars = "!@#$%^&*()-_=+[]{}<>?/."
|
||||
)
|
||||
|
||||
var ErrImpossibleProfile = errors.New("impossible password profile")
|
||||
var ErrUnknownProfile = errors.New("unknown password profile")
|
||||
|
||||
type Profile struct {
|
||||
Name string
|
||||
Length int
|
||||
Lowercase bool
|
||||
Uppercase bool
|
||||
Digits bool
|
||||
Symbols bool
|
||||
MinLowercase int
|
||||
MinUppercase int
|
||||
MinDigits int
|
||||
MinSymbols int
|
||||
ExcludeSimilar bool
|
||||
}
|
||||
|
||||
func DefaultProfiles() map[string]Profile {
|
||||
return map[string]Profile{
|
||||
"strong": {
|
||||
Name: "strong",
|
||||
Length: 24,
|
||||
Lowercase: true,
|
||||
Uppercase: true,
|
||||
Digits: true,
|
||||
Symbols: true,
|
||||
MinLowercase: 2,
|
||||
MinUppercase: 2,
|
||||
MinDigits: 2,
|
||||
MinSymbols: 2,
|
||||
ExcludeSimilar: true,
|
||||
},
|
||||
"memorable": {
|
||||
Name: "memorable",
|
||||
Length: 20,
|
||||
Lowercase: true,
|
||||
Uppercase: true,
|
||||
Digits: true,
|
||||
Symbols: false,
|
||||
MinLowercase: 4,
|
||||
MinUppercase: 2,
|
||||
MinDigits: 2,
|
||||
ExcludeSimilar: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultProfileNames() []string {
|
||||
return ProfileNames(DefaultProfiles())
|
||||
}
|
||||
|
||||
func LookupProfile(name string, profiles map[string]Profile) (Profile, error) {
|
||||
profile, ok := profiles[strings.TrimSpace(name)]
|
||||
if !ok {
|
||||
return Profile{}, fmt.Errorf("%w %q", ErrUnknownProfile, strings.TrimSpace(name))
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func LookupDefaultProfile(name string) (Profile, error) {
|
||||
return LookupProfile(name, DefaultProfiles())
|
||||
}
|
||||
|
||||
func ProfileNames(profiles map[string]Profile) []string {
|
||||
names := make([]string, 0, len(profiles))
|
||||
for name := range profiles {
|
||||
names = append(names, name)
|
||||
}
|
||||
slices.Sort(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func Generate(profile Profile) (string, error) {
|
||||
if err := validateProfile(profile); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var chars []byte
|
||||
var pool strings.Builder
|
||||
if profile.Lowercase {
|
||||
pool.WriteString(lowercaseChars)
|
||||
chars = append(chars, mustRandomChars(lowercaseChars, profile.MinLowercase)...)
|
||||
}
|
||||
if profile.Uppercase {
|
||||
pool.WriteString(uppercaseChars)
|
||||
chars = append(chars, mustRandomChars(uppercaseChars, profile.MinUppercase)...)
|
||||
}
|
||||
if profile.Digits {
|
||||
pool.WriteString(digitChars)
|
||||
chars = append(chars, mustRandomChars(digitChars, profile.MinDigits)...)
|
||||
}
|
||||
if profile.Symbols {
|
||||
pool.WriteString(symbolChars)
|
||||
chars = append(chars, mustRandomChars(symbolChars, profile.MinSymbols)...)
|
||||
}
|
||||
|
||||
allChars := pool.String()
|
||||
for len(chars) < profile.Length {
|
||||
ch, err := randomChar(allChars)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
chars = append(chars, ch)
|
||||
}
|
||||
|
||||
if err := shuffle(chars); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(chars), nil
|
||||
}
|
||||
|
||||
func validateProfile(profile Profile) error {
|
||||
if profile.Length <= 0 {
|
||||
return fmt.Errorf("%w: length must be positive", ErrImpossibleProfile)
|
||||
}
|
||||
|
||||
required := profile.MinLowercase + profile.MinUppercase + profile.MinDigits + profile.MinSymbols
|
||||
if required > profile.Length {
|
||||
return fmt.Errorf("%w: minimum character counts exceed length", ErrImpossibleProfile)
|
||||
}
|
||||
|
||||
if profile.MinLowercase > 0 && !profile.Lowercase {
|
||||
return fmt.Errorf("%w: lowercase disabled with lowercase minimum", ErrImpossibleProfile)
|
||||
}
|
||||
if profile.MinUppercase > 0 && !profile.Uppercase {
|
||||
return fmt.Errorf("%w: uppercase disabled with uppercase minimum", ErrImpossibleProfile)
|
||||
}
|
||||
if profile.MinDigits > 0 && !profile.Digits {
|
||||
return fmt.Errorf("%w: digits disabled with digit minimum", ErrImpossibleProfile)
|
||||
}
|
||||
if profile.MinSymbols > 0 && !profile.Symbols {
|
||||
return fmt.Errorf("%w: symbols disabled with symbol minimum", ErrImpossibleProfile)
|
||||
}
|
||||
|
||||
if !profile.Lowercase && !profile.Uppercase && !profile.Digits && !profile.Symbols {
|
||||
return fmt.Errorf("%w: no character sets enabled", ErrImpossibleProfile)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustRandomChars(chars string, count int) []byte {
|
||||
if count <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]byte, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
ch, err := randomChar(chars)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
out = append(out, ch)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func randomChar(chars string) (byte, error) {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("random index: %w", err)
|
||||
}
|
||||
return chars[n.Int64()], nil
|
||||
}
|
||||
|
||||
func shuffle(chars []byte) error {
|
||||
for i := len(chars) - 1; i > 0; i-- {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("shuffle password: %w", err)
|
||||
}
|
||||
j := int(n.Int64())
|
||||
chars[i], chars[j] = chars[j], chars[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package passwords
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateRespectsProfileRequirements(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profile := Profile{
|
||||
Name: "strong",
|
||||
Length: 24,
|
||||
Lowercase: true,
|
||||
Uppercase: true,
|
||||
Digits: true,
|
||||
Symbols: true,
|
||||
MinLowercase: 2,
|
||||
MinUppercase: 2,
|
||||
MinDigits: 2,
|
||||
MinSymbols: 2,
|
||||
ExcludeSimilar: true,
|
||||
}
|
||||
|
||||
password, err := Generate(profile)
|
||||
if err != nil {
|
||||
t.Fatalf("Generate() error = %v", err)
|
||||
}
|
||||
|
||||
if len(password) != 24 {
|
||||
t.Fatalf("len(password) = %d, want 24", len(password))
|
||||
}
|
||||
|
||||
if countFromSet(password, lowercaseChars) < 2 {
|
||||
t.Fatalf("lowercase count in %q is too small", password)
|
||||
}
|
||||
|
||||
if countFromSet(password, uppercaseChars) < 2 {
|
||||
t.Fatalf("uppercase count in %q is too small", password)
|
||||
}
|
||||
|
||||
if countFromSet(password, digitChars) < 2 {
|
||||
t.Fatalf("digit count in %q is too small", password)
|
||||
}
|
||||
|
||||
if countFromSet(password, symbolChars) < 2 {
|
||||
t.Fatalf("symbol count in %q is too small", password)
|
||||
}
|
||||
|
||||
if strings.ContainsAny(password, "O0Il1") {
|
||||
t.Fatalf("password %q contains excluded similar characters", password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRejectsImpossibleProfiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := Generate(Profile{
|
||||
Name: "bad",
|
||||
Length: 6,
|
||||
Lowercase: true,
|
||||
Uppercase: true,
|
||||
Digits: true,
|
||||
Symbols: true,
|
||||
MinLowercase: 2,
|
||||
MinUppercase: 2,
|
||||
MinDigits: 2,
|
||||
MinSymbols: 2,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Generate() error = nil, want impossible profile error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProfileSetReturnsNamedProfiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
set := DefaultProfiles()
|
||||
profile, ok := set["strong"]
|
||||
if !ok {
|
||||
t.Fatalf("DefaultProfiles()[\"strong\"] missing")
|
||||
}
|
||||
|
||||
if profile.Length < 20 || !profile.Symbols {
|
||||
t.Fatalf("strong profile = %#v, want a strong reusable profile", profile)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultProfileNamesReturnsSortedNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := DefaultProfileNames()
|
||||
want := []string{"memorable", "strong"}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("DefaultProfileNames() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupDefaultProfileResolvesKnownProfilesAndRejectsUnknownNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profile, err := LookupDefaultProfile(" strong ")
|
||||
if err != nil {
|
||||
t.Fatalf("LookupDefaultProfile(\" strong \") error = %v", err)
|
||||
}
|
||||
if profile.Name != "strong" {
|
||||
t.Fatalf("LookupDefaultProfile(\" strong \").Name = %q, want %q", profile.Name, "strong")
|
||||
}
|
||||
|
||||
_, err = LookupDefaultProfile("invalid")
|
||||
if !errors.Is(err, ErrUnknownProfile) {
|
||||
t.Fatalf("LookupDefaultProfile(\"invalid\") error = %v, want ErrUnknownProfile", err)
|
||||
}
|
||||
}
|
||||
|
||||
func countFromSet(password, chars string) int {
|
||||
count := 0
|
||||
for _, r := range password {
|
||||
if strings.ContainsRune(chars, r) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
package vault
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestUpsertEntryPreservesPreviousVersionInHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "vault-console",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "old-token",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Notes: "Original note",
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
model.UpsertEntry(Entry{
|
||||
ID: "vault-console",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "new-token",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Notes: "Updated note",
|
||||
Path: []string{"Root", "Internet"},
|
||||
})
|
||||
|
||||
got := model.EntriesInPath([]string{"Root", "Internet"})
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
|
||||
}
|
||||
|
||||
if got[0].Password != "new-token" {
|
||||
t.Fatalf("Entry.Password = %q, want %q", got[0].Password, "new-token")
|
||||
}
|
||||
|
||||
if len(got[0].History) != 1 {
|
||||
t.Fatalf("len(Entry.History) = %d, want 1", len(got[0].History))
|
||||
}
|
||||
|
||||
if got[0].History[0].Password != "old-token" || got[0].History[0].Notes != "Original note" {
|
||||
t.Fatalf("Entry.History[0] = %#v, want prior entry version", got[0].History[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteEntryMovesItToRecycleBin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "surveillance-console",
|
||||
Title: "Surveillance Console",
|
||||
Username: "codex",
|
||||
Password: "token-2",
|
||||
URL: "https://surveillance.crew.example.invalid",
|
||||
Path: []string{"Root", "Home Assistant"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := model.DeleteEntry("surveillance-console"); err != nil {
|
||||
t.Fatalf("DeleteEntry() error = %v", err)
|
||||
}
|
||||
|
||||
if got := model.EntriesInPath([]string{"Root", "Home Assistant"}); len(got) != 0 {
|
||||
t.Fatalf("EntriesInPath() = %#v, want empty after delete", got)
|
||||
}
|
||||
|
||||
if len(model.RecycleBin) != 1 {
|
||||
t.Fatalf("len(RecycleBin) = %d, want 1", len(model.RecycleBin))
|
||||
}
|
||||
|
||||
if model.RecycleBin[0].Title != "Surveillance Console" {
|
||||
t.Fatalf("RecycleBin[0].Title = %q, want %q", model.RecycleBin[0].Title, "Surveillance Console")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreEntryMovesItBackFromRecycleBin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
RecycleBin: []Entry{
|
||||
{
|
||||
ID: "bellagio",
|
||||
Title: "Bellagio",
|
||||
Username: "rustyryan",
|
||||
Password: "token-3",
|
||||
URL: "https://bellagio.example.invalid",
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := model.RestoreEntry("bellagio"); err != nil {
|
||||
t.Fatalf("RestoreEntry() error = %v", err)
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Root", "Internet"})
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
|
||||
}
|
||||
|
||||
if got[0].Title != "Bellagio" {
|
||||
t.Fatalf("EntriesInPath()[0].Title = %q, want %q", got[0].Title, "Bellagio")
|
||||
}
|
||||
|
||||
if len(model.RecycleBin) != 0 {
|
||||
t.Fatalf("len(RecycleBin) = %d, want 0", len(model.RecycleBin))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreEntryVersionPromotesHistoricalVersionAndRetainsCurrentInHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "vault-console",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "new-token",
|
||||
Notes: "Current note",
|
||||
Path: []string{"Root", "Internet"},
|
||||
History: []Entry{
|
||||
{
|
||||
ID: "vault-console-history-1",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "old-token",
|
||||
Notes: "Previous note",
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := model.RestoreEntryVersion("vault-console", 0); err != nil {
|
||||
t.Fatalf("RestoreEntryVersion() error = %v", err)
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Root", "Internet"})
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
|
||||
}
|
||||
if got[0].Password != "old-token" || got[0].Notes != "Previous note" {
|
||||
t.Fatalf("restored entry = %#v, want old-token/Previous note current version", got[0])
|
||||
}
|
||||
if len(got[0].History) != 1 {
|
||||
t.Fatalf("len(History) = %d, want 1", len(got[0].History))
|
||||
}
|
||||
if got[0].History[0].Password != "new-token" || got[0].History[0].Notes != "Current note" {
|
||||
t.Fatalf("History[0] = %#v, want prior current version retained", got[0].History[0])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,628 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tobischo/gokeepasslib/v3"
|
||||
w "github.com/tobischo/gokeepasslib/v3/wrappers"
|
||||
)
|
||||
|
||||
type KDBXConfig struct {
|
||||
Header *gokeepasslib.DBHeader
|
||||
InnerHeader *gokeepasslib.InnerHeader
|
||||
}
|
||||
|
||||
var ErrInvalidMasterKey = errors.New("invalid master key")
|
||||
|
||||
const (
|
||||
templatesRoot = "Templates"
|
||||
recycleBinRoot = "Recycle Bin"
|
||||
keepassGOIDField = "KeePassGO-ID"
|
||||
remoteProfilesKey = "keepassgo.remoteProfiles"
|
||||
)
|
||||
|
||||
func LoadKDBX(r io.Reader, password string) (Model, error) {
|
||||
return LoadKDBXWithKey(r, MasterKey{Password: password})
|
||||
}
|
||||
|
||||
func SaveKDBX(wr io.Writer, model Model, password string) error {
|
||||
return SaveKDBXWithKey(wr, model, MasterKey{Password: password})
|
||||
}
|
||||
|
||||
func SaveKDBXWithKey(wr io.Writer, model Model, key MasterKey) error {
|
||||
return SaveKDBXWithConfigAndKey(wr, model, key, nil)
|
||||
}
|
||||
|
||||
func SaveKDBXWithConfigAndKey(wr io.Writer, model Model, key MasterKey, config *KDBXConfig) error {
|
||||
credentials, err := newCredentials(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4())
|
||||
db.Credentials = credentials
|
||||
db.Content.Meta = gokeepasslib.NewMetaData()
|
||||
db.Content.Meta.CustomData = customDataForModel(model)
|
||||
db.Content.Root = &gokeepasslib.RootData{}
|
||||
if config != nil && config.Header != nil {
|
||||
db.Header = cloneHeader(config.Header)
|
||||
db.Hashes = gokeepasslib.NewHashes(db.Header)
|
||||
}
|
||||
if db.Header.IsKdbx4() {
|
||||
if config != nil && config.InnerHeader != nil {
|
||||
db.Content.InnerHeader = cloneInnerHeader(config.InnerHeader)
|
||||
db.Content.InnerHeader.Binaries = nil
|
||||
} else if db.Content.InnerHeader == nil {
|
||||
db.Content.InnerHeader = &gokeepasslib.InnerHeader{
|
||||
InnerRandomStreamID: gokeepasslib.ChaChaStreamID,
|
||||
InnerRandomStreamKey: randomBytes(64),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.Content.InnerHeader = nil
|
||||
}
|
||||
db.Content.Root.Groups = buildGroupTree(db, model)
|
||||
db.Content.Root.DeletedObjects = nil
|
||||
|
||||
if err := db.LockProtectedEntries(); err != nil {
|
||||
return fmt.Errorf("lock protected entries: %w", err)
|
||||
}
|
||||
|
||||
if err := gokeepasslib.NewEncoder(wr).Encode(db); err != nil {
|
||||
return fmt.Errorf("encode kdbx: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendGroupEntries(model *Model, db *gokeepasslib.Database, group gokeepasslib.Group, path []string) {
|
||||
path = append(clonePath(path), group.Name)
|
||||
model.CreateGroup(path[:len(path)-1], group.Name)
|
||||
|
||||
for _, entry := range group.Entries {
|
||||
appendModelEntry(model, Entry{
|
||||
ID: extractEntryID(entry),
|
||||
Title: entry.GetTitle(),
|
||||
Username: entry.GetContent("UserName"),
|
||||
Password: entry.GetPassword(),
|
||||
URL: entry.GetContent("URL"),
|
||||
Notes: entry.GetContent("Notes"),
|
||||
Tags: splitTags(entry.Tags),
|
||||
Fields: extractCustomFields(entry),
|
||||
Attachments: extractAttachments(db, entry),
|
||||
History: extractHistory(db, entry, path),
|
||||
Path: clonePath(path),
|
||||
})
|
||||
}
|
||||
|
||||
for _, child := range group.Groups {
|
||||
appendGroupEntries(model, db, child, path)
|
||||
}
|
||||
}
|
||||
|
||||
func appendModelEntry(model *Model, entry Entry) {
|
||||
if len(entry.Path) == 0 {
|
||||
model.Entries = append(model.Entries, entry)
|
||||
return
|
||||
}
|
||||
|
||||
switch entry.Path[0] {
|
||||
case templatesRoot:
|
||||
model.Templates = append(model.Templates, entry)
|
||||
return
|
||||
case recycleBinRoot:
|
||||
entry.Path = slices.Clone(entry.Path[1:])
|
||||
model.RecycleBin = append(model.RecycleBin, entry)
|
||||
return
|
||||
}
|
||||
|
||||
model.Entries = append(model.Entries, entry)
|
||||
}
|
||||
|
||||
func entriesForPersistence(model Model) []Entry {
|
||||
entries := append(slices.Clone(model.Entries), model.Templates...)
|
||||
for _, entry := range model.RecycleBin {
|
||||
recycleEntry := cloneEntry(entry)
|
||||
recycleEntry.Path = append([]string{recycleBinRoot}, recycleEntry.Path...)
|
||||
entries = append(entries, recycleEntry)
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func marshalUUID(id gokeepasslib.UUID) string {
|
||||
text, err := id.MarshalText()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(text)
|
||||
}
|
||||
|
||||
func clonePath(path []string) []string {
|
||||
if len(path) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(path))
|
||||
copy(out, path)
|
||||
return out
|
||||
}
|
||||
|
||||
func splitTags(tags string) []string {
|
||||
if strings.TrimSpace(tags) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields := strings.Split(tags, ";")
|
||||
var out []string
|
||||
for _, field := range fields {
|
||||
field = strings.TrimSpace(field)
|
||||
if field == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, field)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractCustomFields(entry gokeepasslib.Entry) map[string]string {
|
||||
fields := map[string]string{}
|
||||
for _, value := range entry.Values {
|
||||
switch value.Key {
|
||||
case "Title", "UserName", "Password", "URL", "Notes", keepassGOIDField:
|
||||
continue
|
||||
default:
|
||||
fields[value.Key] = value.Value.Content
|
||||
}
|
||||
}
|
||||
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func extractEntryID(entry gokeepasslib.Entry) string {
|
||||
if id := entry.GetContent(keepassGOIDField); id != "" {
|
||||
return id
|
||||
}
|
||||
|
||||
return marshalUUID(entry.UUID)
|
||||
}
|
||||
|
||||
func extractHistory(db *gokeepasslib.Database, entry gokeepasslib.Entry, path []string) []Entry {
|
||||
if len(entry.Histories) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var history []Entry
|
||||
for _, item := range entry.Histories {
|
||||
for _, historical := range item.Entries {
|
||||
history = append(history, Entry{
|
||||
ID: extractEntryID(historical),
|
||||
Title: historical.GetTitle(),
|
||||
Username: historical.GetContent("UserName"),
|
||||
Password: historical.GetPassword(),
|
||||
URL: historical.GetContent("URL"),
|
||||
Notes: historical.GetContent("Notes"),
|
||||
Tags: splitTags(historical.Tags),
|
||||
Fields: extractCustomFields(historical),
|
||||
Attachments: extractAttachments(db, historical),
|
||||
Path: clonePath(path),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
type groupNode struct {
|
||||
name string
|
||||
children map[string]*groupNode
|
||||
entries []Entry
|
||||
}
|
||||
|
||||
type MasterKey struct {
|
||||
Password string
|
||||
KeyFileData []byte
|
||||
}
|
||||
|
||||
func buildGroupTree(db *gokeepasslib.Database, model Model) []gokeepasslib.Group {
|
||||
entries := entriesForPersistence(model)
|
||||
root := &groupNode{children: map[string]*groupNode{}}
|
||||
for _, entry := range entries {
|
||||
node := root
|
||||
for _, segment := range entry.Path {
|
||||
if node.children[segment] == nil {
|
||||
node.children[segment] = &groupNode{
|
||||
name: segment,
|
||||
children: map[string]*groupNode{},
|
||||
}
|
||||
}
|
||||
node = node.children[segment]
|
||||
}
|
||||
node.entries = append(node.entries, entry)
|
||||
}
|
||||
for _, path := range groupPathsForPersistence(model, entries) {
|
||||
node := root
|
||||
for _, segment := range path {
|
||||
if node.children[segment] == nil {
|
||||
node.children[segment] = &groupNode{
|
||||
name: segment,
|
||||
children: map[string]*groupNode{},
|
||||
}
|
||||
}
|
||||
node = node.children[segment]
|
||||
}
|
||||
}
|
||||
|
||||
groups := marshalGroups(db, root)
|
||||
if len(groups) > 0 {
|
||||
return groups
|
||||
}
|
||||
|
||||
group := gokeepasslib.NewGroup()
|
||||
group.Name = "Root"
|
||||
return []gokeepasslib.Group{group}
|
||||
}
|
||||
|
||||
func groupPathsForPersistence(model Model, entries []Entry) [][]string {
|
||||
seen := map[string]bool{}
|
||||
var groups [][]string
|
||||
appendPath := func(path []string) {
|
||||
key := strings.Join(path, "\x00")
|
||||
if seen[key] {
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
groups = append(groups, slices.Clone(path))
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
for i := 1; i <= len(entry.Path); i++ {
|
||||
appendPath(entry.Path[:i])
|
||||
}
|
||||
}
|
||||
for _, path := range model.Groups {
|
||||
for i := 1; i <= len(path); i++ {
|
||||
appendPath(path[:i])
|
||||
}
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
func LoadKDBXWithKey(r io.Reader, key MasterKey) (Model, error) {
|
||||
model, _, err := LoadKDBXWithConfig(r, key)
|
||||
return model, err
|
||||
}
|
||||
|
||||
func LoadKDBXWithConfig(r io.Reader, key MasterKey) (Model, *KDBXConfig, error) {
|
||||
credentials, err := newCredentials(key)
|
||||
if err != nil {
|
||||
return Model{}, nil, err
|
||||
}
|
||||
|
||||
db := gokeepasslib.NewDatabase()
|
||||
db.Credentials = credentials
|
||||
|
||||
if err := gokeepasslib.NewDecoder(r).Decode(db); err != nil {
|
||||
if isInvalidCredentialError(err) {
|
||||
return Model{}, nil, ErrInvalidMasterKey
|
||||
}
|
||||
return Model{}, nil, fmt.Errorf("decode kdbx: %w", err)
|
||||
}
|
||||
|
||||
if err := db.UnlockProtectedEntries(); err != nil {
|
||||
return Model{}, nil, fmt.Errorf("unlock protected entries: %w", err)
|
||||
}
|
||||
|
||||
var model Model
|
||||
for _, group := range db.Content.Root.Groups {
|
||||
appendGroupEntries(&model, db, group, nil)
|
||||
}
|
||||
model.RemoteProfiles = remoteProfilesFromMeta(db.Content.Meta)
|
||||
|
||||
return model, &KDBXConfig{
|
||||
Header: cloneHeader(db.Header),
|
||||
InnerHeader: cloneInnerHeader(db.Content.InnerHeader),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func customDataForModel(model Model) []gokeepasslib.CustomData {
|
||||
if len(model.RemoteProfiles) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
content, err := json.Marshal(model.RemoteProfiles)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []gokeepasslib.CustomData{{
|
||||
Key: remoteProfilesKey,
|
||||
Value: string(content),
|
||||
}}
|
||||
}
|
||||
|
||||
func remoteProfilesFromMeta(meta *gokeepasslib.MetaData) []RemoteProfile {
|
||||
if meta == nil {
|
||||
return nil
|
||||
}
|
||||
for _, item := range meta.CustomData {
|
||||
if item.Key != remoteProfilesKey {
|
||||
continue
|
||||
}
|
||||
var profiles []RemoteProfile
|
||||
if err := json.Unmarshal([]byte(item.Value), &profiles); err != nil {
|
||||
return nil
|
||||
}
|
||||
return profiles
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCredentials(key MasterKey) (*gokeepasslib.DBCredentials, error) {
|
||||
switch {
|
||||
case key.Password != "" && len(key.KeyFileData) > 0:
|
||||
credentials, err := gokeepasslib.NewPasswordAndKeyDataCredentials(key.Password, key.KeyFileData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build password+key credentials: %w", err)
|
||||
}
|
||||
return credentials, nil
|
||||
case len(key.KeyFileData) > 0:
|
||||
credentials, err := gokeepasslib.NewKeyDataCredentials(key.KeyFileData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build key credentials: %w", err)
|
||||
}
|
||||
return credentials, nil
|
||||
default:
|
||||
return gokeepasslib.NewPasswordCredentials(key.Password), nil
|
||||
}
|
||||
}
|
||||
|
||||
func cloneHeader(header *gokeepasslib.DBHeader) *gokeepasslib.DBHeader {
|
||||
if header == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := *header
|
||||
out.RawData = nil
|
||||
if header.Signature != nil {
|
||||
signature := *header.Signature
|
||||
out.Signature = &signature
|
||||
}
|
||||
if header.FileHeaders != nil {
|
||||
fileHeaders := *header.FileHeaders
|
||||
fileHeaders.Comment = slices.Clone(header.FileHeaders.Comment)
|
||||
fileHeaders.CipherID = slices.Clone(header.FileHeaders.CipherID)
|
||||
fileHeaders.MasterSeed = slices.Clone(header.FileHeaders.MasterSeed)
|
||||
fileHeaders.TransformSeed = slices.Clone(header.FileHeaders.TransformSeed)
|
||||
fileHeaders.EncryptionIV = slices.Clone(header.FileHeaders.EncryptionIV)
|
||||
fileHeaders.ProtectedStreamKey = slices.Clone(header.FileHeaders.ProtectedStreamKey)
|
||||
fileHeaders.StreamStartBytes = slices.Clone(header.FileHeaders.StreamStartBytes)
|
||||
if header.FileHeaders.KdfParameters != nil {
|
||||
kdf := *header.FileHeaders.KdfParameters
|
||||
kdf.UUID = slices.Clone(header.FileHeaders.KdfParameters.UUID)
|
||||
kdf.SecretKey = slices.Clone(header.FileHeaders.KdfParameters.SecretKey)
|
||||
kdf.AssocData = slices.Clone(header.FileHeaders.KdfParameters.AssocData)
|
||||
if header.FileHeaders.KdfParameters.RawData != nil {
|
||||
kdf.RawData = cloneVariantDictionary(header.FileHeaders.KdfParameters.RawData)
|
||||
}
|
||||
fileHeaders.KdfParameters = &kdf
|
||||
}
|
||||
if header.FileHeaders.PublicCustomData != nil {
|
||||
fileHeaders.PublicCustomData = cloneVariantDictionary(header.FileHeaders.PublicCustomData)
|
||||
}
|
||||
out.FileHeaders = &fileHeaders
|
||||
}
|
||||
|
||||
return &out
|
||||
}
|
||||
|
||||
func cloneVariantDictionary(dict *gokeepasslib.VariantDictionary) *gokeepasslib.VariantDictionary {
|
||||
if dict == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := &gokeepasslib.VariantDictionary{Version: dict.Version}
|
||||
out.Items = make([]*gokeepasslib.VariantDictionaryItem, 0, len(dict.Items))
|
||||
for _, item := range dict.Items {
|
||||
cloned := *item
|
||||
cloned.Name = slices.Clone(item.Name)
|
||||
cloned.Value = slices.Clone(item.Value)
|
||||
out.Items = append(out.Items, &cloned)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneInnerHeader(header *gokeepasslib.InnerHeader) *gokeepasslib.InnerHeader {
|
||||
if header == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := &gokeepasslib.InnerHeader{
|
||||
InnerRandomStreamID: header.InnerRandomStreamID,
|
||||
InnerRandomStreamKey: slices.Clone(header.InnerRandomStreamKey),
|
||||
}
|
||||
for _, binary := range header.Binaries {
|
||||
out.Binaries = append(out.Binaries, gokeepasslib.Binary{
|
||||
ID: binary.ID,
|
||||
Compressed: binary.Compressed,
|
||||
MemoryProtection: binary.MemoryProtection,
|
||||
Content: slices.Clone(binary.Content),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func randomBytes(length int) []byte {
|
||||
buf := make([]byte, length)
|
||||
_, _ = io.ReadFull(rand.Reader, buf)
|
||||
return buf
|
||||
}
|
||||
|
||||
func isInvalidCredentialError(err error) bool {
|
||||
if errors.Is(err, gokeepasslib.ErrInvalidDatabaseOrCredentials) {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(err.Error(), "Wrong password?")
|
||||
}
|
||||
|
||||
func marshalGroups(db *gokeepasslib.Database, node *groupNode) []gokeepasslib.Group {
|
||||
names := slices.Collect(maps.Keys(node.children))
|
||||
slices.SortFunc(names, compareGroupNames)
|
||||
|
||||
var groups []gokeepasslib.Group
|
||||
for _, name := range names {
|
||||
child := node.children[name]
|
||||
group := gokeepasslib.NewGroup()
|
||||
group.Name = child.name
|
||||
group.Entries = marshalEntries(db, child.entries)
|
||||
group.Groups = marshalGroups(db, child)
|
||||
groups = append(groups, group)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
func compareGroupNames(a, b string) int {
|
||||
switch {
|
||||
case a == b:
|
||||
return 0
|
||||
case a == "Root":
|
||||
return -1
|
||||
case b == "Root":
|
||||
return 1
|
||||
case a == templatesRoot:
|
||||
return -1
|
||||
case b == templatesRoot:
|
||||
return 1
|
||||
case a == recycleBinRoot:
|
||||
return 1
|
||||
case b == recycleBinRoot:
|
||||
return -1
|
||||
case a < b:
|
||||
return -1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func marshalEntries(db *gokeepasslib.Database, entries []Entry) []gokeepasslib.Entry {
|
||||
slices.SortFunc(entries, func(a, b Entry) int {
|
||||
switch {
|
||||
case a.Title < b.Title:
|
||||
return -1
|
||||
case a.Title > b.Title:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
|
||||
var out []gokeepasslib.Entry
|
||||
for _, entry := range entries {
|
||||
out = append(out, marshalEntry(db, entry))
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func marshalEntry(db *gokeepasslib.Database, entry Entry) gokeepasslib.Entry {
|
||||
item := gokeepasslib.NewEntry()
|
||||
item.UUID = uuidForEntryID(entry.ID)
|
||||
item.Tags = strings.Join(entry.Tags, "; ")
|
||||
item.Values = append(item.Values,
|
||||
value("Title", entry.Title),
|
||||
value("UserName", entry.Username),
|
||||
protectedValue("Password", entry.Password),
|
||||
value("URL", entry.URL),
|
||||
value("Notes", entry.Notes),
|
||||
value(keepassGOIDField, entry.ID),
|
||||
)
|
||||
|
||||
keys := slices.Collect(maps.Keys(entry.Fields))
|
||||
slices.Sort(keys)
|
||||
for _, key := range keys {
|
||||
item.Values = append(item.Values, value(key, entry.Fields[key]))
|
||||
}
|
||||
|
||||
attachmentNames := slices.Collect(maps.Keys(entry.Attachments))
|
||||
slices.Sort(attachmentNames)
|
||||
for _, name := range attachmentNames {
|
||||
binary := db.AddBinary(entry.Attachments[name])
|
||||
item.Binaries = append(item.Binaries, binary.CreateReference(name))
|
||||
}
|
||||
|
||||
for _, historical := range entry.History {
|
||||
item.Histories = append(item.Histories, gokeepasslib.History{
|
||||
Entries: []gokeepasslib.Entry{marshalEntry(db, historical)},
|
||||
})
|
||||
}
|
||||
|
||||
return item
|
||||
}
|
||||
|
||||
func uuidForEntryID(id string) gokeepasslib.UUID {
|
||||
if id != "" {
|
||||
var uuid gokeepasslib.UUID
|
||||
if err := uuid.UnmarshalText([]byte(id)); err == nil {
|
||||
return uuid
|
||||
}
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(id))
|
||||
var uuid gokeepasslib.UUID
|
||||
copy(uuid[:], sum[:len(uuid)])
|
||||
if id == "" {
|
||||
copy(uuid[:], time.Now().UTC().AppendFormat(nil, time.RFC3339Nano))
|
||||
}
|
||||
return uuid
|
||||
}
|
||||
|
||||
func value(key, content string) gokeepasslib.ValueData {
|
||||
return gokeepasslib.ValueData{Key: key, Value: gokeepasslib.V{Content: content}}
|
||||
}
|
||||
|
||||
func protectedValue(key, content string) gokeepasslib.ValueData {
|
||||
return gokeepasslib.ValueData{
|
||||
Key: key,
|
||||
Value: gokeepasslib.V{Content: content, Protected: w.NewBoolWrapper(true)},
|
||||
}
|
||||
}
|
||||
|
||||
func extractAttachments(db *gokeepasslib.Database, entry gokeepasslib.Entry) map[string][]byte {
|
||||
if len(entry.Binaries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
attachments := map[string][]byte{}
|
||||
for _, ref := range entry.Binaries {
|
||||
binary := db.FindBinary(ref.Value.ID)
|
||||
if binary == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := binary.GetContentBytes()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attachments[ref.Name] = slices.Clone(content)
|
||||
}
|
||||
|
||||
if len(attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return attachments
|
||||
}
|
||||
@@ -0,0 +1,794 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/tobischo/gokeepasslib/v3"
|
||||
w "github.com/tobischo/gokeepasslib/v3/wrappers"
|
||||
)
|
||||
|
||||
func TestLoadKDBXBuildsModelFromNestedGroups(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := &gokeepasslib.Database{
|
||||
Header: gokeepasslib.NewHeader(),
|
||||
Credentials: gokeepasslib.NewPasswordCredentials("correct horse battery staple"),
|
||||
Content: &gokeepasslib.DBContent{
|
||||
Meta: gokeepasslib.NewMetaData(),
|
||||
Root: &gokeepasslib.RootData{
|
||||
Groups: []gokeepasslib.Group{
|
||||
mustGroup("Root",
|
||||
mustGroup("Internet",
|
||||
mustEntry("Bellagio", "rustyryan", "https://bellagio.example.invalid", "hunter2"),
|
||||
mustEntry("Vault Console", "dannyocean", "https://vault.crew.example.invalid", "bellagio-pass-1"),
|
||||
),
|
||||
mustGroup("Security Office",
|
||||
mustEntry("Surveillance Console", "bashertarr", "https://surveillance.crew.example.invalid", "bellagio-pass-2"),
|
||||
),
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := db.LockProtectedEntries(); err != nil {
|
||||
t.Fatalf("LockProtectedEntries failed: %v", err)
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
|
||||
t.Fatalf("Encode failed: %v", err)
|
||||
}
|
||||
|
||||
model, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX failed: %v", err)
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Root", "Internet"})
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 2", len(got))
|
||||
}
|
||||
|
||||
if got[0].Title != "Bellagio" || got[0].Username != "rustyryan" || got[0].URL != "https://bellagio.example.invalid" {
|
||||
t.Fatalf("unexpected first entry: %#v", got[0])
|
||||
}
|
||||
|
||||
if got[1].Title != "Vault Console" || got[1].Username != "dannyocean" || got[1].URL != "https://vault.crew.example.invalid" {
|
||||
t.Fatalf("unexpected second entry: %#v", got[1])
|
||||
}
|
||||
|
||||
groups := model.ChildGroups([]string{"Root"})
|
||||
if len(groups) != 2 || groups[0] != "Internet" || groups[1] != "Security Office" {
|
||||
t.Fatalf("ChildGroups() = %v, want [Internet Security Office]", groups)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadKDBXPreservesEntryDetails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
entry := mustEntry("Surveillance Console", "bashertarr", "https://surveillance.crew.example.invalid", "bellagio-pass-2")
|
||||
entry.Tags = "automation; home"
|
||||
entry.Values = append(entry.Values,
|
||||
mkValue("Notes", "Long-lived token used by Codex for home automation tasks."),
|
||||
mkValue("X-Role", "automation"),
|
||||
)
|
||||
|
||||
db := &gokeepasslib.Database{
|
||||
Header: gokeepasslib.NewHeader(),
|
||||
Credentials: gokeepasslib.NewPasswordCredentials("correct horse battery staple"),
|
||||
Content: &gokeepasslib.DBContent{
|
||||
Meta: gokeepasslib.NewMetaData(),
|
||||
Root: &gokeepasslib.RootData{
|
||||
Groups: []gokeepasslib.Group{
|
||||
mustGroup("Root", mustGroup("Security Office", entry)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := db.LockProtectedEntries(); err != nil {
|
||||
t.Fatalf("LockProtectedEntries failed: %v", err)
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
|
||||
t.Fatalf("Encode failed: %v", err)
|
||||
}
|
||||
|
||||
model, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX failed: %v", err)
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Root", "Security Office"})
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
|
||||
}
|
||||
|
||||
if got[0].Password != "bellagio-pass-2" {
|
||||
t.Fatalf("Entry.Password = %q, want %q", got[0].Password, "bellagio-pass-2")
|
||||
}
|
||||
|
||||
if got[0].Notes != "Long-lived token used by Codex for home automation tasks." {
|
||||
t.Fatalf("Entry.Notes = %q, want %q", got[0].Notes, "Long-lived token used by Codex for home automation tasks.")
|
||||
}
|
||||
|
||||
if len(got[0].Tags) != 2 || got[0].Tags[0] != "automation" || got[0].Tags[1] != "home" {
|
||||
t.Fatalf("Entry.Tags = %v, want [automation home]", got[0].Tags)
|
||||
}
|
||||
|
||||
if got[0].Fields["X-Role"] != "automation" {
|
||||
t.Fatalf("Entry.Fields[\"X-Role\"] = %q, want %q", got[0].Fields["X-Role"], "automation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveKDBXRoundTripsModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "entry-1",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "bellagio-pass-1",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Notes: "Personal git server token entry used for automation and CLI auth.",
|
||||
Tags: []string{"git", "infra"},
|
||||
Fields: map[string]string{
|
||||
"X-Role": "automation",
|
||||
},
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
{
|
||||
ID: "entry-2",
|
||||
Title: "Surveillance Console",
|
||||
Username: "bashertarr",
|
||||
Password: "bellagio-pass-2",
|
||||
URL: "https://surveillance.crew.example.invalid",
|
||||
Notes: "Long-lived token used by Codex for home automation tasks.",
|
||||
Tags: []string{"automation", "home"},
|
||||
Path: []string{"Root", "Security Office"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
|
||||
t.Fatalf("SaveKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
got := loaded.Search("vault")
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(Search(\"git\")) = %d, want 1", len(got))
|
||||
}
|
||||
|
||||
if got[0].Entry.Notes != "Personal git server token entry used for automation and CLI auth." {
|
||||
t.Fatalf("Search(\"git\") notes = %q, want %q", got[0].Entry.Notes, "Personal git server token entry used for automation and CLI auth.")
|
||||
}
|
||||
|
||||
if got[0].Entry.Fields["X-Role"] != "automation" {
|
||||
t.Fatalf("Search(\"git\") X-Role = %q, want %q", got[0].Entry.Fields["X-Role"], "automation")
|
||||
}
|
||||
|
||||
homeAssistant := loaded.EntriesInPath([]string{"Root", "Security Office"})
|
||||
if len(homeAssistant) != 1 {
|
||||
t.Fatalf("len(EntriesInPath(Security Office)) = %d, want 1", len(homeAssistant))
|
||||
}
|
||||
|
||||
if homeAssistant[0].Password != "bellagio-pass-2" {
|
||||
t.Fatalf("Security Office password = %q, want %q", homeAssistant[0].Password, "bellagio-pass-2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveKDBXRoundTripsTemplates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Templates: []Entry{
|
||||
{
|
||||
ID: "tpl-1",
|
||||
Title: "Website Login",
|
||||
Username: "template-user",
|
||||
Password: "template-password",
|
||||
URL: "https://example.com",
|
||||
Notes: "Reusable template for website accounts.",
|
||||
Tags: []string{"template", "web"},
|
||||
Fields: map[string]string{
|
||||
"Environment": "prod",
|
||||
},
|
||||
Path: []string{"Templates"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
|
||||
t.Fatalf("SaveKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
if len(loaded.Templates) != 1 {
|
||||
t.Fatalf("len(Templates) = %d, want 1", len(loaded.Templates))
|
||||
}
|
||||
|
||||
if loaded.Templates[0].Title != "Website Login" {
|
||||
t.Fatalf("Templates[0].Title = %q, want %q", loaded.Templates[0].Title, "Website Login")
|
||||
}
|
||||
|
||||
if loaded.Templates[0].Fields["Environment"] != "prod" {
|
||||
t.Fatalf("Templates[0].Fields[Environment] = %q, want %q", loaded.Templates[0].Fields["Environment"], "prod")
|
||||
}
|
||||
|
||||
if len(loaded.Entries) != 0 {
|
||||
t.Fatalf("len(Entries) = %d, want 0", len(loaded.Entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveKDBXRoundTripsRemoteProfiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
RemoteProfiles: []RemoteProfile{
|
||||
{
|
||||
ID: "bellagio-webdav",
|
||||
Name: "Bellagio Vault",
|
||||
Backend: RemoteBackendWebDAV,
|
||||
BaseURL: "https://dav.example.invalid/remote.php/dav",
|
||||
Path: "files/bellagio/keepass.kdbx",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
|
||||
t.Fatalf("SaveKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
if len(loaded.RemoteProfiles) != 1 {
|
||||
t.Fatalf("len(RemoteProfiles) = %d, want 1", len(loaded.RemoteProfiles))
|
||||
}
|
||||
|
||||
got := loaded.RemoteProfiles[0]
|
||||
if got.ID != "bellagio-webdav" || got.Name != "Bellagio Vault" {
|
||||
t.Fatalf("loaded remote profile = %#v, want bellagio-webdav Bellagio Vault", got)
|
||||
}
|
||||
if got.Backend != RemoteBackendWebDAV {
|
||||
t.Fatalf("remote backend = %q, want %q", got.Backend, RemoteBackendWebDAV)
|
||||
}
|
||||
if got.BaseURL != "https://dav.example.invalid/remote.php/dav" {
|
||||
t.Fatalf("remote base URL = %q, want remote.php/dav URL", got.BaseURL)
|
||||
}
|
||||
if got.Path != "files/bellagio/keepass.kdbx" {
|
||||
t.Fatalf("remote path = %q, want files/bellagio/keepass.kdbx", got.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveKDBXRoundTripsEntryHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "entry-1",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "new-token",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Path: []string{"Root", "Internet"},
|
||||
History: []Entry{
|
||||
{
|
||||
ID: "entry-1-old",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "old-token",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Path: []string{"Root", "Internet"},
|
||||
Notes: "Original version",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
|
||||
t.Fatalf("SaveKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
got := loaded.EntriesInPath([]string{"Root", "Internet"})
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
|
||||
}
|
||||
|
||||
if len(got[0].History) != 1 {
|
||||
t.Fatalf("len(History) = %d, want 1", len(got[0].History))
|
||||
}
|
||||
|
||||
if got[0].History[0].Password != "old-token" || got[0].History[0].Notes != "Original version" {
|
||||
t.Fatalf("History[0] = %#v, want preserved prior version", got[0].History[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveKDBXRoundTripsRecycleBinEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
RecycleBin: []Entry{
|
||||
{
|
||||
ID: "entry-1",
|
||||
Title: "Surveillance Console",
|
||||
Username: "bashertarr",
|
||||
Password: "bellagio-pass-2",
|
||||
URL: "https://surveillance.crew.example.invalid",
|
||||
Path: []string{"Root", "Security Office"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
|
||||
t.Fatalf("SaveKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
if len(loaded.RecycleBin) != 1 {
|
||||
t.Fatalf("len(RecycleBin) = %d, want 1", len(loaded.RecycleBin))
|
||||
}
|
||||
|
||||
if loaded.RecycleBin[0].Title != "Surveillance Console" {
|
||||
t.Fatalf("RecycleBin[0].Title = %q, want %q", loaded.RecycleBin[0].Title, "Surveillance Console")
|
||||
}
|
||||
|
||||
if len(loaded.RecycleBin[0].Path) != 2 || loaded.RecycleBin[0].Path[0] != "Root" || loaded.RecycleBin[0].Path[1] != "Security Office" {
|
||||
t.Fatalf("RecycleBin[0].Path = %v, want [Root Security Office]", loaded.RecycleBin[0].Path)
|
||||
}
|
||||
|
||||
if len(loaded.Entries) != 0 {
|
||||
t.Fatalf("len(Entries) = %d, want 0", len(loaded.Entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadKDBXWithKeyFileCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyData := []byte(`<?xml version="1.0" encoding="utf-8"?>
|
||||
<KeyFile>
|
||||
<Meta>
|
||||
<Version>1.0</Version>
|
||||
</Meta>
|
||||
<Key>
|
||||
<Data>PbLBYmgEXFhLWf2gxoBMARXgDZGE7f34tr+anCw52LI=</Data>
|
||||
</Key>
|
||||
</KeyFile>
|
||||
`)
|
||||
|
||||
credentials, err := newCredentials(MasterKey{KeyFileData: keyData})
|
||||
if err != nil {
|
||||
t.Fatalf("newCredentials() error = %v", err)
|
||||
}
|
||||
|
||||
db := &gokeepasslib.Database{
|
||||
Header: gokeepasslib.NewHeader(),
|
||||
Credentials: credentials,
|
||||
Content: &gokeepasslib.DBContent{
|
||||
Meta: gokeepasslib.NewMetaData(),
|
||||
Root: &gokeepasslib.RootData{
|
||||
Groups: []gokeepasslib.Group{
|
||||
mustGroup("Root", mustGroup("Internet", mustEntry("Vault Console", "dannyocean", "https://vault.crew.example.invalid", "bellagio-pass-1"))),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := db.LockProtectedEntries(); err != nil {
|
||||
t.Fatalf("LockProtectedEntries failed: %v", err)
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
|
||||
t.Fatalf("Encode failed: %v", err)
|
||||
}
|
||||
|
||||
model, err := LoadKDBXWithKey(bytes.NewReader(encoded.Bytes()), MasterKey{KeyFileData: keyData})
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBXWithKey() error = %v", err)
|
||||
}
|
||||
|
||||
got := model.Search("vault")
|
||||
if len(got) != 1 || got[0].Entry.Password != "bellagio-pass-1" {
|
||||
t.Fatalf("LoadKDBXWithKey() = %#v, want password-preserving vault entry", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadKDBXWithCompositeCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyData := []byte(`<?xml version="1.0" encoding="utf-8"?>
|
||||
<KeyFile>
|
||||
<Meta>
|
||||
<Version>1.0</Version>
|
||||
</Meta>
|
||||
<Key>
|
||||
<Data>PbLBYmgEXFhLWf2gxoBMARXgDZGE7f34tr+anCw52LI=</Data>
|
||||
</Key>
|
||||
</KeyFile>
|
||||
`)
|
||||
|
||||
credentials, err := newCredentials(MasterKey{
|
||||
Password: "correct horse battery staple",
|
||||
KeyFileData: keyData,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("newCredentials() error = %v", err)
|
||||
}
|
||||
|
||||
db := &gokeepasslib.Database{
|
||||
Header: gokeepasslib.NewHeader(),
|
||||
Credentials: credentials,
|
||||
Content: &gokeepasslib.DBContent{
|
||||
Meta: gokeepasslib.NewMetaData(),
|
||||
Root: &gokeepasslib.RootData{
|
||||
Groups: []gokeepasslib.Group{
|
||||
mustGroup("Root", mustGroup("Security Office", mustEntry("Surveillance Console", "bashertarr", "https://surveillance.crew.example.invalid", "bellagio-pass-2"))),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := db.LockProtectedEntries(); err != nil {
|
||||
t.Fatalf("LockProtectedEntries failed: %v", err)
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
|
||||
t.Fatalf("Encode failed: %v", err)
|
||||
}
|
||||
|
||||
model, err := LoadKDBXWithKey(bytes.NewReader(encoded.Bytes()), MasterKey{
|
||||
Password: "correct horse battery staple",
|
||||
KeyFileData: keyData,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBXWithKey() error = %v", err)
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Root", "Security Office"})
|
||||
if len(got) != 1 || got[0].Password != "bellagio-pass-2" {
|
||||
t.Fatalf("LoadKDBXWithKey() = %#v, want Security Office entry with password", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadKDBXReturnsInvalidCredentialsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := &gokeepasslib.Database{
|
||||
Header: gokeepasslib.NewHeader(),
|
||||
Credentials: gokeepasslib.NewPasswordCredentials("correct horse battery staple"),
|
||||
Content: &gokeepasslib.DBContent{
|
||||
Meta: gokeepasslib.NewMetaData(),
|
||||
Root: &gokeepasslib.RootData{
|
||||
Groups: []gokeepasslib.Group{
|
||||
mustGroup("Root", mustGroup("Internet", mustEntry("Vault Console", "dannyocean", "https://vault.crew.example.invalid", "bellagio-pass-1"))),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := db.LockProtectedEntries(); err != nil {
|
||||
t.Fatalf("LockProtectedEntries failed: %v", err)
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := gokeepasslib.NewEncoder(&encoded).Encode(db); err != nil {
|
||||
t.Fatalf("Encode failed: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "definitely wrong password")
|
||||
if !errors.Is(err, ErrInvalidMasterKey) {
|
||||
t.Fatalf("LoadKDBX() error = %v, want %v", err, ErrInvalidMasterKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveKDBXWithKeyRoundTripsModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyData := []byte(`<?xml version="1.0" encoding="utf-8"?>
|
||||
<KeyFile>
|
||||
<Meta>
|
||||
<Version>1.0</Version>
|
||||
</Meta>
|
||||
<Key>
|
||||
<Data>PbLBYmgEXFhLWf2gxoBMARXgDZGE7f34tr+anCw52LI=</Data>
|
||||
</Key>
|
||||
</KeyFile>
|
||||
`)
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "vault-console",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "bellagio-pass-1",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBXWithKey(&encoded, model, MasterKey{KeyFileData: keyData}); err != nil {
|
||||
t.Fatalf("SaveKDBXWithKey() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadKDBXWithKey(bytes.NewReader(encoded.Bytes()), MasterKey{KeyFileData: keyData})
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBXWithKey() error = %v", err)
|
||||
}
|
||||
|
||||
got := loaded.Search("vault")
|
||||
if len(got) != 1 || got[0].Entry.Password != "bellagio-pass-1" {
|
||||
t.Fatalf("round-trip with key file = %#v, want vault entry with password", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveKDBXWithCompositeKeyRoundTripsModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyData := []byte(`<?xml version="1.0" encoding="utf-8"?>
|
||||
<KeyFile>
|
||||
<Meta>
|
||||
<Version>1.0</Version>
|
||||
</Meta>
|
||||
<Key>
|
||||
<Data>PbLBYmgEXFhLWf2gxoBMARXgDZGE7f34tr+anCw52LI=</Data>
|
||||
</Key>
|
||||
</KeyFile>
|
||||
`)
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "surveillance-console",
|
||||
Title: "Surveillance Console",
|
||||
Username: "bashertarr",
|
||||
Password: "bellagio-pass-2",
|
||||
URL: "https://surveillance.crew.example.invalid",
|
||||
Path: []string{"Root", "Security Office"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
key := MasterKey{
|
||||
Password: "correct horse battery staple",
|
||||
KeyFileData: keyData,
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBXWithKey(&encoded, model, key); err != nil {
|
||||
t.Fatalf("SaveKDBXWithKey() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadKDBXWithKey(bytes.NewReader(encoded.Bytes()), key)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBXWithKey() error = %v", err)
|
||||
}
|
||||
|
||||
got := loaded.EntriesInPath([]string{"Root", "Security Office"})
|
||||
if len(got) != 1 || got[0].Password != "bellagio-pass-2" {
|
||||
t.Fatalf("composite key round-trip = %#v, want Security Office entry with password", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKDBXRoundTripsEntryAttachments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "vault-console",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "bellagio-pass-1",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Path: []string{"Root", "Internet"},
|
||||
Attachments: map[string][]byte{
|
||||
"token.txt": []byte("secret attachment contents"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
|
||||
t.Fatalf("SaveKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX() error = %v", err)
|
||||
}
|
||||
|
||||
got := loaded.EntriesInPath([]string{"Root", "Internet"})
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 1", len(got))
|
||||
}
|
||||
|
||||
if string(got[0].Attachments["token.txt"]) != "secret attachment contents" {
|
||||
t.Fatalf("attachment contents = %q, want %q", string(got[0].Attachments["token.txt"]), "secret attachment contents")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKDBXReopenCyclesPreserveStableIDsAndCrossFeatureState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "entry-1",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "bellagio-pass-2",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Notes: "Current credential",
|
||||
Path: []string{"Root", "Internet"},
|
||||
Attachments: map[string][]byte{
|
||||
"token.txt": []byte("secret attachment contents"),
|
||||
},
|
||||
History: []Entry{
|
||||
{
|
||||
ID: "entry-1-history-1",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "bellagio-pass-1",
|
||||
URL: "https://vault.crew.example.invalid",
|
||||
Notes: "Original credential",
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Templates: []Entry{
|
||||
{
|
||||
ID: "tpl-1",
|
||||
Title: "Website Login",
|
||||
Username: "template-user",
|
||||
Password: "template-password",
|
||||
Path: []string{"Templates", "Web"},
|
||||
},
|
||||
},
|
||||
RecycleBin: []Entry{
|
||||
{
|
||||
ID: "deleted-1",
|
||||
Title: "Retired Entry",
|
||||
Username: "archived-user",
|
||||
Password: "retired-token",
|
||||
Path: []string{"Root", "Archive"},
|
||||
},
|
||||
},
|
||||
Groups: [][]string{
|
||||
{"Root", "Archive"},
|
||||
{"Root", "Empty Group"},
|
||||
{"Templates", "Web"},
|
||||
},
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBX(&encoded, model, "correct horse battery staple"); err != nil {
|
||||
t.Fatalf("SaveKDBX(first cycle) error = %v", err)
|
||||
}
|
||||
|
||||
reopened, err := LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX(first cycle) error = %v", err)
|
||||
}
|
||||
|
||||
encoded.Reset()
|
||||
if err := SaveKDBX(&encoded, reopened, "correct horse battery staple"); err != nil {
|
||||
t.Fatalf("SaveKDBX(second cycle) error = %v", err)
|
||||
}
|
||||
|
||||
reopened, err = LoadKDBX(bytes.NewReader(encoded.Bytes()), "correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBX(second cycle) error = %v", err)
|
||||
}
|
||||
|
||||
got := reopened.EntriesInPath([]string{"Root", "Internet"})
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(EntriesInPath(Root/Internet)) = %d, want 1", len(got))
|
||||
}
|
||||
if got[0].ID != "entry-1" {
|
||||
t.Fatalf("entry ID after reopen cycles = %q, want %q", got[0].ID, "entry-1")
|
||||
}
|
||||
if len(got[0].History) != 1 {
|
||||
t.Fatalf("len(History) after reopen cycles = %d, want 1", len(got[0].History))
|
||||
}
|
||||
if got[0].History[0].ID != "entry-1-history-1" {
|
||||
t.Fatalf("history ID after reopen cycles = %q, want %q", got[0].History[0].ID, "entry-1-history-1")
|
||||
}
|
||||
if string(got[0].Attachments["token.txt"]) != "secret attachment contents" {
|
||||
t.Fatalf("attachment after reopen cycles = %q, want %q", string(got[0].Attachments["token.txt"]), "secret attachment contents")
|
||||
}
|
||||
|
||||
if len(reopened.Templates) != 1 || reopened.Templates[0].Path[1] != "Web" {
|
||||
t.Fatalf("Templates after reopen cycles = %#v, want Website Login in Templates/Web", reopened.Templates)
|
||||
}
|
||||
if len(reopened.RecycleBin) != 1 || reopened.RecycleBin[0].Path[1] != "Archive" {
|
||||
t.Fatalf("RecycleBin after reopen cycles = %#v, want recycled entry in Root/Archive", reopened.RecycleBin)
|
||||
}
|
||||
|
||||
rootGroups := reopened.ChildGroups([]string{"Root"})
|
||||
if !slices.Equal(rootGroups, []string{"Archive", "Empty Group", "Internet"}) {
|
||||
t.Fatalf("ChildGroups(Root) after reopen cycles = %v, want [Archive Empty Group Internet]", rootGroups)
|
||||
}
|
||||
templateGroups := reopened.ChildGroups([]string{"Templates"})
|
||||
if !slices.Equal(templateGroups, []string{"Web"}) {
|
||||
t.Fatalf("ChildGroups(Templates) after reopen cycles = %v, want [Web]", templateGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func mustGroup(name string, children ...any) gokeepasslib.Group {
|
||||
group := gokeepasslib.NewGroup()
|
||||
group.Name = name
|
||||
for _, child := range children {
|
||||
switch value := child.(type) {
|
||||
case gokeepasslib.Group:
|
||||
group.Groups = append(group.Groups, value)
|
||||
case gokeepasslib.Entry:
|
||||
group.Entries = append(group.Entries, value)
|
||||
default:
|
||||
panic("unsupported child type")
|
||||
}
|
||||
}
|
||||
return group
|
||||
}
|
||||
|
||||
func mustEntry(title, username, url, password string) gokeepasslib.Entry {
|
||||
entry := gokeepasslib.NewEntry()
|
||||
entry.Values = append(entry.Values,
|
||||
mkValue("Title", title),
|
||||
mkValue("UserName", username),
|
||||
mkValue("URL", url),
|
||||
mkProtectedValue("Password", password),
|
||||
)
|
||||
return entry
|
||||
}
|
||||
|
||||
func mkValue(key, value string) gokeepasslib.ValueData {
|
||||
return gokeepasslib.ValueData{Key: key, Value: gokeepasslib.V{Content: value}}
|
||||
}
|
||||
|
||||
func mkProtectedValue(key, value string) gokeepasslib.ValueData {
|
||||
return gokeepasslib.ValueData{
|
||||
Key: key,
|
||||
Value: gokeepasslib.V{Content: value, Protected: w.NewBoolWrapper(true)},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package vault
|
||||
|
||||
// MasterKeyMode identifies which key material the user intends to provide.
|
||||
type MasterKeyMode string
|
||||
|
||||
const (
|
||||
// MasterKeyModePasswordOnly requires a master password and no key file.
|
||||
MasterKeyModePasswordOnly MasterKeyMode = "password-only"
|
||||
// MasterKeyModeKeyFileOnly requires a key file and no password.
|
||||
MasterKeyModeKeyFileOnly MasterKeyMode = "key-file-only"
|
||||
// MasterKeyModePasswordAndKeyFile requires both password and key file.
|
||||
MasterKeyModePasswordAndKeyFile MasterKeyMode = "password-and-key-file"
|
||||
)
|
||||
@@ -0,0 +1,601 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrEntryNotFound = errors.New("entry not found")
|
||||
var ErrGroupNotEmpty = errors.New("group is not empty")
|
||||
var ErrRemoteProfileNotFound = errors.New("remote profile not found")
|
||||
|
||||
type RemoteBackend string
|
||||
|
||||
const (
|
||||
RemoteBackendWebDAV RemoteBackend = "webdav"
|
||||
)
|
||||
|
||||
type RemoteProfile struct {
|
||||
ID string
|
||||
Name string
|
||||
Backend RemoteBackend
|
||||
BaseURL string
|
||||
Path string
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
ID string
|
||||
Title string
|
||||
Username string
|
||||
Password string
|
||||
URL string
|
||||
Notes string
|
||||
Tags []string
|
||||
Fields map[string]string
|
||||
Attachments map[string][]byte
|
||||
History []Entry
|
||||
Path []string
|
||||
}
|
||||
|
||||
type SearchResult struct {
|
||||
Entry Entry
|
||||
Path string
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Entries []Entry
|
||||
Templates []Entry
|
||||
RecycleBin []Entry
|
||||
Groups [][]string
|
||||
RemoteProfiles []RemoteProfile
|
||||
}
|
||||
|
||||
func (m Model) ChildGroups(path []string) []string {
|
||||
seen := map[string]bool{}
|
||||
var groups []string
|
||||
for _, entry := range m.Entries {
|
||||
if len(path) > len(entry.Path) {
|
||||
continue
|
||||
}
|
||||
if !slices.Equal(entry.Path[:len(path)], path) {
|
||||
continue
|
||||
}
|
||||
if len(entry.Path) == len(path) {
|
||||
continue
|
||||
}
|
||||
group := entry.Path[len(path)]
|
||||
if seen[group] {
|
||||
continue
|
||||
}
|
||||
seen[group] = true
|
||||
groups = append(groups, group)
|
||||
}
|
||||
for _, groupPath := range m.Groups {
|
||||
if len(path) > len(groupPath) {
|
||||
continue
|
||||
}
|
||||
if !slices.Equal(groupPath[:len(path)], path) {
|
||||
continue
|
||||
}
|
||||
if len(groupPath) == len(path) {
|
||||
continue
|
||||
}
|
||||
group := groupPath[len(path)]
|
||||
if seen[group] {
|
||||
continue
|
||||
}
|
||||
seen[group] = true
|
||||
groups = append(groups, group)
|
||||
}
|
||||
slices.Sort(groups)
|
||||
return groups
|
||||
}
|
||||
|
||||
func (m Model) EntriesInPath(path []string) []Entry {
|
||||
var entries []Entry
|
||||
for _, entry := range m.Entries {
|
||||
if slices.Equal(entry.Path, path) {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
slices.SortFunc(entries, func(a, b Entry) int {
|
||||
switch {
|
||||
case a.Title < b.Title:
|
||||
return -1
|
||||
case a.Title > b.Title:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func (m Model) EntriesUnderPath(path []string) []Entry {
|
||||
var entries []Entry
|
||||
for _, entry := range m.Entries {
|
||||
if !hasPathPrefix(entry.Path, path) {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
slices.SortFunc(entries, func(a, b Entry) int {
|
||||
switch {
|
||||
case a.Title < b.Title:
|
||||
return -1
|
||||
case a.Title > b.Title:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func (m Model) Search(query string) []SearchResult {
|
||||
query = strings.TrimSpace(strings.ToLower(query))
|
||||
if query == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results []SearchResult
|
||||
for _, entry := range m.Entries {
|
||||
haystack := strings.ToLower(
|
||||
entry.Title + " " +
|
||||
entry.Username + " " +
|
||||
entry.URL + " " +
|
||||
strings.Join(entry.Path, " "),
|
||||
)
|
||||
if !strings.Contains(haystack, query) {
|
||||
continue
|
||||
}
|
||||
results = append(results, SearchResult{
|
||||
Entry: entry,
|
||||
Path: strings.Join(entry.Path, " / "),
|
||||
})
|
||||
}
|
||||
|
||||
slices.SortFunc(results, func(a, b SearchResult) int {
|
||||
switch {
|
||||
case a.Entry.Title < b.Entry.Title:
|
||||
return -1
|
||||
case a.Entry.Title > b.Entry.Title:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return results
|
||||
}
|
||||
|
||||
func (m *Model) UpsertEntry(entry Entry) {
|
||||
for i := range m.Entries {
|
||||
if m.Entries[i].ID != entry.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
previous := cloneEntry(m.Entries[i])
|
||||
entry.History = append([]Entry{previous}, cloneHistory(m.Entries[i].History)...)
|
||||
m.Entries[i] = cloneEntry(entry)
|
||||
return
|
||||
}
|
||||
|
||||
m.Entries = append(m.Entries, cloneEntry(entry))
|
||||
}
|
||||
|
||||
func (m *Model) RemoveEntryByID(id string) bool {
|
||||
for i := range m.Entries {
|
||||
if m.Entries[i].ID != id {
|
||||
continue
|
||||
}
|
||||
m.Entries = append(m.Entries[:i], m.Entries[i+1:]...)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Model) EntryByID(id string) (Entry, error) {
|
||||
for _, entry := range m.Entries {
|
||||
if entry.ID == id {
|
||||
return cloneEntry(entry), nil
|
||||
}
|
||||
}
|
||||
return Entry{}, ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) UpsertRemoteProfile(profile RemoteProfile) {
|
||||
for i := range m.RemoteProfiles {
|
||||
if m.RemoteProfiles[i].ID != profile.ID {
|
||||
continue
|
||||
}
|
||||
m.RemoteProfiles[i] = profile
|
||||
return
|
||||
}
|
||||
m.RemoteProfiles = append(m.RemoteProfiles, profile)
|
||||
}
|
||||
|
||||
func (m *Model) RemoveRemoteProfileByID(id string) bool {
|
||||
for i := range m.RemoteProfiles {
|
||||
if m.RemoteProfiles[i].ID != id {
|
||||
continue
|
||||
}
|
||||
m.RemoteProfiles = append(m.RemoteProfiles[:i], m.RemoteProfiles[i+1:]...)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m Model) RemoteProfileByID(id string) (RemoteProfile, error) {
|
||||
for _, profile := range m.RemoteProfiles {
|
||||
if profile.ID == id {
|
||||
return profile, nil
|
||||
}
|
||||
}
|
||||
return RemoteProfile{}, ErrRemoteProfileNotFound
|
||||
}
|
||||
|
||||
func (m *Model) UpsertTemplate(entry Entry) {
|
||||
for i := range m.Templates {
|
||||
if m.Templates[i].ID != entry.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
m.Templates[i] = cloneEntry(entry)
|
||||
return
|
||||
}
|
||||
|
||||
m.Templates = append(m.Templates, cloneEntry(entry))
|
||||
}
|
||||
|
||||
func (m *Model) DeleteTemplate(id string) error {
|
||||
for i := range m.Templates {
|
||||
if m.Templates[i].ID != id {
|
||||
continue
|
||||
}
|
||||
|
||||
m.Templates = append(m.Templates[:i], m.Templates[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) DeleteEntry(id string) error {
|
||||
for i := range m.Entries {
|
||||
if m.Entries[i].ID != id {
|
||||
continue
|
||||
}
|
||||
|
||||
m.RecycleBin = append(m.RecycleBin, cloneEntry(m.Entries[i]))
|
||||
m.Entries = append(m.Entries[:i], m.Entries[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) RestoreEntry(id string) error {
|
||||
for i := range m.RecycleBin {
|
||||
if m.RecycleBin[i].ID != id {
|
||||
continue
|
||||
}
|
||||
|
||||
m.Entries = append(m.Entries, cloneEntry(m.RecycleBin[i]))
|
||||
m.RecycleBin = append(m.RecycleBin[:i], m.RecycleBin[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) InstantiateTemplate(templateID string, overrides Entry) (Entry, error) {
|
||||
for i := range m.Templates {
|
||||
if m.Templates[i].ID != templateID {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := mergeEntryTemplate(m.Templates[i], overrides)
|
||||
m.UpsertEntry(entry)
|
||||
return cloneEntry(entry), nil
|
||||
}
|
||||
|
||||
return Entry{}, ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) DuplicateEntry(id, duplicateID string) (Entry, error) {
|
||||
for i := range m.Entries {
|
||||
if m.Entries[i].ID != id {
|
||||
continue
|
||||
}
|
||||
|
||||
duplicate := cloneEntry(m.Entries[i])
|
||||
duplicate.ID = duplicateID
|
||||
duplicate.Title = duplicate.Title + " (Copy)"
|
||||
duplicate.History = nil
|
||||
m.Entries = append(m.Entries, duplicate)
|
||||
return cloneEntry(duplicate), nil
|
||||
}
|
||||
|
||||
return Entry{}, ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) RestoreEntryVersion(id string, historyIndex int) error {
|
||||
for i := range m.Entries {
|
||||
if m.Entries[i].ID != id {
|
||||
continue
|
||||
}
|
||||
if historyIndex < 0 || historyIndex >= len(m.Entries[i].History) {
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
current := cloneEntry(m.Entries[i])
|
||||
restored := cloneEntry(m.Entries[i].History[historyIndex])
|
||||
restored.ID = current.ID
|
||||
restored.History = append([]Entry{current}, append(
|
||||
cloneHistory(m.Entries[i].History[:historyIndex]),
|
||||
cloneHistory(m.Entries[i].History[historyIndex+1:])...,
|
||||
)...)
|
||||
m.Entries[i] = restored
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) CreateGroup(parent []string, name string) {
|
||||
groupPath := append([]string(nil), parent...)
|
||||
for _, part := range splitGroupPath(name) {
|
||||
groupPath = append(groupPath, part)
|
||||
if groupPathExists(m.Groups, groupPath) {
|
||||
continue
|
||||
}
|
||||
m.Groups = append(m.Groups, append([]string(nil), groupPath...))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) RenameGroup(path []string, newName string) error {
|
||||
if len(path) == 0 {
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
renamed := false
|
||||
newPath := append(append([]string(nil), path[:len(path)-1]...), newName)
|
||||
for i := range m.Entries {
|
||||
if !hasPathPrefix(m.Entries[i].Path, path) {
|
||||
continue
|
||||
}
|
||||
m.Entries[i].Path = append(append([]string(nil), newPath...), m.Entries[i].Path[len(path):]...)
|
||||
renamed = true
|
||||
}
|
||||
for i := range m.Templates {
|
||||
if !hasPathPrefix(m.Templates[i].Path, path) {
|
||||
continue
|
||||
}
|
||||
m.Templates[i].Path = append(append([]string(nil), newPath...), m.Templates[i].Path[len(path):]...)
|
||||
renamed = true
|
||||
}
|
||||
for i := range m.Groups {
|
||||
if !hasPathPrefix(m.Groups[i], path) {
|
||||
continue
|
||||
}
|
||||
m.Groups[i] = append(append([]string(nil), newPath...), m.Groups[i][len(path):]...)
|
||||
renamed = true
|
||||
}
|
||||
if !renamed {
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) MoveEntry(id string, path []string) error {
|
||||
for i := range m.Entries {
|
||||
if m.Entries[i].ID != id {
|
||||
continue
|
||||
}
|
||||
m.Entries[i].Path = append([]string(nil), path...)
|
||||
return nil
|
||||
}
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) MoveGroup(path, parent []string) error {
|
||||
if len(path) == 0 {
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
if hasPathPrefix(parent, path) {
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
groupName := path[len(path)-1]
|
||||
newPath := append(append([]string(nil), parent...), groupName)
|
||||
moved := false
|
||||
for i := range m.Entries {
|
||||
if !hasPathPrefix(m.Entries[i].Path, path) {
|
||||
continue
|
||||
}
|
||||
m.Entries[i].Path = append(append([]string(nil), newPath...), m.Entries[i].Path[len(path):]...)
|
||||
moved = true
|
||||
}
|
||||
for i := range m.Templates {
|
||||
if !hasPathPrefix(m.Templates[i].Path, path) {
|
||||
continue
|
||||
}
|
||||
m.Templates[i].Path = append(append([]string(nil), newPath...), m.Templates[i].Path[len(path):]...)
|
||||
moved = true
|
||||
}
|
||||
for i := range m.Groups {
|
||||
if !hasPathPrefix(m.Groups[i], path) {
|
||||
continue
|
||||
}
|
||||
m.Groups[i] = append(append([]string(nil), newPath...), m.Groups[i][len(path):]...)
|
||||
moved = true
|
||||
}
|
||||
if !moved {
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
if !groupPathExists(m.Groups, newPath) {
|
||||
m.Groups = append(m.Groups, append([]string(nil), newPath...))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) MoveTemplate(id string, path []string) error {
|
||||
for i := range m.Templates {
|
||||
if m.Templates[i].ID != id {
|
||||
continue
|
||||
}
|
||||
m.Templates[i].Path = append([]string(nil), path...)
|
||||
return nil
|
||||
}
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
func (m *Model) DeleteGroup(path []string) error {
|
||||
for _, entry := range m.Entries {
|
||||
if slices.Equal(entry.Path, path) || hasPathPrefix(entry.Path, path) {
|
||||
return ErrGroupNotEmpty
|
||||
}
|
||||
}
|
||||
for _, entry := range m.Templates {
|
||||
if slices.Equal(entry.Path, path) || hasPathPrefix(entry.Path, path) {
|
||||
return ErrGroupNotEmpty
|
||||
}
|
||||
}
|
||||
|
||||
for i := range m.Groups {
|
||||
if slices.Equal(m.Groups[i], path) {
|
||||
m.Groups = append(m.Groups[:i], m.Groups[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrEntryNotFound
|
||||
}
|
||||
|
||||
func hasPathPrefix(path, prefix []string) bool {
|
||||
if len(prefix) > len(path) {
|
||||
return false
|
||||
}
|
||||
return slices.Equal(path[:len(prefix)], prefix)
|
||||
}
|
||||
|
||||
func splitGroupPath(name string) []string {
|
||||
var parts []string
|
||||
for _, part := range strings.Split(name, "/") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func groupPathExists(groups [][]string, path []string) bool {
|
||||
for _, existing := range groups {
|
||||
if slices.Equal(existing, path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeEntryTemplate(template, overrides Entry) Entry {
|
||||
entry := cloneEntry(template)
|
||||
|
||||
if overrides.ID != "" {
|
||||
entry.ID = overrides.ID
|
||||
}
|
||||
if overrides.Title != "" {
|
||||
entry.Title = overrides.Title
|
||||
}
|
||||
if overrides.Username != "" {
|
||||
entry.Username = overrides.Username
|
||||
}
|
||||
if overrides.Password != "" {
|
||||
entry.Password = overrides.Password
|
||||
}
|
||||
if overrides.URL != "" {
|
||||
entry.URL = overrides.URL
|
||||
}
|
||||
if overrides.Notes != "" {
|
||||
entry.Notes = overrides.Notes
|
||||
}
|
||||
if len(overrides.Tags) > 0 {
|
||||
entry.Tags = slices.Clone(overrides.Tags)
|
||||
}
|
||||
if len(overrides.Path) > 0 {
|
||||
entry.Path = slices.Clone(overrides.Path)
|
||||
}
|
||||
|
||||
entry.Fields = mergeStringMaps(template.Fields, overrides.Fields)
|
||||
entry.Attachments = mergeBinaryMaps(template.Attachments, overrides.Attachments)
|
||||
entry.History = nil
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
func mergeStringMaps(base, overrides map[string]string) map[string]string {
|
||||
if len(base) == 0 && len(overrides) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(map[string]string, len(base)+len(overrides))
|
||||
for key, value := range base {
|
||||
out[key] = value
|
||||
}
|
||||
for key, value := range overrides {
|
||||
out[key] = value
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func mergeBinaryMaps(base, overrides map[string][]byte) map[string][]byte {
|
||||
if len(base) == 0 && len(overrides) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(map[string][]byte, len(base)+len(overrides))
|
||||
for key, value := range base {
|
||||
out[key] = slices.Clone(value)
|
||||
}
|
||||
for key, value := range overrides {
|
||||
out[key] = slices.Clone(value)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneEntry(entry Entry) Entry {
|
||||
entry.Tags = slices.Clone(entry.Tags)
|
||||
entry.Path = slices.Clone(entry.Path)
|
||||
entry.History = cloneHistory(entry.History)
|
||||
if entry.Fields != nil {
|
||||
fields := make(map[string]string, len(entry.Fields))
|
||||
for key, value := range entry.Fields {
|
||||
fields[key] = value
|
||||
}
|
||||
entry.Fields = fields
|
||||
}
|
||||
if entry.Attachments != nil {
|
||||
attachments := make(map[string][]byte, len(entry.Attachments))
|
||||
for key, value := range entry.Attachments {
|
||||
attachments[key] = slices.Clone(value)
|
||||
}
|
||||
entry.Attachments = attachments
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func cloneHistory(history []Entry) []Entry {
|
||||
if len(history) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]Entry, len(history))
|
||||
for i := range history {
|
||||
out[i] = cloneEntry(history[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,343 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testModel() Model {
|
||||
return Model{
|
||||
Entries: []Entry{
|
||||
{ID: "1", Title: "Bellagio", Username: "rustyryan", URL: "https://bellagio.example.invalid", Path: []string{"Crew", "Internet"}},
|
||||
{ID: "2", Title: "Vault Console", Username: "dannyocean", URL: "https://vault.crew.example.invalid", Path: []string{"Crew", "Internet"}},
|
||||
{ID: "3", Title: "Surveillance Console", Username: "codex", URL: "https://surveillance.crew.example.invalid", Path: []string{"Crew", "Home Assistant"}},
|
||||
{ID: "4", Title: "Alma (WA Prep)", Username: "christina.julian", URL: "https://waprep.getalma.com", Path: []string{"Tricia", "School"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildGroupsReturnsImmediateGroupsOnly(t *testing.T) {
|
||||
model := testModel()
|
||||
|
||||
got := model.ChildGroups([]string{"Crew"})
|
||||
want := []string{"Home Assistant", "Internet"}
|
||||
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("ChildGroups() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntriesInPathReturnsOnlyDirectEntries(t *testing.T) {
|
||||
model := testModel()
|
||||
|
||||
got := model.EntriesInPath([]string{"Crew", "Internet"})
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 2", len(got))
|
||||
}
|
||||
|
||||
if got[0].Title != "Bellagio" || got[1].Title != "Vault Console" {
|
||||
t.Fatalf("EntriesInPath() titles = %q, %q", got[0].Title, got[1].Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntriesUnderPathReturnsDescendantEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := testModel()
|
||||
got := model.EntriesUnderPath([]string{"Crew"})
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("len(EntriesUnderPath(Crew)) = %d, want 3", len(got))
|
||||
}
|
||||
if got[0].Title != "Bellagio" || got[1].Title != "Surveillance Console" || got[2].Title != "Vault Console" {
|
||||
t.Fatalf("EntriesUnderPath(Crew) titles = %q, %q, %q", got[0].Title, got[1].Title, got[2].Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchReturnsMatchesWithFullPathContext(t *testing.T) {
|
||||
model := testModel()
|
||||
|
||||
got := model.Search("vault")
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len(Search()) = %d, want 1", len(got))
|
||||
}
|
||||
|
||||
if got[0].Entry.Title != "Vault Console" {
|
||||
t.Fatalf("Search() title = %q, want %q", got[0].Entry.Title, "Vault Console")
|
||||
}
|
||||
|
||||
if got[0].Path != "Crew / Internet" {
|
||||
t.Fatalf("Search() path = %q, want %q", got[0].Path, "Crew / Internet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateEntriesAreStoredSeparatelyFromNormalEntries(t *testing.T) {
|
||||
model := testModel()
|
||||
|
||||
model.UpsertTemplate(Entry{
|
||||
ID: "tpl-1",
|
||||
Title: "Website Login",
|
||||
Username: "template-user",
|
||||
Password: "template-password",
|
||||
URL: "https://example.com",
|
||||
Notes: "Reusable template for website accounts.",
|
||||
Tags: []string{"template", "web"},
|
||||
Path: []string{"Templates"},
|
||||
})
|
||||
|
||||
if len(model.Entries) != 4 {
|
||||
t.Fatalf("len(Entries) = %d, want 4", len(model.Entries))
|
||||
}
|
||||
|
||||
if len(model.Templates) != 1 {
|
||||
t.Fatalf("len(Templates) = %d, want 1", len(model.Templates))
|
||||
}
|
||||
|
||||
if got := model.Templates[0].Title; got != "Website Login" {
|
||||
t.Fatalf("Templates[0].Title = %q, want %q", got, "Website Login")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstantiateTemplateCreatesNormalEntryWithOverrides(t *testing.T) {
|
||||
model := Model{
|
||||
Templates: []Entry{
|
||||
{
|
||||
ID: "tpl-1",
|
||||
Title: "Website Login",
|
||||
Username: "template-user",
|
||||
Password: "template-password",
|
||||
URL: "https://example.com",
|
||||
Notes: "Reusable template for website accounts.",
|
||||
Tags: []string{"template", "web"},
|
||||
Fields: map[string]string{
|
||||
"Environment": "prod",
|
||||
},
|
||||
Path: []string{"Templates"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entry, err := model.InstantiateTemplate("tpl-1", Entry{
|
||||
ID: "entry-1",
|
||||
Title: "Bellagio",
|
||||
Username: "rustyryan",
|
||||
Password: "hunter2",
|
||||
URL: "https://bellagio.example.invalid",
|
||||
Path: []string{"Crew", "Internet"},
|
||||
Tags: []string{"dns"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("InstantiateTemplate() error = %v", err)
|
||||
}
|
||||
|
||||
if entry.ID != "entry-1" {
|
||||
t.Fatalf("entry.ID = %q, want %q", entry.ID, "entry-1")
|
||||
}
|
||||
|
||||
if entry.Title != "Bellagio" {
|
||||
t.Fatalf("entry.Title = %q, want %q", entry.Title, "Bellagio")
|
||||
}
|
||||
|
||||
if entry.Username != "rustyryan" || entry.Password != "hunter2" || entry.URL != "https://bellagio.example.invalid" {
|
||||
t.Fatalf("entry credentials = %#v, want override values", entry)
|
||||
}
|
||||
|
||||
if entry.Notes != "Reusable template for website accounts." {
|
||||
t.Fatalf("entry.Notes = %q, want %q", entry.Notes, "Reusable template for website accounts.")
|
||||
}
|
||||
|
||||
if !slices.Equal(entry.Tags, []string{"dns"}) {
|
||||
t.Fatalf("entry.Tags = %v, want [dns]", entry.Tags)
|
||||
}
|
||||
|
||||
if entry.Fields["Environment"] != "prod" {
|
||||
t.Fatalf("entry.Fields[Environment] = %q, want %q", entry.Fields["Environment"], "prod")
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Crew", "Internet"})
|
||||
if len(got) != 1 || got[0].Title != "Bellagio" {
|
||||
t.Fatalf("EntriesInPath() = %#v, want instantiated Bellagio entry", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstantiateTemplateFailsForUnknownTemplate(t *testing.T) {
|
||||
model := Model{}
|
||||
|
||||
_, err := model.InstantiateTemplate("missing-template", Entry{ID: "entry-1"})
|
||||
if err == nil {
|
||||
t.Fatal("InstantiateTemplate() error = nil, want ErrEntryNotFound")
|
||||
}
|
||||
|
||||
if !errors.Is(err, ErrEntryNotFound) {
|
||||
t.Fatalf("InstantiateTemplate() error = %v, want ErrEntryNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteTemplateRemovesTemplateWithoutTouchingEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{ID: "entry-1", Title: "Vault Console", Path: []string{"Root", "Internet"}},
|
||||
},
|
||||
Templates: []Entry{
|
||||
{ID: "tpl-1", Title: "Website Login", Path: []string{"Templates"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := model.DeleteTemplate("tpl-1"); err != nil {
|
||||
t.Fatalf("DeleteTemplate() error = %v", err)
|
||||
}
|
||||
|
||||
if len(model.Templates) != 0 {
|
||||
t.Fatalf("len(Templates) = %d, want 0", len(model.Templates))
|
||||
}
|
||||
if len(model.Entries) != 1 || model.Entries[0].ID != "entry-1" {
|
||||
t.Fatalf("Entries = %#v, want unchanged normal entry", model.Entries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveTemplateChangesItsPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Templates: []Entry{
|
||||
{ID: "tpl-1", Title: "Website Login", Path: []string{"Templates", "Web"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := model.MoveTemplate("tpl-1", []string{"Templates", "Infra"}); err != nil {
|
||||
t.Fatalf("MoveTemplate() error = %v", err)
|
||||
}
|
||||
|
||||
if got := model.Templates[0].Path; !slices.Equal(got, []string{"Templates", "Infra"}) {
|
||||
t.Fatalf("Templates[0].Path = %v, want [Templates Infra]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateEntryCopiesEntryWithNewIDAndTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := Model{
|
||||
Entries: []Entry{
|
||||
{
|
||||
ID: "entry-1",
|
||||
Title: "Vault Console",
|
||||
Username: "dannyocean",
|
||||
Password: "token-1",
|
||||
Path: []string{"Root", "Internet"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
duplicate, err := model.DuplicateEntry("entry-1", "entry-2")
|
||||
if err != nil {
|
||||
t.Fatalf("DuplicateEntry() error = %v", err)
|
||||
}
|
||||
|
||||
if duplicate.ID != "entry-2" {
|
||||
t.Fatalf("duplicate.ID = %q, want %q", duplicate.ID, "entry-2")
|
||||
}
|
||||
if duplicate.Title != "Vault Console (Copy)" {
|
||||
t.Fatalf("duplicate.Title = %q, want %q", duplicate.Title, "Vault Console (Copy)")
|
||||
}
|
||||
got := model.EntriesInPath([]string{"Root", "Internet"})
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("len(EntriesInPath()) = %d, want 2", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateGroupMakesItVisibleAsChildGroup(t *testing.T) {
|
||||
model := testModel()
|
||||
|
||||
model.CreateGroup([]string{"Crew"}, "Finance")
|
||||
|
||||
got := model.ChildGroups([]string{"Crew"})
|
||||
want := []string{"Finance", "Home Assistant", "Internet"}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("ChildGroups() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateGroupSupportsNestedRelativePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := testModel()
|
||||
model.CreateGroup([]string{"Crew"}, "Infrastructure / Prod")
|
||||
|
||||
got := model.ChildGroups([]string{"Crew"})
|
||||
if !slices.Equal(got, []string{"Home Assistant", "Infrastructure", "Internet"}) {
|
||||
t.Fatalf("ChildGroups(Crew) = %v, want [Home Assistant Infrastructure Internet]", got)
|
||||
}
|
||||
got = model.ChildGroups([]string{"Crew", "Infrastructure"})
|
||||
if !slices.Equal(got, []string{"Prod"}) {
|
||||
t.Fatalf("ChildGroups(Crew/Infrastructure) = %v, want [Prod]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameGroupMovesEntriesAndKeepsHierarchy(t *testing.T) {
|
||||
model := testModel()
|
||||
|
||||
if err := model.RenameGroup([]string{"Crew", "Internet"}, "Infra"); err != nil {
|
||||
t.Fatalf("RenameGroup() error = %v", err)
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Crew", "Infra"})
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("len(EntriesInPath(Crew/Infra)) = %d, want 2", len(got))
|
||||
}
|
||||
|
||||
if len(model.EntriesInPath([]string{"Crew", "Internet"})) != 0 {
|
||||
t.Fatal("EntriesInPath(Crew/Internet) should be empty after rename")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveEntryChangesItsPath(t *testing.T) {
|
||||
model := testModel()
|
||||
|
||||
if err := model.MoveEntry("1", []string{"Tricia", "School"}); err != nil {
|
||||
t.Fatalf("MoveEntry() error = %v", err)
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Tricia", "School"})
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("len(EntriesInPath(Tricia/School)) = %d, want 2", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveGroupMovesEntriesAndNestedGroups(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := testModel()
|
||||
model.CreateGroup([]string{"Crew", "Internet"}, "Infrastructure")
|
||||
if err := model.MoveGroup([]string{"Crew", "Internet"}, []string{"Tricia"}); err != nil {
|
||||
t.Fatalf("MoveGroup() error = %v", err)
|
||||
}
|
||||
|
||||
got := model.EntriesInPath([]string{"Tricia", "Internet"})
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("len(EntriesInPath(Tricia/Internet)) = %d, want 2", len(got))
|
||||
}
|
||||
if len(model.EntriesInPath([]string{"Crew", "Internet"})) != 0 {
|
||||
t.Fatal("EntriesInPath(Crew/Internet) should be empty after move")
|
||||
}
|
||||
gotGroups := model.ChildGroups([]string{"Tricia", "Internet"})
|
||||
if !slices.Equal(gotGroups, []string{"Infrastructure"}) {
|
||||
t.Fatalf("ChildGroups(Tricia/Internet) = %v, want [Infrastructure]", gotGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteEmptyGroupRemovesItFromNavigation(t *testing.T) {
|
||||
model := testModel()
|
||||
|
||||
model.CreateGroup([]string{"Crew"}, "Finance")
|
||||
if err := model.DeleteGroup([]string{"Crew", "Finance"}); err != nil {
|
||||
t.Fatalf("DeleteGroup() error = %v", err)
|
||||
}
|
||||
|
||||
got := model.ChildGroups([]string{"Crew"})
|
||||
want := []string{"Home Assistant", "Internet"}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("ChildGroups() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/tobischo/gokeepasslib/v3"
|
||||
)
|
||||
|
||||
type SecuritySettings struct {
|
||||
Cipher string
|
||||
KDF string
|
||||
}
|
||||
|
||||
const (
|
||||
CipherAES256 = "aes256"
|
||||
CipherChaCha20 = "chacha20"
|
||||
KDFAES = "aes-kdf"
|
||||
KDFArgon2 = "argon2"
|
||||
)
|
||||
|
||||
func SupportedSecuritySettings() (ciphers []string, kdfs []string) {
|
||||
return []string{CipherAES256, CipherChaCha20}, []string{KDFAES, KDFArgon2}
|
||||
}
|
||||
|
||||
func DetectSecuritySettings(config *KDBXConfig) SecuritySettings {
|
||||
settings := SecuritySettings{
|
||||
Cipher: CipherChaCha20,
|
||||
KDF: KDFArgon2,
|
||||
}
|
||||
if config == nil || config.Header == nil || config.Header.FileHeaders == nil {
|
||||
return settings
|
||||
}
|
||||
if slices.Equal(config.Header.FileHeaders.CipherID, gokeepasslib.CipherAES) {
|
||||
settings.Cipher = CipherAES256
|
||||
}
|
||||
if config.Header.FileHeaders.KdfParameters != nil && slices.Equal(config.Header.FileHeaders.KdfParameters.UUID, gokeepasslib.KdfAES4) {
|
||||
settings.KDF = KDFAES
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
func NewSecurityConfig(settings SecuritySettings) (*KDBXConfig, error) {
|
||||
db := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4())
|
||||
config := &KDBXConfig{
|
||||
Header: cloneHeader(db.Header),
|
||||
InnerHeader: cloneInnerHeader(db.Content.InnerHeader),
|
||||
}
|
||||
return ApplySecuritySettings(config, settings)
|
||||
}
|
||||
|
||||
func ApplySecuritySettings(config *KDBXConfig, settings SecuritySettings) (*KDBXConfig, error) {
|
||||
if config == nil || config.Header == nil || config.Header.FileHeaders == nil {
|
||||
return NewSecurityConfig(settings)
|
||||
}
|
||||
out := &KDBXConfig{
|
||||
Header: cloneHeader(config.Header),
|
||||
InnerHeader: cloneInnerHeader(config.InnerHeader),
|
||||
}
|
||||
if out.Header.FileHeaders.KdfParameters == nil {
|
||||
defaults := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4())
|
||||
out.Header.FileHeaders.KdfParameters = cloneHeader(defaults.Header).FileHeaders.KdfParameters
|
||||
}
|
||||
switch settings.Cipher {
|
||||
case "", CipherChaCha20:
|
||||
out.Header.FileHeaders.CipherID = slices.Clone(gokeepasslib.CipherChaCha20)
|
||||
out.Header.FileHeaders.EncryptionIV = randomBytes(12)
|
||||
case CipherAES256:
|
||||
out.Header.FileHeaders.CipherID = slices.Clone(gokeepasslib.CipherAES)
|
||||
out.Header.FileHeaders.EncryptionIV = randomBytes(16)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported cipher %q", settings.Cipher)
|
||||
}
|
||||
|
||||
var salt [32]byte
|
||||
copy(salt[:], randomBytes(32))
|
||||
switch settings.KDF {
|
||||
case "", KDFArgon2:
|
||||
defaults := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4())
|
||||
kdf := defaults.Header.FileHeaders.KdfParameters
|
||||
out.Header.FileHeaders.KdfParameters = &gokeepasslib.KdfParameters{
|
||||
UUID: slices.Clone(gokeepasslib.KdfArgon2),
|
||||
Rounds: kdf.Rounds,
|
||||
Salt: salt,
|
||||
Parallelism: kdf.Parallelism,
|
||||
Memory: kdf.Memory,
|
||||
Iterations: kdf.Iterations,
|
||||
Version: kdf.Version,
|
||||
}
|
||||
case KDFAES:
|
||||
out.Header.FileHeaders.KdfParameters = &gokeepasslib.KdfParameters{
|
||||
UUID: slices.Clone(gokeepasslib.KdfAES4),
|
||||
Rounds: 6000,
|
||||
Salt: salt,
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported KDF %q", settings.KDF)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/tobischo/gokeepasslib/v3"
|
||||
)
|
||||
|
||||
func TestNewSecurityConfigCreatesRequestedCipherAndKDF(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := NewSecurityConfig(SecuritySettings{Cipher: CipherAES256, KDF: KDFAES})
|
||||
if err != nil {
|
||||
t.Fatalf("NewSecurityConfig() error = %v", err)
|
||||
}
|
||||
if !slices.Equal(config.Header.FileHeaders.CipherID, gokeepasslib.CipherAES) {
|
||||
t.Fatalf("CipherID = %x, want %x", config.Header.FileHeaders.CipherID, gokeepasslib.CipherAES)
|
||||
}
|
||||
if !slices.Equal(config.Header.FileHeaders.KdfParameters.UUID, gokeepasslib.KdfAES4) {
|
||||
t.Fatalf("KDF UUID = %x, want %x", config.Header.FileHeaders.KdfParameters.UUID, gokeepasslib.KdfAES4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplySecuritySettingsPreservesRequestedChoicesAcrossSave(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := NewSecurityConfig(SecuritySettings{Cipher: CipherChaCha20, KDF: KDFArgon2})
|
||||
if err != nil {
|
||||
t.Fatalf("NewSecurityConfig() error = %v", err)
|
||||
}
|
||||
config, err = ApplySecuritySettings(config, SecuritySettings{Cipher: CipherAES256, KDF: KDFAES})
|
||||
if err != nil {
|
||||
t.Fatalf("ApplySecuritySettings() error = %v", err)
|
||||
}
|
||||
|
||||
var encoded bytes.Buffer
|
||||
if err := SaveKDBXWithConfigAndKey(&encoded, Model{}, MasterKey{Password: "correct horse battery staple"}, config); err != nil {
|
||||
t.Fatalf("SaveKDBXWithConfigAndKey() error = %v", err)
|
||||
}
|
||||
_, reloadedConfig, err := LoadKDBXWithConfig(bytes.NewReader(encoded.Bytes()), MasterKey{Password: "correct horse battery staple"})
|
||||
if err != nil {
|
||||
t.Fatalf("LoadKDBXWithConfig() error = %v", err)
|
||||
}
|
||||
got := DetectSecuritySettings(reloadedConfig)
|
||||
if got.Cipher != CipherAES256 || got.KDF != KDFAES {
|
||||
t.Fatalf("DetectSecuritySettings() = %#v, want aes256/aes-kdf", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package webdav
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrConflict = errors.New("webdav conflict")
|
||||
|
||||
type Version struct {
|
||||
ETag string
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
HTTPClient *http.Client
|
||||
BaseURL string
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
func (c Client) Open(path string) ([]byte, Version, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, c.url(path), nil)
|
||||
if err != nil {
|
||||
return nil, Version{}, fmt.Errorf("build GET request: %w", err)
|
||||
}
|
||||
c.applyAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, Version{}, fmt.Errorf("GET %s: %w", path, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, Version{}, fmt.Errorf("GET %s: unexpected status %d", path, resp.StatusCode)
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, Version{}, fmt.Errorf("read %s: %w", path, err)
|
||||
}
|
||||
|
||||
return content, Version{ETag: resp.Header.Get("ETag")}, nil
|
||||
}
|
||||
|
||||
func (c Client) Save(path string, content io.Reader, version Version) (Version, error) {
|
||||
req, err := http.NewRequest(http.MethodPut, c.url(path), content)
|
||||
if err != nil {
|
||||
return Version{}, fmt.Errorf("build PUT request: %w", err)
|
||||
}
|
||||
c.applyAuth(req)
|
||||
if version.ETag != "" {
|
||||
req.Header.Set("If-Match", version.ETag)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return Version{}, fmt.Errorf("PUT %s: %w", path, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusPreconditionFailed {
|
||||
return Version{}, ErrConflict
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusNoContent {
|
||||
return Version{}, fmt.Errorf("PUT %s: unexpected status %d", path, resp.StatusCode)
|
||||
}
|
||||
|
||||
return Version{ETag: resp.Header.Get("ETag")}, nil
|
||||
}
|
||||
|
||||
func (c Client) httpClient() *http.Client {
|
||||
if c.HTTPClient != nil {
|
||||
return c.HTTPClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
func (c Client) applyAuth(req *http.Request) {
|
||||
if c.Username == "" && c.Password == "" {
|
||||
return
|
||||
}
|
||||
req.SetBasicAuth(c.Username, c.Password)
|
||||
}
|
||||
|
||||
func (c Client) url(path string) string {
|
||||
base := strings.TrimRight(c.BaseURL, "/")
|
||||
path = strings.TrimLeft(path, "/")
|
||||
return base + "/" + path
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package webdav
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientOpenDownloadsRemoteVault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Fatalf("method = %s, want GET", r.Method)
|
||||
}
|
||||
if user, pass, ok := r.BasicAuth(); !ok || user != "rustyryan" || pass != "secret" {
|
||||
t.Fatalf("basic auth = %q/%q ok=%v, want rustyryan/secret true", user, pass, ok)
|
||||
}
|
||||
w.Header().Set("ETag", `"etag-1"`)
|
||||
_, _ = io.WriteString(w, "vault-bytes")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := Client{
|
||||
HTTPClient: server.Client(),
|
||||
BaseURL: server.URL,
|
||||
Username: "rustyryan",
|
||||
Password: "secret",
|
||||
}
|
||||
|
||||
content, version, err := client.Open("keepass.kdbx")
|
||||
if err != nil {
|
||||
t.Fatalf("Open() error = %v", err)
|
||||
}
|
||||
|
||||
if string(content) != "vault-bytes" {
|
||||
t.Fatalf("Open() content = %q, want %q", string(content), "vault-bytes")
|
||||
}
|
||||
|
||||
if version.ETag != `"etag-1"` {
|
||||
t.Fatalf("Open() ETag = %q, want %q", version.ETag, `"etag-1"`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSaveUploadsVaultWithIfMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
t.Fatalf("method = %s, want PUT", r.Method)
|
||||
}
|
||||
if got := r.Header.Get("If-Match"); got != `"etag-1"` {
|
||||
t.Fatalf("If-Match = %q, want %q", got, `"etag-1"`)
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error = %v", err)
|
||||
}
|
||||
if string(body) != "updated-vault" {
|
||||
t.Fatalf("PUT body = %q, want %q", string(body), "updated-vault")
|
||||
}
|
||||
w.Header().Set("ETag", `"etag-2"`)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := Client{
|
||||
HTTPClient: server.Client(),
|
||||
BaseURL: server.URL,
|
||||
Username: "rustyryan",
|
||||
Password: "secret",
|
||||
}
|
||||
|
||||
version, err := client.Save("keepass.kdbx", bytes.NewBufferString("updated-vault"), Version{ETag: `"etag-1"`})
|
||||
if err != nil {
|
||||
t.Fatalf("Save() error = %v", err)
|
||||
}
|
||||
|
||||
if version.ETag != `"etag-2"` {
|
||||
t.Fatalf("Save() ETag = %q, want %q", version.ETag, `"etag-2"`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSaveReturnsConflictOnVersionMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusPreconditionFailed)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := Client{
|
||||
HTTPClient: server.Client(),
|
||||
BaseURL: server.URL,
|
||||
Username: "rustyryan",
|
||||
Password: "secret",
|
||||
}
|
||||
|
||||
_, err := client.Save("keepass.kdbx", bytes.NewBufferString("updated-vault"), Version{ETag: `"etag-1"`})
|
||||
if err != ErrConflict {
|
||||
t.Fatalf("Save() error = %v, want %v", err, ErrConflict)
|
||||
}
|
||||
}
|
||||