Files
keepassgo/vault/kdbx.go
T
2026-03-29 11:04:38 -07:00

551 lines
14 KiB
Go

package vault
import (
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"io"
"maps"
"slices"
"strings"
"time"
"github.com/tobischo/gokeepasslib/v3"
w "github.com/tobischo/gokeepasslib/v3/wrappers"
)
type KDBXConfig struct {
Header *gokeepasslib.DBHeader
InnerHeader *gokeepasslib.InnerHeader
}
var ErrInvalidMasterKey = errors.New("invalid master key")
const (
templatesRoot = "Templates"
recycleBinRoot = "Recycle Bin"
keepassGOIDField = "KeePassGO-ID"
)
func LoadKDBX(r io.Reader, password string) (Model, error) {
return LoadKDBXWithKey(r, MasterKey{Password: password})
}
func SaveKDBX(wr io.Writer, model Model, password string) error {
return SaveKDBXWithKey(wr, model, MasterKey{Password: password})
}
func SaveKDBXWithKey(wr io.Writer, model Model, key MasterKey) error {
return SaveKDBXWithConfigAndKey(wr, model, key, nil)
}
func SaveKDBXWithConfigAndKey(wr io.Writer, model Model, key MasterKey, config *KDBXConfig) error {
credentials, err := newCredentials(key)
if err != nil {
return err
}
header := gokeepasslib.NewHeader()
if config != nil && config.Header != nil {
header = cloneHeader(config.Header)
}
content := &gokeepasslib.DBContent{
Meta: gokeepasslib.NewMetaData(),
Root: &gokeepasslib.RootData{},
}
if header.IsKdbx4() {
if config != nil && config.InnerHeader != nil {
content.InnerHeader = cloneInnerHeader(config.InnerHeader)
} else {
content.InnerHeader = &gokeepasslib.InnerHeader{
InnerRandomStreamID: gokeepasslib.ChaChaStreamID,
InnerRandomStreamKey: randomBytes(64),
}
}
}
db := &gokeepasslib.Database{
Header: header,
Credentials: credentials,
Content: content,
Hashes: gokeepasslib.NewHashes(header),
}
db.Content.Root.Groups = buildGroupTree(db, entriesForPersistence(model))
db.Content.Root.DeletedObjects = marshalDeletedObjects(model.RecycleBin)
if err := db.LockProtectedEntries(); err != nil {
return fmt.Errorf("lock protected entries: %w", err)
}
if err := gokeepasslib.NewEncoder(wr).Encode(db); err != nil {
return fmt.Errorf("encode kdbx: %w", err)
}
return nil
}
func appendGroupEntries(model *Model, db *gokeepasslib.Database, group gokeepasslib.Group, path []string) {
path = append(clonePath(path), group.Name)
for _, entry := range group.Entries {
appendModelEntry(model, Entry{
ID: extractEntryID(entry),
Title: entry.GetTitle(),
Username: entry.GetContent("UserName"),
Password: entry.GetPassword(),
URL: entry.GetContent("URL"),
Notes: entry.GetContent("Notes"),
Tags: splitTags(entry.Tags),
Fields: extractCustomFields(entry),
Attachments: extractAttachments(db, entry),
History: extractHistory(db, entry, path),
Path: clonePath(path),
})
}
for _, child := range group.Groups {
appendGroupEntries(model, db, child, path)
}
}
func appendModelEntry(model *Model, entry Entry) {
if len(entry.Path) == 0 {
model.Entries = append(model.Entries, entry)
return
}
switch entry.Path[0] {
case templatesRoot:
model.Templates = append(model.Templates, entry)
return
case recycleBinRoot:
entry.Path = slices.Clone(entry.Path[1:])
model.RecycleBin = append(model.RecycleBin, entry)
return
}
model.Entries = append(model.Entries, entry)
}
func entriesForPersistence(model Model) []Entry {
entries := append(slices.Clone(model.Entries), model.Templates...)
for _, entry := range model.RecycleBin {
recycleEntry := cloneEntry(entry)
recycleEntry.Path = append([]string{recycleBinRoot}, recycleEntry.Path...)
entries = append(entries, recycleEntry)
}
return entries
}
func marshalUUID(id gokeepasslib.UUID) string {
text, err := id.MarshalText()
if err != nil {
return ""
}
return string(text)
}
func clonePath(path []string) []string {
if len(path) == 0 {
return nil
}
out := make([]string, len(path))
copy(out, path)
return out
}
func splitTags(tags string) []string {
if strings.TrimSpace(tags) == "" {
return nil
}
fields := strings.Split(tags, ";")
var out []string
for _, field := range fields {
field = strings.TrimSpace(field)
if field == "" {
continue
}
out = append(out, field)
}
return out
}
func extractCustomFields(entry gokeepasslib.Entry) map[string]string {
fields := map[string]string{}
for _, value := range entry.Values {
switch value.Key {
case "Title", "UserName", "Password", "URL", "Notes", keepassGOIDField:
continue
default:
fields[value.Key] = value.Value.Content
}
}
if len(fields) == 0 {
return nil
}
return fields
}
func extractEntryID(entry gokeepasslib.Entry) string {
if id := entry.GetContent(keepassGOIDField); id != "" {
return id
}
return marshalUUID(entry.UUID)
}
func extractHistory(db *gokeepasslib.Database, entry gokeepasslib.Entry, path []string) []Entry {
if len(entry.Histories) == 0 {
return nil
}
var history []Entry
for _, item := range entry.Histories {
for _, historical := range item.Entries {
history = append(history, Entry{
ID: marshalUUID(historical.UUID),
Title: historical.GetTitle(),
Username: historical.GetContent("UserName"),
Password: historical.GetPassword(),
URL: historical.GetContent("URL"),
Notes: historical.GetContent("Notes"),
Tags: splitTags(historical.Tags),
Fields: extractCustomFields(historical),
Attachments: extractAttachments(db, historical),
Path: clonePath(path),
})
}
}
return history
}
type groupNode struct {
name string
children map[string]*groupNode
entries []Entry
}
type MasterKey struct {
Password string
KeyFileData []byte
}
func buildGroupTree(db *gokeepasslib.Database, entries []Entry) []gokeepasslib.Group {
root := &groupNode{children: map[string]*groupNode{}}
for _, entry := range entries {
node := root
for _, segment := range entry.Path {
if node.children[segment] == nil {
node.children[segment] = &groupNode{
name: segment,
children: map[string]*groupNode{},
}
}
node = node.children[segment]
}
node.entries = append(node.entries, entry)
}
groups := marshalGroups(db, root)
if len(groups) > 0 {
return groups
}
group := gokeepasslib.NewGroup()
group.Name = "Root"
return []gokeepasslib.Group{group}
}
func LoadKDBXWithKey(r io.Reader, key MasterKey) (Model, error) {
model, _, err := LoadKDBXWithConfig(r, key)
return model, err
}
func LoadKDBXWithConfig(r io.Reader, key MasterKey) (Model, *KDBXConfig, error) {
credentials, err := newCredentials(key)
if err != nil {
return Model{}, nil, err
}
db := gokeepasslib.NewDatabase()
db.Credentials = credentials
if err := gokeepasslib.NewDecoder(r).Decode(db); err != nil {
if isInvalidCredentialError(err) {
return Model{}, nil, ErrInvalidMasterKey
}
return Model{}, nil, fmt.Errorf("decode kdbx: %w", err)
}
if err := db.UnlockProtectedEntries(); err != nil {
return Model{}, nil, fmt.Errorf("unlock protected entries: %w", err)
}
var model Model
for _, group := range db.Content.Root.Groups {
appendGroupEntries(&model, db, group, nil)
}
return model, &KDBXConfig{
Header: cloneHeader(db.Header),
InnerHeader: cloneInnerHeader(db.Content.InnerHeader),
}, nil
}
func newCredentials(key MasterKey) (*gokeepasslib.DBCredentials, error) {
switch {
case key.Password != "" && len(key.KeyFileData) > 0:
credentials, err := gokeepasslib.NewPasswordAndKeyDataCredentials(key.Password, key.KeyFileData)
if err != nil {
return nil, fmt.Errorf("build password+key credentials: %w", err)
}
return credentials, nil
case len(key.KeyFileData) > 0:
credentials, err := gokeepasslib.NewKeyDataCredentials(key.KeyFileData)
if err != nil {
return nil, fmt.Errorf("build key credentials: %w", err)
}
return credentials, nil
default:
return gokeepasslib.NewPasswordCredentials(key.Password), nil
}
}
func cloneHeader(header *gokeepasslib.DBHeader) *gokeepasslib.DBHeader {
if header == nil {
return nil
}
out := *header
out.RawData = nil
if header.Signature != nil {
signature := *header.Signature
out.Signature = &signature
}
if header.FileHeaders != nil {
fileHeaders := *header.FileHeaders
fileHeaders.Comment = slices.Clone(header.FileHeaders.Comment)
fileHeaders.CipherID = slices.Clone(header.FileHeaders.CipherID)
fileHeaders.MasterSeed = slices.Clone(header.FileHeaders.MasterSeed)
fileHeaders.TransformSeed = slices.Clone(header.FileHeaders.TransformSeed)
fileHeaders.EncryptionIV = slices.Clone(header.FileHeaders.EncryptionIV)
fileHeaders.ProtectedStreamKey = slices.Clone(header.FileHeaders.ProtectedStreamKey)
fileHeaders.StreamStartBytes = slices.Clone(header.FileHeaders.StreamStartBytes)
if header.FileHeaders.KdfParameters != nil {
kdf := *header.FileHeaders.KdfParameters
kdf.UUID = slices.Clone(header.FileHeaders.KdfParameters.UUID)
kdf.SecretKey = slices.Clone(header.FileHeaders.KdfParameters.SecretKey)
kdf.AssocData = slices.Clone(header.FileHeaders.KdfParameters.AssocData)
if header.FileHeaders.KdfParameters.RawData != nil {
kdf.RawData = cloneVariantDictionary(header.FileHeaders.KdfParameters.RawData)
}
fileHeaders.KdfParameters = &kdf
}
if header.FileHeaders.PublicCustomData != nil {
fileHeaders.PublicCustomData = cloneVariantDictionary(header.FileHeaders.PublicCustomData)
}
out.FileHeaders = &fileHeaders
}
return &out
}
func cloneVariantDictionary(dict *gokeepasslib.VariantDictionary) *gokeepasslib.VariantDictionary {
if dict == nil {
return nil
}
out := &gokeepasslib.VariantDictionary{Version: dict.Version}
out.Items = make([]*gokeepasslib.VariantDictionaryItem, 0, len(dict.Items))
for _, item := range dict.Items {
cloned := *item
cloned.Name = slices.Clone(item.Name)
cloned.Value = slices.Clone(item.Value)
out.Items = append(out.Items, &cloned)
}
return out
}
func cloneInnerHeader(header *gokeepasslib.InnerHeader) *gokeepasslib.InnerHeader {
if header == nil {
return nil
}
out := &gokeepasslib.InnerHeader{
InnerRandomStreamID: header.InnerRandomStreamID,
InnerRandomStreamKey: slices.Clone(header.InnerRandomStreamKey),
}
for _, binary := range header.Binaries {
out.Binaries = append(out.Binaries, gokeepasslib.Binary{
ID: binary.ID,
Compressed: binary.Compressed,
MemoryProtection: binary.MemoryProtection,
Content: slices.Clone(binary.Content),
})
}
return out
}
func randomBytes(length int) []byte {
buf := make([]byte, length)
_, _ = io.ReadFull(rand.Reader, buf)
return buf
}
func isInvalidCredentialError(err error) bool {
if errors.Is(err, gokeepasslib.ErrInvalidDatabaseOrCredentials) {
return true
}
return strings.Contains(err.Error(), "Wrong password?")
}
func marshalGroups(db *gokeepasslib.Database, node *groupNode) []gokeepasslib.Group {
names := slices.Collect(maps.Keys(node.children))
slices.Sort(names)
var groups []gokeepasslib.Group
for _, name := range names {
child := node.children[name]
group := gokeepasslib.NewGroup()
group.Name = child.name
group.Entries = marshalEntries(db, child.entries)
group.Groups = marshalGroups(db, child)
groups = append(groups, group)
}
return groups
}
func marshalEntries(db *gokeepasslib.Database, entries []Entry) []gokeepasslib.Entry {
slices.SortFunc(entries, func(a, b Entry) int {
switch {
case a.Title < b.Title:
return -1
case a.Title > b.Title:
return 1
default:
return 0
}
})
var out []gokeepasslib.Entry
for _, entry := range entries {
out = append(out, marshalEntry(db, entry))
}
return out
}
func marshalEntry(db *gokeepasslib.Database, entry Entry) gokeepasslib.Entry {
item := gokeepasslib.NewEntry()
item.UUID = uuidForEntryID(entry.ID)
item.Tags = strings.Join(entry.Tags, "; ")
item.Values = append(item.Values,
value("Title", entry.Title),
value("UserName", entry.Username),
protectedValue("Password", entry.Password),
value("URL", entry.URL),
value("Notes", entry.Notes),
value(keepassGOIDField, entry.ID),
)
keys := slices.Collect(maps.Keys(entry.Fields))
slices.Sort(keys)
for _, key := range keys {
item.Values = append(item.Values, value(key, entry.Fields[key]))
}
attachmentNames := slices.Collect(maps.Keys(entry.Attachments))
slices.Sort(attachmentNames)
for _, name := range attachmentNames {
binary := db.AddBinary(entry.Attachments[name])
item.Binaries = append(item.Binaries, binary.CreateReference(name))
}
for _, historical := range entry.History {
item.Histories = append(item.Histories, gokeepasslib.History{
Entries: []gokeepasslib.Entry{marshalEntry(db, historical)},
})
}
return item
}
func marshalDeletedObjects(entries []Entry) []gokeepasslib.DeletedObjectData {
if len(entries) == 0 {
return nil
}
deletionTime := w.Now()
out := make([]gokeepasslib.DeletedObjectData, 0, len(entries))
for _, entry := range entries {
out = append(out, gokeepasslib.DeletedObjectData{
UUID: uuidForEntryID(entry.ID),
DeletionTime: &deletionTime,
})
}
return out
}
func uuidForEntryID(id string) gokeepasslib.UUID {
if id != "" {
var uuid gokeepasslib.UUID
if err := uuid.UnmarshalText([]byte(id)); err == nil {
return uuid
}
}
sum := sha256.Sum256([]byte(id))
var uuid gokeepasslib.UUID
copy(uuid[:], sum[:len(uuid)])
if id == "" {
copy(uuid[:], time.Now().UTC().AppendFormat(nil, time.RFC3339Nano))
}
return uuid
}
func value(key, content string) gokeepasslib.ValueData {
return gokeepasslib.ValueData{Key: key, Value: gokeepasslib.V{Content: content}}
}
func protectedValue(key, content string) gokeepasslib.ValueData {
return gokeepasslib.ValueData{
Key: key,
Value: gokeepasslib.V{Content: content, Protected: w.NewBoolWrapper(true)},
}
}
func extractAttachments(db *gokeepasslib.Database, entry gokeepasslib.Entry) map[string][]byte {
if len(entry.Binaries) == 0 {
return nil
}
attachments := map[string][]byte{}
for _, ref := range entry.Binaries {
binary := db.FindBinary(ref.Value.ID)
if binary == nil {
continue
}
content, err := binary.GetContentBytes()
if err != nil {
continue
}
attachments[ref.Name] = slices.Clone(content)
}
if len(attachments) == 0 {
return nil
}
return attachments
}