Host the gRPC API and add token admin views

This commit is contained in:
Joe Julian
2026-03-30 07:50:34 -07:00
parent 84c188129e
commit 9afddd7a93
9 changed files with 1175 additions and 23 deletions
+122
View File
@@ -0,0 +1,122 @@
package api
import (
"errors"
"fmt"
"net"
"strings"
"sync"
"git.julianfamily.org/keepassgo/clipboard"
"git.julianfamily.org/keepassgo/passwords"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
"git.julianfamily.org/keepassgo/session"
"git.julianfamily.org/keepassgo/vault"
"google.golang.org/grpc"
)
type DirtyProvider func() bool
type Host struct {
server *Server
grpcServer *grpc.Server
listener net.Listener
lifecycle lifecycleBackend
dirty DirtyProvider
mu sync.Mutex
lastModel vault.Model
started bool
listenAddr string
}
func StartHost(addr string, lifecycle lifecycleBackend, profiles map[string]passwords.Profile, clipboardWriter clipboard.Writer, dirty DirtyProvider) (*Host, error) {
addr = strings.TrimSpace(addr)
if addr == "" || strings.EqualFold(addr, "off") {
return nil, nil
}
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("listen gRPC host %s: %w", addr, err)
}
service := NewServerWithLifecycle(vault.Model{}, profiles, clipboardWriter, lifecycle)
server := grpc.NewServer(grpc.UnaryInterceptor(AuthInterceptor(service)))
keepassgov1.RegisterVaultServiceServer(server, service)
host := &Host{
server: service,
grpcServer: server,
listener: listener,
lifecycle: lifecycle,
dirty: dirty,
listenAddr: listener.Addr().String(),
started: true,
}
if err := host.SyncFromLifecycle(); err != nil && !errors.Is(err, session.ErrLocked) {
_ = listener.Close()
server.Stop()
return nil, err
}
go func() {
_ = server.Serve(listener)
}()
return host, nil
}
func (h *Host) Address() string {
if h == nil {
return ""
}
return h.listenAddr
}
func (h *Host) Server() *Server {
if h == nil {
return nil
}
return h.server
}
func (h *Host) Stop() error {
if h == nil {
return nil
}
h.mu.Lock()
defer h.mu.Unlock()
if !h.started {
return nil
}
h.started = false
h.grpcServer.Stop()
return h.listener.Close()
}
func (h *Host) SyncFromLifecycle() error {
if h == nil || h.lifecycle == nil || h.server == nil {
return nil
}
h.mu.Lock()
defer h.mu.Unlock()
model, err := h.lifecycle.Current()
locked := false
switch {
case err == nil:
h.lastModel = model
case errors.Is(err, session.ErrLocked):
locked = true
default:
return err
}
dirty := false
if h.dirty != nil {
dirty = h.dirty()
}
h.server.SetSessionState(h.lastModel, locked, dirty)
return nil
}
+72
View File
@@ -0,0 +1,72 @@
package api
import (
"context"
"net"
"testing"
"git.julianfamily.org/keepassgo/passwords"
keepassgov1 "git.julianfamily.org/keepassgo/proto/keepassgo/v1"
"git.julianfamily.org/keepassgo/session"
"git.julianfamily.org/keepassgo/vault"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func TestStartHostServesVaultLifecycleAndSyncsSessionState(t *testing.T) {
t.Parallel()
lifecycle := &session.Manager{}
if err := lifecycle.Create(vault.Model{
Entries: []vault.Entry{
testAPITokenEntry(t),
{ID: "entry-1", Title: "Git Server", Path: []string{"Root", "Internet"}},
},
}, vault.MasterKey{Password: "correct horse battery staple"}); err != nil {
t.Fatalf("Create() error = %v", err)
}
host, err := StartHost("127.0.0.1:0", lifecycle, passwords.DefaultProfiles(), nil, func() bool { return true })
if err != nil {
t.Fatalf("StartHost() error = %v", err)
}
defer func() { _ = host.Stop() }()
conn, err := grpc.NewClient("passthrough:///"+host.Address(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return net.Dial("tcp", host.Address())
}),
)
if err != nil {
t.Fatalf("grpc.NewClient() error = %v", err)
}
defer func() { _ = conn.Close() }()
client := keepassgov1.NewVaultServiceClient(conn)
statusResp, err := client.GetSessionStatus(tokenContext(defaultTestTokenSecret), &keepassgov1.GetSessionStatusRequest{})
if err != nil {
t.Fatalf("GetSessionStatus() error = %v", err)
}
if statusResp.Locked {
t.Fatal("GetSessionStatus().Locked = true, want false")
}
if !statusResp.Dirty {
t.Fatal("GetSessionStatus().Dirty = false, want true from dirty provider")
}
if err := lifecycle.Lock(); err != nil {
t.Fatalf("Lock() error = %v", err)
}
if err := host.SyncFromLifecycle(); err != nil {
t.Fatalf("SyncFromLifecycle() after lock error = %v", err)
}
statusResp, err = client.GetSessionStatus(tokenContext(defaultTestTokenSecret), &keepassgov1.GetSessionStatusRequest{})
if err != nil {
t.Fatalf("GetSessionStatus() after lock error = %v", err)
}
if !statusResp.Locked {
t.Fatal("GetSessionStatus().Locked = false, want true after lifecycle lock")
}
}
+10 -3
View File
@@ -77,9 +77,16 @@ func (s *Server) AuditLog() *apiaudit.Log {
return s.audit
}
func (s *Server) ResolveApproval(id string, outcome apiapproval.Outcome) (apiapproval.Request, error) {
request, _, err := s.approvals.Resolve(id, outcome)
return request, err
func (s *Server) ResolveApproval(id string, outcome apiapproval.Outcome) (apiapproval.Request, *apitokens.PolicyRule, error) {
return s.approvals.Resolve(id, outcome)
}
func (s *Server) SetSessionState(model vault.Model, locked, dirty bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.model = model
s.locked = locked
s.dirty = dirty
}
func (s *Server) GetSessionStatus(_ context.Context, _ *keepassgov1.GetSessionStatusRequest) (*keepassgov1.GetSessionStatusResponse, error) {
+4 -4
View File
@@ -131,7 +131,7 @@ func TestVaultServicePromptsAndResumesWhenApproved(t *testing.T) {
if pending.Operation != apitokens.OperationListEntries {
t.Fatalf("pending.Operation = %q, want %q", pending.Operation, apitokens.OperationListEntries)
}
if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeAllowOnce); err != nil {
if _, _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeAllowOnce); err != nil {
t.Fatalf("ResolveApproval(allow) error = %v", err)
}
@@ -171,7 +171,7 @@ func TestVaultServicePersistsPermanentDenyApproval(t *testing.T) {
}()
pending := waitForServerPendingApproval(t, service, 1)[0]
if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeDenyPermanent); err != nil {
if _, _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeDenyPermanent); err != nil {
t.Fatalf("ResolveApproval(deny permanent) error = %v", err)
}
@@ -211,7 +211,7 @@ func TestVaultServiceReturnsCanceledForCanceledApproval(t *testing.T) {
}()
pending := waitForServerPendingApproval(t, service, 1)[0]
if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeCancel); err != nil {
if _, _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeCancel); err != nil {
t.Fatalf("ResolveApproval(cancel) error = %v", err)
}
@@ -259,7 +259,7 @@ func TestVaultServiceRecordsApprovalAuditEvents(t *testing.T) {
}()
pending := waitForServerPendingApproval(t, service, 1)[0]
if _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeAllowPermanent); err != nil {
if _, _, err := service.ResolveApproval(pending.ID, apiapproval.OutcomeAllowPermanent); err != nil {
t.Fatalf("ResolveApproval(allow permanent) error = %v", err)
}
if err := <-errCh; err != nil {