Files
keepassgo/session/session.go
T
2026-03-29 21:19:42 -07:00

532 lines
12 KiB
Go

package session
import (
"bytes"
"errors"
"fmt"
"os"
"path/filepath"
"reflect"
"slices"
"strings"
"git.julianfamily.org/keepassgo/vault"
"git.julianfamily.org/keepassgo/webdav"
)
var (
ErrLocked = errors.New("vault is locked")
ErrNoPath = errors.New("no vault path configured")
)
type Manager struct {
model vault.Model
config *vault.KDBXConfig
path string
key vault.MasterKey
locked bool
encoded []byte
remoteClient *webdav.Client
remotePath string
remoteVersion webdav.Version
}
func (m *Manager) Create(model vault.Model, key vault.MasterKey) error {
var encoded bytes.Buffer
if err := vault.SaveKDBXWithConfigAndKey(&encoded, model, key, m.config); err != nil {
return fmt.Errorf("encode new vault: %w", err)
}
m.model = model
m.key = key
m.encoded = encoded.Bytes()
m.locked = false
return nil
}
func (m *Manager) HasVault() bool {
return len(m.encoded) > 0 || m.path != "" || m.remotePath != ""
}
func (m *Manager) IsLocked() bool {
return m.locked
}
func (m *Manager) IsRemote() bool {
return m.remoteClient != nil && m.remotePath != ""
}
func (m *Manager) Open(path string, key vault.MasterKey) error {
content, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("read %s: %w", path, err)
}
model, config, err := vault.LoadKDBXWithConfig(bytes.NewReader(content), key)
if err != nil {
return fmt.Errorf("open %s: %w", path, err)
}
m.model = model
m.config = config
m.path = path
m.key = key
m.encoded = content
m.locked = false
return nil
}
func (m *Manager) Save() error {
if m.remoteClient != nil && m.remotePath != "" {
return m.SaveRemote()
}
if m.path == "" {
return ErrNoPath
}
return m.saveToPath(m.path)
}
func (m *Manager) OpenRemote(client webdav.Client, path string, key vault.MasterKey) error {
content, version, err := client.Open(path)
if err != nil {
return fmt.Errorf("open remote %s: %w", path, err)
}
model, config, err := vault.LoadKDBXWithConfig(bytes.NewReader(content), key)
if err != nil {
return fmt.Errorf("decode remote %s: %w", path, err)
}
m.model = model
m.config = config
m.key = key
m.encoded = content
m.locked = false
m.remoteClient = &client
m.remotePath = path
m.remoteVersion = version
return nil
}
func (m *Manager) SaveRemote() error {
if m.remoteClient == nil || m.remotePath == "" {
return ErrNoPath
}
encoded, err := m.persistableBytes()
if err != nil {
return err
}
version, err := m.remoteClient.Save(m.remotePath, bytes.NewReader(encoded), m.remoteVersion)
if err != nil {
return fmt.Errorf("save remote %s: %w", m.remotePath, err)
}
m.encoded = encoded
m.remoteVersion = version
return nil
}
func (m *Manager) Synchronize() error {
switch {
case m.remoteClient != nil && m.remotePath != "":
return m.synchronizeRemote()
case m.path != "":
return m.synchronizeLocal()
default:
return ErrNoPath
}
}
func (m *Manager) SaveAs(path string) error {
if err := m.saveToPath(path); err != nil {
return err
}
m.path = path
return nil
}
func (m *Manager) Replace(model vault.Model) {
m.model = model
m.locked = false
}
func (m *Manager) Current() (vault.Model, error) {
if m.locked {
return vault.Model{}, ErrLocked
}
return m.model, nil
}
func (m *Manager) Lock() error {
if m.locked {
return nil
}
var encoded bytes.Buffer
if err := vault.SaveKDBXWithConfigAndKey(&encoded, m.model, m.key, m.config); err != nil {
return fmt.Errorf("encode vault for lock: %w", err)
}
m.encoded = encoded.Bytes()
m.model = vault.Model{}
m.locked = true
return nil
}
func (m *Manager) Unlock(key vault.MasterKey) error {
model, config, err := vault.LoadKDBXWithConfig(bytes.NewReader(m.encoded), key)
if err != nil {
return fmt.Errorf("unlock vault: %w", err)
}
m.model = model
m.config = config
m.key = key
m.locked = false
return nil
}
func (m *Manager) ChangeMasterKey(key vault.MasterKey) error {
var (
model vault.Model
config *vault.KDBXConfig
err error
)
if m.locked {
model, config, err = vault.LoadKDBXWithConfig(bytes.NewReader(m.encoded), m.key)
if err != nil {
return fmt.Errorf("decode locked vault: %w", err)
}
} else {
model = m.model
config = m.config
}
var encoded bytes.Buffer
if err := vault.SaveKDBXWithConfigAndKey(&encoded, model, key, config); err != nil {
return fmt.Errorf("encode vault with updated master key: %w", err)
}
m.key = key
m.config = config
m.encoded = encoded.Bytes()
if !m.locked {
m.model = model
}
return nil
}
func (m *Manager) saveToPath(path string) error {
encoded, err := m.persistableBytes()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("create parent dir for %s: %w", path, err)
}
if err := os.WriteFile(path, encoded, 0o600); err != nil {
return fmt.Errorf("write %s: %w", path, err)
}
m.encoded = encoded
return nil
}
func (m *Manager) persistableBytes() ([]byte, error) {
if m.locked {
return append([]byte(nil), m.encoded...), nil
}
var encoded bytes.Buffer
if err := vault.SaveKDBXWithConfigAndKey(&encoded, m.model, m.key, m.config); err != nil {
return nil, fmt.Errorf("encode vault: %w", err)
}
return encoded.Bytes(), nil
}
func (m *Manager) synchronizeLocal() error {
current, err := m.currentModelForPersistence()
if err != nil {
return err
}
content, err := os.ReadFile(m.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return m.saveToPath(m.path)
}
return fmt.Errorf("read %s: %w", m.path, err)
}
latest, config, err := vault.LoadKDBXWithConfig(bytes.NewReader(content), m.key)
if err != nil {
return fmt.Errorf("open %s for synchronize: %w", m.path, err)
}
base, err := m.baseModel()
if err != nil {
return err
}
merged := mergeModels(base, current, latest)
var encoded bytes.Buffer
if err := vault.SaveKDBXWithConfigAndKey(&encoded, merged, m.key, config); err != nil {
return fmt.Errorf("encode synchronized vault: %w", err)
}
if err := os.WriteFile(m.path, encoded.Bytes(), 0o600); err != nil {
return fmt.Errorf("write synchronized %s: %w", m.path, err)
}
m.model = merged
m.config = config
m.encoded = encoded.Bytes()
m.locked = false
return nil
}
func (m *Manager) synchronizeRemote() error {
current, err := m.currentModelForPersistence()
if err != nil {
return err
}
content, version, err := m.remoteClient.Open(m.remotePath)
if err != nil {
return fmt.Errorf("open remote %s for synchronize: %w", m.remotePath, err)
}
latest, config, err := vault.LoadKDBXWithConfig(bytes.NewReader(content), m.key)
if err != nil {
return fmt.Errorf("decode remote %s for synchronize: %w", m.remotePath, err)
}
base, err := m.baseModel()
if err != nil {
return err
}
merged := mergeModels(base, current, latest)
var encoded bytes.Buffer
if err := vault.SaveKDBXWithConfigAndKey(&encoded, merged, m.key, config); err != nil {
return fmt.Errorf("encode synchronized remote vault: %w", err)
}
nextVersion, err := m.remoteClient.Save(m.remotePath, bytes.NewReader(encoded.Bytes()), version)
if err != nil {
return fmt.Errorf("save synchronized remote %s: %w", m.remotePath, err)
}
m.model = merged
m.config = config
m.encoded = encoded.Bytes()
m.remoteVersion = nextVersion
m.locked = false
return nil
}
func (m *Manager) currentModelForPersistence() (vault.Model, error) {
if m.locked {
return vault.LoadKDBXWithKey(bytes.NewReader(m.encoded), m.key)
}
return m.model, nil
}
func (m *Manager) baseModel() (vault.Model, error) {
if len(m.encoded) == 0 {
return vault.Model{}, nil
}
model, err := vault.LoadKDBXWithKey(bytes.NewReader(m.encoded), m.key)
if err != nil {
return vault.Model{}, fmt.Errorf("decode baseline vault: %w", err)
}
return model, nil
}
func mergeModels(base, local, latest vault.Model) vault.Model {
merged := latest
merged.Entries = mergeEntrySet(base.Entries, local.Entries, latest.Entries)
merged.Templates = mergeEntrySet(base.Templates, local.Templates, latest.Templates)
merged.RecycleBin = mergeEntrySet(base.RecycleBin, local.RecycleBin, latest.RecycleBin)
merged.Groups = mergeGroups(base.Groups, local.Groups, latest.Groups)
return merged
}
func mergeEntrySet(base, local, latest []vault.Entry) []vault.Entry {
baseByID := mapEntries(base)
localByID := mapEntries(local)
latestByID := mapEntries(latest)
for id, current := range localByID {
original, hadBase := baseByID[id]
if !hadBase || !entriesEqual(original, current) {
if latestCurrent, latestChanged := latestByID[id]; hadBase && latestChanged && !entriesEqual(original, latestCurrent) && !entriesEqual(latestCurrent, current) {
current = mergeConflictedEntry(current, latestCurrent)
}
latestByID[id] = current
}
}
for id := range baseByID {
if _, stillLocal := localByID[id]; stillLocal {
continue
}
delete(latestByID, id)
}
out := make([]vault.Entry, 0, len(latestByID))
for _, item := range latestByID {
out = append(out, item)
}
slices.SortFunc(out, func(a, b vault.Entry) int {
switch {
case a.Title < b.Title:
return -1
case a.Title > b.Title:
return 1
default:
return 0
}
})
return out
}
func mergeConflictedEntry(current, latest vault.Entry) vault.Entry {
displaced := cloneEntry(latest)
if sameEntryVersion(current, displaced) {
return current
}
mergedHistory := make([]vault.Entry, 0, len(current.History)+1)
mergedHistory = append(mergedHistory, displaced)
for _, item := range current.History {
if sameEntryVersion(item, displaced) {
continue
}
mergedHistory = append(mergedHistory, cloneEntry(item))
}
current.History = mergedHistory
return current
}
func mapEntries(entries []vault.Entry) map[string]vault.Entry {
out := make(map[string]vault.Entry, len(entries))
for _, item := range entries {
out[item.ID] = item
}
return out
}
func entriesEqual(a, b vault.Entry) bool {
return a.ID == b.ID &&
a.Title == b.Title &&
a.Username == b.Username &&
a.Password == b.Password &&
a.URL == b.URL &&
a.Notes == b.Notes &&
slices.Equal(a.Tags, b.Tags) &&
slices.Equal(a.Path, b.Path) &&
reflect.DeepEqual(a.History, b.History) &&
reflect.DeepEqual(a.Fields, b.Fields) &&
equalAttachments(a.Attachments, b.Attachments)
}
func equalAttachments(a, b map[string][]byte) bool {
if len(a) != len(b) {
return false
}
for key, value := range a {
if !slices.Equal(value, b[key]) {
return false
}
}
return true
}
func cloneEntry(entry vault.Entry) vault.Entry {
entry.Tags = slices.Clone(entry.Tags)
entry.Path = slices.Clone(entry.Path)
entry.History = cloneHistory(entry.History)
if entry.Fields != nil {
fields := make(map[string]string, len(entry.Fields))
for key, value := range entry.Fields {
fields[key] = value
}
entry.Fields = fields
}
if entry.Attachments != nil {
attachments := make(map[string][]byte, len(entry.Attachments))
for key, value := range entry.Attachments {
attachments[key] = slices.Clone(value)
}
entry.Attachments = attachments
}
return entry
}
func cloneHistory(history []vault.Entry) []vault.Entry {
if len(history) == 0 {
return nil
}
out := make([]vault.Entry, len(history))
for i := range history {
out[i] = cloneEntry(history[i])
}
return out
}
func sameEntryVersion(a, b vault.Entry) bool {
return entriesEqual(a, b)
}
func mergeGroups(base, local, latest [][]string) [][]string {
set := map[string][]string{}
for _, path := range latest {
set[pathKey(path)] = append([]string(nil), path...)
}
baseSet := map[string]bool{}
for _, path := range base {
baseSet[pathKey(path)] = true
}
localSet := map[string]bool{}
for _, path := range local {
key := pathKey(path)
localSet[key] = true
set[key] = append([]string(nil), path...)
}
for key := range baseSet {
if localSet[key] {
continue
}
delete(set, key)
}
out := make([][]string, 0, len(set))
for _, path := range set {
out = append(out, path)
}
slices.SortFunc(out, func(a, b []string) int {
joinedA := pathKey(a)
joinedB := pathKey(b)
switch {
case joinedA < joinedB:
return -1
case joinedA > joinedB:
return 1
default:
return 0
}
})
return out
}
func pathKey(path []string) string {
return strings.Join(path, "\x00")
}