package vault import ( "crypto/rand" "crypto/sha256" "encoding/json" "errors" "fmt" "io" "maps" "slices" "strings" "time" "github.com/tobischo/gokeepasslib/v3" w "github.com/tobischo/gokeepasslib/v3/wrappers" ) type KDBXConfig struct { Header *gokeepasslib.DBHeader InnerHeader *gokeepasslib.InnerHeader } var ErrInvalidMasterKey = errors.New("invalid master key") const ( templatesRoot = "Templates" recycleBinRoot = "Recycle Bin" keepassRoot = "keepass" keepassGOIDField = "KeePassGO-ID" remoteProfilesKey = "keepassgo.remoteProfiles" ) func LoadKDBX(r io.Reader, password string) (Model, error) { return LoadKDBXWithKey(r, MasterKey{Password: password}) } func SaveKDBX(wr io.Writer, model Model, password string) error { return SaveKDBXWithKey(wr, model, MasterKey{Password: password}) } func SaveKDBXWithKey(wr io.Writer, model Model, key MasterKey) error { return SaveKDBXWithConfigAndKey(wr, model, key, nil) } func SaveKDBXWithConfigAndKey(wr io.Writer, model Model, key MasterKey, config *KDBXConfig) error { credentials, err := newCredentials(key) if err != nil { return err } db := gokeepasslib.NewDatabase(gokeepasslib.WithDatabaseKDBXVersion4()) db.Credentials = credentials db.Content.Meta = gokeepasslib.NewMetaData() db.Content.Meta.CustomData = customDataForModel(model) db.Content.Root = &gokeepasslib.RootData{} if config != nil && config.Header != nil { db.Header = cloneHeader(config.Header) db.Hashes = gokeepasslib.NewHashes(db.Header) } if db.Header.IsKdbx4() { if config != nil && config.InnerHeader != nil { db.Content.InnerHeader = cloneInnerHeader(config.InnerHeader) db.Content.InnerHeader.Binaries = nil } else if db.Content.InnerHeader == nil { db.Content.InnerHeader = &gokeepasslib.InnerHeader{ InnerRandomStreamID: gokeepasslib.ChaChaStreamID, InnerRandomStreamKey: randomBytes(64), } } } else { db.Content.InnerHeader = nil } db.Content.Root.Groups = buildGroupTree(db, model) db.Content.Root.DeletedObjects = nil if err := db.LockProtectedEntries(); err != nil { return fmt.Errorf("lock protected entries: %w", err) } if err := gokeepasslib.NewEncoder(wr).Encode(db); err != nil { return fmt.Errorf("encode kdbx: %w", err) } return nil } func appendGroupEntries(model *Model, db *gokeepasslib.Database, group gokeepasslib.Group, path []string) { path = append(clonePath(path), group.Name) model.CreateGroup(path[:len(path)-1], group.Name) for _, entry := range group.Entries { appendModelEntry(model, Entry{ ID: extractEntryID(entry), Title: entry.GetTitle(), Username: entry.GetContent("UserName"), Password: entry.GetPassword(), URL: entry.GetContent("URL"), Notes: entry.GetContent("Notes"), Tags: splitTags(entry.Tags), Fields: extractCustomFields(entry), Attachments: extractAttachments(db, entry), History: extractHistory(db, entry, path), Path: clonePath(path), }) } for _, child := range group.Groups { appendGroupEntries(model, db, child, path) } } func appendModelEntry(model *Model, entry Entry) { if len(entry.Path) == 0 { model.Entries = append(model.Entries, entry) return } switch entry.Path[0] { case templatesRoot: model.Templates = append(model.Templates, entry) return case recycleBinRoot: entry.Path = slices.Clone(entry.Path[1:]) model.RecycleBin = append(model.RecycleBin, entry) return } model.Entries = append(model.Entries, entry) } func entriesForPersistence(model Model) []Entry { entries := append(slices.Clone(model.Entries), model.Templates...) for _, entry := range model.RecycleBin { recycleEntry := cloneEntry(entry) recycleEntry.Path = append([]string{recycleBinRoot}, recycleEntry.Path...) entries = append(entries, recycleEntry) } return entries } func marshalUUID(id gokeepasslib.UUID) string { text, err := id.MarshalText() if err != nil { return "" } return string(text) } func clonePath(path []string) []string { if len(path) == 0 { return nil } out := make([]string, len(path)) copy(out, path) return out } func splitTags(tags string) []string { if strings.TrimSpace(tags) == "" { return nil } fields := strings.Split(tags, ";") var out []string for _, field := range fields { field = strings.TrimSpace(field) if field == "" { continue } out = append(out, field) } return out } func extractCustomFields(entry gokeepasslib.Entry) map[string]string { fields := map[string]string{} for _, value := range entry.Values { switch value.Key { case "Title", "UserName", "Password", "URL", "Notes", keepassGOIDField: continue default: fields[value.Key] = value.Value.Content } } if len(fields) == 0 { return nil } return fields } func extractEntryID(entry gokeepasslib.Entry) string { if id := entry.GetContent(keepassGOIDField); id != "" { return id } return marshalUUID(entry.UUID) } func extractHistory(db *gokeepasslib.Database, entry gokeepasslib.Entry, path []string) []Entry { if len(entry.Histories) == 0 { return nil } var history []Entry for _, item := range entry.Histories { for _, historical := range item.Entries { history = append(history, Entry{ ID: extractEntryID(historical), Title: historical.GetTitle(), Username: historical.GetContent("UserName"), Password: historical.GetPassword(), URL: historical.GetContent("URL"), Notes: historical.GetContent("Notes"), Tags: splitTags(historical.Tags), Fields: extractCustomFields(historical), Attachments: extractAttachments(db, historical), Path: clonePath(path), }) } } return history } type groupNode struct { name string children map[string]*groupNode entries []Entry } type MasterKey struct { Password string KeyFileData []byte } func buildGroupTree(db *gokeepasslib.Database, model Model) []gokeepasslib.Group { entries := entriesForPersistence(model) root := &groupNode{children: map[string]*groupNode{}} for _, entry := range entries { node := root for _, segment := range entry.Path { if node.children[segment] == nil { node.children[segment] = &groupNode{ name: segment, children: map[string]*groupNode{}, } } node = node.children[segment] } node.entries = append(node.entries, entry) } for _, path := range groupPathsForPersistence(model, entries) { node := root for _, segment := range path { if node.children[segment] == nil { node.children[segment] = &groupNode{ name: segment, children: map[string]*groupNode{}, } } node = node.children[segment] } } groups := marshalGroups(db, root) if len(groups) > 0 { return groups } group := gokeepasslib.NewGroup() group.Name = "Root" return []gokeepasslib.Group{group} } func groupPathsForPersistence(model Model, entries []Entry) [][]string { seen := map[string]bool{} var groups [][]string appendPath := func(path []string) { key := strings.Join(path, "\x00") if seen[key] { return } seen[key] = true groups = append(groups, slices.Clone(path)) } for _, entry := range entries { for i := 1; i <= len(entry.Path); i++ { appendPath(entry.Path[:i]) } } for _, path := range model.Groups { for i := 1; i <= len(path); i++ { appendPath(path[:i]) } } return groups } func LoadKDBXWithKey(r io.Reader, key MasterKey) (Model, error) { model, _, err := LoadKDBXWithConfig(r, key) return model, err } func LoadKDBXWithConfig(r io.Reader, key MasterKey) (Model, *KDBXConfig, error) { credentials, err := newCredentials(key) if err != nil { return Model{}, nil, err } db := gokeepasslib.NewDatabase() db.Credentials = credentials if err := gokeepasslib.NewDecoder(r).Decode(db); err != nil { if isInvalidCredentialError(err) { return Model{}, nil, ErrInvalidMasterKey } return Model{}, nil, fmt.Errorf("decode kdbx: %w", err) } if err := db.UnlockProtectedEntries(); err != nil { return Model{}, nil, fmt.Errorf("unlock protected entries: %w", err) } var model Model for _, group := range db.Content.Root.Groups { appendGroupEntries(&model, db, group, nil) } model.RemoteProfiles = remoteProfilesFromMeta(db.Content.Meta) return model, &KDBXConfig{ Header: cloneHeader(db.Header), InnerHeader: cloneInnerHeader(db.Content.InnerHeader), }, nil } func customDataForModel(model Model) []gokeepasslib.CustomData { if len(model.RemoteProfiles) == 0 { return nil } content, err := json.Marshal(model.RemoteProfiles) if err != nil { return nil } return []gokeepasslib.CustomData{{ Key: remoteProfilesKey, Value: string(content), }} } func remoteProfilesFromMeta(meta *gokeepasslib.MetaData) []RemoteProfile { if meta == nil { return nil } for _, item := range meta.CustomData { if item.Key != remoteProfilesKey { continue } var profiles []RemoteProfile if err := json.Unmarshal([]byte(item.Value), &profiles); err != nil { return nil } return profiles } return nil } func newCredentials(key MasterKey) (*gokeepasslib.DBCredentials, error) { switch { case key.Password != "" && len(key.KeyFileData) > 0: credentials, err := gokeepasslib.NewPasswordAndKeyDataCredentials(key.Password, key.KeyFileData) if err != nil { return nil, fmt.Errorf("build password+key credentials: %w", err) } return credentials, nil case len(key.KeyFileData) > 0: credentials, err := gokeepasslib.NewKeyDataCredentials(key.KeyFileData) if err != nil { return nil, fmt.Errorf("build key credentials: %w", err) } return credentials, nil default: return gokeepasslib.NewPasswordCredentials(key.Password), nil } } func cloneHeader(header *gokeepasslib.DBHeader) *gokeepasslib.DBHeader { if header == nil { return nil } out := *header out.RawData = nil if header.Signature != nil { signature := *header.Signature out.Signature = &signature } if header.FileHeaders != nil { fileHeaders := *header.FileHeaders fileHeaders.Comment = slices.Clone(header.FileHeaders.Comment) fileHeaders.CipherID = slices.Clone(header.FileHeaders.CipherID) fileHeaders.MasterSeed = slices.Clone(header.FileHeaders.MasterSeed) fileHeaders.TransformSeed = slices.Clone(header.FileHeaders.TransformSeed) fileHeaders.EncryptionIV = slices.Clone(header.FileHeaders.EncryptionIV) fileHeaders.ProtectedStreamKey = slices.Clone(header.FileHeaders.ProtectedStreamKey) fileHeaders.StreamStartBytes = slices.Clone(header.FileHeaders.StreamStartBytes) if header.FileHeaders.KdfParameters != nil { kdf := *header.FileHeaders.KdfParameters kdf.UUID = slices.Clone(header.FileHeaders.KdfParameters.UUID) kdf.SecretKey = slices.Clone(header.FileHeaders.KdfParameters.SecretKey) kdf.AssocData = slices.Clone(header.FileHeaders.KdfParameters.AssocData) if header.FileHeaders.KdfParameters.RawData != nil { kdf.RawData = cloneVariantDictionary(header.FileHeaders.KdfParameters.RawData) } fileHeaders.KdfParameters = &kdf } if header.FileHeaders.PublicCustomData != nil { fileHeaders.PublicCustomData = cloneVariantDictionary(header.FileHeaders.PublicCustomData) } out.FileHeaders = &fileHeaders } return &out } func cloneVariantDictionary(dict *gokeepasslib.VariantDictionary) *gokeepasslib.VariantDictionary { if dict == nil { return nil } out := &gokeepasslib.VariantDictionary{Version: dict.Version} out.Items = make([]*gokeepasslib.VariantDictionaryItem, 0, len(dict.Items)) for _, item := range dict.Items { cloned := *item cloned.Name = slices.Clone(item.Name) cloned.Value = slices.Clone(item.Value) out.Items = append(out.Items, &cloned) } return out } func cloneInnerHeader(header *gokeepasslib.InnerHeader) *gokeepasslib.InnerHeader { if header == nil { return nil } out := &gokeepasslib.InnerHeader{ InnerRandomStreamID: header.InnerRandomStreamID, InnerRandomStreamKey: slices.Clone(header.InnerRandomStreamKey), } for _, binary := range header.Binaries { out.Binaries = append(out.Binaries, gokeepasslib.Binary{ ID: binary.ID, Compressed: binary.Compressed, MemoryProtection: binary.MemoryProtection, Content: slices.Clone(binary.Content), }) } return out } func randomBytes(length int) []byte { buf := make([]byte, length) _, _ = io.ReadFull(rand.Reader, buf) return buf } func isInvalidCredentialError(err error) bool { if errors.Is(err, gokeepasslib.ErrInvalidDatabaseOrCredentials) { return true } return strings.Contains(err.Error(), "Wrong password?") } func marshalGroups(db *gokeepasslib.Database, node *groupNode) []gokeepasslib.Group { names := slices.Collect(maps.Keys(node.children)) slices.SortFunc(names, compareGroupNames) var groups []gokeepasslib.Group for _, name := range names { child := node.children[name] group := gokeepasslib.NewGroup() group.Name = child.name group.Entries = marshalEntries(db, child.entries) group.Groups = marshalGroups(db, child) groups = append(groups, group) } return groups } func compareGroupNames(a, b string) int { switch { case a == b: return 0 case a == "Root": return -1 case b == "Root": return 1 case a == keepassRoot: return -1 case b == keepassRoot: return 1 case a == templatesRoot: return -1 case b == templatesRoot: return 1 case a == recycleBinRoot: return 1 case b == recycleBinRoot: return -1 case a < b: return -1 default: return 1 } } func marshalEntries(db *gokeepasslib.Database, entries []Entry) []gokeepasslib.Entry { slices.SortFunc(entries, func(a, b Entry) int { switch { case a.Title < b.Title: return -1 case a.Title > b.Title: return 1 default: return 0 } }) var out []gokeepasslib.Entry for _, entry := range entries { out = append(out, marshalEntry(db, entry)) } return out } func marshalEntry(db *gokeepasslib.Database, entry Entry) gokeepasslib.Entry { item := gokeepasslib.NewEntry() item.UUID = uuidForEntryID(entry.ID) item.Tags = strings.Join(entry.Tags, "; ") item.Values = append(item.Values, value("Title", entry.Title), value("UserName", entry.Username), protectedValue("Password", entry.Password), value("URL", entry.URL), value("Notes", entry.Notes), value(keepassGOIDField, entry.ID), ) keys := slices.Collect(maps.Keys(entry.Fields)) slices.Sort(keys) for _, key := range keys { item.Values = append(item.Values, value(key, entry.Fields[key])) } attachmentNames := slices.Collect(maps.Keys(entry.Attachments)) slices.Sort(attachmentNames) for _, name := range attachmentNames { binary := db.AddBinary(entry.Attachments[name]) item.Binaries = append(item.Binaries, binary.CreateReference(name)) } for _, historical := range entry.History { item.Histories = append(item.Histories, gokeepasslib.History{ Entries: []gokeepasslib.Entry{marshalEntry(db, historical)}, }) } return item } func uuidForEntryID(id string) gokeepasslib.UUID { if id != "" { var uuid gokeepasslib.UUID if err := uuid.UnmarshalText([]byte(id)); err == nil { return uuid } } sum := sha256.Sum256([]byte(id)) var uuid gokeepasslib.UUID copy(uuid[:], sum[:len(uuid)]) if id == "" { copy(uuid[:], time.Now().UTC().AppendFormat(nil, time.RFC3339Nano)) } return uuid } func value(key, content string) gokeepasslib.ValueData { return gokeepasslib.ValueData{Key: key, Value: gokeepasslib.V{Content: content}} } func protectedValue(key, content string) gokeepasslib.ValueData { return gokeepasslib.ValueData{ Key: key, Value: gokeepasslib.V{Content: content, Protected: w.NewBoolWrapper(true)}, } } func extractAttachments(db *gokeepasslib.Database, entry gokeepasslib.Entry) map[string][]byte { if len(entry.Binaries) == 0 { return nil } attachments := map[string][]byte{} for _, ref := range entry.Binaries { binary := db.FindBinary(ref.Value.ID) if binary == nil { continue } content, err := binary.GetContentBytes() if err != nil { continue } attachments[ref.Name] = slices.Clone(content) } if len(attachments) == 0 { return nil } return attachments }