Simplify KeePassGO desktop vault workflow
This commit is contained in:
@@ -5,6 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"git.julianfamily.org/keepassgo/vault"
|
||||
"git.julianfamily.org/keepassgo/webdav"
|
||||
@@ -40,6 +44,18 @@ func (m *Manager) Create(model vault.Model, key vault.MasterKey) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) HasVault() bool {
|
||||
return len(m.encoded) > 0 || m.path != "" || m.remotePath != ""
|
||||
}
|
||||
|
||||
func (m *Manager) IsLocked() bool {
|
||||
return m.locked
|
||||
}
|
||||
|
||||
func (m *Manager) IsRemote() bool {
|
||||
return m.remoteClient != nil && m.remotePath != ""
|
||||
}
|
||||
|
||||
func (m *Manager) Open(path string, key vault.MasterKey) error {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
@@ -114,6 +130,17 @@ func (m *Manager) SaveRemote() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Synchronize() error {
|
||||
switch {
|
||||
case m.remoteClient != nil && m.remotePath != "":
|
||||
return m.synchronizeRemote()
|
||||
case m.path != "":
|
||||
return m.synchronizeLocal()
|
||||
default:
|
||||
return ErrNoPath
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) SaveAs(path string) error {
|
||||
if err := m.saveToPath(path); err != nil {
|
||||
return err
|
||||
@@ -203,6 +230,9 @@ func (m *Manager) saveToPath(path string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
return fmt.Errorf("create parent dir for %s: %w", path, err)
|
||||
}
|
||||
if err := os.WriteFile(path, encoded, 0o600); err != nil {
|
||||
return fmt.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
@@ -222,3 +252,223 @@ func (m *Manager) persistableBytes() ([]byte, error) {
|
||||
}
|
||||
return encoded.Bytes(), nil
|
||||
}
|
||||
|
||||
func (m *Manager) synchronizeLocal() error {
|
||||
current, err := m.currentModelForPersistence()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(m.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return m.saveToPath(m.path)
|
||||
}
|
||||
return fmt.Errorf("read %s: %w", m.path, err)
|
||||
}
|
||||
|
||||
latest, config, err := vault.LoadKDBXWithConfig(bytes.NewReader(content), m.key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open %s for synchronize: %w", m.path, err)
|
||||
}
|
||||
|
||||
base, err := m.baseModel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
merged := mergeModels(base, current, latest)
|
||||
var encoded bytes.Buffer
|
||||
if err := vault.SaveKDBXWithConfigAndKey(&encoded, merged, m.key, config); err != nil {
|
||||
return fmt.Errorf("encode synchronized vault: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(m.path, encoded.Bytes(), 0o600); err != nil {
|
||||
return fmt.Errorf("write synchronized %s: %w", m.path, err)
|
||||
}
|
||||
|
||||
m.model = merged
|
||||
m.config = config
|
||||
m.encoded = encoded.Bytes()
|
||||
m.locked = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) synchronizeRemote() error {
|
||||
current, err := m.currentModelForPersistence()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content, version, err := m.remoteClient.Open(m.remotePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open remote %s for synchronize: %w", m.remotePath, err)
|
||||
}
|
||||
|
||||
latest, config, err := vault.LoadKDBXWithConfig(bytes.NewReader(content), m.key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode remote %s for synchronize: %w", m.remotePath, err)
|
||||
}
|
||||
|
||||
base, err := m.baseModel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
merged := mergeModels(base, current, latest)
|
||||
var encoded bytes.Buffer
|
||||
if err := vault.SaveKDBXWithConfigAndKey(&encoded, merged, m.key, config); err != nil {
|
||||
return fmt.Errorf("encode synchronized remote vault: %w", err)
|
||||
}
|
||||
|
||||
nextVersion, err := m.remoteClient.Save(m.remotePath, bytes.NewReader(encoded.Bytes()), version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("save synchronized remote %s: %w", m.remotePath, err)
|
||||
}
|
||||
|
||||
m.model = merged
|
||||
m.config = config
|
||||
m.encoded = encoded.Bytes()
|
||||
m.remoteVersion = nextVersion
|
||||
m.locked = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) currentModelForPersistence() (vault.Model, error) {
|
||||
if m.locked {
|
||||
return vault.LoadKDBXWithKey(bytes.NewReader(m.encoded), m.key)
|
||||
}
|
||||
return m.model, nil
|
||||
}
|
||||
|
||||
func (m *Manager) baseModel() (vault.Model, error) {
|
||||
if len(m.encoded) == 0 {
|
||||
return vault.Model{}, nil
|
||||
}
|
||||
model, err := vault.LoadKDBXWithKey(bytes.NewReader(m.encoded), m.key)
|
||||
if err != nil {
|
||||
return vault.Model{}, fmt.Errorf("decode baseline vault: %w", err)
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func mergeModels(base, local, latest vault.Model) vault.Model {
|
||||
merged := latest
|
||||
merged.Entries = mergeEntrySet(base.Entries, local.Entries, latest.Entries)
|
||||
merged.Templates = mergeEntrySet(base.Templates, local.Templates, latest.Templates)
|
||||
merged.RecycleBin = mergeEntrySet(base.RecycleBin, local.RecycleBin, latest.RecycleBin)
|
||||
merged.Groups = mergeGroups(base.Groups, local.Groups, latest.Groups)
|
||||
return merged
|
||||
}
|
||||
|
||||
func mergeEntrySet(base, local, latest []vault.Entry) []vault.Entry {
|
||||
baseByID := mapEntries(base)
|
||||
localByID := mapEntries(local)
|
||||
latestByID := mapEntries(latest)
|
||||
|
||||
for id, current := range localByID {
|
||||
original, hadBase := baseByID[id]
|
||||
if !hadBase || !entriesEqual(original, current) {
|
||||
latestByID[id] = current
|
||||
}
|
||||
}
|
||||
for id := range baseByID {
|
||||
if _, stillLocal := localByID[id]; stillLocal {
|
||||
continue
|
||||
}
|
||||
delete(latestByID, id)
|
||||
}
|
||||
|
||||
out := make([]vault.Entry, 0, len(latestByID))
|
||||
for _, item := range latestByID {
|
||||
out = append(out, item)
|
||||
}
|
||||
slices.SortFunc(out, func(a, b vault.Entry) int {
|
||||
switch {
|
||||
case a.Title < b.Title:
|
||||
return -1
|
||||
case a.Title > b.Title:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func mapEntries(entries []vault.Entry) map[string]vault.Entry {
|
||||
out := make(map[string]vault.Entry, len(entries))
|
||||
for _, item := range entries {
|
||||
out[item.ID] = item
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func entriesEqual(a, b vault.Entry) bool {
|
||||
return a.ID == b.ID &&
|
||||
a.Title == b.Title &&
|
||||
a.Username == b.Username &&
|
||||
a.Password == b.Password &&
|
||||
a.URL == b.URL &&
|
||||
a.Notes == b.Notes &&
|
||||
slices.Equal(a.Tags, b.Tags) &&
|
||||
slices.Equal(a.Path, b.Path) &&
|
||||
reflect.DeepEqual(a.History, b.History) &&
|
||||
reflect.DeepEqual(a.Fields, b.Fields) &&
|
||||
equalAttachments(a.Attachments, b.Attachments)
|
||||
}
|
||||
|
||||
func equalAttachments(a, b map[string][]byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for key, value := range a {
|
||||
if !slices.Equal(value, b[key]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func mergeGroups(base, local, latest [][]string) [][]string {
|
||||
set := map[string][]string{}
|
||||
for _, path := range latest {
|
||||
set[pathKey(path)] = append([]string(nil), path...)
|
||||
}
|
||||
baseSet := map[string]bool{}
|
||||
for _, path := range base {
|
||||
baseSet[pathKey(path)] = true
|
||||
}
|
||||
localSet := map[string]bool{}
|
||||
for _, path := range local {
|
||||
key := pathKey(path)
|
||||
localSet[key] = true
|
||||
set[key] = append([]string(nil), path...)
|
||||
}
|
||||
for key := range baseSet {
|
||||
if localSet[key] {
|
||||
continue
|
||||
}
|
||||
delete(set, key)
|
||||
}
|
||||
out := make([][]string, 0, len(set))
|
||||
for _, path := range set {
|
||||
out = append(out, path)
|
||||
}
|
||||
slices.SortFunc(out, func(a, b []string) int {
|
||||
joinedA := pathKey(a)
|
||||
joinedB := pathKey(b)
|
||||
switch {
|
||||
case joinedA < joinedB:
|
||||
return -1
|
||||
case joinedA > joinedB:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func pathKey(path []string) string {
|
||||
return strings.Join(path, "\x00")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user