Use runtime-dir Unix sockets for local gRPC

This commit is contained in:
Joe Julian
2026-04-11 08:26:37 -07:00
parent c017308aa1
commit 2ef571c241
16 changed files with 346 additions and 29 deletions
+50 -3
View File
@@ -4,10 +4,13 @@ import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"git.julianfamily.org/keepassgo/internal/clipboard"
"git.julianfamily.org/keepassgo/internal/grpcaddr"
"git.julianfamily.org/keepassgo/internal/passwords"
"git.julianfamily.org/keepassgo/internal/session"
"git.julianfamily.org/keepassgo/internal/vault"
@@ -27,6 +30,7 @@ type Host struct {
lastModel vault.Model
started bool
listenAddr string
socketPath string
}
func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, dirty DirtyProvider) (*Host, error) {
@@ -35,7 +39,11 @@ func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]pass
return nil, nil
}
listener, err := net.Listen("tcp", addr)
network, endpoint, err := grpcaddr.Parse(addr)
if err != nil {
return nil, err
}
listener, socketPath, err := listen(network, endpoint)
if err != nil {
return nil, fmt.Errorf("listen gRPC host %s: %w", addr, err)
}
@@ -50,7 +58,8 @@ func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]pass
listener: listener,
lifecycle: lifecycle,
dirty: dirty,
listenAddr: listener.Addr().String(),
listenAddr: formatListenAddress(network, listener.Addr().String(), socketPath),
socketPath: socketPath,
started: true,
}
if err := host.SyncFromLifecycle(); err != nil && !errors.Is(err, session.ErrLocked) {
@@ -91,7 +100,13 @@ func (h *Host) Stop() error {
}
h.started = false
h.grpcServer.Stop()
return h.listener.Close()
err := h.listener.Close()
if h.socketPath != "" {
if removeErr := os.Remove(h.socketPath); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) && err == nil {
err = removeErr
}
}
return err
}
func (h *Host) SyncFromLifecycle() error {
@@ -120,3 +135,35 @@ func (h *Host) SyncFromLifecycle() error {
h.server.SetSessionState(h.lastModel, locked, dirty)
return nil
}
func listen(network, endpoint string) (net.Listener, string, error) {
if network == "unix" {
if err := os.MkdirAll(filepath.Dir(endpoint), 0o700); err != nil {
return nil, "", err
}
if err := os.Remove(endpoint); err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, "", err
}
listener, err := net.Listen("unix", endpoint)
if err != nil {
return nil, "", err
}
if err := os.Chmod(endpoint, 0o600); err != nil {
_ = listener.Close()
return nil, "", err
}
return listener, endpoint, nil
}
listener, err := net.Listen(network, endpoint)
if err != nil {
return nil, "", err
}
return listener, "", nil
}
func formatListenAddress(network, listenerAddr, socketPath string) string {
if network == "unix" {
return "unix://" + socketPath
}
return listenerAddr
}