diff --git a/CODEX.md b/CODEX.md new file mode 100644 index 0000000..4ee686d --- /dev/null +++ b/CODEX.md @@ -0,0 +1,25 @@ +# CODEX.md — go-stream + +This repository keeps its working conventions in [CLAUDE.md](/workspace/CLAUDE.md). + +Read these two documents before changing code: + +```text +docs/RFC.md — go-stream implementation spec +docs/RFC-025-AGENT-EXPERIENCE.md — AX design principles +``` + +Key conventions: + +- Use `core.E(scope, message, cause)` for errors. +- Keep comments as concrete usage examples. +- Prefer predictable names over shorthand. +- Preserve the transport-agnostic public API and the `ws` compatibility surface. + +Commit convention: + +```text +type(scope): description + +Co-Authored-By: Virgil +``` diff --git a/adapter/redis/ax7_more_test.go b/adapter/redis/ax7_more_test.go new file mode 100644 index 0000000..18294a6 --- /dev/null +++ b/adapter/redis/ax7_more_test.go @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package redis + +import ( + "github.com/alicebob/miniredis/v2" + + core "dappco.re/go" + "dappco.re/go/stream" +) + +func ax7StartedBridge(t *core.T) (*Bridge, core.CancelFunc) { + redisServer := miniredis.RunT(t) + hub := stream.NewHub() + ctx, cancel := core.WithCancel(core.Background()) + go hub.Run(ctx) + + bridge, err := NewBridge(hub, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + go func() { + if err := bridge.Start(ctx); err != nil { + t.Errorf("Start() error = %v", err) + } + }() + core.Sleep(100 * core.Millisecond) + return bridge, cancel +} + +func TestAX7_NewBridge_Good(t *core.T) { + redisServer := miniredis.RunT(t) + hub := stream.NewHub() + + bridge, err := NewBridge(hub, Config{Addr: redisServer.Addr()}) + core.AssertNoError(t, err) + core.AssertEqual(t, hub, bridge.hub) + core.AssertEqual(t, "stream", bridge.config.Prefix) +} + +func TestAX7_NewBridge_Bad(t *core.T) { + redisServer := miniredis.RunT(t) + + bridge, err := NewBridge(nil, Config{Addr: redisServer.Addr()}) + core.AssertError(t, err) + core.AssertNil(t, bridge) +} + +func TestAX7_NewBridge_Ugly(t *core.T) { + redisServer := miniredis.RunT(t) + hub := stream.NewHub() + + left, err := NewBridge(hub, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + right, err := NewBridge(hub, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + core.AssertNotEqual(t, left.SourceID(), right.SourceID()) +} + +func TestAX7_Bridge_SourceID_Good(t *core.T) { + redisServer := miniredis.RunT(t) + bridge, err := NewBridge(stream.NewHub(), Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + + core.AssertNotEmpty(t, bridge.SourceID()) + core.AssertEqual(t, 36, core.RuneCount(bridge.SourceID())) +} + +func TestAX7_Bridge_SourceID_Bad(t *core.T) { + var bridge *Bridge + + core.AssertEqual(t, "", bridge.SourceID()) + core.AssertNil(t, bridge) +} + +func TestAX7_Bridge_SourceID_Ugly(t *core.T) { + redisServer := miniredis.RunT(t) + bridge, err := NewBridge(stream.NewHub(), Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + + sourceID := bridge.SourceID() + core.AssertEqual(t, sourceID, bridge.SourceID()) + core.AssertNotEmpty(t, sourceID) +} + +func TestAX7_Bridge_Start_Good(t *core.T) { + bridge, cancel := ax7StartedBridge(t) + defer cancel() + + bridge.mutex.RLock() + running := bridge.running + bridge.mutex.RUnlock() + core.AssertTrue(t, running) +} + +func TestAX7_Bridge_Start_Bad(t *core.T) { + var bridge *Bridge + + err := bridge.Start(core.Background()) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "nil bridge") +} + +func TestAX7_Bridge_Start_Ugly(t *core.T) { + bridge, cancel := ax7StartedBridge(t) + defer cancel() + + err := bridge.Start(core.Background()) + core.AssertNoError(t, err) + core.AssertNotEmpty(t, bridge.SourceID()) +} + +func TestAX7_Bridge_Stop_Good(t *core.T) { + bridge, cancel := ax7StartedBridge(t) + defer cancel() + + core.AssertNoError(t, bridge.Stop()) + core.Sleep(50 * core.Millisecond) + bridge.mutex.RLock() + running := bridge.running + bridge.mutex.RUnlock() + core.AssertFalse(t, running) +} + +func TestAX7_Bridge_Stop_Bad(t *core.T) { + var bridge *Bridge + + core.AssertNoError(t, bridge.Stop()) + core.AssertNil(t, bridge) +} + +func TestAX7_Bridge_Stop_Ugly(t *core.T) { + redisServer := miniredis.RunT(t) + bridge, err := NewBridge(stream.NewHub(), Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + + core.AssertNoError(t, bridge.Stop()) + core.AssertNotEmpty(t, bridge.SourceID()) +} + +func TestAX7_Bridge_PublishToChannel_Good(t *core.T) { + redisServer := miniredis.RunT(t) + hub1 := stream.NewHub() + hub2 := stream.NewHub() + ctx, cancel := core.WithCancel(core.Background()) + defer cancel() + go hub1.Run(ctx) + go hub2.Run(ctx) + bridge1, err := NewBridge(hub1, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + bridge2, err := NewBridge(hub2, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + go func() { core.AssertNoError(t, bridge1.Start(ctx)) }() + go func() { core.AssertNoError(t, bridge2.Start(ctx)) }() + core.Sleep(100 * core.Millisecond) + + received := make(chan []byte, 1) + stop := hub2.Subscribe("block", func(frame []byte) { received <- append([]byte(nil), frame...) }) + defer stop() + core.AssertNoError(t, bridge1.PublishToChannel("block", []byte("template"))) + core.AssertEqual(t, "template", string(<-received)) +} + +func TestAX7_Bridge_PublishToChannel_Bad(t *core.T) { + redisServer := miniredis.RunT(t) + bridge, err := NewBridge(stream.NewHub(), Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + + err = bridge.PublishToChannel("block", []byte("template")) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "not started") +} + +func TestAX7_Bridge_PublishToChannel_Ugly(t *core.T) { + bridge, cancel := ax7StartedBridge(t) + defer cancel() + + err := bridge.PublishToChannel("", []byte("template")) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "empty channel") +} + +func TestAX7_Bridge_PublishBroadcast_Bad(t *core.T) { + redisServer := miniredis.RunT(t) + bridge, err := NewBridge(stream.NewHub(), Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.RequireNoError(t, err) + + err = bridge.PublishBroadcast([]byte("shutdown")) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "not started") +} + +func TestAX7_Bridge_PublishBroadcast_Ugly(t *core.T) { + var bridge *Bridge + + err := bridge.PublishBroadcast([]byte("shutdown")) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "nil bridge") +} diff --git a/adapter/redis/example_test.go b/adapter/redis/example_test.go new file mode 100644 index 0000000..2198db5 --- /dev/null +++ b/adapter/redis/example_test.go @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package redis_test + +import ( + "context" + + "dappco.re/go/stream" + "dappco.re/go/stream/adapter/redis" +) + +func ExampleNewBridge() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hub := stream.NewHub() + go hub.Run(ctx) + + bridge, err := redis.NewBridge(hub, redis.Config{ + Addr: "127.0.0.1:6379", + Prefix: "pool", + }) + if err != nil { + return + } + defer bridge.Stop() + + go func() { + _ = bridge.Start(ctx) + }() + + _ = bridge.PublishToChannel("block", []byte("template")) +} diff --git a/adapter/redis/redis.go b/adapter/redis/redis.go index f10586b..5115731 100644 --- a/adapter/redis/redis.go +++ b/adapter/redis/redis.go @@ -1,21 +1,29 @@ // SPDX-License-Identifier: EUPL-1.2 -// Package redis is the Redis pub/sub bridge for stream.Hub. -// Enables cross-instance coordination: multiple Hub instances on different nodes -// using the same Redis backend coordinate broadcasts and channel messages transparently. +// bridge, err := redis.NewBridge(hub, redis.Config{Addr: "redis:6379", Prefix: "pool"}) +// +// if err != nil { +// return err +// } +// +// go bridge.Start(ctx) package redis import ( "context" + "crypto/rand" "crypto/tls" - "strconv" + "encoding/hex" "sync" + "time" - "dappco.re/go/core" + "github.com/redis/go-redis/v9" + + "dappco.re/go" "dappco.re/go/stream" ) -// Config configures the Redis bridge. +// config := redis.Config{Addr: "127.0.0.1:6379", Prefix: "pool"} type Config struct { Addr string Password string @@ -24,179 +32,301 @@ type Config struct { TLSConfig *tls.Config } -// Bridge connects a Hub to Redis pub/sub for cross-instance messaging. +// bridge, err := redis.NewBridge(hub, redis.Config{Addr: "127.0.0.1:6379", Prefix: "pool"}) +// +// if err != nil { +// return err +// } +// +// go bridge.Start(ctx) +// defer bridge.Stop() type Bridge struct { hub *stream.Hub config Config sourceID string - mu sync.Mutex - running bool - stopCh chan struct{} + mutex sync.RWMutex + running bool + cancel context.CancelFunc + pubsub *redis.PubSub + client *redis.Client + publishStop func() + broadcastStop func() } -type bridgeRegistry struct { - mu sync.RWMutex - bridges map[string]map[*Bridge]struct{} +type envelope struct { + SourceID string `json:"s"` + Frame []byte `json:"f"` } -var registry = bridgeRegistry{bridges: map[string]map[*Bridge]struct{}{}} - -// NewBridge creates and validates the Redis connection. Does not start listening. -func NewBridge(hub *stream.Hub, cfg Config) (*Bridge, error) { +// bridge, err := redis.NewBridge(hub, config) +func NewBridge(hub *stream.Hub, config Config) (*Bridge, error) { if hub == nil { return nil, core.E("stream.redis", "nil hub", nil) } - if cfg.Addr == "" { + if config.Addr == "" { return nil, core.E("stream.redis", "empty address", nil) } - if cfg.Prefix == "" { - cfg.Prefix = "stream" + if config.Prefix == "" { + config.Prefix = "stream" } + client := newRedisClient(config) + defer client.Close() + + pingContext, pingCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pingCancel() + if err := client.Ping(pingContext).Err(); err != nil { + return nil, core.E("stream.redis", "redis ping failed", err) + } + return &Bridge{ hub: hub, - config: cfg, - sourceID: stream.NewPeer("redis").ID, - stopCh: make(chan struct{}), + config: config, + sourceID: randomSourceID(), }, nil } -// Start begins the Redis pub/sub listener. Blocks in a goroutine until Stop() or ctx cancel. -func (b *Bridge) Start(ctx context.Context) error { - if b == nil { +// go bridge.Start(ctx) +func (bridge *Bridge) Start(ctx context.Context) error { + if bridge == nil { return core.E("stream.redis", "nil bridge", nil) } - b.mu.Lock() - if b.running { - b.mu.Unlock() - <-ctx.Done() + if ctx == nil { + ctx = context.Background() + } + + bridge.mutex.Lock() + if bridge.running { + bridge.mutex.Unlock() return nil } - b.running = true - stopCh := b.stopCh - key := b.registryKey() - b.mu.Unlock() + bridge.running = true + bridge.mutex.Unlock() - registry.add(key, b) - defer registry.remove(key, b) + runContext, runCancel := context.WithCancel(ctx) + client := newRedisClient(bridge.config) + pubsub := client.PSubscribe(runContext, bridge.broadcastChannel(), bridge.channelPattern()) + publishStop := bridge.hub.SubscribePublished(func(channel string, frame []byte) { + if channel == "" { + return + } + if err := bridge.publishWithClient(client, bridge.channelKey(channel), frame); err != nil { + return + } + }) + broadcastStop := bridge.hub.SubscribeBroadcast(func(frame []byte) { + if err := bridge.publishWithClient(client, bridge.broadcastChannel(), frame); err != nil { + return + } + }) - select { - case <-ctx.Done(): - case <-stopCh: - } + bridge.mutex.Lock() + bridge.cancel = runCancel + bridge.client = client + bridge.pubsub = pubsub + bridge.publishStop = publishStop + bridge.broadcastStop = broadcastStop + bridge.mutex.Unlock() + + defer func() { + bridge.mutex.Lock() + publishStop := bridge.publishStop + broadcastStop := bridge.broadcastStop + bridge.running = false + bridge.cancel = nil + bridge.client = nil + bridge.pubsub = nil + bridge.publishStop = nil + bridge.broadcastStop = nil + bridge.mutex.Unlock() + if publishStop != nil { + publishStop() + } + if broadcastStop != nil { + broadcastStop() + } + runCancel() + if err := pubsub.Close(); err != nil { + return + } + if err := client.Close(); err != nil { + return + } + }() + + for { + message, err := pubsub.ReceiveMessage(runContext) + if err != nil { + if runContext.Err() != nil { + return nil + } + return err + } + + var decoded envelope + if !core.JSONUnmarshal([]byte(message.Payload), &decoded).OK { + continue + } + if decoded.SourceID == bridge.sourceID { + continue + } - b.mu.Lock() - b.running = false - b.mu.Unlock() - return nil + channel := bridge.channelFromRedis(message.Channel) + if channel == "" { + if err := bridge.hub.BroadcastFromBridge(decoded.Frame); err != nil { + return err + } + continue + } + if err := bridge.hub.PublishFromBridge(channel, decoded.Frame); err != nil { + return err + } + } } -// Stop cleanly shuts down the bridge. Closes the pub/sub subscription and Redis client. -func (b *Bridge) Stop() error { - if b == nil { +// defer bridge.Stop() +func (bridge *Bridge) Stop() error { + if bridge == nil { return nil } - b.mu.Lock() - if !b.running { - b.mu.Unlock() + + bridge.mutex.RLock() + running := bridge.running + cancel := bridge.cancel + pubsub := bridge.pubsub + client := bridge.client + publishStop := bridge.publishStop + broadcastStop := bridge.broadcastStop + bridge.mutex.RUnlock() + + if !running { return nil } - close(b.stopCh) - b.stopCh = make(chan struct{}) - b.mu.Unlock() - return nil + + if cancel != nil { + cancel() + } + if publishStop != nil { + publishStop() + } + if broadcastStop != nil { + broadcastStop() + } + var err error + if pubsub != nil { + err = pubsub.Close() + } + if client != nil { + if closeErr := client.Close(); closeErr != nil { + return core.ErrorJoin(err, closeErr) + } + } + return err } -// PublishToChannel publishes frame to a specific hub channel via Redis. -func (b *Bridge) PublishToChannel(channel string, frame []byte) error { - if b == nil { +// _ = bridge.PublishToChannel("block", templateBytes) +func (bridge *Bridge) PublishToChannel(channel string, frame []byte) error { + if bridge == nil { return core.E("stream.redis", "nil bridge", nil) } - if !b.isRunning() { - return core.E("stream.redis", "bridge not started", nil) + if channel == "" { + return core.E("stream.redis", "empty channel", nil) } - registry.publish(b.registryKey(), channel, envelope{ - SourceID: b.sourceID, - Frame: append([]byte(nil), frame...), - }) - return nil + + return bridge.publish(bridge.channelKey(channel), frame) } -// PublishBroadcast publishes frame as a broadcast via Redis. -func (b *Bridge) PublishBroadcast(frame []byte) error { - if b == nil { +// _ = bridge.PublishBroadcast(shutdownFrame) +func (bridge *Bridge) PublishBroadcast(frame []byte) error { + if bridge == nil { return core.E("stream.redis", "nil bridge", nil) } - if !b.isRunning() { - return core.E("stream.redis", "bridge not started", nil) - } - registry.publish(b.registryKey(), "", envelope{ - SourceID: b.sourceID, - Frame: append([]byte(nil), frame...), - }) - return nil + + return bridge.publish(bridge.broadcastChannel(), frame) } -// SourceID returns the random instance identifier. -func (b *Bridge) SourceID() string { - if b == nil { +// id := bridge.SourceID() +func (bridge *Bridge) SourceID() string { + if bridge == nil { return "" } - return b.sourceID + return bridge.sourceID +} + +func (bridge *Bridge) publish(channel string, frame []byte) error { + bridge.mutex.RLock() + running := bridge.running + client := bridge.client + bridge.mutex.RUnlock() + if !running { + return core.E("stream.redis", "bridge not started", nil) + } + if client == nil { + client = newRedisClient(bridge.config) + defer client.Close() + } + + return bridge.publishWithClient(client, channel, frame) } -func (b *Bridge) registryKey() string { - return b.config.Addr + "|" + strconv.Itoa(b.config.DB) + "|" + b.config.Prefix +func (bridge *Bridge) publishWithClient(client *redis.Client, channel string, frame []byte) error { + if client == nil { + return core.E("stream.redis", "nil redis client", nil) + } + + payload := envelope{ + SourceID: bridge.sourceID, + Frame: append([]byte(nil), frame...), + } + encoded := core.JSONMarshal(payload) + if !encoded.OK { + if err, ok := encoded.Value.(error); ok { + return err + } + return core.E("stream.redis", "failed to marshal envelope", nil) + } + + publishContext, publishCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer publishCancel() + return client.Publish(publishContext, channel, encoded.Value).Err() } -func (b *Bridge) isRunning() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.running +func (bridge *Bridge) broadcastChannel() string { + return bridge.config.Prefix + ":broadcast" } -type envelope struct { - SourceID string `json:"s"` - Frame []byte `json:"f"` +func (bridge *Bridge) channelKey(channel string) string { + return bridge.config.Prefix + ":channel:" + channel } -func (r *bridgeRegistry) add(key string, bridge *Bridge) { - r.mu.Lock() - defer r.mu.Unlock() - if r.bridges[key] == nil { - r.bridges[key] = map[*Bridge]struct{}{} - } - r.bridges[key][bridge] = struct{}{} +func (bridge *Bridge) channelPattern() string { + return bridge.config.Prefix + ":channel:*" } -func (r *bridgeRegistry) remove(key string, bridge *Bridge) { - r.mu.Lock() - defer r.mu.Unlock() - if bridges := r.bridges[key]; bridges != nil { - delete(bridges, bridge) - if len(bridges) == 0 { - delete(r.bridges, key) - } +func (bridge *Bridge) channelFromRedis(channel string) string { + if channel == bridge.broadcastChannel() { + return "" } + return core.TrimPrefix(channel, bridge.config.Prefix+":channel:") } -func (r *bridgeRegistry) publish(key, channel string, message envelope) { - r.mu.RLock() - bridges := r.bridges[key] - targets := make([]*Bridge, 0, len(bridges)) - for bridge := range bridges { - targets = append(targets, bridge) - } - r.mu.RUnlock() +func newRedisClient(config Config) *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: config.Addr, + Password: config.Password, + DB: config.DB, + TLSConfig: config.TLSConfig, + }) +} - for _, bridge := range targets { - if bridge == nil || bridge.sourceID == message.SourceID { - continue - } - if channel == "" { - _ = bridge.hub.Broadcast(message.Frame) - continue - } - _ = bridge.hub.Publish(channel, message.Frame) - } +func randomSourceID() string { + var raw [16]byte + _, _ = rand.Read(raw[:]) + raw[6] = (raw[6] & 0x0f) | 0x40 + raw[8] = (raw[8] & 0x3f) | 0x80 + return hex.EncodeToString(raw[:4]) + "-" + + hex.EncodeToString(raw[4:6]) + "-" + + hex.EncodeToString(raw[6:8]) + "-" + + hex.EncodeToString(raw[8:10]) + "-" + + hex.EncodeToString(raw[10:]) } diff --git a/adapter/redis/redis_test.go b/adapter/redis/redis_test.go new file mode 100644 index 0000000..f956aca --- /dev/null +++ b/adapter/redis/redis_test.go @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package redis + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + + "dappco.re/go/stream" +) + +func TestBridge_Publish_Good(t *testing.T) { + redisServer := miniredis.RunT(t) + + hub1 := stream.NewHub() + hub2 := stream.NewHub() + + hub1Context, hub1Cancel := context.WithCancel(context.Background()) + defer hub1Cancel() + hub2Context, hub2Cancel := context.WithCancel(context.Background()) + defer hub2Cancel() + + go hub1.Run(hub1Context) + go hub2.Run(hub2Context) + + bridge1, err := NewBridge(hub1, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + if err != nil { + t.Fatalf("NewBridge(hub1) error = %v", err) + } + bridge2, err := NewBridge(hub2, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + if err != nil { + t.Fatalf("NewBridge(hub2) error = %v", err) + } + + bridgeContext, bridgeCancel := context.WithCancel(context.Background()) + defer bridgeCancel() + go func() { _ = bridge1.Start(bridgeContext) }() + go func() { _ = bridge2.Start(bridgeContext) }() + time.Sleep(100 * time.Millisecond) + + received := make(chan []byte, 1) + unsubscribe := hub2.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if err := hub1.Publish("block", []byte("template")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for bridged frame") + } + + peer := stream.NewPeer("ws") + if err := hub2.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub2.RemovePeer(peer) + + if err := hub1.Broadcast([]byte("shutdown")); err != nil { + t.Fatalf("Broadcast() error = %v", err) + } + + select { + case frame := <-peer.SendQueue(): + if string(frame) != "shutdown" { + t.Fatalf("received frame = %q, want %q", string(frame), "shutdown") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for bridged broadcast") + } +} + +func TestBridge_Publish_Bad(t *testing.T) { + hub := stream.NewHub() + _, err := NewBridge(hub, Config{}) + if err == nil { + t.Fatal("NewBridge() error = nil, want empty address error") + } +} + +func TestBridge_Publish_BadBeforeStart(t *testing.T) { + redisServer := miniredis.RunT(t) + + hub := stream.NewHub() + bridge, err := NewBridge(hub, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + if err != nil { + t.Fatalf("NewBridge() error = %v", err) + } + + if err := bridge.PublishToChannel("block", []byte("template")); err == nil { + t.Fatal("PublishToChannel() error = nil, want bridge not started error") + } + if err := bridge.PublishBroadcast([]byte("shutdown")); err == nil { + t.Fatal("PublishBroadcast() error = nil, want bridge not started error") + } +} + +func TestBridge_Publish_Ugly(t *testing.T) { + redisServer := miniredis.RunT(t) + + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + bridge, err := NewBridge(hub, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + if err != nil { + t.Fatalf("NewBridge() error = %v", err) + } + + bridgeContext, bridgeCancel := context.WithCancel(context.Background()) + defer bridgeCancel() + go func() { _ = bridge.Start(bridgeContext) }() + time.Sleep(100 * time.Millisecond) + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if err := bridge.PublishToChannel("block", []byte("template")); err != nil { + t.Fatalf("PublishToChannel() error = %v", err) + } + + select { + case frame := <-received: + t.Fatalf("received unexpected self-echo frame = %q", string(frame)) + case <-time.After(200 * time.Millisecond): + } +} + +func TestAX7_Bridge_PublishBroadcast_Good(t *testing.T) { + redisServer := miniredis.RunT(t) + + hub1 := stream.NewHub() + hub2 := stream.NewHub() + + hub1Context, hub1Cancel := context.WithCancel(context.Background()) + defer hub1Cancel() + hub2Context, hub2Cancel := context.WithCancel(context.Background()) + defer hub2Cancel() + + go hub1.Run(hub1Context) + go hub2.Run(hub2Context) + + bridge1, err := NewBridge(hub1, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + if err != nil { + t.Fatalf("NewBridge(hub1) error = %v", err) + } + bridge2, err := NewBridge(hub2, Config{Addr: redisServer.Addr(), Prefix: "pool"}) + if err != nil { + t.Fatalf("NewBridge(hub2) error = %v", err) + } + + bridgeContext, bridgeCancel := context.WithCancel(context.Background()) + defer bridgeCancel() + go func() { _ = bridge1.Start(bridgeContext) }() + go func() { _ = bridge2.Start(bridgeContext) }() + time.Sleep(100 * time.Millisecond) + + peer := stream.NewPeer("ws") + if err := hub2.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub2.RemovePeer(peer) + + if err := bridge1.PublishBroadcast([]byte("shutdown")); err != nil { + t.Fatalf("PublishBroadcast() error = %v", err) + } + + select { + case frame := <-peer.SendQueue(): + if string(frame) != "shutdown" { + t.Fatalf("received frame = %q, want %q", string(frame), "shutdown") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for bridged broadcast") + } +} diff --git a/adapter/sse/ax7_more_test.go b/adapter/sse/ax7_more_test.go new file mode 100644 index 0000000..961529d --- /dev/null +++ b/adapter/sse/ax7_more_test.go @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package sse + +import ( + core "dappco.re/go" + "dappco.re/go/stream" +) + +func TestAX7_New_Good(t *core.T) { + adapter := New(Config{HeartbeatInterval: core.Second, RetryMs: 99}) + + core.AssertNotNil(t, adapter) + core.AssertEqual(t, core.Second, adapter.config.HeartbeatInterval) + core.AssertEqual(t, 99, adapter.config.RetryMs) +} + +func TestAX7_New_Bad(t *core.T) { + adapter := New(Config{}) + + core.AssertNotNil(t, adapter) + core.AssertEqual(t, 15*core.Second, adapter.config.HeartbeatInterval) + core.AssertEqual(t, 3000, adapter.config.RetryMs) +} + +func TestAX7_New_Ugly(t *core.T) { + authenticator := stream.NewAPIKeyAuth(map[string]string{"sk": "user"}) + adapter := New(Config{Authenticator: authenticator}) + + core.AssertEqual(t, authenticator, adapter.config.Authenticator) + core.AssertEqual(t, 15*core.Second, adapter.config.HeartbeatInterval) + core.AssertEqual(t, 3000, adapter.config.RetryMs) +} + +func TestAX7_Adapter_Mount_Good(t *core.T) { + adapter := New(Config{}) + hub := stream.NewHub() + + adapter.Mount(hub) + core.AssertEqual(t, hub, adapter.hub) + core.AssertFalse(t, adapter.hub.Running()) +} + +func TestAX7_Adapter_Mount_Bad(t *core.T) { + adapter := New(Config{}) + + adapter.Mount(nil) + core.AssertNil(t, adapter.hub) + core.AssertNotNil(t, adapter.Handler()) +} + +func TestAX7_Adapter_Mount_Ugly(t *core.T) { + adapter := New(Config{}) + first := stream.NewHub() + second := stream.NewHub() + + adapter.Mount(first) + adapter.Mount(second) + core.AssertEqual(t, second, adapter.hub) +} + +func TestAX7_Adapter_ServeHTTP_Bad(t *core.T) { + adapter := New(Config{}) + recorder := core.NewHTTPTestRecorder() + request := core.NewHTTPTestRequest("GET", "/stream/events", nil) + + adapter.ServeHTTP(recorder, request) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not mounted") +} + +func TestAX7_Adapter_ServeHTTP_Ugly(t *core.T) { + adapter := New(Config{}) + adapter.Mount(stream.NewHub()) + recorder := core.NewHTTPTestRecorder() + request := core.NewHTTPTestRequest("GET", "/stream/events?channel=events", nil) + + adapter.ServeHTTP(recorder, request) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not running") +} + +func TestAX7_Adapter_HandlerForChannel_Bad(t *core.T) { + adapter := New(Config{}) + handler := adapter.HandlerForChannel("") + + recorder := core.NewHTTPTestRecorder() + handler.ServeHTTP(recorder, core.NewHTTPTestRequest("GET", "/stream/events", nil)) + core.AssertEqual(t, 500, recorder.Code) +} + +func TestAX7_Adapter_HandlerForChannel_Ugly(t *core.T) { + adapter := New(Config{}) + adapter.Mount(stream.NewHub()) + handler := adapter.HandlerForChannel("private") + + recorder := core.NewHTTPTestRecorder() + handler.ServeHTTP(recorder, core.NewHTTPTestRequest("GET", "/stream/events", nil)) + core.AssertContains(t, recorder.Body.String(), "not running") +} diff --git a/adapter/sse/example_test.go b/adapter/sse/example_test.go new file mode 100644 index 0000000..c015b67 --- /dev/null +++ b/adapter/sse/example_test.go @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package sse_test + +import ( + "context" + "net/http" + + "dappco.re/go/stream" + "dappco.re/go/stream/adapter/sse" +) + +func ExampleAdapter_HandlerForChannel() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hub := stream.NewHub() + go hub.Run(ctx) + + adapter := sse.New(sse.Config{}) + adapter.Mount(hub) + + http.Handle("/stream/hashrate", adapter.HandlerForChannel("hashrate")) +} diff --git a/adapter/sse/sse.go b/adapter/sse/sse.go index 8f69811..2dc3eb3 100644 --- a/adapter/sse/sse.go +++ b/adapter/sse/sse.go @@ -1,33 +1,49 @@ // SPDX-License-Identifier: EUPL-1.2 -// Package sse is the Server-Sent Events transport adapter for stream.Hub. -// Lightweight server-push over HTTP/1.1 - no upgrade required. -// Used by core/api for live stats, agent event streams, and /live_stats endpoints. +// adapter := sse.New(sse.Config{HeartbeatInterval: 15 * time.Second}) +// adapter.Mount(hub) +// http.Handle("/stream/events", adapter.Handler()) package sse import ( - "fmt" + "bytes" + "io" "net/http" + "strconv" + "sync" "time" - "dappco.re/go/core" "dappco.re/go/stream" ) -// Config configures the SSE adapter. +// config := sse.Config{ +// Authenticator: stream.NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}), +// HeartbeatInterval: 15 * time.Second, +// RetryMs: 3000, +// } type Config struct { - Authenticator stream.Authenticator + // sse.New(sse.Config{Authenticator: stream.NewAPIKeyAuth(keys)}) + Authenticator stream.Authenticator + + // sse.New(sse.Config{OnAuthFailure: func(r *http.Request, result stream.AuthResult) { ... }}) + OnAuthFailure func(r *http.Request, result stream.AuthResult) + + // sse.New(sse.Config{HeartbeatInterval: 15 * time.Second}) HeartbeatInterval time.Duration - RetryMs int + + // sse.New(sse.Config{RetryMs: 3000}) + RetryMs int } -// Adapter is the SSE transport adapter for a stream.Hub. +// adapter := sse.New(sse.Config{}) +// adapter.Mount(hub) +// http.Handle("/stream/events", adapter.Handler()) type Adapter struct { hub *stream.Hub config Config } -// New creates an SSE adapter. Call Mount before serving requests. +// adapter := sse.New(sse.Config{HeartbeatInterval: 15 * time.Second}) func New(config Config) *Adapter { if config.HeartbeatInterval == 0 { config.HeartbeatInterval = 15 * time.Second @@ -38,35 +54,51 @@ func New(config Config) *Adapter { return &Adapter{config: config} } -// Mount wires the adapter to a hub. Must be called before Handler(). -func (a *Adapter) Mount(hub *stream.Hub) { - a.hub = hub +// adapter.Mount(hub) +func (adapter *Adapter) Mount(hub *stream.Hub) { + adapter.hub = hub } -// Handler returns an http.HandlerFunc that accepts SSE connections. -func (a *Adapter) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - a.serve(w, r, r.URL.Query()["channel"]) - } +// http.Handle("/stream/events", adapter.Handler()) +// http.Get("http://127.0.0.1:8080/stream/events?channel=hashrate") +func (adapter *Adapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + adapter.serve(w, r, r.URL.Query()["channel"]) } -// HandlerForChannel returns a handler that auto-subscribes all connections to channel. -func (a *Adapter) HandlerForChannel(channel string) http.HandlerFunc { +// http.Handle("/stream/events", adapter.Handler()) +// http.Get("http://127.0.0.1:8080/stream/events?channel=hashrate") +func (adapter *Adapter) Handler() http.HandlerFunc { + return adapter.ServeHTTP +} + +// http.Handle("/stream/hashrate", adapter.HandlerForChannel("hashrate")) +func (adapter *Adapter) HandlerForChannel(channel string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - a.serve(w, r, []string{channel}) + adapter.serve(w, r, []string{channel}) } } -func (a *Adapter) serve(w http.ResponseWriter, r *http.Request, channels []string) { - if a.hub == nil { +func (adapter *Adapter) serve(w http.ResponseWriter, r *http.Request, channels []string) { + if adapter.hub == nil { http.Error(w, "stream hub not mounted", http.StatusInternalServerError) return } - result := stream.AuthResult{Valid: true} - if a.config.Authenticator != nil { - result = a.config.Authenticator.Authenticate(r) - if !result.Valid { + config := adapter.config + if config.HeartbeatInterval == 0 { + config.HeartbeatInterval = 15 * time.Second + } + if config.RetryMs == 0 { + config.RetryMs = 3000 + } + + authResult := stream.AuthResult{Valid: true} + if adapter.config.Authenticator != nil { + authResult = adapter.config.Authenticator.Authenticate(r) + if !authResult.Valid { + if adapter.config.OnAuthFailure != nil { + adapter.config.OnAuthFailure(r, authResult) + } http.Error(w, "unauthorised", http.StatusUnauthorized) return } @@ -84,41 +116,85 @@ func (a *Adapter) serve(w http.ResponseWriter, r *http.Request, channels []strin header.Set("X-Accel-Buffering", "no") peer := stream.NewPeer("sse") - peer.UserID = result.UserID - peer.Claims = result.Claims - _ = a.hub.AddPeer(peer) - defer a.hub.RemovePeer(peer) + peer.UserID = authResult.UserID + if authResult.Claims != nil { + peer.Claims = authResult.Claims + } + done := make(chan struct{}) + var doneOnce sync.Once + peer.SetCloseHook(func() { + doneOnce.Do(func() { + close(done) + }) + }) + + for _, channel := range channels { + if channel == "" { + continue + } + if err := adapter.hub.CanSubscribePeer(peer, channel); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + } + + if !adapter.hub.Running() { + http.Error(w, "stream hub not running", http.StatusInternalServerError) + return + } for _, channel := range channels { if channel == "" { continue } - _ = a.hub.SubscribePeer(peer, channel) + if err := adapter.hub.SubscribePeer(peer, channel); err != nil { + http.Error(w, "stream hub not running", http.StatusInternalServerError) + return + } } - _, _ = fmt.Fprintf(w, "retry: %d\n\n", a.config.RetryMs) + header.Set("Connection", "keep-alive") + + _, _ = io.WriteString(w, "retry: "+strconv.Itoa(config.RetryMs)+"\n\n") flusher.Flush() - ticker := time.NewTicker(a.config.HeartbeatInterval) + if err := adapter.hub.AddPeer(peer); err != nil { + return + } + defer adapter.hub.RemovePeer(peer) + + ticker := time.NewTicker(config.HeartbeatInterval) defer ticker.Stop() - done := r.Context().Done() + requestDone := r.Context().Done() for { select { case <-done: return + case <-requestDone: + return case frame, ok := <-peer.SendQueue(): if !ok { return } - _, _ = fmt.Fprintf(w, "data: %s\n\n", frame) + writeEventFrame(w, frame) flusher.Flush() case <-ticker.C: - _, _ = fmt.Fprint(w, ": ping\n\n") + writeHeartbeatFrame(w) flusher.Flush() } } } -var _ time.Duration -var _ = core.E +func writeEventFrame(writer io.Writer, frame []byte) { + for _, line := range bytes.Split(frame, []byte{'\n'}) { + _, _ = io.WriteString(writer, "data: ") + _, _ = writer.Write(line) + _, _ = io.WriteString(writer, "\n") + } + _, _ = io.WriteString(writer, "\n") +} + +func writeHeartbeatFrame(writer io.Writer) { + _, _ = io.WriteString(writer, ": ping\n\n") +} diff --git a/adapter/sse/sse_test.go b/adapter/sse/sse_test.go new file mode 100644 index 0000000..c43e48e --- /dev/null +++ b/adapter/sse/sse_test.go @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package sse + +import ( + "bufio" + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "dappco.re/go" + "dappco.re/go/stream" +) + +func TestAX7_Adapter_Handler_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{HeartbeatInterval: 20 * time.Millisecond}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + response, err := http.Get(server.URL + "?channel=hashrate") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + waitForPeerCount(t, hub, 1) + if err := hub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + reader := bufio.NewReader(response.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("ReadString() error = %v", err) + } + if core.Trim(line) == "data: 123456" { + return + } + } +} + +func TestAdapter_Handler_ZeroValueConfig_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := &Adapter{} + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + response, err := http.Get(server.URL + "?channel=hashrate") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + waitForPeerCount(t, hub, 1) + if err := hub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + reader := bufio.NewReader(response.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("ReadString() error = %v", err) + } + if core.Trim(line) == "data: 123456" { + return + } + } +} + +func TestAdapter_Handler_MultilineFrame_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{HeartbeatInterval: 20 * time.Millisecond}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + response, err := http.Get(server.URL + "?channel=hashrate") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + waitForPeerCount(t, hub, 1) + if err := hub.Publish("hashrate", []byte("123\n456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + reader := bufio.NewReader(response.Body) + lines := make([]string, 0, 4) + for len(lines) < 4 { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("ReadString() error = %v", err) + } + lines = append(lines, line) + } + + expected := []string{"retry: 3000\n", "\n", "data: 123\n", "data: 456\n"} + for index, line := range expected { + if lines[index] != line { + t.Fatalf("lines[%d] = %q, want %q", index, lines[index], line) + } + } +} + +func TestAX7_Adapter_Handler_Bad(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + var authFailureCount atomic.Int32 + adapter := New(Config{ + Authenticator: stream.NewAPIKeyAuth(map[string]string{"valid-key": "user-1"}), + OnAuthFailure: func(r *http.Request, result stream.AuthResult) { + authFailureCount.Add(1) + }, + }) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + response, err := http.Get(server.URL + "?channel=hashrate") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + if response.StatusCode != http.StatusUnauthorized { + t.Fatalf("StatusCode = %d, want %d", response.StatusCode, http.StatusUnauthorized) + } + if authFailureCount.Load() != 1 { + t.Fatalf("OnAuthFailure invoked %d times, want %d", authFailureCount.Load(), 1) + } +} + +func TestAdapter_Handler_HubNotRunning_Bad(t *testing.T) { + adapter := New(Config{}) + adapter.Mount(stream.NewHub()) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + response, err := http.Get(server.URL + "?channel=hashrate") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + if response.StatusCode != http.StatusInternalServerError { + t.Fatalf("StatusCode = %d, want %d", response.StatusCode, http.StatusInternalServerError) + } +} + +func TestAdapter_Handler_ChannelAuthoriser_Bad(t *testing.T) { + hub := stream.NewHubWithConfig(stream.HubConfig{ + ChannelAuthoriser: func(peer *stream.Peer, channel string) bool { + return channel == "public" + }, + }) + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + response, err := http.Get(server.URL + "?channel=private") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + if response.StatusCode != http.StatusForbidden { + t.Fatalf("StatusCode = %d, want %d", response.StatusCode, http.StatusForbidden) + } + waitForPeerCount(t, hub, 0) +} + +func TestAX7_Adapter_Handler_Ugly(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{HeartbeatInterval: 20 * time.Millisecond}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + requestContext, requestCancel := context.WithCancel(context.Background()) + request, err := http.NewRequestWithContext(requestContext, http.MethodGet, server.URL+"?channel=hashrate", nil) + if err != nil { + t.Fatalf("NewRequestWithContext() error = %v", err) + } + response, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("Do() error = %v", err) + } + + waitForPeerCount(t, hub, 1) + requestCancel() + _ = response.Body.Close() + + waitForPeerCount(t, hub, 0) +} + +func TestAX7_Adapter_ServeHTTP_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{HeartbeatInterval: 20 * time.Millisecond}) + adapter.Mount(hub) + + server := httptest.NewServer(adapter) + defer server.Close() + + response, err := http.Get(server.URL + "?channel=serve-http") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + waitForPeerCount(t, hub, 1) + if err := hub.Publish("serve-http", []byte("ok")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + reader := bufio.NewReader(response.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("ReadString() error = %v", err) + } + if core.Trim(line) == "data: ok" { + return + } + } +} + +func TestAX7_Adapter_HandlerForChannel_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{HeartbeatInterval: 20 * time.Millisecond}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.HandlerForChannel("hashrate"))) + defer server.Close() + + response, err := http.Get(server.URL) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + waitForPeerCount(t, hub, 1) + if err := hub.Publish("hashrate", []byte("654321")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + reader := bufio.NewReader(response.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("ReadString() error = %v", err) + } + if core.Trim(line) == "data: 654321" { + return + } + } +} + +func TestAdapter_Handler_RetryMs_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{RetryMs: 1234, HeartbeatInterval: time.Second}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + response, err := http.Get(server.URL + "?channel=hashrate") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + defer response.Body.Close() + + reader := bufio.NewReader(response.Body) + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("ReadString() error = %v", err) + } + if core.Trim(line) != "retry: 1234" { + t.Fatalf("first line = %q, want %q", core.Trim(line), "retry: 1234") + } +} + +func waitForPeerCount(t *testing.T, hub *stream.Hub, expected int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if hub.PeerCount() == expected { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("PeerCount() = %d, want %d", hub.PeerCount(), expected) +} diff --git a/adapter/tcp/ax7_more_test.go b/adapter/tcp/ax7_more_test.go new file mode 100644 index 0000000..c4e43f7 --- /dev/null +++ b/adapter/tcp/ax7_more_test.go @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package tcp + +import ( + core "dappco.re/go" + "dappco.re/go/stream" +) + +func ax7TCPHub(t *core.T) (*stream.Hub, core.Context, core.CancelFunc) { + hub := stream.NewHub() + ctx, cancel := core.WithCancel(core.Background()) + go hub.Run(ctx) + deadline := core.Now().Add(2 * core.Second) + for core.Now().Before(deadline) { + if hub.Running() { + return hub, ctx, cancel + } + core.Sleep(10 * core.Millisecond) + } + t.Fatal("timed out waiting for hub") + return nil, nil, nil +} + +func TestAX7_New_Good(t *core.T) { + adapter := New(Config{Addr: "127.0.0.1:0", HandshakeTimeout: core.Second}) + + core.AssertNotNil(t, adapter) + core.AssertEqual(t, "127.0.0.1:0", adapter.config.Addr) + core.AssertEqual(t, core.Second, adapter.config.HandshakeTimeout) +} + +func TestAX7_New_Bad(t *core.T) { + adapter := New(Config{}) + + core.AssertNotNil(t, adapter) + core.AssertEqual(t, "", adapter.config.Addr) + core.AssertEqual(t, 5*core.Second, adapter.config.HandshakeTimeout) +} + +func TestAX7_New_Ugly(t *core.T) { + adapter := New(Config{HandshakeChannel: "auth", HandshakeFrame: []byte("token")}) + + core.AssertEqual(t, "auth", adapter.config.HandshakeChannel) + core.AssertEqual(t, "token", string(adapter.config.HandshakeFrame)) + core.AssertEqual(t, 5*core.Second, adapter.config.HandshakeTimeout) +} + +func TestAX7_Adapter_Mount_Good(t *core.T) { + adapter := New(Config{}) + hub := stream.NewHub() + + adapter.Mount(hub) + core.AssertEqual(t, hub, adapter.hub) + core.AssertFalse(t, adapter.hub.Running()) +} + +func TestAX7_Adapter_Mount_Bad(t *core.T) { + adapter := New(Config{}) + + adapter.Mount(nil) + core.AssertNil(t, adapter.hub) + core.AssertNotNil(t, adapter) +} + +func TestAX7_Adapter_Mount_Ugly(t *core.T) { + adapter := New(Config{}) + first := stream.NewHub() + second := stream.NewHub() + + adapter.Mount(first) + adapter.Mount(second) + core.AssertEqual(t, second, adapter.hub) +} + +func TestAX7_Adapter_Listen_Good(t *core.T) { + hub, ctx, cancel := ax7TCPHub(t) + defer cancel() + adapter := New(Config{Addr: "127.0.0.1:0"}) + adapter.Mount(hub) + + go func() { core.AssertNoError(t, adapter.Listen(ctx)) }() + addr := waitForListenerAddress(t, adapter) + core.AssertNotEmpty(t, addr) +} + +func TestAX7_Adapter_Listen_Bad(t *core.T) { + var adapter *Adapter + + err := adapter.Listen(core.Background()) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "nil adapter") +} + +func TestAX7_Adapter_Listen_Ugly(t *core.T) { + adapter := New(Config{}) + adapter.Mount(stream.NewHub()) + + err := adapter.Listen(core.Background()) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "empty address") +} + +func TestAX7_Adapter_Dial_Good(t *core.T) { + hub, ctx, cancel := ax7TCPHub(t) + defer cancel() + server := New(Config{Addr: "127.0.0.1:0"}) + server.Mount(hub) + go func() { core.AssertNoError(t, server.Listen(ctx)) }() + addr := waitForListenerAddress(t, server) + + client := New(Config{Addr: addr, HandshakeChannel: "auth", HandshakeFrame: []byte("token")}) + peer, err := client.Dial(ctx, hub) + core.AssertNoError(t, err) + core.AssertNotNil(t, peer) + core.AssertEqual(t, "tcp", peer.Transport) +} + +func TestAX7_Adapter_Dial_Bad(t *core.T) { + var adapter *Adapter + + peer, err := adapter.Dial(core.Background(), nil) + core.AssertError(t, err) + core.AssertNil(t, peer) +} + +func TestAX7_Adapter_Dial_Ugly(t *core.T) { + adapter := New(Config{Addr: "127.0.0.1:1"}) + + peer, err := adapter.Dial(core.Background(), nil) + core.AssertError(t, err) + core.AssertNil(t, peer) +} + +func TestAX7_NewReconnectingTCP_Good(t *core.T) { + client := NewReconnectingTCP(ReconnectConfig{Addr: "127.0.0.1:9000"}) + + core.AssertNotNil(t, client) + core.AssertEqual(t, "127.0.0.1:9000", client.config.Addr) + core.AssertEqual(t, stream.StateDisconnected, client.State()) +} + +func TestAX7_NewReconnectingTCP_Bad(t *core.T) { + client := NewReconnectingTCP(ReconnectConfig{}) + + core.AssertEqual(t, core.Second, client.config.InitialBackoff) + core.AssertEqual(t, 30*core.Second, client.config.MaxBackoff) + core.AssertEqual(t, 2.0, client.config.BackoffMultiplier) +} + +func TestAX7_NewReconnectingTCP_Ugly(t *core.T) { + client := NewReconnectingTCP(ReconnectConfig{InitialBackoff: core.Millisecond, MaxBackoff: core.Second, BackoffMultiplier: 3}) + + core.AssertEqual(t, core.Millisecond, client.config.InitialBackoff) + core.AssertEqual(t, core.Second, client.config.MaxBackoff) + core.AssertEqual(t, 3.0, client.config.BackoffMultiplier) +} + +func TestAX7_ReconnectingTCP_Connect_Good(t *core.T) { + hub, ctx, cancel := ax7TCPHub(t) + defer cancel() + server := New(Config{Addr: "127.0.0.1:0"}) + server.Mount(hub) + go func() { core.AssertNoError(t, server.Listen(ctx)) }() + addr := waitForListenerAddress(t, server) + + client := NewReconnectingTCP(ReconnectConfig{Addr: addr}) + go func() { core.AssertNoError(t, client.Connect(ctx)) }() + deadline := core.Now().Add(2 * core.Second) + for core.Now().Before(deadline) { + if client.State() == stream.StateConnected { + core.AssertEqual(t, stream.StateConnected, client.State()) + return + } + core.Sleep(10 * core.Millisecond) + } + t.Fatal("timed out waiting for connected state") +} + +func TestAX7_ReconnectingTCP_Connect_Bad(t *core.T) { + var client *ReconnectingTCP + + err := client.Connect(core.Background()) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "nil reconnecting tcp") +} + +func TestAX7_ReconnectingTCP_Connect_Ugly(t *core.T) { + client := NewReconnectingTCP(ReconnectConfig{Addr: "127.0.0.1:1", MaxRetries: 1, InitialBackoff: core.Millisecond}) + + err := client.Connect(core.Background()) + core.AssertError(t, err) + core.AssertEqual(t, stream.StateDisconnected, client.State()) +} + +func TestAX7_ReconnectingTCP_Send_Good(t *core.T) { + left, right := core.NetPipe() + defer left.Close() + defer right.Close() + client := NewReconnectingTCP(ReconnectConfig{}) + client.setConn(left) + done := make(chan error, 1) + + go func() { done <- client.Send("block", []byte("template")) }() + channel, frame, err := readTCPFrame(right, 0, MaxFrameSize) + core.AssertNoError(t, err) + core.AssertEqual(t, "block", channel) + core.AssertEqual(t, "template", string(frame)) + core.AssertNoError(t, <-done) +} + +func TestAX7_ReconnectingTCP_Send_Bad(t *core.T) { + var client *ReconnectingTCP + + err := client.Send("block", []byte("template")) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "nil reconnecting tcp") +} + +func TestAX7_ReconnectingTCP_Send_Ugly(t *core.T) { + client := NewReconnectingTCP(ReconnectConfig{}) + + err := client.Send("block", []byte("template")) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "not connected") +} + +func TestAX7_ReconnectingTCP_State_Bad(t *core.T) { + var client *ReconnectingTCP + + core.AssertEqual(t, stream.StateDisconnected, client.State()) + core.AssertNil(t, client) +} + +func TestAX7_ReconnectingTCP_State_Ugly(t *core.T) { + client := NewReconnectingTCP(ReconnectConfig{}) + client.setState(stream.StateConnecting) + + core.AssertEqual(t, stream.StateConnecting, client.State()) + client.setState(stream.StateDisconnected) + core.AssertEqual(t, stream.StateDisconnected, client.State()) +} + +func TestAX7_ReconnectingTCP_Close_Good(t *core.T) { + left, right := core.NetPipe() + defer right.Close() + client := NewReconnectingTCP(ReconnectConfig{}) + client.setConn(left) + + core.AssertNoError(t, client.Close()) + core.AssertEqual(t, stream.StateDisconnected, client.State()) + core.AssertTrue(t, client.closed) +} + +func TestAX7_ReconnectingTCP_Close_Bad(t *core.T) { + var client *ReconnectingTCP + + core.AssertNoError(t, client.Close()) + core.AssertNil(t, client) +} + +func TestAX7_ReconnectingTCP_Close_Ugly(t *core.T) { + client := NewReconnectingTCP(ReconnectConfig{}) + + core.AssertNoError(t, client.Close()) + core.AssertNoError(t, client.Close()) + core.AssertEqual(t, stream.StateDisconnected, client.State()) +} diff --git a/adapter/tcp/example_test.go b/adapter/tcp/example_test.go new file mode 100644 index 0000000..2b81f33 --- /dev/null +++ b/adapter/tcp/example_test.go @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package tcp_test + +import ( + "context" + + "dappco.re/go/stream" + "dappco.re/go/stream/adapter/tcp" +) + +func ExampleAdapter_Listen() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hub := stream.NewHub() + go hub.Run(ctx) + + adapter := tcp.New(tcp.Config{ + Addr: ":9000", + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + if string(handshake) != "trusted" { + return stream.AuthResult{Valid: false} + } + return stream.AuthResult{Valid: true, UserID: "peer-1"} + }), + }) + adapter.Mount(hub) + + go func() { + _ = adapter.Listen(ctx) + }() +} diff --git a/adapter/tcp/reconnect.go b/adapter/tcp/reconnect.go index 9a6813c..950ed8a 100644 --- a/adapter/tcp/reconnect.go +++ b/adapter/tcp/reconnect.go @@ -5,17 +5,30 @@ package tcp import ( "context" "crypto/tls" - "errors" "net" "sync" "time" - "dappco.re/go/core" + "dappco.re/go" + "dappco.re/go/stream" ) -// ReconnectConfig configures the client-side reconnecting TCP connection. +// config := tcp.ReconnectConfig{ +// Addr: "127.0.0.1:9000", +// OnReconnect: func(attempt int) { +// core.Print(nil, "tcp reconnect attempt=%d", attempt) +// }, +// OnMessage: func(channel string, frame []byte) { +// _ = channel +// _ = frame +// }, +// } +// +// client := tcp.NewReconnectingTCP(config) type ReconnectConfig struct { Addr string + HandshakeFrame []byte + HandshakeChannel string InitialBackoff time.Duration MaxBackoff time.Duration BackoffMultiplier float64 @@ -23,19 +36,22 @@ type ReconnectConfig struct { TLS *tls.Config OnConnect func() OnDisconnect func() + OnReconnect func(attempt int) OnMessage func(channel string, frame []byte) } -// ReconnectingTCP connects to a TCP stream endpoint with automatic reconnection. +// client := tcp.NewReconnectingTCP(tcp.ReconnectConfig{Addr: "10.69.69.165:9000"}) type ReconnectingTCP struct { config ReconnectConfig - mu sync.RWMutex - conn net.Conn - closed bool + mutex sync.RWMutex + writeMutex sync.Mutex + conn net.Conn + state stream.ConnectionState + closed bool } -// NewReconnectingTCP creates a reconnecting TCP client. +// client := tcp.NewReconnectingTCP(tcp.ReconnectConfig{Addr: "10.69.69.165:9000"}) func NewReconnectingTCP(config ReconnectConfig) *ReconnectingTCP { if config.InitialBackoff == 0 { config.InitialBackoff = time.Second @@ -46,52 +62,89 @@ func NewReconnectingTCP(config ReconnectConfig) *ReconnectingTCP { if config.BackoffMultiplier <= 0 { config.BackoffMultiplier = 2 } - return &ReconnectingTCP{config: config} + return &ReconnectingTCP{ + config: config, + state: stream.StateDisconnected, + } } -// Connect starts the connection loop. Blocks until ctx is cancelled. -func (rc *ReconnectingTCP) Connect(ctx context.Context) error { - if rc == nil { - return errors.New("nil reconnecting tcp") +// err := client.Connect(ctx) +func (client *ReconnectingTCP) Connect(ctx context.Context) error { + if client == nil { + return core.E("stream.tcp", "nil reconnecting tcp", nil) } if ctx == nil { ctx = context.Background() } - backoff := rc.config.InitialBackoff + backoff := client.config.InitialBackoff attempt := 0 for { - if rc.isClosed() || ctx.Err() != nil { + if client.isClosed() || ctx.Err() != nil { return nil } - conn, err := rc.dial(ctx) + client.setState(stream.StateConnecting) + conn, err := client.dial(ctx) if err != nil { attempt++ - if rc.config.MaxRetries > 0 && attempt > rc.config.MaxRetries { + client.setState(stream.StateDisconnected) + if client.config.MaxRetries > 0 && attempt > client.config.MaxRetries { return err } + if client.config.OnReconnect != nil { + client.config.OnReconnect(attempt) + } if err := sleepContext(ctx, backoff); err != nil { return err } - backoff = nextTCPBackoff(backoff, rc.config.BackoffMultiplier, rc.config.MaxBackoff) + backoff = nextTCPBackoff(backoff, client.config.BackoffMultiplier, client.config.MaxBackoff) + continue + } + if err := client.writeHandshake(conn); err != nil { + if closeErr := conn.Close(); closeErr != nil { + err = core.ErrorJoin(err, closeErr) + } + attempt++ + client.setState(stream.StateDisconnected) + if client.config.MaxRetries > 0 && attempt > client.config.MaxRetries { + return err + } + if client.config.OnReconnect != nil { + client.config.OnReconnect(attempt) + } + if err := sleepContext(ctx, backoff); err != nil { + return err + } + backoff = nextTCPBackoff(backoff, client.config.BackoffMultiplier, client.config.MaxBackoff) continue } - rc.setConn(conn) - if rc.config.OnConnect != nil { - rc.config.OnConnect() + client.setConn(conn) + stopClose := context.AfterFunc(ctx, func() { + if err := conn.Close(); err != nil { + return + } + }) + backoff = client.config.InitialBackoff + attempt = 0 + if client.config.OnConnect != nil { + client.config.OnConnect() } - readErr := rc.readLoop(ctx, conn) + readErr := client.readLoop(ctx, conn) + stopClose() - rc.clearConn(conn) - _ = conn.Close() - if rc.config.OnDisconnect != nil { - rc.config.OnDisconnect() + client.clearConn(conn) + client.setState(stream.StateDisconnected) + if err := conn.Close(); err != nil && readErr == nil { + readErr = err + } + if client.config.OnDisconnect != nil { + client.config.OnDisconnect() } - if rc.isClosed() || ctx.Err() != nil { + if client.isClosed() || ctx.Err() != nil { return nil } if readErr == nil { @@ -99,99 +152,127 @@ func (rc *ReconnectingTCP) Connect(ctx context.Context) error { } else { attempt++ } - if rc.config.MaxRetries > 0 && attempt > rc.config.MaxRetries { + if client.config.MaxRetries > 0 && attempt > client.config.MaxRetries { return readErr } + if client.config.OnReconnect != nil { + client.config.OnReconnect(attempt) + } if err := sleepContext(ctx, backoff); err != nil { return err } - backoff = nextTCPBackoff(backoff, rc.config.BackoffMultiplier, rc.config.MaxBackoff) + backoff = nextTCPBackoff(backoff, client.config.BackoffMultiplier, client.config.MaxBackoff) } } -// Send transmits frame on channel through the TCP connection. -func (rc *ReconnectingTCP) Send(channel string, frame []byte) error { - if rc == nil { - return errors.New("nil reconnecting tcp") +// _ = client.Send("vpn:peer-abc123", encryptedPacket) +func (client *ReconnectingTCP) Send(channel string, frame []byte) error { + if client == nil { + return core.E("stream.tcp", "nil reconnecting tcp", nil) } - rc.mu.RLock() - conn := rc.conn - rc.mu.RUnlock() - if conn == nil { + client.writeMutex.Lock() + defer client.writeMutex.Unlock() + + client.mutex.RLock() + connection := client.conn + client.mutex.RUnlock() + if connection == nil { return core.E("stream.tcp", "not connected", nil) } - _, err := conn.Write(encodeFrame(channel, frame)) - return err + return writeAll(connection, encodeTCPFrame(channel, frame)) } -// Close shuts down the reconnecting client. -func (rc *ReconnectingTCP) Close() error { - if rc == nil { +// if client.State() == stream.StateConnected { +// _ = client.Send("vpn:peer-abc123", encryptedPacket) +// } +func (client *ReconnectingTCP) State() stream.ConnectionState { + if client == nil { + return stream.StateDisconnected + } + client.mutex.RLock() + defer client.mutex.RUnlock() + return client.state +} + +// _ = client.Close() +func (client *ReconnectingTCP) Close() error { + if client == nil { return nil } - rc.mu.Lock() - rc.closed = true - conn := rc.conn - rc.conn = nil - rc.mu.Unlock() + client.mutex.Lock() + client.closed = true + conn := client.conn + client.conn = nil + client.state = stream.StateDisconnected + client.mutex.Unlock() if conn != nil { return conn.Close() } return nil } -func (rc *ReconnectingTCP) dial(ctx context.Context) (net.Conn, error) { +func (client *ReconnectingTCP) dial(ctx context.Context) (net.Conn, error) { dialer := &net.Dialer{} - if rc.config.TLS != nil { - conn, err := dialer.DialContext(ctx, "tcp", rc.config.Addr) + if client.config.TLS != nil { + conn, err := dialer.DialContext(ctx, "tcp", client.config.Addr) if err != nil { return nil, err } - tlsConn := tls.Client(conn, rc.config.TLS) + tlsConn := tls.Client(conn, client.config.TLS) if err := tlsConn.HandshakeContext(ctx); err != nil { - _ = conn.Close() + if closeErr := conn.Close(); closeErr != nil { + return nil, core.ErrorJoin(err, closeErr) + } return nil, err } return tlsConn, nil } - return dialer.DialContext(ctx, "tcp", rc.config.Addr) + return dialer.DialContext(ctx, "tcp", client.config.Addr) } -func (rc *ReconnectingTCP) readLoop(ctx context.Context, conn net.Conn) error { +func (client *ReconnectingTCP) readLoop(ctx context.Context, conn net.Conn) error { for { select { case <-ctx.Done(): return ctx.Err() default: } - channel, frame, err := readFrame(conn, 0) + channel, frame, err := readTCPFrame(conn, 0, MaxFrameSize) if err != nil { return err } - if rc.config.OnMessage != nil { - rc.config.OnMessage(channel, frame) + if client.config.OnMessage != nil { + client.config.OnMessage(channel, frame) } } } -func (rc *ReconnectingTCP) setConn(conn net.Conn) { - rc.mu.Lock() - rc.conn = conn - rc.mu.Unlock() +func (client *ReconnectingTCP) setConn(conn net.Conn) { + client.mutex.Lock() + client.conn = conn + client.state = stream.StateConnected + client.mutex.Unlock() } -func (rc *ReconnectingTCP) clearConn(conn net.Conn) { - rc.mu.Lock() - if rc.conn == conn { - rc.conn = nil +func (client *ReconnectingTCP) clearConn(conn net.Conn) { + client.mutex.Lock() + if client.conn == conn { + client.conn = nil + client.state = stream.StateDisconnected } - rc.mu.Unlock() + client.mutex.Unlock() +} + +func (client *ReconnectingTCP) setState(state stream.ConnectionState) { + client.mutex.Lock() + client.state = state + client.mutex.Unlock() } -func (rc *ReconnectingTCP) isClosed() bool { - rc.mu.RLock() - defer rc.mu.RUnlock() - return rc.closed +func (client *ReconnectingTCP) isClosed() bool { + client.mutex.RLock() + defer client.mutex.RUnlock() + return client.closed } func nextTCPBackoff(current time.Duration, multiplier float64, maximum time.Duration) time.Duration { @@ -218,3 +299,13 @@ func sleepContext(ctx context.Context, duration time.Duration) error { return nil } } + +func (client *ReconnectingTCP) writeHandshake(conn net.Conn) error { + if conn == nil { + return core.E("stream.tcp", "nil connection", nil) + } + if len(client.config.HandshakeFrame) == 0 && client.config.HandshakeChannel == "" { + return nil + } + return writeAll(conn, encodeTCPFrame(client.config.HandshakeChannel, client.config.HandshakeFrame)) +} diff --git a/adapter/tcp/tcp.go b/adapter/tcp/tcp.go index fb0d097..c84a3d1 100644 --- a/adapter/tcp/tcp.go +++ b/adapter/tcp/tcp.go @@ -1,45 +1,62 @@ // SPDX-License-Identifier: EUPL-1.2 -// Package tcp is the raw TCP transport adapter for stream.Hub. -// Length-prefixed framing over plain or TLS TCP. Used by go-p2p VPN tunnels -// and go-proxy stratum sessions where WebSocket overhead is undesirable. +// adapter := tcp.New(tcp.Config{Addr: ":9000"}) +// adapter.Mount(hub) +// go adapter.Listen(ctx) package tcp import ( "context" "crypto/tls" "encoding/binary" - "errors" "io" "net" "sync" "time" - "dappco.re/go/core" + "dappco.re/go" "dappco.re/go/stream" ) // MaxFrameSize is the maximum allowed frame size in bytes. const MaxFrameSize = 65535 -// Config configures the TCP adapter. +const maxHandshakeFrameSize = 4 << 10 + +// config := tcp.Config{ +// Addr: ":9000", +// ConnAuthenticator: auth, +// } type Config struct { - Addr string + // tcp.New(tcp.Config{Addr: ":9000"}) + Addr string + + // tcp.New(tcp.Config{ConnAuthenticator: auth}) ConnAuthenticator stream.ConnAuthenticator - HandshakeTimeout time.Duration - TLS *tls.Config + + // tcp.New(tcp.Config{HandshakeFrame: []byte("trusted")}) + HandshakeFrame []byte + + // tcp.New(tcp.Config{HandshakeChannel: "auth"}) + HandshakeChannel string + + // tcp.New(tcp.Config{HandshakeTimeout: 5 * time.Second}) + HandshakeTimeout time.Duration + + // tcp.New(tcp.Config{TLS: &tls.Config{}}) + TLS *tls.Config } -// Adapter is the raw TCP transport adapter. +// adapter := tcp.New(tcp.Config{Addr: ":9000", ConnAuthenticator: auth}) type Adapter struct { hub *stream.Hub config Config - mu sync.Mutex + mutex sync.Mutex listener net.Listener } -// New creates a TCP adapter. Call Mount before Listen or Dial. +// adapter := tcp.New(tcp.Config{Addr: ":9000", ConnAuthenticator: auth}) func New(config Config) *Adapter { if config.HandshakeTimeout == 0 { config.HandshakeTimeout = 5 * time.Second @@ -47,32 +64,46 @@ func New(config Config) *Adapter { return &Adapter{config: config} } -// Mount wires the adapter to a hub. -func (a *Adapter) Mount(hub *stream.Hub) { - a.hub = hub +// adapter.Mount(hub) +func (adapter *Adapter) Mount(hub *stream.Hub) { + adapter.hub = hub } -// Listen starts the TCP accept loop. Blocks until ctx cancelled. -func (a *Adapter) Listen(ctx context.Context) error { - if a == nil { - return errors.New("nil adapter") +// go adapter.Listen(ctx) +func (adapter *Adapter) Listen(ctx context.Context) error { + if adapter == nil { + return core.E("stream.tcp", "nil adapter", nil) + } + if ctx == nil { + ctx = context.Background() } - if a.hub == nil { + if adapter.hub == nil { return core.E("stream.tcp", "stream hub not mounted", nil) } - if a.config.Addr == "" { + if adapter.config.Addr == "" { return core.E("stream.tcp", "empty address", nil) } - listener, err := a.listen() + listener, err := adapter.listen() if err != nil { return err } - defer listener.Close() + defer func() { + if err := listener.Close(); err != nil { + return + } + adapter.mutex.Lock() + if adapter.listener == listener { + adapter.listener = nil + } + adapter.mutex.Unlock() + }() go func() { <-ctx.Done() - _ = listener.Close() + if err := listener.Close(); err != nil { + return + } }() for { @@ -86,125 +117,210 @@ func (a *Adapter) Listen(ctx context.Context) error { } return err } - go a.handleConn(ctx, conn, a.hub) + go adapter.handleConn(ctx, conn, adapter.hub) } } -// Dial connects to a remote TCP stream endpoint. Returns a Peer that can send/receive. -func (a *Adapter) Dial(ctx context.Context, hub *stream.Hub) (*stream.Peer, error) { - if a == nil { - return nil, errors.New("nil adapter") +// peer, err := adapter.Dial(ctx, hub) +func (adapter *Adapter) Dial(ctx context.Context, hub *stream.Hub) (*stream.Peer, error) { + if adapter == nil { + return nil, core.E("stream.tcp", "nil adapter", nil) + } + if ctx == nil { + ctx = context.Background() } if hub == nil { - hub = a.hub + hub = adapter.hub } if hub == nil { return nil, core.E("stream.tcp", "stream hub not mounted", nil) } - conn, err := a.dial(ctx) + conn, err := adapter.dial(ctx) if err != nil { return nil, err } - _, _ = conn.Write(encodeFrame("", nil)) + if err := adapter.writeHandshake(conn); err != nil { + if closeErr := conn.Close(); closeErr != nil { + return nil, core.ErrorJoin(err, closeErr) + } + return nil, err + } peer := stream.NewPeer("tcp") - _ = hub.AddPeer(peer) - _ = hub.SubscribePeer(peer, "*") - go a.pipePeer(ctx, conn, peer, hub) + peer.SetCloseHook(func() { + if err := conn.Close(); err != nil { + return + } + }) + if !hub.Running() { + if err := conn.Close(); err != nil { + return nil, err + } + return nil, stream.ErrHubNotRunning + } + if err := hub.AddPeer(peer); err != nil { + if closeErr := conn.Close(); closeErr != nil { + return nil, core.ErrorJoin(err, closeErr) + } + return nil, err + } + if err := hub.SubscribePeer(peer, "*"); err != nil { + hub.RemovePeer(peer) + if closeErr := conn.Close(); closeErr != nil { + return nil, core.ErrorJoin(err, closeErr) + } + return nil, err + } + go adapter.pipePeer(ctx, conn, peer, hub) return peer, nil } -func (a *Adapter) listen() (net.Listener, error) { - a.mu.Lock() - defer a.mu.Unlock() - if a.listener != nil { - return a.listener, nil +func (adapter *Adapter) listen() (net.Listener, error) { + adapter.mutex.Lock() + defer adapter.mutex.Unlock() + if adapter.listener != nil { + return adapter.listener, nil } var ( listener net.Listener err error ) - if a.config.TLS != nil { - listener, err = tls.Listen("tcp", a.config.Addr, a.config.TLS) + if adapter.config.TLS != nil { + listener, err = tls.Listen("tcp", adapter.config.Addr, adapter.config.TLS) } else { - listener, err = net.Listen("tcp", a.config.Addr) + listener, err = net.Listen("tcp", adapter.config.Addr) } if err != nil { return nil, err } - a.listener = listener + adapter.listener = listener return listener, nil } -func (a *Adapter) dial(ctx context.Context) (net.Conn, error) { +func (adapter *Adapter) dial(ctx context.Context) (net.Conn, error) { + if ctx == nil { + ctx = context.Background() + } dialer := &net.Dialer{} - if a.config.TLS != nil { - conn, err := dialer.DialContext(ctx, "tcp", a.config.Addr) + if adapter.config.TLS != nil { + conn, err := dialer.DialContext(ctx, "tcp", adapter.config.Addr) if err != nil { return nil, err } - tlsConn := tls.Client(conn, a.config.TLS) + tlsConn := tls.Client(conn, adapter.config.TLS) if err := tlsConn.HandshakeContext(ctx); err != nil { - _ = conn.Close() + if closeErr := conn.Close(); closeErr != nil { + return nil, core.ErrorJoin(err, closeErr) + } return nil, err } return tlsConn, nil } - return dialer.DialContext(ctx, "tcp", a.config.Addr) + return dialer.DialContext(ctx, "tcp", adapter.config.Addr) } -func (a *Adapter) handleConn(ctx context.Context, conn net.Conn, hub *stream.Hub) { +func (adapter *Adapter) handleConn(ctx context.Context, conn net.Conn, hub *stream.Hub) { defer conn.Close() + stopClose := context.AfterFunc(ctx, func() { + if err := conn.Close(); err != nil { + return + } + }) + defer stopClose() - _, handshake, err := readFrame(conn, a.config.HandshakeTimeout) + handshakeMaxSize := MaxFrameSize + if adapter.config.ConnAuthenticator != nil { + handshakeMaxSize = maxHandshakeFrameSize + } + channel, frame, err := readTCPFrame(conn, adapter.config.HandshakeTimeout, handshakeMaxSize) if err != nil { return } - if auth := a.config.ConnAuthenticator; auth != nil { - result := auth.AuthenticateConn(handshake) - if !result.Valid { + authResult := stream.AuthResult{Valid: true} + if auth := adapter.config.ConnAuthenticator; auth != nil { + authResult = auth.AuthenticateConn(frame) + if !authResult.Valid { return } } peer := stream.NewPeer("tcp") - _ = hub.AddPeer(peer) - _ = hub.SubscribePeer(peer, "*") + peer.UserID = authResult.UserID + if authResult.Claims != nil { + peer.Claims = authResult.Claims + } + peer.SetCloseHook(func() { + if err := conn.Close(); err != nil { + return + } + }) + if !hub.Running() { + return + } + if err := hub.AddPeer(peer); err != nil { + return + } + if err := hub.SubscribePeer(peer, "*"); err != nil { + hub.RemovePeer(peer) + return + } defer hub.RemovePeer(peer) - go a.writePump(ctx, conn, peer) + go adapter.writePump(ctx, conn, peer, hub.Config().WriteTimeout) + + if auth := adapter.config.ConnAuthenticator; auth == nil { + if err := dispatchTCPFrame(hub, peer, channel, frame); err != nil { + return + } + } for { - channel, frame, err := readFrame(conn, 0) + channel, frame, err := readTCPFrame(conn, 0, MaxFrameSize) if err != nil { return } if channel == "" { - _ = hub.Broadcast(frame) + if err := hub.BroadcastFromPeer(peer, frame); err != nil { + return + } continue } - _ = hub.Publish(channel, frame) + if err := hub.PublishFromPeer(peer, channel, frame); err != nil { + return + } + } +} + +func dispatchTCPFrame(hub *stream.Hub, peer *stream.Peer, channel string, frame []byte) error { + if channel == "" { + return hub.BroadcastFromPeer(peer, frame) } + return hub.PublishFromPeer(peer, channel, frame) } -func (a *Adapter) pipePeer(ctx context.Context, conn net.Conn, peer *stream.Peer, hub *stream.Hub) { +func (adapter *Adapter) pipePeer(ctx context.Context, conn net.Conn, peer *stream.Peer, hub *stream.Hub) { defer conn.Close() - go a.writePump(ctx, conn, peer) + stopClose := context.AfterFunc(ctx, func() { + if err := conn.Close(); err != nil { + return + } + }) + defer stopClose() + go adapter.writePump(ctx, conn, peer, hub.Config().WriteTimeout) for { - channel, frame, err := readFrame(conn, 0) + channel, frame, err := readTCPFrame(conn, 0, MaxFrameSize) if err != nil { hub.RemovePeer(peer) return } - if channel == "" { - _ = hub.Broadcast(frame) - continue + if err := dispatchTCPFrame(hub, peer, channel, frame); err != nil { + hub.RemovePeer(peer) + return } - _ = hub.Publish(channel, frame) } } -func (a *Adapter) writePump(ctx context.Context, conn net.Conn, peer *stream.Peer) { +func (adapter *Adapter) writePump(ctx context.Context, conn net.Conn, peer *stream.Peer, writeTimeout time.Duration) { for { select { case <-ctx.Done(): @@ -213,16 +329,27 @@ func (a *Adapter) writePump(ctx context.Context, conn net.Conn, peer *stream.Pee if !ok { return } - if _, err := conn.Write(frame); err != nil { + if writeTimeout > 0 { + if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return + } + } + if err := writeAll(conn, frame); err != nil { return } } } } -func readFrame(conn net.Conn, timeout time.Duration) (string, []byte, error) { +func readTCPFrame(conn net.Conn, timeout time.Duration, maxFrameSize int) (string, []byte, error) { if timeout > 0 { - _ = conn.SetReadDeadline(time.Now().Add(timeout)) + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return "", nil, err + } + } else { + if err := conn.SetReadDeadline(time.Time{}); err != nil { + return "", nil, err + } } var length uint32 if err := binary.Read(conn, binary.BigEndian, &length); err != nil { @@ -231,7 +358,7 @@ func readFrame(conn net.Conn, timeout time.Duration) (string, []byte, error) { } return "", nil, err } - if length > MaxFrameSize { + if maxFrameSize > 0 && length > uint32(maxFrameSize) { return "", nil, core.E("stream.tcp", "frame too large", nil) } payload := make([]byte, length) @@ -250,7 +377,7 @@ func readFrame(conn net.Conn, timeout time.Duration) (string, []byte, error) { return channel, frame, nil } -func encodeFrame(channel string, frame []byte) []byte { +func encodeTCPFrame(channel string, frame []byte) []byte { channelBytes := []byte(channel) payloadLength := uint32(4 + len(channelBytes) + len(frame)) buffer := make([]byte, 4+payloadLength) @@ -260,12 +387,34 @@ func encodeFrame(channel string, frame []byte) []byte { copy(buffer[8+len(channelBytes):], frame) return buffer } + +func writeAll(conn net.Conn, payload []byte) error { + for len(payload) > 0 { + written, err := conn.Write(payload) + if err != nil { + return err + } + if written <= 0 { + return io.ErrShortWrite + } + payload = payload[written:] + } + return nil +} + +func (adapter *Adapter) writeHandshake(conn net.Conn) error { + if conn == nil { + return core.E("stream.tcp", "nil connection", nil) + } + if len(adapter.config.HandshakeFrame) == 0 && adapter.config.HandshakeChannel == "" { + return nil + } + return writeAll(conn, encodeTCPFrame(adapter.config.HandshakeChannel, adapter.config.HandshakeFrame)) +} + func isClosedNetworkError(err error) bool { if err == nil { return false } - if errors.Is(err, net.ErrClosed) { - return true - } - return false + return err == net.ErrClosed } diff --git a/adapter/tcp/tcp_test.go b/adapter/tcp/tcp_test.go new file mode 100644 index 0000000..944bf9f --- /dev/null +++ b/adapter/tcp/tcp_test.go @@ -0,0 +1,822 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package tcp + +import ( + "bytes" + "context" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "dappco.re/go/stream" +) + +func TestTCP_Listen_Good(t *testing.T) { + hub := stream.NewHubWithConfig(stream.HubConfig{ + OnConnect: func(peer *stream.Peer) { + if peer.UserID != "user-42" { + t.Errorf("peer.UserID = %q, want %q", peer.UserID, "user-42") + } + if peer.Claims["role"] != "admin" { + t.Errorf("peer.Claims[role] = %v, want %q", peer.Claims["role"], "admin") + } + }, + }) + + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Addr: "127.0.0.1:0", + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + if string(handshake) != "hello" { + return stream.AuthResult{Valid: false} + } + return stream.AuthResult{ + Valid: true, + UserID: "user-42", + Claims: map[string]any{"role": "admin"}, + } + }), + }) + adapter.Mount(hub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + defer listenCancel() + go func() { + _ = adapter.Listen(listenContext) + }() + + address := waitForListenerAddress(t, adapter) + connection, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + if _, err := connection.Write(encodeTCPFrame("", []byte("hello"))); err != nil { + t.Fatalf("Write() error = %v", err) + } + + waitForPeerCount(t, hub, 1) +} + +func TestTCP_Listen_NoAuthenticator_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Addr: "127.0.0.1:0", + }) + adapter.Mount(hub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + defer listenCancel() + go func() { + _ = adapter.Listen(listenContext) + }() + + address := waitForListenerAddress(t, adapter) + connection, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if _, err := connection.Write(encodeTCPFrame("block", []byte("template"))); err != nil { + t.Fatalf("Write() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for unauthenticated frame") + } +} + +func TestTCP_Listen_SelfDelivery_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Addr: "127.0.0.1:0", + }) + adapter.Mount(hub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + defer listenCancel() + go func() { + _ = adapter.Listen(listenContext) + }() + + address := waitForListenerAddress(t, adapter) + connection, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if _, err := connection.Write(encodeTCPFrame("block", []byte("template"))); err != nil { + t.Fatalf("Write() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for published TCP frame") + } + + channel, frame, err := readTCPFrame(connection, 2*time.Second, MaxFrameSize) + if err != nil { + t.Fatalf("readTCPFrame() error = %v", err) + } + if channel != "block" { + t.Fatalf("readTCPFrame() channel = %q, want %q", channel, "block") + } + if string(frame) != "template" { + t.Fatalf("readTCPFrame() frame = %q, want %q", string(frame), "template") + } +} + +func TestTCP_Listen_ContextCancel_ClosesPeer_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Addr: "127.0.0.1:0", + }) + adapter.Mount(hub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + go func() { + _ = adapter.Listen(listenContext) + }() + + address := waitForListenerAddress(t, adapter) + connection, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + if _, err := connection.Write(encodeTCPFrame("", []byte("hello"))); err != nil { + t.Fatalf("Write() error = %v", err) + } + + waitForPeerCount(t, hub, 1) + + channel, frame, err := readTCPFrame(connection, 2*time.Second, MaxFrameSize) + if err != nil { + t.Fatalf("readTCPFrame() initial echo error = %v", err) + } + if channel != "" { + t.Fatalf("readTCPFrame() initial echo channel = %q, want %q", channel, "") + } + if string(frame) != "hello" { + t.Fatalf("readTCPFrame() initial echo frame = %q, want %q", string(frame), "hello") + } + + listenCancel() + + channel, frame, err = readTCPFrame(connection, 2*time.Second, MaxFrameSize) + if err == nil { + t.Fatalf("readTCPFrame() = (%q, %q, nil), want connection close", channel, string(frame)) + } + if err == stream.ErrHandshakeTimeout { + t.Fatalf("readTCPFrame() error = %v, want connection close", err) + } + + waitForPeerCount(t, hub, 0) +} + +func TestTCP_Listen_Bad(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Addr: "127.0.0.1:0", + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + return stream.AuthResult{Valid: false} + }), + }) + adapter.Mount(hub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + defer listenCancel() + go func() { + _ = adapter.Listen(listenContext) + }() + + address := waitForListenerAddress(t, adapter) + connection, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + if _, err := connection.Write(encodeTCPFrame("", []byte("nope"))); err != nil { + t.Fatalf("Write() error = %v", err) + } + + waitForPeerCount(t, hub, 0) +} + +func TestTCP_Listen_Ugly(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Addr: "127.0.0.1:0", + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + return stream.AuthResult{Valid: true} + }), + HandshakeTimeout: 50 * time.Millisecond, + }) + adapter.Mount(hub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + defer listenCancel() + go func() { + _ = adapter.Listen(listenContext) + }() + + address := waitForListenerAddress(t, adapter) + connection, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + time.Sleep(120 * time.Millisecond) + waitForPeerCount(t, hub, 0) +} + +func TestTCP_Listen_NoAuthenticator_LargeInitialFrame_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Addr: "127.0.0.1:0", + }) + adapter.Mount(hub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + defer listenCancel() + go func() { + _ = adapter.Listen(listenContext) + }() + + address := waitForListenerAddress(t, adapter) + connection, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + largeFrame := bytes.Repeat([]byte("a"), maxHandshakeFrameSize+1) + if _, err := connection.Write(encodeTCPFrame("block", largeFrame)); err != nil { + t.Fatalf("Write() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != string(largeFrame) { + t.Fatalf("received frame size = %d, want %d", len(frame), len(largeFrame)) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for large initial frame") + } +} + +func TestReconnectingTCP_Send_Concurrent_Good(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer listener.Close() + + serverAccepted := make(chan net.Conn, 1) + go func() { + connection, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + serverAccepted <- connection + }() + + client := NewReconnectingTCP(ReconnectConfig{Addr: listener.Addr().String()}) + + connectContext, connectCancel := context.WithCancel(context.Background()) + connectDone := make(chan error, 1) + go func() { + connectDone <- client.Connect(connectContext) + }() + defer func() { + connectCancel() + _ = client.Close() + <-connectDone + }() + + serverConnection := <-serverAccepted + defer serverConnection.Close() + + for deadline := time.Now().Add(2 * time.Second); time.Now().Before(deadline); { + if client.State() == stream.StateConnected { + break + } + time.Sleep(10 * time.Millisecond) + } + if client.State() != stream.StateConnected { + t.Fatal("client did not reach connected state") + } + + senderCount := 32 + var sendGroup sync.WaitGroup + for index := range senderCount { + sendGroup.Add(1) + go func(index int) { + defer sendGroup.Done() + if sendErr := client.Send("hashrate", []byte{byte(index)}); sendErr != nil { + t.Errorf("Send() error = %v", sendErr) + } + }(index) + } + sendGroup.Wait() + + receivedValues := map[byte]bool{} + for len(receivedValues) < senderCount { + channel, frame, readErr := readTCPFrame(serverConnection, 2*time.Second, MaxFrameSize) + if readErr != nil { + t.Fatalf("readTCPFrame() error = %v", readErr) + } + if channel != "hashrate" { + t.Fatalf("readTCPFrame() channel = %q, want %q", channel, "hashrate") + } + if len(frame) != 1 { + t.Fatalf("len(frame) = %d, want %d", len(frame), 1) + } + receivedValues[frame[0]] = true + } +} + +func TestTCP_Listen_AuthHandshakeTooLarge_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Addr: "127.0.0.1:0", + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + return stream.AuthResult{Valid: true} + }), + }) + adapter.Mount(hub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + defer listenCancel() + go func() { + _ = adapter.Listen(listenContext) + }() + + address := waitForListenerAddress(t, adapter) + connection, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + tooLargeHandshake := make([]byte, maxHandshakeFrameSize+1) + if _, err := connection.Write(encodeTCPFrame("", tooLargeHandshake)); err != nil { + t.Fatalf("Write() error = %v", err) + } + + waitForPeerCount(t, hub, 0) +} + +func TestTCP_Dial_NilContext_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer listener.Close() + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + connection, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer connection.Close() + _, _ = connection.Write(encodeTCPFrame("block", []byte("template"))) + time.Sleep(50 * time.Millisecond) + }() + + adapter := New(Config{Addr: listener.Addr().String()}) + peer, err := adapter.Dial(nil, hub) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + if peer == nil { + t.Fatal("Dial() peer = nil") + } + defer peer.Close() + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for dialed frame") + } + + <-serverDone +} + +func TestTCP_Dial_HubNotRunning_Bad(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer listener.Close() + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + connection, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer connection.Close() + _, _, _ = readTCPFrame(connection, 2*time.Second, MaxFrameSize) + }() + + adapter := New(Config{Addr: listener.Addr().String()}) + peer, err := adapter.Dial(context.Background(), stream.NewHub()) + if err == nil { + if peer != nil { + peer.Close() + } + t.Fatal("Dial() error = nil, want hub lifecycle failure") + } + if peer != nil { + t.Fatalf("Dial() peer = %#v, want nil", peer) + } + + <-serverDone +} + +func TestTCP_Dial_Handshake_Good(t *testing.T) { + serverHub := stream.NewHub() + serverHubContext, serverHubCancel := context.WithCancel(context.Background()) + defer serverHubCancel() + go serverHub.Run(serverHubContext) + + serverAdapter := New(Config{ + Addr: "127.0.0.1:0", + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + if string(handshake) != "trusted" { + return stream.AuthResult{Valid: false} + } + return stream.AuthResult{Valid: true, UserID: "peer-1"} + }), + }) + serverAdapter.Mount(serverHub) + + listenContext, listenCancel := context.WithCancel(context.Background()) + defer listenCancel() + go func() { + _ = serverAdapter.Listen(listenContext) + }() + + clientHub := stream.NewHub() + clientHubContext, clientHubCancel := context.WithCancel(context.Background()) + defer clientHubCancel() + go clientHub.Run(clientHubContext) + + clientAdapter := New(Config{ + Addr: waitForListenerAddress(t, serverAdapter), + HandshakeFrame: []byte("trusted"), + }) + + received := make(chan []byte, 1) + unsubscribe := clientHub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + peer, err := clientAdapter.Dial(context.Background(), clientHub) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer peer.Close() + + waitForPeerCount(t, serverHub, 1) + + if err := serverHub.Publish("block", []byte("template")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for dialed handshake frame") + } +} + +func TestAX7_ReconnectingTCP_State_Good(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer listener.Close() + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + connection, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer connection.Close() + time.Sleep(100 * time.Millisecond) + }() + + client := NewReconnectingTCP(ReconnectConfig{ + Addr: listener.Addr().String(), + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 10 * time.Millisecond, + }) + if client.State() != stream.StateDisconnected { + t.Fatalf("State() = %v, want %v", client.State(), stream.StateDisconnected) + } + + connectContext, connectCancel := context.WithCancel(context.Background()) + defer connectCancel() + connectDone := make(chan error, 1) + go func() { + connectDone <- client.Connect(connectContext) + }() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if client.State() == stream.StateConnected { + break + } + time.Sleep(10 * time.Millisecond) + } + if client.State() != stream.StateConnected { + t.Fatalf("State() = %v, want %v", client.State(), stream.StateConnected) + } + + connectCancel() + select { + case err := <-connectDone: + if err != nil { + t.Fatalf("Connect() error = %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for Connect() to return") + } + + if err := client.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if client.State() != stream.StateDisconnected { + t.Fatalf("State() = %v, want %v", client.State(), stream.StateDisconnected) + } + + <-serverDone +} + +func TestReconnectingTCP_OnReconnect_Good(t *testing.T) { + var reconnectCount atomic.Int32 + client := NewReconnectingTCP(ReconnectConfig{ + Addr: "127.0.0.1:1", + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 10 * time.Millisecond, + MaxRetries: 1, + OnReconnect: func(attempt int) { + reconnectCount.Store(int32(attempt)) + }, + }) + + err := client.Connect(context.Background()) + if err == nil { + t.Fatal("Connect() error = nil, want dial error") + } + if reconnectCount.Load() != 1 { + t.Fatalf("OnReconnect attempt = %d, want %d", reconnectCount.Load(), 1) + } + if client.State() != stream.StateDisconnected { + t.Fatalf("State() = %v, want %v", client.State(), stream.StateDisconnected) + } +} + +func TestReconnectingTCP_Connect_Handshake_Good(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer listener.Close() + + received := make(chan []byte, 1) + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + connection, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer connection.Close() + + channel, frame, readErr := readTCPFrame(connection, time.Second, MaxFrameSize) + if readErr != nil { + return + } + if channel != "auth" { + return + } + received <- append([]byte(nil), frame...) + _ = writeAll(connection, encodeTCPFrame("block", []byte("template"))) + }() + + clientMessages := make(chan []byte, 1) + client := NewReconnectingTCP(ReconnectConfig{ + Addr: listener.Addr().String(), + HandshakeChannel: "auth", + HandshakeFrame: []byte("trusted"), + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 10 * time.Millisecond, + OnMessage: func(channel string, frame []byte) { + if channel == "block" { + clientMessages <- append([]byte(nil), frame...) + } + }, + }) + + connectContext, connectCancel := context.WithCancel(context.Background()) + connectDone := make(chan error, 1) + go func() { + connectDone <- client.Connect(connectContext) + }() + + select { + case frame := <-received: + if string(frame) != "trusted" { + t.Fatalf("handshake frame = %q, want %q", string(frame), "trusted") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handshake frame") + } + + select { + case frame := <-clientMessages: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for reconnecting client frame") + } + + connectCancel() + select { + case err := <-connectDone: + if err != nil && err != context.Canceled { + t.Fatalf("Connect() error = %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for Connect() to return") + } + + <-serverDone +} + +func waitForListenerAddress(t *testing.T, adapter *Adapter) string { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + adapter.mutex.Lock() + listener := adapter.listener + adapter.mutex.Unlock() + if listener != nil { + return listener.Addr().String() + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for listener") + return "" +} + +func waitForPeerCount(t *testing.T, hub *stream.Hub, expected int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if hub.PeerCount() == expected { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("PeerCount() = %d, want %d", hub.PeerCount(), expected) +} + +func TestWriteAll_Good(t *testing.T) { + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + wrapped := &partialWriteConn{Conn: left, chunkSize: 2} + payload := []byte("hello") + received := make(chan []byte, 1) + go func() { + buffer := make([]byte, len(payload)) + _, err := io.ReadFull(right, buffer) + if err != nil { + received <- nil + return + } + received <- buffer + }() + + if err := writeAll(wrapped, payload); err != nil { + t.Fatalf("writeAll() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "hello" { + t.Fatalf("received frame = %q, want %q", string(frame), "hello") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for payload") + } +} + +type partialWriteConn struct { + net.Conn + chunkSize int +} + +func (conn *partialWriteConn) Write(payload []byte) (int, error) { + if conn.chunkSize > 0 && len(payload) > conn.chunkSize { + payload = payload[:conn.chunkSize] + } + return conn.Conn.Write(payload) +} diff --git a/adapter/ws/ax7_more_test.go b/adapter/ws/ax7_more_test.go new file mode 100644 index 0000000..fe20a5b --- /dev/null +++ b/adapter/ws/ax7_more_test.go @@ -0,0 +1,442 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package ws + +import ( + "github.com/alicebob/miniredis/v2" + "github.com/gorilla/websocket" + + core "dappco.re/go" + "dappco.re/go/stream" + adapterredis "dappco.re/go/stream/adapter/redis" +) + +func ax7WebSocketPair(t *core.T) (*websocket.Conn, *websocket.Conn, func()) { + upgrader := websocket.Upgrader{CheckOrigin: func(*core.Request) bool { return true }} + serverConn := make(chan *websocket.Conn, 1) + server := core.NewHTTPTestServer(core.HandlerFunc(func(w core.ResponseWriter, r *core.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("Upgrade() error = %v", err) + return + } + serverConn <- conn + })) + clientConn, _, err := websocket.DefaultDialer.Dial("ws"+server.URL[len("http"):], nil) + core.RequireNoError(t, err) + conn := <-serverConn + cleanup := func() { + clientConn.Close() + conn.Close() + server.Close() + } + return clientConn, conn, cleanup +} + +func TestAX7_New_Good(t *core.T) { + adapter := New(Config{ReadBufferSize: 2048, WriteBufferSize: 4096}) + + core.AssertNotNil(t, adapter) + core.AssertEqual(t, 2048, adapter.config.ReadBufferSize) + core.AssertEqual(t, 4096, adapter.config.WriteBufferSize) +} + +func TestAX7_New_Bad(t *core.T) { + adapter := New(Config{}) + + core.AssertNotNil(t, adapter) + core.AssertEqual(t, 1024, adapter.config.ReadBufferSize) + core.AssertEqual(t, 1024, adapter.config.WriteBufferSize) +} + +func TestAX7_New_Ugly(t *core.T) { + allowed := false + adapter := New(Config{CheckOrigin: func(*core.Request) bool { allowed = true; return true }}) + + core.AssertTrue(t, adapter.config.CheckOrigin(nil)) + core.AssertTrue(t, allowed) +} + +func TestAX7_Adapter_Mount_Good(t *core.T) { + adapter := New(Config{}) + hub := stream.NewHub() + + adapter.Mount(hub) + core.AssertEqual(t, hub, adapter.hub) + core.AssertNotNil(t, adapter.Handler()) +} + +func TestAX7_Adapter_Mount_Bad(t *core.T) { + adapter := New(Config{}) + + adapter.Mount(nil) + core.AssertNil(t, adapter.hub) + core.AssertNotNil(t, adapter) +} + +func TestAX7_Adapter_Mount_Ugly(t *core.T) { + adapter := New(Config{}) + first := stream.NewHub() + second := stream.NewHub() + + adapter.Mount(first) + adapter.Mount(second) + core.AssertEqual(t, second, adapter.hub) +} + +func TestAX7_Adapter_ServeHTTP_Bad(t *core.T) { + adapter := New(Config{}) + recorder := core.NewHTTPTestRecorder() + request := core.NewHTTPTestRequest("GET", "/stream/ws", nil) + + adapter.ServeHTTP(recorder, request) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not mounted") +} + +func TestAX7_Adapter_ServeHTTP_Ugly(t *core.T) { + adapter := New(Config{}) + adapter.Mount(stream.NewHub()) + recorder := core.NewHTTPTestRecorder() + request := core.NewHTTPTestRequest("GET", "/stream/ws?channel=hashrate", nil) + + adapter.ServeHTTP(recorder, request) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not running") +} + +func TestAX7_Adapter_HandlerForChannel_Bad(t *core.T) { + adapter := New(Config{}) + handler := adapter.HandlerForChannel("hashrate") + recorder := core.NewHTTPTestRecorder() + + handler.ServeHTTP(recorder, core.NewHTTPTestRequest("GET", "/stream/ws", nil)) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not mounted") +} + +func TestAX7_Adapter_HandlerForChannel_Ugly(t *core.T) { + adapter := New(Config{}) + adapter.Mount(stream.NewHub()) + handler := adapter.HandlerForChannel("hashrate") + recorder := core.NewHTTPTestRecorder() + + handler.ServeHTTP(recorder, core.NewHTTPTestRequest("GET", "/stream/ws", nil)) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not running") +} + +func TestAX7_DefaultHubConfig_Good(t *core.T) { + config := DefaultHubConfig() + + core.AssertEqual(t, 30*core.Second, config.HeartbeatInterval) + core.AssertEqual(t, 60*core.Second, config.PongTimeout) + core.AssertEqual(t, 10*core.Second, config.WriteTimeout) +} + +func TestAX7_DefaultHubConfig_Bad(t *core.T) { + config := DefaultHubConfig() + + core.AssertNil(t, config.OnConnect) + core.AssertNil(t, config.OnDisconnect) + core.AssertNil(t, config.ChannelAuthoriser) +} + +func TestAX7_DefaultHubConfig_Ugly(t *core.T) { + config := DefaultHubConfig() + + core.AssertGreater(t, config.PongTimeout, config.HeartbeatInterval) + core.AssertGreater(t, config.WriteTimeout, core.Duration(0)) +} + +func TestAX7_NewAPIKeyAuth_Good(t *core.T) { + authenticator := NewAPIKeyAuth(map[string]string{"sk": "user"}) + + core.AssertNotNil(t, authenticator) + core.AssertEqual(t, "user", authenticator.Keys["sk"]) +} + +func TestAX7_NewAPIKeyAuth_Bad(t *core.T) { + authenticator := NewAPIKeyAuth(nil) + + core.AssertNotNil(t, authenticator) + core.AssertEqual(t, 0, len(authenticator.Keys)) +} + +func TestAX7_NewAPIKeyAuth_Ugly(t *core.T) { + keys := map[string]string{"sk": "user"} + authenticator := NewAPIKeyAuth(keys) + keys["sk"] = "mutated" + + core.AssertEqual(t, "user", authenticator.Keys["sk"]) + core.AssertEqual(t, "mutated", keys["sk"]) +} + +func TestAX7_NewHub_Good(t *core.T) { + hub := NewHub() + + core.AssertNotNil(t, hub) + core.AssertFalse(t, hub.Running()) + core.AssertEqual(t, 0, hub.PeerCount()) +} + +func TestAX7_NewHub_Bad(t *core.T) { + hub := NewHub() + + core.AssertEqual(t, 30*core.Second, hub.Config().HeartbeatInterval) + core.AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_NewHub_Ugly(t *core.T) { + left := NewHub() + right := NewHub() + + core.AssertNotEqual(t, left, right) + core.AssertNoError(t, left.AddPeer(stream.NewPeer("ws"))) +} + +func TestAX7_NewHubWithConfig_Good(t *core.T) { + hub := NewHubWithConfig(stream.HubConfig{HeartbeatInterval: core.Second, PongTimeout: 3 * core.Second}) + + core.AssertEqual(t, core.Second, hub.Config().HeartbeatInterval) + core.AssertEqual(t, 3*core.Second, hub.Config().PongTimeout) +} + +func TestAX7_NewHubWithConfig_Bad(t *core.T) { + hub := NewHubWithConfig(stream.HubConfig{}) + + core.AssertEqual(t, 30*core.Second, hub.Config().HeartbeatInterval) + core.AssertEqual(t, 60*core.Second, hub.Config().PongTimeout) +} + +func TestAX7_NewHubWithConfig_Ugly(t *core.T) { + called := false + hub := NewHubWithConfig(stream.HubConfig{OnConnect: func(*stream.Peer) { called = true }}) + + core.AssertNoError(t, hub.AddPeer(stream.NewPeer("ws"))) + core.AssertTrue(t, called) +} + +func TestAX7_NewPeer_Good(t *core.T) { + peer := NewPeer("ws") + + core.AssertNotNil(t, peer) + core.AssertEqual(t, "ws", peer.Transport) + core.AssertNotEmpty(t, peer.ID) +} + +func TestAX7_NewPeer_Bad(t *core.T) { + peer := NewPeer("") + + core.AssertNotNil(t, peer) + core.AssertEqual(t, "", peer.Transport) + core.AssertNotNil(t, peer.SendQueue()) +} + +func TestAX7_NewPeer_Ugly(t *core.T) { + left := NewPeer("ws") + right := NewPeer("ws") + + core.AssertNotEqual(t, left.ID, right.ID) + core.AssertEqual(t, "ws", right.Transport) +} + +func TestAX7_Pipe_Good(t *core.T) { + source := stream.NewHub() + destination := stream.NewHub() + + stop := Pipe(source, destination) + core.AssertNotNil(t, stop) + stop() +} + +func TestAX7_Pipe_Bad(t *core.T) { + stop := Pipe(nil, stream.NewHub()) + + core.AssertNotNil(t, stop) + core.AssertNotPanics(t, stop) +} + +func TestAX7_Pipe_Ugly(t *core.T) { + hub := stream.NewHub() + stop := Pipe(hub, hub) + + core.AssertNotNil(t, stop) + core.AssertNotPanics(t, stop) +} + +func TestAX7_NewRedisBridge_Good(t *core.T) { + redisServer := miniredis.RunT(t) + bridge, err := NewRedisBridge(stream.NewHub(), adapterredis.Config{Addr: redisServer.Addr(), Prefix: "pool"}) + + core.AssertNoError(t, err) + core.AssertNotNil(t, bridge) + core.AssertNotEmpty(t, bridge.SourceID()) +} + +func TestAX7_NewRedisBridge_Bad(t *core.T) { + redisServer := miniredis.RunT(t) + bridge, err := NewRedisBridge(nil, adapterredis.Config{Addr: redisServer.Addr(), Prefix: "pool"}) + + core.AssertError(t, err) + core.AssertNil(t, bridge) +} + +func TestAX7_NewRedisBridge_Ugly(t *core.T) { + bridge, err := NewRedisBridge(stream.NewHub(), adapterredis.Config{}) + + core.AssertError(t, err) + core.AssertNil(t, bridge) +} + +func TestAX7_NewReconnectingClient_Good(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{URL: "ws://127.0.0.1/stream/ws"}) + + core.AssertNotNil(t, client) + core.AssertEqual(t, stream.StateDisconnected, client.State()) + core.AssertNoError(t, client.Close()) +} + +func TestAX7_NewReconnectingClient_Bad(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{}) + + core.AssertEqual(t, 500*core.Millisecond, client.config.InitialBackoff) + core.AssertEqual(t, 30*core.Second, client.config.MaxBackoff) + core.AssertEqual(t, 2.0, client.config.BackoffMultiplier) +} + +func TestAX7_NewReconnectingClient_Ugly(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{InitialBackoff: core.Millisecond, MaxBackoff: core.Second, BackoffMultiplier: 3}) + + core.AssertEqual(t, core.Millisecond, client.config.InitialBackoff) + core.AssertEqual(t, core.Second, client.config.MaxBackoff) + core.AssertEqual(t, 3.0, client.config.BackoffMultiplier) +} + +func TestAX7_ReconnectingClient_Connect_Good(t *core.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*core.Request) bool { return true }} + connected := make(chan struct{}, 1) + server := core.NewHTTPTestServer(core.HandlerFunc(func(w core.ResponseWriter, r *core.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err == nil { + defer conn.Close() + <-r.Context().Done() + } + })) + defer server.Close() + ctx, cancel := core.WithCancel(core.Background()) + defer cancel() + client := NewReconnectingClient(ReconnectConfig{ + URL: "ws" + server.URL[len("http"):], + OnConnect: func() { connected <- struct{}{} }, + }) + errs := make(chan error, 1) + + go func() { errs <- client.Connect(ctx) }() + <-connected + core.AssertEqual(t, stream.StateConnected, client.State()) + core.AssertNoError(t, client.Close()) + cancel() + core.AssertNoError(t, <-errs) +} + +func TestAX7_ReconnectingClient_Connect_Bad(t *core.T) { + var client *ReconnectingClient + + err := client.Connect(core.Background()) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "nil reconnecting client") +} + +func TestAX7_ReconnectingClient_Connect_Ugly(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{URL: "://bad-url", MaxRetries: 1, InitialBackoff: core.Millisecond}) + + err := client.Connect(core.Background()) + core.AssertError(t, err) + core.AssertEqual(t, stream.StateDisconnected, client.State()) +} + +func TestAX7_ReconnectingClient_Send_Good(t *core.T) { + clientConn, serverConn, cleanup := ax7WebSocketPair(t) + defer cleanup() + client := NewReconnectingClient(ReconnectConfig{}) + client.mutex.Lock() + client.conn = clientConn + client.state = stream.StateConnected + client.mutex.Unlock() + + core.AssertNoError(t, client.Send(stream.Message{Type: stream.TypePing, Channel: "health"})) + _, payload, err := serverConn.ReadMessage() + core.AssertNoError(t, err) + core.AssertContains(t, string(payload), `"type":"ping"`) +} + +func TestAX7_ReconnectingClient_Send_Bad(t *core.T) { + var client *ReconnectingClient + + err := client.Send(stream.Message{Type: stream.TypePing}) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "nil reconnecting client") +} + +func TestAX7_ReconnectingClient_Send_Ugly(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{}) + + err := client.Send(stream.Message{Type: stream.TypePing}) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "not connected") +} + +func TestAX7_ReconnectingClient_State_Good(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{}) + client.mutex.Lock() + client.state = stream.StateConnected + client.mutex.Unlock() + + core.AssertEqual(t, stream.StateConnected, client.State()) + core.AssertNoError(t, client.Close()) +} + +func TestAX7_ReconnectingClient_State_Bad(t *core.T) { + var client *ReconnectingClient + + core.AssertEqual(t, stream.StateDisconnected, client.State()) + core.AssertNil(t, client) +} + +func TestAX7_ReconnectingClient_State_Ugly(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{}) + + core.AssertNoError(t, client.Close()) + core.AssertEqual(t, stream.StateDisconnected, client.State()) + core.AssertTrue(t, client.closed) +} + +func TestAX7_ReconnectingClient_Close_Good(t *core.T) { + clientConn, _, cleanup := ax7WebSocketPair(t) + defer cleanup() + client := NewReconnectingClient(ReconnectConfig{}) + client.mutex.Lock() + client.conn = clientConn + client.state = stream.StateConnected + client.mutex.Unlock() + + core.AssertNoError(t, client.Close()) + core.AssertEqual(t, stream.StateDisconnected, client.State()) + core.AssertTrue(t, client.closed) +} + +func TestAX7_ReconnectingClient_Close_Bad(t *core.T) { + var client *ReconnectingClient + + core.AssertNoError(t, client.Close()) + core.AssertNil(t, client) +} + +func TestAX7_ReconnectingClient_Close_Ugly(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{}) + + core.AssertNoError(t, client.Close()) + core.AssertNoError(t, client.Close()) + core.AssertEqual(t, stream.StateDisconnected, client.State()) +} diff --git a/adapter/ws/compat.go b/adapter/ws/compat.go new file mode 100644 index 0000000..a016354 --- /dev/null +++ b/adapter/ws/compat.go @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: EUPL-1.2 + +// Package ws preserves the legacy go-ws compatibility surface while the new +// transport-agnostic stream package does the actual work. +package ws + +import ( + "dappco.re/go/stream" + "dappco.re/go/stream/adapter/redis" +) + +// Stream preserves the transport-agnostic stream interface for legacy callers. +type Stream = stream.Stream + +// Frame preserves the legacy raw payload alias. +type Frame = stream.Frame + +// Channel preserves the legacy channel name alias. +type Channel = stream.Channel + +// Hub preserves the legacy go-ws Hub type name. +type Hub = stream.Hub + +// HubConfig preserves the legacy go-ws HubConfig type name. +type HubConfig = stream.HubConfig + +// ChannelAuthoriser preserves the legacy go-ws channel authoriser type name. +type ChannelAuthoriser = stream.ChannelAuthoriser + +// HubStats preserves the legacy hub stats type name. +type HubStats = stream.HubStats + +// Peer preserves the transport-agnostic peer type under the legacy package. +type Peer = stream.Peer + +// Client preserves the legacy go-ws Client type name. +type Client = stream.Peer + +// Authenticator preserves the legacy go-ws Authenticator type name. +type Authenticator = stream.Authenticator + +// AuthenticatorFunc preserves the legacy go-ws AuthenticatorFunc helper. +type AuthenticatorFunc = stream.AuthenticatorFunc + +// AuthResult preserves the legacy go-ws AuthResult type name. +type AuthResult = stream.AuthResult + +// APIKeyAuthenticator preserves the legacy API key authenticator type name. +type APIKeyAuthenticator = stream.APIKeyAuthenticator + +// BearerTokenAuth preserves the legacy bearer-token authenticator type name. +type BearerTokenAuth = stream.BearerTokenAuth + +// QueryTokenAuth preserves the legacy query-token authenticator type name. +type QueryTokenAuth = stream.QueryTokenAuth + +// ConnAuthenticator preserves the legacy raw-connection authenticator name. +type ConnAuthenticator = stream.ConnAuthenticator + +// ConnAuthenticatorFunc preserves the legacy raw-connection helper name. +type ConnAuthenticatorFunc = stream.ConnAuthenticatorFunc + +// ConnectionState preserves the reconnecting client connection state type. +type ConnectionState = stream.ConnectionState + +// Message preserves the legacy go-ws WebSocket message envelope. +type Message = stream.Message + +// MessageType preserves the legacy go-ws message type name. +type MessageType = stream.MessageType + +const ( + // TypeProcessOutput preserves the legacy message type constant. + TypeProcessOutput = stream.TypeProcessOutput + // TypeProcessStatus preserves the legacy message type constant. + TypeProcessStatus = stream.TypeProcessStatus + // TypeEvent preserves the legacy message type constant. + TypeEvent = stream.TypeEvent + // TypeError preserves the legacy message type constant. + TypeError = stream.TypeError + // TypePing preserves the legacy message type constant. + TypePing = stream.TypePing + // TypePong preserves the legacy message type constant. + TypePong = stream.TypePong + // TypeSubscribe preserves the legacy message type constant. + TypeSubscribe = stream.TypeSubscribe + // TypeUnsubscribe preserves the legacy message type constant. + TypeUnsubscribe = stream.TypeUnsubscribe + // StateDisconnected preserves the reconnecting client disconnected state. + StateDisconnected = stream.StateDisconnected + // StateConnecting preserves the reconnecting client connecting state. + StateConnecting = stream.StateConnecting + // StateConnected preserves the reconnecting client connected state. + StateConnected = stream.StateConnected +) + +var ( + // ErrMissingAuthHeader preserves the legacy missing-header sentinel error. + ErrMissingAuthHeader = stream.ErrMissingAuthHeader + // ErrMalformedAuthHeader preserves the legacy malformed-header sentinel error. + ErrMalformedAuthHeader = stream.ErrMalformedAuthHeader + // ErrInvalidAPIKey preserves the legacy invalid API key sentinel error. + ErrInvalidAPIKey = stream.ErrInvalidAPIKey + // ErrHandshakeTimeout preserves the legacy handshake timeout sentinel error. + ErrHandshakeTimeout = stream.ErrHandshakeTimeout + // ErrAuthRejected preserves the legacy authenticator rejection sentinel error. + ErrAuthRejected = stream.ErrAuthRejected + // ErrHubNotRunning preserves the legacy hub lifecycle sentinel error. + ErrHubNotRunning = stream.ErrHubNotRunning + // ErrEmptyChannel preserves the legacy empty-channel sentinel error. + ErrEmptyChannel = stream.ErrEmptyChannel +) + +// RedisBridge preserves the legacy go-ws RedisBridge type name. +type RedisBridge = redis.Bridge + +// bridge, err := ws.NewRedisBridge(hub, redis.Config{Addr: "redis:6379", Prefix: "pool"}) +func NewRedisBridge(hub *stream.Hub, config redis.Config) (*RedisBridge, error) { + return redis.NewBridge(hub, config) +} + +// auth := ws.NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) +func NewAPIKeyAuth(keys map[string]string) *APIKeyAuthenticator { + return stream.NewAPIKeyAuth(keys) +} + +// hub := ws.NewHub() +func NewHub() *Hub { + return stream.NewHub() +} + +// hub := ws.NewHubWithConfig(stream.HubConfig{HeartbeatInterval: 30 * time.Second}) +func NewHubWithConfig(config HubConfig) *Hub { + return stream.NewHubWithConfig(config) +} + +// config := ws.DefaultHubConfig() +func DefaultHubConfig() HubConfig { + return stream.DefaultHubConfig() +} + +// peer := ws.NewPeer("ws") +func NewPeer(transport string) *Peer { + return stream.NewPeer(transport) +} + +// stop := ws.Pipe(sourceHub, destinationHub) +func Pipe(source Stream, destination Stream) func() { + return stream.Pipe(source, destination) +} diff --git a/adapter/ws/compat_test.go b/adapter/ws/compat_test.go new file mode 100644 index 0000000..e105f6f --- /dev/null +++ b/adapter/ws/compat_test.go @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package ws + +import ( + "context" + "testing" + "time" +) + +func TestCompat_LegacySurface_Good(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{"valid-key": "user-1"}) + if auth == nil { + t.Fatal("NewAPIKeyAuth() = nil") + } + + var frame Frame = []byte("payload") + if string(frame) != "payload" { + t.Fatalf("Frame alias produced %q, want %q", string(frame), "payload") + } + + var channel Channel = "hashrate" + if channel != "hashrate" { + t.Fatalf("Channel alias produced %q, want %q", channel, "hashrate") + } + + var authoriser ChannelAuthoriser + if authoriser != nil { + t.Fatal("ChannelAuthoriser alias should default to nil") + } + + if StateDisconnected != 0 || StateConnecting != 1 || StateConnected != 2 { + t.Fatalf("unexpected connection states: %d %d %d", StateDisconnected, StateConnecting, StateConnected) + } + + if ErrMissingAuthHeader == nil || ErrMalformedAuthHeader == nil || ErrInvalidAPIKey == nil { + t.Fatal("expected auth sentinel errors to be re-exported") + } + if ErrHandshakeTimeout == nil || ErrAuthRejected == nil || ErrHubNotRunning == nil || ErrEmptyChannel == nil { + t.Fatal("expected transport sentinel errors to be re-exported") + } + + sourceHub := NewHub() + destinationHub := NewHub() + + sourceContext, sourceCancel := context.WithCancel(context.Background()) + defer sourceCancel() + destinationContext, destinationCancel := context.WithCancel(context.Background()) + defer destinationCancel() + + go sourceHub.Run(sourceContext) + go destinationHub.Run(destinationContext) + waitForRunningHub(t, sourceHub) + waitForRunningHub(t, destinationHub) + + received := make(chan []byte, 1) + unsubscribe := destinationHub.Subscribe("hashrate", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + stop := Pipe(sourceHub, destinationHub) + defer stop() + + if err := sourceHub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "123456" { + t.Fatalf("received frame = %q, want %q", string(frame), "123456") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for piped frame") + } + + peer := NewPeer("ws") + if peer == nil { + t.Fatal("NewPeer() = nil") + } + if peer.Transport != "ws" { + t.Fatalf("peer.Transport = %q, want %q", peer.Transport, "ws") + } + + stats := destinationHub.Stats() + var _ HubStats = stats +} + +func TestCompat_LegacySurface_Bad(t *testing.T) { + hub := NewHub() + + if err := hub.Publish("hashrate", []byte("123456")); err != ErrHubNotRunning { + t.Fatalf("Publish() error = %v, want %v", err, ErrHubNotRunning) + } + + peer := NewPeer("ws") + if err := hub.SubscribePeer(peer, ""); err != ErrEmptyChannel { + t.Fatalf("SubscribePeer() error = %v, want %v", err, ErrEmptyChannel) + } +} + +func TestCompat_LegacySurface_Ugly(t *testing.T) { + var source Stream + stop := Pipe(source, source) + if stop == nil { + t.Fatal("Pipe(nil, nil) returned nil stop function") + } + stop() +} + +func waitForRunningHub(t *testing.T, hub *Hub) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if hub.Publish("health", nil) == nil { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for hub to start") +} diff --git a/adapter/ws/example_test.go b/adapter/ws/example_test.go new file mode 100644 index 0000000..5cf941a --- /dev/null +++ b/adapter/ws/example_test.go @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package ws_test + +import ( + "context" + "net/http" + + "dappco.re/go/stream" + "dappco.re/go/stream/adapter/ws" +) + +func ExampleAdapter_Handler() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hub := stream.NewHub() + go hub.Run(ctx) + + adapter := ws.New(ws.Config{ + Authenticator: stream.NewAPIKeyAuth(map[string]string{ + "sk-live": "user-42", + }), + }) + adapter.Mount(hub) + + http.Handle("/stream/ws", adapter.Handler()) +} diff --git a/adapter/ws/reconnect.go b/adapter/ws/reconnect.go index e2c3cb2..48fb785 100644 --- a/adapter/ws/reconnect.go +++ b/adapter/ws/reconnect.go @@ -4,18 +4,24 @@ package ws import ( "context" - "errors" "net/http" "sync" "time" "github.com/gorilla/websocket" - "dappco.re/go/core" + "dappco.re/go" "dappco.re/go/stream" ) -// ReconnectConfig configures the client-side reconnecting WebSocket. +// config := ws.ReconnectConfig{ +// URL: "ws://127.0.0.1:8080/stream/ws", +// OnMessage: func(message stream.Message) { +// _ = message.Channel +// }, +// } +// +// client := ws.NewReconnectingClient(config) type ReconnectConfig struct { URL string InitialBackoff time.Duration @@ -30,17 +36,18 @@ type ReconnectConfig struct { Headers http.Header } -// ReconnectingClient is a WebSocket client with automatic reconnection. +// client := ws.NewReconnectingClient(ws.ReconnectConfig{URL: "ws://127.0.0.1:8080/stream/ws"}) +// _ = client.Connect(context.Background()) type ReconnectingClient struct { config ReconnectConfig state stream.ConnectionState - mu sync.RWMutex + mutex sync.RWMutex conn *websocket.Conn closed bool } -// NewReconnectingClient creates a reconnecting WebSocket client. +// client := ws.NewReconnectingClient(ws.ReconnectConfig{URL: "ws://localhost:8080/stream/ws"}) func NewReconnectingClient(config ReconnectConfig) *ReconnectingClient { if config.InitialBackoff == 0 { config.InitialBackoff = 500 * time.Millisecond @@ -54,69 +61,80 @@ func NewReconnectingClient(config ReconnectConfig) *ReconnectingClient { return &ReconnectingClient{config: config, state: stream.StateDisconnected} } -// Connect starts the connection loop. Blocks until ctx is cancelled. -func (rc *ReconnectingClient) Connect(ctx context.Context) error { - if rc == nil { - return errors.New("nil reconnecting client") +// client := ws.NewReconnectingClient(ws.ReconnectConfig{URL: "ws://127.0.0.1:8080/stream/ws"}) +// err := client.Connect(ctx) +func (client *ReconnectingClient) Connect(ctx context.Context) error { + if client == nil { + return core.E("stream.ws", "nil reconnecting client", nil) } if ctx == nil { ctx = context.Background() } - dialer := rc.config.Dialer + dialer := client.config.Dialer if dialer == nil { dialer = websocket.DefaultDialer } - backoff := rc.config.InitialBackoff + backoff := client.config.InitialBackoff attempt := 0 for { - if rc.isClosed() || ctx.Err() != nil { + if client.isClosed() || ctx.Err() != nil { return nil } - rc.setState(stream.StateConnecting) + client.setState(stream.StateConnecting) - conn, _, err := dialer.DialContext(ctx, rc.config.URL, rc.config.Headers) + conn, _, err := dialer.DialContext(ctx, client.config.URL, client.config.Headers) if err != nil { attempt++ - rc.setState(stream.StateDisconnected) - if rc.config.MaxRetries > 0 && attempt > rc.config.MaxRetries { + client.setState(stream.StateDisconnected) + if client.config.MaxRetries > 0 && attempt > client.config.MaxRetries { return err } - if rc.config.OnReconnect != nil { - rc.config.OnReconnect(attempt) + if client.config.OnReconnect != nil { + client.config.OnReconnect(attempt) } if err := sleepContext(ctx, backoff); err != nil { return err } - backoff = nextBackoff(backoff, rc.config.BackoffMultiplier, rc.config.MaxBackoff) + backoff = nextBackoff(backoff, client.config.BackoffMultiplier, client.config.MaxBackoff) continue } - rc.mu.Lock() - rc.conn = conn - rc.state = stream.StateConnected - rc.mu.Unlock() - if rc.config.OnConnect != nil { - rc.config.OnConnect() + client.mutex.Lock() + client.conn = conn + client.state = stream.StateConnected + client.mutex.Unlock() + stopClose := context.AfterFunc(ctx, func() { + if err := conn.Close(); err != nil { + return + } + }) + backoff = client.config.InitialBackoff + attempt = 0 + if client.config.OnConnect != nil { + client.config.OnConnect() } - readErr := rc.readLoop(ctx, conn) + readErr := client.readLoop(ctx, conn) + stopClose() - rc.mu.Lock() - if rc.conn == conn { - rc.conn = nil + client.mutex.Lock() + if client.conn == conn { + client.conn = nil + } + client.state = stream.StateDisconnected + client.mutex.Unlock() + if err := conn.Close(); err != nil && readErr == nil { + readErr = err } - rc.state = stream.StateDisconnected - rc.mu.Unlock() - _ = conn.Close() - if rc.config.OnDisconnect != nil { - rc.config.OnDisconnect() + if client.config.OnDisconnect != nil { + client.config.OnDisconnect() } - if rc.isClosed() || ctx.Err() != nil { + if client.isClosed() || ctx.Err() != nil { return nil } if readErr == nil { @@ -124,23 +142,23 @@ func (rc *ReconnectingClient) Connect(ctx context.Context) error { } else { attempt++ } - if rc.config.MaxRetries > 0 && attempt > rc.config.MaxRetries { + if client.config.MaxRetries > 0 && attempt > client.config.MaxRetries { return readErr } - if rc.config.OnReconnect != nil { - rc.config.OnReconnect(attempt) + if client.config.OnReconnect != nil { + client.config.OnReconnect(attempt) } if err := sleepContext(ctx, backoff); err != nil { return err } - backoff = nextBackoff(backoff, rc.config.BackoffMultiplier, rc.config.MaxBackoff) + backoff = nextBackoff(backoff, client.config.BackoffMultiplier, client.config.MaxBackoff) } } -// Send marshals and sends a message through the WebSocket connection. -func (rc *ReconnectingClient) Send(msg stream.Message) error { - if rc == nil { - return errors.New("nil reconnecting client") +// _ = client.Send(stream.Message{Type: stream.TypeEvent, Channel: "hashrate", Data: map[string]any{"h": 1234567}}) +func (client *ReconnectingClient) Send(msg stream.Message) error { + if client == nil { + return core.E("stream.ws", "nil reconnecting client", nil) } if msg.Timestamp.IsZero() { msg.Timestamp = time.Now().UTC() @@ -153,48 +171,48 @@ func (rc *ReconnectingClient) Send(msg stream.Message) error { return core.E("stream.ws", "failed to marshal message", nil) } - rc.mu.RLock() - conn := rc.conn - rc.mu.RUnlock() + client.mutex.RLock() + conn := client.conn + client.mutex.RUnlock() if conn == nil { return core.E("stream.ws", "not connected", nil) } - rc.mu.Lock() - defer rc.mu.Unlock() - if rc.conn == nil { + client.mutex.Lock() + defer client.mutex.Unlock() + if client.conn == nil { return core.E("stream.ws", "not connected", nil) } - return rc.conn.WriteMessage(websocket.TextMessage, payload.Value.([]byte)) + return client.conn.WriteMessage(websocket.TextMessage, payload.Value.([]byte)) } -// State returns the current connection state. -func (rc *ReconnectingClient) State() stream.ConnectionState { - if rc == nil { +// state := client.State() +func (client *ReconnectingClient) State() stream.ConnectionState { + if client == nil { return stream.StateDisconnected } - rc.mu.RLock() - defer rc.mu.RUnlock() - return rc.state + client.mutex.RLock() + defer client.mutex.RUnlock() + return client.state } -// Close shuts down the reconnecting client. -func (rc *ReconnectingClient) Close() error { - if rc == nil { +// _ = client.Close() +func (client *ReconnectingClient) Close() error { + if client == nil { return nil } - rc.mu.Lock() - rc.closed = true - conn := rc.conn - rc.conn = nil - rc.state = stream.StateDisconnected - rc.mu.Unlock() + client.mutex.Lock() + client.closed = true + conn := client.conn + client.conn = nil + client.state = stream.StateDisconnected + client.mutex.Unlock() if conn != nil { return conn.Close() } return nil } -func (rc *ReconnectingClient) readLoop(ctx context.Context, conn *websocket.Conn) error { +func (client *ReconnectingClient) readLoop(ctx context.Context, conn *websocket.Conn) error { for { select { case <-ctx.Done(): @@ -212,22 +230,22 @@ func (rc *ReconnectingClient) readLoop(ctx context.Context, conn *websocket.Conn if !core.JSONUnmarshal(payload, &message).OK { continue } - if rc.config.OnMessage != nil { - rc.config.OnMessage(message) + if client.config.OnMessage != nil { + client.config.OnMessage(message) } } } -func (rc *ReconnectingClient) isClosed() bool { - rc.mu.RLock() - defer rc.mu.RUnlock() - return rc.closed +func (client *ReconnectingClient) isClosed() bool { + client.mutex.RLock() + defer client.mutex.RUnlock() + return client.closed } -func (rc *ReconnectingClient) setState(state stream.ConnectionState) { - rc.mu.Lock() - rc.state = state - rc.mu.Unlock() +func (client *ReconnectingClient) setState(state stream.ConnectionState) { + client.mutex.Lock() + client.state = state + client.mutex.Unlock() } func nextBackoff(current time.Duration, multiplier float64, maximum time.Duration) time.Duration { diff --git a/adapter/ws/ws.go b/adapter/ws/ws.go index da3d4ea..b0bd62b 100644 --- a/adapter/ws/ws.go +++ b/adapter/ws/ws.go @@ -1,61 +1,51 @@ // SPDX-License-Identifier: EUPL-1.2 -// Package ws is the WebSocket transport adapter for stream.Hub. -// It wires gorilla/websocket onto the hub, handling HTTP upgrade, -// per-client read/write pumps, and authentication. -// -// adapter := ws.New(ws.Config{Authenticator: auth}) -// adapter.Mount(hub) -// http.Handle("/stream/ws", adapter.Handler()) +// adapter := ws.New(ws.Config{Authenticator: auth}) +// adapter.Mount(hub) +// http.Handle("/stream/ws", adapter.Handler()) package ws import ( + "context" "net/http" "time" "github.com/gorilla/websocket" - "dappco.re/go/core" + "dappco.re/go" "dappco.re/go/stream" ) -// Config configures the WebSocket adapter. -// -// cfg := ws.Config{ +// config := ws.Config{ // Authenticator: stream.NewAPIKeyAuth(keys), // OnAuthFailure: func(r *http.Request, res stream.AuthResult) { -// log.Printf("ws auth fail from %s", r.RemoteAddr) +// core.Print("stream", "ws auth fail from %s", r.RemoteAddr) // }, // } type Config struct { - // Authenticator is called during HTTP upgrade. When nil, all connections accepted. + // ws.New(ws.Config{Authenticator: stream.NewAPIKeyAuth(keys)}) Authenticator stream.Authenticator - // OnAuthFailure is called when Authenticator rejects a connection. + // ws.New(ws.Config{OnAuthFailure: func(r *http.Request, result stream.AuthResult) { ... }}) OnAuthFailure func(r *http.Request, result stream.AuthResult) - // ReadBufferSize and WriteBufferSize are passed to the gorilla upgrader. - // Default: 1024 each. + // ws.New(ws.Config{ReadBufferSize: 1024, WriteBufferSize: 1024}) ReadBufferSize int WriteBufferSize int - // CheckOrigin overrides the upgrader's origin check. When nil, all origins accepted. + // ws.New(ws.Config{CheckOrigin: func(r *http.Request) bool { return true }}) CheckOrigin func(r *http.Request) bool } -// Adapter is the WebSocket transport adapter for a stream.Hub. -// -// adapter := ws.New(ws.Config{...}) -// adapter.Mount(hub) -// http.Handle("/ws", adapter.Handler()) +// adapter := ws.New(ws.Config{Authenticator: auth}) +// adapter.Mount(hub) +// http.Handle("/ws", adapter.Handler()) type Adapter struct { hub *stream.Hub config Config } -// New creates a WebSocket adapter. Call Mount before serving requests. -// -// adapter := ws.New(ws.Config{Authenticator: auth}) +// adapter := ws.New(ws.Config{Authenticator: auth}) func New(config Config) *Adapter { if config.ReadBufferSize == 0 { config.ReadBufferSize = 1024 @@ -66,99 +56,227 @@ func New(config Config) *Adapter { return &Adapter{config: config} } -// Mount wires the adapter to a hub. Must be called before Handler(). -// -// adapter.Mount(hub) -func (a *Adapter) Mount(hub *stream.Hub) { - a.hub = hub +// adapter.Mount(hub) +func (adapter *Adapter) Mount(hub *stream.Hub) { + adapter.hub = hub } -// Handler returns an http.HandlerFunc for WebSocket connections. -// Compatible with net/http and gin (use gin.WrapF). +// http.Handle("/stream/ws", adapter.Handler()) // -// http.Handle("/stream/ws", adapter.Handler()) +// Gin: +// r.GET("/stream/ws", gin.WrapF(adapter.Handler())) +func (adapter *Adapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + adapter.serveHTTP(w, r, r.URL.Query()["channel"]) +} + +// HandlerForChannel returns a handler that auto-subscribes every connection to one channel. // -// // Gin: -// r.GET("/stream/ws", gin.WrapF(adapter.Handler())) -func (a *Adapter) Handler() http.HandlerFunc { +// http.Handle("/stream/hashrate", adapter.HandlerForChannel("hashrate")) +func (adapter *Adapter) HandlerForChannel(channel string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if a.hub == nil { - http.Error(w, "stream hub not mounted", http.StatusInternalServerError) + adapter.serveHTTP(w, r, []string{channel}) + } +} + +func (adapter *Adapter) serveHTTP(w http.ResponseWriter, r *http.Request, channels []string) { + if adapter.hub == nil { + http.Error(w, "stream hub not mounted", http.StatusInternalServerError) + return + } + + authResult := stream.AuthResult{Valid: true} + if adapter.config.Authenticator != nil { + authResult = adapter.config.Authenticator.Authenticate(r) + if !authResult.Valid { + if adapter.config.OnAuthFailure != nil { + adapter.config.OnAuthFailure(r, authResult) + } + http.Error(w, "unauthorised", http.StatusUnauthorized) return } + } - result := stream.AuthResult{Valid: true} - if a.config.Authenticator != nil { - result = a.config.Authenticator.Authenticate(r) - if !result.Valid { - if a.config.OnAuthFailure != nil { - a.config.OnAuthFailure(r, result) - } - http.Error(w, "unauthorised", http.StatusUnauthorized) - return + peer := stream.NewPeer("ws") + peer.UserID = authResult.UserID + if authResult.Claims != nil { + peer.Claims = authResult.Claims + } + for _, channel := range channels { + if channel == "" { + continue + } + if err := adapter.hub.CanSubscribePeer(peer, channel); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + } + + if !adapter.hub.Running() { + http.Error(w, "stream hub not running", http.StatusInternalServerError) + return + } + + upgrader := websocket.Upgrader{ + ReadBufferSize: adapter.config.ReadBufferSize, + WriteBufferSize: adapter.config.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { + if adapter.config.CheckOrigin != nil { + return adapter.config.CheckOrigin(r) } + return true + }, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if err := adapter.hub.AddPeer(peer); err != nil { + if closeErr := conn.Close(); closeErr != nil { + return } + http.Error(w, "stream hub not running", http.StatusInternalServerError) + return + } + defer adapter.hub.RemovePeer(peer) - upgrader := websocket.Upgrader{ - ReadBufferSize: a.config.ReadBufferSize, - WriteBufferSize: a.config.WriteBufferSize, - CheckOrigin: func(r *http.Request) bool { - if a.config.CheckOrigin != nil { - return a.config.CheckOrigin(r) - } - return true - }, + peer.SetCloseHook(func() { + if err := conn.Close(); err != nil { + return + } + }) + for _, channel := range channels { + if channel == "" { + continue + } + if err := adapter.hub.SubscribePeer(peer, channel); err != nil { + peer.Close() + return } + } + defer conn.Close() + stopClose := context.AfterFunc(r.Context(), func() { + if err := conn.Close(); err != nil { + return + } + }) + defer stopClose() - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + hubConfig := adapter.hub.Config() + if hubConfig.PongTimeout > 0 { + if err := conn.SetReadDeadline(time.Now().Add(hubConfig.PongTimeout)); err != nil { + peer.Close() return } + conn.SetPongHandler(func(string) error { + return conn.SetReadDeadline(time.Now().Add(hubConfig.PongTimeout)) + }) + } - peer := stream.NewPeer("ws") - peer.UserID = result.UserID - peer.Claims = result.Claims - _ = a.hub.AddPeer(peer) - defer a.hub.RemovePeer(peer) - defer conn.Close() + go adapter.writePump(conn, peer, hubConfig.WriteTimeout, hubConfig.HeartbeatInterval) - go func() { - for frame := range peer.SendQueue() { - if err := conn.WriteMessage(websocket.TextMessage, frame); err != nil { + conn.SetReadLimit(1 << 20) + for { + messageType, payload, err := conn.ReadMessage() + if err != nil { + break + } + if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage { + continue + } + var message stream.Message + if !core.JSONUnmarshal(payload, &message).OK { + continue + } + switch message.Type { + case stream.TypeSubscribe: + if err := adapter.hub.SubscribePeer(peer, message.Channel); err != nil { + if ok := peer.Send(marshalMessage(stream.Message{ + Type: stream.TypeError, + Channel: message.Channel, + Data: errorPayload(err), + Timestamp: time.Now().UTC(), + })); !ok { return } } - }() - - conn.SetReadLimit(1 << 20) - for { - messageType, payload, err := conn.ReadMessage() - if err != nil { - break + case stream.TypeUnsubscribe: + adapter.hub.UnsubscribePeer(peer, message.Channel) + case stream.TypePing: + if ok := peer.Send([]byte(core.JSONMarshalString(stream.Message{ + Type: stream.TypePong, + Channel: message.Channel, + ProcessID: message.ProcessID, + Timestamp: time.Now().UTC(), + }))); !ok { + return } - if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage { + default: + if message.Channel == "" { + if err := adapter.hub.BroadcastFromPeer(peer, payload); err != nil { + return + } continue } - var message stream.Message - if !core.JSONUnmarshal(payload, &message).OK { - continue + if err := adapter.hub.PublishFromPeer(peer, message.Channel, payload); err != nil { + return } - switch message.Type { - case stream.TypeSubscribe: - _ = a.hub.SubscribePeer(peer, message.Channel) - case stream.TypeUnsubscribe: - a.hub.UnsubscribePeer(peer, message.Channel) - case stream.TypePing: - _ = peer.Send([]byte(core.JSONMarshalString(stream.Message{ - Type: stream.TypePong, - Channel: message.Channel, - ProcessID: message.ProcessID, - Timestamp: time.Now().UTC(), - }))) + } + } + + peer.Close() +} + +// http.Handle("/stream/ws", adapter.Handler()) +// r.GET("/stream/ws", gin.WrapF(adapter.Handler())) +func (adapter *Adapter) Handler() http.HandlerFunc { + return adapter.ServeHTTP +} + +func (adapter *Adapter) writePump(conn *websocket.Conn, peer *stream.Peer, writeTimeout, heartbeatInterval time.Duration) { + var ticker *time.Ticker + var heartbeat <-chan time.Time + if heartbeatInterval > 0 { + ticker = time.NewTicker(heartbeatInterval) + defer ticker.Stop() + heartbeat = ticker.C + } + for { + select { + case <-heartbeat: + if writeTimeout > 0 { + if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return + } + } + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + case frame, ok := <-peer.SendQueue(): + if !ok { + return + } + if writeTimeout > 0 { + if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return + } + } + if err := conn.WriteMessage(websocket.TextMessage, frame); err != nil { + return } } + } +} + +func marshalMessage(message stream.Message) []byte { + return []byte(core.JSONMarshalString(message)) +} - peer.Close() +func errorPayload(err error) map[string]any { + if err == nil { + return nil } + return map[string]any{"message": err.Error()} } diff --git a/adapter/ws/ws_test.go b/adapter/ws/ws_test.go new file mode 100644 index 0000000..9e5be87 --- /dev/null +++ b/adapter/ws/ws_test.go @@ -0,0 +1,558 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package ws + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/gorilla/websocket" + + "dappco.re/go" + "dappco.re/go/stream" +) + +func TestAX7_Adapter_Handler_Good(t *testing.T) { + hub := stream.NewHubWithConfig(stream.HubConfig{ + HeartbeatInterval: 20 * time.Millisecond, + PongTimeout: 100 * time.Millisecond, + WriteTimeout: 100 * time.Millisecond, + }) + + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + conn := dialWebSocket(t, server.URL, nil) + defer conn.Close() + + receivedPing := make(chan struct{}, 1) + receivedFrame := make(chan []byte, 1) + conn.SetPingHandler(func(appData string) error { + select { + case receivedPing <- struct{}{}: + default: + } + return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second)) + }) + + done := make(chan struct{}) + go func() { + defer close(done) + for { + messageType, payload, err := conn.ReadMessage() + if err != nil { + return + } + if messageType == websocket.TextMessage { + receivedFrame <- append([]byte(nil), payload...) + } + } + }() + + if err := conn.WriteJSON(stream.Message{ + Type: stream.TypeSubscribe, + Channel: "hashrate", + }); err != nil { + t.Fatalf("WriteJSON() error = %v", err) + } + + waitForChannelSubscriberCount(t, hub, "hashrate", 1) + + if err := hub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-receivedFrame: + if string(frame) != "123456" { + t.Fatalf("received frame = %q, want %q", string(frame), "123456") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for published frame") + } + + select { + case <-receivedPing: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for heartbeat ping") + } + + _ = conn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for client reader to exit") + } +} + +func TestAX7_Adapter_Handler_Bad(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + Authenticator: stream.NewAPIKeyAuth(map[string]string{"valid-key": "user-1"}), + }) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + _, resp, err := websocket.DefaultDialer.Dial(websocketURL(server.URL), nil) + if err == nil { + t.Fatal("Dial() error = nil, want auth failure") + } + if resp == nil { + t.Fatal("Dial() response = nil, want 401 response") + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("StatusCode = %d, want %d", resp.StatusCode, http.StatusUnauthorized) + } +} + +func TestAdapter_Handler_UpgradeFailure_DoesNotRegisterPeer_Good(t *testing.T) { + var connectCount atomic.Int32 + hub := stream.NewHubWithConfig(stream.HubConfig{ + OnConnect: func(peer *stream.Peer) { + connectCount.Add(1) + }, + }) + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{ + CheckOrigin: func(r *http.Request) bool { + return false + }, + }) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + _, resp, err := websocket.DefaultDialer.Dial(websocketURL(server.URL), nil) + if err == nil { + t.Fatal("Dial() error = nil, want upgrade failure") + } + if resp == nil { + t.Fatal("Dial() response = nil, want handshake failure response") + } + if connectCount.Load() != 0 { + t.Fatalf("OnConnect invoked %d times, want %d", connectCount.Load(), 0) + } + waitForPeerCount(t, hub, 0) +} + +func TestAdapter_Handler_HubNotRunning_Bad(t *testing.T) { + adapter := New(Config{}) + adapter.Mount(stream.NewHub()) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + _, resp, err := websocket.DefaultDialer.Dial(websocketURL(server.URL), nil) + if err == nil { + t.Fatal("Dial() error = nil, want hub lifecycle failure") + } + if resp == nil { + t.Fatal("Dial() response = nil, want 500 response") + } + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("StatusCode = %d, want %d", resp.StatusCode, http.StatusInternalServerError) + } +} + +func TestAdapter_Handler_QueryChannelAuthoriser_Bad(t *testing.T) { + hub := stream.NewHubWithConfig(stream.HubConfig{ + ChannelAuthoriser: func(peer *stream.Peer, channel string) bool { + return channel == "public" + }, + }) + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + _, resp, err := websocket.DefaultDialer.Dial(websocketURL(server.URL)+"?channel=private", nil) + if err == nil { + t.Fatal("Dial() error = nil, want forbidden response") + } + if resp == nil { + t.Fatal("Dial() response = nil, want 403 response") + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("StatusCode = %d, want %d", resp.StatusCode, http.StatusForbidden) + } + waitForPeerCount(t, hub, 0) +} + +func TestAX7_Adapter_Handler_Ugly(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + conn := dialWebSocket(t, server.URL, nil) + if err := conn.WriteJSON(stream.Message{ + Type: stream.TypeSubscribe, + Channel: "block", + }); err != nil { + t.Fatalf("WriteJSON() error = %v", err) + } + + if err := conn.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + waitForPeerCount(t, hub, 0) +} + +func TestAX7_Adapter_ServeHTTP_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(adapter) + defer server.Close() + + conn := dialWebSocket(t, server.URL, nil) + defer conn.Close() + + if err := conn.WriteJSON(stream.Message{ + Type: stream.TypeSubscribe, + Channel: "serve-http", + }); err != nil { + t.Fatalf("WriteJSON() error = %v", err) + } + + waitForChannelSubscriberCount(t, hub, "serve-http", 1) + + if err := hub.Publish("serve-http", []byte("ok")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + if messageType != websocket.TextMessage { + t.Fatalf("messageType = %d, want %d", messageType, websocket.TextMessage) + } + if string(payload) != "ok" { + t.Fatalf("payload = %q, want %q", string(payload), "ok") + } +} + +func TestAX7_Adapter_HandlerForChannel_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.HandlerForChannel("hashrate"))) + defer server.Close() + + conn := dialWebSocket(t, server.URL, nil) + defer conn.Close() + + waitForChannelSubscriberCount(t, hub, "hashrate", 1) + + if err := hub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + if messageType != websocket.TextMessage { + t.Fatalf("messageType = %d, want %d", messageType, websocket.TextMessage) + } + if string(payload) != "123456" { + t.Fatalf("payload = %q, want %q", string(payload), "123456") + } +} + +func TestAdapter_Handler_InboundPublish_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + conn := dialWebSocket(t, server.URL, nil) + defer conn.Close() + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("agent", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + message := stream.Message{ + Type: stream.TypeEvent, + Channel: "agent", + Data: map[string]any{"status": "ok"}, + Timestamp: time.Now().UTC(), + } + if err := conn.WriteJSON(message); err != nil { + t.Fatalf("WriteJSON() error = %v", err) + } + + select { + case frame := <-received: + var decoded stream.Message + if !core.JSONUnmarshal(frame, &decoded).OK { + t.Fatalf("received invalid JSON frame: %q", string(frame)) + } + if decoded.Type != stream.TypeEvent { + t.Fatalf("decoded.Type = %q, want %q", decoded.Type, stream.TypeEvent) + } + if decoded.Channel != "agent" { + t.Fatalf("decoded.Channel = %q, want %q", decoded.Channel, "agent") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for inbound websocket frame") + } +} + +func TestAdapter_Handler_InboundPublish_SelfDelivery_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + conn := dialWebSocket(t, server.URL, nil) + defer conn.Close() + + if err := conn.WriteJSON(stream.Message{ + Type: stream.TypeSubscribe, + Channel: "agent", + }); err != nil { + t.Fatalf("WriteJSON(subscribe) error = %v", err) + } + + waitForChannelSubscriberCount(t, hub, "agent", 1) + _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + if err := conn.WriteJSON(stream.Message{ + Type: stream.TypeEvent, + Channel: "agent", + Data: map[string]any{"status": "ok"}, + Timestamp: time.Now().UTC(), + }); err != nil { + t.Fatalf("WriteJSON(event) error = %v", err) + } + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + if messageType != websocket.TextMessage { + t.Fatalf("messageType = %d, want %d", messageType, websocket.TextMessage) + } + + var decoded stream.Message + if !core.JSONUnmarshal(payload, &decoded).OK { + t.Fatalf("received invalid JSON frame: %q", string(payload)) + } + if decoded.Type != stream.TypeEvent { + t.Fatalf("decoded.Type = %q, want %q", decoded.Type, stream.TypeEvent) + } + if decoded.Channel != "agent" { + t.Fatalf("decoded.Channel = %q, want %q", decoded.Channel, "agent") + } + _ = conn.SetReadDeadline(time.Time{}) +} + +func TestAdapter_Handler_SubscribeDenied_Bad(t *testing.T) { + hub := stream.NewHubWithConfig(stream.HubConfig{ + ChannelAuthoriser: func(peer *stream.Peer, channel string) bool { + return channel == "public" + }, + }) + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + conn := dialWebSocket(t, server.URL, nil) + defer conn.Close() + + if err := conn.WriteJSON(stream.Message{ + Type: stream.TypeSubscribe, + Channel: "private", + }); err != nil { + t.Fatalf("WriteJSON() error = %v", err) + } + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + if messageType != websocket.TextMessage { + t.Fatalf("messageType = %d, want %d", messageType, websocket.TextMessage) + } + + var message stream.Message + if !core.JSONUnmarshal(payload, &message).OK { + t.Fatalf("JSONUnmarshal() failed for payload: %q", string(payload)) + } + if message.Type != stream.TypeError { + t.Fatalf("message.Type = %q, want %q", message.Type, stream.TypeError) + } + if message.Channel != "private" { + t.Fatalf("message.Channel = %q, want %q", message.Channel, "private") + } + if hub.ChannelSubscriberCount("private") != 0 { + t.Fatalf("ChannelSubscriberCount(%q) = %d, want %d", "private", hub.ChannelSubscriberCount("private"), 0) + } +} + +func TestAdapter_Handler_PeerClose_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + adapter := New(Config{}) + adapter.Mount(hub) + + server := httptest.NewServer(http.HandlerFunc(adapter.Handler())) + defer server.Close() + + conn := dialWebSocket(t, server.URL, nil) + defer conn.Close() + + var peer *stream.Peer + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + for candidate := range hub.AllPeers() { + peer = candidate + break + } + if peer != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + if peer == nil { + t.Fatal("timed out waiting for websocket peer") + } + + peer.Close() + + readDone := make(chan error, 1) + go func() { + _, _, err := conn.ReadMessage() + readDone <- err + }() + + select { + case err := <-readDone: + if err == nil { + t.Fatal("ReadMessage() error = nil, want closed websocket") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for peer close to close websocket") + } + + waitForPeerCount(t, hub, 0) +} + +func dialWebSocket(t *testing.T, serverURL string, header http.Header) *websocket.Conn { + t.Helper() + conn, resp, err := websocket.DefaultDialer.Dial(websocketURL(serverURL), header) + if err != nil { + if resp != nil { + t.Fatalf("Dial() error = %v, status = %s", err, resp.Status) + } + t.Fatalf("Dial() error = %v", err) + } + return conn +} + +func websocketURL(serverURL string) string { + parsed, err := url.Parse(serverURL) + if err != nil { + return serverURL + } + switch parsed.Scheme { + case "http": + parsed.Scheme = "ws" + case "https": + parsed.Scheme = "wss" + } + return parsed.String() +} + +func waitForPeerCount(t *testing.T, hub *stream.Hub, expected int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if hub.PeerCount() == expected { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("PeerCount() = %d, want %d", hub.PeerCount(), expected) +} + +func waitForChannelSubscriberCount(t *testing.T, hub *stream.Hub, channel string, expected int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if hub.ChannelSubscriberCount(channel) == expected { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("ChannelSubscriberCount(%q) = %d, want %d", channel, hub.ChannelSubscriberCount(channel), expected) +} diff --git a/adapter/zmq/ax7_more_test.go b/adapter/zmq/ax7_more_test.go new file mode 100644 index 0000000..f24f59c --- /dev/null +++ b/adapter/zmq/ax7_more_test.go @@ -0,0 +1,163 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package zmq + +import ( + core "dappco.re/go" + "dappco.re/go/stream" +) + +func TestAX7_Mode_String_Good(t *core.T) { + core.AssertEqual(t, "pubsub", ModePubSub.String()) + core.AssertEqual(t, "pushpull", ModePushPull.String()) + core.AssertNotEqual(t, ModePubSub.String(), ModePushPull.String()) +} + +func TestAX7_Mode_String_Bad(t *core.T) { + mode := Mode(99) + + core.AssertEqual(t, "unknown", mode.String()) + core.AssertNotEqual(t, "pubsub", mode.String()) +} + +func TestAX7_Mode_String_Ugly(t *core.T) { + mode := Mode(-1) + + core.AssertEqual(t, "unknown", mode.String()) + core.AssertNotPanics(t, func() { _ = mode.String() }) +} + +func TestAX7_Role_String_Good(t *core.T) { + core.AssertEqual(t, "publisher", RolePublisher.String()) + core.AssertEqual(t, "subscriber", RoleSubscriber.String()) + core.AssertEqual(t, "pusher", RolePusher.String()) + core.AssertEqual(t, "puller", RolePuller.String()) +} + +func TestAX7_Role_String_Bad(t *core.T) { + role := Role(99) + + core.AssertEqual(t, "unknown", role.String()) + core.AssertNotEqual(t, "publisher", role.String()) +} + +func TestAX7_Role_String_Ugly(t *core.T) { + role := Role(-1) + + core.AssertEqual(t, "unknown", role.String()) + core.AssertNotPanics(t, func() { _ = role.String() }) +} + +func TestAX7_New_Good(t *core.T) { + adapter := New(Config{Mode: ModePubSub, Endpoint: "tcp://127.0.0.1:1", Role: RolePublisher}) + + core.AssertNotNil(t, adapter) + core.AssertEqual(t, ModePubSub, adapter.config.Mode) + core.AssertEqual(t, 5*core.Second, adapter.config.HandshakeTimeout) +} + +func TestAX7_New_Bad(t *core.T) { + adapter := New(Config{}) + + core.AssertNotNil(t, adapter) + core.AssertEqual(t, ModePubSub, adapter.config.Mode) + core.AssertEqual(t, "", adapter.config.Endpoint) +} + +func TestAX7_New_Ugly(t *core.T) { + adapter := New(Config{HandshakeTimeout: core.Millisecond, Topics: []string{"block"}}) + + core.AssertEqual(t, core.Millisecond, adapter.config.HandshakeTimeout) + core.AssertEqual(t, []string{"block"}, adapter.config.Topics) +} + +func TestAX7_Adapter_Mount_Good(t *core.T) { + adapter := New(Config{}) + hub := stream.NewHub() + + adapter.Mount(hub) + core.AssertEqual(t, hub, adapter.hub) + core.AssertFalse(t, adapter.running) +} + +func TestAX7_Adapter_Mount_Bad(t *core.T) { + adapter := New(Config{}) + + adapter.Mount(nil) + core.AssertNil(t, adapter.hub) + core.AssertNotNil(t, adapter) +} + +func TestAX7_Adapter_Mount_Ugly(t *core.T) { + adapter := New(Config{}) + first := stream.NewHub() + second := stream.NewHub() + + adapter.Mount(first) + adapter.Mount(second) + core.AssertEqual(t, second, adapter.hub) +} + +func TestAX7_Adapter_Start_Good(t *core.T) { + hub := stream.NewHub() + ctx, cancel := core.WithCancel(core.Background()) + defer cancel() + go hub.Run(ctx) + adapter := New(Config{Mode: ModePubSub, Endpoint: randomTCPEndpoint(t), Role: RolePublisher}) + adapter.Mount(hub) + + go func() { + if err := adapter.Start(ctx); err != nil { + t.Errorf("Start() error = %v", err) + } + }() + waitForAdapterRunning(t, adapter) + core.AssertTrue(t, adapter.running) +} + +func TestAX7_Adapter_Start_Bad(t *core.T) { + adapter := New(Config{Mode: Mode(99), Endpoint: randomTCPEndpoint(t), Role: RolePublisher}) + adapter.Mount(stream.NewHub()) + + err := adapter.Start(core.Background()) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "invalid mode") +} + +func TestAX7_Adapter_Stop_Good(t *core.T) { + hub := stream.NewHub() + ctx, cancel := core.WithCancel(core.Background()) + defer cancel() + go hub.Run(ctx) + adapter := New(Config{Mode: ModePubSub, Endpoint: randomTCPEndpoint(t), Role: RolePublisher}) + adapter.Mount(hub) + go func() { _ = adapter.Start(ctx) }() + waitForAdapterRunning(t, adapter) + + core.AssertNoError(t, adapter.Stop()) + core.Sleep(50 * core.Millisecond) + core.AssertFalse(t, adapter.running) +} + +func TestAX7_Adapter_Stop_Bad(t *core.T) { + var adapter *Adapter + + core.AssertNoError(t, adapter.Stop()) + core.AssertNil(t, adapter) +} + +func TestAX7_Adapter_Stop_Ugly(t *core.T) { + adapter := New(Config{Mode: ModePubSub, Endpoint: randomTCPEndpoint(t), Role: RolePublisher}) + + core.AssertNoError(t, adapter.Stop()) + core.AssertFalse(t, adapter.running) +} + +func TestAX7_Adapter_Publish_Ugly(t *core.T) { + adapter := New(Config{Mode: ModePubSub, Endpoint: randomTCPEndpoint(t), Role: RolePublisher}) + adapter.Mount(stream.NewHub()) + + err := adapter.Publish("block", []byte("template")) + core.AssertError(t, err) + core.AssertContains(t, err.Error(), "not started") +} diff --git a/adapter/zmq/example_test.go b/adapter/zmq/example_test.go new file mode 100644 index 0000000..d44cfce --- /dev/null +++ b/adapter/zmq/example_test.go @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package zmq_test + +import ( + "context" + + "dappco.re/go/stream" + "dappco.re/go/stream/adapter/zmq" +) + +func ExampleAdapter_Start() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hub := stream.NewHub() + go hub.Run(ctx) + + adapter := zmq.New(zmq.Config{ + Mode: zmq.ModePubSub, + Endpoint: "tcp://127.0.0.1:5555", + Role: zmq.RoleSubscriber, + Topics: []string{"block"}, + }) + adapter.Mount(hub) + + go func() { + _ = adapter.Start(ctx) + }() +} diff --git a/adapter/zmq/zmq.go b/adapter/zmq/zmq.go index 0840acd..e2a610d 100644 --- a/adapter/zmq/zmq.go +++ b/adapter/zmq/zmq.go @@ -1,19 +1,32 @@ // SPDX-License-Identifier: EUPL-1.2 -// Package zmq is the ZeroMQ transport adapter for stream.Hub. -// High-throughput IPC for daemon block notifications and inter-process job broadcasts. +// adapter := zmq.New(zmq.Config{ +// Mode: zmq.ModePubSub, +// Endpoint: "tcp://127.0.0.1:5555", +// Role: zmq.RoleSubscriber, +// }) +// +// adapter.Mount(hub) +// go adapter.Start(ctx) +// defer adapter.Stop() package zmq import ( "context" - "strconv" + "net" + "net/url" "sync" + "time" - "dappco.re/go/core" + "github.com/go-zeromq/zmq4" + + "dappco.re/go" "dappco.re/go/stream" ) -// Mode selects the ZMQ socket pattern. +const maxHandshakeFrameSize = 4 << 10 + +// mode := zmq.ModePubSub type Mode int const ( @@ -21,7 +34,19 @@ const ( ModePushPull ) -// Role is the ZMQ socket role. +// core.Print(nil, "mode=%s", zmq.ModePubSub.String()) +func (mode Mode) String() string { + switch mode { + case ModePubSub: + return "pubsub" + case ModePushPull: + return "pushpull" + default: + return "unknown" + } +} + +// role := zmq.RoleSubscriber type Role int const ( @@ -31,189 +56,379 @@ const ( RolePuller ) -// Config configures the ZMQ adapter. +// core.Print(nil, "role=%s", zmq.RoleSubscriber.String()) +func (role Role) String() string { + switch role { + case RolePublisher: + return "publisher" + case RoleSubscriber: + return "subscriber" + case RolePusher: + return "pusher" + case RolePuller: + return "puller" + default: + return "unknown" + } +} + +// config := zmq.Config{ +// Mode: zmq.ModePubSub, +// Endpoint: "tcp://127.0.0.1:5555", +// Role: zmq.RoleSubscriber, +// } type Config struct { Mode Mode Endpoint string Role Role Topics []string + + // ConnAuthenticator validates the first received frame before normal dispatch. + // When nil, the adapter accepts the connection without handshake validation. + ConnAuthenticator stream.ConnAuthenticator + + // HandshakeTimeout limits how long the adapter waits for the first frame when + // ConnAuthenticator is configured. Defaults to 5 seconds. + HandshakeTimeout time.Duration } -// Adapter is the ZMQ transport adapter. +// adapter := zmq.New(zmq.Config{Mode: zmq.ModePubSub, Endpoint: "tcp://127.0.0.1:5555", Role: zmq.RoleSubscriber}) type Adapter struct { hub *stream.Hub config Config - source string - mu sync.Mutex + mutex sync.RWMutex running bool - stopCh chan struct{} + socket zmq4.Socket + cancel context.CancelFunc } -type zmqRegistry struct { - mu sync.RWMutex - adapters map[string]map[*Adapter]struct{} -} - -var registry = zmqRegistry{adapters: map[string]map[*Adapter]struct{}{}} - -// New creates a ZMQ adapter. Call Mount and Start before use. +// adapter := zmq.New(zmq.Config{Mode: zmq.ModePubSub, Endpoint: "tcp://127.0.0.1:5555", Role: zmq.RoleSubscriber}) func New(config Config) *Adapter { - return &Adapter{config: config, source: stream.NewPeer("zmq").ID, stopCh: make(chan struct{})} + if config.HandshakeTimeout == 0 { + config.HandshakeTimeout = 5 * time.Second + } + return &Adapter{config: config} } -// Mount wires the adapter to a hub. -func (a *Adapter) Mount(hub *stream.Hub) { - a.hub = hub +// adapter.Mount(hub) +func (adapter *Adapter) Mount(hub *stream.Hub) { + adapter.hub = hub } -// Start opens the ZMQ socket and begins receive/dispatch. Blocks until ctx cancelled. -func (a *Adapter) Start(ctx context.Context) error { - if a == nil { +// go adapter.Start(ctx) +func (adapter *Adapter) Start(ctx context.Context) error { + if adapter == nil { return core.E("stream.zmq", "nil adapter", nil) } - if a.config.Endpoint == "" { + if ctx == nil { + ctx = context.Background() + } + if adapter.config.Endpoint == "" { return core.E("stream.zmq", "empty endpoint", nil) } - if a.hub == nil { + if adapter.hub == nil { return core.E("stream.zmq", "stream hub not mounted", nil) } + if err := adapter.validateRole(); err != nil { + return err + } + + runContext, runCancel := context.WithCancel(ctx) + socket, err := adapter.newSocket(runContext) + if err != nil { + runCancel() + return err + } + if err := adapter.connectSocket(socket); err != nil { + if closeErr := socket.Close(); closeErr != nil { + runCancel() + return core.ErrorJoin(err, closeErr) + } + runCancel() + return err + } - a.mu.Lock() - if a.running { - a.mu.Unlock() - <-ctx.Done() + adapter.mutex.Lock() + if adapter.running { + adapter.mutex.Unlock() + if err := socket.Close(); err != nil { + runCancel() + return err + } + runCancel() return nil } - a.running = true - stopCh := a.stopCh - key := a.registryKey() - a.mu.Unlock() + adapter.running = true + adapter.socket = socket + adapter.cancel = runCancel + adapter.mutex.Unlock() - registry.add(key, a) - defer registry.remove(key, a) + defer func() { + adapter.mutex.Lock() + adapter.running = false + adapter.socket = nil + adapter.cancel = nil + adapter.mutex.Unlock() + runCancel() + if err := socket.Close(); err != nil { + return + } + }() - select { - case <-ctx.Done(): - case <-stopCh: + if !adapter.isReceiver() { + peer := adapter.registerPeer(socket, stream.AuthResult{}) + if peer != nil { + defer adapter.hub.RemovePeer(peer) + } + <-runContext.Done() + return nil } - a.mu.Lock() - a.running = false - a.mu.Unlock() - return nil + authResult := stream.AuthResult{Valid: true} + if adapter.config.ConnAuthenticator != nil { + handshake, err := adapter.recvWithTimeout(runContext, socket, adapter.config.HandshakeTimeout) + if err != nil { + if err == context.Canceled { + return nil + } + return err + } + if len(handshake.Bytes()) > maxHandshakeFrameSize { + return stream.ErrAuthRejected + } + authResult = adapter.config.ConnAuthenticator.AuthenticateConn(handshake.Bytes()) + if !authResult.Valid { + return stream.ErrAuthRejected + } + } + peer := adapter.registerPeer(socket, authResult) + if peer != nil { + defer adapter.hub.RemovePeer(peer) + } + + for { + message, err := socket.Recv() + if err != nil { + if runContext.Err() != nil { + return nil + } + return err + } + + channel, frame, ok := decodeMessage(message) + if !ok { + continue + } + if channel == "" { + if err := adapter.hub.Broadcast(frame); err != nil { + return err + } + continue + } + if err := adapter.hub.Publish(channel, frame); err != nil { + return err + } + } +} + +func (adapter *Adapter) registerPeer(socket zmq4.Socket, authResult stream.AuthResult) *stream.Peer { + if adapter == nil || adapter.hub == nil { + return nil + } + peer := stream.NewPeer("zmq") + peer.UserID = authResult.UserID + if authResult.Claims != nil { + peer.Claims = authResult.Claims + } + if socket != nil { + peer.SetCloseHook(func() { + if err := socket.Close(); err != nil { + return + } + }) + } + if err := adapter.hub.AddPeer(peer); err != nil { + return nil + } + return peer } -// Publish sends frame with topic (channel name) via the ZMQ socket. -func (a *Adapter) Publish(channel string, frame []byte) error { - if a == nil { +// _ = adapter.Publish("block", templateBytes) +func (adapter *Adapter) Publish(channel string, frame []byte) error { + if adapter == nil { return core.E("stream.zmq", "nil adapter", nil) } - if a.config.Role != RolePublisher && a.config.Role != RolePusher { + if !adapter.isSender() { return core.E("stream.zmq", "publish not supported for this role", nil) } - if !a.isRunning() { + + adapter.mutex.RLock() + defer adapter.mutex.RUnlock() + if !adapter.running || adapter.socket == nil { return core.E("stream.zmq", "adapter not started", nil) } - registry.publish(a.registryKey(), message{ - SourceID: a.sourceID(), - Channel: channel, - Frame: append([]byte(nil), frame...), - }) - return nil + + return adapter.socket.Send(zmq4.NewMsg(encodeMessage(channel, frame))) } -// Stop shuts down the adapter. -func (a *Adapter) Stop() error { - if a == nil { +// defer adapter.Stop() +func (adapter *Adapter) Stop() error { + if adapter == nil { return nil } - a.mu.Lock() - if !a.running { - a.mu.Unlock() - return nil + + adapter.mutex.Lock() + cancel := adapter.cancel + socket := adapter.socket + adapter.running = false + adapter.cancel = nil + adapter.socket = nil + adapter.mutex.Unlock() + + if cancel != nil { + cancel() + } + if socket != nil { + return socket.Close() } - close(a.stopCh) - a.stopCh = make(chan struct{}) - a.mu.Unlock() return nil } -type message struct { - SourceID string - Channel string - Frame []byte +func (adapter *Adapter) validateRole() error { + switch adapter.config.Mode { + case ModePubSub: + if adapter.config.Role != RolePublisher && adapter.config.Role != RoleSubscriber { + return core.E("stream.zmq", "invalid pubsub role", nil) + } + case ModePushPull: + if adapter.config.Role != RolePusher && adapter.config.Role != RolePuller { + return core.E("stream.zmq", "invalid pushpull role", nil) + } + default: + return core.E("stream.zmq", "invalid mode", nil) + } + return nil +} + +func (adapter *Adapter) newSocket(ctx context.Context) (zmq4.Socket, error) { + switch adapter.config.Role { + case RolePublisher: + return zmq4.NewPub(ctx), nil + case RoleSubscriber: + socket := zmq4.NewSub(ctx) + topics := adapter.config.Topics + if len(topics) == 0 { + topics = []string{""} + } + for _, topic := range topics { + if err := socket.SetOption(zmq4.OptionSubscribe, topic); err != nil { + return nil, err + } + } + return socket, nil + case RolePusher: + return zmq4.NewPush(ctx), nil + case RolePuller: + return zmq4.NewPull(ctx), nil + default: + return nil, core.E("stream.zmq", "invalid role", nil) + } } -func (a *Adapter) registryKey() string { - return a.config.Endpoint + "|" + strconv.Itoa(int(a.config.Mode)) +func (adapter *Adapter) connectSocket(socket zmq4.Socket) error { + if adapter.shouldListen() { + return socket.Listen(listenEndpoint(adapter.config.Endpoint)) + } + return socket.Dial(adapter.config.Endpoint) } -func (a *Adapter) sourceID() string { - return a.source +func (adapter *Adapter) shouldListen() bool { + if adapter.config.Mode == ModePushPull { + return adapter.config.Role == RolePusher + } + return adapter.config.Role == RolePublisher } -func (a *Adapter) isRunning() bool { - a.mu.Lock() - defer a.mu.Unlock() - return a.running +func (adapter *Adapter) isSender() bool { + return adapter.config.Role == RolePublisher || adapter.config.Role == RolePusher } -func (r *zmqRegistry) add(key string, adapter *Adapter) { - r.mu.Lock() - defer r.mu.Unlock() - if r.adapters[key] == nil { - r.adapters[key] = map[*Adapter]struct{}{} - } - r.adapters[key][adapter] = struct{}{} +func (adapter *Adapter) isReceiver() bool { + return adapter.config.Role == RoleSubscriber || adapter.config.Role == RolePuller } -func (r *zmqRegistry) remove(key string, adapter *Adapter) { - r.mu.Lock() - defer r.mu.Unlock() - if adapters := r.adapters[key]; adapters != nil { - delete(adapters, adapter) - if len(adapters) == 0 { - delete(r.adapters, key) +func decodeMessage(message zmq4.Msg) (string, []byte, bool) { + payload := message.Bytes() + for index, value := range payload { + if value != 0 { + continue } + channel := string(payload[:index]) + frame := append([]byte(nil), payload[index+1:]...) + return channel, frame, true } + return "", nil, false } -func (r *zmqRegistry) publish(key string, message message) { - r.mu.RLock() - adapters := r.adapters[key] - targets := make([]*Adapter, 0, len(adapters)) - for adapter := range adapters { - targets = append(targets, adapter) +func (adapter *Adapter) recvWithTimeout(ctx context.Context, socket zmq4.Socket, timeout time.Duration) (zmq4.Msg, error) { + if timeout <= 0 { + msg, err := socket.Recv() + return msg, err } - r.mu.RUnlock() - for _, adapter := range targets { - if adapter == nil || adapter.sourceID() == message.SourceID { - continue - } - if adapter.config.Role != RoleSubscriber && adapter.config.Role != RolePuller { - continue - } - if len(adapter.config.Topics) > 0 && message.Channel != "" { - allowed := false - for _, topic := range adapter.config.Topics { - if topic == message.Channel { - allowed = true - break - } - } - if !allowed { - continue - } - } - if adapter.hub == nil { - continue + type result struct { + message zmq4.Msg + err error + } + + receive := make(chan result, 1) + go func() { + msg, err := socket.Recv() + receive <- result{message: msg, err: err} + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-ctx.Done(): + if err := socket.Close(); err != nil { + return zmq4.Msg{}, err } - if message.Channel == "" { - _ = adapter.hub.Broadcast(message.Frame) - continue + return zmq4.Msg{}, ctx.Err() + case outcome := <-receive: + return outcome.message, outcome.err + case <-timer.C: + if err := socket.Close(); err != nil { + return zmq4.Msg{}, err } - _ = adapter.hub.Publish(message.Channel, message.Frame) + return zmq4.Msg{}, stream.ErrHandshakeTimeout + } +} + +func encodeMessage(channel string, frame []byte) []byte { + output := make([]byte, 0, len(channel)+1+len(frame)) + output = append(output, []byte(channel)...) + output = append(output, 0) + output = append(output, frame...) + return output +} + +func listenEndpoint(endpoint string) string { + parsed, err := url.Parse(endpoint) + if err != nil || parsed.Scheme != "tcp" { + return endpoint + } + + host, port, err := net.SplitHostPort(parsed.Host) + if err != nil { + return endpoint + } + if host == "" || host == "*" { + return endpoint } + + parsed.Host = net.JoinHostPort("*", port) + return parsed.String() } diff --git a/adapter/zmq/zmq_test.go b/adapter/zmq/zmq_test.go new file mode 100644 index 0000000..c4f97e0 --- /dev/null +++ b/adapter/zmq/zmq_test.go @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package zmq + +import ( + "context" + "net" + "strconv" + "testing" + "time" + + "dappco.re/go/stream" +) + +func TestAX7_Adapter_Publish_Good(t *testing.T) { + publisherHub := stream.NewHub() + subscriberHub := stream.NewHub() + + publisherContext, publisherCancel := context.WithCancel(context.Background()) + defer publisherCancel() + subscriberContext, subscriberCancel := context.WithCancel(context.Background()) + defer subscriberCancel() + + go publisherHub.Run(publisherContext) + go subscriberHub.Run(subscriberContext) + + endpoint := randomTCPEndpoint(t) + publisher := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RolePublisher, + }) + publisher.Mount(publisherHub) + + subscriber := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RoleSubscriber, + Topics: []string{"block"}, + }) + subscriber.Mount(subscriberHub) + + runContext, runCancel := context.WithCancel(context.Background()) + defer runCancel() + go func() { _ = publisher.Start(runContext) }() + go func() { _ = subscriber.Start(runContext) }() + waitForAdapterRunning(t, publisher) + waitForAdapterRunning(t, subscriber) + + received := make(chan []byte, 1) + unsubscribe := subscriberHub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if err := publisher.Publish("block", []byte("template")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + return + case <-time.After(100 * time.Millisecond): + } + } + t.Fatal("timed out waiting for zmq frame") +} + +func TestAX7_Adapter_Publish_Bad(t *testing.T) { + hub := stream.NewHub() + adapter := New(Config{ + Mode: ModePubSub, + Endpoint: randomTCPEndpoint(t), + Role: RoleSubscriber, + }) + adapter.Mount(hub) + + if err := adapter.Publish("block", []byte("template")); err == nil { + t.Fatal("Publish() error = nil, want publish not supported error") + } +} + +func TestAX7_Adapter_Start_Ugly(t *testing.T) { + pusherHub := stream.NewHub() + pullerHub := stream.NewHub() + + pusherContext, pusherCancel := context.WithCancel(context.Background()) + defer pusherCancel() + pullerContext, pullerCancel := context.WithCancel(context.Background()) + defer pullerCancel() + + go pusherHub.Run(pusherContext) + go pullerHub.Run(pullerContext) + + endpoint := randomTCPEndpoint(t) + puller := New(Config{ + Mode: ModePushPull, + Endpoint: endpoint, + Role: RolePuller, + }) + puller.Mount(pullerHub) + + pusher := New(Config{ + Mode: ModePushPull, + Endpoint: endpoint, + Role: RolePusher, + }) + pusher.Mount(pusherHub) + + runContext, runCancel := context.WithCancel(context.Background()) + defer runCancel() + go func() { _ = puller.Start(runContext) }() + go func() { _ = pusher.Start(runContext) }() + waitForAdapterRunning(t, puller) + waitForAdapterRunning(t, pusher) + + received := make(chan []byte, 1) + unsubscribe := pullerHub.Subscribe("job", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if err := pusher.Publish("job", []byte("work")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + select { + case frame := <-received: + if string(frame) != "work" { + t.Fatalf("received frame = %q, want %q", string(frame), "work") + } + return + case <-time.After(100 * time.Millisecond): + } + } + t.Fatal("timed out waiting for push/pull frame") +} + +func TestAdapter_Start_Auth_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + endpoint := randomTCPEndpoint(t) + subscriber := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RoleSubscriber, + Topics: []string{"block"}, + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + if string(handshake) != "block\x00hello" { + return stream.AuthResult{Valid: false} + } + return stream.AuthResult{Valid: true} + }), + }) + subscriber.Mount(hub) + + publisher := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RolePublisher, + }) + publisher.Mount(stream.NewHub()) + + runContext, runCancel := context.WithCancel(context.Background()) + defer runCancel() + go func() { _ = subscriber.Start(runContext) }() + go func() { _ = publisher.Start(runContext) }() + waitForAdapterRunning(t, subscriber) + waitForAdapterRunning(t, publisher) + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if err := publisher.Publish("block", []byte("hello")); err != nil { + t.Fatalf("handshake Publish() error = %v", err) + } + if err := publisher.Publish("block", []byte("template")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + return + case <-time.After(100 * time.Millisecond): + } + } + + t.Fatal("timed out waiting for authenticated zmq frame") +} + +func TestAdapter_Start_RegistersPeer_Good(t *testing.T) { + connected := make(chan *stream.Peer, 1) + hub := stream.NewHubWithConfig(stream.HubConfig{ + OnConnect: func(peer *stream.Peer) { + select { + case connected <- peer: + default: + } + }, + }) + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + endpoint := randomTCPEndpoint(t) + subscriber := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RoleSubscriber, + Topics: []string{"block"}, + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + if string(handshake) != "block\x00hello" { + return stream.AuthResult{Valid: false} + } + return stream.AuthResult{ + Valid: true, + UserID: "node-42", + Claims: map[string]any{"role": "worker"}, + } + }), + }) + subscriber.Mount(hub) + + publisher := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RolePublisher, + }) + publisher.Mount(stream.NewHub()) + + runContext, runCancel := context.WithCancel(context.Background()) + defer runCancel() + go func() { _ = subscriber.Start(runContext) }() + go func() { _ = publisher.Start(runContext) }() + waitForAdapterRunning(t, subscriber) + waitForAdapterRunning(t, publisher) + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if err := publisher.Publish("block", []byte("hello")); err != nil { + t.Fatalf("handshake Publish() error = %v", err) + } + select { + case peer := <-connected: + if peer.Transport != "zmq" { + t.Fatalf("connected peer transport = %q, want %q", peer.Transport, "zmq") + } + if peer.UserID != "node-42" { + t.Fatalf("connected peer userID = %q, want %q", peer.UserID, "node-42") + } + if role, _ := peer.Claims["role"].(string); role != "worker" { + t.Fatalf("connected peer role = %q, want %q", role, "worker") + } + if peers := hub.PeerCount(); peers != 1 { + t.Fatalf("PeerCount() = %d, want %d", peers, 1) + } + return + case <-time.After(100 * time.Millisecond): + } + } + + t.Fatal("timed out waiting for zmq peer registration") +} + +func TestAdapter_Start_Auth_Ugly(t *testing.T) { + endpoint := randomTCPEndpoint(t) + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + subscriber := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RoleSubscriber, + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + return stream.AuthResult{Valid: false} + }), + HandshakeTimeout: 500 * time.Millisecond, + }) + subscriber.Mount(hub) + + publisher := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RolePublisher, + }) + publisher.Mount(stream.NewHub()) + + runContext, runCancel := context.WithCancel(context.Background()) + defer runCancel() + errs := make(chan error, 1) + go func() { errs <- subscriber.Start(runContext) }() + go func() { _ = publisher.Start(runContext) }() + waitForAdapterRunning(t, subscriber) + waitForAdapterRunning(t, publisher) + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if err := publisher.Publish("block", []byte("hello")); err != nil { + t.Fatalf("handshake Publish() error = %v", err) + } + + select { + case err := <-errs: + if err != stream.ErrAuthRejected { + t.Fatalf("Start() error = %v, want %v", err, stream.ErrAuthRejected) + } + return + case <-time.After(100 * time.Millisecond): + } + } + + t.Fatal("timed out waiting for auth rejection") +} + +func TestAdapter_Start_Auth_Timeout(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + subscriber := New(Config{ + Mode: ModePubSub, + Endpoint: randomTCPEndpoint(t), + Role: RoleSubscriber, + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + return stream.AuthResult{Valid: true} + }), + HandshakeTimeout: 500 * time.Millisecond, + }) + subscriber.Mount(hub) + + publisher := New(Config{ + Mode: ModePubSub, + Endpoint: subscriber.config.Endpoint, + Role: RolePublisher, + }) + publisher.Mount(stream.NewHub()) + + runContext, runCancel := context.WithCancel(context.Background()) + defer runCancel() + errs := make(chan error, 1) + go func() { errs <- subscriber.Start(runContext) }() + go func() { _ = publisher.Start(runContext) }() + waitForAdapterRunning(t, subscriber) + waitForAdapterRunning(t, publisher) + + select { + case err := <-errs: + if err != stream.ErrHandshakeTimeout { + t.Fatalf("Start() error = %v, want %v", err, stream.ErrHandshakeTimeout) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handshake timeout") + } +} + +func TestAdapter_Start_Auth_HandshakeTooLarge_Good(t *testing.T) { + hub := stream.NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + + endpoint := randomTCPEndpoint(t) + subscriber := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RoleSubscriber, + ConnAuthenticator: stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { + return stream.AuthResult{Valid: true} + }), + HandshakeTimeout: 500 * time.Millisecond, + }) + subscriber.Mount(hub) + + publisher := New(Config{ + Mode: ModePubSub, + Endpoint: endpoint, + Role: RolePublisher, + }) + publisher.Mount(stream.NewHub()) + + runContext, runCancel := context.WithCancel(context.Background()) + defer runCancel() + errs := make(chan error, 1) + go func() { errs <- subscriber.Start(runContext) }() + go func() { _ = publisher.Start(runContext) }() + waitForAdapterRunning(t, subscriber) + waitForAdapterRunning(t, publisher) + + tooLargeHandshake := make([]byte, maxHandshakeFrameSize+1) + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if err := publisher.Publish("block", tooLargeHandshake); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case err := <-errs: + if err != stream.ErrAuthRejected { + t.Fatalf("Start() error = %v, want %v", err, stream.ErrAuthRejected) + } + return + case <-time.After(100 * time.Millisecond): + } + } + + t.Fatal("timed out waiting for handshake rejection") +} + +func TestModeAndRole_String_Good(t *testing.T) { + if ModePubSub.String() != "pubsub" { + t.Fatalf("ModePubSub.String() = %q, want %q", ModePubSub.String(), "pubsub") + } + if ModePushPull.String() != "pushpull" { + t.Fatalf("ModePushPull.String() = %q, want %q", ModePushPull.String(), "pushpull") + } + if RolePublisher.String() != "publisher" { + t.Fatalf("RolePublisher.String() = %q, want %q", RolePublisher.String(), "publisher") + } + if RoleSubscriber.String() != "subscriber" { + t.Fatalf("RoleSubscriber.String() = %q, want %q", RoleSubscriber.String(), "subscriber") + } + if RolePusher.String() != "pusher" { + t.Fatalf("RolePusher.String() = %q, want %q", RolePusher.String(), "pusher") + } + if RolePuller.String() != "puller" { + t.Fatalf("RolePuller.String() = %q, want %q", RolePuller.String(), "puller") + } +} + +func randomTCPEndpoint(t *testing.T) string { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer listener.Close() + + address, ok := listener.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("Addr() type = %T, want *net.TCPAddr", listener.Addr()) + } + return "tcp://127.0.0.1:" + strconv.Itoa(address.Port) +} + +func waitForAdapterRunning(t *testing.T, adapter *Adapter) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + adapter.mutex.RLock() + running := adapter.running + adapter.mutex.RUnlock() + if running { + time.Sleep(100 * time.Millisecond) + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for adapter to start") +} diff --git a/auth.go b/auth.go index e9e241f..7a565aa 100644 --- a/auth.go +++ b/auth.go @@ -1,58 +1,67 @@ // SPDX-License-Identifier: EUPL-1.2 +// auth := stream.NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) +// request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) +// request.Header.Set("Authorization", "Bearer sk-live") +// result := auth.Authenticate(request) package stream import ( "net/http" - "dappco.re/go/core" + "dappco.re/go" ) -// Authenticator validates an HTTP request during the WebSocket upgrade or SSE -// connection. Implementations may inspect headers, query parameters, or cookies. +// auth := stream.AuthenticatorFunc(func(request *http.Request) stream.AuthResult { +// if request.Header.Get("X-Api-Key") == "sk-live" { +// return stream.AuthResult{Valid: true, UserID: "user-42"} +// } +// return stream.AuthResult{Valid: false} +// }) type Authenticator interface { - Authenticate(r *http.Request) AuthResult + Authenticate(request *http.Request) AuthResult } -// AuthResult holds the outcome of an authentication attempt. +// result := stream.AuthResult{ +// Valid: true, +// UserID: "user-42", +// Claims: map[string]any{"role": "admin"}, +// } type AuthResult struct { - // Valid indicates whether authentication succeeded. Valid bool - // UserID is the authenticated user's identifier. UserID string - // Claims holds arbitrary metadata (roles, scopes, tenant ID). + // claims := result.Claims + // claims["role"] = "admin" Claims map[string]any - // Error holds the reason for failure, if any. Error error } -// AuthenticatorFunc adapts a plain function to the Authenticator interface. -// -// auth := stream.AuthenticatorFunc(func(r *http.Request) stream.AuthResult { -// token := r.Header.Get("X-Api-Key") -// if token == "" { return stream.AuthResult{Valid: false} } -// return stream.AuthResult{Valid: true, UserID: lookupUser(token)} +// authenticator := stream.AuthenticatorFunc(func(request *http.Request) stream.AuthResult { +// return stream.AuthResult{Valid: true, UserID: "user-42"} // }) -type AuthenticatorFunc func(r *http.Request) AuthResult +type AuthenticatorFunc func(request *http.Request) AuthResult -// Authenticate calls f(r). -func (f AuthenticatorFunc) Authenticate(r *http.Request) AuthResult { - return f(r) +// request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) +// result := authenticatorFunc.Authenticate(request) +func (authenticatorFunc AuthenticatorFunc) Authenticate(request *http.Request) AuthResult { + if authenticatorFunc == nil || request == nil { + return AuthResult{Valid: false} + } + return normalizeAuthResult(authenticatorFunc(request)) } -// APIKeyAuthenticator validates Authorization: Bearer against a static map. -// -// auth := stream.NewAPIKeyAuth(map[string]string{"sk-prod-1": "user-42"}) +// auth := stream.NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) +// request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) +// request.Header.Set("Authorization", "Bearer sk-live") +// result := auth.Authenticate(request) type APIKeyAuthenticator struct { Keys map[string]string } -// NewAPIKeyAuth creates an API key authenticator from a key-to-user map. -// -// auth := stream.NewAPIKeyAuth(map[string]string{"sk-prod-1": "user-42"}) +// auth := stream.NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) func NewAPIKeyAuth(keys map[string]string) *APIKeyAuthenticator { if keys == nil { keys = map[string]string{} @@ -64,101 +73,124 @@ func NewAPIKeyAuth(keys map[string]string) *APIKeyAuthenticator { return &APIKeyAuthenticator{Keys: copied} } -// Authenticate validates the request's Authorization Bearer token against the key map. -func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult { - if a == nil { +// auth := stream.NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) +// request.Header.Set("Authorization", "Bearer sk-live") +// result := auth.Authenticate(request) +func (authenticator *APIKeyAuthenticator) Authenticate(request *http.Request) AuthResult { + if authenticator == nil || request == nil { return AuthResult{Valid: false} } - header := r.Header.Get("Authorization") - if header == "" { - return AuthResult{Valid: false, Error: ErrMissingAuthHeader} - } - if !core.HasPrefix(header, "Bearer ") { - return AuthResult{Valid: false, Error: ErrMalformedAuthHeader} - } - token := core.TrimPrefix(header, "Bearer ") - if token == "" { - return AuthResult{Valid: false, Error: ErrMalformedAuthHeader} + token, result := bearerTokenFromRequest(request) + if !result.Valid { + return result } - userID, ok := a.Keys[token] + userID, ok := authenticator.Keys[token] if !ok { return AuthResult{Valid: false, Error: ErrInvalidAPIKey} } - return AuthResult{Valid: true, UserID: userID} + return normalizeAuthResult(AuthResult{Valid: true, UserID: userID}) } -// BearerTokenAuth delegates bearer token validation to a caller-supplied function. -// -// auth := &stream.BearerTokenAuth{ +// authenticator := &stream.BearerTokenAuth{ // Validate: func(token string) stream.AuthResult { -// claims, err := jwt.Parse(token, keyFunc) -// if err != nil { return stream.AuthResult{Valid: false, Error: err} } -// return stream.AuthResult{Valid: true, UserID: claims.Subject} +// if token == "sk-live" { +// return stream.AuthResult{Valid: true, UserID: "user-42"} +// } +// return stream.AuthResult{Valid: false} // }, // } type BearerTokenAuth struct { Validate func(token string) AuthResult } -// Authenticate extracts the Bearer token and delegates to Validate. -func (b *BearerTokenAuth) Authenticate(r *http.Request) AuthResult { - if b == nil || b.Validate == nil { +// request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) +// request.Header.Set("Authorization", "Bearer sk-live") +// result := authenticator.Authenticate(request) +func (authenticator *BearerTokenAuth) Authenticate(request *http.Request) AuthResult { + if authenticator == nil || authenticator.Validate == nil || request == nil { return AuthResult{Valid: false} } - header := r.Header.Get("Authorization") - if header == "" { - return AuthResult{Valid: false, Error: ErrMissingAuthHeader} - } - if !core.HasPrefix(header, "Bearer ") { - return AuthResult{Valid: false, Error: ErrMalformedAuthHeader} - } - token := core.TrimPrefix(header, "Bearer ") - if token == "" { - return AuthResult{Valid: false, Error: ErrMalformedAuthHeader} + token, result := bearerTokenFromRequest(request) + if !result.Valid { + return result } - return b.Validate(token) + return normalizeAuthResult(authenticator.Validate(token)) } -// QueryTokenAuth extracts a ?token= query parameter and validates via caller function. -// Use when browser clients cannot set headers (native WebSocket API). -// -// auth := &stream.QueryTokenAuth{ -// Validate: func(token string) stream.AuthResult { ... }, +// authenticator := &stream.QueryTokenAuth{ +// Validate: func(token string) stream.AuthResult { +// if token == "sk-live" { +// return stream.AuthResult{Valid: true, UserID: "user-42"} +// } +// return stream.AuthResult{Valid: false} +// }, // } type QueryTokenAuth struct { Validate func(token string) AuthResult } -// Authenticate extracts the token query parameter and delegates to Validate. -func (q *QueryTokenAuth) Authenticate(r *http.Request) AuthResult { - if q == nil || q.Validate == nil { +// request := httptest.NewRequest(http.MethodGet, "/stream/ws?token=sk-live", nil) +// result := authenticator.Authenticate(request) +func (authenticator *QueryTokenAuth) Authenticate(request *http.Request) AuthResult { + if authenticator == nil || authenticator.Validate == nil || request == nil { return AuthResult{Valid: false} } - token := r.URL.Query().Get("token") + token := request.URL.Query().Get("token") if token == "" { return AuthResult{Valid: false} } - return q.Validate(token) + return normalizeAuthResult(authenticator.Validate(token)) } -// ConnAuthenticator validates a raw connection handshake for TCP and ZMQ adapters. -// The handshake is the first message received on the connection (up to 4 KB). -// // auth := stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { -// var h tcp.Handshake -// if r := core.JSONUnmarshal(handshake, &h); !r.OK { -// return stream.AuthResult{Valid: false} +// if string(handshake) == "hello" { +// return stream.AuthResult{Valid: true, UserID: "peer-1"} // } -// return verifyHMAC(h.Token, h.Timestamp) +// return stream.AuthResult{Valid: false} // }) type ConnAuthenticator interface { AuthenticateConn(handshake []byte) AuthResult } -// ConnAuthenticatorFunc adapts a plain function to ConnAuthenticator. +// auth := stream.ConnAuthenticatorFunc(func(handshake []byte) stream.AuthResult { +// if string(handshake) == "hello" { +// return stream.AuthResult{Valid: true, UserID: "peer-1"} +// } +// return stream.AuthResult{Valid: false} +// }) type ConnAuthenticatorFunc func(handshake []byte) AuthResult -// AuthenticateConn calls f(handshake). -func (f ConnAuthenticatorFunc) AuthenticateConn(handshake []byte) AuthResult { - return f(handshake) +// result := auth.AuthenticateConn([]byte("hello")) +func (connAuthenticatorFunc ConnAuthenticatorFunc) AuthenticateConn(handshake []byte) AuthResult { + if connAuthenticatorFunc == nil { + return AuthResult{Valid: false} + } + return normalizeAuthResult(connAuthenticatorFunc(handshake)) +} + +// token, result := bearerTokenFromRequest(request) +func bearerTokenFromRequest(request *http.Request) (string, AuthResult) { + header := request.Header.Get("Authorization") + if header == "" { + return "", AuthResult{Valid: false, Error: ErrMissingAuthHeader} + } + if !core.HasPrefix(header, "Bearer ") { + return "", AuthResult{Valid: false, Error: ErrMalformedAuthHeader} + } + token := core.TrimPrefix(header, "Bearer ") + if token == "" { + return "", AuthResult{Valid: false, Error: ErrMalformedAuthHeader} + } + return token, AuthResult{Valid: true} +} + +// result = normalizeAuthResult(result) +func normalizeAuthResult(result AuthResult) AuthResult { + if !result.Valid { + return result + } + if result.Claims == nil { + result.Claims = map[string]any{} + } + return result } diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..d0d00e6 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestAuth_APIKeyAuthenticator_ClaimsInitialized_Good(t *testing.T) { + authenticator := NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) + request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer sk-live") + + result := authenticator.Authenticate(request) + if !result.Valid { + t.Fatal("Authenticate() result.Valid = false, want true") + } + if result.Claims == nil { + t.Fatal("Authenticate() result.Claims = nil, want empty map") + } + if len(result.Claims) != 0 { + t.Fatalf("len(Authenticate().Claims) = %d, want 0", len(result.Claims)) + } +} + +func TestAuth_AuthenticatorFunc_ClaimsInitialized_Good(t *testing.T) { + authenticator := AuthenticatorFunc(func(request *http.Request) AuthResult { + return AuthResult{Valid: true, UserID: "user-42"} + }) + request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + + result := authenticator.Authenticate(request) + if !result.Valid { + t.Fatal("Authenticate() result.Valid = false, want true") + } + if result.Claims == nil { + t.Fatal("Authenticate() result.Claims = nil, want empty map") + } + if len(result.Claims) != 0 { + t.Fatalf("len(Authenticate().Claims) = %d, want 0", len(result.Claims)) + } +} + +func TestAuth_ConnAuthenticatorFunc_ClaimsInitialized_Good(t *testing.T) { + authenticator := ConnAuthenticatorFunc(func(handshake []byte) AuthResult { + return AuthResult{Valid: true, UserID: "peer-1"} + }) + + result := authenticator.AuthenticateConn([]byte("hello")) + if !result.Valid { + t.Fatal("AuthenticateConn() result.Valid = false, want true") + } + if result.Claims == nil { + t.Fatal("AuthenticateConn() result.Claims = nil, want empty map") + } + if len(result.Claims) != 0 { + t.Fatalf("len(AuthenticateConn().Claims) = %d, want 0", len(result.Claims)) + } +} + +func TestAuth_APIKeyAuthenticator_Bad(t *testing.T) { + authenticator := NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) + + // Missing Authorization header. + request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + result := authenticator.Authenticate(request) + if result.Valid { + t.Fatal("Authenticate() without header: result.Valid = true, want false") + } + if result.Error != ErrMissingAuthHeader { + t.Fatalf("Authenticate() without header: error = %v, want %v", result.Error, ErrMissingAuthHeader) + } + + // Malformed Authorization header (not "Bearer "). + request = httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + request.Header.Set("Authorization", "Basic sk-live") + result = authenticator.Authenticate(request) + if result.Valid { + t.Fatal("Authenticate() with Basic scheme: result.Valid = true, want false") + } + if result.Error != ErrMalformedAuthHeader { + t.Fatalf("Authenticate() with Basic scheme: error = %v, want %v", result.Error, ErrMalformedAuthHeader) + } + + // Unknown API key. + request = httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer sk-unknown") + result = authenticator.Authenticate(request) + if result.Valid { + t.Fatal("Authenticate() with unknown key: result.Valid = true, want false") + } + if result.Error != ErrInvalidAPIKey { + t.Fatalf("Authenticate() with unknown key: error = %v, want %v", result.Error, ErrInvalidAPIKey) + } +} + +func TestAuth_APIKeyAuthenticator_Ugly(t *testing.T) { + // Nil authenticator returns invalid result without panic. + var authenticator *APIKeyAuthenticator + request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer sk-live") + + result := authenticator.Authenticate(request) + if result.Valid { + t.Fatal("nil authenticator: result.Valid = true, want false") + } + + // Nil request returns invalid result without panic. + validAuth := NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) + result = validAuth.Authenticate(nil) + if result.Valid { + t.Fatal("nil request: result.Valid = true, want false") + } +} + +func TestAuth_BearerTokenAuth_Good(t *testing.T) { + authenticator := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + if token == "jwt-valid" { + return AuthResult{Valid: true, UserID: "user-99", Claims: map[string]any{"role": "admin"}} + } + return AuthResult{Valid: false} + }, + } + + request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer jwt-valid") + result := authenticator.Authenticate(request) + if !result.Valid { + t.Fatal("Authenticate() result.Valid = false, want true") + } + if result.UserID != "user-99" { + t.Fatalf("result.UserID = %q, want %q", result.UserID, "user-99") + } + if result.Claims["role"] != "admin" { + t.Fatalf("result.Claims[role] = %v, want %q", result.Claims["role"], "admin") + } +} + +func TestAuth_BearerTokenAuth_Bad(t *testing.T) { + authenticator := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: false} + }, + } + + // Valid header but rejected by validator. + request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer bad-token") + result := authenticator.Authenticate(request) + if result.Valid { + t.Fatal("Authenticate() with rejected token: result.Valid = true, want false") + } +} + +func TestAuth_BearerTokenAuth_Ugly(t *testing.T) { + // Nil Validate function returns invalid without panic. + authenticator := &BearerTokenAuth{} + request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer test") + result := authenticator.Authenticate(request) + if result.Valid { + t.Fatal("nil Validate: result.Valid = true, want false") + } +} + +func TestAuth_QueryTokenAuth_Good(t *testing.T) { + authenticator := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + if token == "ws-token-1" { + return AuthResult{Valid: true, UserID: "browser-user"} + } + return AuthResult{Valid: false} + }, + } + + request := httptest.NewRequest(http.MethodGet, "/stream/ws?token=ws-token-1", nil) + result := authenticator.Authenticate(request) + if !result.Valid { + t.Fatal("Authenticate() result.Valid = false, want true") + } + if result.UserID != "browser-user" { + t.Fatalf("result.UserID = %q, want %q", result.UserID, "browser-user") + } +} + +func TestAuth_QueryTokenAuth_Bad(t *testing.T) { + authenticator := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: false} + }, + } + + // Missing token query parameter. + request := httptest.NewRequest(http.MethodGet, "/stream/ws", nil) + result := authenticator.Authenticate(request) + if result.Valid { + t.Fatal("Authenticate() without token param: result.Valid = true, want false") + } +} + +func TestAuth_QueryTokenAuth_Ugly(t *testing.T) { + // Nil Validate function returns invalid without panic. + authenticator := &QueryTokenAuth{} + request := httptest.NewRequest(http.MethodGet, "/stream/ws?token=test", nil) + result := authenticator.Authenticate(request) + if result.Valid { + t.Fatal("nil Validate: result.Valid = true, want false") + } + + // Nil authenticator returns invalid without panic. + var nilAuth AuthenticatorFunc + result = nilAuth.Authenticate(httptest.NewRequest(http.MethodGet, "/stream/ws", nil)) + if result.Valid { + t.Fatal("nil AuthenticatorFunc: result.Valid = true, want false") + } +} + +func TestAuth_ConnAuthenticatorFunc_Bad(t *testing.T) { + authenticator := ConnAuthenticatorFunc(func(handshake []byte) AuthResult { + return AuthResult{Valid: false} + }) + + result := authenticator.AuthenticateConn([]byte("invalid-handshake")) + if result.Valid { + t.Fatal("AuthenticateConn() with invalid handshake: result.Valid = true, want false") + } +} + +func TestAuth_ConnAuthenticatorFunc_Ugly(t *testing.T) { + // Nil ConnAuthenticatorFunc returns invalid without panic. + var authenticator ConnAuthenticatorFunc + result := authenticator.AuthenticateConn([]byte("hello")) + if result.Valid { + t.Fatal("nil ConnAuthenticatorFunc: result.Valid = true, want false") + } +} diff --git a/ax7_more_test.go b/ax7_more_test.go new file mode 100644 index 0000000..38109cc --- /dev/null +++ b/ax7_more_test.go @@ -0,0 +1,893 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream + +import core "dappco.re/go" + +type T = core.T +type CancelFunc = core.CancelFunc +type Duration = core.Duration +type Request = core.Request + +const ( + Millisecond = core.Millisecond + Second = core.Second +) + +var ( + Background = core.Background + NewHTTPTestRequest = core.NewHTTPTestRequest + Sleep = core.Sleep + WithCancel = core.WithCancel + WithTimeout = core.WithTimeout + + AssertContains = core.AssertContains + AssertEqual = core.AssertEqual + AssertError = core.AssertError + AssertFalse = core.AssertFalse + AssertNil = core.AssertNil + AssertNoError = core.AssertNoError + AssertNotEqual = core.AssertNotEqual + AssertNotNil = core.AssertNotNil + AssertNotPanics = core.AssertNotPanics + AssertTrue = core.AssertTrue +) + +func ax7RunningHub(t *T) (*Hub, CancelFunc) { + hub := NewHub() + ctx, cancel := WithCancel(Background()) + go hub.Run(ctx) + waitForRunningHub(t, hub) + return hub, cancel +} + +func ax7Timeout(duration Duration) <-chan struct{} { + ctx, cancel := WithTimeout(Background(), duration) + done := make(chan struct{}) + go func() { + defer cancel() + <-ctx.Done() + close(done) + }() + return done +} + +func TestAX7_APIKeyAuthenticator_Authenticate_Good(t *T) { + authenticator := NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer sk-live") + + result := authenticator.Authenticate(request) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-42", result.UserID) + AssertNotNil(t, result.Claims) +} + +func TestAX7_APIKeyAuthenticator_Authenticate_Bad(t *T) { + authenticator := NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + + result := authenticator.Authenticate(request) + AssertFalse(t, result.Valid) + AssertEqual(t, ErrMissingAuthHeader, result.Error) +} + +func TestAX7_APIKeyAuthenticator_Authenticate_Ugly(t *T) { + var authenticator *APIKeyAuthenticator + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer sk-live") + + result := authenticator.Authenticate(request) + AssertFalse(t, result.Valid) + AssertEqual(t, "", result.UserID) +} + +func TestAX7_AuthenticatorFunc_Authenticate_Good(t *T) { + authenticator := AuthenticatorFunc(func(request *Request) AuthResult { + return AuthResult{Valid: request.URL.Path == "/stream/ws", UserID: "agent"} + }) + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + + result := authenticator.Authenticate(request) + AssertTrue(t, result.Valid) + AssertEqual(t, "agent", result.UserID) + AssertNotNil(t, result.Claims) +} + +func TestAX7_AuthenticatorFunc_Authenticate_Bad(t *T) { + authenticator := AuthenticatorFunc(func(request *Request) AuthResult { + return AuthResult{Valid: false, Error: ErrAuthRejected} + }) + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + + result := authenticator.Authenticate(request) + AssertFalse(t, result.Valid) + AssertEqual(t, ErrAuthRejected, result.Error) +} + +func TestAX7_AuthenticatorFunc_Authenticate_Ugly(t *T) { + var authenticator AuthenticatorFunc + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + + result := authenticator.Authenticate(request) + AssertFalse(t, result.Valid) + AssertNil(t, result.Claims) +} + +func TestAX7_BearerTokenAuth_Authenticate_Good(t *T) { + authenticator := &BearerTokenAuth{Validate: func(token string) AuthResult { + return AuthResult{Valid: token == "jwt-valid", UserID: "user-99"} + }} + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer jwt-valid") + + result := authenticator.Authenticate(request) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-99", result.UserID) +} + +func TestAX7_BearerTokenAuth_Authenticate_Bad(t *T) { + authenticator := &BearerTokenAuth{Validate: func(token string) AuthResult { + return AuthResult{Valid: false, Error: ErrAuthRejected} + }} + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer jwt-invalid") + + result := authenticator.Authenticate(request) + AssertFalse(t, result.Valid) + AssertEqual(t, ErrAuthRejected, result.Error) +} + +func TestAX7_BearerTokenAuth_Authenticate_Ugly(t *T) { + authenticator := &BearerTokenAuth{} + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + request.Header.Set("Authorization", "Bearer jwt-valid") + + result := authenticator.Authenticate(request) + AssertFalse(t, result.Valid) + AssertEqual(t, "", result.UserID) +} + +func TestAX7_ConnAuthenticatorFunc_AuthenticateConn_Good(t *T) { + authenticator := ConnAuthenticatorFunc(func(handshake []byte) AuthResult { + return AuthResult{Valid: string(handshake) == "hello", UserID: "peer-1"} + }) + result := authenticator.AuthenticateConn([]byte("hello")) + + AssertTrue(t, result.Valid) + AssertEqual(t, "peer-1", result.UserID) + AssertNotNil(t, result.Claims) +} + +func TestAX7_ConnAuthenticatorFunc_AuthenticateConn_Bad(t *T) { + authenticator := ConnAuthenticatorFunc(func(handshake []byte) AuthResult { + return AuthResult{Valid: false, Error: ErrAuthRejected} + }) + result := authenticator.AuthenticateConn([]byte("bad")) + + AssertFalse(t, result.Valid) + AssertEqual(t, ErrAuthRejected, result.Error) + AssertNil(t, result.Claims) +} + +func TestAX7_ConnAuthenticatorFunc_AuthenticateConn_Ugly(t *T) { + var authenticator ConnAuthenticatorFunc + result := authenticator.AuthenticateConn(nil) + + AssertFalse(t, result.Valid) + AssertEqual(t, "", result.UserID) + AssertNil(t, result.Claims) +} + +func TestAX7_DefaultHubConfig_Good(t *T) { + config := DefaultHubConfig() + + AssertEqual(t, 30*Second, config.HeartbeatInterval) + AssertEqual(t, 60*Second, config.PongTimeout) + AssertEqual(t, 10*Second, config.WriteTimeout) +} + +func TestAX7_DefaultHubConfig_Bad(t *T) { + config := DefaultHubConfig() + + AssertNil(t, config.OnConnect) + AssertNil(t, config.OnDisconnect) + AssertNil(t, config.ChannelAuthoriser) +} + +func TestAX7_DefaultHubConfig_Ugly(t *T) { + config := normalizeHubConfig(HubConfig{HeartbeatInterval: Second, PongTimeout: Millisecond}) + + AssertEqual(t, Second, config.HeartbeatInterval) + AssertEqual(t, 2*Second, config.PongTimeout) + AssertEqual(t, 10*Second, config.WriteTimeout) +} + +func TestAX7_Hub_AddPeer_Good(t *T) { + hub := NewHub() + peer := NewPeer("ws") + + AssertNoError(t, hub.AddPeer(peer)) + AssertEqual(t, 1, hub.PeerCount()) + AssertEqual(t, 0, len(peer.Subscriptions())) +} + +func TestAX7_Hub_AddPeer_Bad(t *T) { + hub := NewHub() + + err := hub.AddPeer(nil) + AssertError(t, err) + AssertContains(t, err.Error(), "nil peer") +} + +func TestAX7_Hub_AddPeer_Ugly(t *T) { + hub := NewHub() + peer := &Peer{Transport: "ws"} + + AssertNoError(t, hub.AddPeer(peer)) + AssertNotNil(t, peer.SendQueue()) + AssertEqual(t, 1, hub.PeerCount()) +} + +func TestAX7_Hub_AllChannels_Good(t *T) { + hub := NewHub() + stopA := hub.Subscribe("block", func([]byte) {}) + defer stopA() + stopB := hub.Subscribe("hashrate", func([]byte) {}) + defer stopB() + + var channels []string + for channel := range hub.AllChannels() { + channels = append(channels, channel) + } + AssertEqual(t, []string{"block", "hashrate"}, channels) +} + +func TestAX7_Hub_AllChannels_Bad(t *T) { + var hub *Hub + count := 0 + + for range hub.AllChannels() { + count++ + } + AssertEqual(t, 0, count) +} + +func TestAX7_Hub_AllChannels_Ugly(t *T) { + hub := NewHub() + stop := hub.Subscribe("events", func([]byte) {}) + seq := hub.AllChannels() + stop() + + var channels []string + for channel := range seq { + channels = append(channels, channel) + } + AssertEqual(t, []string{"events"}, channels) +} + +func TestAX7_Hub_AllPeers_Good(t *T) { + hub := NewHub() + AssertNoError(t, hub.AddPeer(NewPeer("ws"))) + AssertNoError(t, hub.AddPeer(NewPeer("sse"))) + + count := 0 + for range hub.AllPeers() { + count++ + } + AssertEqual(t, 2, count) +} + +func TestAX7_Hub_AllPeers_Bad(t *T) { + var hub *Hub + count := 0 + + for range hub.AllPeers() { + count++ + } + AssertEqual(t, 0, count) +} + +func TestAX7_Hub_AllPeers_Ugly(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + seq := hub.AllPeers() + hub.RemovePeer(peer) + + count := 0 + for range seq { + count++ + } + AssertEqual(t, 1, count) +} + +func TestAX7_Hub_BroadcastFromBridge_Good(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + waitForPeerCount(t, hub, 1) + + AssertNoError(t, hub.BroadcastFromBridge([]byte("bridge"))) + frame := <-peer.SendQueue() + AssertEqual(t, "bridge", string(frame)) +} + +func TestAX7_Hub_BroadcastFromBridge_Bad(t *T) { + hub := NewHub() + + err := hub.BroadcastFromBridge([]byte("bridge")) + AssertEqual(t, ErrHubNotRunning, err) + AssertEqual(t, 0, hub.PeerCount()) +} + +func TestAX7_Hub_BroadcastFromBridge_Ugly(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + seen := false + stop := hub.SubscribeBroadcast(func([]byte) { seen = true }) + defer stop() + + AssertNoError(t, hub.BroadcastFromBridge([]byte("bridge"))) + Sleep(20 * Millisecond) + AssertFalse(t, seen) +} + +func TestAX7_Hub_BroadcastFromPeer_Bad(t *T) { + hub := NewHub() + peer := NewPeer("ws") + + err := hub.BroadcastFromPeer(peer, []byte("frame")) + AssertEqual(t, ErrHubNotRunning, err) + AssertEqual(t, 0, hub.PeerCount()) +} + +func TestAX7_Hub_BroadcastFromPeer_Ugly(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + source := NewPeer("ws") + receiver := NewPeer("ws") + AssertNoError(t, hub.AddPeer(source)) + AssertNoError(t, hub.AddPeer(receiver)) + waitForPeerCount(t, hub, 2) + + AssertNoError(t, hub.BroadcastFromPeer(source, []byte("fanout"))) + AssertEqual(t, "fanout", string(<-receiver.SendQueue())) +} + +func TestAX7_Hub_CanSubscribePeer_Good(t *T) { + hub := NewHubWithConfig(HubConfig{ChannelAuthoriser: func(peer *Peer, channel string) bool { + return peer.UserID == "agent" && channel == "private" + }}) + peer := NewPeer("ws") + peer.UserID = "agent" + + AssertNoError(t, hub.CanSubscribePeer(peer, "private")) + AssertNoError(t, hub.CanSubscribePeer(peer, "*")) +} + +func TestAX7_Hub_CanSubscribePeer_Ugly(t *T) { + hub := NewHub() + + err := hub.CanSubscribePeer(NewPeer("ws"), "") + AssertEqual(t, ErrEmptyChannel, err) + AssertError(t, err) +} + +func TestAX7_Hub_ChannelCount_Good(t *T) { + hub := NewHub() + stop := hub.Subscribe("events", func([]byte) {}) + defer stop() + + AssertEqual(t, 1, hub.ChannelCount()) + AssertEqual(t, 1, hub.ChannelSubscriberCount("events")) +} + +func TestAX7_Hub_ChannelCount_Bad(t *T) { + var hub *Hub + + AssertEqual(t, 0, hub.ChannelCount()) + AssertEqual(t, 0, hub.ChannelSubscriberCount("missing")) +} + +func TestAX7_Hub_ChannelCount_Ugly(t *T) { + hub := NewHub() + stop := hub.Subscribe("*", func([]byte) {}) + defer stop() + + AssertEqual(t, 0, hub.ChannelCount()) + AssertEqual(t, 1, hub.ChannelSubscriberCount("*")) +} + +func TestAX7_Hub_ChannelSubscriberCount_Good(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + AssertNoError(t, hub.SubscribePeer(peer, "hashrate")) + stop := hub.Subscribe("hashrate", func([]byte) {}) + defer stop() + + AssertEqual(t, 2, hub.ChannelSubscriberCount("hashrate")) + AssertEqual(t, 1, hub.PeerCount()) +} + +func TestAX7_Hub_ChannelSubscriberCount_Bad(t *T) { + hub := NewHub() + + AssertEqual(t, 0, hub.ChannelSubscriberCount("missing")) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_ChannelSubscriberCount_Ugly(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + AssertNoError(t, hub.SubscribePeer(peer, "*")) + + AssertEqual(t, 1, hub.ChannelSubscriberCount("*")) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_Config_Good(t *T) { + hub := NewHubWithConfig(HubConfig{HeartbeatInterval: Second, PongTimeout: 3 * Second}) + + config := hub.Config() + AssertEqual(t, Second, config.HeartbeatInterval) + AssertEqual(t, 3*Second, config.PongTimeout) +} + +func TestAX7_Hub_Config_Bad(t *T) { + var hub *Hub + + config := hub.Config() + AssertEqual(t, 30*Second, config.HeartbeatInterval) + AssertEqual(t, 60*Second, config.PongTimeout) +} + +func TestAX7_Hub_Config_Ugly(t *T) { + hub := NewHubWithConfig(HubConfig{HeartbeatInterval: Second, PongTimeout: Second}) + + config := hub.Config() + AssertEqual(t, Second, config.HeartbeatInterval) + AssertEqual(t, 2*Second, config.PongTimeout) +} + +func TestAX7_Hub_PeerCount_Good(t *T) { + hub := NewHub() + AssertNoError(t, hub.AddPeer(NewPeer("ws"))) + + AssertEqual(t, 1, hub.PeerCount()) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_PeerCount_Bad(t *T) { + var hub *Hub + + AssertEqual(t, 0, hub.PeerCount()) + AssertEqual(t, HubStats{}, hub.Stats()) +} + +func TestAX7_Hub_PeerCount_Ugly(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + hub.RemovePeer(peer) + + AssertEqual(t, 0, hub.PeerCount()) + AssertEqual(t, []string{}, peer.Subscriptions()) +} + +func TestAX7_Hub_PublishFromBridge_Good(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + received := make(chan []byte, 1) + stop := hub.Subscribe("block", func(frame []byte) { received <- append([]byte(nil), frame...) }) + defer stop() + + AssertNoError(t, hub.PublishFromBridge("block", []byte("template"))) + AssertEqual(t, "template", string(<-received)) +} + +func TestAX7_Hub_PublishFromBridge_Bad(t *T) { + hub := NewHub() + + err := hub.PublishFromBridge("block", []byte("template")) + AssertEqual(t, ErrHubNotRunning, err) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_PublishFromBridge_Ugly(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + seen := false + stop := hub.SubscribePublished(func(string, []byte) { seen = true }) + defer stop() + + AssertNoError(t, hub.PublishFromBridge("block", []byte("template"))) + Sleep(20 * Millisecond) + AssertFalse(t, seen) +} + +func TestAX7_Hub_PublishFromPeer_Bad(t *T) { + hub := NewHub() + peer := NewPeer("ws") + + err := hub.PublishFromPeer(peer, "block", []byte("template")) + AssertEqual(t, ErrHubNotRunning, err) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_PublishFromPeer_Ugly(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + source := NewPeer("ws") + receiver := NewPeer("ws") + AssertNoError(t, hub.AddPeer(source)) + AssertNoError(t, hub.AddPeer(receiver)) + AssertNoError(t, hub.SubscribePeer(receiver, "block")) + + AssertNoError(t, hub.PublishFromPeer(source, "block", []byte("template"))) + AssertEqual(t, "template", string(<-receiver.SendQueue())) +} + +func TestAX7_Hub_RemovePeer_Good(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + + hub.RemovePeer(peer) + AssertEqual(t, 0, hub.PeerCount()) + AssertEqual(t, []string{}, peer.Subscriptions()) +} + +func TestAX7_Hub_RemovePeer_Bad(t *T) { + var hub *Hub + peer := NewPeer("ws") + + AssertNotPanics(t, func() { hub.RemovePeer(peer) }) + AssertEqual(t, "ws", peer.Transport) +} + +func TestAX7_Hub_RemovePeer_Ugly(t *T) { + hub := NewHub() + peer := NewPeer("ws") + + AssertNotPanics(t, func() { hub.RemovePeer(peer) }) + AssertEqual(t, 0, hub.PeerCount()) +} + +func TestAX7_Hub_SendToChannel_Good(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + received := make(chan []byte, 1) + stop := hub.Subscribe("hashrate", func(frame []byte) { received <- append([]byte(nil), frame...) }) + defer stop() + + AssertNoError(t, hub.SendToChannel("hashrate", []byte("123"))) + AssertEqual(t, "123", string(<-received)) +} + +func TestAX7_Hub_SendToChannel_Bad(t *T) { + var hub *Hub + + err := hub.SendToChannel("hashrate", []byte("123")) + AssertError(t, err) + AssertContains(t, err.Error(), "nil hub") +} + +func TestAX7_Hub_SendToChannel_Ugly(t *T) { + hub := NewHub() + + err := hub.SendToChannel("hashrate", []byte("123")) + AssertEqual(t, ErrHubNotRunning, err) + AssertEqual(t, 0, hub.PeerCount()) +} + +func TestAX7_Hub_Stats_Good(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + AssertNoError(t, hub.SubscribePeer(peer, "hashrate")) + + stats := hub.Stats() + AssertEqual(t, 1, stats.Peers) + AssertEqual(t, 1, stats.SubscriberCount["hashrate"]) +} + +func TestAX7_Hub_Stats_Bad(t *T) { + var hub *Hub + + stats := hub.Stats() + AssertEqual(t, 0, stats.Peers) + AssertEqual(t, 0, stats.Channels) +} + +func TestAX7_Hub_Stats_Ugly(t *T) { + hub := NewHub() + stop := hub.Subscribe("events", func([]byte) {}) + defer stop() + + stats := hub.Stats() + AssertEqual(t, 1, stats.Channels) + AssertEqual(t, 1, stats.SubscriberCount["events"]) +} + +func TestAX7_Hub_SubscribeBroadcast_Good(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + received := make(chan []byte, 1) + stop := hub.SubscribeBroadcast(func(frame []byte) { received <- append([]byte(nil), frame...) }) + defer stop() + + AssertNoError(t, hub.Broadcast([]byte("shutdown"))) + AssertEqual(t, "shutdown", string(<-received)) +} + +func TestAX7_Hub_SubscribeBroadcast_Bad(t *T) { + var hub *Hub + + stop := hub.SubscribeBroadcast(func([]byte) {}) + AssertNotNil(t, stop) + AssertNotPanics(t, stop) +} + +func TestAX7_Hub_SubscribeBroadcast_Ugly(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + received := make(chan []byte, 1) + stop := hub.SubscribeBroadcast(func(frame []byte) { received <- frame }) + stop() + + AssertNoError(t, hub.Broadcast([]byte("shutdown"))) + select { + case frame := <-received: + t.Fatalf("received after unsubscribe: %q", string(frame)) + case <-ax7Timeout(20 * Millisecond): + } +} + +func TestAX7_Hub_SubscribePeer_Good(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + + AssertNoError(t, hub.SubscribePeer(peer, "hashrate")) + AssertEqual(t, []string{"hashrate"}, peer.Subscriptions()) +} + +func TestAX7_Hub_SubscribePeer_Bad(t *T) { + hub := NewHub() + + err := hub.SubscribePeer(nil, "hashrate") + AssertError(t, err) + AssertContains(t, err.Error(), "nil peer") +} + +func TestAX7_Hub_SubscribePeer_Ugly(t *T) { + hub := NewHubWithConfig(HubConfig{ChannelAuthoriser: func(*Peer, string) bool { return false }}) + peer := NewPeer("ws") + + err := hub.SubscribePeer(peer, "private") + AssertEqual(t, ErrAuthRejected, err) + AssertEqual(t, []string{}, peer.Subscriptions()) +} + +func TestAX7_Hub_SubscribePublished_Good(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + received := make(chan string, 1) + stop := hub.SubscribePublished(func(channel string, frame []byte) { received <- channel + ":" + string(frame) }) + defer stop() + + AssertNoError(t, hub.Publish("block", []byte("template"))) + AssertEqual(t, "block:template", <-received) +} + +func TestAX7_Hub_SubscribePublished_Bad(t *T) { + var hub *Hub + + stop := hub.SubscribePublished(func(string, []byte) {}) + AssertNotNil(t, stop) + AssertNotPanics(t, stop) +} + +func TestAX7_Hub_SubscribePublished_Ugly(t *T) { + hub, cancel := ax7RunningHub(t) + defer cancel() + received := make(chan string, 1) + stop := hub.SubscribePublished(func(channel string, frame []byte) { received <- channel }) + stop() + + AssertNoError(t, hub.Publish("block", []byte("template"))) + select { + case channel := <-received: + t.Fatalf("received after unsubscribe: %q", channel) + case <-ax7Timeout(20 * Millisecond): + } +} + +func TestAX7_Hub_SubscribeWithError_Bad(t *T) { + hub := NewHub() + + stop, err := hub.SubscribeWithError("", func([]byte) {}) + AssertEqual(t, ErrEmptyChannel, err) + AssertNotNil(t, stop) +} + +func TestAX7_Hub_SubscribeWithError_Ugly(t *T) { + hub := NewHub() + + stop, err := hub.SubscribeWithError("events", nil) + AssertError(t, err) + AssertNotNil(t, stop) +} + +func TestAX7_Hub_UnsubscribePeer_Good(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + AssertNoError(t, hub.SubscribePeer(peer, "block")) + + hub.UnsubscribePeer(peer, "block") + AssertEqual(t, []string{}, peer.Subscriptions()) + AssertEqual(t, 0, hub.ChannelSubscriberCount("block")) +} + +func TestAX7_Hub_UnsubscribePeer_Bad(t *T) { + hub := NewHub() + + AssertNotPanics(t, func() { hub.UnsubscribePeer(nil, "block") }) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_UnsubscribePeer_Ugly(t *T) { + hub := NewHub() + peer := NewPeer("ws") + AssertNoError(t, hub.AddPeer(peer)) + + AssertNotPanics(t, func() { hub.UnsubscribePeer(peer, "") }) + AssertEqual(t, 0, len(peer.Subscriptions())) +} + +func TestAX7_NewAPIKeyAuth_Good(t *T) { + authenticator := NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) + + AssertNotNil(t, authenticator) + AssertEqual(t, "user-42", authenticator.Keys["sk-live"]) + AssertEqual(t, 1, len(authenticator.Keys)) +} + +func TestAX7_NewAPIKeyAuth_Bad(t *T) { + authenticator := NewAPIKeyAuth(nil) + + AssertNotNil(t, authenticator) + AssertNotNil(t, authenticator.Keys) + AssertEqual(t, 0, len(authenticator.Keys)) +} + +func TestAX7_NewAPIKeyAuth_Ugly(t *T) { + keys := map[string]string{"sk-live": "user-42"} + authenticator := NewAPIKeyAuth(keys) + keys["sk-live"] = "mutated" + + AssertEqual(t, "user-42", authenticator.Keys["sk-live"]) + AssertEqual(t, "mutated", keys["sk-live"]) +} + +func TestAX7_NewHub_Good(t *T) { + hub := NewHub() + + AssertNotNil(t, hub) + AssertFalse(t, hub.Running()) + AssertEqual(t, 0, hub.PeerCount()) +} + +func TestAX7_NewHub_Bad(t *T) { + hub := NewHub() + + AssertNotNil(t, hub.Config()) + AssertEqual(t, 30*Second, hub.Config().HeartbeatInterval) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_NewHub_Ugly(t *T) { + left := NewHub() + right := NewHub() + + AssertNotEqual(t, left, right) + AssertNotNil(t, left.done) + AssertNotNil(t, right.done) +} + +func TestAX7_NewHubWithConfig_Good(t *T) { + hub := NewHubWithConfig(HubConfig{HeartbeatInterval: Second, PongTimeout: 3 * Second}) + + AssertNotNil(t, hub) + AssertEqual(t, Second, hub.Config().HeartbeatInterval) + AssertEqual(t, 3*Second, hub.Config().PongTimeout) +} + +func TestAX7_NewHubWithConfig_Bad(t *T) { + hub := NewHubWithConfig(HubConfig{}) + + AssertEqual(t, 30*Second, hub.Config().HeartbeatInterval) + AssertEqual(t, 60*Second, hub.Config().PongTimeout) + AssertEqual(t, 10*Second, hub.Config().WriteTimeout) +} + +func TestAX7_NewHubWithConfig_Ugly(t *T) { + called := false + hub := NewHubWithConfig(HubConfig{OnConnect: func(*Peer) { called = true }}) + + AssertNoError(t, hub.AddPeer(NewPeer("ws"))) + AssertTrue(t, called) +} + +func TestAX7_Peer_Close_Bad(t *T) { + var peer *Peer + + AssertNotPanics(t, func() { peer.Close() }) + AssertNil(t, peer) +} + +func TestAX7_Peer_SendQueue_Good(t *T) { + peer := NewPeer("ws") + queue := peer.SendQueue() + + AssertNotNil(t, queue) + AssertTrue(t, peer.Send([]byte("frame"))) + AssertEqual(t, "frame", string(<-queue)) +} + +func TestAX7_Peer_SendQueue_Ugly(t *T) { + peer := NewPeer("ws") + queue := peer.SendQueue() + peer.Close() + + _, ok := <-queue + AssertFalse(t, ok) + AssertFalse(t, peer.Send([]byte("late"))) +} + +func TestAX7_Peer_SetCloseHook_Ugly(t *T) { + peer := NewPeer("ws") + count := 0 + peer.SetCloseHook(func() { count++ }) + peer.SetCloseHook(func() { count += 10 }) + + peer.Close() + AssertEqual(t, 10, count) + AssertEqual(t, []string{}, peer.Subscriptions()) +} + +func TestAX7_QueryTokenAuth_Authenticate_Good(t *T) { + authenticator := &QueryTokenAuth{Validate: func(token string) AuthResult { + return AuthResult{Valid: token == "query-token", UserID: "browser"} + }} + request := NewHTTPTestRequest("GET", "/stream/ws?token=query-token", nil) + + result := authenticator.Authenticate(request) + AssertTrue(t, result.Valid) + AssertEqual(t, "browser", result.UserID) +} + +func TestAX7_QueryTokenAuth_Authenticate_Bad(t *T) { + authenticator := &QueryTokenAuth{Validate: func(token string) AuthResult { + return AuthResult{Valid: token == "query-token"} + }} + request := NewHTTPTestRequest("GET", "/stream/ws", nil) + + result := authenticator.Authenticate(request) + AssertFalse(t, result.Valid) + AssertEqual(t, "", result.UserID) +} + +func TestAX7_QueryTokenAuth_Authenticate_Ugly(t *T) { + var authenticator *QueryTokenAuth + request := NewHTTPTestRequest("GET", "/stream/ws?token=query-token", nil) + + result := authenticator.Authenticate(request) + AssertFalse(t, result.Valid) + AssertNil(t, result.Claims) +} diff --git a/docs/specs/core/go/RFC.md b/docs/specs/core/go/RFC.md new file mode 100644 index 0000000..adefbfb --- /dev/null +++ b/docs/specs/core/go/RFC.md @@ -0,0 +1,5 @@ +# go-stream RFC mirror + +The canonical implementation spec for this module is [`docs/RFC.md`](/workspace/docs/RFC.md). +Keep this path as a compatibility mirror for agents and tooling that expect the +`docs/specs/core/go/RFC.md` location. diff --git a/docs/specs/rfc/RFC-CORE-008-AGENT-EXPERIENCE.md b/docs/specs/rfc/RFC-CORE-008-AGENT-EXPERIENCE.md new file mode 100644 index 0000000..5d8d511 --- /dev/null +++ b/docs/specs/rfc/RFC-CORE-008-AGENT-EXPERIENCE.md @@ -0,0 +1,6 @@ +# Agent Experience RFC mirror + +The canonical AX design principles for this repository are in +[`docs/RFC-025-AGENT-EXPERIENCE.md`](/workspace/docs/RFC-025-AGENT-EXPERIENCE.md). +Keep this path as a compatibility mirror for agents and tooling that expect the +`docs/specs/rfc/RFC-CORE-008-AGENT-EXPERIENCE.md` location. diff --git a/errors.go b/errors.go index 4872e36..784a015 100644 --- a/errors.go +++ b/errors.go @@ -2,29 +2,44 @@ package stream -import "dappco.re/go/core" +import "dappco.re/go" -// Sentinel errors for the stream package. All errors use core.E(). +// if err := hub.Publish("hashrate", frame); err == ErrHubNotRunning { +// return +// } var ( - // ErrMissingAuthHeader is returned when no Authorization header is present. + // if err := auth.Authenticate(request); err == stream.ErrMissingAuthHeader { + // http.Error(w, "missing auth", http.StatusUnauthorized) + // } ErrMissingAuthHeader = core.E("stream.auth", "missing Authorization header", nil) - // ErrMalformedAuthHeader is returned when the header is not "Bearer ". + // if err := auth.Authenticate(request); err == stream.ErrMalformedAuthHeader { + // http.Error(w, "bad auth header", http.StatusUnauthorized) + // } ErrMalformedAuthHeader = core.E("stream.auth", "malformed Authorization header", nil) - // ErrInvalidAPIKey is returned when the API key is not in the key map. + // if err := auth.Authenticate(request); err == stream.ErrInvalidAPIKey { + // http.Error(w, "unknown key", http.StatusUnauthorized) + // } ErrInvalidAPIKey = core.E("stream.auth", "invalid API key", nil) - // ErrHandshakeTimeout is returned when the TCP/ZMQ peer did not send a - // handshake within the configured deadline. + // if err := adapter.Listen(ctx); err == stream.ErrHandshakeTimeout { + // return + // } ErrHandshakeTimeout = core.E("stream.auth", "handshake timeout", nil) - // ErrAuthRejected is returned when ConnAuthenticator denies the handshake. + // if err := adapter.Listen(ctx); err == stream.ErrAuthRejected { + // return + // } ErrAuthRejected = core.E("stream.auth", "connection rejected by authenticator", nil) - // ErrHubNotRunning is returned when Publish or Broadcast is called before Run. + // if err := hub.Publish("hashrate", frame); err == stream.ErrHubNotRunning { + // go hub.Run(ctx) + // } ErrHubNotRunning = core.E("stream.hub", "hub not running", nil) - // ErrEmptyChannel is returned when Subscribe is called with an empty channel name. + // if _, err := hub.SubscribeE("", func([]byte) {}); err == stream.ErrEmptyChannel { + // return + // } ErrEmptyChannel = core.E("stream.hub", "empty channel", nil) ) diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..30d45eb --- /dev/null +++ b/example_test.go @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream_test + +import ( + "context" + "net/http/httptest" + "time" + + "dappco.re/go" + "dappco.re/go/stream" +) + +func ExampleNewHub() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hub := stream.NewHub() + go hub.Run(ctx) + + received := make(chan string, 1) + stop := hub.Subscribe("hashrate", func(frame []byte) { + received <- string(frame) + }) + defer stop() + + waitForHub(hub) + _ = hub.Publish("hashrate", []byte(`{"h":123456}`)) + + select { + case frame := <-received: + core.Print(nil, "%s", frame) + case <-time.After(time.Second): + core.Print(nil, "%s", "timeout") + } + + // Output: + // {"h":123456} +} + +func ExamplePipe() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sourceHub := stream.NewHub() + destinationHub := stream.NewHub() + go sourceHub.Run(ctx) + go destinationHub.Run(ctx) + + received := make(chan string, 1) + stopSubscribe := destinationHub.Subscribe("block", func(frame []byte) { + received <- string(frame) + }) + defer stopSubscribe() + + stopPipe := stream.Pipe(sourceHub, destinationHub) + defer stopPipe() + + waitForHub(sourceHub) + waitForHub(destinationHub) + _ = sourceHub.Publish("block", []byte(`{"height":42}`)) + + select { + case frame := <-received: + core.Print(nil, "%s", frame) + case <-time.After(time.Second): + core.Print(nil, "%s", "timeout") + } + + // Output: + // {"height":42} +} + +func ExampleHub_Stats() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hub := stream.NewHub() + go hub.Run(ctx) + + peer := stream.NewPeer("ws") + _ = hub.AddPeer(peer) + defer hub.RemovePeer(peer) + + _ = hub.SubscribePeer(peer, "hashrate") + + stats := hub.Stats() + core.Print(nil, "%d %d %d", stats.Peers, stats.Channels, stats.SubscriberCount["hashrate"]) + + // Output: + // 1 1 1 +} + +func waitForHub(hub *stream.Hub) { + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if hub.Broadcast(nil) == nil { + return + } + time.Sleep(10 * time.Millisecond) + } +} + +func ExampleNewAPIKeyAuth() { + authenticator := stream.NewAPIKeyAuth(map[string]string{ + "sk-live": "user-42", + }) + + request := httptest.NewRequest("GET", "http://example.com/stream/ws", nil) + request.Header.Set("Authorization", "Bearer sk-live") + + result := authenticator.Authenticate(request) + core.Print(nil, "%t %s", result.Valid, result.UserID) + + // Output: + // true user-42 +} + +func ExampleMessage() { + msg := stream.Message{ + Type: stream.TypeEvent, + Channel: "hashrate", + ProcessID: "agent-42", + Data: map[string]any{"h": 1234567}, + } + + core.Print(nil, "%s %s %s %v", msg.Type, msg.Channel, msg.ProcessID, msg.Data) + + // Output: + // event hashrate agent-42 map[h:1234567] +} + +func ExampleMessageType() { + core.Print(nil, "%s", stream.TypeSubscribe) + + // Output: + // subscribe +} diff --git a/go.mod b/go.mod index 9c05138..1daf5b6 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,20 @@ module dappco.re/go/stream go 1.26.0 -require dappco.re/go/core v0.8.0-alpha.1 +require ( + github.com/alicebob/miniredis/v2 v2.37.0 + github.com/go-zeromq/zmq4 v0.17.0 + github.com/gorilla/websocket v1.5.3 + github.com/redis/go-redis/v9 v9.18.0 +) -require github.com/gorilla/websocket v1.5.3 +require ( + dappco.re/go v0.9.0 + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-zeromq/goczmq/v4 v4.2.2 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect + go.uber.org/atomic v1.11.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/text v0.15.0 // indirect +) diff --git a/go.sum b/go.sum index 41034bd..a171336 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,38 @@ -dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk= -dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +dappco.re/go v0.9.0 h1:4ruZRNqKDDva8o6g65tYggjGVe42E6/lMZfVKXtr3p0= +dappco.re/go v0.9.0/go.mod h1:xapr7fLK4/9Pu2iSCr4qZuIuatmtx1j56zS/oPDbGyQ= +github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68= +github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/go-zeromq/goczmq/v4 v4.2.2 h1:HAJN+i+3NW55ijMJJhk7oWxHKXgAuSBkoFfvr8bYj4U= +github.com/go-zeromq/goczmq/v4 v4.2.2/go.mod h1:Sm/lxrfxP/Oxqs0tnHD6WAhwkWrx+S+1MRrKzcxoaYE= +github.com/go-zeromq/zmq4 v0.17.0 h1:r12/XdqPeRbuaF4C3QZJeWCt7a5vpJbslDH1rTXF+Kc= +github.com/go-zeromq/zmq4 v0.17.0/go.mod h1:EQxjJD92qKnrsVMzAnx62giD6uJIPi1dMGZ781iCDtY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/hub.go b/hub.go index 70098ab..2640f83 100644 --- a/hub.go +++ b/hub.go @@ -1,191 +1,253 @@ // SPDX-License-Identifier: EUPL-1.2 +// hub := stream.NewHub() +// go hub.Run(ctx) +// stop := hub.Pipe(remoteHub) +// defer stop() package stream import ( "context" "iter" + "sort" "sync" - "dappco.re/go/core" + "dappco.re/go" ) -// Hub is the central channel-based broker. Transport adapters register peers into -// the hub; the hub serialises all state mutations through Go channels. -// -// hub := stream.NewHub() -// go hub.Run(ctx) +const defaultHubQueueSize = 256 + +// hub := stream.NewHub() +// go hub.Run(ctx) // -// wsAdapter := ws.New(ws.Config{Authenticator: auth}) -// wsAdapter.Mount(hub) -// http.Handle("/stream/ws", wsAdapter.Handler()) +// wsAdapter := ws.New(ws.Config{Authenticator: auth}) +// wsAdapter.Mount(hub) +// http.Handle("/stream/ws", wsAdapter.Handler()) type Hub struct { - peers map[*Peer]bool - broadcast chan []byte - register chan *Peer - unregister chan *Peer - channels map[string]map[*Peer]bool - handlers map[string]map[uint64]func([]byte) - nextID uint64 - config HubConfig - done chan struct{} - doneOnce sync.Once - running bool - mu sync.RWMutex -} - -// NewHub creates a hub with default configuration. -// -// hub := stream.NewHub() -// go hub.Run(ctx) + peers map[*Peer]bool + broadcastQueue chan broadcastDelivery + publishQueue chan publishDelivery + register chan *Peer + unregister chan *Peer + channels map[string]map[*Peer]bool + channelHandlers map[string]map[uint64]func([]byte) + broadcastHandlers map[uint64]func([]byte) + publishHandlers map[uint64]func(string, []byte) + nextHandlerID uint64 + config HubConfig + done chan struct{} + doneOnce sync.Once + running bool + mutex sync.RWMutex +} + +// hub := stream.NewHub() +// go hub.Run(ctx) func NewHub() *Hub { return NewHubWithConfig(DefaultHubConfig()) } -// NewHubWithConfig creates a hub with the given configuration. -// // hub := stream.NewHubWithConfig(stream.HubConfig{ -// HeartbeatInterval: 30 * time.Second, -// OnConnect: func(p *stream.Peer) { log.Println("connected", p.ID) }, +// HeartbeatInterval: 30 * time.Second, +// OnConnect: func(peer *stream.Peer) { core.Print("stream", "connected %s", peer.ID) }, // }) func NewHubWithConfig(config HubConfig) *Hub { - if config.HeartbeatInterval == 0 { - config.HeartbeatInterval = DefaultHubConfig().HeartbeatInterval - } - if config.PongTimeout == 0 { - config.PongTimeout = DefaultHubConfig().PongTimeout - } - if config.WriteTimeout == 0 { - config.WriteTimeout = DefaultHubConfig().WriteTimeout - } + config = normalizeHubConfig(config) return &Hub{ - peers: map[*Peer]bool{}, - broadcast: make(chan []byte, 256), - register: make(chan *Peer, 256), - unregister: make(chan *Peer, 256), - channels: map[string]map[*Peer]bool{}, - handlers: map[string]map[uint64]func([]byte){}, - config: config, - done: make(chan struct{}), + peers: map[*Peer]bool{}, + broadcastQueue: make(chan broadcastDelivery, defaultHubQueueSize), + publishQueue: make(chan publishDelivery, defaultHubQueueSize), + register: make(chan *Peer, defaultHubQueueSize), + unregister: make(chan *Peer, defaultHubQueueSize), + channels: map[string]map[*Peer]bool{}, + channelHandlers: map[string]map[uint64]func([]byte){}, + broadcastHandlers: map[uint64]func([]byte){}, + publishHandlers: map[uint64]func(string, []byte){}, + config: config, + done: make(chan struct{}), } } -// Run starts the hub's select loop. Call in a goroutine. Exits when ctx is cancelled. -// -// go hub.Run(ctx) -func (h *Hub) Run(ctx context.Context) { - if h == nil { +// config := hub.Config() +// writeTimeout := config.WriteTimeout +func (hub *Hub) Config() HubConfig { + if hub == nil { + return DefaultHubConfig() + } + hub.mutex.RLock() + config := hub.config + hub.mutex.RUnlock() + return normalizeHubConfig(config) +} + +// if hub.Running() { _ = hub.Publish("hashrate", frame) } +func (hub *Hub) Running() bool { + if hub == nil { + return false + } + hub.mutex.RLock() + defer hub.mutex.RUnlock() + return hub.running +} + +// go hub.Run(ctx) +func (hub *Hub) Run(ctx context.Context) { + if hub == nil { return } if ctx == nil { ctx = context.Background() } - h.mu.Lock() - if h.running { - h.mu.Unlock() - <-ctx.Done() + hub.mutex.Lock() + if hub.running { + hub.mutex.Unlock() return } - h.running = true - h.mu.Unlock() + hub.running = true + hub.mutex.Unlock() + + defer func() { + hub.mutex.Lock() + peers := make([]*Peer, 0, len(hub.peers)) + for peer := range hub.peers { + peers = append(peers, peer) + } + hub.running = false + hub.mutex.Unlock() - <-ctx.Done() + for _, peer := range peers { + hub.removePeer(peer) + } - h.mu.Lock() - peers := make([]*Peer, 0, len(h.peers)) - for peer := range h.peers { - peers = append(peers, peer) - } - h.running = false - h.mu.Unlock() + hub.doneOnce.Do(func() { + close(hub.done) + }) + }() - for _, peer := range peers { - h.RemovePeer(peer) + for { + select { + case <-ctx.Done(): + return + case peer := <-hub.register: + hub.addPeer(peer) + case peer := <-hub.unregister: + hub.removePeer(peer) + case item := <-hub.broadcastQueue: + hub.broadcastToPeers(item.source, item.frame, item.notifyBroadcastSubscribers) + case item := <-hub.publishQueue: + hub.processPublishDelivery(item.channel, item.frame, item.notifyPublishSubscribers) + } } +} - h.doneOnce.Do(func() { - close(h.done) - }) +// hub.SendToChannel("hashrate", []byte(`{"h":123456}`)) +func (hub *Hub) SendToChannel(channel string, frame []byte) error { + return hub.sendToChannel(channel, frame, true) } -// SendToChannel delivers frame to all peers subscribed to channel. -// Returns nil if channel has no subscribers (not an error). -// -// hub.SendToChannel("process:abc123", frame) -func (h *Hub) SendToChannel(channel string, frame []byte) error { - if h == nil { +// _ = hub.PublishFromPeer(peer, "block", []byte("template")) +func (hub *Hub) PublishFromPeer(source *Peer, channel string, frame []byte) error { + return hub.sendToChannelFromPeer(source, channel, frame, true) +} + +// _ = hub.PublishFromBridge("block", []byte("template")) +func (hub *Hub) PublishFromBridge(channel string, frame []byte) error { + return hub.sendToChannel(channel, frame, false) +} + +// _ = hub.sendToChannel("hashrate", []byte("123456"), true) +func (hub *Hub) sendToChannel(channel string, frame []byte, notifyPublishSubscribers bool) error { + return hub.sendToChannelFromPeer(nil, channel, frame, notifyPublishSubscribers) +} + +// _ = hub.sendToChannelFromPeer(peer, "hashrate", []byte("123456"), true) +func (hub *Hub) sendToChannelFromPeer(source *Peer, channel string, frame []byte, notifyPublishSubscribers bool) error { + if hub == nil { return core.E("stream.hub", "nil hub", nil) } - h.mu.RLock() - running := h.running - peers := h.channels[channel] - wildcardPeers := h.channels["*"] - if channel == "*" { - wildcardPeers = nil - } - handlers := cloneHandlers(h.handlers[channel]) - wildcardHandlers := cloneHandlers(h.handlers["*"]) - h.mu.RUnlock() + hub.mutex.RLock() + running := hub.running + peersToSend := hub.collectChannelPeersLocked(channel, source) + hasHandlers := len(hub.channelHandlers[channel]) > 0 + hasWildcardHandlers := len(hub.channelHandlers["*"]) > 0 && channel != "*" + hasPublishers := notifyPublishSubscribers && len(hub.publishHandlers) > 0 + hub.mutex.RUnlock() if !running { return ErrHubNotRunning } - if len(peers) == 0 && len(handlers) == 0 && len(wildcardHandlers) == 0 { + if len(peersToSend) == 0 && !hasHandlers && !hasWildcardHandlers && !hasPublishers { return nil } - for peer := range peers { - h.sendToPeer(peer, channel, frame) + for _, peer := range peersToSend { + hub.sendToPeer(peer, channel, frame) } - for peer := range wildcardPeers { - h.sendToPeer(peer, channel, frame) - } - h.invokeHandlers(handlers, frame) - h.invokeHandlers(wildcardHandlers, frame) + hub.enqueuePublishDelivery(channel, frame, notifyPublishSubscribers) return nil } -// Subscribe registers a handler function invoked for every frame arriving on channel. -// Returns an unsubscribe function. Multiple handlers per channel are allowed. -// Handlers run in the hub's goroutine — keep them non-blocking. +// unsub, err := hub.SubscribeWithError("block", func(frame []byte) { +// handleBlock(frame) +// }) // -// unsub := hub.Subscribe("block", func(f []byte) { ... }) -// defer unsub() -func (h *Hub) Subscribe(channel string, handler func([]byte)) func() { - if h == nil || channel == "" || handler == nil { - return func() {} +// if err != nil { +// return err +// } +// +// defer unsub() +func (hub *Hub) SubscribeWithError(channel string, handler func([]byte)) (func(), error) { + if hub == nil { + return func() {}, core.E("stream.hub", "nil hub", nil) + } + if channel == "" { + return func() {}, ErrEmptyChannel } - h.mu.Lock() - if h.handlers == nil { - h.handlers = map[string]map[uint64]func([]byte){} + if handler == nil { + return func() {}, core.E("stream.hub", "nil handler", nil) } - if h.channels == nil { - h.channels = map[string]map[*Peer]bool{} + hub.mutex.Lock() + if hub.channelHandlers == nil { + hub.channelHandlers = map[string]map[uint64]func([]byte){} } - h.nextID++ - id := h.nextID - if h.handlers[channel] == nil { - h.handlers[channel] = map[uint64]func([]byte){} + if hub.channels == nil { + hub.channels = map[string]map[*Peer]bool{} } - h.handlers[channel][id] = handler - h.mu.Unlock() + hub.nextHandlerID++ + id := hub.nextHandlerID + if hub.channelHandlers[channel] == nil { + hub.channelHandlers[channel] = map[uint64]func([]byte){} + } + hub.channelHandlers[channel][id] = handler + hub.mutex.Unlock() - return func() { - h.mu.Lock() - defer h.mu.Unlock() - if handlers := h.handlers[channel]; handlers != nil { + return onceFunction(func() { + hub.mutex.Lock() + defer hub.mutex.Unlock() + if handlers := hub.channelHandlers[channel]; handlers != nil { delete(handlers, id) if len(handlers) == 0 { - delete(h.handlers, channel) + delete(hub.channelHandlers, channel) } } - } + }), nil } -// SubscribePeer adds peer to a named channel. Used by transport adapters when -// a peer requests channel subscription (WebSocket TypeSubscribe message, etc.). -// -// hub.SubscribePeer(peer, "hashrate") -func (h *Hub) SubscribePeer(peer *Peer, channel string) error { - if h == nil { +// unsub, err := hub.SubscribeE("block", func(frame []byte) { +// handleBlock(frame) +// }) +func (hub *Hub) SubscribeE(channel string, handler func([]byte)) (func(), error) { + return hub.SubscribeWithError(channel, handler) +} + +// unsubscribe := hub.Subscribe("block", func(frame []byte) { handleBlock(frame) }) +// defer unsubscribe() +func (hub *Hub) Subscribe(channel string, handler func([]byte)) func() { + unsub, _ := hub.SubscribeWithError(channel, handler) + return unsub +} + +// _ = hub.SubscribePeer(peer, "hashrate") +func (hub *Hub) SubscribePeer(peer *Peer, channel string) error { + if hub == nil { return core.E("stream.hub", "nil hub", nil) } if peer == nil { @@ -194,133 +256,217 @@ func (h *Hub) SubscribePeer(peer *Peer, channel string) error { if channel == "" { return ErrEmptyChannel } - h.mu.Lock() - defer h.mu.Unlock() - if h.config.ChannelAuthoriser != nil && channel != "*" && !h.config.ChannelAuthoriser(peer, channel) { + hub.mutex.Lock() + defer hub.mutex.Unlock() + if hub.config.ChannelAuthoriser != nil && channel != "*" && !hub.config.ChannelAuthoriser(peer, channel) { return ErrAuthRejected } + peer.mutex.Lock() if peer.send == nil { - peer.send = make(chan []byte, 256) + peer.send = make(chan []byte, defaultPeerSendBufferSize) } if peer.subscriptions == nil { peer.subscriptions = map[string]bool{} } peer.subscriptions[channel] = true - if h.channels[channel] == nil { - h.channels[channel] = map[*Peer]bool{} + peer.mutex.Unlock() + if hub.channels[channel] == nil { + hub.channels[channel] = map[*Peer]bool{} } - h.channels[channel][peer] = true + hub.channels[channel][peer] = true return nil } -// UnsubscribePeer removes peer from a named channel. -// -// hub.UnsubscribePeer(peer, "hashrate") -func (h *Hub) UnsubscribePeer(peer *Peer, channel string) { - if h == nil || peer == nil || channel == "" { +// err := hub.CanSubscribePeer(peer, "hashrate") +func (hub *Hub) CanSubscribePeer(peer *Peer, channel string) error { + if hub == nil { + return core.E("stream.hub", "nil hub", nil) + } + if peer == nil { + return core.E("stream.hub", "nil peer", nil) + } + if channel == "" { + return ErrEmptyChannel + } + hub.mutex.RLock() + defer hub.mutex.RUnlock() + if hub.config.ChannelAuthoriser != nil && channel != "*" && !hub.config.ChannelAuthoriser(peer, channel) { + return ErrAuthRejected + } + return nil +} + +// hub.UnsubscribePeer(peer, "hashrate") +func (hub *Hub) UnsubscribePeer(peer *Peer, channel string) { + if hub == nil || peer == nil || channel == "" { return } - h.mu.Lock() - defer h.mu.Unlock() + hub.mutex.Lock() + defer hub.mutex.Unlock() + peer.mutex.Lock() delete(peer.subscriptions, channel) - if peers := h.channels[channel]; peers != nil { + peer.mutex.Unlock() + if peers := hub.channels[channel]; peers != nil { delete(peers, peer) if len(peers) == 0 { - delete(h.channels, channel) + delete(hub.channels, channel) } } } -// Publish sends frame to all subscribers of channel. Satisfies Stream interface. -// -// hub.Publish("hashrate", frame) -func (h *Hub) Publish(channel string, frame []byte) error { - return h.SendToChannel(channel, frame) +// _ = hub.Publish("hashrate", []byte(`{"h":123456}`)) +func (hub *Hub) Publish(channel string, frame []byte) error { + return hub.sendToChannel(channel, frame, true) } -// Broadcast sends frame to every connected peer regardless of subscriptions. -// Satisfies Stream interface. -// -// hub.Broadcast([]byte(`{"type":"shutdown"}`)) -func (h *Hub) Broadcast(frame []byte) error { - if h == nil { +// _ = hub.Broadcast([]byte(`{"type":"shutdown"}`)) +func (hub *Hub) Broadcast(frame []byte) error { + return hub.broadcastFrame(frame, true) +} + +// _ = hub.BroadcastFromPeer(peer, []byte("shutdown")) +func (hub *Hub) BroadcastFromPeer(source *Peer, frame []byte) error { + return hub.broadcastFrameFromPeer(source, frame, true) +} + +// _ = hub.BroadcastFromBridge([]byte("shutdown")) +func (hub *Hub) BroadcastFromBridge(frame []byte) error { + return hub.broadcastFrame(frame, false) +} + +// _ = hub.broadcastFrame(frame, true) +func (hub *Hub) broadcastFrame(frame []byte, notifyBroadcastSubscribers bool) error { + return hub.broadcastFrameFromPeer(nil, frame, notifyBroadcastSubscribers) +} + +// _ = hub.broadcastFrameFromPeer(peer, frame, true) +func (hub *Hub) broadcastFrameFromPeer(source *Peer, frame []byte, notifyBroadcastSubscribers bool) error { + if hub == nil { return core.E("stream.hub", "nil hub", nil) } - h.mu.RLock() - running := h.running - peers := make([]*Peer, 0, len(h.peers)) - for peer := range h.peers { - peers = append(peers, peer) - } - handlers := cloneHandlers(h.handlers["*"]) - h.mu.RUnlock() + hub.mutex.RLock() + running := hub.running + hub.mutex.RUnlock() if !running { return ErrHubNotRunning } - for _, peer := range peers { - h.sendBroadcastToPeer(peer, frame) + select { + case hub.broadcastQueue <- broadcastDelivery{ + source: source, + frame: append([]byte(nil), frame...), + notifyBroadcastSubscribers: notifyBroadcastSubscribers, + }: + return nil + default: + go hub.enqueueBroadcast(broadcastDelivery{ + source: source, + frame: append([]byte(nil), frame...), + notifyBroadcastSubscribers: notifyBroadcastSubscribers, + }) } - h.invokeHandlers(handlers, frame) return nil } -// Pipe connects this hub to dst: every frame published here is forwarded to dst. -// Returns a stop function. Satisfies Stream interface. -// -// stop := hub.Pipe(remoteHub) -// defer stop() -func (h *Hub) Pipe(dst Stream) func() { - return Pipe(h, dst) +// stop := hub.Pipe(remoteHub) +func (hub *Hub) Pipe(dst Stream) func() { + return Pipe(hub, dst) } -// Stats returns a snapshot of current hub state. -// -// s := hub.Stats() -// core.Print("stream", "peers=%d channels=%d", s.Peers, s.Channels) -func (h *Hub) Stats() HubStats { - if h == nil { +// stats := hub.Stats() +func (hub *Hub) Stats() HubStats { + if hub == nil { return HubStats{} } - h.mu.RLock() - defer h.mu.RUnlock() + hub.mutex.RLock() + defer hub.mutex.RUnlock() subscriberCount := map[string]int{} - for channel, peers := range h.channels { + for channel, peers := range hub.channels { if channel == "*" { continue } - subscriberCount[channel] = len(peers) + subscriberCount[channel] = len(peers) + len(hub.channelHandlers[channel]) + } + for channel, handlers := range hub.channelHandlers { + if channel == "*" { + continue + } + if _, exists := subscriberCount[channel]; exists { + continue + } + subscriberCount[channel] = len(handlers) } return HubStats{ - Peers: len(h.peers), + Peers: len(hub.peers), Channels: len(subscriberCount), SubscriberCount: subscriberCount, } } -// PeerCount returns the number of connected peers. -// -// n := hub.PeerCount() -func (h *Hub) PeerCount() int { - if h == nil { +// stop := hub.SubscribePublished(func(channel string, frame []byte) { +// core.Print("stream", "channel=%s frame=%d", channel, len(frame)) +// }) +func (hub *Hub) SubscribePublished(handler func(string, []byte)) func() { + return hub.subscribePublished(handler) +} + +// stop := hub.SubscribeBroadcast(func(frame []byte) { +// core.Print("stream", "broadcast frame=%d", len(frame)) +// }) +func (hub *Hub) SubscribeBroadcast(handler func([]byte)) func() { + if hub == nil || handler == nil { + return func() {} + } + hub.mutex.Lock() + if hub.broadcastHandlers == nil { + hub.broadcastHandlers = map[uint64]func([]byte){} + } + hub.nextHandlerID++ + id := hub.nextHandlerID + hub.broadcastHandlers[id] = handler + hub.mutex.Unlock() + + return onceFunction(func() { + hub.mutex.Lock() + defer hub.mutex.Unlock() + delete(hub.broadcastHandlers, id) + }) +} + +// n := hub.PeerCount() +func (hub *Hub) PeerCount() int { + if hub == nil { return 0 } - h.mu.RLock() - defer h.mu.RUnlock() - return len(h.peers) + hub.mutex.RLock() + defer hub.mutex.RUnlock() + return len(hub.peers) } -// ChannelCount returns the number of active channels. -// -// n := hub.ChannelCount() -func (h *Hub) ChannelCount() int { - if h == nil { +// n := hub.ChannelCount() +func (hub *Hub) ChannelCount() int { + if hub == nil { return 0 } - h.mu.RLock() - defer h.mu.RUnlock() + hub.mutex.RLock() + defer hub.mutex.RUnlock() count := 0 - for channel, peers := range h.channels { - if channel == "*" || len(peers) == 0 { + for channel, peers := range hub.channels { + if channel == "*" { + continue + } + if len(peers)+len(hub.channelHandlers[channel]) == 0 { + continue + } + count++ + } + for channel, handlers := range hub.channelHandlers { + if channel == "*" { + continue + } + if len(handlers) == 0 { + continue + } + if len(hub.channels[channel]) > 0 { continue } count++ @@ -328,32 +474,38 @@ func (h *Hub) ChannelCount() int { return count } -// ChannelSubscriberCount returns the subscriber count for a channel. -// Returns 0 if the channel has no subscribers. -// -// n := hub.ChannelSubscriberCount("hashrate") -func (h *Hub) ChannelSubscriberCount(channel string) int { - if h == nil { +// n := hub.ChannelSubscriberCount("hashrate") +func (hub *Hub) ChannelSubscriberCount(channel string) int { + if hub == nil { return 0 } - h.mu.RLock() - defer h.mu.RUnlock() - return len(h.channels[channel]) + hub.mutex.RLock() + defer hub.mutex.RUnlock() + return len(hub.channels[channel]) + len(hub.channelHandlers[channel]) } -// AllPeers returns an iterator for all connected peers. -// -// for peer := range hub.AllPeers() { log.Println(peer.UserID) } -func (h *Hub) AllPeers() iter.Seq[*Peer] { - if h == nil { +// for peer := range hub.AllPeers() { +// _ = peer.UserID +// } +func (hub *Hub) AllPeers() iter.Seq[*Peer] { + if hub == nil { return func(yield func(*Peer) bool) {} } - h.mu.RLock() - peers := make([]*Peer, 0, len(h.peers)) - for peer := range h.peers { + hub.mutex.RLock() + peers := make([]*Peer, 0, len(hub.peers)) + for peer := range hub.peers { peers = append(peers, peer) } - h.mu.RUnlock() + hub.mutex.RUnlock() + sort.SliceStable(peers, func(left, right int) bool { + if peers[left] == nil { + return false + } + if peers[right] == nil { + return true + } + return peers[left].ID < peers[right].ID + }) return func(yield func(*Peer) bool) { for _, peer := range peers { if !yield(peer) { @@ -363,24 +515,35 @@ func (h *Hub) AllPeers() iter.Seq[*Peer] { } } -// AllChannels returns an iterator for all active channels. -// -// for ch := range hub.AllChannels() { log.Println(ch) } -func (h *Hub) AllChannels() iter.Seq[string] { - if h == nil { +// for channel := range hub.AllChannels() { +// _ = channel +// } +func (hub *Hub) AllChannels() iter.Seq[string] { + if hub == nil { return func(yield func(string) bool) {} } - h.mu.RLock() - channels := make([]string, 0, len(h.channels)) - for channel, peers := range h.channels { - if channel == "*" || len(peers) == 0 { + hub.mutex.RLock() + channels := make(map[string]struct{}, len(hub.channels)+len(hub.channelHandlers)) + for channel, peers := range hub.channels { + if channel == "*" || len(peers)+len(hub.channelHandlers[channel]) == 0 { continue } - channels = append(channels, channel) + channels[channel] = struct{}{} + } + for channel, handlers := range hub.channelHandlers { + if channel == "*" || len(handlers) == 0 { + continue + } + channels[channel] = struct{}{} + } + hub.mutex.RUnlock() + sortedChannels := make([]string, 0, len(channels)) + for channel := range channels { + sortedChannels = append(sortedChannels, channel) } - h.mu.RUnlock() + sort.Strings(sortedChannels) return func(yield func(string) bool) { - for _, channel := range channels { + for _, channel := range sortedChannels { if !yield(channel) { return } @@ -388,103 +551,347 @@ func (h *Hub) AllChannels() iter.Seq[string] { } } -// AddPeer registers a peer with the hub and invokes OnConnect. -// -// hub.AddPeer(stream.NewPeer("ws")) -func (h *Hub) AddPeer(peer *Peer) error { - if h == nil { +// peer := stream.NewPeer("ws") +// _ = hub.AddPeer(peer) +func (hub *Hub) AddPeer(peer *Peer) error { + if hub == nil { return core.E("stream.hub", "nil hub", nil) } if peer == nil { return core.E("stream.hub", "nil peer", nil) } + peer.mutex.Lock() if peer.send == nil { - peer.send = make(chan []byte, 256) + peer.send = make(chan []byte, defaultPeerSendBufferSize) } if peer.subscriptions == nil { peer.subscriptions = map[string]bool{} } - h.mu.Lock() - if h.peers == nil { - h.peers = map[*Peer]bool{} + peer.mutex.Unlock() + hub.mutex.RLock() + running := hub.running + hub.mutex.RUnlock() + if running { + select { + case hub.register <- peer: + return nil + default: + } } - if h.peers[peer] { - h.mu.Unlock() - return nil + hub.addPeer(peer) + return nil +} + +// hub.RemovePeer(peer) +func (hub *Hub) RemovePeer(peer *Peer) { + if hub == nil || peer == nil { + return + } + hub.mutex.RLock() + running := hub.running + hub.mutex.RUnlock() + if running { + select { + case hub.unregister <- peer: + return + default: + } + } + hub.removePeer(peer) +} + +// hub.sendToPeer(peer, "hashrate", []byte("123456")) +func (hub *Hub) sendToPeer(peer *Peer, channel string, frame []byte) { + if peer == nil { + return + } + if peer.Transport == "tcp" { + if ok := peer.Send(encodeTCPFrame(channel, frame)); !ok { + return + } + return + } + if ok := peer.Send(frame); !ok { + return } - h.peers[peer] = true - onConnect := h.config.OnConnect - h.mu.Unlock() +} + +// hub.sendBroadcastToPeer(peer, []byte("shutdown")) +func (hub *Hub) sendBroadcastToPeer(peer *Peer, frame []byte) { + if peer == nil { + return + } + if peer.Transport == "tcp" { + if ok := peer.Send(encodeTCPFrame("", frame)); !ok { + return + } + return + } + if ok := peer.Send(frame); !ok { + return + } +} + +// hub.invokeHandlers(handlers, frame) +func (hub *Hub) invokeHandlers(handlers []func([]byte), frame []byte) { + for _, handler := range handlers { + func(handlerFunction func([]byte)) { + defer func() { + if recovered := recover(); recovered != nil { + return + } + }() + handlerFunction(frame) + }(handler) + } +} + +// hub.addPeer(peer) +func (hub *Hub) addPeer(peer *Peer) { + if hub == nil || peer == nil { + return + } + hub.mutex.Lock() + if hub.peers == nil { + hub.peers = map[*Peer]bool{} + } + if hub.peers[peer] { + hub.mutex.Unlock() + return + } + hub.peers[peer] = true + onConnect := hub.config.OnConnect + hub.mutex.Unlock() if onConnect != nil { onConnect(peer) } - return nil } -// RemovePeer unregisters a peer from the hub and invokes OnDisconnect. -// -// hub.RemovePeer(peer) -func (h *Hub) RemovePeer(peer *Peer) { - if h == nil || peer == nil { +// hub.removePeer(peer) +func (hub *Hub) removePeer(peer *Peer) { + if hub == nil || peer == nil { return } - h.mu.Lock() - if !h.peers[peer] { - h.mu.Unlock() + hub.mutex.Lock() + if !hub.peers[peer] { + hub.mutex.Unlock() return } - delete(h.peers, peer) - for channel, peers := range h.channels { + delete(hub.peers, peer) + for channel, peers := range hub.channels { delete(peers, peer) if len(peers) == 0 { - delete(h.channels, channel) + delete(hub.channels, channel) } } - peer.mu.Lock() + peer.mutex.Lock() peer.subscriptions = map[string]bool{} - peer.mu.Unlock() - onDisconnect := h.config.OnDisconnect - h.mu.Unlock() + peer.mutex.Unlock() + onDisconnect := hub.config.OnDisconnect + hub.mutex.Unlock() peer.Close() if onDisconnect != nil { onDisconnect(peer) } } -func (h *Hub) sendToPeer(peer *Peer, channel string, frame []byte) { - if peer == nil { +// hub.broadcastToPeers(nil, frame, true) +func (hub *Hub) broadcastToPeers(_ *Peer, frame []byte, notifyBroadcastSubscribers bool) { + if hub == nil { return } - if peer.Transport == "tcp" { - _ = peer.Send(encodeTCPFrame(channel, frame)) + hub.mutex.RLock() + peers := make([]*Peer, 0, len(hub.peers)) + for peer := range hub.peers { + peers = append(peers, peer) + } + handlers := cloneChannelHandlers(hub.channelHandlers["*"]) + broadcastHandlers := cloneBroadcastHandlers(hub.broadcastHandlers) + hub.mutex.RUnlock() + for _, peer := range peers { + hub.sendBroadcastToPeer(peer, frame) + } + hub.invokeHandlers(handlers, frame) + if notifyBroadcastSubscribers { + hub.invokeBroadcastHandlers(broadcastHandlers, frame) + } +} + +// item := publishDelivery{channel: "block", frame: data, notifyPublishSubscribers: true} +type publishDelivery struct { + channel string + frame []byte + notifyPublishSubscribers bool +} + +// item := broadcastDelivery{frame: data, notifyBroadcastSubscribers: true} +type broadcastDelivery struct { + source *Peer + frame []byte + notifyBroadcastSubscribers bool +} + +// hub.enqueuePublishDelivery("hashrate", frame, true) +func (hub *Hub) enqueuePublishDelivery(channel string, frame []byte, notifyPublishSubscribers bool) { + if hub == nil { return } - _ = peer.Send(frame) + item := publishDelivery{ + channel: channel, + frame: append([]byte(nil), frame...), + notifyPublishSubscribers: notifyPublishSubscribers, + } + select { + case hub.publishQueue <- item: + default: + go hub.enqueuePublishDeliveryAsync(item) + } } -func (h *Hub) sendBroadcastToPeer(peer *Peer, frame []byte) { - if peer == nil { +// hub.enqueueBroadcast(broadcastDelivery{frame: data}) +func (hub *Hub) enqueueBroadcast(item broadcastDelivery) { + if hub == nil { return } - if peer.Transport == "tcp" { - _ = peer.Send(encodeTCPFrame("", frame)) + select { + case hub.broadcastQueue <- item: + case <-hub.done: + } +} + +// go hub.enqueuePublishDeliveryAsync(publishDelivery{channel: "block", frame: data}) +func (hub *Hub) enqueuePublishDeliveryAsync(item publishDelivery) { + if hub == nil { return } - _ = peer.Send(frame) + select { + case hub.publishQueue <- item: + case <-hub.done: + } +} + +// hub.processPublishDelivery("hashrate", frame, true) +func (hub *Hub) processPublishDelivery(channel string, frame []byte, notifyPublishSubscribers bool) { + if hub == nil { + return + } + hub.mutex.RLock() + handlers := cloneChannelHandlers(hub.channelHandlers[channel]) + wildcardHandlers := cloneChannelHandlers(hub.channelHandlers["*"]) + publishHandlers := clonePublishHandlers(hub.publishHandlers) + hub.mutex.RUnlock() + + hub.invokeHandlers(handlers, frame) + if channel != "*" { + hub.invokeHandlers(wildcardHandlers, frame) + } + if notifyPublishSubscribers { + hub.invokePublishHandlers(publishHandlers, channel, frame) + } +} + +// stop := hub.subscribePublished(func(channel string, frame []byte) { ... }) +func (hub *Hub) subscribePublished(handler func(string, []byte)) func() { + if hub == nil || handler == nil { + return func() {} + } + hub.mutex.Lock() + if hub.publishHandlers == nil { + hub.publishHandlers = map[uint64]func(string, []byte){} + } + hub.nextHandlerID++ + id := hub.nextHandlerID + hub.publishHandlers[id] = handler + hub.mutex.Unlock() + + return onceFunction(func() { + hub.mutex.Lock() + defer hub.mutex.Unlock() + delete(hub.publishHandlers, id) + }) +} + +// hub.invokeBroadcastHandlers(handlers, frame) +func (hub *Hub) invokeBroadcastHandlers(handlers []func([]byte), frame []byte) { + for _, handler := range handlers { + func(handlerFunction func([]byte)) { + defer func() { + if recovered := recover(); recovered != nil { + return + } + }() + handlerFunction(frame) + }(handler) + } } -func (h *Hub) invokeHandlers(handlers []func([]byte), frame []byte) { +// hub.invokePublishHandlers(handlers, "block", frame) +func (hub *Hub) invokePublishHandlers(handlers []func(string, []byte), channel string, frame []byte) { for _, handler := range handlers { - func(fn func([]byte)) { + func(handlerFunction func(string, []byte)) { defer func() { - _ = recover() + if recovered := recover(); recovered != nil { + return + } }() - fn(frame) + handlerFunction(channel, frame) }(handler) } } -func cloneHandlers(handlers map[uint64]func([]byte)) []func([]byte) { +// peers := hub.collectChannelPeersLocked("hashrate", nil) +func (hub *Hub) collectChannelPeersLocked(channel string, _ *Peer) []*Peer { + combined := map[*Peer]struct{}{} + for peer := range hub.channels[channel] { + combined[peer] = struct{}{} + } + if channel != "*" { + for peer := range hub.channels["*"] { + combined[peer] = struct{}{} + } + } + peers := make([]*Peer, 0, len(combined)) + for peer := range combined { + peers = append(peers, peer) + } + sort.SliceStable(peers, func(left, right int) bool { + if peers[left] == nil { + return false + } + if peers[right] == nil { + return true + } + return peers[left].ID < peers[right].ID + }) + return peers +} + +// cloned := cloneChannelHandlers(hub.channelHandlers["hashrate"]) +func cloneChannelHandlers(handlers map[uint64]func([]byte)) []func([]byte) { + if len(handlers) == 0 { + return nil + } + cloned := make([]func([]byte), 0, len(handlers)) + for _, handler := range handlers { + cloned = append(cloned, handler) + } + return cloned +} + +// cloned := clonePublishHandlers(hub.publishHandlers) +func clonePublishHandlers(handlers map[uint64]func(string, []byte)) []func(string, []byte) { + if len(handlers) == 0 { + return nil + } + cloned := make([]func(string, []byte), 0, len(handlers)) + for _, handler := range handlers { + cloned = append(cloned, handler) + } + return cloned +} + +// cloned := cloneBroadcastHandlers(hub.broadcastHandlers) +func cloneBroadcastHandlers(handlers map[uint64]func([]byte)) []func([]byte) { if len(handlers) == 0 { return nil } diff --git a/hub_config.go b/hub_config.go index b2a5be5..d10ec65 100644 --- a/hub_config.go +++ b/hub_config.go @@ -4,48 +4,42 @@ package stream import "time" -// HubConfig controls hub behaviour and lifecycle callbacks. -// -// cfg := stream.HubConfig{ +// authoriser := stream.ChannelAuthoriser(func(peer *stream.Peer, channel string) bool { +// return peer.Claims["role"] == "admin" || channel == "public" +// }) +type ChannelAuthoriser func(peer *Peer, channel string) bool + +// config := stream.HubConfig{ // HeartbeatInterval: 30 * time.Second, -// OnConnect: func(p *stream.Peer) { metrics.Inc("peers") }, -// ChannelAuthoriser: func(p *stream.Peer, ch string) bool { -// return p.Claims["role"] == "admin" || ch == "public" +// PongTimeout: 60 * time.Second, +// WriteTimeout: 10 * time.Second, +// OnConnect: func(peer *stream.Peer) { +// metrics.Inc("peers") // }, // } type HubConfig struct { - // HeartbeatInterval is the server-side ping interval for WebSocket peers. - // Defaults to 30 seconds. Ignored by SSE and TCP adapters. + // config := stream.HubConfig{HeartbeatInterval: 30 * time.Second} HeartbeatInterval time.Duration - // PongTimeout is the deadline after a ping before the WS connection is closed. - // Must be greater than HeartbeatInterval. Defaults to 60 seconds. + // config := stream.HubConfig{PongTimeout: 60 * time.Second} PongTimeout time.Duration - // WriteTimeout is the per-write deadline for WS and TCP adapters. - // Defaults to 10 seconds. + // config := stream.HubConfig{WriteTimeout: 10 * time.Second} WriteTimeout time.Duration - // OnConnect is called when a peer registers. Optional. - // - // OnConnect: func(p *stream.Peer) { metrics.Inc("peers") }, + // config := stream.HubConfig{OnConnect: func(peer *stream.Peer) { metrics.Inc("peers") }} OnConnect func(peer *Peer) - // OnDisconnect is called when a peer unregisters. Optional. + // config := stream.HubConfig{OnDisconnect: func(peer *stream.Peer) { metrics.Dec("peers") }} OnDisconnect func(peer *Peer) - // ChannelAuthoriser optionally decides whether a peer may subscribe to a channel. - // Return true to allow. When nil, all subscriptions are allowed. - // - // ChannelAuthoriser: func(p *stream.Peer, ch string) bool { - // return p.Claims["role"] == "admin" || ch == "public" - // }, - ChannelAuthoriser func(peer *Peer, channel string) bool + // config := stream.HubConfig{ChannelAuthoriser: func(peer *stream.Peer, channel string) bool { + // return peer.Claims["role"] == "admin" || channel == "public" + // }} + ChannelAuthoriser ChannelAuthoriser } -// DefaultHubConfig returns sensible defaults. -// -// cfg := stream.DefaultHubConfig() +// config := stream.DefaultHubConfig() func DefaultHubConfig() HubConfig { return HubConfig{ HeartbeatInterval: 30 * time.Second, @@ -53,3 +47,21 @@ func DefaultHubConfig() HubConfig { WriteTimeout: 10 * time.Second, } } + +// config = normalizeHubConfig(config) +func normalizeHubConfig(config HubConfig) HubConfig { + defaults := DefaultHubConfig() + if config.HeartbeatInterval == 0 { + config.HeartbeatInterval = defaults.HeartbeatInterval + } + if config.PongTimeout == 0 { + config.PongTimeout = defaults.PongTimeout + } + if config.PongTimeout <= config.HeartbeatInterval { + config.PongTimeout = config.HeartbeatInterval * 2 + } + if config.WriteTimeout == 0 { + config.WriteTimeout = defaults.WriteTimeout + } + return config +} diff --git a/hub_test.go b/hub_test.go new file mode 100644 index 0000000..accb1c7 --- /dev/null +++ b/hub_test.go @@ -0,0 +1,1066 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream + +import ( + "context" + "sync" + "testing" + "time" +) + +type testStream struct { + mutex sync.Mutex + subscribers map[string]map[int]func([]byte) + nextID int + published []publishedFrame + broadcasts [][]byte +} + +type publishedFrame struct { + channel string + frame []byte +} + +func newTestStream() *testStream { + return &testStream{ + subscribers: map[string]map[int]func([]byte){}, + } +} + +func TestHub_NewPeer_DefaultClaims_Good(t *testing.T) { + peer := NewPeer("ws") + if peer == nil { + t.Fatal("NewPeer() = nil") + } + if peer.Claims == nil { + t.Fatal("NewPeer().Claims = nil, want empty map") + } + if len(peer.Claims) != 0 { + t.Fatalf("len(NewPeer().Claims) = %d, want 0", len(peer.Claims)) + } + peer.Claims["role"] = "worker" + if role := peer.Claims["role"]; role != "worker" { + t.Fatalf("Claims[role] = %v, want %q", role, "worker") + } +} + +func (streamValue *testStream) Publish(channel string, frame []byte) error { + streamValue.mutex.Lock() + streamValue.published = append(streamValue.published, publishedFrame{ + channel: channel, + frame: append([]byte(nil), frame...), + }) + handlers := streamValue.cloneHandlersLocked(channel) + wildcardHandlers := streamValue.cloneHandlersLocked("*") + streamValue.mutex.Unlock() + + for _, handler := range handlers { + handler(frame) + } + if channel != "*" { + for _, handler := range wildcardHandlers { + handler(frame) + } + } + return nil +} + +func (streamValue *testStream) Subscribe(channel string, handler func([]byte)) func() { + streamValue.mutex.Lock() + defer streamValue.mutex.Unlock() + streamValue.nextID++ + id := streamValue.nextID + if streamValue.subscribers[channel] == nil { + streamValue.subscribers[channel] = map[int]func([]byte){} + } + streamValue.subscribers[channel][id] = handler + return func() { + streamValue.mutex.Lock() + defer streamValue.mutex.Unlock() + delete(streamValue.subscribers[channel], id) + if len(streamValue.subscribers[channel]) == 0 { + delete(streamValue.subscribers, channel) + } + } +} + +func (streamValue *testStream) Broadcast(frame []byte) error { + streamValue.mutex.Lock() + defer streamValue.mutex.Unlock() + streamValue.broadcasts = append(streamValue.broadcasts, append([]byte(nil), frame...)) + return nil +} + +func (streamValue *testStream) Pipe(dst Stream) func() { + return Pipe(streamValue, dst) +} + +func (streamValue *testStream) Stats() HubStats { + return HubStats{} +} + +func (streamValue *testStream) cloneHandlersLocked(channel string) []func([]byte) { + handlers := streamValue.subscribers[channel] + if len(handlers) == 0 { + return nil + } + cloned := make([]func([]byte), 0, len(handlers)) + for _, handler := range handlers { + cloned = append(cloned, handler) + } + return cloned +} + +func TestAX7_Hub_Pipe_Good(t *testing.T) { + sourceHub := NewHub() + destinationHub := NewHub() + + sourceContext, sourceCancel := context.WithCancel(context.Background()) + defer sourceCancel() + destinationContext, destinationCancel := context.WithCancel(context.Background()) + defer destinationCancel() + + go sourceHub.Run(sourceContext) + go destinationHub.Run(destinationContext) + waitForRunningHub(t, sourceHub) + waitForRunningHub(t, destinationHub) + + stop := Pipe(sourceHub, destinationHub) + defer stop() + + received := make(chan []byte, 1) + unsubscribe := destinationHub.Subscribe("hashrate", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if err := sourceHub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "123456" { + t.Fatalf("received frame = %q, want %q", string(frame), "123456") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for forwarded frame") + } +} + +func TestHub_Pipe_Broadcast_Good(t *testing.T) { + sourceHub := NewHub() + destinationHub := NewHub() + + sourceContext, sourceCancel := context.WithCancel(context.Background()) + defer sourceCancel() + destinationContext, destinationCancel := context.WithCancel(context.Background()) + defer destinationCancel() + + go sourceHub.Run(sourceContext) + go destinationHub.Run(destinationContext) + waitForRunningHub(t, sourceHub) + waitForRunningHub(t, destinationHub) + + stop := Pipe(sourceHub, destinationHub) + defer stop() + + received := make(chan []byte, 1) + unsubscribe := destinationHub.SubscribeBroadcast(func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if err := sourceHub.Broadcast([]byte("shutdown")); err != nil { + t.Fatalf("Broadcast() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "shutdown" { + t.Fatalf("received broadcast frame = %q, want %q", string(frame), "shutdown") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for broadcast frame") + } +} + +func TestAX7_Hub_Pipe_Bad(t *testing.T) { + sourceHub := NewHub() + destinationHub := NewHub() + + sourceContext, sourceCancel := context.WithCancel(context.Background()) + defer sourceCancel() + destinationContext, destinationCancel := context.WithCancel(context.Background()) + defer destinationCancel() + + go sourceHub.Run(sourceContext) + go destinationHub.Run(destinationContext) + waitForRunningHub(t, sourceHub) + waitForRunningHub(t, destinationHub) + + stop := Pipe(sourceHub, destinationHub) + received := make(chan []byte, 1) + unsubscribe := destinationHub.Subscribe("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + var stopWG sync.WaitGroup + for i := 0; i < 8; i++ { + stopWG.Add(1) + go func() { + defer stopWG.Done() + stop() + }() + } + stopWG.Wait() + + if err := sourceHub.Publish("block", []byte("template")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + t.Fatalf("received unexpected frame after stop: %q", string(frame)) + case <-time.After(200 * time.Millisecond): + } +} + +func TestAX7_Hub_Pipe_Ugly(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + stop := Pipe(hub, hub) + defer stop() + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("agent", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if err := hub.Publish("agent", []byte("event")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "event" { + t.Fatalf("received frame = %q, want %q", string(frame), "event") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for local frame") + } +} + +func TestHub_Pipe_GenericPublishFallback_Good(t *testing.T) { + sourceStream := newTestStream() + destinationStream := newTestStream() + + stop := Pipe(sourceStream, destinationStream) + defer stop() + + if err := sourceStream.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + destinationStream.mutex.Lock() + defer destinationStream.mutex.Unlock() + if len(destinationStream.published) != 1 { + t.Fatalf("len(published) = %d, want %d", len(destinationStream.published), 1) + } + if destinationStream.published[0].channel != "*" { + t.Fatalf("published channel = %q, want %q", destinationStream.published[0].channel, "*") + } + if string(destinationStream.published[0].frame) != "123456" { + t.Fatalf("published frame = %q, want %q", string(destinationStream.published[0].frame), "123456") + } + if len(destinationStream.broadcasts) != 0 { + t.Fatalf("len(broadcasts) = %d, want %d", len(destinationStream.broadcasts), 0) + } +} + +func TestAX7_Hub_Publish_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + waitForPeerCount(t, hub, 1) + + if err := hub.SubscribePeer(peer, "hashrate"); err != nil { + t.Fatalf("SubscribePeer(channel) error = %v", err) + } + if err := hub.SubscribePeer(peer, "*"); err != nil { + t.Fatalf("SubscribePeer(wildcard) error = %v", err) + } + + if err := hub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-peer.SendQueue(): + if string(frame) != "123456" { + t.Fatalf("received frame = %q, want %q", string(frame), "123456") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for published frame") + } + + select { + case frame := <-peer.SendQueue(): + t.Fatalf("received duplicate frame = %q", string(frame)) + case <-time.After(200 * time.Millisecond): + } +} + +func TestAX7_Hub_Publish_Bad(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + if err := hub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v, want nil", err) + } +} + +func TestAX7_Hub_PublishFromPeer_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + + if err := hub.SubscribePeer(peer, "hashrate"); err != nil { + t.Fatalf("SubscribePeer() error = %v", err) + } + + if err := hub.PublishFromPeer(peer, "hashrate", []byte("123456")); err != nil { + t.Fatalf("PublishFromPeer() error = %v", err) + } + + select { + case frame := <-peer.SendQueue(): + if string(frame) != "123456" { + t.Fatalf("received frame = %q, want %q", string(frame), "123456") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for published frame") + } +} + +func TestAX7_Hub_BroadcastFromPeer_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + + if err := hub.BroadcastFromPeer(peer, []byte("shutdown")); err != nil { + t.Fatalf("BroadcastFromPeer() error = %v", err) + } + + select { + case frame := <-peer.SendQueue(): + if string(frame) != "shutdown" { + t.Fatalf("received frame = %q, want %q", string(frame), "shutdown") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for broadcast frame") + } +} + +func TestAX7_Hub_Publish_Ugly(t *testing.T) { + hub := NewHub() + + if err := hub.Publish("hashrate", []byte("123456")); err != ErrHubNotRunning { + t.Fatalf("Publish() error = %v, want %v", err, ErrHubNotRunning) + } +} + +func TestAX7_Hub_Running_Good(t *testing.T) { + hub := NewHub() + if hub.Running() { + t.Fatal("Running() = true before Run()") + } + + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + if !hub.Running() { + t.Fatal("Running() = false while Run() is active") + } + + hubCancel() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if !hub.Running() { + return + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatal("Running() stayed true after context cancellation") +} + +func TestAX7_Hub_Running_Bad(t *testing.T) { + var hub *Hub + if hub.Running() { + t.Fatal("nil hub Running() = true, want false") + } +} + +func TestAX7_Hub_Running_Ugly(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + observed := make(chan bool, 1) + go func() { + observed <- hub.Running() + }() + + select { + case running := <-observed: + if !running { + t.Fatal("Running() = false while hub is active") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for concurrent Running() read") + } + + hubCancel() +} + +func TestAX7_Hub_Broadcast_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + waitForPeerCount(t, hub, 1) + + if err := hub.Broadcast([]byte("123456")); err != nil { + t.Fatalf("Broadcast() error = %v", err) + } + + select { + case frame := <-peer.SendQueue(): + if string(frame) != "123456" { + t.Fatalf("received frame = %q, want %q", string(frame), "123456") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for broadcast frame") + } +} + +func TestAX7_Hub_Broadcast_Bad(t *testing.T) { + hub := NewHub() + + if err := hub.Broadcast([]byte("123456")); err != ErrHubNotRunning { + t.Fatalf("Broadcast() error = %v, want %v", err, ErrHubNotRunning) + } +} + +func TestAX7_Hub_Broadcast_Ugly(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + waitForPeerCount(t, hub, 1) + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("*", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if err := hub.Broadcast([]byte("event")); err != nil { + t.Fatalf("Broadcast() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "event" { + t.Fatalf("received handler frame = %q, want %q", string(frame), "event") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for broadcast handler") + } + + select { + case frame := <-peer.SendQueue(): + if string(frame) != "event" { + t.Fatalf("received peer frame = %q, want %q", string(frame), "event") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for broadcast peer") + } + + hubCancel() + waitForPeerCount(t, hub, 0) +} + +func TestAX7_Hub_SubscribeE_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + received := make(chan []byte, 1) + unsubscribe, err := hub.SubscribeE("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + if err != nil { + t.Fatalf("SubscribeE() error = %v", err) + } + defer unsubscribe() + + if err := hub.Publish("block", []byte("template")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for subscribed frame") + } +} + +func TestAX7_Hub_SubscribeWithError_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + received := make(chan []byte, 1) + unsubscribe, err := hub.SubscribeWithError("block", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + if err != nil { + t.Fatalf("SubscribeWithError() error = %v", err) + } + defer unsubscribe() + + if err := hub.Publish("block", []byte("template")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "template" { + t.Fatalf("received frame = %q, want %q", string(frame), "template") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for subscribed frame") + } +} + +func TestHub_Stats_IncludeHandlerOnlyChannels_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + unsubscribe := hub.Subscribe("events", func(frame []byte) {}) + defer unsubscribe() + + stats := hub.Stats() + if stats.Peers != 0 { + t.Fatalf("Stats().Peers = %d, want %d", stats.Peers, 0) + } + if stats.Channels != 1 { + t.Fatalf("Stats().Channels = %d, want %d", stats.Channels, 1) + } + if stats.SubscriberCount["events"] != 1 { + t.Fatalf("Stats().SubscriberCount[events] = %d, want %d", stats.SubscriberCount["events"], 1) + } + if hub.ChannelCount() != 1 { + t.Fatalf("ChannelCount() = %d, want %d", hub.ChannelCount(), 1) + } + if hub.ChannelSubscriberCount("events") != 1 { + t.Fatalf("ChannelSubscriberCount(events) = %d, want %d", hub.ChannelSubscriberCount("events"), 1) + } + + channels := make([]string, 0, 1) + for channel := range hub.AllChannels() { + channels = append(channels, channel) + } + if len(channels) != 1 || channels[0] != "events" { + t.Fatalf("AllChannels() = %v, want [events]", channels) + } +} + +func TestAX7_Hub_SubscribeE_Bad(t *testing.T) { + hub := NewHub() + + unsubscribe, err := hub.SubscribeE("", func(frame []byte) {}) + if err != ErrEmptyChannel { + t.Fatalf("SubscribeE() error = %v, want %v", err, ErrEmptyChannel) + } + if unsubscribe == nil { + t.Fatal("SubscribeE() unsubscribe = nil") + } + unsubscribe() +} + +func TestAX7_Hub_SubscribeE_Ugly(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + panicked := 0 + unsubscribe, err := hub.SubscribeE("event", func(frame []byte) { + panicked++ + panic("boom") + }) + if err != nil { + t.Fatalf("SubscribeE() error = %v", err) + } + defer unsubscribe() + + received := make(chan []byte, 1) + safeUnsubscribe := hub.Subscribe("event", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer safeUnsubscribe() + + if err := hub.Publish("event", []byte("payload")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "payload" { + t.Fatalf("received frame = %q, want %q", string(frame), "payload") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for safe handler") + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if panicked == 1 { + return + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatalf("SubscribeE panic handler count = %d, want 1", panicked) +} + +func TestAX7_Hub_CanSubscribePeer_Bad(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + ChannelAuthoriser: func(peer *Peer, channel string) bool { + return channel == "public" + }, + }) + + peer := NewPeer("ws") + if err := hub.CanSubscribePeer(peer, "private"); err != ErrAuthRejected { + t.Fatalf("CanSubscribePeer() error = %v, want %v", err, ErrAuthRejected) + } + if err := hub.CanSubscribePeer(peer, "public"); err != nil { + t.Fatalf("CanSubscribePeer() error = %v, want nil", err) + } +} + +func TestAX7_Peer_Subscriptions_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + + if err := hub.SubscribePeer(peer, "hashrate"); err != nil { + t.Fatalf("SubscribePeer(hashrate) error = %v", err) + } + if err := hub.SubscribePeer(peer, "block"); err != nil { + t.Fatalf("SubscribePeer(block) error = %v", err) + } + + subscriptions := peer.Subscriptions() + if len(subscriptions) != 2 { + t.Fatalf("len(Subscriptions()) = %d, want %d", len(subscriptions), 2) + } + if subscriptions[0] != "block" || subscriptions[1] != "hashrate" { + t.Fatalf("Subscriptions() = %v, want [block hashrate]", subscriptions) + } + + hub.UnsubscribePeer(peer, "block") + subscriptions = peer.Subscriptions() + if len(subscriptions) != 1 || subscriptions[0] != "hashrate" { + t.Fatalf("Subscriptions() after unsubscribe = %v, want [hashrate]", subscriptions) + } +} + +func TestAX7_Peer_Subscriptions_Bad(t *testing.T) { + var peer *Peer + + if subscriptions := peer.Subscriptions(); subscriptions != nil { + t.Fatalf("Subscriptions() = %v, want nil", subscriptions) + } +} + +func TestAX7_Peer_Subscriptions_Ugly(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + + if err := hub.SubscribePeer(peer, "hashrate"); err != nil { + t.Fatalf("SubscribePeer() error = %v", err) + } + + subscriptions := peer.Subscriptions() + subscriptions[0] = "tampered" + + current := peer.Subscriptions() + if len(current) != 1 || current[0] != "hashrate" { + t.Fatalf("Subscriptions() after caller mutation = %v, want [hashrate]", current) + } +} + +func TestHub_SendToChannel_Wildcard_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + count := 0 + unsubscribe := hub.Subscribe("*", func(frame []byte) { + if string(frame) == "event" { + count++ + } + }) + defer unsubscribe() + + if err := hub.Publish("*", []byte("event")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if count == 1 { + return + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatalf("wildcard handler count = %d, want 1", count) +} + +func TestAX7_Peer_Close_Good(t *testing.T) { + peer := NewPeer("ws") + closed := make(chan struct{}, 1) + + peer.SetCloseHook(func() { + closed <- struct{}{} + }) + peer.Close() + peer.Close() + + select { + case <-closed: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for close hook") + } + + select { + case <-closed: + t.Fatal("close hook ran more than once") + case <-time.After(200 * time.Millisecond): + } + + select { + case _, ok := <-peer.SendQueue(): + if ok { + t.Fatal("SendQueue() channel still open after Close()") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for closed SendQueue()") + } +} + +func TestAX7_Hub_Run_Good(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go hub.Run(ctx) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + waitForPeerCount(t, hub, 1) + + cancel() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if !hub.Running() { + break + } + time.Sleep(10 * time.Millisecond) + } + if hub.Running() { + t.Fatal("hub still running after context cancellation") + } + if hub.PeerCount() != 0 { + t.Fatalf("PeerCount() = %d after shutdown, want 0", hub.PeerCount()) + } +} + +func TestAX7_Hub_Run_Bad(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go hub.Run(ctx) + waitForRunningHub(t, hub) + + // Second Run call is a no-op — hub remains running with the original context. + secondDone := make(chan struct{}) + go func() { + hub.Run(ctx) + close(secondDone) + }() + + select { + case <-secondDone: + case <-time.After(2 * time.Second): + t.Fatal("second Run() did not return immediately") + } + + if !hub.Running() { + t.Fatal("hub stopped after second Run() call") + } +} + +func TestAX7_Hub_Run_Ugly(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + + go hub.Run(ctx) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + waitForPeerCount(t, hub, 1) + + // Cancel context while a broadcast is in flight. + go func() { + for i := 0; i < 100; i++ { + _ = hub.Broadcast([]byte("inflight")) + } + }() + cancel() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if !hub.Running() { + break + } + time.Sleep(10 * time.Millisecond) + } + if hub.Running() { + t.Fatal("hub still running after context cancellation during broadcast") + } + if hub.PeerCount() != 0 { + t.Fatalf("PeerCount() = %d after shutdown, want 0 (goroutine leak)", hub.PeerCount()) + } +} + +func TestAX7_Hub_Subscribe_Good(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go hub.Run(ctx) + waitForRunningHub(t, hub) + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("hashrate", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if err := hub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "123456" { + t.Fatalf("received frame = %q, want %q", string(frame), "123456") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for subscribed frame") + } +} + +func TestAX7_Hub_Subscribe_Bad(t *testing.T) { + hub := NewHub() + + unsubscribe := hub.Subscribe("", func(frame []byte) {}) + if unsubscribe == nil { + t.Fatal("Subscribe() with empty channel returned nil unsubscribe") + } + unsubscribe() +} + +func TestAX7_Hub_Subscribe_Ugly(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go hub.Run(ctx) + waitForRunningHub(t, hub) + + panicked := 0 + _ = hub.Subscribe("event", func(frame []byte) { + panicked++ + panic("handler panic") + }) + + received := make(chan []byte, 1) + unsubscribe := hub.Subscribe("event", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + if err := hub.Publish("event", []byte("payload")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "payload" { + t.Fatalf("received frame = %q, want %q", string(frame), "payload") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for safe handler after panic") + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if panicked == 1 { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("panic handler count = %d, want 1", panicked) +} + +func waitForRunningHub(t *testing.T, hub *Hub) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + hub.mutex.RLock() + running := hub.running + hub.mutex.RUnlock() + if running { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for hub to start") +} + +func waitForPeerCount(t *testing.T, hub *Hub, expected int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if hub.PeerCount() == expected { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("PeerCount() = %d, want %d", hub.PeerCount(), expected) +} diff --git a/message.go b/message.go index 4c84704..38419fc 100644 --- a/message.go +++ b/message.go @@ -1,31 +1,55 @@ // SPDX-License-Identifier: EUPL-1.2 +// msg := stream.Message{ +// Type: stream.TypeEvent, +// Channel: "hashrate", +// Data: map[string]any{"h": 1234567}, +// Timestamp: time.Now().UTC(), +// } +// +// frame, _ := core.JSONMarshal(msg) +// hub.Publish("hashrate", frame.Value.([]byte)) package stream import "time" -// MessageType identifies the purpose of a WebSocket message. -// Preserved from go-ws for backward compatibility with browser clients. +// messageType := stream.TypeEvent type MessageType string +// typ := stream.TypeEvent.String() +// // typ == "event" +func (messageType MessageType) String() string { + return string(messageType) +} + const ( - TypeProcessOutput MessageType = "process_output" // real-time process output line - TypeProcessStatus MessageType = "process_status" // process status change (running/exited) - TypeEvent MessageType = "event" // generic named event - TypeError MessageType = "error" // error message - TypePing MessageType = "ping" // client → server keepalive - TypePong MessageType = "pong" // server → client keepalive response - TypeSubscribe MessageType = "subscribe" // client requests channel subscription - TypeUnsubscribe MessageType = "unsubscribe" // client cancels channel subscription + // message := stream.Message{Type: stream.TypeProcessOutput, ProcessID: "build-123"} + TypeProcessOutput MessageType = "process_output" + // message := stream.Message{Type: stream.TypeProcessStatus, ProcessID: "build-123"} + TypeProcessStatus MessageType = "process_status" + // message := stream.Message{Type: stream.TypeEvent, Channel: "hashrate"} + TypeEvent MessageType = "event" + // message := stream.Message{Type: stream.TypeError, Data: "unauthorised"} + TypeError MessageType = "error" + // message := stream.Message{Type: stream.TypePing, ProcessID: "client-1"} + TypePing MessageType = "ping" + // reply := stream.Message{Type: stream.TypePong, ProcessID: "client-1"} + TypePong MessageType = "pong" + // message := stream.Message{Type: stream.TypeSubscribe, Channel: "block"} + TypeSubscribe MessageType = "subscribe" + // message := stream.Message{Type: stream.TypeUnsubscribe, Channel: "block"} + TypeUnsubscribe MessageType = "unsubscribe" ) -// Message is the JSON envelope for WebSocket frames. Preserved from go-ws. -// // msg := stream.Message{ -// Type: stream.TypeEvent, -// Channel: "hashrate", -// Data: map[string]any{"h": 1234567}, +// Type: stream.TypeEvent, +// Channel: "hashrate", +// Data: map[string]any{"h": 1234567}, +// Timestamp: time.Now().UTC(), // } +// +// frame, _ := core.JSONMarshal(msg) +// _ = frame type Message struct { Type MessageType `json:"type"` Channel string `json:"channel,omitempty"` diff --git a/message_test.go b/message_test.go new file mode 100644 index 0000000..e24803b --- /dev/null +++ b/message_test.go @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream + +import ( + "testing" + "time" +) + +func TestAX7_MessageType_String_Good(t *testing.T) { + cases := []struct { + messageType MessageType + expected string + }{ + {TypeProcessOutput, "process_output"}, + {TypeProcessStatus, "process_status"}, + {TypeEvent, "event"}, + {TypeError, "error"}, + {TypePing, "ping"}, + {TypePong, "pong"}, + {TypeSubscribe, "subscribe"}, + {TypeUnsubscribe, "unsubscribe"}, + } + for _, testCase := range cases { + if testCase.messageType.String() != testCase.expected { + t.Fatalf("%q.String() = %q, want %q", testCase.messageType, testCase.messageType.String(), testCase.expected) + } + } +} + +func TestAX7_MessageType_String_Bad(t *testing.T) { + // Unknown MessageType returns its raw string value. + unknown := MessageType("nonexistent") + if unknown.String() != "nonexistent" { + t.Fatalf("unknown MessageType.String() = %q, want %q", unknown.String(), "nonexistent") + } +} + +func TestAX7_MessageType_String_Ugly(t *testing.T) { + // Empty MessageType returns empty string. + empty := MessageType("") + if empty.String() != "" { + t.Fatalf("empty MessageType.String() = %q, want %q", empty.String(), "") + } +} + +func TestMessage_Fields_Good(t *testing.T) { + timestamp := time.Date(2026, 4, 5, 12, 0, 0, 0, time.UTC) + message := Message{ + Type: TypeEvent, + Channel: "hashrate", + ProcessID: "agent-42", + Data: map[string]any{"h": 1234567}, + Timestamp: timestamp, + } + if message.Type != TypeEvent { + t.Fatalf("message.Type = %q, want %q", message.Type, TypeEvent) + } + if message.Channel != "hashrate" { + t.Fatalf("message.Channel = %q, want %q", message.Channel, "hashrate") + } + if message.ProcessID != "agent-42" { + t.Fatalf("message.ProcessID = %q, want %q", message.ProcessID, "agent-42") + } + if message.Timestamp != timestamp { + t.Fatalf("message.Timestamp = %v, want %v", message.Timestamp, timestamp) + } +} + +func TestMessage_Fields_Bad(t *testing.T) { + // Zero-value Message has empty fields — no panic. + message := Message{} + if message.Type != "" { + t.Fatalf("zero Message.Type = %q, want empty", message.Type) + } + if message.Channel != "" { + t.Fatalf("zero Message.Channel = %q, want empty", message.Channel) + } + if message.Timestamp.IsZero() != true { + t.Fatal("zero Message.Timestamp.IsZero() = false, want true") + } +} + +func TestMessage_Fields_Ugly(t *testing.T) { + // Message with nil Data does not panic on access. + message := Message{Type: TypeError, Data: nil} + if message.Data != nil { + t.Fatalf("Message.Data = %v, want nil", message.Data) + } +} diff --git a/stats.go b/stats.go index dd2a365..d6261a8 100644 --- a/stats.go +++ b/stats.go @@ -2,17 +2,18 @@ package stream -// HubStats is a snapshot of hub state at a point in time. -// -// s := hub.Stats() -// log.Printf("peers=%d channels=%d", s.Peers, s.Channels) +// stats := hub.Stats() +// core.Print("stream", "peers=%d channels=%d", stats.Peers, stats.Channels) type HubStats struct { - // Peers is the number of currently connected peers across all transports. + // stats := hub.Stats() + // core.Print("stream", "peers=%d", stats.Peers) Peers int `json:"peers"` - // Channels is the number of active named channels with at least one subscriber. + // stats := hub.Stats() + // core.Print("stream", "channels=%d", stats.Channels) Channels int `json:"channels"` - // SubscriberCount maps channel name to subscriber count. + // stats := hub.Stats() + // count := stats.SubscriberCount["hashrate"] SubscriberCount map[string]int `json:"subscriber_count"` } diff --git a/stats_test.go b/stats_test.go new file mode 100644 index 0000000..8ebf7da --- /dev/null +++ b/stats_test.go @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream + +import ( + "context" + "testing" + + "dappco.re/go" +) + +func TestStats_HubStats_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + waitForPeerCount(t, hub, 1) + + if err := hub.SubscribePeer(peer, "hashrate"); err != nil { + t.Fatalf("SubscribePeer() error = %v", err) + } + + stats := hub.Stats() + if stats.Peers != 1 { + t.Fatalf("Stats().Peers = %d, want %d", stats.Peers, 1) + } + if stats.Channels != 1 { + t.Fatalf("Stats().Channels = %d, want %d", stats.Channels, 1) + } + if stats.SubscriberCount["hashrate"] != 1 { + t.Fatalf("Stats().SubscriberCount[hashrate] = %d, want %d", stats.SubscriberCount["hashrate"], 1) + } +} + +func TestStats_HubStats_Bad(t *testing.T) { + // Stats on a nil hub returns zero values. + var hub *Hub + stats := hub.Stats() + if stats.Peers != 0 { + t.Fatalf("nil hub Stats().Peers = %d, want %d", stats.Peers, 0) + } + if stats.Channels != 0 { + t.Fatalf("nil hub Stats().Channels = %d, want %d", stats.Channels, 0) + } +} + +func TestStats_HubStats_Ugly(t *testing.T) { + // Stats called after all peers are removed returns zero peers. + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + if err := hub.SubscribePeer(peer, "hashrate"); err != nil { + t.Fatalf("SubscribePeer() error = %v", err) + } + hub.RemovePeer(peer) + waitForPeerCount(t, hub, 0) + + stats := hub.Stats() + if stats.Peers != 0 { + t.Fatalf("Stats().Peers after remove = %d, want %d", stats.Peers, 0) + } +} + +func TestStats_HubStats_JSONTags_Good(t *testing.T) { + // Verify HubStats serialises with the expected JSON field names. + stats := HubStats{ + Peers: 3, + Channels: 2, + SubscriberCount: map[string]int{"hashrate": 2, "block": 1}, + } + result := core.JSONMarshal(stats) + if !result.OK { + t.Fatalf("JSONMarshal(HubStats) failed: %v", result.Value) + } + serialised := string(result.Value.([]byte)) + if !core.Contains(serialised, `"peers":3`) { + t.Fatalf("JSON missing peers field: %s", serialised) + } + if !core.Contains(serialised, `"channels":2`) { + t.Fatalf("JSON missing channels field: %s", serialised) + } + if !core.Contains(serialised, `"subscriber_count"`) { + t.Fatalf("JSON missing subscriber_count field: %s", serialised) + } +} + +func TestStats_PeerCount_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + if hub.PeerCount() != 0 { + t.Fatalf("PeerCount() = %d, want %d", hub.PeerCount(), 0) + } + + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + waitForPeerCount(t, hub, 1) + + if hub.PeerCount() != 1 { + t.Fatalf("PeerCount() = %d, want %d", hub.PeerCount(), 1) + } + hub.RemovePeer(peer) +} + +func TestStats_PeerCount_Bad(t *testing.T) { + // PeerCount on nil hub returns 0. + var hub *Hub + if hub.PeerCount() != 0 { + t.Fatalf("nil hub PeerCount() = %d, want %d", hub.PeerCount(), 0) + } +} + +func TestStats_ChannelCount_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + unsubscribe := hub.Subscribe("events", func([]byte) {}) + defer unsubscribe() + + if hub.ChannelCount() != 1 { + t.Fatalf("ChannelCount() = %d, want %d", hub.ChannelCount(), 1) + } +} + +func TestStats_ChannelCount_Bad(t *testing.T) { + // ChannelCount on nil hub returns 0. + var hub *Hub + if hub.ChannelCount() != 0 { + t.Fatalf("nil hub ChannelCount() = %d, want %d", hub.ChannelCount(), 0) + } +} + +func TestStats_ChannelSubscriberCount_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + // Peer subscriber. + peer := NewPeer("ws") + if err := hub.AddPeer(peer); err != nil { + t.Fatalf("AddPeer() error = %v", err) + } + defer hub.RemovePeer(peer) + + if err := hub.SubscribePeer(peer, "hashrate"); err != nil { + t.Fatalf("SubscribePeer() error = %v", err) + } + + // Handler subscriber. + unsubscribe := hub.Subscribe("hashrate", func([]byte) {}) + defer unsubscribe() + + count := hub.ChannelSubscriberCount("hashrate") + if count != 2 { + t.Fatalf("ChannelSubscriberCount(hashrate) = %d, want %d", count, 2) + } +} + +func TestStats_ChannelSubscriberCount_Bad(t *testing.T) { + hub := NewHub() + // Channel with no subscribers returns 0. + count := hub.ChannelSubscriberCount("nonexistent") + if count != 0 { + t.Fatalf("ChannelSubscriberCount(nonexistent) = %d, want %d", count, 0) + } +} + +func TestStats_AllPeers_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + peer1 := NewPeer("ws") + peer2 := NewPeer("sse") + _ = hub.AddPeer(peer1) + _ = hub.AddPeer(peer2) + defer hub.RemovePeer(peer1) + defer hub.RemovePeer(peer2) + waitForPeerCount(t, hub, 2) + + count := 0 + for range hub.AllPeers() { + count++ + } + if count != 2 { + t.Fatalf("AllPeers() count = %d, want %d", count, 2) + } +} + +func TestStats_AllPeers_Bad(t *testing.T) { + // AllPeers on nil hub yields no peers. + var hub *Hub + count := 0 + for range hub.AllPeers() { + count++ + } + if count != 0 { + t.Fatalf("nil hub AllPeers() count = %d, want %d", count, 0) + } +} + +func TestStats_AllChannels_Good(t *testing.T) { + hub := NewHub() + hubContext, hubCancel := context.WithCancel(context.Background()) + defer hubCancel() + go hub.Run(hubContext) + waitForRunningHub(t, hub) + + unsub1 := hub.Subscribe("block", func([]byte) {}) + unsub2 := hub.Subscribe("hashrate", func([]byte) {}) + defer unsub1() + defer unsub2() + + channels := make([]string, 0, 2) + for channel := range hub.AllChannels() { + channels = append(channels, channel) + } + if len(channels) != 2 { + t.Fatalf("AllChannels() count = %d, want %d", len(channels), 2) + } + // Channels should be sorted. + if channels[0] != "block" || channels[1] != "hashrate" { + t.Fatalf("AllChannels() = %v, want [block, hashrate]", channels) + } +} + +func TestStats_AllChannels_Bad(t *testing.T) { + // AllChannels on nil hub yields no channels. + var hub *Hub + count := 0 + for range hub.AllChannels() { + count++ + } + if count != 0 { + t.Fatalf("nil hub AllChannels() count = %d, want %d", count, 0) + } +} diff --git a/stream.go b/stream.go index c95c95b..1cb9f85 100644 --- a/stream.go +++ b/stream.go @@ -1,15 +1,11 @@ // SPDX-License-Identifier: EUPL-1.2 -// Package stream is the transport-agnostic event and data pipe for the CoreGO -// ecosystem. It generalises WebSocket, SSE, Redis pub/sub, ZeroMQ, and raw TCP -// behind a single Stream interface. Consumers never import a specific transport — -// they call Stream. Transport adapters are wired at startup. +// Package stream wires transport-agnostic hubs and peers together. // // hub := stream.NewHub() // go hub.Run(ctx) -// hub.Publish("hashrate", []byte(`{"h":123456}`)) -// unsub := hub.Subscribe("block", func(f []byte) { handleBlock(f) }) -// defer unsub() +// stop := stream.Pipe(hub, remoteHub) +// defer stop() package stream import ( @@ -23,59 +19,43 @@ import ( "time" ) +const defaultPeerSendBufferSize = 256 + // Stream is the transport-agnostic event and data pipe. -// Consumers never import a specific adapter — they call Stream. // -// var s stream.Stream = hub -// s.Publish("hashrate", frame) -// s.Subscribe("block", handler) +// hub := stream.NewHub() +// var bus stream.Stream = hub +// _ = bus.Publish("hashrate", []byte(`{"h":123456}`)) +// stop := bus.Pipe(remoteHub) +// defer stop() type Stream interface { - // Publish sends frame to all subscribers of channel. - // Returns core.E if the hub is not running. - // - // hub.Publish("hashrate", []byte(`{"h":123456}`)) + // _ = hub.Publish("hashrate", []byte(`{"h":123456}`)) Publish(channel string, frame []byte) error - // Subscribe registers handler for all frames arriving on channel. - // Returns an unsubscribe function. Safe to call from multiple goroutines. - // - // unsub := hub.Subscribe("block", func(f []byte) { ... }) - // defer unsub() + // unsubscribe := hub.Subscribe("block", func(frame []byte) { handleBlock(frame) }) + // defer unsubscribe() Subscribe(channel string, handler func([]byte)) func() - // Broadcast sends frame to every connected peer regardless of subscriptions. - // - // hub.Broadcast([]byte(`{"type":"shutdown"}`)) + // _ = hub.Broadcast([]byte(`{"type":"shutdown"}`)) Broadcast(frame []byte) error - // Pipe connects this stream to dst: every frame published here is forwarded to dst. - // Returns a stop function. - // - // stop := hub.Pipe(remoteHub) - // defer stop() - Pipe(dst Stream) func() + // stop := localHub.Pipe(remoteHub) + // defer stop() + Pipe(destination Stream) func() - // Stats returns a snapshot of current hub state. - // - // s := hub.Stats() + // stats := hub.Stats() Stats() HubStats } -// Frame is a raw byte payload delivered through the hub. -// Adapters and consumers define their own serialisation over Frame. +// frame := stream.Frame([]byte(`{"type":"event"}`)) type Frame = []byte -// Channel is a named topic string used for pub/sub routing. +// channel := stream.Channel("hashrate") type Channel = string -// Peer represents one connected endpoint. Created by a transport adapter. -// -// peer := &stream.Peer{ -// ID: uuid.New(), -// UserID: authResult.UserID, -// Claims: authResult.Claims, -// Transport: "ws", -// } +// peer := stream.NewPeer("ws") +// peer.UserID = authResult.UserID +// peer.Claims = authResult.Claims type Peer struct { // ID is a random UUID assigned on creation. ID string @@ -92,92 +72,123 @@ type Peer struct { send chan []byte subscriptions map[string]bool - mu sync.RWMutex + closeHook func() + mutex sync.RWMutex closeOnce sync.Once } -// NewPeer creates a peer with a generated identifier and a buffered send queue. -// -// peer := stream.NewPeer("ws") +// peer := stream.NewPeer("ws") +// peer.UserID = "user-42" func NewPeer(transport string) *Peer { return &Peer{ - ID: randomID(), + ID: randomUUID(), + Claims: map[string]any{}, Transport: transport, - send: make(chan []byte, 256), + send: make(chan []byte, defaultPeerSendBufferSize), subscriptions: map[string]bool{}, } } -// Subscriptions returns a copy of this peer's current channel subscriptions. -// -// channels := peer.Subscriptions() // ["hashrate", "block"] -func (p *Peer) Subscriptions() []string { - if p == nil { +// channels := peer.Subscriptions() // ["hashrate", "block"] +func (peer *Peer) Subscriptions() []string { + if peer == nil { return nil } - p.mu.RLock() - defer p.mu.RUnlock() - channels := make([]string, 0, len(p.subscriptions)) - for channel := range p.subscriptions { + peer.mutex.RLock() + defer peer.mutex.RUnlock() + channels := make([]string, 0, len(peer.subscriptions)) + for channel := range peer.subscriptions { channels = append(channels, channel) } sort.Strings(channels) return channels } -// Send enqueues frame for delivery. Non-blocking: drops and returns false if buffer full. -// -// ok := peer.Send(frame) -func (p *Peer) Send(frame []byte) bool { - if p == nil { +// ok := peer.Send([]byte("template")) +func (peer *Peer) Send(frame []byte) bool { + if peer == nil { return false } defer func() { - _ = recover() + if recovered := recover(); recovered != nil { + return + } }() - p.mu.RLock() - defer p.mu.RUnlock() - if p.send == nil { + peer.mutex.RLock() + defer peer.mutex.RUnlock() + if peer.send == nil { return false } payload := append([]byte(nil), frame...) select { - case p.send <- payload: + case peer.send <- payload: return true default: return false } } -// Close signals the transport adapter to shut down this connection. -// -// peer.Close() -func (p *Peer) Close() { - if p == nil { +// peer := stream.NewPeer("ws") +// peer.SetCloseHook(func() { _ = conn.Close() }) +// peer.Close() +func (peer *Peer) Close() { + if peer == nil { return } - p.closeOnce.Do(func() { - p.mu.Lock() - defer p.mu.Unlock() - if p.send != nil { - close(p.send) + peer.closeOnce.Do(func() { + peer.mutex.Lock() + send := peer.send + closeHook := peer.closeHook + peer.closeHook = nil + peer.mutex.Unlock() + if send != nil { + close(send) + } + if closeHook != nil { + closeHook() } }) } -// SendQueue returns the peer's outgoing frame queue. +// peer.SetCloseHook(func() { _ = conn.Close() }) +func (peer *Peer) SetCloseHook(closeFunc func()) { + if peer == nil { + return + } + peer.mutex.Lock() + defer peer.mutex.Unlock() + peer.closeHook = closeFunc +} + +// SendQueue exposes the adapter-facing outbound queue. // -// for frame := range peer.SendQueue() { ... } -func (p *Peer) SendQueue() <-chan []byte { - if p == nil { +// go func() { +// for frame := range peer.SendQueue() { +// _ = frame +// } +// }() +func (peer *Peer) SendQueue() <-chan []byte { + if peer == nil { return nil } - p.mu.RLock() - defer p.mu.RUnlock() - return p.send + peer.mutex.RLock() + defer peer.mutex.RUnlock() + return peer.send } -// ConnectionState represents the lifecycle state of a reconnecting client. +// switch client.State() { +// case stream.StateConnected: +// +// _ = client.Send(stream.Message{Type: stream.TypePing}) +// +// case stream.StateConnecting: +// +// time.Sleep(100 * time.Millisecond) +// +// default: +// +// // disconnected +// } type ConnectionState int const ( @@ -186,26 +197,77 @@ const ( StateConnected ) -// Envelope wraps a frame with metadata for cross-instance transport. +// state := stream.StateConnected +// core.Print(nil, "connection state=%s", state.String()) +func (state ConnectionState) String() string { + switch state { + case StateConnecting: + return "connecting" + case StateConnected: + return "connected" + default: + return "disconnected" + } +} + +// envelope := stream.Envelope{ +// SourceID: "node-a", +// Channel: "block", +// Frame: []byte("template"), +// } type Envelope struct { SourceID string Channel string Frame []byte } -// Pipe connects src to dst: every frame published on src is forwarded to dst. -// Returns a stop function. Safe to call from multiple goroutines. +// Pipe connects src to dst. // // stop := stream.Pipe(zmqHub, wsHub) // defer stop() +// +// Published frames keep their channel. Broadcast frames stay broadcasts when the +// source exposes that hook. func Pipe(src Stream, dst Stream) func() { if src == nil || dst == nil || src == dst { return func() {} } - stop := src.Subscribe("*", func(frame []byte) { - _ = dst.Broadcast(frame) + type publishedFrameSource interface { + SubscribePublished(handler func(string, []byte)) func() + } + type broadcastFrameSource interface { + SubscribeBroadcast(handler func([]byte)) func() + } + stops := make([]func(), 0, 2) + if publisher, ok := src.(publishedFrameSource); ok { + stops = append(stops, onceFunction(publisher.SubscribePublished(func(channel string, frame []byte) { + if err := dst.Publish(channel, cloneFrame(frame)); err != nil { + return + } + }))) + } + if broadcaster, ok := src.(broadcastFrameSource); ok { + stops = append(stops, onceFunction(broadcaster.SubscribeBroadcast(func(frame []byte) { + if err := dst.Broadcast(cloneFrame(frame)); err != nil { + return + } + }))) + } + if len(stops) == 0 { + // Generic Stream implementations do not expose channel names, so fall back + // to publishing on the wildcard channel. + stop := src.Subscribe("*", func(frame []byte) { + if err := dst.Publish("*", cloneFrame(frame)); err != nil { + return + } + }) + return onceFunction(stop) + } + return onceFunction(func() { + for index := len(stops) - 1; index >= 0; index-- { + stops[index]() + } }) - return stop } // Ensure Hub satisfies Stream at compile time. @@ -219,9 +281,12 @@ var ( _ time.Duration ) -func randomID() string { +// id := randomUUID() // "a1b2c3d4-e5f6-4a7b-8c9d-e0f1a2b3c4d5" +func randomUUID() string { var raw [16]byte _, _ = rand.Read(raw[:]) + raw[6] = (raw[6] & 0x0f) | 0x40 + raw[8] = (raw[8] & 0x3f) | 0x80 return hex.EncodeToString(raw[:4]) + "-" + hex.EncodeToString(raw[4:6]) + "-" + hex.EncodeToString(raw[6:8]) + "-" + @@ -229,6 +294,8 @@ func randomID() string { hex.EncodeToString(raw[10:]) } +// wire := encodeTCPFrame("block", []byte("template")) +// _ = conn.Write(wire) func encodeTCPFrame(channel string, frame []byte) []byte { channelBytes := []byte(channel) payloadLength := uint32(4 + len(channelBytes) + len(frame)) @@ -239,3 +306,24 @@ func encodeTCPFrame(channel string, frame []byte) []byte { copy(output[8+len(channelBytes):], frame) return output } + +// copy := cloneFrame(original) +func cloneFrame(frame []byte) []byte { + if len(frame) == 0 { + return nil + } + return append([]byte(nil), frame...) +} + +// stop := onceFunction(func() { unsubscribe() }) +// stop() // executes once +// stop() // no-op +func onceFunction(handler func()) func() { + if handler == nil { + return func() {} + } + var once sync.Once + return func() { + once.Do(handler) + } +} diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..af12bf3 --- /dev/null +++ b/stream_test.go @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package stream + +import ( + "sync" + "testing" +) + +func TestAX7_ConnectionState_String_Good(t *testing.T) { + cases := []struct { + state ConnectionState + expected string + }{ + {StateDisconnected, "disconnected"}, + {StateConnecting, "connecting"}, + {StateConnected, "connected"}, + } + for _, testCase := range cases { + if testCase.state.String() != testCase.expected { + t.Fatalf("ConnectionState(%d).String() = %q, want %q", testCase.state, testCase.state.String(), testCase.expected) + } + } +} + +func TestAX7_ConnectionState_String_Bad(t *testing.T) { + // Unknown ConnectionState value falls through to default ("disconnected"). + unknown := ConnectionState(99) + if unknown.String() != "disconnected" { + t.Fatalf("ConnectionState(99).String() = %q, want %q", unknown.String(), "disconnected") + } +} + +func TestAX7_ConnectionState_String_Ugly(t *testing.T) { + // Negative ConnectionState value still returns "disconnected". + negative := ConnectionState(-1) + if negative.String() != "disconnected" { + t.Fatalf("ConnectionState(-1).String() = %q, want %q", negative.String(), "disconnected") + } +} + +func TestEnvelope_Fields_Good(t *testing.T) { + envelope := Envelope{ + SourceID: "node-a", + Channel: "block", + Frame: []byte("template"), + } + if envelope.SourceID != "node-a" { + t.Fatalf("Envelope.SourceID = %q, want %q", envelope.SourceID, "node-a") + } + if envelope.Channel != "block" { + t.Fatalf("Envelope.Channel = %q, want %q", envelope.Channel, "block") + } + if string(envelope.Frame) != "template" { + t.Fatalf("Envelope.Frame = %q, want %q", string(envelope.Frame), "template") + } +} + +func TestEnvelope_Fields_Bad(t *testing.T) { + // Zero-value Envelope has empty fields — no panic. + envelope := Envelope{} + if envelope.SourceID != "" { + t.Fatalf("zero Envelope.SourceID = %q, want empty", envelope.SourceID) + } + if envelope.Channel != "" { + t.Fatalf("zero Envelope.Channel = %q, want empty", envelope.Channel) + } + if envelope.Frame != nil { + t.Fatalf("zero Envelope.Frame = %v, want nil", envelope.Frame) + } +} + +func TestEnvelope_Fields_Ugly(t *testing.T) { + // Envelope with nil frame does not panic on len(). + envelope := Envelope{SourceID: "test", Frame: nil} + if len(envelope.Frame) != 0 { + t.Fatalf("len(nil Envelope.Frame) = %d, want 0", len(envelope.Frame)) + } +} + +func TestAX7_NewPeer_Good(t *testing.T) { + peer := NewPeer("ws") + if peer == nil { + t.Fatal("NewPeer() = nil") + } + if peer.ID == "" { + t.Fatal("NewPeer().ID is empty") + } + if peer.Transport != "ws" { + t.Fatalf("NewPeer().Transport = %q, want %q", peer.Transport, "ws") + } + if peer.Claims == nil { + t.Fatal("NewPeer().Claims = nil, want empty map") + } + if peer.SendQueue() == nil { + t.Fatal("NewPeer().SendQueue() = nil, want channel") + } +} + +func TestAX7_NewPeer_Bad(t *testing.T) { + // NewPeer with empty transport creates a valid peer. + peer := NewPeer("") + if peer == nil { + t.Fatal("NewPeer('') = nil") + } + if peer.Transport != "" { + t.Fatalf("NewPeer('').Transport = %q, want empty", peer.Transport) + } +} + +func TestAX7_NewPeer_Ugly(t *testing.T) { + // Two peers created simultaneously have different IDs. + peer1 := NewPeer("ws") + peer2 := NewPeer("ws") + if peer1.ID == peer2.ID { + t.Fatalf("two NewPeer() calls produced the same ID: %q", peer1.ID) + } +} + +func TestAX7_Peer_Send_Good(t *testing.T) { + peer := NewPeer("ws") + ok := peer.Send([]byte("hello")) + if !ok { + t.Fatal("Send() returned false, want true") + } + select { + case frame := <-peer.SendQueue(): + if string(frame) != "hello" { + t.Fatalf("received frame = %q, want %q", string(frame), "hello") + } + default: + t.Fatal("no frame received from SendQueue()") + } +} + +func TestAX7_Peer_Send_Bad(t *testing.T) { + // Send to nil peer returns false without panic. + var peer *Peer + ok := peer.Send([]byte("hello")) + if ok { + t.Fatal("nil peer Send() = true, want false") + } +} + +func TestAX7_Peer_Send_Ugly(t *testing.T) { + // Send after Close returns false without panic. + peer := NewPeer("ws") + peer.Close() + ok := peer.Send([]byte("hello")) + if ok { + t.Fatal("Send() after Close() = true, want false") + } +} + +func TestAX7_Peer_Close_Ugly(t *testing.T) { + // Double Close does not panic. + peer := NewPeer("ws") + peer.Close() + peer.Close() +} + +func TestAX7_Peer_SetCloseHook_Good(t *testing.T) { + peer := NewPeer("ws") + invoked := false + peer.SetCloseHook(func() { invoked = true }) + peer.Close() + if !invoked { + t.Fatal("close hook was not invoked") + } +} + +func TestAX7_Peer_SetCloseHook_Bad(t *testing.T) { + // SetCloseHook on nil peer does not panic. + var peer *Peer + peer.SetCloseHook(func() {}) + if peer != nil { + t.Fatal("nil peer changed after SetCloseHook") + } +} + +func TestAX7_Peer_SendQueue_Bad(t *testing.T) { + // SendQueue on nil peer returns nil. + var peer *Peer + if peer.SendQueue() != nil { + t.Fatal("nil peer SendQueue() != nil") + } +} + +func TestPeer_Subscriptions_SortedCopy_Good(t *testing.T) { + // Subscriptions returns a sorted copy. + peer := NewPeer("ws") + peer.mutex.Lock() + peer.subscriptions["block"] = true + peer.subscriptions["hashrate"] = true + peer.subscriptions["agent"] = true + peer.mutex.Unlock() + + subs := peer.Subscriptions() + expected := []string{"agent", "block", "hashrate"} + if len(subs) != len(expected) { + t.Fatalf("Subscriptions() length = %d, want %d", len(subs), len(expected)) + } + for index, channel := range expected { + if subs[index] != channel { + t.Fatalf("Subscriptions()[%d] = %q, want %q", index, subs[index], channel) + } + } +} + +func TestPipe_NilStreams_Good(t *testing.T) { + // Pipe with nil src returns a no-op stop function without panic. + stop := Pipe(nil, NewHub()) + stop() + + // Pipe with nil dst returns a no-op stop function without panic. + stop = Pipe(NewHub(), nil) + stop() +} + +func TestPipe_SameStream_Bad(t *testing.T) { + // Pipe with src == dst returns a no-op stop function (no infinite loop). + hub := NewHub() + stop := Pipe(hub, hub) + stop() +} + +func TestPipe_StopConcurrency_Ugly(t *testing.T) { + // Calling stop multiple times concurrently does not panic. + hub1 := NewHub() + hub2 := NewHub() + stop := Pipe(hub1, hub2) + var waitGroup sync.WaitGroup + for index := 0; index < 10; index++ { + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + stop() + }() + } + waitGroup.Wait() +} + +func TestEncodeTCPFrame_Good(t *testing.T) { + frame := encodeTCPFrame("block", []byte("template")) + if len(frame) == 0 { + t.Fatal("encodeTCPFrame() produced empty output") + } + // The frame should contain the payload length prefix, channel length, channel, and data. + // Total: 4 (payload len) + 4 (channel len) + 5 ("block") + 8 ("template") = 21 + if len(frame) != 21 { + t.Fatalf("encodeTCPFrame() len = %d, want %d", len(frame), 21) + } +} + +func TestEncodeTCPFrame_Bad(t *testing.T) { + // Empty channel and empty frame produces a minimal valid frame. + frame := encodeTCPFrame("", []byte{}) + // 4 (payload len) + 4 (channel len=0) + 0 (channel) + 0 (frame) = 8 + if len(frame) != 8 { + t.Fatalf("encodeTCPFrame('', []) len = %d, want %d", len(frame), 8) + } +} + +func TestCloneFrame_Good(t *testing.T) { + original := []byte("hello") + cloned := cloneFrame(original) + if string(cloned) != "hello" { + t.Fatalf("cloneFrame() = %q, want %q", string(cloned), "hello") + } + // Modifying the clone should not affect the original. + cloned[0] = 'H' + if string(original) != "hello" { + t.Fatalf("modifying clone affected original: %q", string(original)) + } +} + +func TestCloneFrame_Bad(t *testing.T) { + // cloneFrame of nil returns nil. + cloned := cloneFrame(nil) + if cloned != nil { + t.Fatalf("cloneFrame(nil) = %v, want nil", cloned) + } +} + +func TestCloneFrame_Ugly(t *testing.T) { + // cloneFrame of empty slice returns nil. + cloned := cloneFrame([]byte{}) + if cloned != nil { + t.Fatalf("cloneFrame([]byte{}) = %v, want nil", cloned) + } +} + +func TestOnceFunction_Good(t *testing.T) { + count := 0 + handler := onceFunction(func() { count++ }) + handler() + handler() + handler() + if count != 1 { + t.Fatalf("onceFunction handler invoked %d times, want 1", count) + } +} + +func TestOnceFunction_Bad(t *testing.T) { + // onceFunction with nil handler returns a no-op function. + handler := onceFunction(nil) + if handler == nil { + t.Fatal("onceFunction(nil) returned nil") + } + handler() // should not panic +} + +func TestOnceFunction_Ugly(t *testing.T) { + // Concurrent calls to onceFunction result execute the handler exactly once. + count := 0 + var counterMutex sync.Mutex + handler := onceFunction(func() { + counterMutex.Lock() + count++ + counterMutex.Unlock() + }) + var waitGroup sync.WaitGroup + for index := 0; index < 50; index++ { + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + handler() + }() + } + waitGroup.Wait() + if count != 1 { + t.Fatalf("concurrent onceFunction handler invoked %d times, want 1", count) + } +} + +func TestRandomUUID_Good(t *testing.T) { + id := randomUUID() + if len(id) != 36 { + t.Fatalf("randomUUID() length = %d, want 36", len(id)) + } + // Verify UUID v4 format: 8-4-4-4-12 + if id[8] != '-' || id[13] != '-' || id[18] != '-' || id[23] != '-' { + t.Fatalf("randomUUID() = %q, not in UUID format", id) + } +} + +func TestRandomUUID_Bad(t *testing.T) { + // Two calls produce different UUIDs. + id1 := randomUUID() + id2 := randomUUID() + if id1 == id2 { + t.Fatalf("randomUUID() produced duplicate: %q", id1) + } +} diff --git a/tests/cli/stream/Taskfile.yaml b/tests/cli/stream/Taskfile.yaml new file mode 100644 index 0000000..84c8ef6 --- /dev/null +++ b/tests/cli/stream/Taskfile.yaml @@ -0,0 +1,26 @@ +version: "3" + +tasks: + default: + deps: + - build + - vet + - test + + build: + desc: Compile every package in go-stream. + dir: ../../.. + cmds: + - GOWORK=off go build ./... + + vet: + desc: Run go vet across the module. + dir: ../../.. + cmds: + - GOWORK=off go vet ./... + + test: + desc: Run unit tests. + dir: ../../.. + cmds: + - GOWORK=off go test -count=1 ./... diff --git a/ws/ax7_more_test.go b/ws/ax7_more_test.go new file mode 100644 index 0000000..753708e --- /dev/null +++ b/ws/ax7_more_test.go @@ -0,0 +1,290 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package ws + +import ( + "github.com/alicebob/miniredis/v2" + "github.com/gorilla/websocket" + + core "dappco.re/go" + "dappco.re/go/stream" + adapterredis "dappco.re/go/stream/adapter/redis" +) + +func TestAX7_DefaultHubConfig_Good(t *core.T) { + config := DefaultHubConfig() + + core.AssertEqual(t, 30*core.Second, config.HeartbeatInterval) + core.AssertEqual(t, 60*core.Second, config.PongTimeout) + core.AssertEqual(t, 10*core.Second, config.WriteTimeout) +} + +func TestAX7_DefaultHubConfig_Bad(t *core.T) { + config := DefaultHubConfig() + + core.AssertNil(t, config.OnConnect) + core.AssertNil(t, config.OnDisconnect) + core.AssertNil(t, config.ChannelAuthoriser) +} + +func TestAX7_DefaultHubConfig_Ugly(t *core.T) { + config := DefaultHubConfig() + + core.AssertGreater(t, config.PongTimeout, config.HeartbeatInterval) + core.AssertGreater(t, config.WriteTimeout, core.Duration(0)) +} + +func TestAX7_New_Good(t *core.T) { + adapter := New(Config{ReadBufferSize: 2048, WriteBufferSize: 4096}) + + core.AssertNotNil(t, adapter) + core.AssertNotNil(t, adapter.Handler()) +} + +func TestAX7_New_Bad(t *core.T) { + adapter := New(Config{}) + + core.AssertNotNil(t, adapter) + core.AssertNotNil(t, adapter.HandlerForChannel("events")) +} + +func TestAX7_New_Ugly(t *core.T) { + called := false + adapter := New(Config{CheckOrigin: func(*core.Request) bool { called = true; return true }}) + + core.AssertTrue(t, adapter.Handler() != nil) + core.AssertTrue(t, adapter.HandlerForChannel("x") != nil) + core.AssertFalse(t, called) +} + +func TestAX7_NewAPIKeyAuth_Good(t *core.T) { + authenticator := NewAPIKeyAuth(map[string]string{"sk": "user"}) + + core.AssertNotNil(t, authenticator) + core.AssertEqual(t, "user", authenticator.Keys["sk"]) +} + +func TestAX7_NewAPIKeyAuth_Bad(t *core.T) { + authenticator := NewAPIKeyAuth(nil) + + core.AssertNotNil(t, authenticator) + core.AssertEqual(t, 0, len(authenticator.Keys)) +} + +func TestAX7_NewAPIKeyAuth_Ugly(t *core.T) { + keys := map[string]string{"sk": "user"} + authenticator := NewAPIKeyAuth(keys) + keys["sk"] = "mutated" + + core.AssertEqual(t, "user", authenticator.Keys["sk"]) + core.AssertEqual(t, "mutated", keys["sk"]) +} + +func TestAX7_NewHub_Good(t *core.T) { + hub := NewHub() + + core.AssertNotNil(t, hub) + core.AssertNotNil(t, hub.Hub) + core.AssertFalse(t, hub.Running()) +} + +func TestAX7_NewHub_Bad(t *core.T) { + hub := NewHub() + + core.AssertEqual(t, 30*core.Second, hub.Config().HeartbeatInterval) + core.AssertEqual(t, 0, hub.PeerCount()) +} + +func TestAX7_NewHub_Ugly(t *core.T) { + left := NewHub() + right := NewHub() + + core.AssertNotEqual(t, left, right) + core.AssertNotEqual(t, left.Hub, right.Hub) +} + +func TestAX7_NewHubWithConfig_Good(t *core.T) { + hub := NewHubWithConfig(HubConfig{HeartbeatInterval: core.Second, PongTimeout: 3 * core.Second}) + + core.AssertEqual(t, core.Second, hub.Config().HeartbeatInterval) + core.AssertEqual(t, 3*core.Second, hub.Config().PongTimeout) +} + +func TestAX7_NewHubWithConfig_Bad(t *core.T) { + hub := NewHubWithConfig(HubConfig{}) + + core.AssertEqual(t, 30*core.Second, hub.Config().HeartbeatInterval) + core.AssertEqual(t, 60*core.Second, hub.Config().PongTimeout) +} + +func TestAX7_NewHubWithConfig_Ugly(t *core.T) { + called := false + hub := NewHubWithConfig(HubConfig{OnConnect: func(*Peer) { called = true }}) + + core.AssertNoError(t, hub.AddPeer(NewPeer("ws"))) + core.AssertTrue(t, called) +} + +func TestAX7_NewPeer_Good(t *core.T) { + peer := NewPeer("ws") + + core.AssertNotNil(t, peer) + core.AssertEqual(t, "ws", peer.Transport) + core.AssertNotEmpty(t, peer.ID) +} + +func TestAX7_NewPeer_Bad(t *core.T) { + peer := NewPeer("") + + core.AssertNotNil(t, peer) + core.AssertEqual(t, "", peer.Transport) + core.AssertNotNil(t, peer.SendQueue()) +} + +func TestAX7_NewPeer_Ugly(t *core.T) { + left := NewPeer("ws") + right := NewPeer("ws") + + core.AssertNotEqual(t, left.ID, right.ID) + core.AssertEqual(t, "ws", right.Transport) +} + +func TestAX7_NewReconnectingClient_Good(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{URL: "ws://127.0.0.1/stream/ws"}) + + core.AssertNotNil(t, client) + core.AssertEqual(t, stream.StateDisconnected, client.State()) + core.AssertNoError(t, client.Close()) +} + +func TestAX7_NewReconnectingClient_Bad(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{}) + + core.AssertNotNil(t, client) + core.AssertEqual(t, StateDisconnected, client.State()) +} + +func TestAX7_NewReconnectingClient_Ugly(t *core.T) { + client := NewReconnectingClient(ReconnectConfig{URL: "://bad-url", MaxRetries: 1, InitialBackoff: core.Millisecond}) + + err := client.Connect(core.Background()) + core.AssertError(t, err) + core.AssertEqual(t, StateDisconnected, client.State()) +} + +func TestAX7_NewRedisBridge_Good(t *core.T) { + redisServer := miniredis.RunT(t) + bridge, err := NewRedisBridge(NewHub(), adapterredis.Config{Addr: redisServer.Addr(), Prefix: "pool"}) + + core.AssertNoError(t, err) + core.AssertNotNil(t, bridge) + core.AssertNotEmpty(t, bridge.SourceID()) +} + +func TestAX7_NewRedisBridge_Bad(t *core.T) { + bridge, err := NewRedisBridge("unsupported", adapterredis.Config{}) + + core.AssertError(t, err) + core.AssertNil(t, bridge) +} + +func TestAX7_NewRedisBridge_Ugly(t *core.T) { + var hub *Hub + redisServer := miniredis.RunT(t) + + bridge, err := NewRedisBridge(hub, adapterredis.Config{Addr: redisServer.Addr(), Prefix: "pool"}) + core.AssertError(t, err) + core.AssertNil(t, bridge) +} + +func TestAX7_Pipe_Good(t *core.T) { + source := NewHub() + destination := NewHub() + + stop := Pipe(source, destination) + core.AssertNotNil(t, stop) + stop() +} + +func TestAX7_Pipe_Bad(t *core.T) { + stop := Pipe(nil, NewHub()) + + core.AssertNotNil(t, stop) + core.AssertNotPanics(t, stop) +} + +func TestAX7_Pipe_Ugly(t *core.T) { + hub := NewHub() + stop := Pipe(hub, hub) + + core.AssertNotNil(t, stop) + core.AssertNotPanics(t, stop) +} + +func TestAX7_Hub_Handler_Good(t *core.T) { + hub := NewHub() + ctx, cancel := core.WithCancel(core.Background()) + defer cancel() + go hub.Run(ctx) + waitForRunningHub(t, hub) + server := core.NewHTTPTestServer(hub.Handler()) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial("ws"+server.URL[len("http"):], nil) + core.AssertNoError(t, err) + core.AssertNoError(t, conn.Close()) +} + +func TestAX7_Hub_Handler_Bad(t *core.T) { + var hub *Hub + handler := hub.Handler() + recorder := core.NewHTTPTestRecorder() + + handler.ServeHTTP(recorder, core.NewHTTPTestRequest("GET", "/stream/ws", nil)) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not mounted") +} + +func TestAX7_Hub_Handler_Ugly(t *core.T) { + hub := NewHub() + handler := hub.Handler() + recorder := core.NewHTTPTestRecorder() + + handler.ServeHTTP(recorder, core.NewHTTPTestRequest("GET", "/stream/ws", nil)) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not running") +} + +func TestAX7_Hub_HandlerForChannel_Good(t *core.T) { + hub := NewHub() + ctx, cancel := core.WithCancel(core.Background()) + defer cancel() + go hub.Run(ctx) + waitForRunningHub(t, hub) + server := core.NewHTTPTestServer(hub.HandlerForChannel("hashrate")) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial("ws"+server.URL[len("http"):], nil) + core.AssertNoError(t, err) + core.AssertNoError(t, conn.Close()) +} + +func TestAX7_Hub_HandlerForChannel_Bad(t *core.T) { + var hub *Hub + handler := hub.HandlerForChannel("hashrate") + recorder := core.NewHTTPTestRecorder() + + handler.ServeHTTP(recorder, core.NewHTTPTestRequest("GET", "/stream/ws", nil)) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not mounted") +} + +func TestAX7_Hub_HandlerForChannel_Ugly(t *core.T) { + hub := NewHub() + handler := hub.HandlerForChannel("hashrate") + recorder := core.NewHTTPTestRecorder() + + handler.ServeHTTP(recorder, core.NewHTTPTestRequest("GET", "/stream/ws", nil)) + core.AssertEqual(t, 500, recorder.Code) + core.AssertContains(t, recorder.Body.String(), "not running") +} diff --git a/ws/compat.go b/ws/compat.go new file mode 100644 index 0000000..1294e9e --- /dev/null +++ b/ws/compat.go @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: EUPL-1.2 + +// hub := ws.NewHub() +// go hub.Run(ctx) +// http.Handle("/stream/ws", hub.Handler()) +package ws + +import ( + "net/http" + "sync" + + "dappco.re/go" + "dappco.re/go/stream" + adapterredis "dappco.re/go/stream/adapter/redis" + adapterws "dappco.re/go/stream/adapter/ws" +) + +// Stream preserves the transport-agnostic stream interface for legacy callers. +type Stream = stream.Stream + +// Frame preserves the legacy raw payload alias. +type Frame = stream.Frame + +// Channel preserves the legacy channel name alias. +type Channel = stream.Channel + +// HubConfig preserves the legacy go-ws HubConfig type name. +type HubConfig = stream.HubConfig + +// ChannelAuthoriser preserves the legacy go-ws channel authoriser type name. +type ChannelAuthoriser = stream.ChannelAuthoriser + +// HubStats preserves the legacy hub stats type name. +type HubStats = stream.HubStats + +// Peer preserves the transport-agnostic peer type under the legacy package. +type Peer = stream.Peer + +// Client preserves the legacy go-ws Client type name. +type Client = stream.Peer + +// Authenticator preserves the legacy go-ws Authenticator type name. +type Authenticator = stream.Authenticator + +// AuthenticatorFunc preserves the legacy go-ws AuthenticatorFunc helper. +type AuthenticatorFunc = stream.AuthenticatorFunc + +// AuthResult preserves the legacy go-ws AuthResult type name. +type AuthResult = stream.AuthResult + +// APIKeyAuthenticator preserves the legacy API key authenticator type name. +type APIKeyAuthenticator = stream.APIKeyAuthenticator + +// BearerTokenAuth preserves the legacy bearer-token authenticator type name. +type BearerTokenAuth = stream.BearerTokenAuth + +// QueryTokenAuth preserves the legacy query-token authenticator type name. +type QueryTokenAuth = stream.QueryTokenAuth + +// ConnAuthenticator preserves the legacy raw-connection authenticator name. +type ConnAuthenticator = stream.ConnAuthenticator + +// ConnAuthenticatorFunc preserves the legacy raw-connection helper name. +type ConnAuthenticatorFunc = stream.ConnAuthenticatorFunc + +// ConnectionState preserves the reconnecting client connection state type. +type ConnectionState = stream.ConnectionState + +// Message preserves the legacy go-ws WebSocket message envelope. +type Message = stream.Message + +// MessageType preserves the legacy go-ws message type name. +type MessageType = stream.MessageType + +const ( + // TypeProcessOutput preserves the legacy message type constant. + TypeProcessOutput = stream.TypeProcessOutput + // TypeProcessStatus preserves the legacy message type constant. + TypeProcessStatus = stream.TypeProcessStatus + // TypeEvent preserves the legacy message type constant. + TypeEvent = stream.TypeEvent + // TypeError preserves the legacy message type constant. + TypeError = stream.TypeError + // TypePing preserves the legacy message type constant. + TypePing = stream.TypePing + // TypePong preserves the legacy message type constant. + TypePong = stream.TypePong + // TypeSubscribe preserves the legacy message type constant. + TypeSubscribe = stream.TypeSubscribe + // TypeUnsubscribe preserves the legacy message type constant. + TypeUnsubscribe = stream.TypeUnsubscribe + // StateDisconnected preserves the reconnecting client disconnected state. + StateDisconnected = stream.StateDisconnected + // StateConnecting preserves the reconnecting client connecting state. + StateConnecting = stream.StateConnecting + // StateConnected preserves the reconnecting client connected state. + StateConnected = stream.StateConnected +) + +var ( + // ErrMissingAuthHeader preserves the legacy missing-header sentinel error. + ErrMissingAuthHeader = stream.ErrMissingAuthHeader + // ErrMalformedAuthHeader preserves the legacy malformed-header sentinel error. + ErrMalformedAuthHeader = stream.ErrMalformedAuthHeader + // ErrInvalidAPIKey preserves the legacy invalid API key sentinel error. + ErrInvalidAPIKey = stream.ErrInvalidAPIKey + // ErrHandshakeTimeout preserves the legacy handshake timeout sentinel error. + ErrHandshakeTimeout = stream.ErrHandshakeTimeout + // ErrAuthRejected preserves the legacy authenticator rejection sentinel error. + ErrAuthRejected = stream.ErrAuthRejected + // ErrHubNotRunning preserves the legacy hub lifecycle sentinel error. + ErrHubNotRunning = stream.ErrHubNotRunning + // ErrEmptyChannel preserves the legacy empty-channel sentinel error. + ErrEmptyChannel = stream.ErrEmptyChannel +) + +// Adapter preserves the legacy WebSocket adapter type name. +type Adapter = adapterws.Adapter + +// Config preserves the legacy WebSocket adapter configuration type name. +type Config = adapterws.Config + +// ReconnectConfig preserves the legacy reconnecting WebSocket configuration type name. +type ReconnectConfig = adapterws.ReconnectConfig + +// RedisBridge preserves the legacy go-ws RedisBridge type name. +type RedisBridge = adapterredis.Bridge + +// Hub preserves the legacy go-ws Hub surface while embedding the new stream hub. +// +// hub := ws.NewHub() +// go hub.Run(ctx) +// http.Handle("/stream/ws", hub.Handler()) +type Hub struct { + *stream.Hub + + adapterOnce sync.Once + adapter *adapterws.Adapter +} + +// bridge, err := ws.NewRedisBridge(hub, redis.Config{Addr: "redis:6379", Prefix: "pool"}) +func NewRedisBridge(hub any, config adapterredis.Config) (*RedisBridge, error) { + switch typedHub := hub.(type) { + case *Hub: + if typedHub == nil { + return adapterredis.NewBridge(nil, config) + } + return adapterredis.NewBridge(typedHub.Hub, config) + case *stream.Hub: + return adapterredis.NewBridge(typedHub, config) + default: + return nil, core.E("stream.ws", "unsupported hub type", nil) + } +} + +// auth := ws.NewAPIKeyAuth(map[string]string{"sk-live": "user-42"}) +func NewAPIKeyAuth(keys map[string]string) *APIKeyAuthenticator { + return stream.NewAPIKeyAuth(keys) +} + +// hub := ws.NewHub() +func NewHub() *Hub { + return &Hub{Hub: stream.NewHub()} +} + +// hub := ws.NewHubWithConfig(stream.HubConfig{HeartbeatInterval: 30 * time.Second}) +func NewHubWithConfig(config HubConfig) *Hub { + return &Hub{Hub: stream.NewHubWithConfig(config)} +} + +// config := ws.DefaultHubConfig() +func DefaultHubConfig() HubConfig { + return stream.DefaultHubConfig() +} + +// peer := ws.NewPeer("ws") +func NewPeer(transport string) *Peer { + return stream.NewPeer(transport) +} + +// stop := ws.Pipe(sourceHub, destinationHub) +func Pipe(source Stream, destination Stream) func() { + return stream.Pipe(source, destination) +} + +// adapter := ws.New(ws.Config{Authenticator: auth}) +func New(config Config) *Adapter { + return adapterws.New(config) +} + +// client := ws.NewReconnectingClient(ws.ReconnectConfig{URL: "ws://127.0.0.1:8080/stream/ws"}) +func NewReconnectingClient(config ReconnectConfig) *adapterws.ReconnectingClient { + return adapterws.NewReconnectingClient(config) +} + +// Handler preserves the old hub-bound WebSocket handler entrypoint. +// +// http.Handle("/stream/ws", hub.Handler()) +func (hub *Hub) Handler() http.HandlerFunc { + if hub == nil { + return func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "stream hub not mounted", http.StatusInternalServerError) + } + } + return hub.compatAdapter().Handler() +} + +// HandlerForChannel preserves the old dedicated-channel handler entrypoint. +// +// http.Handle("/stream/hashrate", hub.HandlerForChannel("hashrate")) +func (hub *Hub) HandlerForChannel(channel string) http.HandlerFunc { + if hub == nil { + return func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "stream hub not mounted", http.StatusInternalServerError) + } + } + return hub.compatAdapter().HandlerForChannel(channel) +} + +func (hub *Hub) compatAdapter() *adapterws.Adapter { + hub.adapterOnce.Do(func() { + adapter := adapterws.New(adapterws.Config{}) + adapter.Mount(hub.Hub) + hub.adapter = adapter + }) + return hub.adapter +} + +var _ Stream = (*Hub)(nil) diff --git a/ws/compat_test.go b/ws/compat_test.go new file mode 100644 index 0000000..e9965fa --- /dev/null +++ b/ws/compat_test.go @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package ws + +import ( + "context" + "net/http/httptest" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +func TestCompat_LegacySurface_Good(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{"valid-key": "user-1"}) + if auth == nil { + t.Fatal("NewAPIKeyAuth() = nil") + } + + var frame Frame = []byte("payload") + if string(frame) != "payload" { + t.Fatalf("Frame alias produced %q, want %q", string(frame), "payload") + } + + var channel Channel = "hashrate" + if channel != "hashrate" { + t.Fatalf("Channel alias produced %q, want %q", channel, "hashrate") + } + + var authoriser ChannelAuthoriser + if authoriser != nil { + t.Fatal("ChannelAuthoriser alias should default to nil") + } + + if StateDisconnected != 0 || StateConnecting != 1 || StateConnected != 2 { + t.Fatalf("unexpected connection states: %d %d %d", StateDisconnected, StateConnecting, StateConnected) + } + + if ErrMissingAuthHeader == nil || ErrMalformedAuthHeader == nil || ErrInvalidAPIKey == nil { + t.Fatal("expected auth sentinel errors to be re-exported") + } + if ErrHandshakeTimeout == nil || ErrAuthRejected == nil || ErrHubNotRunning == nil || ErrEmptyChannel == nil { + t.Fatal("expected transport sentinel errors to be re-exported") + } + + sourceHub := NewHub() + destinationHub := NewHub() + + sourceContext, sourceCancel := context.WithCancel(context.Background()) + defer sourceCancel() + destinationContext, destinationCancel := context.WithCancel(context.Background()) + defer destinationCancel() + + go sourceHub.Run(sourceContext) + go destinationHub.Run(destinationContext) + waitForRunningHub(t, sourceHub) + waitForRunningHub(t, destinationHub) + + received := make(chan []byte, 1) + unsubscribe := destinationHub.Subscribe("hashrate", func(frame []byte) { + received <- append([]byte(nil), frame...) + }) + defer unsubscribe() + + stop := Pipe(sourceHub, destinationHub) + defer stop() + + if err := sourceHub.Publish("hashrate", []byte("123456")); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + select { + case frame := <-received: + if string(frame) != "123456" { + t.Fatalf("received frame = %q, want %q", string(frame), "123456") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for piped frame") + } + + peer := NewPeer("ws") + if peer == nil { + t.Fatal("NewPeer() = nil") + } + if peer.Transport != "ws" { + t.Fatalf("peer.Transport = %q, want %q", peer.Transport, "ws") + } + + stats := destinationHub.Stats() + var _ HubStats = stats +} + +func TestCompat_LegacySurface_Bad(t *testing.T) { + hub := NewHub() + + if err := hub.Publish("hashrate", []byte("123456")); err != ErrHubNotRunning { + t.Fatalf("Publish() error = %v, want %v", err, ErrHubNotRunning) + } + + peer := NewPeer("ws") + if err := hub.SubscribePeer(peer, ""); err != ErrEmptyChannel { + t.Fatalf("SubscribePeer() error = %v, want %v", err, ErrEmptyChannel) + } +} + +func TestCompat_LegacySurface_Ugly(t *testing.T) { + var source Stream + stop := Pipe(source, source) + if stop == nil { + t.Fatal("Pipe(nil, nil) returned nil stop function") + } + stop() +} + +func TestCompat_HubHandler_Good(t *testing.T) { + hub := NewHub() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + waitForRunningHub(t, hub) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + url := "ws" + server.URL[len("http"):] + "?channel=hashrate" + connection, _, err := websocket.DefaultDialer.Dial(url, nil) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + defer connection.Close() + + payload := []byte(`{"type":"event","channel":"hashrate","data":{"h":123456},"timestamp":"2026-01-01T00:00:00Z"}`) + if err := hub.Publish("hashrate", payload); err != nil { + t.Fatalf("Publish() error = %v", err) + } + + if err := connection.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("SetReadDeadline() error = %v", err) + } + _, frame, err := connection.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + if string(frame) != string(payload) { + t.Fatalf("ReadMessage() frame = %q, want %q", string(frame), string(payload)) + } +} + +func waitForRunningHub(t *testing.T, hub *Hub) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if hub.Publish("health", nil) == nil { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for hub to start") +} diff --git a/ws/example_test.go b/ws/example_test.go new file mode 100644 index 0000000..d76f43f --- /dev/null +++ b/ws/example_test.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package ws_test + +import ( + "context" + "net/http" + + "dappco.re/go/stream/ws" +) + +func ExampleHub_Handler() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hub := ws.NewHub() + go hub.Run(ctx) + + http.Handle("/stream/ws", hub.Handler()) +}