Files
keepassgo/internal/api/host.go
T
2026-04-11 23:56:48 -07:00

173 lines
3.7 KiB
Go

package api
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"git.julianfamily.org/keepassgo/internal/clipboard"
"git.julianfamily.org/keepassgo/internal/grpcaddr"
"git.julianfamily.org/keepassgo/internal/passwords"
"git.julianfamily.org/keepassgo/internal/session"
"git.julianfamily.org/keepassgo/internal/vault"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
"google.golang.org/grpc"
)
type DirtyProvider func() bool
type Host struct {
server *Server
grpcServer *grpc.Server
listener net.Listener
lifecycle lifecycleBackend
dirty DirtyProvider
mu sync.Mutex
lastModel vault.Model
started bool
listenAddr string
socketPath string
}
func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, dirty DirtyProvider) (*Host, error) {
addr = strings.TrimSpace(addr)
if addr == "" || strings.EqualFold(addr, "off") {
return nil, nil
}
network, endpoint, err := grpcaddr.Parse(addr)
if err != nil {
return nil, err
}
listener, socketPath, err := listen(network, endpoint)
if err != nil {
return nil, fmt.Errorf("listen gRPC host %s: %w", addr, err)
}
service := NewServerWithLifecycle(vault.Model{}, profiles, clipboardWriter, lifecycle)
server := grpc.NewServer()
keepassgov1.RegisterVaultServiceServer(server, service)
host := &Host{
server: service,
grpcServer: server,
listener: listener,
lifecycle: lifecycle,
dirty: dirty,
listenAddr: formatListenAddress(network, listener.Addr().String(), socketPath),
socketPath: socketPath,
started: true,
}
if err := host.SyncFromLifecycle(); err != nil && !errors.Is(err, session.ErrLocked) {
_ = listener.Close()
server.Stop()
return nil, err
}
go func() {
_ = server.Serve(listener)
}()
return host, nil
}
func (h *Host) Address() string {
if h == nil {
return ""
}
return h.listenAddr
}
func (h *Host) Server() *Server {
if h == nil {
return nil
}
return h.server
}
func (h *Host) Stop() error {
if h == nil {
return nil
}
h.mu.Lock()
defer h.mu.Unlock()
if !h.started {
return nil
}
h.started = false
h.grpcServer.Stop()
err := h.listener.Close()
if errors.Is(err, net.ErrClosed) {
err = nil
}
if h.socketPath != "" {
if removeErr := os.Remove(h.socketPath); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) && err == nil {
err = removeErr
}
}
return err
}
func (h *Host) SyncFromLifecycle() error {
if h == nil || h.lifecycle == nil || h.server == nil {
return nil
}
h.mu.Lock()
defer h.mu.Unlock()
model, err := h.lifecycle.Current()
locked := false
switch {
case err == nil:
h.lastModel = model
case errors.Is(err, session.ErrLocked):
locked = true
default:
return err
}
dirty := false
if h.dirty != nil {
dirty = h.dirty()
}
h.server.SetSessionState(h.lastModel, locked, dirty)
return nil
}
func listen(network, endpoint string) (net.Listener, string, error) {
if network == "unix" {
if err := os.MkdirAll(filepath.Dir(endpoint), 0o700); err != nil {
return nil, "", err
}
if err := os.Remove(endpoint); err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, "", err
}
listener, err := net.Listen("unix", endpoint)
if err != nil {
return nil, "", err
}
if err := os.Chmod(endpoint, 0o600); err != nil {
_ = listener.Close()
return nil, "", err
}
return listener, endpoint, nil
}
listener, err := net.Listen(network, endpoint)
if err != nil {
return nil, "", err
}
return listener, "", nil
}
func formatListenAddress(network, listenerAddr, socketPath string) string {
if network == "unix" {
return "unix://" + socketPath
}
return listenerAddr
}