Use runtime-dir Unix sockets for local gRPC
This commit is contained in:
+50
-3
@@ -4,10 +4,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
"git.julianfamily.org/keepassgo/internal/grpcaddr"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
@@ -27,6 +30,7 @@ type Host struct {
|
||||
lastModel vault.Model
|
||||
started bool
|
||||
listenAddr string
|
||||
socketPath string
|
||||
}
|
||||
|
||||
func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, dirty DirtyProvider) (*Host, error) {
|
||||
@@ -35,7 +39,11 @@ func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]pass
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
network, endpoint, err := grpcaddr.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listener, socketPath, err := listen(network, endpoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen gRPC host %s: %w", addr, err)
|
||||
}
|
||||
@@ -50,7 +58,8 @@ func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]pass
|
||||
listener: listener,
|
||||
lifecycle: lifecycle,
|
||||
dirty: dirty,
|
||||
listenAddr: listener.Addr().String(),
|
||||
listenAddr: formatListenAddress(network, listener.Addr().String(), socketPath),
|
||||
socketPath: socketPath,
|
||||
started: true,
|
||||
}
|
||||
if err := host.SyncFromLifecycle(); err != nil && !errors.Is(err, session.ErrLocked) {
|
||||
@@ -91,7 +100,13 @@ func (h *Host) Stop() error {
|
||||
}
|
||||
h.started = false
|
||||
h.grpcServer.Stop()
|
||||
return h.listener.Close()
|
||||
err := h.listener.Close()
|
||||
if h.socketPath != "" {
|
||||
if removeErr := os.Remove(h.socketPath); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) && err == nil {
|
||||
err = removeErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Host) SyncFromLifecycle() error {
|
||||
@@ -120,3 +135,35 @@ func (h *Host) SyncFromLifecycle() error {
|
||||
h.server.SetSessionState(h.lastModel, locked, dirty)
|
||||
return nil
|
||||
}
|
||||
|
||||
func listen(network, endpoint string) (net.Listener, string, error) {
|
||||
if network == "unix" {
|
||||
if err := os.MkdirAll(filepath.Dir(endpoint), 0o700); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := os.Remove(endpoint); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, "", err
|
||||
}
|
||||
listener, err := net.Listen("unix", endpoint)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := os.Chmod(endpoint, 0o600); err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
return listener, endpoint, nil
|
||||
}
|
||||
listener, err := net.Listen(network, endpoint)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return listener, "", nil
|
||||
}
|
||||
|
||||
func formatListenAddress(network, listenerAddr, socketPath string) string {
|
||||
if network == "unix" {
|
||||
return "unix://" + socketPath
|
||||
}
|
||||
return listenerAddr
|
||||
}
|
||||
|
||||
@@ -2,10 +2,13 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
"git.julianfamily.org/keepassgo/internal/grpcaddr"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
@@ -42,10 +45,14 @@ func TestStartHostServesVaultLifecycleAndSyncsSessionState(t *testing.T) {
|
||||
}
|
||||
defer func() { _ = host.Stop() }()
|
||||
|
||||
network, endpoint, err := grpcaddr.Parse(host.Address())
|
||||
if err != nil {
|
||||
t.Fatalf("Parse(host.Address()) error = %v", err)
|
||||
}
|
||||
conn, err := grpc.NewClient("passthrough:///"+host.Address(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||
return net.Dial("tcp", host.Address())
|
||||
return net.Dial(network, endpoint)
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -80,3 +87,37 @@ func TestStartHostServesVaultLifecycleAndSyncsSessionState(t *testing.T) {
|
||||
t.Fatal("GetSessionStatus().Locked = false, want true after lifecycle lock")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartHostServesOverUnixSocket(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketDir := t.TempDir()
|
||||
socketPath := socketDir + "/keepassgo.sock"
|
||||
lifecycle := &session.Manager{}
|
||||
if err := lifecycle.Create(vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
testAPITokenEntry(t,
|
||||
apitokens.PolicyRule{Effect: apitokens.EffectAllow, Operation: apitokens.OperationManageVault, Resource: apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}}},
|
||||
),
|
||||
},
|
||||
}, vault.MasterKey{Password: "correct horse battery staple"}); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
host, err := StartHost("unix://"+socketPath, lifecycle, passwords.DefaultProfiles(), nil, func() bool { return false })
|
||||
if err != nil {
|
||||
t.Fatalf("StartHost() error = %v", err)
|
||||
}
|
||||
if got := host.Address(); got != "unix://"+socketPath {
|
||||
t.Fatalf("host.Address() = %q, want %q", got, "unix://"+socketPath)
|
||||
}
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
t.Fatalf("Stat(socketPath) error = %v", err)
|
||||
}
|
||||
if err := host.Stop(); err != nil {
|
||||
t.Fatalf("Stop() error = %v", err)
|
||||
}
|
||||
if _, err := os.Stat(socketPath); !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("socket exists after Stop(), err = %v, want not-exist", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user