From f8dc6590e6283c12df61f82b93420a684c7a9e6e Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Mon, 27 Apr 2026 12:18:09 +0530 Subject: [PATCH] feat(wsclient): add Hostlink WebSocket client with signed upgrade, reconnect, and feature flag --- app/services/requestsigner/signer.go | 29 +- app/services/requestsigner/signer_test.go | 44 ++- app/services/wsclient/client.go | 262 ++++++++++++++ app/services/wsclient/client_test.go | 406 ++++++++++++++++++++++ app/services/wsclient/gorilla.go | 53 +++ config/appconf/appconf.go | 58 ++++ config/appconf/appconf_test.go | 61 ++++ go.mod | 1 + go.sum | 2 + main.go | 38 ++ ws_startup_test.go | 76 ++++ 11 files changed, 1022 insertions(+), 8 deletions(-) create mode 100644 app/services/wsclient/client.go create mode 100644 app/services/wsclient/client_test.go create mode 100644 app/services/wsclient/gorilla.go create mode 100644 ws_startup_test.go diff --git a/app/services/requestsigner/signer.go b/app/services/requestsigner/signer.go index b982700..42bc32e 100644 --- a/app/services/requestsigner/signer.go +++ b/app/services/requestsigner/signer.go @@ -65,23 +65,38 @@ func New2(privateKeyPath, agentStatePath string) (*RequestSigner, error) { } func (s *RequestSigner) SignRequest(req *http.Request) error { + headers, err := s.SignHeaders() + if err != nil { + return err + } + for key, values := range headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + + return nil +} + +func (s *RequestSigner) SignHeaders() (http.Header, error) { timestamp := strconv.FormatInt(time.Now().Unix(), 10) nonceValue, err := s.generateNonce() if err != nil { - return fmt.Errorf("failed to generate nonce: %w", err) + return nil, fmt.Errorf("failed to generate nonce: %w", err) } signature, err := s.generateSignature(s.agentID, timestamp, nonceValue) if err != nil { - return fmt.Errorf("failed to generate signature: %w", err) + return nil, fmt.Errorf("failed to generate signature: %w", err) } - req.Header.Set("X-Agent-ID", s.agentID) - req.Header.Set("X-Timestamp", timestamp) - req.Header.Set("X-Nonce", nonceValue) - req.Header.Set("X-Signature", signature) + headers := http.Header{} + headers.Set("X-Agent-ID", s.agentID) + headers.Set("X-Timestamp", timestamp) + headers.Set("X-Nonce", nonceValue) + headers.Set("X-Signature", signature) - return nil + return headers, nil } func (s *RequestSigner) generateSignature(agentID, timestamp, nonce string) (string, error) { diff --git a/app/services/requestsigner/signer_test.go b/app/services/requestsigner/signer_test.go index 373ffb2..944b993 100644 --- a/app/services/requestsigner/signer_test.go +++ b/app/services/requestsigner/signer_test.go @@ -209,6 +209,48 @@ func TestRequestSigner_SignRequest(t *testing.T) { }) } +func TestRequestSigner_SignHeaders(t *testing.T) { + t.Run("should return required headers for websocket upgrade", func(t *testing.T) { + signer := setupTestSigner(t) + + headers, err := signer.SignHeaders() + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + requiredHeaders := []string{"X-Agent-ID", "X-Timestamp", "X-Nonce", "X-Signature"} + for _, header := range requiredHeaders { + if headers.Get(header) == "" { + t.Errorf("expected header %s to be set", header) + } + } + if headers.Get("X-Agent-ID") != "test-agent-123" { + t.Errorf("expected X-Agent-ID test-agent-123, got %q", headers.Get("X-Agent-ID")) + } + }) + + t.Run("should produce signature verifiable from returned headers", func(t *testing.T) { + signer := setupTestSigner(t) + + headers, err := signer.SignHeaders() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + signatureBytes, err := base64.StdEncoding.DecodeString(headers.Get("X-Signature")) + if err != nil { + t.Fatalf("expected base64 signature, got %v", err) + } + + message := fmt.Sprintf("%s|%s|%s", headers.Get("X-Agent-ID"), headers.Get("X-Timestamp"), headers.Get("X-Nonce")) + hashed := sha256.Sum256([]byte(message)) + if err := rsa.VerifyPSS(&signer.privateKey.PublicKey, crypto.SHA256, hashed[:], signatureBytes, nil); err != nil { + t.Errorf("signature verification failed: %v", err) + } + }) +} + func TestRequestSigner_GenerateSignature(t *testing.T) { t.Run("should generate valid RSA-PSS signature", func(t *testing.T) { signer := setupTestSigner(t) @@ -436,4 +478,4 @@ func saveTestPrivateKey(t *testing.T, dir string, key *rsa.PrivateKey) string { } return keyPath -} \ No newline at end of file +} diff --git a/app/services/wsclient/client.go b/app/services/wsclient/client.go new file mode 100644 index 0000000..45bc992 --- /dev/null +++ b/app/services/wsclient/client.go @@ -0,0 +1,262 @@ +package wsclient + +import ( + "context" + "errors" + "fmt" + "math/rand/v2" + "net/http" + "sync" + "time" + + "hostlink/app/services/agentstate" + "hostlink/app/services/requestsigner" + "hostlink/internal/wsprotocol" +) + +var ErrAgentNotRegistered = errors.New("agent not registered: missing agent ID") + +type Dialer interface { + Dial(ctx context.Context, url string, headers http.Header) (Conn, error) +} + +type Conn interface { + WriteEnvelope(ctx context.Context, env wsprotocol.Envelope) error + ReadEnvelope(ctx context.Context) (wsprotocol.Envelope, error) + Ping(ctx context.Context) error + Close() error +} + +type SleepFunc func(context.Context, time.Duration) error + +type Config struct { + URL string + AgentState *agentstate.AgentState + PrivateKeyPath string + Dialer Dialer + ReconnectMin time.Duration + ReconnectMax time.Duration + PingInterval time.Duration + SleepFunc SleepFunc +} + +type Client struct { + url string + agentID string + signer *requestsigner.RequestSigner + dialer Dialer + reconnectMin time.Duration + reconnectMax time.Duration + pingInterval time.Duration + sleep SleepFunc + + mu sync.RWMutex + active bool + lastAck *wsprotocol.AckPayload +} + +func New(cfg Config) (*Client, error) { + if cfg.AgentState == nil { + return nil, ErrAgentNotRegistered + } + agentID := cfg.AgentState.GetAgentID() + if agentID == "" { + return nil, ErrAgentNotRegistered + } + signer, err := requestsigner.New(cfg.PrivateKeyPath, agentID) + if err != nil { + return nil, fmt.Errorf("create request signer: %w", err) + } + if cfg.Dialer == nil { + cfg.Dialer = DefaultDialer{} + } + if cfg.ReconnectMin == 0 { + cfg.ReconnectMin = time.Second + } + if cfg.ReconnectMax == 0 { + cfg.ReconnectMax = 5 * time.Minute + } + if cfg.PingInterval == 0 { + cfg.PingInterval = 30 * time.Second + } + if cfg.SleepFunc == nil { + cfg.SleepFunc = sleepContext + } + + return &Client{ + url: cfg.URL, + agentID: agentID, + signer: signer, + dialer: cfg.Dialer, + reconnectMin: cfg.ReconnectMin, + reconnectMax: cfg.ReconnectMax, + pingInterval: cfg.PingInterval, + sleep: cfg.SleepFunc, + }, nil +} + +func (c *Client) Start(ctx context.Context) error { + backoff := c.reconnectMin + for { + if ctx.Err() != nil { + return nil + } + + err := c.runOnce(ctx) + c.setActive(false) + if ctx.Err() != nil { + return nil + } + if err == nil { + backoff = c.reconnectMin + continue + } + + delay := jitter(backoff) + if err := c.sleep(ctx, delay); err != nil { + return nil + } + backoff *= 2 + if backoff > c.reconnectMax { + backoff = c.reconnectMax + } + } +} + +func (c *Client) IsActive() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.active +} + +func (c *Client) LastAck() *wsprotocol.AckPayload { + c.mu.RLock() + defer c.mu.RUnlock() + if c.lastAck == nil { + return nil + } + ack := *c.lastAck + return &ack +} + +func (c *Client) runOnce(ctx context.Context) error { + headers, err := c.signer.SignHeaders() + if err != nil { + return err + } + conn, err := c.dialer.Dial(ctx, c.url, headers) + if err != nil { + return err + } + defer conn.Close() + + hello := c.buildHello() + if err := conn.WriteEnvelope(ctx, hello); err != nil { + return err + } + + readErr := make(chan error, 1) + go func() { readErr <- c.readLoop(ctx, conn, hello.MessageID) }() + + if c.pingInterval <= 0 { + return <-readErr + } + ticker := time.NewTicker(c.pingInterval) + defer ticker.Stop() + + for { + select { + case err := <-readErr: + return err + case <-ticker.C: + if err := conn.Ping(ctx); err != nil { + _ = conn.Close() + return err + } + case <-ctx.Done(): + return nil + } + } +} + +func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) error { + for { + env, err := conn.ReadEnvelope(ctx) + if err != nil { + if ctx.Err() != nil { + return nil + } + return err + } + if err := env.Validate(c.agentID); err != nil { + return err + } + + switch env.Type { + case wsprotocol.TypeAgentHelloAck: + ack, err := wsprotocol.DecodePayload[wsprotocol.AckPayload](env) + if err != nil { + return err + } + if ack.AckedMessageID == helloMessageID { + c.setActive(true) + } + c.setLastAck(&ack) + case wsprotocol.TypeAck: + ack, err := wsprotocol.DecodePayload[wsprotocol.AckPayload](env) + if err != nil { + return err + } + c.setLastAck(&ack) + case wsprotocol.TypeError: + return fmt.Errorf("websocket protocol error: %s", env.MessageID) + default: + return fmt.Errorf("unsupported inbound websocket message type: %s", env.Type) + } + } +} + +func (c *Client) buildHello() wsprotocol.Envelope { + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), + Type: wsprotocol.TypeAgentHello, + AgentID: c.agentID, + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: map[string]any{}, + } +} + +func (c *Client) setActive(active bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.active = active +} + +func (c *Client) setLastAck(ack *wsprotocol.AckPayload) { + c.mu.Lock() + defer c.mu.Unlock() + c.lastAck = ack +} + +func sleepContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func jitter(d time.Duration) time.Duration { + if d <= 0 { + return 0 + } + delta := d / 4 + if delta <= 0 { + return d + } + return d - delta + time.Duration(rand.Int64N(int64(delta*2))) +} diff --git a/app/services/wsclient/client_test.go b/app/services/wsclient/client_test.go new file mode 100644 index 0000000..b89b00e --- /dev/null +++ b/app/services/wsclient/client_test.go @@ -0,0 +1,406 @@ +package wsclient + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "net/http" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "hostlink/app/services/agentstate" + "hostlink/internal/wsprotocol" +) + +func TestClientSendsHelloAndMarksActiveAfterHelloAck(t *testing.T) { + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + written := conn.waitForWrite(t) + if written.Type != wsprotocol.TypeAgentHello { + t.Fatalf("written type = %q, want %q", written.Type, wsprotocol.TypeAgentHello) + } + if written.AgentID != "agent_ws_test" { + t.Fatalf("written agent_id = %q", written.AgentID) + } + if len(written.Payload) != 0 { + t.Fatalf("hello payload = %#v, want empty object", written.Payload) + } + + conn.readCh <- wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_ack", + Type: wsprotocol.TypeAgentHelloAck, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMap(t, wsprotocol.BuildAck(wsprotocol.AckOptions{ + AckedMessageID: written.MessageID, + AckedType: wsprotocol.TypeAgentHello, + })), + } + + waitFor(t, func() bool { return client.IsActive() }, "client to become active") + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientDialUsesSignedUpgradeHeaders(t *testing.T) { + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer) + + runCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + conn.waitForWrite(t) + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } + + for _, header := range []string{"X-Agent-ID", "X-Timestamp", "X-Nonce", "X-Signature"} { + if dialer.headers.Get(header) == "" { + t.Fatalf("expected signed upgrade header %s", header) + } + } + if dialer.headers.Get("X-Agent-ID") != "agent_ws_test" { + t.Fatalf("X-Agent-ID = %q", dialer.headers.Get("X-Agent-ID")) + } +} + +func TestClientHandlesAckWithoutTaskSideEffects(t *testing.T) { + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + conn.waitForWrite(t) + conn.readCh <- wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_generic_ack", + Type: wsprotocol.TypeAck, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMap(t, wsprotocol.BuildAck(wsprotocol.AckOptions{ + AckedMessageID: "msg_other", + AckedType: wsprotocol.TypeAck, + })), + } + + waitFor(t, func() bool { return client.LastAck() != nil }, "ack to be recorded") + if client.LastAck().AckedMessageID != "msg_other" { + t.Fatalf("last ack = %#v", client.LastAck()) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientErrorMessageTriggersReconnect(t *testing.T) { + first := newFakeConn() + second := newFakeConn() + dialer := &fakeDialer{conns: []*fakeConn{first, second}} + sleeps := make(chan time.Duration, 2) + client := newTestClient(t, dialer, WithSleepFunc(func(ctx context.Context, d time.Duration) error { + sleeps <- d + return nil + })) + + runCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + first.waitForWrite(t) + first.readCh <- wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_error", + Type: wsprotocol.TypeError, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMap(t, wsprotocol.BuildError(wsprotocol.ErrorOptions{ + Code: "expected_agent_hello", + Message: "first message must be agent.hello", + RelatedMessageID: "msg_hello", + })), + } + + select { + case <-sleeps: + case <-time.After(time.Second): + t.Fatal("expected reconnect backoff sleep") + } + second.waitForWrite(t) + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestNewFailsForMissingAgentState(t *testing.T) { + state := agentstate.New(t.TempDir()) + + _, err := New(Config{ + URL: "ws://example.test/api/v1/agents/ws", + AgentState: state, + PrivateKeyPath: saveTestPrivateKey(t, t.TempDir()), + }) + + if err == nil || !errors.Is(err, ErrAgentNotRegistered) { + t.Fatalf("New error = %v, want ErrAgentNotRegistered", err) + } +} + +func TestNewFailsForMissingPrivateKey(t *testing.T) { + state := agentstate.New(t.TempDir()) + if err := state.SetAgentID("agent_ws_test"); err != nil { + t.Fatalf("set agent ID: %v", err) + } + + _, err := New(Config{ + URL: "ws://example.test/api/v1/agents/ws", + AgentState: state, + PrivateKeyPath: filepath.Join(t.TempDir(), "missing.key"), + }) + + if err == nil { + t.Fatal("expected missing private key error") + } +} + +func TestReconnectsAfterServerClose(t *testing.T) { + first := newFakeConn() + second := newFakeConn() + dialer := &fakeDialer{conns: []*fakeConn{first, second}} + sleeps := make(chan time.Duration, 2) + client := newTestClient(t, dialer, WithSleepFunc(func(ctx context.Context, d time.Duration) error { + sleeps <- d + return nil + })) + + runCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + first.waitForWrite(t) + first.readErr <- errors.New("server closed") + select { + case <-sleeps: + case <-time.After(time.Second): + t.Fatal("expected reconnect backoff sleep") + } + second.waitForWrite(t) + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestReconnectsAfterPingFailure(t *testing.T) { + first := newFakeConn() + first.pingErr = errors.New("ping failed") + second := newFakeConn() + dialer := &fakeDialer{conns: []*fakeConn{first, second}} + client := newTestClient(t, dialer, + WithPingInterval(time.Millisecond), + WithSleepFunc(func(ctx context.Context, d time.Duration) error { return nil }), + ) + + runCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + first.waitForWrite(t) + second.waitForWrite(t) + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } + if !first.closed() { + t.Fatal("expected first connection to be closed after ping failure") + } +} + +type clientOption func(*Config) + +func newTestClient(t *testing.T, dialer Dialer, opts ...clientOption) *Client { + t.Helper() + state := agentstate.New(t.TempDir()) + if err := state.SetAgentID("agent_ws_test"); err != nil { + t.Fatalf("set agent ID: %v", err) + } + cfg := Config{ + URL: "ws://example.test/api/v1/agents/ws", + AgentState: state, + PrivateKeyPath: saveTestPrivateKey(t, t.TempDir()), + Dialer: dialer, + ReconnectMin: time.Millisecond, + ReconnectMax: 10 * time.Millisecond, + PingInterval: time.Hour, + } + for _, opt := range opts { + opt(&cfg) + } + client, err := New(cfg) + if err != nil { + t.Fatalf("New client: %v", err) + } + return client +} + +func WithSleepFunc(fn SleepFunc) clientOption { + return func(cfg *Config) { cfg.SleepFunc = fn } +} + +func WithPingInterval(d time.Duration) clientOption { + return func(cfg *Config) { cfg.PingInterval = d } +} + +type fakeDialer struct { + mu sync.Mutex + conn *fakeConn + conns []*fakeConn + headers http.Header + calls int +} + +func (d *fakeDialer) Dial(ctx context.Context, url string, headers http.Header) (Conn, error) { + d.mu.Lock() + defer d.mu.Unlock() + d.calls++ + d.headers = headers.Clone() + if len(d.conns) > 0 { + conn := d.conns[0] + d.conns = d.conns[1:] + return conn, nil + } + return d.conn, nil +} + +type fakeConn struct { + readCh chan wsprotocol.Envelope + readErr chan error + writeCh chan wsprotocol.Envelope + pingErr error + mu sync.Mutex + closedV bool +} + +func newFakeConn() *fakeConn { + return &fakeConn{ + readCh: make(chan wsprotocol.Envelope, 4), + readErr: make(chan error, 4), + writeCh: make(chan wsprotocol.Envelope, 4), + } +} + +func (c *fakeConn) WriteEnvelope(ctx context.Context, env wsprotocol.Envelope) error { + select { + case c.writeCh <- env: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *fakeConn) ReadEnvelope(ctx context.Context) (wsprotocol.Envelope, error) { + select { + case env := <-c.readCh: + return env, nil + case err := <-c.readErr: + return wsprotocol.Envelope{}, err + case <-ctx.Done(): + return wsprotocol.Envelope{}, ctx.Err() + } +} + +func (c *fakeConn) Ping(ctx context.Context) error { + return c.pingErr +} + +func (c *fakeConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closedV = true + return nil +} + +func (c *fakeConn) closed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closedV +} + +func (c *fakeConn) waitForWrite(t *testing.T) wsprotocol.Envelope { + t.Helper() + select { + case env := <-c.writeCh: + return env + case <-time.After(time.Second): + t.Fatal("timed out waiting for written envelope") + return wsprotocol.Envelope{} + } +} + +func payloadMap(t *testing.T, value any) map[string]any { + t.Helper() + data, err := json.Marshal(value) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + var payload map[string]any + if err := json.Unmarshal(data, &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + return payload +} + +func waitFor(t *testing.T, check func() bool, description string) { + t.Helper() + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if check() { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("timed out waiting for %s", description) +} + +func saveTestPrivateKey(t *testing.T, dir string) string { + t.Helper() + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + keyPath := filepath.Join(dir, "agent.key") + file, err := os.Create(keyPath) + if err != nil { + t.Fatalf("create key: %v", err) + } + defer file.Close() + if err := pem.Encode(file, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}); err != nil { + t.Fatalf("write key: %v", err) + } + return keyPath +} diff --git a/app/services/wsclient/gorilla.go b/app/services/wsclient/gorilla.go new file mode 100644 index 0000000..73db8c6 --- /dev/null +++ b/app/services/wsclient/gorilla.go @@ -0,0 +1,53 @@ +package wsclient + +import ( + "context" + "net/http" + "time" + + "hostlink/internal/wsprotocol" + + "github.com/gorilla/websocket" +) + +type DefaultDialer struct{} + +func (DefaultDialer) Dial(ctx context.Context, url string, headers http.Header) (Conn, error) { + conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, headers) + if err != nil { + return nil, err + } + return &gorillaConn{conn: conn}, nil +} + +type gorillaConn struct { + conn *websocket.Conn +} + +func (c *gorillaConn) WriteEnvelope(ctx context.Context, env wsprotocol.Envelope) error { + if deadline, ok := ctx.Deadline(); ok { + _ = c.conn.SetWriteDeadline(deadline) + } + return c.conn.WriteJSON(env) +} + +func (c *gorillaConn) ReadEnvelope(ctx context.Context) (wsprotocol.Envelope, error) { + if deadline, ok := ctx.Deadline(); ok { + _ = c.conn.SetReadDeadline(deadline) + } + var env wsprotocol.Envelope + err := c.conn.ReadJSON(&env) + return env, err +} + +func (c *gorillaConn) Ping(ctx context.Context) error { + deadline := time.Now().Add(10 * time.Second) + if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(deadline) { + deadline = ctxDeadline + } + return c.conn.WriteControl(websocket.PingMessage, nil, deadline) +} + +func (c *gorillaConn) Close() error { + return c.conn.Close() +} diff --git a/config/appconf/appconf.go b/config/appconf/appconf.go index bc736ed..6eec133 100644 --- a/config/appconf/appconf.go +++ b/config/appconf/appconf.go @@ -2,7 +2,9 @@ package appconf import ( + "net/url" "os" + "path" "strings" "time" @@ -80,6 +82,62 @@ func SelfUpdateEnabled() bool { } } +// WebSocketEnabled returns whether the agent WebSocket client is enabled. +// Controlled by HOSTLINK_WS_ENABLED (default: false). +func WebSocketEnabled() bool { + v := strings.TrimSpace(os.Getenv("HOSTLINK_WS_ENABLED")) + if v == "" { + return false + } + switch strings.ToLower(v) { + case "true", "1", "yes": + return true + default: + return false + } +} + +// WebSocketURL returns the agent WebSocket gateway URL. +// Controlled by HOSTLINK_WS_URL, otherwise derived from SH_CONTROL_PLANE_URL. +func WebSocketURL() string { + if rawURL := strings.TrimSpace(os.Getenv("HOSTLINK_WS_URL")); rawURL != "" { + return rawURL + } + + parsed, err := url.Parse(ControlPlaneURL()) + if err != nil { + log.Warnf("invalid control plane URL %q, using as websocket URL base", ControlPlaneURL()) + return ControlPlaneURL() + "/api/v1/agents/ws" + } + + switch parsed.Scheme { + case "https": + parsed.Scheme = "wss" + case "http": + parsed.Scheme = "ws" + } + parsed.Path = path.Join(parsed.Path, "/api/v1/agents/ws") + return parsed.String() +} + +// WebSocketReconnectMin returns the initial WebSocket reconnect delay. +// Controlled by HOSTLINK_WS_RECONNECT_MIN (default: 1s, clamped to [100ms, 5m]). +func WebSocketReconnectMin() time.Duration { + return parseDurationClamped("HOSTLINK_WS_RECONNECT_MIN", time.Second, 100*time.Millisecond, 5*time.Minute) +} + +// WebSocketReconnectMax returns the maximum WebSocket reconnect delay. +// Controlled by HOSTLINK_WS_RECONNECT_MAX (default: 5m, clamped to [1s, 1h]). +func WebSocketReconnectMax() time.Duration { + return parseDurationClamped("HOSTLINK_WS_RECONNECT_MAX", 5*time.Minute, time.Second, time.Hour) +} + +// WebSocketPingInterval returns the WebSocket keepalive ping interval. +// Controlled by HOSTLINK_WS_PING_INTERVAL (default: 30s, clamped to [5s, 5m]). +func WebSocketPingInterval() time.Duration { + return parseDurationClamped("HOSTLINK_WS_PING_INTERVAL", 30*time.Second, 5*time.Second, 5*time.Minute) +} + // UpdateCheckInterval returns the interval between update checks. // Controlled by HOSTLINK_UPDATE_CHECK_INTERVAL (default: 5m, clamped to [1m, 24h]). func UpdateCheckInterval() time.Duration { diff --git a/config/appconf/appconf_test.go b/config/appconf/appconf_test.go index 9ed330b..030d3db 100644 --- a/config/appconf/appconf_test.go +++ b/config/appconf/appconf_test.go @@ -86,3 +86,64 @@ func TestInstallPath_CustomValue(t *testing.T) { t.Setenv("HOSTLINK_INSTALL_PATH", "/opt/hostlink/bin/hostlink") assert.Equal(t, "/opt/hostlink/bin/hostlink", InstallPath()) } + +func TestWebSocketEnabled_DefaultFalse(t *testing.T) { + t.Setenv("HOSTLINK_WS_ENABLED", "") + assert.False(t, WebSocketEnabled()) +} + +func TestWebSocketEnabled_ExplicitTrue(t *testing.T) { + t.Setenv("HOSTLINK_WS_ENABLED", "true") + assert.True(t, WebSocketEnabled()) +} + +func TestWebSocketEnabled_ExplicitFalse(t *testing.T) { + t.Setenv("HOSTLINK_WS_ENABLED", "0") + assert.False(t, WebSocketEnabled()) +} + +func TestWebSocketURL_DerivesWSSFromHTTPSControlPlane(t *testing.T) { + t.Setenv("HOSTLINK_WS_URL", "") + t.Setenv("SH_CONTROL_PLANE_URL", "https://api.selfhost.dev") + + assert.Equal(t, "wss://api.selfhost.dev/api/v1/agents/ws", WebSocketURL()) +} + +func TestWebSocketURL_DerivesWSFromHTTPControlPlane(t *testing.T) { + t.Setenv("HOSTLINK_WS_URL", "") + t.Setenv("SH_CONTROL_PLANE_URL", "http://localhost:3000") + + assert.Equal(t, "ws://localhost:3000/api/v1/agents/ws", WebSocketURL()) +} + +func TestWebSocketURL_OverrideWins(t *testing.T) { + t.Setenv("HOSTLINK_WS_URL", "ws://127.0.0.1:9090/custom") + t.Setenv("SH_CONTROL_PLANE_URL", "https://api.selfhost.dev") + + assert.Equal(t, "ws://127.0.0.1:9090/custom", WebSocketURL()) +} + +func TestWebSocketReconnectMin_Default1s(t *testing.T) { + t.Setenv("HOSTLINK_WS_RECONNECT_MIN", "") + assert.Equal(t, time.Second, WebSocketReconnectMin()) +} + +func TestWebSocketReconnectMin_CustomValue(t *testing.T) { + t.Setenv("HOSTLINK_WS_RECONNECT_MIN", "5s") + assert.Equal(t, 5*time.Second, WebSocketReconnectMin()) +} + +func TestWebSocketReconnectMax_Default5m(t *testing.T) { + t.Setenv("HOSTLINK_WS_RECONNECT_MAX", "") + assert.Equal(t, 5*time.Minute, WebSocketReconnectMax()) +} + +func TestWebSocketPingInterval_Default30s(t *testing.T) { + t.Setenv("HOSTLINK_WS_PING_INTERVAL", "") + assert.Equal(t, 30*time.Second, WebSocketPingInterval()) +} + +func TestWebSocketPingInterval_CustomValue(t *testing.T) { + t.Setenv("HOSTLINK_WS_PING_INTERVAL", "45s") + assert.Equal(t, 45*time.Second, WebSocketPingInterval()) +} diff --git a/go.mod b/go.mod index 11a6777..62551cc 100644 --- a/go.mod +++ b/go.mod @@ -50,6 +50,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect diff --git a/go.sum b/go.sum index 31d8438..9dae07d 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17k github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= diff --git a/main.go b/main.go index 20708da..7cde99b 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "hostlink/app/services/updatecheck" "hostlink/app/services/updatedownload" "hostlink/app/services/updatepreflight" + "hostlink/app/services/wsclient" "hostlink/cmd/upgrade" "hostlink/config" "hostlink/config/appconf" @@ -258,6 +259,7 @@ func runServer(ctx context.Context, cmd *cli.Command) error { // Wait for registration to complete <-registeredChan log.Println("Agent registered, starting task job...") + startWebSocketClientIfEnabled(ctx, newDefaultWebSocketRuntime) fetcher, err := taskfetcher.NewDefault() if err != nil { @@ -297,6 +299,42 @@ func runServer(ctx context.Context, cmd *cli.Command) error { return e.Start(fmt.Sprintf(":%s", appconf.Port())) } +type webSocketRuntime interface { + Start(context.Context) error +} + +func startWebSocketClientIfEnabled(ctx context.Context, constructor func() (webSocketRuntime, error)) bool { + if !appconf.WebSocketEnabled() { + return false + } + runtime, err := constructor() + if err != nil { + log.Printf("failed to initialize websocket client: %v", err) + return false + } + go func() { + if err := runtime.Start(ctx); err != nil { + log.Printf("websocket client stopped with error: %v", err) + } + }() + return true +} + +func newDefaultWebSocketRuntime() (webSocketRuntime, error) { + state := agentstate.New(appconf.AgentStatePath()) + if err := state.Load(); err != nil { + return nil, fmt.Errorf("failed to load agent state: %w", err) + } + return wsclient.New(wsclient.Config{ + URL: appconf.WebSocketURL(), + AgentState: state, + PrivateKeyPath: appconf.AgentPrivateKeyPath(), + ReconnectMin: appconf.WebSocketReconnectMin(), + ReconnectMax: appconf.WebSocketReconnectMax(), + PingInterval: appconf.WebSocketPingInterval(), + }) +} + func startSelfUpdateJob(ctx context.Context) { paths := update.DefaultPaths() diff --git a/ws_startup_test.go b/ws_startup_test.go new file mode 100644 index 0000000..2e042f8 --- /dev/null +++ b/ws_startup_test.go @@ -0,0 +1,76 @@ +package main + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestStartWebSocketClientIfEnabled_DisabledNoops(t *testing.T) { + t.Setenv("HOSTLINK_WS_ENABLED", "") + called := false + + started := startWebSocketClientIfEnabled(context.Background(), func() (webSocketRuntime, error) { + called = true + return &fakeWebSocketRuntime{}, nil + }) + + if started { + t.Fatal("expected websocket startup to be skipped") + } + if called { + t.Fatal("expected constructor not to be called when websocket is disabled") + } +} + +func TestStartWebSocketClientIfEnabled_EnabledStartsAsync(t *testing.T) { + t.Setenv("HOSTLINK_WS_ENABLED", "true") + startedCh := make(chan struct{}) + releaseCh := make(chan struct{}) + runtime := &fakeWebSocketRuntime{startedCh: startedCh, releaseCh: releaseCh} + + started := startWebSocketClientIfEnabled(context.Background(), func() (webSocketRuntime, error) { + return runtime, nil + }) + + if !started { + t.Fatal("expected websocket startup to be attempted") + } + select { + case <-startedCh: + case <-time.After(time.Second): + t.Fatal("expected websocket client to start asynchronously") + } + close(releaseCh) +} + +func TestStartWebSocketClientIfEnabled_ConstructorFailureDoesNotStart(t *testing.T) { + t.Setenv("HOSTLINK_WS_ENABLED", "true") + + started := startWebSocketClientIfEnabled(context.Background(), func() (webSocketRuntime, error) { + return nil, errors.New("missing agent state") + }) + + if started { + t.Fatal("expected websocket startup to report not started after constructor failure") + } +} + +type fakeWebSocketRuntime struct { + startedCh chan struct{} + releaseCh chan struct{} +} + +func (f *fakeWebSocketRuntime) Start(ctx context.Context) error { + if f.startedCh != nil { + close(f.startedCh) + } + if f.releaseCh != nil { + select { + case <-f.releaseCh: + case <-ctx.Done(): + } + } + return nil +}