Use runtime-dir Unix sockets for local gRPC
This commit is contained in:
+50
-3
@@ -4,10 +4,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/clipboard"
|
||||
"git.julianfamily.org/keepassgo/internal/grpcaddr"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
@@ -27,6 +30,7 @@ type Host struct {
|
||||
lastModel vault.Model
|
||||
started bool
|
||||
listenAddr string
|
||||
socketPath string
|
||||
}
|
||||
|
||||
func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, dirty DirtyProvider) (*Host, error) {
|
||||
@@ -35,7 +39,11 @@ func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]pass
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
network, endpoint, err := grpcaddr.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listener, socketPath, err := listen(network, endpoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen gRPC host %s: %w", addr, err)
|
||||
}
|
||||
@@ -50,7 +58,8 @@ func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]pass
|
||||
listener: listener,
|
||||
lifecycle: lifecycle,
|
||||
dirty: dirty,
|
||||
listenAddr: listener.Addr().String(),
|
||||
listenAddr: formatListenAddress(network, listener.Addr().String(), socketPath),
|
||||
socketPath: socketPath,
|
||||
started: true,
|
||||
}
|
||||
if err := host.SyncFromLifecycle(); err != nil && !errors.Is(err, session.ErrLocked) {
|
||||
@@ -91,7 +100,13 @@ func (h *Host) Stop() error {
|
||||
}
|
||||
h.started = false
|
||||
h.grpcServer.Stop()
|
||||
return h.listener.Close()
|
||||
err := h.listener.Close()
|
||||
if h.socketPath != "" {
|
||||
if removeErr := os.Remove(h.socketPath); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) && err == nil {
|
||||
err = removeErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Host) SyncFromLifecycle() error {
|
||||
@@ -120,3 +135,35 @@ func (h *Host) SyncFromLifecycle() error {
|
||||
h.server.SetSessionState(h.lastModel, locked, dirty)
|
||||
return nil
|
||||
}
|
||||
|
||||
func listen(network, endpoint string) (net.Listener, string, error) {
|
||||
if network == "unix" {
|
||||
if err := os.MkdirAll(filepath.Dir(endpoint), 0o700); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := os.Remove(endpoint); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, "", err
|
||||
}
|
||||
listener, err := net.Listen("unix", endpoint)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := os.Chmod(endpoint, 0o600); err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
return listener, endpoint, nil
|
||||
}
|
||||
listener, err := net.Listen(network, endpoint)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return listener, "", nil
|
||||
}
|
||||
|
||||
func formatListenAddress(network, listenerAddr, socketPath string) string {
|
||||
if network == "unix" {
|
||||
return "unix://" + socketPath
|
||||
}
|
||||
return listenerAddr
|
||||
}
|
||||
|
||||
@@ -2,10 +2,13 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
"git.julianfamily.org/keepassgo/internal/grpcaddr"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
@@ -42,10 +45,14 @@ func TestStartHostServesVaultLifecycleAndSyncsSessionState(t *testing.T) {
|
||||
}
|
||||
defer func() { _ = host.Stop() }()
|
||||
|
||||
network, endpoint, err := grpcaddr.Parse(host.Address())
|
||||
if err != nil {
|
||||
t.Fatalf("Parse(host.Address()) error = %v", err)
|
||||
}
|
||||
conn, err := grpc.NewClient("passthrough:///"+host.Address(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||
return net.Dial("tcp", host.Address())
|
||||
return net.Dial(network, endpoint)
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -80,3 +87,37 @@ func TestStartHostServesVaultLifecycleAndSyncsSessionState(t *testing.T) {
|
||||
t.Fatal("GetSessionStatus().Locked = false, want true after lifecycle lock")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartHostServesOverUnixSocket(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketDir := t.TempDir()
|
||||
socketPath := socketDir + "/keepassgo.sock"
|
||||
lifecycle := &session.Manager{}
|
||||
if err := lifecycle.Create(vault.Model{
|
||||
Entries: []vault.Entry{
|
||||
testAPITokenEntry(t,
|
||||
apitokens.PolicyRule{Effect: apitokens.EffectAllow, Operation: apitokens.OperationManageVault, Resource: apitokens.Resource{Kind: apitokens.ResourceGroup, Path: []string{"Root"}}},
|
||||
),
|
||||
},
|
||||
}, vault.MasterKey{Password: "correct horse battery staple"}); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
host, err := StartHost("unix://"+socketPath, lifecycle, passwords.DefaultProfiles(), nil, func() bool { return false })
|
||||
if err != nil {
|
||||
t.Fatalf("StartHost() error = %v", err)
|
||||
}
|
||||
if got := host.Address(); got != "unix://"+socketPath {
|
||||
t.Fatalf("host.Address() = %q, want %q", got, "unix://"+socketPath)
|
||||
}
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
t.Fatalf("Stat(socketPath) error = %v", err)
|
||||
}
|
||||
if err := host.Stop(); err != nil {
|
||||
t.Fatalf("Stop() error = %v", err)
|
||||
}
|
||||
if _, err := os.Stat(socketPath); !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("socket exists after Stop(), err = %v, want not-exist", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"git.julianfamily.org/keepassgo/internal/apiapproval"
|
||||
"git.julianfamily.org/keepassgo/internal/apitokens"
|
||||
"git.julianfamily.org/keepassgo/internal/appui/platform"
|
||||
"git.julianfamily.org/keepassgo/internal/grpcaddr"
|
||||
"git.julianfamily.org/keepassgo/internal/passwords"
|
||||
"git.julianfamily.org/keepassgo/internal/session"
|
||||
"git.julianfamily.org/keepassgo/internal/vault"
|
||||
@@ -56,10 +57,7 @@ func Main() {
|
||||
}
|
||||
|
||||
func defaultGRPCAddr(goos string) string {
|
||||
if strings.EqualFold(strings.TrimSpace(goos), "android") {
|
||||
return "off"
|
||||
}
|
||||
return "127.0.0.1:47777"
|
||||
return grpcaddr.Default(goos)
|
||||
}
|
||||
|
||||
func run(w *app.Window, mode string, paths statePaths, grpcAddr string) error {
|
||||
|
||||
@@ -2,22 +2,26 @@ package browserbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/grpcaddr"
|
||||
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
NativeHostName = "org.keepassgo.browser"
|
||||
DefaultGRPCAddress = "127.0.0.1:47777"
|
||||
defaultFirefoxID = "browser@keepassgo.invalid"
|
||||
NativeHostName = "com.keepassgo.browser"
|
||||
defaultFirefoxID = "browser@keepassgo.com"
|
||||
maxNativeMessageSize = 1024 * 1024
|
||||
chromiumIDBytes = 16
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
@@ -136,7 +140,7 @@ func (r Request) Connection() (Connection, error) {
|
||||
BearerToken: strings.TrimSpace(r.BearerToken),
|
||||
}
|
||||
if conn.GRPCAddress == "" {
|
||||
conn.GRPCAddress = DefaultGRPCAddress
|
||||
conn.GRPCAddress = grpcaddr.Default(runtime.GOOS)
|
||||
}
|
||||
if conn.BearerToken == "" {
|
||||
return Connection{}, fmt.Errorf("browser bridge bearer token is required")
|
||||
@@ -277,6 +281,31 @@ func Manifest(browser Browser, binaryPath, extensionID string) (NativeHostManife
|
||||
}
|
||||
}
|
||||
|
||||
func ChromiumExtensionIDFromManifestKey(raw string) (string, error) {
|
||||
normalized := strings.TrimSpace(raw)
|
||||
normalized = strings.ReplaceAll(normalized, "-----BEGIN PUBLIC KEY-----", "")
|
||||
normalized = strings.ReplaceAll(normalized, "-----END PUBLIC KEY-----", "")
|
||||
normalized = strings.ReplaceAll(normalized, "\n", "")
|
||||
normalized = strings.ReplaceAll(normalized, "\r", "")
|
||||
normalized = strings.ReplaceAll(normalized, "\t", "")
|
||||
normalized = strings.ReplaceAll(normalized, " ", "")
|
||||
if normalized == "" {
|
||||
return "", fmt.Errorf("chromium extension key is required")
|
||||
}
|
||||
publicKeyDER, err := base64.StdEncoding.DecodeString(normalized)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode chromium extension key: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256(publicKeyDER)
|
||||
var builder strings.Builder
|
||||
builder.Grow(chromiumIDBytes * 2)
|
||||
for _, b := range hash[:chromiumIDBytes] {
|
||||
builder.WriteByte('a' + ((b >> 4) & 0x0f))
|
||||
builder.WriteByte('a' + (b & 0x0f))
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func DefaultManifestPath(browser Browser) (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
|
||||
@@ -39,6 +41,9 @@ func TestReadRequestAndWriteResponse(t *testing.T) {
|
||||
if req.Action != "find-logins" || req.BearerToken != "secret" {
|
||||
t.Fatalf("ReadRequest() = %#v, want action and token preserved", req)
|
||||
}
|
||||
if conn, err := req.Connection(); err != nil || conn.GRPCAddress != "127.0.0.1:47777" {
|
||||
t.Fatalf("req.Connection() = (%#v, %v), want explicit tcp address preserved", conn, err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
if err := WriteResponse(&output, Response{Success: true, Version: "1"}); err != nil {
|
||||
@@ -118,6 +123,22 @@ func TestHandleRequestRequiresBearerToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestConnectionDefaultsAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := Request{Action: "status", BearerToken: "secret"}
|
||||
conn, err := req.Connection()
|
||||
if err != nil {
|
||||
t.Fatalf("Connection() error = %v", err)
|
||||
}
|
||||
if conn.GRPCAddress == "" {
|
||||
t.Fatal("Connection().GRPCAddress = empty, want default address")
|
||||
}
|
||||
if runtime.GOOS != "windows" && !strings.HasPrefix(conn.GRPCAddress, "unix://") && conn.GRPCAddress != "off" {
|
||||
t.Fatalf("Connection().GRPCAddress = %q, want unix socket default on this platform", conn.GRPCAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallManifest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -147,6 +168,19 @@ func TestInstallManifest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChromiumExtensionIDFromManifestKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const publicKey = "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAMfW0u1k4K5A0uN2s0aH7uQKpM3x5Hf8mZfY1xVh0m7E2mJ7M8GiV4m0g0I2w9U9D1yqGQ6w8jzH5v8t7qB2RjMCAwEAAQ=="
|
||||
got, err := ChromiumExtensionIDFromManifestKey(publicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("ChromiumExtensionIDFromManifestKey() error = %v", err)
|
||||
}
|
||||
if got != "okcdfigpojphpoecpglkkmkjmiaefmpd" {
|
||||
t.Fatalf("ChromiumExtensionIDFromManifestKey() = %q, want %q", got, "okcdfigpojphpoecpglkkmkjmiaefmpd")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClient struct {
|
||||
status *keepassgov1.GetSessionStatusResponse
|
||||
matches []*keepassgov1.BrowserLoginMatch
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"git.julianfamily.org/keepassgo/internal/grpcaddr"
|
||||
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@@ -18,20 +20,27 @@ type GRPCClient struct {
|
||||
|
||||
func Dial(ctx context.Context, conn Connection) (*grpc.ClientConn, *GRPCClient, context.Context, error) {
|
||||
if strings.TrimSpace(conn.GRPCAddress) == "" {
|
||||
conn.GRPCAddress = DefaultGRPCAddress
|
||||
conn.GRPCAddress = grpcaddr.Default(runtime.GOOS)
|
||||
}
|
||||
if strings.TrimSpace(conn.BearerToken) == "" {
|
||||
return nil, nil, nil, fmt.Errorf("browser bridge bearer token is required")
|
||||
}
|
||||
address := strings.TrimSpace(conn.GRPCAddress)
|
||||
grpcConn, err := grpc.NewClient("passthrough:///"+address,
|
||||
network, endpoint, err := grpcaddr.Parse(conn.GRPCAddress)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
target := endpoint
|
||||
if network == "unix" {
|
||||
target = "passthrough:///" + endpoint
|
||||
}
|
||||
grpcConn, err := grpc.NewClient(target,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||
return net.Dial("tcp", address)
|
||||
return net.Dial(network, endpoint)
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("dial gRPC host %s: %w", address, err)
|
||||
return nil, nil, nil, fmt.Errorf("dial gRPC host %s: %w", strings.TrimSpace(conn.GRPCAddress), err)
|
||||
}
|
||||
ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+strings.TrimSpace(conn.BearerToken))
|
||||
return grpcConn, &GRPCClient{client: keepassgov1.NewVaultServiceClient(grpcConn)}, ctx, nil
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package grpcaddr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const socketName = "keepassgo-grpc.sock"
|
||||
|
||||
func Default(goos string) string {
|
||||
if strings.EqualFold(strings.TrimSpace(goos), "android") {
|
||||
return "off"
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(goos), "windows") {
|
||||
return "127.0.0.1:47777"
|
||||
}
|
||||
return "unix://" + DefaultSocketPath()
|
||||
}
|
||||
|
||||
func DefaultSocketPath() string {
|
||||
return filepath.Join(runtimeDir(), "keepassgo", socketName)
|
||||
}
|
||||
|
||||
func runtimeDir() string {
|
||||
if dir := strings.TrimSpace(os.Getenv("XDG_RUNTIME_DIR")); dir != "" {
|
||||
return dir
|
||||
}
|
||||
if runtime.GOOS != "windows" {
|
||||
uid := strconv.Itoa(os.Getuid())
|
||||
runUserDir := filepath.Join("/run/user", uid)
|
||||
if info, err := os.Stat(runUserDir); err == nil && info.IsDir() {
|
||||
return runUserDir
|
||||
}
|
||||
}
|
||||
return filepath.Join(os.TempDir(), fmt.Sprintf("keepassgo-runtime-%d", os.Getuid()))
|
||||
}
|
||||
|
||||
func Parse(raw string) (network, endpoint string, err error) {
|
||||
value := strings.TrimSpace(raw)
|
||||
switch {
|
||||
case value == "":
|
||||
return "", "", fmt.Errorf("gRPC address is required")
|
||||
case strings.EqualFold(value, "off"):
|
||||
return "", "", nil
|
||||
case strings.HasPrefix(value, "unix://"):
|
||||
path := strings.TrimSpace(strings.TrimPrefix(value, "unix://"))
|
||||
if path == "" {
|
||||
return "", "", fmt.Errorf("unix gRPC socket path is required")
|
||||
}
|
||||
return "unix", path, nil
|
||||
case strings.HasPrefix(value, "tcp://"):
|
||||
addr := strings.TrimSpace(strings.TrimPrefix(value, "tcp://"))
|
||||
if addr == "" {
|
||||
return "", "", fmt.Errorf("tcp gRPC address is required")
|
||||
}
|
||||
return "tcp", addr, nil
|
||||
case strings.HasPrefix(value, "/"):
|
||||
return "unix", value, nil
|
||||
default:
|
||||
return "tcp", value, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package grpcaddr
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultUsesUnixSocketOnUnixLikeSystems(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix default is not expected on windows")
|
||||
}
|
||||
t.Setenv("XDG_RUNTIME_DIR", "/tmp/keepassgo-runtime-test")
|
||||
|
||||
got := Default("linux")
|
||||
want := "unix:///tmp/keepassgo-runtime-test/keepassgo/keepassgo-grpc.sock"
|
||||
if got != want {
|
||||
t.Fatalf("Default() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantNetwork string
|
||||
wantEnd string
|
||||
}{
|
||||
{name: "unix scheme", input: "unix:///tmp/keepassgo.sock", wantNetwork: "unix", wantEnd: "/tmp/keepassgo.sock"},
|
||||
{name: "tcp scheme", input: "tcp://127.0.0.1:47777", wantNetwork: "tcp", wantEnd: "127.0.0.1:47777"},
|
||||
{name: "bare path", input: filepath.Clean("/tmp/keepassgo.sock"), wantNetwork: "unix", wantEnd: filepath.Clean("/tmp/keepassgo.sock")},
|
||||
{name: "bare tcp", input: "127.0.0.1:47777", wantNetwork: "tcp", wantEnd: "127.0.0.1:47777"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotNetwork, gotEnd, err := Parse(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse() error = %v", err)
|
||||
}
|
||||
if gotNetwork != tt.wantNetwork || gotEnd != tt.wantEnd {
|
||||
t.Fatalf("Parse() = (%q, %q), want (%q, %q)", gotNetwork, gotEnd, tt.wantNetwork, tt.wantEnd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user