diff --git a/session/session.go b/session/session.go index 7594999..8f61ccc 100644 --- a/session/session.go +++ b/session/session.go @@ -24,6 +24,7 @@ type Manager struct { config *vault.KDBXConfig path string key vault.MasterKey + vaultRoot string locked bool encoded []byte remoteClient *webdav.Client @@ -32,6 +33,8 @@ type Manager struct { } func (m *Manager) Create(model vault.Model, key vault.MasterKey) error { + root := detectSingleVaultRoot(model) + model = normalizeUnderRoot(model, root) var encoded bytes.Buffer if err := vault.SaveKDBXWithConfigAndKey(&encoded, model, key, m.config); err != nil { return fmt.Errorf("encode new vault: %w", err) @@ -39,6 +42,7 @@ func (m *Manager) Create(model vault.Model, key vault.MasterKey) error { m.model = model m.key = key + m.vaultRoot = root m.encoded = encoded.Bytes() m.locked = false return nil @@ -71,6 +75,7 @@ func (m *Manager) Open(path string, key vault.MasterKey) error { m.config = config m.path = path m.key = key + m.vaultRoot = detectSingleVaultRoot(model) m.encoded = content m.locked = false return nil @@ -102,6 +107,7 @@ func (m *Manager) OpenRemote(client webdav.Client, path string, key vault.Master m.model = model m.config = config m.key = key + m.vaultRoot = detectSingleVaultRoot(model) m.encoded = content m.locked = false m.remoteClient = &client @@ -162,10 +168,11 @@ func (m *Manager) SynchronizeToLocal(path string) error { if err != nil { return err } + merged = normalizeUnderRoot(merged, m.vaultRoot) if err := saveModelToLocal(path, merged, m.key, configOrCurrent(config, m.config)); err != nil { return err } - m.model = merged + m.model = normalizeUnderRoot(merged, m.vaultRoot) m.locked = false return nil } @@ -191,10 +198,11 @@ func (m *Manager) SynchronizeToRemote(client webdav.Client, path string) error { if err != nil { return err } + merged = normalizeUnderRoot(merged, m.vaultRoot) if err := saveModelToRemote(client, path, merged, m.key, configOrCurrent(config, m.config), version); err != nil { return err } - m.model = merged + m.model = normalizeUnderRoot(merged, m.vaultRoot) m.locked = false return nil } @@ -209,7 +217,12 @@ func (m *Manager) SaveAs(path string) error { } func (m *Manager) Replace(model vault.Model) { - m.model = model + root := m.vaultRoot + if root == "" { + root = detectSingleVaultRoot(model) + } + m.model = normalizeUnderRoot(model, root) + m.vaultRoot = root m.locked = false } @@ -227,7 +240,8 @@ func (m *Manager) Lock() error { } var encoded bytes.Buffer - if err := vault.SaveKDBXWithConfigAndKey(&encoded, m.model, m.key, m.config); err != nil { + model := normalizeUnderRoot(m.model, m.vaultRoot) + if err := vault.SaveKDBXWithConfigAndKey(&encoded, model, m.key, m.config); err != nil { return fmt.Errorf("encode vault for lock: %w", err) } @@ -246,6 +260,7 @@ func (m *Manager) Unlock(key vault.MasterKey) error { m.model = model m.config = config m.key = key + m.vaultRoot = detectSingleVaultRoot(model) m.locked = false return nil } @@ -303,9 +318,13 @@ func (m *Manager) persistableBytes() ([]byte, error) { if m.locked { return append([]byte(nil), m.encoded...), nil } + model, err := m.currentModelForPersistence() + if err != nil { + return nil, err + } var encoded bytes.Buffer - if err := vault.SaveKDBXWithConfigAndKey(&encoded, m.model, m.key, m.config); err != nil { + if err := vault.SaveKDBXWithConfigAndKey(&encoded, model, m.key, m.config); err != nil { return nil, fmt.Errorf("encode vault: %w", err) } return encoded.Bytes(), nil @@ -336,6 +355,7 @@ func (m *Manager) synchronizeLocal() error { } merged := mergeModels(base, current, latest) + merged = normalizeUnderRoot(merged, m.vaultRoot) var encoded bytes.Buffer if err := vault.SaveKDBXWithConfigAndKey(&encoded, merged, m.key, config); err != nil { return fmt.Errorf("encode synchronized vault: %w", err) @@ -373,6 +393,7 @@ func (m *Manager) synchronizeRemote() error { } merged := mergeModels(base, current, latest) + merged = normalizeUnderRoot(merged, m.vaultRoot) var encoded bytes.Buffer if err := vault.SaveKDBXWithConfigAndKey(&encoded, merged, m.key, config); err != nil { return fmt.Errorf("encode synchronized remote vault: %w", err) @@ -393,9 +414,13 @@ func (m *Manager) synchronizeRemote() error { func (m *Manager) currentModelForPersistence() (vault.Model, error) { if m.locked { - return vault.LoadKDBXWithKey(bytes.NewReader(m.encoded), m.key) + model, err := vault.LoadKDBXWithKey(bytes.NewReader(m.encoded), m.key) + if err != nil { + return vault.Model{}, err + } + return normalizeUnderRoot(model, m.vaultRoot), nil } - return m.model, nil + return normalizeUnderRoot(m.model, m.vaultRoot), nil } func (m *Manager) baseModel() (vault.Model, error) { @@ -418,6 +443,7 @@ func (m *Manager) mergedWithPeer(other vault.Model) (vault.Model, error) { } func (m *Manager) persistMergedToCurrentSource(merged vault.Model) error { + merged = normalizeUnderRoot(merged, m.vaultRoot) switch { case m.remoteClient != nil && m.remotePath != "": if err := saveModelToRemote(*m.remoteClient, m.remotePath, merged, m.key, configOrCurrent(m.config, nil), m.remoteVersion); err != nil { @@ -435,17 +461,22 @@ func (m *Manager) persistMergedToCurrentSource(merged vault.Model) error { } func (m *Manager) reloadCurrentLocal(merged vault.Model) error { + merged = normalizeUnderRoot(merged, m.vaultRoot) encoded, err := encodeModelWithConfig(merged, m.key, configOrCurrent(m.config, nil)) if err != nil { return err } m.model = merged + if root := detectSingleVaultRoot(merged); root != "" { + m.vaultRoot = root + } m.encoded = encoded m.locked = false return nil } func (m *Manager) reloadCurrentRemote(merged vault.Model) error { + merged = normalizeUnderRoot(merged, m.vaultRoot) encoded, err := encodeModelWithConfig(merged, m.key, configOrCurrent(m.config, nil)) if err != nil { return err @@ -455,6 +486,9 @@ func (m *Manager) reloadCurrentRemote(merged vault.Model) error { return fmt.Errorf("reopen remote %s after synchronize: %w", m.remotePath, err) } m.model = merged + if root := detectSingleVaultRoot(merged); root != "" { + m.vaultRoot = root + } m.encoded = encoded m.remoteVersion = version m.locked = false @@ -716,6 +750,52 @@ func mergePeerGroups(primary, secondary [][]string) [][]string { return out } +func detectSingleVaultRoot(model vault.Model) string { + if len(model.EntriesInPath(nil)) != 0 { + return "" + } + groups := model.ChildGroups(nil) + if len(groups) != 1 { + return "" + } + return groups[0] +} + +func normalizeUnderRoot(model vault.Model, root string) vault.Model { + if root == "" { + return model + } + + out := cloneModel(model) + normalizePath := func(path []string) []string { + switch { + case len(path) == 0: + return []string{root} + case path[0] == root: + return path + default: + return append([]string{root}, path...) + } + } + + for i := range out.Entries { + out.Entries[i].Path = normalizePath(out.Entries[i].Path) + for j := range out.Entries[i].History { + out.Entries[i].History[j].Path = normalizePath(out.Entries[i].History[j].Path) + } + } + for i := range out.RecycleBin { + out.RecycleBin[i].Path = normalizePath(out.RecycleBin[i].Path) + for j := range out.RecycleBin[i].History { + out.RecycleBin[i].History[j].Path = normalizePath(out.RecycleBin[i].History[j].Path) + } + } + for i := range out.Groups { + out.Groups[i] = normalizePath(out.Groups[i]) + } + return out +} + func loadLocalSource(path string, key vault.MasterKey) (vault.Model, *vault.KDBXConfig, error) { content, err := os.ReadFile(path) if err != nil { diff --git a/session/session_test.go b/session/session_test.go index 9263f46..7446328 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -175,6 +175,80 @@ func TestSavePersistsEditsBackToCurrentPath(t *testing.T) { } } +func TestSaveReparentsMixedPathsUnderSingleVaultRoot(t *testing.T) { + t.Parallel() + + key := vault.MasterKey{Password: "correct horse battery staple"} + path := filepath.Join(t.TempDir(), "hidden-root.kdbx") + + var initial bytes.Buffer + if err := vault.SaveKDBX(&initial, vault.Model{ + Entries: []vault.Entry{ + { + ID: "entry-1", + Title: "Vault Console", + Username: "dannyocean", + Password: "token-1", + URL: "https://vault.crew.example.invalid", + Path: []string{"keepass", "Crew", "Internet"}, + }, + { + ID: "entry-2", + Title: "Mail", + Username: "dannyocean", + Password: "token-2", + URL: "https://dispatch.crew.example.invalid", + Path: []string{"keepass", "Crew", "eMail"}, + }, + }, + }, key.Password); err != nil { + t.Fatalf("SaveKDBX() error = %v", err) + } + if err := os.WriteFile(path, initial.Bytes(), 0o600); err != nil { + t.Fatalf("WriteFile(hidden-root.kdbx) error = %v", err) + } + + var sess Manager + if err := sess.Open(path, key); err != nil { + t.Fatalf("Open() error = %v", err) + } + + current, err := sess.Current() + if err != nil { + t.Fatalf("Current() error = %v", err) + } + current.Entries[0].Path = []string{"Crew", "Internet"} + current.Groups = append(current.Groups, []string{"Crew"}, []string{"Crew", "Internet"}, []string{"Crew", "eMail"}) + sess.Replace(current) + + if err := sess.Save(); err != nil { + t.Fatalf("Save() error = %v", err) + } + + reopened, err := os.Open(path) + if err != nil { + t.Fatalf("Open(saved path) error = %v", err) + } + defer reopened.Close() + + db := gokeepasslib.NewDatabase() + db.Credentials = gokeepasslib.NewPasswordCredentials(key.Password) + if err := gokeepasslib.NewDecoder(reopened).Decode(db); err != nil { + t.Fatalf("Decode(saved path) error = %v", err) + } + if err := db.UnlockProtectedEntries(); err != nil { + t.Fatalf("UnlockProtectedEntries() error = %v", err) + } + + if len(db.Content.Root.Groups) != 1 || db.Content.Root.Groups[0].Name != "keepass" { + t.Fatalf("top-level groups = %#v, want single keepass root", db.Content.Root.Groups) + } + rootGroups := db.Content.Root.Groups[0].Groups + if len(rootGroups) != 1 || rootGroups[0].Name != "Crew" { + t.Fatalf("keepass child groups = %#v, want single Crew group", rootGroups) + } +} + func TestSaveWithoutPathFails(t *testing.T) { t.Parallel()