diff --git a/.golangci.yml b/.golangci.yml index 774475b..46cba85 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,3 +1,5 @@ +version: "2" + run: timeout: 5m go: "1.26" @@ -8,15 +10,15 @@ linters: - errcheck - staticcheck - unused - - gosimple - ineffassign - - typecheck - gocritic - - gofmt disable: - exhaustive - wrapcheck issues: - exclude-use-default: false max-same-issues: 0 + +formatters: + enable: + - gofmt diff --git a/auth.go b/auth.go index a37daa2..2a0da1e 100644 --- a/auth.go +++ b/auth.go @@ -3,18 +3,27 @@ package ws import ( + // AX-6-exception: WebSocket requires HTTP upgrade (RFC 6455) "net/http" - "strings" + "reflect" + "unsafe" - core "dappco.re/go/core" - coreerr "dappco.re/go/core/log" + core "dappco.re/go" + coreerr "dappco.re/go/log" ) +const maxClaimsCloneDepth = 64 + // AuthResult holds the outcome of an authentication attempt. +// result := ws.AuthResult{Authenticated: true, UserID: "user-123"} type AuthResult struct { // Valid indicates whether authentication succeeded. Valid bool + // Authenticated is an RFC-compatible alias for Valid. The package + // treats either field as a successful authentication result. + Authenticated bool + // UserID is the authenticated user's identifier. UserID string @@ -26,16 +35,435 @@ type AuthResult struct { Error error } -// Authenticator validates an HTTP request during the WebSocket upgrade -// handshake. Implementations may inspect headers, query parameters, -// cookies, or any other request attribute. +// authenticatedResult builds a successful AuthResult with both success +// flags populated. +func authenticatedResult(userID string, claims map[string]any) AuthResult { + userID = core.Trim(userID) + if userID == "" { + return AuthResult{ + Valid: false, + Error: ErrMissingUserID, + } + } + + clonedClaims, ok := cloneClaims(claims) + if !ok { + return AuthResult{ + Valid: false, + Error: ErrInvalidAuthClaims, + } + } + + return AuthResult{ + Valid: true, + Authenticated: true, + UserID: userID, + Claims: clonedClaims, + } +} + +// normalizeAuthResult ensures the compatibility alias fields stay in sync. +func normalizeAuthResult(result AuthResult) AuthResult { + if result.Valid || result.Authenticated { + result.Valid = true + result.Authenticated = true + } + return result +} + +// authResultAccepted reports whether an authentication attempt succeeded. +func authResultAccepted(result AuthResult) bool { + return result.Valid || result.Authenticated +} + +// finalizeAuthResult rejects successful authentication results that do not +// provide a usable user identity. +func finalizeAuthResult(result AuthResult) AuthResult { + result = normalizeAuthResult(result) + if !authResultAccepted(result) { + return result + } + result.UserID = core.Trim(result.UserID) + if result.UserID == "" { + return AuthResult{ + Valid: false, + Error: ErrMissingUserID, + } + } + clonedClaims, ok := cloneClaims(result.Claims) + if !ok { + return AuthResult{ + Valid: false, + Error: ErrInvalidAuthClaims, + } + } + result.Claims = clonedClaims + return result +} + +// cloneClaims snapshots the auth claims map so caller-side mutations after +// authentication do not change the active session state. +func cloneClaims(claims map[string]any) (map[string]any, bool) { + if len(claims) == 0 { + return nil, true + } + + cloned := make(map[string]any, len(claims)) + seen := make(map[uintptr]reflect.Value) + for key, value := range claims { + clonedValue, ok := cloneClaimsValue(reflect.ValueOf(value), seen, 0) + if !ok { + return nil, false + } + cloned[key] = clonedValue + } + return cloned, true +} + +// cloneClaimsValue snapshots a claim value and rejects unsupported reference +// types so authentication sessions never retain caller-owned mutable state. +func cloneClaimsValue(v reflect.Value, seen map[uintptr]reflect.Value, depth int) (any, bool) { + if !v.IsValid() { + return nil, true + } + + if depth > maxClaimsCloneDepth { + return nil, false + } + + if !v.CanInterface() { + if !v.CanAddr() { + return nil, false + } + + v = reflect.ValueOf(valueInterface(v)) + } + + switch v.Kind() { + case reflect.Pointer: + if v.IsNil() { + return nil, true + } + + ptr := v.Pointer() + if cloned, ok := seen[ptr]; ok { + return cloned.Interface(), true + } + + clone := reflect.New(v.Elem().Type()) + seen[ptr] = clone + if !setClonedValue(clone.Elem(), v.Elem(), seen, depth+1) { + return nil, false + } + return clone.Interface(), true + case reflect.Map: + if v.IsNil() { + return nil, true + } + + ptr := v.Pointer() + if cloned, ok := seen[ptr]; ok { + return cloned.Interface(), true + } + + clone := reflect.MakeMapWithSize(v.Type(), v.Len()) + seen[ptr] = clone + iter := v.MapRange() + for iter.Next() { + clonedKey, ok := cloneClaimsValue(iter.Key(), seen, depth+1) + if !ok { + return nil, false + } + + keyValue := reflect.ValueOf(clonedKey) + if !keyValue.IsValid() { + return nil, false + } + if !keyValue.Type().AssignableTo(v.Type().Key()) { + if keyValue.Type().ConvertibleTo(v.Type().Key()) { + keyValue = keyValue.Convert(v.Type().Key()) + } else { + return nil, false + } + } + + clonedValue, ok := cloneClaimsValue(iter.Value(), seen, depth+1) + if !ok { + return nil, false + } + + if clonedValue == nil { + clone.SetMapIndex(keyValue, reflect.Zero(v.Type().Elem())) + continue + } + + value := reflect.ValueOf(clonedValue) + if value.Type().AssignableTo(v.Type().Elem()) { + clone.SetMapIndex(keyValue, value) + continue + } + if value.Type().ConvertibleTo(v.Type().Elem()) { + clone.SetMapIndex(keyValue, value.Convert(v.Type().Elem())) + continue + } + + return nil, false + } + return clone.Interface(), true + case reflect.Slice: + if v.IsNil() { + return nil, true + } + if v.Type().Elem().Kind() == reflect.Uint8 { + clone := make([]byte, v.Len()) + reflect.Copy(reflect.ValueOf(clone), v) + return clone, true + } + + ptr := v.Pointer() + if cloned, ok := seen[ptr]; ok { + return cloned.Interface(), true + } + + clone := reflect.MakeSlice(v.Type(), v.Len(), v.Len()) + seen[ptr] = clone + for i := 0; i < v.Len(); i++ { + if !setClonedValue(clone.Index(i), v.Index(i), seen, depth+1) { + return nil, false + } + } + return clone.Interface(), true + case reflect.Array: + clone := reflect.New(v.Type()).Elem() + for i := 0; i < v.Len(); i++ { + if !setClonedValue(clone.Index(i), v.Index(i), seen, depth+1) { + return nil, false + } + } + return clone.Interface(), true + case reflect.Struct: + clone := reflect.New(v.Type()).Elem() + clone.Set(v) + for i := 0; i < v.NumField(); i++ { + if !setClonedValue(clone.Field(i), v.Field(i), seen, depth+1) { + return nil, false + } + } + return clone.Interface(), true + case reflect.Interface: + if v.IsNil() { + return nil, true + } + return cloneClaimsValue(v.Elem(), seen, depth+1) + case reflect.Chan, reflect.Func, reflect.UnsafePointer: + return nil, false + default: + return valueInterface(v), true + } +} + +// deepCloneValue recursively copies common composite values so auth claims do +// not retain references to caller-owned mutable state. It preserves scalar +// values as-is and falls back to the original value for unsupported kinds. +func deepCloneValue(v reflect.Value) any { + cloned, _ := deepCloneValueWithState(v, make(map[uintptr]reflect.Value), 0) + return cloned +} + +func deepCloneValueWithState(v reflect.Value, seen map[uintptr]reflect.Value, depth int) (any, bool) { + if !v.IsValid() { + return nil, true + } + + if depth > maxClaimsCloneDepth { + return nil, false + } + + if !v.CanInterface() { + if !v.CanAddr() { + return nil, false + } + + v = reflect.ValueOf(valueInterface(v)) + } + + switch v.Kind() { + case reflect.Pointer: + if v.IsNil() { + return nil, true + } + + ptr := v.Pointer() + if cloned, ok := seen[ptr]; ok { + return cloned.Interface(), true + } + + clone := reflect.New(v.Elem().Type()) + seen[ptr] = clone + if !setClonedValue(clone.Elem(), v.Elem(), seen, depth+1) { + return nil, false + } + return clone.Interface(), true + case reflect.Map: + if v.IsNil() { + return nil, true + } + + ptr := v.Pointer() + if cloned, ok := seen[ptr]; ok { + return cloned.Interface(), true + } + + clone := reflect.MakeMapWithSize(v.Type(), v.Len()) + seen[ptr] = clone + iter := v.MapRange() + for iter.Next() { + clonedValue, ok := deepCloneValueWithState(iter.Value(), seen, depth+1) + if !ok { + return nil, false + } + if clonedValue == nil { + clone.SetMapIndex(iter.Key(), reflect.Zero(v.Type().Elem())) + continue + } + + value := reflect.ValueOf(clonedValue) + if value.Type().AssignableTo(v.Type().Elem()) { + clone.SetMapIndex(iter.Key(), value) + continue + } + if value.Type().ConvertibleTo(v.Type().Elem()) { + clone.SetMapIndex(iter.Key(), value.Convert(v.Type().Elem())) + continue + } + + clone.SetMapIndex(iter.Key(), iter.Value()) + } + return clone.Interface(), true + case reflect.Slice: + if v.IsNil() { + return nil, true + } + if v.Type().Elem().Kind() == reflect.Uint8 { + clone := make([]byte, v.Len()) + reflect.Copy(reflect.ValueOf(clone), v) + return clone, true + } + + ptr := v.Pointer() + if cloned, ok := seen[ptr]; ok { + return cloned.Interface(), true + } + + clone := reflect.MakeSlice(v.Type(), v.Len(), v.Len()) + seen[ptr] = clone + for i := 0; i < v.Len(); i++ { + if !setClonedValue(clone.Index(i), v.Index(i), seen, depth+1) { + return nil, false + } + } + return clone.Interface(), true + case reflect.Array: + clone := reflect.New(v.Type()).Elem() + for i := 0; i < v.Len(); i++ { + if !setClonedValue(clone.Index(i), v.Index(i), seen, depth+1) { + return nil, false + } + } + return clone.Interface(), true + case reflect.Struct: + clone := reflect.New(v.Type()).Elem() + clone.Set(v) + for i := 0; i < v.NumField(); i++ { + if !setClonedValue(clone.Field(i), v.Field(i), seen, depth+1) { + return nil, false + } + } + return clone.Interface(), true + default: + return valueInterface(v), true + } +} + +func setClonedValue(dst reflect.Value, src reflect.Value, seen map[uintptr]reflect.Value, depth int) bool { + cloned, ok := cloneClaimsValue(src, seen, depth) + if !ok { + return false + } + return assignClonedValue(dst, cloned) +} + +func assignClonedValue(dst reflect.Value, cloned any) bool { + if !dst.IsValid() { + return false + } + + if cloned == nil { + return setReflectValue(dst, reflect.Zero(dst.Type())) + } + + value := reflect.ValueOf(cloned) + if value.Type().AssignableTo(dst.Type()) { + return setReflectValue(dst, value) + } + if value.Type().ConvertibleTo(dst.Type()) { + return setReflectValue(dst, value.Convert(dst.Type())) + } + + return false +} + +// setReflectValue sets dst to value, using dst.UnsafeAddr and +// reflect.NewAt when the destination field is unexported. It is only used +// while cloning trusted claim values into a fresh value of the same concrete +// type; callers must pass an addressable destination, a type-compatible value, +// and must not race with other mutation of that destination. +func setReflectValue(dst reflect.Value, value reflect.Value) bool { + if dst.CanSet() { + dst.Set(value) + return true + } + + if !dst.CanAddr() { + return false + } + + writable := reflect.NewAt(dst.Type(), unsafe.Pointer(dst.UnsafeAddr())).Elem() + writable.Set(value) + return true +} + +func valueInterface(v reflect.Value) any { + if !v.IsValid() { + return nil + } + if v.CanInterface() { + return v.Interface() + } + if v.CanAddr() { + return reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem().Interface() + } + return nil +} + +// Authenticator validates an HTTP upgrade request and returns the identity +// that should be attached to the accepted WebSocket client. +// AX-6-exception: Authentication runs during the RFC 6455 HTTP/1.1 upgrade +// handshake, so authenticators intentionally receive the net/http request +// object that gorilla/websocket validates and upgrades. +// +// auth := ws.AuthenticatorFunc(func(r *http.Request) ws.AuthResult { +// return ws.AuthResult{Authenticated: true, UserID: "user-123"} +// }) type Authenticator interface { Authenticate(r *http.Request) AuthResult } -// AuthenticatorFunc is an adapter that allows ordinary functions to be -// used as Authenticators. If f is a function with the appropriate -// signature, AuthenticatorFunc(f) is an Authenticator that calls f. +// AuthenticatorFunc adapts a function to the Authenticator interface. +// +// auth := ws.AuthenticatorFunc(func(r *http.Request) ws.AuthResult { +// return ws.AuthResult{Authenticated: true, UserID: "user-123"} +// }) type AuthenticatorFunc func(r *http.Request) AuthResult // Authenticate calls f(r). @@ -47,22 +475,75 @@ func (f AuthenticatorFunc) Authenticate(r *http.Request) AuthResult { } } - return f(r) + return finalizeAuthResult(f(r)) } -// APIKeyAuthenticator validates requests against a static map of API -// keys. It expects the key in the Authorization header as a Bearer -// token: `Authorization: Bearer `. Each key maps to a user ID. +// APIKeyAuthenticator validates bearer tokens against a construction-time +// snapshot of API keys to user IDs. +// +// auth := ws.NewAPIKeyAuth(map[string]string{"secret-key": "user-123"}) type APIKeyAuthenticator struct { - // Keys maps API key values to user IDs. + // Keys is a construction-time snapshot of API key values to user IDs. + // Treat it as read-only; Authenticate uses the internal snapshot. Keys map[string]string + + keys map[string]string } -// NewAPIKeyAuth creates an APIKeyAuthenticator from the given key→userID +// NewAPIKeyAuth creates an APIKeyAuthenticator from the given key-to-userID // mapping. The returned authenticator validates `Authorization: Bearer ` // headers against the provided keys. func NewAPIKeyAuth(keys map[string]string) *APIKeyAuthenticator { - return &APIKeyAuthenticator{Keys: keys} + if keys == nil { + return &APIKeyAuthenticator{ + Keys: nil, + keys: nil, + } + } + + snapshot := cloneStringMap(keys) + + return &APIKeyAuthenticator{ + Keys: snapshot, + keys: cloneStringMap(snapshot), + } +} + +func cloneStringMap(values map[string]string) map[string]string { + if len(values) == 0 { + return nil + } + + clone := make(map[string]string, len(values)) + for key, value := range values { + clone[key] = value + } + return clone +} + +// NewBearerTokenAuth creates a bearer-token authenticator. +// +// auth := ws.NewBearerTokenAuth(func(token string) ws.AuthResult { +// return ws.AuthResult{Authenticated: token == "secret", UserID: "user-1"} +// }) +// +// A custom validator should be supplied for production use. When no +// validator is configured, the authenticator rejects the connection. +func NewBearerTokenAuth(validateFns ...func(token string) AuthResult) *BearerTokenAuth { + if len(validateFns) > 0 && validateFns[0] != nil { + return &BearerTokenAuth{ + Validate: validateFns[0], + } + } + + return &BearerTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{ + Valid: false, + Error: coreerr.E("BearerTokenAuth", "validate function is not configured", nil), + } + }, + } } // Authenticate checks the Authorization header for a valid Bearer token. @@ -90,7 +571,7 @@ func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult { } parts := core.SplitN(header, " ", 2) - if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + if len(parts) != 2 || core.Lower(parts[0]) != "bearer" { return AuthResult{ Valid: false, Error: ErrMalformedAuthHeader, @@ -105,7 +586,7 @@ func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult { } } - userID, ok := a.Keys[token] + userID, ok := a.keys[token] if !ok { return AuthResult{ Valid: false, @@ -113,20 +594,24 @@ func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult { } } - return AuthResult{ - Valid: true, - UserID: userID, - Claims: map[string]any{ - "auth_method": "api_key", - }, + if core.Trim(userID) == "" { + return AuthResult{ + Valid: false, + Error: ErrInvalidAPIKey, + } } + + return authenticatedResult(userID, map[string]any{ + "auth_method": "api_key", + }) } -// BearerTokenAuth extracts an Authorization: Bearer header and -// validates it using a caller-supplied function. Unlike APIKeyAuthenticator, -// this authenticator delegates validation entirely to the caller, making -// it suitable for JWT verification, token introspection, or any custom -// bearer scheme. +// BearerTokenAuth validates bearer tokens with a caller-supplied validation +// function. +// +// auth := ws.NewBearerTokenAuth(func(token string) ws.AuthResult { +// return ws.AuthResult{Authenticated: true, UserID: "user-123"} +// }) type BearerTokenAuth struct { // Validate receives the raw bearer token string and should return // an AuthResult. The caller controls UserID, Claims, and error @@ -166,7 +651,7 @@ func (b *BearerTokenAuth) Authenticate(r *http.Request) AuthResult { } parts := core.SplitN(header, " ", 2) - if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + if len(parts) != 2 || core.Lower(parts[0]) != "bearer" { return AuthResult{ Valid: false, Error: ErrMalformedAuthHeader, @@ -181,19 +666,46 @@ func (b *BearerTokenAuth) Authenticate(r *http.Request) AuthResult { } } - return b.Validate(token) + return finalizeAuthResult(b.Validate(token)) } -// QueryTokenAuth extracts a token from the ?token= query parameter and -// validates it using a caller-supplied function. This is useful for -// browser clients that cannot set custom headers on WebSocket connections -// (e.g. the browser's native WebSocket API does not support custom headers). +// QueryTokenAuth validates the token query parameter with a caller-supplied +// validation function. +// +// auth := ws.NewQueryTokenAuth(func(token string) ws.AuthResult { +// return ws.AuthResult{Authenticated: true, UserID: "user-123"} +// }) type QueryTokenAuth struct { // Validate receives the raw token value from the query string and // should return an AuthResult. Validate func(token string) AuthResult } +// NewQueryTokenAuth creates a query-token authenticator. +// +// auth := ws.NewQueryTokenAuth(func(token string) ws.AuthResult { +// return ws.AuthResult{Authenticated: token == "browser-token", UserID: "user-2"} +// }) +// +// A custom validator should be supplied for production use. When no +// validator is configured, the authenticator rejects the connection. +func NewQueryTokenAuth(validateFns ...func(token string) AuthResult) *QueryTokenAuth { + if len(validateFns) > 0 && validateFns[0] != nil { + return &QueryTokenAuth{ + Validate: validateFns[0], + } + } + + return &QueryTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{ + Valid: false, + Error: coreerr.E("QueryTokenAuth", "validate function is not configured", nil), + } + }, + } +} + // Authenticate implements the Authenticator interface for query parameter tokens. func (q *QueryTokenAuth) Authenticate(r *http.Request) AuthResult { if q == nil { @@ -232,5 +744,5 @@ func (q *QueryTokenAuth) Authenticate(r *http.Request) AuthResult { } } - return q.Validate(token) + return finalizeAuthResult(q.Validate(token)) } diff --git a/auth_test.go b/auth_test.go index 1a41aab..107f550 100644 --- a/auth_test.go +++ b/auth_test.go @@ -6,14 +6,14 @@ import ( "context" "net/http" "net/http/httptest" + "reflect" + // Note: AX-6 — internal concurrency primitive; structural for go-ws hub state (RFC mandates concurrent connection map). "sync" "testing" "time" - core "dappco.re/go/core" + core "dappco.re/go" "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- @@ -30,11 +30,22 @@ func TestAPIKeyAuthenticator_ValidKey(t *testing.T) { r.Header.Set("Authorization", "Bearer key-abc") result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("user-1", result.UserID) { + t.Errorf("expected %v, got %v", "user-1", result.UserID) + } + if !testEqual("api_key", result.Claims["auth_method"]) { + t.Errorf("expected %v, got %v", "api_key", result.Claims["auth_method"]) + } + if err := result.Error; err != nil { + t.Errorf("expected no error, got %v", err) + } - assert.True(t, result.Valid) - assert.Equal(t, "user-1", result.UserID) - assert.Equal(t, "api_key", result.Claims["auth_method"]) - assert.NoError(t, result.Error) } func TestAPIKeyAuthenticator_InvalidKey(t *testing.T) { @@ -46,10 +57,16 @@ func TestAPIKeyAuthenticator_InvalidKey(t *testing.T) { r.Header.Set("Authorization", "Bearer wrong-key") result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if !testIsEmpty(result.UserID) { + t.Errorf("expected empty value, got %v", result.UserID) + } + if !(core.Is(result.Error, ErrInvalidAPIKey)) { + t.Errorf("expected true") + } - assert.False(t, result.Valid) - assert.Empty(t, result.UserID) - assert.True(t, core.Is(result.Error, ErrInvalidAPIKey)) } func TestAPIKeyAuthenticator_MissingHeader(t *testing.T) { @@ -61,9 +78,13 @@ func TestAPIKeyAuthenticator_MissingHeader(t *testing.T) { // No Authorization header set result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if !(core.Is(result.Error, ErrMissingAuthHeader)) { + t.Errorf("expected true") + } - assert.False(t, result.Valid) - assert.True(t, core.Is(result.Error, ErrMissingAuthHeader)) } func TestAPIKeyAuthenticator_MalformedHeader(t *testing.T) { @@ -82,97 +103,1341 @@ func TestAPIKeyAuthenticator_MalformedHeader(t *testing.T) { {"empty bearer with spaces", "Bearer "}, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/ws", nil) - r.Header.Set("Authorization", tt.header) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", tt.header) + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if !(core.Is(result.Error, ErrMalformedAuthHeader)) { + t.Errorf("expected true") + } + + }) + } +} + +func TestAPIKeyAuthenticator_CaseInsensitiveScheme(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "bearer key-abc") + + result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("user-1", result.UserID) { + t.Errorf("expected %v, got %v", "user-1", result.UserID) + } + +} + +func TestAPIKeyAuthenticator_SecondKey(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + "key-def": "user-2", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-def") + + result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !testEqual("user-2", result.UserID) { + t.Errorf("expected %v, got %v", "user-2", result.UserID) + } + +} + +func TestAPIKeyAuthenticator_CopiesInputMap(t *testing.T) { + keys := map[string]string{ + "key-abc": "user-1", + } + + auth := NewAPIKeyAuth(keys) + keys["key-abc"] = "user-2" + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-abc") + + result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !testEqual("user-1", result.UserID) { + t.Errorf("expected %v, got %v", "user-1", result.UserID) + } + +} + +func TestAPIKeyAuthenticator_SnapshotsInternalMap(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + }) + + auth.Keys["key-abc"] = "user-2" + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-abc") + + result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !testEqual("user-1", result.UserID) { + t.Errorf("expected %v, got %v", "user-1", result.UserID) + } + +} + +func TestAPIKeyAuthenticator_ManualLiteral_DoesNotUseExportedKeys(t *testing.T) { + auth := &APIKeyAuthenticator{ + Keys: map[string]string{ + "key-abc": "user-1", + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-abc") + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrInvalidAPIKey)) { + t.Errorf("expected true") + } + +} + +func TestAPIKeyAuthenticator_EmptyUserID_Bad(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-abc") + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrInvalidAPIKey)) { + t.Errorf("expected true") + } + +} + +func TestAPIKeyAuthenticator_NilMap_Good(t *testing.T) { + auth := NewAPIKeyAuth(nil) + if testIsNil(auth) { + t.Fatalf("expected non-nil value") + } + if !testIsEmpty(auth.Keys) { + t.Errorf("expected empty value, got %v", auth.Keys) + } + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-abc") + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrInvalidAPIKey)) { + t.Errorf("expected true") + } + +} + +// --------------------------------------------------------------------------- +// Unit tests — AuthenticatorFunc adapter +// --------------------------------------------------------------------------- + +func TestAuthenticatorFunc_Adapter(t *testing.T) { + called := false + fn := AuthenticatorFunc(func(r *http.Request) AuthResult { + called = true + return AuthResult{Valid: true, UserID: "func-user"} + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + result := fn.Authenticate(r) + if !(called) { + t.Errorf("expected true") + } + if !(result.Valid) { + t.Errorf("expected true") + } + if !testEqual("func-user", result.UserID) { + t.Errorf("expected %v, got %v", "func-user", result.UserID) + } + +} + +func TestAuthenticatorFunc_Rejection(t *testing.T) { + fn := AuthenticatorFunc(func(r *http.Request) AuthResult { + return AuthResult{Valid: false, Error: core.NewError("custom rejection")} + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + result := fn.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil || err.Error() != "custom rejection" { + t.Errorf("expected error %q, got %v", "custom rejection", err) + } + +} + +func TestAuthenticatorFunc_NilFunction(t *testing.T) { + var fn AuthenticatorFunc + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + result := fn.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "authenticator function is nil") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "authenticator function is nil") + } + +} + +func TestAuth_NewBearerTokenAuth_DefaultValidator_Bad(t *testing.T) { + auth := NewBearerTokenAuth() + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer token-123") + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "validate function is not configured") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "validate function is not configured") + } + +} + +func TestAuth_NewBearerTokenAuth_Bad(t *testing.T) { + auth := NewBearerTokenAuth() + + result := auth.Validate("") + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "validate function is not configured") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "validate function is not configured") + } + +} + +func TestAuth_NewBearerTokenAuth_Ugly(t *testing.T) { + auth := &BearerTokenAuth{} + + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws", nil)) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "validate function is not configured") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "validate function is not configured") + } + +} + +func TestAuth_NewBearerTokenAuth_CustomValidator_Good(t *testing.T) { + auth := NewBearerTokenAuth(func(token string) AuthResult { + if token == "custom-token" { + return AuthResult{Authenticated: true, UserID: "custom-user"} + } + return AuthResult{Valid: false, Error: core.NewError("bad token")} + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer custom-token") + + result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("custom-user", result.UserID) { + t.Errorf("expected %v, got %v", "custom-user", result.UserID) + } + +} + +func TestAuth_authenticatedResult_Good(t *testing.T) { + claims := map[string]any{ + "role": "admin", + } + + result := authenticatedResult("user-123", claims) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("user-123", result.UserID) { + t.Errorf("expected %v, got %v", "user-123", result.UserID) + } + if !testEqual(claims, result.Claims) { + t.Errorf("expected %v, got %v", claims, result.Claims) + } + if err := result.Error; err != nil { + t.Errorf("expected no error, got %v", err) + } + +} + +func TestAuth_authenticatedResult_Bad(t *testing.T) { + result := authenticatedResult(" ", nil) + if result.Valid { + t.Errorf("expected false") + } + if result.Authenticated { + t.Errorf("expected false") + } + if !testIsEmpty(result.UserID) { + t.Errorf("expected empty value, got %v", result.UserID) + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrMissingUserID)) { + t.Errorf("expected true") + } + +} + +type authClaimNode struct { + Next *authClaimNode +} + +func deepAuthClaimNode(depth int) *authClaimNode { + root := &authClaimNode{} + current := root + for i := 0; i < depth; i++ { + next := &authClaimNode{} + current.Next = next + current = next + } + return root +} + +func deepAuthClaimsChain(depth int) map[string]any { + return map[string]any{ + "chain": deepAuthClaimNode(depth), + } +} + +func TestAuth_authenticatedResult_Ugly(t *testing.T) { + claims := deepAuthClaimsChain(maxClaimsCloneDepth + 64) + + result := authenticatedResult("user-123", claims) + if result.Valid { + t.Errorf("expected false") + } + if result.Authenticated { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrInvalidAuthClaims)) { + t.Errorf("expected true") + } + +} + +func TestAuth_finalizeAuthResult_Good(t *testing.T) { + claims := map[string]any{ + "role": "admin", + "scope": map[string]any{ + "channels": []string{"alpha", "beta"}, + }, + } + + result := finalizeAuthResult(AuthResult{ + Authenticated: true, + UserID: " user-123 ", + Claims: claims, + }) + if !(result.Valid) { + t.Fatalf("expected true") + } + if !(result.Authenticated) { + t.Fatalf("expected true") + } + if !testEqual("user-123", result.UserID) { + t.Errorf("expected %v, got %v", "user-123", result.UserID) + } + if !testEqual("admin", result.Claims["role"]) { + t.Errorf("expected %v, got %v", "admin", result.Claims["role"]) + } + + claims["role"] = "user" + claimsScope := claims["scope"].(map[string]any) + claimsScope["channels"] = []string{"gamma"} + if !testEqual("admin", result.Claims["role"]) { + t.Errorf("expected %v, got %v", "admin", result.Claims["role"]) + } + + resultScope := result.Claims["scope"].(map[string]any) + if !testEqual([]string{"alpha", "beta"}, resultScope["channels"]) { + t.Errorf("expected %v, got %v", []string{"alpha", "beta"}, resultScope["channels"]) + } + +} + +func TestAuth_finalizeAuthResult_Bad(t *testing.T) { + result := finalizeAuthResult(AuthResult{ + Valid: true, + UserID: " ", + }) + if result.Valid { + t.Errorf("expected false") + } + if result.Authenticated { + t.Errorf("expected false") + } + if !testIsEmpty(result.UserID) { + t.Errorf("expected empty value, got %v", result.UserID) + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrMissingUserID)) { + t.Errorf("expected true") + } + +} + +func TestAuth_finalizeAuthResult_Ugly(t *testing.T) { + result := finalizeAuthResult(AuthResult{ + Valid: true, + UserID: "user-123", + Claims: deepAuthClaimsChain(maxClaimsCloneDepth + 64), + }) + if result.Valid { + t.Errorf("expected false") + } + if result.Authenticated { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrInvalidAuthClaims)) { + t.Errorf("expected true") + } + +} + +func TestAuth_NewBearerTokenAuth_NilValidator_Bad(t *testing.T) { + auth := NewBearerTokenAuth(nil) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer token-123") + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "validate function is not configured") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "validate function is not configured") + } + +} + +func TestAuth_NewQueryTokenAuth_DefaultValidator_ValidateCall_Bad(t *testing.T) { + auth := NewQueryTokenAuth() + + r := httptest.NewRequest(http.MethodGet, "/ws?token=query-123", nil) + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "validate function is not configured") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "validate function is not configured") + } + +} + +func TestAuth_NewQueryTokenAuth_Bad(t *testing.T) { + auth := NewQueryTokenAuth() + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "missing token query parameter") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "missing token query parameter") + } + +} + +func TestAuth_NewQueryTokenAuth_DefaultValidator_ValidateEmpty_Bad(t *testing.T) { + auth := NewQueryTokenAuth() + + result := auth.Validate("") + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "validate function is not configured") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "validate function is not configured") + } + +} + +func TestAuth_NewQueryTokenAuth_Ugly(t *testing.T) { + auth := &QueryTokenAuth{} + + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws?token=abc", nil)) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "validate function is not configured") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "validate function is not configured") + } + +} + +func TestAuth_NewQueryTokenAuth_CustomValidator_Good(t *testing.T) { + auth := NewQueryTokenAuth(func(token string) AuthResult { + if token == "browser-token" { + return AuthResult{Authenticated: true, UserID: "browser-user"} + } + return AuthResult{Valid: false, Error: core.NewError("bad token")} + }) + + r := httptest.NewRequest(http.MethodGet, "/ws?token=browser-token", nil) + + result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("browser-user", result.UserID) { + t.Errorf("expected %v, got %v", "browser-user", result.UserID) + } + +} + +func TestAuth_NewQueryTokenAuth_NilValidator_Bad(t *testing.T) { + auth := NewQueryTokenAuth(nil) + + r := httptest.NewRequest(http.MethodGet, "/ws?token=query-123", nil) + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "validate function is not configured") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "validate function is not configured") + } + +} + +func TestAuth_CustomValidator_EmptyUserID_Bad(t *testing.T) { + t.Run("bearer", func(t *testing.T) { + auth := NewBearerTokenAuth(func(token string) AuthResult { + return AuthResult{Valid: true, UserID: ""} + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer token-123") + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrMissingUserID)) { + t.Errorf("expected true") + } + + }) + + t.Run("query", func(t *testing.T) { + auth := NewQueryTokenAuth(func(token string) AuthResult { + return AuthResult{Authenticated: true} + }) + + r := httptest.NewRequest(http.MethodGet, "/ws?token=query-123", nil) + + result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrMissingUserID)) { + t.Errorf("expected true") + } + + }) +} + +func TestAuth_ClaimsAreCloned(t *testing.T) { + claims := map[string]any{ + "role": "admin", + "scope": map[string]any{ + "channels": []string{"alpha", "beta"}, + }, + } + + auth := AuthenticatorFunc(func(r *http.Request) AuthResult { + return AuthResult{Valid: true, UserID: "user-123", Claims: claims} + }) + + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws", nil)) + if !(result.Valid) { + t.Fatalf("expected true") + } + if testIsNil(result.Claims) { + t.Fatalf("expected non-nil value") + } + + claims["role"] = "user" + claimsScope := claims["scope"].(map[string]any) + claimsScope["channels"] = []string{"gamma"} + if !testEqual("admin", result.Claims["role"]) { + t.Errorf("expected %v, got %v", "admin", result.Claims["role"]) + } + + resultScope := result.Claims["scope"].(map[string]any) + if !testEqual([]string{"alpha", "beta"}, resultScope["channels"]) { + t.Errorf("expected %v, got %v", []string{"alpha", "beta"}, resultScope["channels"]) + } + +} + +func TestAuth_ClaimsAreCloneSafeForCycles(t *testing.T) { + claims := map[string]any{} + claims["self"] = claims + + auth := AuthenticatorFunc(func(r *http.Request) AuthResult { + return AuthResult{Valid: true, UserID: "user-123", Claims: claims} + }) + + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws", nil)) + if !(result.Valid) { + t.Fatalf("expected true") + } + if testIsNil(result.Claims) { + t.Fatalf("expected non-nil value") + } + + clonedSelf, ok := result.Claims["self"].(map[string]any) + if !(ok) { + t.Fatalf("expected true") + } + if testEqual(reflect.ValueOf(claims).Pointer(), reflect.ValueOf(clonedSelf).Pointer()) { + t.Errorf("expected values to differ: %v", reflect.ValueOf(clonedSelf).Pointer()) + } + +} + +func TestAuth_ClaimsRejectUnsupportedKinds(t *testing.T) { + auth := AuthenticatorFunc(func(r *http.Request) AuthResult { + return AuthResult{ + Valid: true, + UserID: "user-123", + Claims: map[string]any{ + "stream": make(chan int), + }, + } + }) + + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws", nil)) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(result.Error, ErrInvalidAuthClaims)) { + t.Errorf("expected true") + } + +} + +func TestAuth_deepCloneValueWithState_Good(t *testing.T) { + type secretClaim struct { + Name string + bytes []byte + Next *secretClaim + } + + original := &secretClaim{ + Name: "alice", + bytes: []byte{1, 2, 3}, + } + original.Next = original + + clonedValue, ok := deepCloneValueWithState(reflect.ValueOf(original), make(map[uintptr]reflect.Value), 0) + if !(ok) { + t.Fatalf("expected true") + } + + clone := clonedValue.(*secretClaim) + if testSame(original, clone) { + t.Fatalf("expected different references") + } + if testIsNil(clone.Next) { + t.Fatalf("expected non-nil value") + } + if !testSame(clone, clone.Next) { + t.Errorf("expected same reference") + } + if !testEqual([]byte{1, 2, 3}, clone.bytes) { + t.Errorf("expected %v, got %v", []byte{1, 2, 3}, clone.bytes) + } + + original.bytes[0] = 9 + if !testEqual([]byte{1, 2, 3}, clone.bytes) { + t.Errorf("expected %v, got %v", []byte{1, 2, 3}, clone.bytes) + } + + cyclicMap := map[string]any{} + cyclicMap["self"] = cyclicMap + clonedMap, ok := deepCloneValueWithState(reflect.ValueOf(cyclicMap), make(map[uintptr]reflect.Value), 0) + if !(ok) { + t.Fatalf("expected true") + } + if testIsNil(clonedMap) { + t.Fatalf("expected non-nil value") + } + + cyclicSlice := make([]any, 1) + cyclicSlice[0] = cyclicSlice + clonedSlice, ok := deepCloneValueWithState(reflect.ValueOf(cyclicSlice), make(map[uintptr]reflect.Value), 0) + if !(ok) { + t.Fatalf("expected true") + } + if testIsNil(clonedSlice) { + t.Fatalf("expected non-nil value") + } + +} + +func TestAuth_deepCloneValueWithState_Bad(t *testing.T) { + value := reflect.ValueOf(struct { + secret int + }{secret: 123}).Field(0) + + cloned, ok := deepCloneValueWithState(value, make(map[uintptr]reflect.Value), 0) + if ok { + t.Errorf("expected false") + } + if !testIsNil(cloned) { + t.Errorf("expected nil, got %T", cloned) + } + +} + +func TestAuth_deepCloneValueWithState_Ugly(t *testing.T) { + cloned, ok := deepCloneValueWithState(reflect.ValueOf(deepAuthClaimNode(maxClaimsCloneDepth+1)), make(map[uintptr]reflect.Value), 0) + if ok { + t.Errorf("expected false") + } + if !testIsNil(cloned) { + t.Errorf("expected nil, got %T", cloned) + } + +} + +func TestAuth_valueInterface_Good(t *testing.T) { + type claim struct { + secret int + } + + value := reflect.ValueOf(&claim{secret: 7}).Elem().FieldByName("secret") + if !testEqual(7, valueInterface(value)) { + t.Errorf("expected %v, got %v", 7, valueInterface(value)) + } + +} + +func TestAuth_valueInterface_Bad(t *testing.T) { + if !testIsNil(valueInterface(reflect.Value{})) { + t.Errorf("expected nil, got %T", valueInterface(reflect.Value{})) + } + +} + +func TestAuth_valueInterface_Ugly(t *testing.T) { + type claim struct { + secret int + } + if !testIsNil(valueInterface(reflect.ValueOf(claim{secret: 7}).FieldByName("secret"))) { + t.Errorf("expected nil, got %T", valueInterface(reflect.ValueOf(claim{secret: 7}).FieldByName("secret"))) + } + +} + +func TestAuth_setReflectValue_Good(t *testing.T) { + type claim struct { + Value int + } + + original := &claim{} + field := reflect.ValueOf(original).Elem().FieldByName("Value") + if !(setReflectValue(field, reflect.ValueOf(7))) { + t.Errorf("expected true") + } + if !testEqual(7, original.Value) { + t.Errorf("expected %v, got %v", 7, original.Value) + } + +} + +func TestAuth_setReflectValue_Bad(t *testing.T) { + if setReflectValue(reflect.Value{}, reflect.ValueOf(7)) { + t.Errorf("expected false") + } + +} + +func TestAuth_setReflectValue_Ugly(t *testing.T) { + type claim struct { + secret int + } + + original := &claim{} + field := reflect.ValueOf(original).Elem().FieldByName("secret") + if !(setReflectValue(field, reflect.ValueOf(7))) { + t.Errorf("expected true") + } + if !testEqual(7, original.secret) { + t.Errorf("expected %v, got %v", 7, original.secret) + } + +} + +func TestAuth_assignClonedValue_Good(t *testing.T) { + type alias int + + var dst alias + if !(assignClonedValue(reflect.ValueOf(&dst).Elem(), int64(7))) { + t.Errorf("expected true") + } + if !testEqual(alias(7), dst) { + t.Errorf("expected %v, got %v", alias(7), dst) + } + +} + +func TestAuth_assignClonedValue_Bad(t *testing.T) { + var dst int + if assignClonedValue(reflect.Value{}, 7) { + t.Errorf("expected false") + } + if assignClonedValue(reflect.ValueOf(&dst).Elem(), struct { + }{}) { + t.Errorf("expected false") + } + +} + +func TestAuth_assignClonedValue_Ugly(t *testing.T) { + var dst int + if !(assignClonedValue(reflect.ValueOf(&dst).Elem(), nil)) { + t.Errorf("expected true") + } + if !testIsZero(dst) { + t.Errorf("expected zero value, got %v", dst) + } + +} + +func TestAuth_cloneStringMap_Good(t *testing.T) { + original := map[string]string{ + "key-abc": "user-1", + } + + clone := cloneStringMap(original) + if testIsNil(clone) { + t.Fatalf("expected non-nil value") + } + if !testEqual(original, clone) { + t.Errorf("expected %v, got %v", original, clone) + } + + original["key-abc"] = "user-2" + if !testEqual("user-1", clone["key-abc"]) { + t.Errorf("expected %v, got %v", "user-1", clone["key-abc"]) + } + +} + +func TestAuth_cloneStringMap_Bad(t *testing.T) { + if !testIsNil(cloneStringMap(nil)) { + t.Errorf("expected nil, got %T", cloneStringMap(nil)) + } + +} + +func TestAuth_cloneStringMap_Ugly(t *testing.T) { + if !testIsNil(cloneStringMap(map[string]string{})) { + t.Errorf("expected nil, got %T", cloneStringMap(map[string]string{})) + } + +} + +func TestAuth_deepCloneValue_Good(t *testing.T) { + type nestedClaim struct { + Name string + Tags []string + Bytes []byte + Meta map[string]any + Counts [2]int + Child *struct { + Enabled bool + Flags []string + } + Optional *struct { + Label string + } + } + + original := nestedClaim{ + Name: "alice", + Tags: []string{"alpha", "beta"}, + Bytes: []byte{1, 2, 3}, + Meta: map[string]any{"channels": []string{"one", "two"}}, + Counts: [2]int{7, 9}, + Child: &struct { + Enabled bool + Flags []string + }{ + Enabled: true, + Flags: []string{"root", "admin"}, + }, + Optional: nil, + } + + cloned := deepCloneValue(reflect.ValueOf(original)) + if testIsNil(cloned) { + t.Fatalf("expected non-nil value") + } + + clone := cloned.(nestedClaim) + if testSame(original.Child, clone.Child) { + t.Fatalf("expected different references") + } + if !testEqual(original, clone) { + t.Errorf("expected %v, got %v", original, clone) + } + + original.Tags[0] = "mutated" + original.Bytes[0] = 9 + original.Meta["channels"] = []string{"changed"} + original.Counts[0] = 42 + original.Child.Enabled = false + original.Child.Flags[0] = "guest" + if !testEqual([]string{"alpha", "beta"}, clone.Tags) { + t.Errorf("expected %v, got %v", []string{"alpha", "beta"}, clone.Tags) + } + if !testEqual([]byte{1, 2, 3}, clone.Bytes) { + t.Errorf("expected %v, got %v", []byte{1, 2, 3}, clone.Bytes) + } + if !testEqual([]string{"one", "two"}, clone.Meta["channels"]) { + t.Errorf("expected %v, got %v", []string{"one", "two"}, clone.Meta["channels"]) + } + if !testEqual([2]int{7, 9}, clone.Counts) { + t.Errorf("expected %v, got %v", [2]int{7, 9}, clone.Counts) + } + if !(clone.Child.Enabled) { + t.Errorf("expected true") + } + if !testEqual([]string{"root", "admin"}, clone.Child.Flags) { + t.Errorf("expected %v, got %v", []string{"root", "admin"}, clone.Child.Flags) + } + if !testIsNil(clone.Optional) { + t.Errorf("expected nil, got %T", clone.Optional) + } + +} + +func TestAuth_ClaimsDeepClone_UnexportedMutableFields(t *testing.T) { + type opaqueClaim struct { + Name string + roles []string + meta map[string]any + } + + original := &opaqueClaim{ + Name: "alice", + roles: []string{"admin", "ops"}, + meta: map[string]any{ + "channels": []string{"alpha", "beta"}, + }, + } + + auth := AuthenticatorFunc(func(r *http.Request) AuthResult { + return AuthResult{Valid: true, UserID: "user-123", Claims: map[string]any{"opaque": original}} + }) + + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws", nil)) + if !(result.Valid) { + t.Fatalf("expected true") + } + + cloned, ok := result.Claims["opaque"].(*opaqueClaim) + if !(ok) { + t.Fatalf("expected true") + } + if testSame(original, cloned) { + t.Fatalf("expected different references") + } + + original.roles[0] = "viewer" + original.meta["channels"] = []string{"gamma"} + if !testEqual([]string{"admin", "ops"}, cloned.roles) { + t.Errorf("expected %v, got %v", []string{"admin", "ops"}, cloned.roles) + } + if !testEqual([]string{"alpha", "beta"}, cloned.meta["channels"]) { + t.Errorf("expected %v, got %v", []string{"alpha", "beta"}, cloned.meta["channels"]) + } + +} + +func TestAuth_cloneClaimsValue_Good(t *testing.T) { + type opaqueClaim struct { + Name string + roles []string + meta map[string]any + } + + original := &opaqueClaim{ + Name: "alice", + roles: []string{"admin", "ops"}, + meta: map[string]any{ + "channels": []string{"alpha", "beta"}, + }, + } + + claims := map[string]any{ + "profile": original, + "self": nil, + } + claims["self"] = claims + + clonedValue, ok := cloneClaimsValue(reflect.ValueOf(claims), make(map[uintptr]reflect.Value), 0) + if !(ok) { + t.Fatalf("expected true") + } + + cloned, ok := clonedValue.(map[string]any) + if !(ok) { + t.Fatalf("expected true") + } + if testEqual(reflect.ValueOf(claims).Pointer(), reflect.ValueOf(cloned).Pointer()) { + t.Errorf("expected values to differ: %v", reflect.ValueOf(cloned).Pointer()) + } + + clonedProfile, ok := cloned["profile"].(*opaqueClaim) + if !(ok) { + t.Fatalf("expected true") + } + if testSame(original, clonedProfile) { + t.Fatalf("expected different references") + } + + clonedSelf, ok := cloned["self"].(map[string]any) + if !(ok) { + t.Fatalf("expected true") + } + if testEqual(reflect.ValueOf(claims).Pointer(), reflect.ValueOf(clonedSelf).Pointer()) { + t.Errorf("expected values to differ: %v", reflect.ValueOf(clonedSelf).Pointer()) + } + if !testEqual("alice", clonedProfile.Name) { + t.Errorf("expected %v, got %v", "alice", clonedProfile.Name) + } + if !testEqual([]string{"admin", "ops"}, clonedProfile.roles) { + t.Errorf("expected %v, got %v", []string{"admin", "ops"}, clonedProfile.roles) + } + if !testEqual([]string{"alpha", "beta"}, clonedProfile.meta["channels"]) { + t.Errorf("expected %v, got %v", []string{"alpha", "beta"}, clonedProfile.meta["channels"]) + } + + original.roles[0] = "viewer" + original.meta["channels"] = []string{"gamma"} + if !testEqual([]string{"admin", "ops"}, clonedProfile.roles) { + t.Errorf("expected %v, got %v", []string{"admin", "ops"}, clonedProfile.roles) + } + if !testEqual([]string{"alpha", "beta"}, clonedProfile.meta["channels"]) { + t.Errorf("expected %v, got %v", []string{"alpha", "beta"}, clonedProfile.meta["channels"]) + } + +} + +func TestAuth_cloneClaimsValue_Bad(t *testing.T) { + tests := []struct { + name string + value reflect.Value + }{ + {name: "unsupported kind", value: reflect.ValueOf(make(chan int))}, + {name: "unsupported func", value: reflect.ValueOf(func() {})}, + { + name: "unaddressable unexported field", + value: reflect.ValueOf(struct{ secret int }{secret: 1}).Field(0), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cloned, ok := cloneClaimsValue(tt.value, make(map[uintptr]reflect.Value), 0) + if ok { + t.Errorf("expected false") + } + if !testIsNil(cloned) { + t.Errorf("expected nil, got %T", cloned) + } + + }) + } +} + +func TestAuth_cloneClaimsValue_Ugly(t *testing.T) { + cloned, ok := cloneClaimsValue(reflect.ValueOf(deepAuthClaimNode(maxClaimsCloneDepth+1)), make(map[uintptr]reflect.Value), 0) + if ok { + t.Errorf("expected false") + } + if !testIsNil(cloned) { + t.Errorf("expected nil, got %T", cloned) + } + +} + +func TestAuth_deepCloneValue_Bad(t *testing.T) { + var nilSlice []string + var nilMap map[string]int + var nilPtr *int + if !testIsNil(deepCloneValue(reflect.ValueOf(nilSlice))) { + t.Errorf("expected nil, got %T", deepCloneValue(reflect.ValueOf(nilSlice))) + } + if !testIsNil(deepCloneValue(reflect.ValueOf(nilMap))) { + t.Errorf("expected nil, got %T", deepCloneValue(reflect.ValueOf(nilMap))) + } + if !testIsNil(deepCloneValue(reflect.ValueOf(nilPtr))) { + t.Errorf("expected nil, got %T", deepCloneValue(reflect.ValueOf(nilPtr))) + } + if !testIsNil(deepCloneValue(reflect.Value{})) { + t.Errorf("expected nil, got %T", deepCloneValue(reflect.Value{})) + } + if !testEqual(42, deepCloneValue(reflect.ValueOf(42))) { + t.Errorf("expected %v, got %v", 42, deepCloneValue(reflect.ValueOf(42))) + } + +} + +func TestAuth_deepCloneValue_Ugly(t *testing.T) { + ch := make(chan int, 1) + fn := func() {} + if !testEqual(ch, deepCloneValue(reflect.ValueOf(ch))) { + t.Errorf("expected %v, got %v", ch, deepCloneValue(reflect.ValueOf(ch))) + } + testNotPanics(t, func() { + _ = deepCloneValue(reflect.ValueOf(fn)) + }) + +} - result := auth.Authenticate(r) +func TestAuth_UserIDIsTrimmedOnSuccess(t *testing.T) { + auth := AuthenticatorFunc(func(r *http.Request) AuthResult { + return AuthResult{ + Valid: true, + UserID: " user-123 ", + } + }) - assert.False(t, result.Valid) - assert.True(t, core.Is(result.Error, ErrMalformedAuthHeader)) - }) + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws", nil)) + if !(result.Valid) { + t.Fatalf("expected true") } + if !testEqual("user-123", result.UserID) { + t.Errorf("expected %v, got %v", "user-123", result.UserID) + } + } -func TestAPIKeyAuthenticator_CaseInsensitiveScheme(t *testing.T) { - auth := NewAPIKeyAuth(map[string]string{ - "key-abc": "user-1", - }) +func TestAuth_Authenticate_NilReceivers_Ugly(t *testing.T) { + t.Run("api key", func(t *testing.T) { + var auth *APIKeyAuthenticator - r := httptest.NewRequest(http.MethodGet, "/ws", nil) - r.Header.Set("Authorization", "bearer key-abc") + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws", nil)) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "authenticator is nil") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "authenticator is nil") + } - result := auth.Authenticate(r) + }) - assert.True(t, result.Valid) - assert.Equal(t, "user-1", result.UserID) -} + t.Run("bearer", func(t *testing.T) { + var auth *BearerTokenAuth + + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws", nil)) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "authenticator is nil") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "authenticator is nil") + } -func TestAPIKeyAuthenticator_SecondKey(t *testing.T) { - auth := NewAPIKeyAuth(map[string]string{ - "key-abc": "user-1", - "key-def": "user-2", }) - r := httptest.NewRequest(http.MethodGet, "/ws", nil) - r.Header.Set("Authorization", "Bearer key-def") + t.Run("query", func(t *testing.T) { + var auth *QueryTokenAuth - result := auth.Authenticate(r) + result := auth.Authenticate(httptest.NewRequest(http.MethodGet, "/ws?token=abc", nil)) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "authenticator is nil") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "authenticator is nil") + } - assert.True(t, result.Valid) - assert.Equal(t, "user-2", result.UserID) + }) } -// --------------------------------------------------------------------------- -// Unit tests — AuthenticatorFunc adapter -// --------------------------------------------------------------------------- +func TestAuth_Authenticate_NilRequest_Ugly(t *testing.T) { + t.Run("api key", func(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{"key": "user"}) + + result := auth.Authenticate(nil) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "request is nil") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "request is nil") + } -func TestAuthenticatorFunc_Adapter(t *testing.T) { - called := false - fn := AuthenticatorFunc(func(r *http.Request) AuthResult { - called = true - return AuthResult{Valid: true, UserID: "func-user"} }) - r := httptest.NewRequest(http.MethodGet, "/ws", nil) - result := fn.Authenticate(r) + t.Run("bearer", func(t *testing.T) { + auth := NewBearerTokenAuth() - assert.True(t, called) - assert.True(t, result.Valid) - assert.Equal(t, "func-user", result.UserID) -} + result := auth.Authenticate(nil) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "request is nil") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "request is nil") + } -func TestAuthenticatorFunc_Rejection(t *testing.T) { - fn := AuthenticatorFunc(func(r *http.Request) AuthResult { - return AuthResult{Valid: false, Error: core.NewError("custom rejection")} }) - r := httptest.NewRequest(http.MethodGet, "/ws", nil) - result := fn.Authenticate(r) - - assert.False(t, result.Valid) - assert.EqualError(t, result.Error, "custom rejection") -} + t.Run("query", func(t *testing.T) { + auth := NewQueryTokenAuth() -func TestAuthenticatorFunc_NilFunction(t *testing.T) { - var fn AuthenticatorFunc + result := auth.Authenticate(nil) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "request is nil") { - r := httptest.NewRequest(http.MethodGet, "/ws", nil) - result := fn.Authenticate(r) + // --------------------------------------------------------------------------- + // Unit tests — nil Authenticator (backward compat) + // --------------------------------------------------------------------------- + t.Errorf("expected %v to contain %v", result.Error.Error(), "request is nil") + } - assert.False(t, result.Valid) - require.Error(t, result.Error) - assert.Contains(t, result.Error.Error(), "authenticator function is nil") + }) } -// --------------------------------------------------------------------------- -// Unit tests — nil Authenticator (backward compat) -// --------------------------------------------------------------------------- - func TestNilAuthenticator_AllConnectionsAccepted(t *testing.T) { - hub := NewHub() // No authenticator set - assert.Nil(t, hub.config.Authenticator) + hub := NewHub() + if // No authenticator set + !testIsNil(hub.config.Authenticator) { + t.Errorf("expected nil, got %T", hub.config.Authenticator) + } + } // --------------------------------------------------------------------------- @@ -185,6 +1450,11 @@ func startAuthTestHub(t *testing.T, config HubConfig) (*httptest.Server, *Hub, c hub := NewHubWithConfig(config) ctx, cancel := context.WithCancel(context.Background()) go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } server := httptest.NewServer(hub.Handler()) t.Cleanup(func() { @@ -219,20 +1489,33 @@ func TestIntegration_AuthenticatedConnect(t *testing.T) { header.Set("Authorization", "Bearer valid-key") conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) - require.NoError(t, err) - defer conn.Close() - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf( + + // Give the hub a moment to process registration + "expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } - // Give the hub a moment to process registration time.Sleep(50 * time.Millisecond) mu.Lock() client := connectedClient mu.Unlock() + if testIsNil(client) { + t.Fatalf("expected non-nil value") + } + if !testEqual("user-42", client.UserID) { + t.Errorf("expected %v, got %v", "user-42", client.UserID) + } + if !testEqual("api_key", client.Claims["auth_method"]) { + t.Errorf("expected %v, got %v", "api_key", client.Claims["auth_method"]) + } - require.NotNil(t, client, "OnConnect should have fired") - assert.Equal(t, "user-42", client.UserID) - assert.Equal(t, "api_key", client.Claims["auth_method"]) } func TestIntegration_RejectedConnect_InvalidKey(t *testing.T) { @@ -249,12 +1532,18 @@ func TestIntegration_RejectedConnect_InvalidKey(t *testing.T) { conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) if conn != nil { - conn.Close() + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) } - require.Error(t, err) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, 0, hub.ClientCount()) } func TestIntegration_RejectedConnect_NoAuthHeader(t *testing.T) { @@ -269,25 +1558,42 @@ func TestIntegration_RejectedConnect_NoAuthHeader(t *testing.T) { // No Authorization header conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil) if conn != nil { - conn.Close() + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount( + + // No authenticator — all connections should be accepted + )) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) } - require.Error(t, err) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, 0, hub.ClientCount()) } func TestIntegration_NilAuthenticator_BackwardCompat(t *testing.T) { - // No authenticator — all connections should be accepted + server, hub, _ := startAuthTestHub(t, HubConfig{}) conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil) - require.NoError(t, err) - defer conn.Close() - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } + } func TestIntegration_OnAuthFailure_Callback(t *testing.T) { @@ -316,7 +1622,7 @@ func TestIntegration_OnAuthFailure_Callback(t *testing.T) { conn, _, _ := websocket.DefaultDialer.Dial(authWSURL(server), header) if conn != nil { - conn.Close() + _ = conn.Close() } // Give callback time to execute @@ -324,11 +1630,19 @@ func TestIntegration_OnAuthFailure_Callback(t *testing.T) { failureMu.Lock() defer failureMu.Unlock() + if !(failureCalled) { + t.Errorf("expected true") + } + if failureResult.Valid { + t.Errorf("expected false") + } + if !(core.Is(failureResult.Error, ErrInvalidAPIKey)) { + t.Errorf("expected true") + } + if testIsNil(failureRequest) { + t.Errorf("expected non-nil value") + } - assert.True(t, failureCalled, "OnAuthFailure should have been called") - assert.False(t, failureResult.Valid) - assert.True(t, core.Is(failureResult.Error, ErrInvalidAPIKey)) - assert.NotNil(t, failureRequest) } func TestIntegration_MultipleClients_DifferentKeys(t *testing.T) { @@ -365,38 +1679,55 @@ func TestIntegration_MultipleClients_DifferentKeys(t *testing.T) { header.Set("Authorization", "Bearer "+k.key) conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) - require.NoError(t, err, "key %s should connect", k.key) - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + conns = append(conns, conn) } defer func() { for _, c := range conns { - c.Close() + testClose(t, c.Close) } }() time.Sleep(100 * time.Millisecond) - - assert.Equal(t, 3, hub.ClientCount()) + if !testEqual(3, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 3, hub.ClientCount()) + } mu.Lock() defer mu.Unlock() for _, k := range keys { client, ok := connectedClients[k.userID] - require.True(t, ok, "should have client for %s", k.userID) - assert.Equal(t, k.userID, client.UserID) + if !(ok) { + t.Fatalf("expected true") + } + if !testEqual(k.userID, client.UserID) { + t.Errorf("expected %v, got %v", k.userID, client.UserID) + } + } } func TestIntegration_AuthenticatorFunc_WithHub(t *testing.T) { // Use AuthenticatorFunc as the hub's authenticator + claims := map[string]any{ + "source": "query_param", + "scope": map[string]any{ + "channels": []string{"alpha", "beta"}, + }, + } fn := AuthenticatorFunc(func(r *http.Request) AuthResult { token := r.URL.Query().Get("token") if token == "magic" { return AuthResult{ Valid: true, UserID: "magic-user", - Claims: map[string]any{"source": "query_param"}, + Claims: claims, } } return AuthResult{Valid: false, Error: core.NewError("bad token")} @@ -408,19 +1739,55 @@ func TestIntegration_AuthenticatorFunc_WithHub(t *testing.T) { // Valid token via query parameter conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server)+"?token=magic", nil) - require.NoError(t, err) - defer conn.Close() - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } + + claims["source"] = "mutated" + claimsScope := claims["scope"].(map[string]any) + claimsScope["channels"] = []string{"gamma"} + + hub.mu.RLock() + var attachedClient *Client + for client := range hub.clients { + attachedClient = client + break + } + hub.mu.RUnlock() + if testIsNil(attachedClient) { + t.Fatalf("expected non-nil value") + } + if !testEqual("magic-user", attachedClient.UserID) { + t.Errorf("expected %v, got %v", "magic-user", attachedClient.UserID) + } + if !testEqual("query_param", attachedClient.Claims["source"]) { + t.Errorf("expected %v, got %v", "query_param", attachedClient.Claims["source"]) + } - // Invalid token + scope := attachedClient.Claims["scope"].(map[string]any) + if !testEqual([]string{"alpha", "beta"}, scope["channels"]) { + t.Errorf("expected %v, got %v", []string{"alpha", "beta"}, scope["channels"]) + } + + // Invalid token. conn2, resp2, _ := websocket.DefaultDialer.Dial(authWSURL(server)+"?token=wrong", nil) if conn2 != nil { - conn2.Close() + _ = conn2.Close() } - assert.Equal(t, http.StatusUnauthorized, resp2.StatusCode) + if !testEqual(http.StatusUnauthorized, resp2.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp2.StatusCode) + } + } func TestIntegration_AuthenticatorFuncNil_WithHub(t *testing.T) { @@ -432,12 +1799,18 @@ func TestIntegration_AuthenticatorFuncNil_WithHub(t *testing.T) { conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil) if conn != nil { - conn.Close() + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) } - require.Error(t, err) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, 0, hub.ClientCount()) } func TestIntegration_AuthenticatorFuncPanic_WithHub(t *testing.T) { @@ -458,18 +1831,30 @@ func TestIntegration_AuthenticatorFuncPanic_WithHub(t *testing.T) { conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil) if conn != nil { - conn.Close() + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) } - - require.Error(t, err) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, 0, hub.ClientCount()) select { case result := <-failureCalled: - assert.False(t, result.Valid) - require.Error(t, result.Error) - assert.Contains(t, result.Error.Error(), "authenticator panicked") + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "authenticator panicked") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "authenticator panicked") + } + case <-time.After(time.Second): t.Fatal("OnAuthFailure should be called when authenticator panics") } @@ -489,29 +1874,45 @@ func TestIntegration_AuthenticatedClient_ReceivesMessages(t *testing.T) { header.Set("Authorization", "Bearer key-1") conn, _, err := websocket.DefaultDialer.Dial(authWSURL(server), header) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Broadcast a message err = hub.Broadcast(Message{Type: TypeEvent, Data: "hello"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - // Read it - conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("expected no error, got %v", err) + } _, data, err := conn.ReadMessage() - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } var msg Message - require.True(t, core.JSONUnmarshal(data, &msg).OK) - assert.Equal(t, TypeEvent, msg.Type) - assert.Equal(t, "hello", msg.Data) -} + if !(core.JSONUnmarshal(data, &msg).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeEvent, msg.Type) { + t.Errorf("expected %v, got %v", TypeEvent, -// --------------------------------------------------------------------------- -// Unit tests — BearerTokenAuth -// --------------------------------------------------------------------------- + // --------------------------------------------------------------------------- + // Unit tests — BearerTokenAuth + // --------------------------------------------------------------------------- + msg.Type) + } + if !testEqual("hello", msg.Data) { + t.Errorf("expected %v, got %v", "hello", msg.Data) + } + +} func TestBearerTokenAuth_ValidToken_Good(t *testing.T) { auth := &BearerTokenAuth{ @@ -531,11 +1932,22 @@ func TestBearerTokenAuth_ValidToken_Good(t *testing.T) { r.Header.Set("Authorization", "Bearer jwt-abc-123") result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("user-42", result.UserID) { + t.Errorf("expected %v, got %v", "user-42", result.UserID) + } + if !testEqual("admin", result.Claims["role"]) { + t.Errorf("expected %v, got %v", "admin", result.Claims["role"]) + } + if !testEqual("jwt", result.Claims["auth_method"]) { + t.Errorf("expected %v, got %v", "jwt", result.Claims["auth_method"]) + } - assert.True(t, result.Valid) - assert.Equal(t, "user-42", result.UserID) - assert.Equal(t, "admin", result.Claims["role"]) - assert.Equal(t, "jwt", result.Claims["auth_method"]) } func TestBearerTokenAuth_InvalidToken_Bad(t *testing.T) { @@ -549,9 +1961,13 @@ func TestBearerTokenAuth_InvalidToken_Bad(t *testing.T) { r.Header.Set("Authorization", "Bearer expired-token") result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil || err.Error() != "token expired" { + t.Errorf("expected error %q, got %v", "token expired", err) + } - assert.False(t, result.Valid) - assert.EqualError(t, result.Error, "token expired") } func TestBearerTokenAuth_MissingHeader_Bad(t *testing.T) { @@ -564,9 +1980,13 @@ func TestBearerTokenAuth_MissingHeader_Bad(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/ws", nil) result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if !(core.Is(result.Error, ErrMissingAuthHeader)) { + t.Errorf("expected true") + } - assert.False(t, result.Valid) - assert.True(t, core.Is(result.Error, ErrMissingAuthHeader)) } func TestBearerTokenAuth_MalformedHeader_Bad(t *testing.T) { @@ -593,9 +2013,13 @@ func TestBearerTokenAuth_MalformedHeader_Bad(t *testing.T) { r.Header.Set("Authorization", tt.header) result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if !(core.Is(result.Error, ErrMalformedAuthHeader)) { + t.Errorf("expected true") + } - assert.False(t, result.Valid) - assert.True(t, core.Is(result.Error, ErrMalformedAuthHeader)) }) } } @@ -611,14 +2035,18 @@ func TestBearerTokenAuth_CaseInsensitiveScheme_Good(t *testing.T) { r.Header.Set("Authorization", "bearer my-token") result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !testEqual("user-1", result.UserID) { - assert.True(t, result.Valid) - assert.Equal(t, "user-1", result.UserID) -} + // --------------------------------------------------------------------------- + // Integration tests — BearerTokenAuth with Hub + // --------------------------------------------------------------------------- + t.Errorf("expected %v, got %v", "user-1", result.UserID) + } -// --------------------------------------------------------------------------- -// Integration tests — BearerTokenAuth with Hub -// --------------------------------------------------------------------------- +} func TestIntegration_BearerTokenAuth_AcceptsValidToken_Good(t *testing.T) { auth := &BearerTokenAuth{ @@ -650,19 +2078,30 @@ func TestIntegration_BearerTokenAuth_AcceptsValidToken_Good(t *testing.T) { header.Set("Authorization", "Bearer valid-jwt") conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) - require.NoError(t, err) - defer conn.Close() - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } time.Sleep(50 * time.Millisecond) mu.Lock() client := connectedClient mu.Unlock() + if testIsNil(client) { + t.Fatalf("expected non-nil value") + } + if !testEqual("jwt-user", client.UserID) { + t.Errorf("expected %v, got %v", "jwt-user", client.UserID) + } + if !testEqual("bearer", client.Claims["auth_method"]) { + t.Errorf("expected %v, got %v", "bearer", client.Claims["auth_method"]) + } - require.NotNil(t, client) - assert.Equal(t, "jwt-user", client.UserID) - assert.Equal(t, "bearer", client.Claims["auth_method"]) } func TestIntegration_BearerTokenAuth_RejectsInvalidToken_Bad(t *testing.T) { @@ -681,18 +2120,20 @@ func TestIntegration_BearerTokenAuth_RejectsInvalidToken_Bad(t *testing.T) { conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) if conn != nil { - conn.Close() + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) } - require.Error(t, err) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, 0, hub.ClientCount()) } -// --------------------------------------------------------------------------- -// Unit tests — QueryTokenAuth -// --------------------------------------------------------------------------- - func TestQueryTokenAuth_ValidToken_Good(t *testing.T) { auth := &QueryTokenAuth{ Validate: func(token string) AuthResult { @@ -710,10 +2151,19 @@ func TestQueryTokenAuth_ValidToken_Good(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/ws?token=browser-token-456", nil) result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("browser-user", result.UserID) { + t.Errorf("expected %v, got %v", "browser-user", result.UserID) + } + if !testEqual("query_param", result.Claims["auth_method"]) { + t.Errorf("expected %v, got %v", "query_param", result.Claims["auth_method"]) + } - assert.True(t, result.Valid) - assert.Equal(t, "browser-user", result.UserID) - assert.Equal(t, "query_param", result.Claims["auth_method"]) } func TestQueryTokenAuth_InvalidToken_Bad(t *testing.T) { @@ -726,9 +2176,13 @@ func TestQueryTokenAuth_InvalidToken_Bad(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/ws?token=bad-token", nil) result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil || err.Error() != "unknown token" { + t.Errorf("expected error %q, got %v", "unknown token", err) + } - assert.False(t, result.Valid) - assert.EqualError(t, result.Error, "unknown token") } func TestQueryTokenAuth_MissingParam_Bad(t *testing.T) { @@ -741,9 +2195,13 @@ func TestQueryTokenAuth_MissingParam_Bad(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/ws", nil) result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if !testContains(result.Error.Error(), "missing token query parameter") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "missing token query parameter") + } - assert.False(t, result.Valid) - assert.Contains(t, result.Error.Error(), "missing token query parameter") } func TestQueryTokenAuth_EmptyParam_Bad(t *testing.T) { @@ -756,9 +2214,13 @@ func TestQueryTokenAuth_EmptyParam_Bad(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/ws?token=", nil) result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if !testContains(result.Error.Error(), "missing token query parameter") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "missing token query parameter") + } - assert.False(t, result.Valid) - assert.Contains(t, result.Error.Error(), "missing token query parameter") } func TestQueryTokenAuth_NilURL_Bad(t *testing.T) { @@ -772,16 +2234,25 @@ func TestQueryTokenAuth_NilURL_Bad(t *testing.T) { r := &http.Request{Method: http.MethodGet} result := auth.Authenticate(r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "request URL is nil") { + t.Errorf("expected %v to contain %v", result.Error.Error(), - assert.False(t, result.Valid) - require.Error(t, result.Error) - assert.Contains(t, result.Error.Error(), "request URL is nil") - assert.False(t, called, "validate should not be called when request URL is nil") -} + // --------------------------------------------------------------------------- + // Integration tests — QueryTokenAuth with Hub + // --------------------------------------------------------------------------- + "request URL is nil") + } + if called { + t.Errorf("expected false") + } -// --------------------------------------------------------------------------- -// Integration tests — QueryTokenAuth with Hub -// --------------------------------------------------------------------------- +} func TestIntegration_QueryTokenAuth_AcceptsValidToken_Good(t *testing.T) { auth := &QueryTokenAuth{ @@ -811,20 +2282,33 @@ func TestIntegration_QueryTokenAuth_AcceptsValidToken_Good(t *testing.T) { conn, resp, err := websocket.DefaultDialer.Dial( authWSURL(server)+"?token=browser-secret", nil) - require.NoError(t, err) - defer conn.Close() - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } mu.Lock() client := connectedClient mu.Unlock() + if testIsNil(client) { + t.Fatalf("expected non-nil value") + } + if !testEqual("browser-user-99", client.UserID) { + t.Errorf("expected %v, got %v", "browser-user-99", client.UserID) + } + if !testEqual("browser", client.Claims["origin"]) { + t.Errorf("expected %v, got %v", "browser", client.Claims["origin"]) + } - require.NotNil(t, client) - assert.Equal(t, "browser-user-99", client.UserID) - assert.Equal(t, "browser", client.Claims["origin"]) } func TestIntegration_QueryTokenAuth_RejectsInvalidToken_Bad(t *testing.T) { @@ -841,12 +2325,18 @@ func TestIntegration_QueryTokenAuth_RejectsInvalidToken_Bad(t *testing.T) { conn, resp, err := websocket.DefaultDialer.Dial( authWSURL(server)+"?token=wrong", nil) if conn != nil { - conn.Close() + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) } - require.Error(t, err) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, 0, hub.ClientCount()) } func TestIntegration_QueryTokenAuth_RejectsMissingToken_Bad(t *testing.T) { @@ -863,16 +2353,25 @@ func TestIntegration_QueryTokenAuth_RejectsMissingToken_Bad(t *testing.T) { // No ?token= parameter conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil) if conn != nil { - conn.Close() + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount( + + // Authenticated via query param, then subscribe and receive messages + )) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) } - require.Error(t, err) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, 0, hub.ClientCount()) } func TestIntegration_QueryTokenAuth_EndToEnd_Good(t *testing.T) { - // Authenticated via query param, then subscribe and receive messages + auth := &QueryTokenAuth{ Validate: func(token string) AuthResult { if token == "good-token" { @@ -888,26 +2387,87 @@ func TestIntegration_QueryTokenAuth_EndToEnd_Good(t *testing.T) { conn, _, err := websocket.DefaultDialer.Dial( authWSURL(server)+"?token=good-token", nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Subscribe to a channel err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "events"}) - require.NoError(t, err) - time.Sleep(50 * time.Millisecond) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - assert.Equal(t, 1, hub.ChannelSubscriberCount("events")) + time.Sleep(50 * time.Millisecond) + if !testEqual(1, hub.ChannelSubscriberCount("events")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("events")) + } - // Send a message to the channel + // Send a message to the channel. err = hub.SendToChannel("events", Message{Type: TypeEvent, Data: "hello alice"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - conn.SetReadDeadline(time.Now().Add(time.Second)) + if err := conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("expected no error, got %v", err) + } var received Message err = conn.ReadJSON(&received) - require.NoError(t, err) - assert.Equal(t, TypeEvent, received.Type) - assert.Equal(t, "hello alice", received.Data) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + if !testEqual("hello alice", received.Data) { + t.Errorf("expected %v, got %v", "hello alice", received.Data) + } + +} + +func TestAPIKeyAuthenticator_AuthenticatedAlias(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-abc") + + result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + +} + +func TestQueryTokenAuth_AuthenticatedAlias(t *testing.T) { + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{ + Authenticated: true, + UserID: token, + } + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws?token=alias-token", nil) + + result := auth.Authenticate(r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("alias-token", result.UserID) { + t.Errorf("expected %v, got %v", "alias-token", result.UserID) + } + } diff --git a/ax7_v090_test.go b/ax7_v090_test.go new file mode 100644 index 0000000..021df0c --- /dev/null +++ b/ax7_v090_test.go @@ -0,0 +1,1226 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ws + +import core "dappco.re/go" + +type T = core.T +type CancelFunc = core.CancelFunc +type HTTPTestServer = core.HTTPTestServer +type Request = core.Request +type Response = core.Response +type Listener = core.Listener +type Conn = core.Conn +type BufReader = core.BufReader +type HandlerFunc = core.HandlerFunc + +const ( + Millisecond = core.Millisecond + Second = core.Second +) + +var ( + AnError = core.AnError + Atoi = core.Atoi + AssertContains = core.AssertContains + AssertEmpty = core.AssertEmpty + AssertEqual = core.AssertEqual + AssertError = core.AssertError + AssertErrorIs = core.AssertErrorIs + AssertFalse = core.AssertFalse + AssertNil = core.AssertNil + AssertNoError = core.AssertNoError + AssertNotEmpty = core.AssertNotEmpty + AssertNotNil = core.AssertNotNil + AssertNotPanics = core.AssertNotPanics + AssertTrue = core.AssertTrue + Background = core.Background + Concat = core.Concat + Errorf = core.Errorf + HasPrefix = core.HasPrefix + HTTPGet = core.HTTPGet + Itoa = core.Itoa + JSONUnmarshal = core.JSONUnmarshal + NetListen = core.NetListen + NewBufReader = core.NewBufReader + NewHTTPTestRecorder = core.NewHTTPTestRecorder + NewHTTPTestRequest = core.NewHTTPTestRequest + NewHTTPTestServer = core.NewHTTPTestServer + Now = core.Now + RequireNoError = core.RequireNoError + RequireTrue = core.RequireTrue + Sleep = core.Sleep + Trim = core.Trim + TrimPrefix = core.TrimPrefix + Upper = core.Upper + WithCancel = core.WithCancel + WithTimeout = core.WithTimeout + WriteString = core.WriteString +) + +func ax7Client() *Client { + return &Client{ + send: make(chan []byte, 16), + subscriptions: make(map[string]bool), + } +} + +func ax7Eventually(condition func() bool) bool { + deadline := Now().Add(Second) + for Now().Before(deadline) { + if condition() { + return true + } + Sleep(5 * Millisecond) + } + return condition() +} + +func ax7StartHub(t *T) (*Hub, CancelFunc) { + hub := NewHub() + ctx, cancel := WithCancel(Background()) + go hub.Run(ctx) + RequireTrue(t, ax7Eventually(func() bool { return hub.isRunning() })) + t.Cleanup(cancel) + return hub, cancel +} + +func ax7StartWSServer(t *T, config HubConfig) (*Hub, *HTTPTestServer) { + hub := NewHubWithConfig(config) + ctx, cancel := WithCancel(Background()) + go hub.Run(ctx) + RequireTrue(t, ax7Eventually(func() bool { return hub.isRunning() })) + server := NewHTTPTestServer(hub.Handler()) + t.Cleanup(func() { + server.Close() + cancel() + }) + return hub, server +} + +func ax7WSURL(server *HTTPTestServer) string { + return Concat("ws", TrimPrefix(server.URL, "http")) +} + +func ax7BroadcastMessage(t *T, hub *Hub) Message { + timeout, cancel := WithTimeout(Background(), Second) + defer cancel() + select { + case raw := <-hub.broadcast: + var msg Message + RequireTrue(t, JSONUnmarshal(raw, &msg).OK) + return msg + case <-timeout.Done(): + t.Fatal("timed out waiting for broadcast") + return Message{} + } +} + +func ax7ClientMessage(t *T, client *Client) Message { + timeout, cancel := WithTimeout(Background(), Second) + defer cancel() + select { + case raw := <-client.send: + var msg Message + RequireTrue(t, JSONUnmarshal(raw, &msg).OK) + return msg + case <-timeout.Done(): + t.Fatal("timed out waiting for client message") + return Message{} + } +} + +func ax7AuthRequest(header string) *Request { + req := NewHTTPTestRequest("GET", "/ws", nil) + if header != "" { + req.Header.Set("Authorization", header) + } + return req +} + +func ax7StartRedis(t *T) string { + r := NetListen("tcp", "127.0.0.1:0") + RequireTrue(t, r.OK) + listener := r.Value.(Listener) + t.Cleanup(func() { + if err := listener.Close(); err != nil { + AssertContains(t, err.Error(), "closed") + } + }) + go ax7AcceptRedis(listener) + return listener.Addr().String() +} + +func ax7AcceptRedis(listener Listener) { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go ax7ServeRedis(conn) + } +} + +func ax7ServeRedis(conn Conn) { + defer func() { + if err := conn.Close(); err != nil { + return + } + }() + + reader := NewBufReader(conn) + for { + parts, err := ax7ReadRedisCommand(reader) + if err != nil { + return + } + if len(parts) == 0 { + continue + } + + switch Upper(parts[0]) { + case "HELLO": + if !ax7WriteRedis(conn, "%7\r\n+server\r\n+redis\r\n+version\r\n+7.2.0\r\n+proto\r\n:3\r\n+id\r\n:1\r\n+mode\r\n+standalone\r\n+role\r\n+master\r\n+modules\r\n*0\r\n") { + return + } + case "PING": + if !ax7WriteRedis(conn, "+PONG\r\n") { + return + } + case "CLIENT", "SELECT", "READONLY": + if !ax7WriteRedis(conn, "+OK\r\n") { + return + } + case "PSUBSCRIBE": + pattern := "" + if len(parts) > 1 { + pattern = parts[1] + } + if !ax7WriteRedis(conn, Concat("*3\r\n$10\r\npsubscribe\r\n$", Itoa(len(pattern)), "\r\n", pattern, "\r\n:1\r\n")) { + return + } + case "PUBLISH": + if !ax7WriteRedis(conn, ":1\r\n") { + return + } + case "QUIT": + if !ax7WriteRedis(conn, "+OK\r\n") { + return + } + return + default: + if !ax7WriteRedis(conn, "+OK\r\n") { + return + } + } + } +} + +func ax7ReadRedisCommand(reader *BufReader) ([]string, error) { + line, err := reader.ReadString('\n') + if err != nil { + return nil, err + } + line = Trim(line) + if !HasPrefix(line, "*") { + return nil, Errorf("unexpected redis frame: %s", line) + } + + count := Atoi(TrimPrefix(line, "*")) + if !count.OK { + return nil, count.Value.(error) + } + + parts := make([]string, 0, count.Value.(int)) + for i := 0; i < count.Value.(int); i++ { + if _, err := reader.ReadString('\n'); err != nil { + return nil, err + } + value, err := reader.ReadString('\n') + if err != nil { + return nil, err + } + parts = append(parts, Trim(value)) + } + return parts, nil +} + +func ax7WriteRedis(conn Conn, payload string) bool { + return WriteString(conn, payload).OK +} + +// --- DefaultHubConfig --- + +func TestAX7_DefaultHubConfig_Good(t *T) { + cfg := DefaultHubConfig() + AssertEqual(t, DefaultHeartbeatInterval, cfg.HeartbeatInterval) + AssertEqual(t, DefaultPongTimeout, cfg.PongTimeout) + AssertEqual(t, DefaultWriteTimeout, cfg.WriteTimeout) + AssertEqual(t, DefaultMaxSubscriptionsPerClient, cfg.MaxSubscriptionsPerClient) +} + +func TestAX7_DefaultHubConfig_Bad(t *T) { + cfg := DefaultHubConfig() + AssertNil(t, cfg.Authenticator) + AssertNil(t, cfg.ChannelAuthoriser) + AssertNil(t, cfg.CheckOrigin) +} + +func TestAX7_DefaultHubConfig_Ugly(t *T) { + cfg := DefaultHubConfig() + cfg.AllowedOrigins = append(cfg.AllowedOrigins, "https://app.example") + again := DefaultHubConfig() + AssertEmpty(t, again.AllowedOrigins) +} + +// --- Hub construction and loop --- + +func TestAX7_NewHub_Good(t *T) { + hub := NewHub() + AssertNotNil(t, hub) + AssertNotNil(t, hub.clients) + AssertNotNil(t, hub.broadcast) + AssertNotNil(t, hub.channels) +} + +func TestAX7_NewHub_Bad(t *T) { + hub := NewHub() + AssertEqual(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval) + AssertEqual(t, DefaultPongTimeout, hub.config.PongTimeout) + AssertEqual(t, DefaultMaxSubscriptionsPerClient, hub.config.MaxSubscriptionsPerClient) +} + +func TestAX7_NewHub_Ugly(t *T) { + hub := NewHub() + req := NewHTTPTestRequest("GET", "http://evil.example/ws", nil) + req.Header.Set("Origin", "https://evil.example") + AssertTrue(t, hub.config.CheckOrigin(req)) +} + +func TestAX7_Hub_Run_Good(t *T) { + hub, cancel := ax7StartHub(t) + AssertTrue(t, hub.isRunning()) + cancel() + AssertTrue(t, ax7Eventually(func() bool { return !hub.isRunning() })) +} + +func TestAX7_Hub_Run_Bad(t *T) { + var hub *Hub + AssertNotPanics(t, func() { + hub.Run(Background()) + }) + AssertFalse(t, hub.isRunning()) +} + +func TestAX7_Hub_Run_Ugly(t *T) { + hub := NewHub() + ctx, cancel := WithCancel(Background()) + cancel() + hub.Run(ctx) + AssertFalse(t, hub.isRunning()) +} + +// --- Hub subscriptions and delivery --- + +func TestAX7_Hub_Subscribe_Good(t *T) { + hub := NewHub() + client := ax7Client() + err := hub.Subscribe(client, "agent.dispatch") + AssertNoError(t, err) + AssertEqual(t, 1, hub.ChannelSubscriberCount("agent.dispatch")) +} + +func TestAX7_Hub_Subscribe_Bad(t *T) { + hub := NewHub() + client := ax7Client() + err := hub.Subscribe(client, " agent.dispatch") + AssertError(t, err, "invalid channel") + AssertEmpty(t, client.Subscriptions()) +} + +func TestAX7_Hub_Subscribe_Ugly(t *T) { + var hub *Hub + client := ax7Client() + err := hub.Subscribe(client, "agent.dispatch") + AssertError(t, err, "hub must not be nil") + AssertEmpty(t, client.Subscriptions()) +} + +func TestAX7_Hub_Unsubscribe_Good(t *T) { + hub := NewHub() + client := ax7Client() + RequireNoError(t, hub.Subscribe(client, "agent.dispatch")) + hub.Unsubscribe(client, "agent.dispatch") + AssertEqual(t, 0, hub.ChannelSubscriberCount("agent.dispatch")) +} + +func TestAX7_Hub_Unsubscribe_Bad(t *T) { + hub := NewHub() + client := ax7Client() + RequireNoError(t, hub.Subscribe(client, "agent.dispatch")) + hub.Unsubscribe(client, " agent.dispatch") + AssertEqual(t, 1, hub.ChannelSubscriberCount("agent.dispatch")) +} + +func TestAX7_Hub_Unsubscribe_Ugly(t *T) { + var hub *Hub + client := ax7Client() + AssertNotPanics(t, func() { + hub.Unsubscribe(client, "agent.dispatch") + }) + AssertEmpty(t, client.Subscriptions()) +} + +func TestAX7_Hub_Broadcast_Good(t *T) { + hub := NewHub() + err := hub.Broadcast(Message{Type: TypeEvent, Data: "ready"}) + AssertNoError(t, err) + msg := ax7BroadcastMessage(t, hub) + AssertEqual(t, TypeEvent, msg.Type) + AssertFalse(t, msg.Timestamp.IsZero()) +} + +func TestAX7_Hub_Broadcast_Bad(t *T) { + hub := NewHub() + err := hub.Broadcast(Message{Type: TypeProcessOutput, ProcessID: "bad:id"}) + AssertError(t, err, "invalid process ID") + AssertEqual(t, 0, len(hub.broadcast)) +} + +func TestAX7_Hub_Broadcast_Ugly(t *T) { + var hub *Hub + err := hub.Broadcast(Message{Type: TypeEvent, Data: "ready"}) + AssertError(t, err, "hub must not be nil") + AssertNil(t, hub) +} + +func TestAX7_Hub_SendToChannel_Good(t *T) { + hub := NewHub() + client := ax7Client() + RequireNoError(t, hub.Subscribe(client, "agent.dispatch")) + err := hub.SendToChannel("agent.dispatch", Message{Type: TypeEvent, Data: "queued"}) + AssertNoError(t, err) + AssertEqual(t, "agent.dispatch", ax7ClientMessage(t, client).Channel) +} + +func TestAX7_Hub_SendToChannel_Bad(t *T) { + hub := NewHub() + err := hub.SendToChannel(" agent.dispatch", Message{Type: TypeEvent}) + AssertError(t, err, "invalid channel") + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_SendToChannel_Ugly(t *T) { + hub := NewHub() + err := hub.SendToChannel("agent.dispatch", Message{Type: TypeEvent, Data: "nobody"}) + AssertNoError(t, err) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_SendProcessOutput_Good(t *T) { + hub := NewHub() + client := ax7Client() + RequireNoError(t, hub.Subscribe(client, "process:proc-1")) + RequireNoError(t, hub.SendProcessOutput("proc-1", "line")) + AssertEqual(t, TypeProcessOutput, ax7ClientMessage(t, client).Type) +} + +func TestAX7_Hub_SendProcessOutput_Bad(t *T) { + hub := NewHub() + err := hub.SendProcessOutput("bad:id", "line") + AssertError(t, err, "invalid process ID") + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_SendProcessOutput_Ugly(t *T) { + hub := NewHub() + err := hub.SendProcessOutput("proc-1", "") + AssertNoError(t, err) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_SendProcessStatus_Good(t *T) { + hub := NewHub() + client := ax7Client() + RequireNoError(t, hub.Subscribe(client, "process:proc-1")) + RequireNoError(t, hub.SendProcessStatus("proc-1", "running", 0)) + AssertEqual(t, TypeProcessStatus, ax7ClientMessage(t, client).Type) +} + +func TestAX7_Hub_SendProcessStatus_Bad(t *T) { + hub := NewHub() + err := hub.SendProcessStatus("bad:id", "failed", 1) + AssertError(t, err, "invalid process ID") + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_SendProcessStatus_Ugly(t *T) { + hub := NewHub() + err := hub.SendProcessStatus("proc-1", "", -1) + AssertNoError(t, err) + AssertEqual(t, 0, hub.ChannelCount()) +} + +func TestAX7_Hub_SendError_Good(t *T) { + hub := NewHub() + RequireNoError(t, hub.SendError("server refused")) + msg := ax7BroadcastMessage(t, hub) + AssertEqual(t, TypeError, msg.Type) + AssertEqual(t, "server refused", msg.Data) +} + +func TestAX7_Hub_SendError_Bad(t *T) { + var hub *Hub + err := hub.SendError("server refused") + AssertError(t, err, "hub must not be nil") + AssertNil(t, hub) +} + +func TestAX7_Hub_SendError_Ugly(t *T) { + hub := NewHub() + RequireNoError(t, hub.SendError("")) + msg := ax7BroadcastMessage(t, hub) + AssertEqual(t, "", msg.Data) +} + +func TestAX7_Hub_SendEvent_Good(t *T) { + hub := NewHub() + RequireNoError(t, hub.SendEvent("agent.ready", "payload")) + msg := ax7BroadcastMessage(t, hub) + AssertEqual(t, TypeEvent, msg.Type) + AssertContains(t, msg.Data.(map[string]any), "event") +} + +func TestAX7_Hub_SendEvent_Bad(t *T) { + var hub *Hub + err := hub.SendEvent("agent.ready", "payload") + AssertError(t, err, "hub must not be nil") + AssertNil(t, hub) +} + +func TestAX7_Hub_SendEvent_Ugly(t *T) { + hub := NewHub() + RequireNoError(t, hub.SendEvent("", nil)) + msg := ax7BroadcastMessage(t, hub) + AssertEqual(t, TypeEvent, msg.Type) + AssertContains(t, msg.Data.(map[string]any), "data") +} + +// --- Hub snapshots and counts --- + +func TestAX7_Hub_ClientCount_Good(t *T) { + hub := NewHub() + client := ax7Client() + hub.clients[client] = true + AssertEqual(t, 1, hub.ClientCount()) +} + +func TestAX7_Hub_ClientCount_Bad(t *T) { + hub := NewHub() + AssertEqual(t, 0, hub.ClientCount()) + AssertNotNil(t, hub.clients) +} + +func TestAX7_Hub_ClientCount_Ugly(t *T) { + var hub *Hub + AssertEqual(t, 0, hub.ClientCount()) + AssertNil(t, hub) +} + +func TestAX7_Hub_ChannelCount_Good(t *T) { + hub := NewHub() + RequireNoError(t, hub.Subscribe(ax7Client(), "alpha")) + RequireNoError(t, hub.Subscribe(ax7Client(), "beta")) + AssertEqual(t, 2, hub.ChannelCount()) +} + +func TestAX7_Hub_ChannelCount_Bad(t *T) { + hub := NewHub() + AssertEqual(t, 0, hub.ChannelCount()) + AssertNotNil(t, hub.channels) +} + +func TestAX7_Hub_ChannelCount_Ugly(t *T) { + var hub *Hub + AssertEqual(t, 0, hub.ChannelCount()) + AssertNil(t, hub) +} + +func TestAX7_Hub_ChannelSubscriberCount_Good(t *T) { + hub := NewHub() + RequireNoError(t, hub.Subscribe(ax7Client(), "alpha")) + RequireNoError(t, hub.Subscribe(ax7Client(), "alpha")) + AssertEqual(t, 2, hub.ChannelSubscriberCount("alpha")) +} + +func TestAX7_Hub_ChannelSubscriberCount_Bad(t *T) { + hub := NewHub() + AssertEqual(t, 0, hub.ChannelSubscriberCount("missing")) + AssertNotNil(t, hub.channels) +} + +func TestAX7_Hub_ChannelSubscriberCount_Ugly(t *T) { + var hub *Hub + AssertEqual(t, 0, hub.ChannelSubscriberCount("alpha")) + AssertNil(t, hub) +} + +func TestAX7_Hub_AllClients_Good(t *T) { + hub := NewHub() + hub.clients[&Client{UserID: "b"}] = true + hub.clients[&Client{UserID: "a"}] = true + var ids []string + for client := range hub.AllClients() { + ids = append(ids, client.UserID) + } + AssertEqual(t, []string{"a", "b"}, ids) +} + +func TestAX7_Hub_AllClients_Bad(t *T) { + hub := NewHub() + var clients []*Client + for client := range hub.AllClients() { + clients = append(clients, client) + } + AssertEmpty(t, clients) +} + +func TestAX7_Hub_AllClients_Ugly(t *T) { + var hub *Hub + var clients []*Client + for client := range hub.AllClients() { + clients = append(clients, client) + } + AssertEmpty(t, clients) +} + +func TestAX7_Hub_AllChannels_Good(t *T) { + hub := NewHub() + RequireNoError(t, hub.Subscribe(ax7Client(), "beta")) + RequireNoError(t, hub.Subscribe(ax7Client(), "alpha")) + var channels []string + for channel := range hub.AllChannels() { + channels = append(channels, channel) + } + AssertEqual(t, []string{"alpha", "beta"}, channels) +} + +func TestAX7_Hub_AllChannels_Bad(t *T) { + hub := NewHub() + var channels []string + for channel := range hub.AllChannels() { + channels = append(channels, channel) + } + AssertEmpty(t, channels) +} + +func TestAX7_Hub_AllChannels_Ugly(t *T) { + var hub *Hub + var channels []string + for channel := range hub.AllChannels() { + channels = append(channels, channel) + } + AssertEmpty(t, channels) +} + +func TestAX7_Hub_Stats_Good(t *T) { + hub := NewHub() + client := ax7Client() + hub.clients[client] = true + RequireNoError(t, hub.Subscribe(client, "alpha")) + stats := hub.Stats() + AssertEqual(t, 1, stats.Clients) + AssertEqual(t, 1, stats.Channels) + AssertEqual(t, 1, stats.Subscribers) +} + +func TestAX7_Hub_Stats_Bad(t *T) { + hub := NewHub() + stats := hub.Stats() + AssertEqual(t, HubStats{}, stats) + AssertEqual(t, 0, stats.Subscribers) +} + +func TestAX7_Hub_Stats_Ugly(t *T) { + var hub *Hub + stats := hub.Stats() + AssertEqual(t, HubStats{}, stats) + AssertNil(t, hub) +} + +// --- HTTP handlers --- + +func TestAX7_Hub_Handler_Good(t *T) { + hub, _ := ax7StartHub(t) + server := NewHTTPTestServer(hub.Handler()) + t.Cleanup(server.Close) + resp := HTTPGet(server.URL) + RequireTrue(t, resp.OK) + AssertEqual(t, 400, resp.Value.(*Response).StatusCode) + AssertNoError(t, resp.Value.(*Response).Body.Close()) +} + +func TestAX7_Hub_Handler_Bad(t *T) { + var hub *Hub + handler := hub.Handler() + rec := NewHTTPTestRecorder() + req := NewHTTPTestRequest("GET", "/ws", nil) + handler(rec, req) + AssertEqual(t, 503, rec.Code) + AssertContains(t, rec.Body.String(), "Hub is not configured") +} + +func TestAX7_Hub_Handler_Ugly(t *T) { + hub, _ := ax7StartHub(t) + hub.config.CheckOrigin = func(*Request) bool { panic("origin panic") } + rec := NewHTTPTestRecorder() + req := NewHTTPTestRequest("GET", "http://example.com/ws", nil) + hub.Handler()(rec, req) + AssertEqual(t, 403, rec.Code) +} + +func TestAX7_Hub_HandleWebSocket_Good(t *T) { + hub, _ := ax7StartHub(t) + server := NewHTTPTestServer(HandlerFunc(hub.HandleWebSocket)) + t.Cleanup(server.Close) + resp := HTTPGet(server.URL) + RequireTrue(t, resp.OK) + AssertEqual(t, 400, resp.Value.(*Response).StatusCode) + AssertNoError(t, resp.Value.(*Response).Body.Close()) +} + +func TestAX7_Hub_HandleWebSocket_Bad(t *T) { + var hub *Hub + rec := NewHTTPTestRecorder() + req := NewHTTPTestRequest("GET", "/ws", nil) + hub.HandleWebSocket(rec, req) + AssertEqual(t, 503, rec.Code) + AssertContains(t, rec.Body.String(), "Hub is not configured") +} + +func TestAX7_Hub_HandleWebSocket_Ugly(t *T) { + hub, _ := ax7StartHub(t) + hub.config.CheckOrigin = func(*Request) bool { return false } + rec := NewHTTPTestRecorder() + req := NewHTTPTestRequest("GET", "http://example.com/ws", nil) + hub.HandleWebSocket(rec, req) + AssertEqual(t, 403, rec.Code) +} + +// --- Client methods --- + +func TestAX7_Client_Subscriptions_Good(t *T) { + client := ax7Client() + client.subscriptions["beta"] = true + client.subscriptions["alpha"] = true + AssertEqual(t, []string{"alpha", "beta"}, client.Subscriptions()) +} + +func TestAX7_Client_Subscriptions_Bad(t *T) { + var client *Client + subscriptions := client.Subscriptions() + AssertNil(t, subscriptions) + AssertEmpty(t, subscriptions) +} + +func TestAX7_Client_Subscriptions_Ugly(t *T) { + client := ax7Client() + client.subscriptions["alpha"] = true + snapshot := client.Subscriptions() + snapshot[0] = "mutated" + AssertEqual(t, []string{"alpha"}, client.Subscriptions()) +} + +func TestAX7_Client_AllSubscriptions_Good(t *T) { + client := ax7Client() + client.subscriptions["beta"] = true + client.subscriptions["alpha"] = true + var subscriptions []string + for channel := range client.AllSubscriptions() { + subscriptions = append(subscriptions, channel) + } + AssertEqual(t, []string{"alpha", "beta"}, subscriptions) +} + +func TestAX7_Client_AllSubscriptions_Bad(t *T) { + client := ax7Client() + var subscriptions []string + for channel := range client.AllSubscriptions() { + subscriptions = append(subscriptions, channel) + } + AssertEmpty(t, subscriptions) +} + +func TestAX7_Client_AllSubscriptions_Ugly(t *T) { + var client *Client + var subscriptions []string + for channel := range client.AllSubscriptions() { + subscriptions = append(subscriptions, channel) + } + AssertEmpty(t, subscriptions) +} + +func TestAX7_Client_Close_Good(t *T) { + hub := NewHub() + client := ax7Client() + client.hub = hub + hub.clients[client] = true + RequireNoError(t, hub.Subscribe(client, "alpha")) + AssertNoError(t, client.Close()) + AssertEqual(t, 0, hub.ClientCount()) + AssertEmpty(t, client.Subscriptions()) +} + +func TestAX7_Client_Close_Bad(t *T) { + var client *Client + err := client.Close() + AssertNoError(t, err) + AssertNil(t, client) +} + +func TestAX7_Client_Close_Ugly(t *T) { + hub, _ := ax7StartHub(t) + client := ax7Client() + client.hub = hub + hub.register <- client + RequireTrue(t, ax7Eventually(func() bool { return hub.ClientCount() == 1 })) + AssertNoError(t, client.Close()) + AssertTrue(t, ax7Eventually(func() bool { return hub.ClientCount() == 0 })) +} + +// --- Reconnecting client --- + +func TestAX7_NewReconnectingClient_Good(t *T) { + rc := NewReconnectingClient(ReconnectConfig{URL: "ws://example.invalid/ws"}) + AssertEqual(t, StateDisconnected, rc.State()) + AssertEqual(t, Second, rc.config.InitialBackoff) + AssertEqual(t, 30*Second, rc.config.MaxBackoff) + AssertNotNil(t, rc.config.Dialer) +} + +func TestAX7_NewReconnectingClient_Bad(t *T) { + rc := NewReconnectingClient(ReconnectConfig{InitialBackoff: 10 * Millisecond, MaxBackoff: 5 * Millisecond}) + AssertEqual(t, 5*Millisecond, rc.config.InitialBackoff) + AssertEqual(t, 5*Millisecond, rc.config.MaxBackoff) + AssertEqual(t, 2.0, rc.config.BackoffMultiplier) +} + +func TestAX7_NewReconnectingClient_Ugly(t *T) { + rc := NewReconnectingClient(ReconnectConfig{BackoffMultiplier: -4, InitialBackoff: -1, MaxBackoff: -1}) + AssertEqual(t, Second, rc.config.InitialBackoff) + AssertEqual(t, 30*Second, rc.config.MaxBackoff) + AssertEqual(t, 2.0, rc.config.BackoffMultiplier) +} + +func TestAX7_ReconnectingClient_Connect_Good(t *T) { + _, server := ax7StartWSServer(t, HubConfig{}) + rc := NewReconnectingClient(ReconnectConfig{URL: ax7WSURL(server)}) + done := make(chan error, 1) + go func() { done <- rc.Connect(Background()) }() + RequireTrue(t, ax7Eventually(func() bool { return rc.State() == StateConnected })) + AssertNoError(t, rc.Close()) + timeout, cancel := WithTimeout(Background(), Second) + defer cancel() + select { + case err := <-done: + if err != nil { + AssertContains(t, err.Error(), "context canceled") + } + case <-timeout.Done(): + t.Fatal("timed out waiting for reconnecting client shutdown") + } +} + +func TestAX7_ReconnectingClient_Connect_Bad(t *T) { + var rc *ReconnectingClient + err := rc.Connect(Background()) + AssertError(t, err, "client must not be nil") + AssertNil(t, rc) +} + +func TestAX7_ReconnectingClient_Connect_Ugly(t *T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1/ws", + InitialBackoff: Millisecond, + MaxBackoff: Millisecond, + MaxReconnectAttempts: 1, + }) + err := rc.Connect(Background()) + AssertError(t, err, "max retries") + AssertEqual(t, StateDisconnected, rc.State()) +} + +func TestAX7_ReconnectingClient_Send_Good(t *T) { + _, server := ax7StartWSServer(t, HubConfig{}) + rc := NewReconnectingClient(ReconnectConfig{URL: ax7WSURL(server)}) + done := make(chan error, 1) + go func() { done <- rc.Connect(Background()) }() + RequireTrue(t, ax7Eventually(func() bool { return rc.State() == StateConnected })) + AssertNoError(t, rc.Send(Message{Type: TypePing})) + AssertNoError(t, rc.Close()) + <-done +} + +func TestAX7_ReconnectingClient_Send_Bad(t *T) { + var rc *ReconnectingClient + err := rc.Send(Message{Type: TypePing}) + AssertError(t, err, "client must not be nil") + AssertNil(t, rc) +} + +func TestAX7_ReconnectingClient_Send_Ugly(t *T) { + rc := NewReconnectingClient(ReconnectConfig{URL: "ws://example.invalid/ws"}) + err := rc.Send(Message{Type: TypePing}) + AssertError(t, err, "not connected") + AssertEqual(t, StateDisconnected, rc.State()) +} + +func TestAX7_ReconnectingClient_State_Good(t *T) { + rc := NewReconnectingClient(ReconnectConfig{}) + rc.setState(StateConnecting) + AssertEqual(t, StateConnecting, rc.State()) +} + +func TestAX7_ReconnectingClient_State_Bad(t *T) { + var rc *ReconnectingClient + state := rc.State() + AssertEqual(t, StateDisconnected, state) + AssertNil(t, rc) +} + +func TestAX7_ReconnectingClient_State_Ugly(t *T) { + rc := NewReconnectingClient(ReconnectConfig{}) + rc.setState(ConnectionState(99)) + AssertEqual(t, ConnectionState(99), rc.State()) +} + +func TestAX7_ReconnectingClient_Close_Good(t *T) { + rc := NewReconnectingClient(ReconnectConfig{}) + rc.setState(StateConnected) + AssertNoError(t, rc.Close()) + AssertEqual(t, StateDisconnected, rc.State()) +} + +func TestAX7_ReconnectingClient_Close_Bad(t *T) { + var rc *ReconnectingClient + err := rc.Close() + AssertNoError(t, err) + AssertNil(t, rc) +} + +func TestAX7_ReconnectingClient_Close_Ugly(t *T) { + rc := NewReconnectingClient(ReconnectConfig{}) + AssertNoError(t, rc.Close()) + AssertNoError(t, rc.Close()) + AssertEqual(t, StateDisconnected, rc.State()) +} + +// --- Authentication --- + +func TestAX7_NewAPIKeyAuth_Good(t *T) { + auth := NewAPIKeyAuth(map[string]string{"secret": "user-1"}) + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-1", result.UserID) + AssertEqual(t, "api_key", result.Claims["auth_method"]) +} + +func TestAX7_NewAPIKeyAuth_Bad(t *T) { + auth := NewAPIKeyAuth(nil) + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertFalse(t, result.Valid) + AssertErrorIs(t, result.Error, ErrInvalidAPIKey) +} + +func TestAX7_NewAPIKeyAuth_Ugly(t *T) { + keys := map[string]string{"secret": "user-1"} + auth := NewAPIKeyAuth(keys) + keys["secret"] = "mutated" + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-1", result.UserID) +} + +func TestAX7_APIKeyAuthenticator_Authenticate_Good(t *T) { + auth := NewAPIKeyAuth(map[string]string{"secret": "user-1"}) + result := auth.Authenticate(ax7AuthRequest("bearer secret")) + AssertTrue(t, result.Valid) + AssertTrue(t, result.Authenticated) + AssertEqual(t, "user-1", result.UserID) +} + +func TestAX7_APIKeyAuthenticator_Authenticate_Bad(t *T) { + auth := NewAPIKeyAuth(map[string]string{"secret": "user-1"}) + result := auth.Authenticate(ax7AuthRequest("Bearer wrong")) + AssertFalse(t, result.Valid) + AssertErrorIs(t, result.Error, ErrInvalidAPIKey) + AssertEqual(t, "", result.UserID) +} + +func TestAX7_APIKeyAuthenticator_Authenticate_Ugly(t *T) { + var auth *APIKeyAuthenticator + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "authenticator is nil") +} + +func TestAX7_AuthenticatorFunc_Authenticate_Good(t *T) { + auth := AuthenticatorFunc(func(*Request) AuthResult { + return AuthResult{Authenticated: true, UserID: " user-1 "} + }) + result := auth.Authenticate(NewHTTPTestRequest("GET", "/ws", nil)) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-1", result.UserID) +} + +func TestAX7_AuthenticatorFunc_Authenticate_Bad(t *T) { + var auth AuthenticatorFunc + result := auth.Authenticate(NewHTTPTestRequest("GET", "/ws", nil)) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "authenticator function is nil") +} + +func TestAX7_AuthenticatorFunc_Authenticate_Ugly(t *T) { + auth := AuthenticatorFunc(func(*Request) AuthResult { + return AuthResult{Authenticated: true, UserID: ""} + }) + result := auth.Authenticate(NewHTTPTestRequest("GET", "/ws", nil)) + AssertFalse(t, result.Valid) + AssertErrorIs(t, result.Error, ErrMissingUserID) +} + +func TestAX7_NewBearerTokenAuth_Good(t *T) { + auth := NewBearerTokenAuth(func(token string) AuthResult { + return AuthResult{Authenticated: token == "secret", UserID: "user-1"} + }) + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-1", result.UserID) +} + +func TestAX7_NewBearerTokenAuth_Bad(t *T) { + auth := NewBearerTokenAuth() + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "validate function is not configured") +} + +func TestAX7_NewBearerTokenAuth_Ugly(t *T) { + auth := NewBearerTokenAuth(nil) + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "validate function is not configured") +} + +func TestAX7_BearerTokenAuth_Authenticate_Good(t *T) { + auth := &BearerTokenAuth{Validate: func(token string) AuthResult { + return AuthResult{Authenticated: token == "secret", UserID: "user-1"} + }} + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-1", result.UserID) +} + +func TestAX7_BearerTokenAuth_Authenticate_Bad(t *T) { + auth := NewBearerTokenAuth(func(string) AuthResult { + return AuthResult{Valid: false, Error: AnError} + }) + result := auth.Authenticate(ax7AuthRequest("")) + AssertFalse(t, result.Valid) + AssertErrorIs(t, result.Error, ErrMissingAuthHeader) +} + +func TestAX7_BearerTokenAuth_Authenticate_Ugly(t *T) { + var auth *BearerTokenAuth + result := auth.Authenticate(ax7AuthRequest("Bearer secret")) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "authenticator is nil") +} + +func TestAX7_NewQueryTokenAuth_Good(t *T) { + auth := NewQueryTokenAuth(func(token string) AuthResult { + return AuthResult{Authenticated: token == "secret", UserID: "user-1"} + }) + result := auth.Authenticate(NewHTTPTestRequest("GET", "/ws?token=secret", nil)) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-1", result.UserID) +} + +func TestAX7_NewQueryTokenAuth_Bad(t *T) { + auth := NewQueryTokenAuth() + result := auth.Authenticate(NewHTTPTestRequest("GET", "/ws?token=secret", nil)) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "validate function is not configured") +} + +func TestAX7_NewQueryTokenAuth_Ugly(t *T) { + auth := NewQueryTokenAuth(nil) + result := auth.Authenticate(NewHTTPTestRequest("GET", "/ws?token=secret", nil)) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "validate function is not configured") +} + +func TestAX7_QueryTokenAuth_Authenticate_Good(t *T) { + auth := &QueryTokenAuth{Validate: func(token string) AuthResult { + return AuthResult{Authenticated: token == "secret", UserID: "user-1"} + }} + result := auth.Authenticate(NewHTTPTestRequest("GET", "/ws?token=secret", nil)) + AssertTrue(t, result.Valid) + AssertEqual(t, "user-1", result.UserID) +} + +func TestAX7_QueryTokenAuth_Authenticate_Bad(t *T) { + auth := NewQueryTokenAuth(func(string) AuthResult { + return AuthResult{Authenticated: true, UserID: "user-1"} + }) + result := auth.Authenticate(NewHTTPTestRequest("GET", "/ws", nil)) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "missing token") +} + +func TestAX7_QueryTokenAuth_Authenticate_Ugly(t *T) { + auth := NewQueryTokenAuth(func(string) AuthResult { + return AuthResult{Authenticated: true, UserID: "user-1"} + }) + req := NewHTTPTestRequest("GET", "/ws?token=secret", nil) + req.URL = nil + result := auth.Authenticate(req) + AssertFalse(t, result.Valid) + AssertError(t, result.Error, "request URL is nil") +} + +// --- Redis bridge --- + +func TestAX7_NewRedisBridge_Good(t *T) { + addr := ax7StartRedis(t) + hub := NewHub() + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: addr, Prefix: "ws"}) + RequireNoError(t, err) + AssertEqual(t, hub, bridge.hub) + AssertNotEmpty(t, bridge.SourceID()) + AssertNoError(t, bridge.Stop()) +} + +func TestAX7_NewRedisBridge_Bad(t *T) { + bridge, err := NewRedisBridge(nil, RedisConfig{Addr: "127.0.0.1:1"}) + AssertError(t, err, "hub must not be nil") + AssertNil(t, bridge) +} + +func TestAX7_NewRedisBridge_Ugly(t *T) { + bridge, err := NewRedisBridge(NewHub(), RedisConfig{Addr: "127.0.0.1:1", Prefix: "bad prefix"}) + AssertError(t, err, "invalid redis prefix") + AssertNil(t, bridge) +} + +func TestAX7_RedisBridge_Start_Good(t *T) { + addr := ax7StartRedis(t) + bridge, err := NewRedisBridge(NewHub(), RedisConfig{Addr: addr, Prefix: "ws"}) + RequireNoError(t, err) + AssertNoError(t, bridge.Start(Background())) + AssertTrue(t, redisBridgeListening(bridge)) + AssertNoError(t, bridge.Stop()) +} + +func TestAX7_RedisBridge_Start_Bad(t *T) { + var bridge *RedisBridge + err := bridge.Start(Background()) + AssertError(t, err, "bridge must not be nil") + AssertNil(t, bridge) +} + +func TestAX7_RedisBridge_Start_Ugly(t *T) { + bridge := &RedisBridge{prefix: "ws"} + err := bridge.Start(nil) + AssertError(t, err, "redis client is not available") + AssertFalse(t, redisBridgeListening(bridge)) +} + +func TestAX7_RedisBridge_Stop_Good(t *T) { + addr := ax7StartRedis(t) + bridge, err := NewRedisBridge(NewHub(), RedisConfig{Addr: addr, Prefix: "ws"}) + RequireNoError(t, err) + AssertNoError(t, bridge.Stop()) + AssertNil(t, bridge.client) +} + +func TestAX7_RedisBridge_Stop_Bad(t *T) { + var bridge *RedisBridge + err := bridge.Stop() + AssertNoError(t, err) + AssertNil(t, bridge) +} + +func TestAX7_RedisBridge_Stop_Ugly(t *T) { + bridge := &RedisBridge{} + AssertNoError(t, bridge.Stop()) + AssertNoError(t, bridge.Stop()) + AssertNil(t, bridge.client) +} + +func TestAX7_RedisBridge_PublishToChannel_Good(t *T) { + addr := ax7StartRedis(t) + bridge, err := NewRedisBridge(NewHub(), RedisConfig{Addr: addr, Prefix: "ws"}) + RequireNoError(t, err) + bridge.ctx = Background() + AssertNoError(t, bridge.PublishToChannel("agent.dispatch", Message{Type: TypeEvent, Data: "ready"})) + AssertNoError(t, bridge.Stop()) +} + +func TestAX7_RedisBridge_PublishToChannel_Bad(t *T) { + bridge := &RedisBridge{hub: NewHub(), prefix: "ws"} + err := bridge.PublishToChannel(" agent.dispatch", Message{Type: TypeEvent}) + AssertError(t, err, "invalid channel") + AssertEqual(t, 0, bridge.hub.ChannelCount()) +} + +func TestAX7_RedisBridge_PublishToChannel_Ugly(t *T) { + var bridge *RedisBridge + err := bridge.PublishToChannel("agent.dispatch", Message{Type: TypeEvent}) + AssertError(t, err, "bridge must not be nil") + AssertNil(t, bridge) +} + +func TestAX7_RedisBridge_PublishBroadcast_Good(t *T) { + addr := ax7StartRedis(t) + bridge, err := NewRedisBridge(NewHub(), RedisConfig{Addr: addr, Prefix: "ws"}) + RequireNoError(t, err) + bridge.ctx = Background() + AssertNoError(t, bridge.PublishBroadcast(Message{Type: TypeEvent, Data: "ready"})) + AssertNoError(t, bridge.Stop()) +} + +func TestAX7_RedisBridge_PublishBroadcast_Bad(t *T) { + bridge := &RedisBridge{prefix: "ws"} + err := bridge.PublishBroadcast(Message{Type: TypeEvent}) + AssertError(t, err, "hub must not be nil") + AssertNil(t, bridge.hub) +} + +func TestAX7_RedisBridge_PublishBroadcast_Ugly(t *T) { + bridge := &RedisBridge{hub: NewHub(), prefix: "ws"} + err := bridge.PublishBroadcast(Message{Type: TypeProcessOutput, ProcessID: "bad:id"}) + AssertError(t, err, "invalid process ID") + AssertEqual(t, 0, bridge.hub.ChannelCount()) +} + +func TestAX7_RedisBridge_SourceID_Good(t *T) { + bridge := &RedisBridge{sourceID: "source-1"} + sourceID := bridge.SourceID() + AssertEqual(t, "source-1", sourceID) + AssertNotEmpty(t, sourceID) +} + +func TestAX7_RedisBridge_SourceID_Bad(t *T) { + var bridge *RedisBridge + sourceID := bridge.SourceID() + AssertEqual(t, "", sourceID) + AssertNil(t, bridge) +} + +func TestAX7_RedisBridge_SourceID_Ugly(t *T) { + bridge := &RedisBridge{} + sourceID := bridge.SourceID() + AssertEqual(t, "", sourceID) + AssertEmpty(t, sourceID) +} diff --git a/docs/architecture.md b/docs/architecture.md index 0c6c049..b45ac38 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -415,6 +415,6 @@ These all acquire a read lock and return a snapshot. The iterators copy keys und **Local-only subscriber state.** The Redis bridge relays messages but does not share subscription state. `hub.ChannelSubscriberCount` and `hub.Stats` reflect only the local instance. There is no global subscriber registry. Sticky sessions at the load balancer level (IP hash or cookie) are the recommended approach for most deployments. -**Permissive origin check.** The WebSocket upgrader accepts all origins (`CheckOrigin` returns true). This is appropriate for development and internal tooling. Production deployments should add origin validation in the `Authenticator` or behind a reverse proxy. +**Strict origin check by default.** The WebSocket handler rejects cross-origin upgrades unless `HubConfig.CheckOrigin` explicitly allows them. The built-in default requires the `Origin` scheme and host to match the request target, and callbacks are treated as deny-by-default if they panic. Production deployments should keep the default or supply a narrowly scoped override. **Fixed broadcast buffer.** The hub's broadcast channel has capacity 256. High-throughput broadcast workloads can saturate this buffer, causing `hub.Broadcast` to return an error. Callers should handle this and decide whether to drop or queue at the application level. diff --git a/docs/history.md b/docs/history.md index 1fceae6..59de198 100644 --- a/docs/history.md +++ b/docs/history.md @@ -91,7 +91,7 @@ Sticky sessions at the load balancer level (by client IP or cookie) eliminate th ### Origin Check -The WebSocket upgrader is configured with `CheckOrigin: func(*http.Request) bool { return true }`. This accepts connections from any origin, which is appropriate for local development and internal tooling. Production deployments behind a reverse proxy with strict origin control should override the upgrader or add origin validation in an `Authenticator` implementation. +The WebSocket handler applies a same-origin policy by default. Cross-origin connections are rejected unless `HubConfig.CheckOrigin` explicitly opts in to them. The default check requires the `Origin` scheme and host to match the request target, and origin-check callbacks fail closed if they panic. ### Broadcast Buffer diff --git a/docs/index.md b/docs/index.md index 814ebcd..52158b4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -116,7 +116,6 @@ The entire library lives in a single Go package (`ws`). There are no sub-package |---|---|---| | `github.com/gorilla/websocket` | v1.5.3 | WebSocket server and client implementation | | `github.com/redis/go-redis/v9` | v9.18.0 | Redis pub/sub bridge (runtime opt-in) | -| `github.com/stretchr/testify` | v1.11.1 | Test assertions (test-only) | The Redis dependency is a compile-time import but a runtime opt-in. Applications that never create a `RedisBridge` incur no Redis connections. There are no CGO requirements; the module builds cleanly on Linux, macOS, and Windows. diff --git a/errors.go b/errors.go index 3aaaac4..f23d2ab 100644 --- a/errors.go +++ b/errors.go @@ -2,7 +2,7 @@ package ws -import coreerr "dappco.re/go/core/log" +import coreerr "dappco.re/go/log" // Authentication errors returned by the built-in APIKeyAuthenticator. var ( @@ -16,4 +16,16 @@ var ( // ErrInvalidAPIKey is returned when the provided API key does not // match any known key. ErrInvalidAPIKey = coreerr.E("", "invalid API key", nil) + + // ErrMissingUserID is returned when an authentication result marks a + // request as successful but does not provide a user identifier. + ErrMissingUserID = coreerr.E("", "authenticated user ID must not be empty", nil) + + // ErrInvalidAuthClaims is returned when an authentication result carries + // claims that cannot be safely snapshotted. + ErrInvalidAuthClaims = coreerr.E("", "authentication claims are invalid", nil) + + // ErrSubscriptionLimitExceeded is returned when a client exceeds the + // configured per-client subscription cap. + ErrSubscriptionLimitExceeded = coreerr.E("", "subscription limit exceeded", nil) ) diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..2717f01 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,54 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ws + +import ( + "fmt" + "testing" + + core "dappco.re/go" +) + +func TestErrors_AuthSentinels_Good(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + {name: "missing header", err: ErrMissingAuthHeader, want: "missing Authorization header"}, + {name: "malformed header", err: ErrMalformedAuthHeader, want: "malformed Authorization header"}, + {name: "invalid api key", err: ErrInvalidAPIKey, want: "invalid API key"}, + {name: "missing user id", err: ErrMissingUserID, want: "authenticated user ID must not be empty"}, + {name: "invalid auth claims", err: ErrInvalidAuthClaims, want: "authentication claims are invalid"}, + {name: "subscription limit exceeded", err: ErrSubscriptionLimitExceeded, want: "subscription limit exceeded"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.err; err == nil || err.Error() != tt.want { + t.Errorf("expected error %q, got %v", tt.want, err) + } + }) + } +} + +func TestErrors_AuthSentinels_Bad(t *testing.T) { + if testEqual(ErrMissingAuthHeader.Error(), ErrMalformedAuthHeader.Error()) { + t.Errorf("expected values to differ: %v", ErrMalformedAuthHeader.Error()) + } + if testEqual(ErrMissingAuthHeader.Error(), ErrInvalidAPIKey.Error()) { + t.Errorf("expected values to differ: %v", ErrInvalidAPIKey.Error()) + } + if testEqual(ErrMalformedAuthHeader.Error(), ErrInvalidAPIKey.Error()) { + t.Errorf("expected values to differ: %v", ErrInvalidAPIKey.Error()) + } + +} + +func TestErrors_AuthSentinels_Ugly(t *testing.T) { + wrapped := fmt.Errorf("auth rejected: %w", ErrMissingAuthHeader) + if !(core.Is(wrapped, ErrMissingAuthHeader)) { + t.Errorf("expected true") + } + +} diff --git a/go.mod b/go.mod index c12c4eb..9ed1257 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,12 @@ -module dappco.re/go/core/ws +module dappco.re/go/ws -go 1.26.0 +go 1.26.2 require ( - dappco.re/go/core v0.8.0-alpha.1 - dappco.re/go/core/log v0.1.0 + dappco.re/go v0.9.0 + dappco.re/go/log v0.8.0-alpha.1 github.com/gorilla/websocket v1.5.3 github.com/redis/go-redis/v9 v9.18.0 - github.com/stretchr/testify v1.11.1 ) require ( @@ -15,10 +14,11 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect - github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/stretchr/testify v1.11.1 // indirect github.com/zeebo/xxh3 v1.1.0 // indirect go.uber.org/atomic v1.11.0 // indirect golang.org/x/sys v0.42.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace dappco.re/go/log => github.com/dappcore/go-log v0.8.0-alpha.1 diff --git a/go.sum b/go.sum index 1c802d3..6757888 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,13 @@ -dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk= -dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A= -dappco.re/go/core/log v0.1.0 h1:pa71Vq2TD2aoEUQWFKwNcaJ3GBY8HbaNGqtE688Unyc= -dappco.re/go/core/log v0.1.0/go.mod h1:Nkqb8gsXhZAO8VLpx7B8i1iAmohhzqA20b9Zr8VUcJs= +dappco.re/go v0.9.0 h1:4ruZRNqKDDva8o6g65tYggjGVe42E6/lMZfVKXtr3p0= +dappco.re/go v0.9.0/go.mod h1:xapr7fLK4/9Pu2iSCr4qZuIuatmtx1j56zS/oPDbGyQ= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/dappcore/go-log v0.8.0-alpha.1 h1:OqZ9Njhz4fr+2BCHOgWxZZcPj/T46jN2UlOCytOCr2Y= +github.com/dappcore/go-log v0.8.0-alpha.1/go.mod h1:IC04Em9SfVTcXiWc1BqZDQfa1MtOuMDEermZkQcTz9c= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= @@ -17,16 +16,10 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= -github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= -github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= @@ -35,8 +28,5 @@ go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go.work b/go.work new file mode 100644 index 0000000..af6a920 --- /dev/null +++ b/go.work @@ -0,0 +1,7 @@ +go 1.26.2 + +use ( + ./ + ../go + ../go-log +) diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000..f318895 --- /dev/null +++ b/go.work.sum @@ -0,0 +1,4 @@ +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= diff --git a/redis.go b/redis.go index 57a2026..63ea375 100644 --- a/redis.go +++ b/redis.go @@ -4,17 +4,27 @@ package ws import ( "context" - "crypto/rand" + // AX-6-exception: Redis TLS transport config "crypto/tls" - "encoding/hex" + // Note: AX-6 — internal concurrency primitive; structural for go-ws hub state (RFC mandates concurrent connection map). "sync" + "time" - core "dappco.re/go/core" - coreerr "dappco.re/go/core/log" + core "dappco.re/go" + coreerr "dappco.re/go/log" "github.com/redis/go-redis/v9" ) -// RedisConfig configures the Redis pub/sub bridge. +const ( + redisConnectTimeout = 5 * time.Second + redisPublishTimeout = 5 * time.Second + maxRedisEnvelopeBytes = defaultMaxMessageBytes +) + +// RedisConfig configures the Redis connection and channel namespace used by a +// RedisBridge. +// +// bridge, _ := ws.NewRedisBridge(hub, ws.RedisConfig{Addr: "localhost:6379"}) type RedisConfig struct { // Addr is the Redis server address (e.g. "10.69.69.87:6379"). Addr string @@ -40,9 +50,45 @@ type redisEnvelope struct { Message Message `json:"message"` } -// RedisBridge connects a Hub to Redis pub/sub for cross-instance messaging. -// Multiple Hub instances using the same Redis backend will coordinate -// broadcasts and channel messages transparently. +func decodeRedisEnvelope(payload string) (redisEnvelope, bool) { + if len(payload) == 0 || len(payload) > maxRedisEnvelopeBytes { + return redisEnvelope{}, false + } + + var env redisEnvelope + if r := core.JSONUnmarshal([]byte(payload), &env); !r.OK { + return redisEnvelope{}, false + } + + return env, true +} + +// validRedisForwardedMessage rejects forwarded envelopes that carry an invalid +// process identifier. Redis is an external trust boundary, so process IDs are +// re-validated before messages are delivered to the local hub. +func validRedisForwardedMessage(msg Message) bool { + if msg.ProcessID != "" && !validProcessID(msg.ProcessID) { + return false + } + + return true +} + +func validRedisPublishMessage(msg Message) bool { + if msg.ProcessID != "" && !validProcessID(msg.ProcessID) { + return false + } + + return true +} + +func validRedisPrefix(prefix string) bool { + return validIdentifier(prefix, maxChannelNameLen) +} + +// RedisBridge mirrors hub broadcasts and channel messages through Redis pub/sub. +// +// bridge, _ := ws.NewRedisBridge(hub, ws.RedisConfig{Addr: "localhost:6379"}) type RedisBridge struct { hub *Hub client *redis.Client @@ -52,12 +98,13 @@ type RedisBridge struct { ctx context.Context cancel context.CancelFunc wg sync.WaitGroup + mu sync.RWMutex } -// NewRedisBridge creates a Redis bridge for the given Hub. -// It establishes a connection to Redis and validates connectivity -// before returning. The bridge must be started with Start() to -// begin processing messages. +// NewRedisBridge validates Redis connectivity and returns a bridge ready to be +// started with Start. +// +// ws.NewRedisBridge(hub, ws.RedisConfig{Addr: "localhost:6379"}) func NewRedisBridge(hub *Hub, cfg RedisConfig) (*RedisBridge, error) { if hub == nil { return nil, coreerr.E("NewRedisBridge", "hub must not be nil", nil) @@ -68,118 +115,216 @@ func NewRedisBridge(hub *Hub, cfg RedisConfig) (*RedisBridge, error) { if cfg.Prefix == "" { cfg.Prefix = "ws" } + if !validRedisPrefix(cfg.Prefix) { + return nil, coreerr.E("NewRedisBridge", "invalid redis prefix", nil) + } client := redis.NewClient(newRedisOptions(cfg)) // Verify connectivity. - if err := client.Ping(context.Background()).Err(); err != nil { - client.Close() + pingCtx, cancel := context.WithTimeout(context.Background(), redisConnectTimeout) + defer cancel() + if err := client.Ping(pingCtx).Err(); err != nil { + logCloseError("NewRedisBridge.client", client.Close) return nil, coreerr.E("NewRedisBridge", "redis ping failed", err) } // Generate a unique source ID to prevent echo loops. - idBytes := make([]byte, 16) - if _, err := rand.Read(idBytes); err != nil { - client.Close() - return nil, coreerr.E("NewRedisBridge", "failed to generate source ID", err) - } - sourceID := hex.EncodeToString(idBytes) + sourceID := core.ID() - return &RedisBridge{ + bridge := &RedisBridge{ hub: hub, client: client, prefix: cfg.Prefix, sourceID: sourceID, - }, nil + } + + return bridge, nil } func newRedisOptions(cfg RedisConfig) *redis.Options { return &redis.Options{ - Addr: cfg.Addr, - Password: cfg.Password, - DB: cfg.DB, - TLSConfig: cfg.TLSConfig, + Addr: cfg.Addr, + Password: cfg.Password, + DB: cfg.DB, + TLSConfig: cfg.TLSConfig, + DialTimeout: redisConnectTimeout, + ReadTimeout: redisConnectTimeout, + WriteTimeout: redisConnectTimeout, + PoolTimeout: redisConnectTimeout, } } -// Start begins listening for Redis messages and forwarding them to -// the local Hub's clients. It subscribes to the broadcast channel -// and uses pattern-subscribe for all channel-targeted messages. -// The bridge runs until Stop() is called or the provided context -// is cancelled. +// Start subscribes the bridge to Redis pub/sub channels and launches the +// listener goroutine. Calling Start again replaces the active listener. +// +// err := bridge.Start(ctx) func (rb *RedisBridge) Start(ctx context.Context) error { - rb.ctx, rb.cancel = context.WithCancel(ctx) + if rb == nil { + return coreerr.E("RedisBridge.Start", "bridge must not be nil", nil) + } + + if ctx == nil { + ctx = context.Background() + } + + if err := rb.stopListener(); err != nil { + return err + } + + rb.mu.RLock() + client := rb.client + prefix := rb.prefix + rb.mu.RUnlock() + if client == nil { + return coreerr.E("RedisBridge.Start", "redis client is not available", nil) + } + if !validRedisPrefix(prefix) { + return coreerr.E("RedisBridge.Start", "invalid redis prefix", nil) + } - broadcastChan := rb.prefix + ":broadcast" - channelPattern := rb.prefix + ":channel:*" + runCtx, cancel := context.WithCancel(ctx) - rb.pubsub = rb.client.PSubscribe(rb.ctx, broadcastChan, channelPattern) + broadcastChan := prefix + ":broadcast" + channelPattern := prefix + ":channel:*" + + pubsub := client.PSubscribe(runCtx, broadcastChan, channelPattern) // Wait for the subscription confirmation. - _, err := rb.pubsub.Receive(rb.ctx) + receiveCtx, receiveCancel := context.WithTimeout(runCtx, redisConnectTimeout) + defer receiveCancel() + _, err := pubsub.Receive(receiveCtx) if err != nil { - rb.pubsub.Close() + cancel() + logCloseError("RedisBridge.Start.pubsub", pubsub.Close) return coreerr.E("RedisBridge.Start", "redis subscribe failed", err) } + rb.mu.Lock() + rb.ctx = runCtx + rb.cancel = cancel + rb.pubsub = pubsub + rb.mu.Unlock() + rb.wg.Add(1) - go rb.listen() + go rb.listen(runCtx, pubsub, prefix) return nil } -// Stop cleanly shuts down the Redis bridge. It cancels the listener -// goroutine, closes the pub/sub subscription, and closes the Redis -// client connection. +// Stop closes the Redis listener and client held by the bridge. +// +// defer bridge.Stop() func (rb *RedisBridge) Stop() error { - if rb.cancel != nil { - rb.cancel() + if rb == nil { + return nil } - // Wait for the listener goroutine to exit. - rb.wg.Wait() - var firstErr error - if rb.pubsub != nil { - if err := rb.pubsub.Close(); err != nil && firstErr == nil { - firstErr = err - } + if err := rb.stopListener(); err != nil { + firstErr = err } - if rb.client != nil { - if err := rb.client.Close(); err != nil && firstErr == nil { + + rb.mu.Lock() + client := rb.client + rb.client = nil + rb.mu.Unlock() + if client != nil { + if err := client.Close(); err != nil && firstErr == nil { firstErr = err } } + return firstErr } -// PublishToChannel publishes a message to a specific channel via Redis. -// Other bridge instances subscribed to the same Redis will receive the -// message and deliver it to their local Hub clients on that channel. +// PublishToChannel sends a message to local subscribers and publishes it to the +// Redis channel for the named hub channel. +// +// err := bridge.PublishToChannel("notifications", ws.Message{Type: ws.TypeEvent, Data: "ready"}) func (rb *RedisBridge) PublishToChannel(channel string, msg Message) error { + if rb == nil { + return coreerr.E("RedisBridge.PublishToChannel", "bridge must not be nil", nil) + } + + if err := validateChannelTarget("RedisBridge.PublishToChannel", channel); err != nil { + return err + } + + if rb.hub == nil { + return coreerr.E("RedisBridge.PublishToChannel", "hub must not be nil", nil) + } + + msg = stampServerMessage(msg) + if !validRedisPublishMessage(msg) { + return coreerr.E("RedisBridge.PublishToChannel", "invalid process ID", nil) + } + redisChan := rb.prefix + ":channel:" + channel + if err := rb.hub.sendToChannelMessage(channel, msg, true); err != nil { + return err + } + return rb.publish(redisChan, msg) } -// PublishBroadcast publishes a broadcast message via Redis. All bridge -// instances will receive it and deliver to all their local Hub clients. +// PublishBroadcast sends a message to local clients and publishes it to the +// Redis broadcast channel. +// +// err := bridge.PublishBroadcast(ws.Message{Type: ws.TypeEvent, Data: "ready"}) func (rb *RedisBridge) PublishBroadcast(msg Message) error { + if rb == nil { + return coreerr.E("RedisBridge.PublishBroadcast", "bridge must not be nil", nil) + } + if rb.hub == nil { + return coreerr.E("RedisBridge.PublishBroadcast", "hub must not be nil", nil) + } + + msg = stampServerMessage(msg) + if !validRedisPublishMessage(msg) { + return coreerr.E("RedisBridge.PublishBroadcast", "invalid process ID", nil) + } + + localErr := rb.hub.broadcastMessage(msg, true) redisChan := rb.prefix + ":broadcast" - return rb.publish(redisChan, msg) + redisErr := rb.publish(redisChan, msg) + + if localErr != nil && redisErr != nil { + return coreerr.E("RedisBridge.PublishBroadcast", core.Sprintf("local: %v; redis: %v", localErr, redisErr), redisErr) + } + if redisErr != nil { + return redisErr + } + + return localErr } // publish serialises the envelope and publishes to the given Redis channel. func (rb *RedisBridge) publish(redisChan string, msg Message) error { - if rb.ctx == nil { + if rb == nil { + return coreerr.E("RedisBridge.publish", "bridge must not be nil", nil) + } + + rb.mu.RLock() + ctx := rb.ctx + client := rb.client + sourceID := rb.sourceID + rb.mu.RUnlock() + + if ctx == nil { return coreerr.E("RedisBridge.publish", "bridge has not been started", nil) } - if rb.client == nil { + if client == nil { return coreerr.E("RedisBridge.publish", "redis client is not available", nil) } + if !validRedisPublishMessage(msg) { + return coreerr.E("RedisBridge.publish", "invalid process ID", nil) + } + env := redisEnvelope{ - SourceID: rb.sourceID, + SourceID: sourceID, Message: msg, } @@ -188,31 +333,38 @@ func (rb *RedisBridge) publish(redisChan string, msg Message) error { return coreerr.E("RedisBridge.publish", "failed to marshal redis envelope", nil) } - return rb.client.Publish(rb.ctx, redisChan, r.Value.([]byte)).Err() + if !validRedisPrefix(rb.prefix) { + return coreerr.E("RedisBridge.publish", "invalid redis prefix", nil) + } + + publishCtx, cancel := context.WithTimeout(ctx, redisPublishTimeout) + defer cancel() + + return client.Publish(publishCtx, redisChan, r.Value.([]byte)).Err() } // listen runs in a goroutine, reading messages from the Redis pub/sub // channel and forwarding them to the local Hub. Messages originating // from this bridge instance (matching sourceID) are silently dropped // to prevent infinite loops. -func (rb *RedisBridge) listen() { +func (rb *RedisBridge) listen(ctx context.Context, pubsub *redis.PubSub, prefix string) { defer rb.wg.Done() - ch := rb.pubsub.Channel() - broadcastChan := rb.prefix + ":broadcast" - channelPrefix := rb.prefix + ":channel:" + ch := pubsub.Channel() + broadcastChan := prefix + ":broadcast" + channelPrefix := prefix + ":channel:" for { select { - case <-rb.ctx.Done(): + case <-ctx.Done(): return case redisMsg, ok := <-ch: if !ok { return } - var env redisEnvelope - if r := core.JSONUnmarshal([]byte(redisMsg.Payload), &env); !r.OK { + env, ok := decodeRedisEnvelope(redisMsg.Payload) + if !ok { // Skip malformed messages. continue } @@ -222,22 +374,68 @@ func (rb *RedisBridge) listen() { continue } + if !validRedisForwardedMessage(env.Message) { + continue + } + switch { case redisMsg.Channel == broadcastChan: + if rb.hub == nil { + continue + } // Deliver as a local broadcast. - _ = rb.hub.Broadcast(env.Message) + if err := rb.hub.broadcastMessage(env.Message, true); err != nil { + coreerr.Warn("failed to forward redis broadcast", "op", "RedisBridge.listen", "err", err) + } case core.HasPrefix(redisMsg.Channel, channelPrefix): + if rb.hub == nil { + continue + } // Extract the Hub channel name from the Redis channel. hubChannel := core.TrimPrefix(redisMsg.Channel, channelPrefix) - _ = rb.hub.SendToChannel(hubChannel, env.Message) + if validateChannelTarget("RedisBridge.listen", hubChannel) != nil { + continue + } + if err := rb.hub.sendToChannelMessage(hubChannel, env.Message, true); err != nil { + coreerr.Warn("failed to forward redis channel message", "op", "RedisBridge.listen", "err", err) + } } } } } -// SourceID returns the unique identifier for this bridge instance. -// Useful for testing and debugging. +func (rb *RedisBridge) stopListener() error { + rb.mu.Lock() + cancel := rb.cancel + pubsub := rb.pubsub + rb.cancel = nil + rb.pubsub = nil + rb.ctx = nil + rb.mu.Unlock() + + if cancel != nil { + cancel() + } + + var err error + if pubsub != nil { + err = pubsub.Close() + } + + rb.wg.Wait() + + return err +} + +// SourceID returns the bridge instance ID used to suppress self-echoed Redis +// messages. +// +// sourceID := bridge.SourceID() func (rb *RedisBridge) SourceID() string { + if rb == nil { + return "" + } + return rb.sourceID } diff --git a/redis_test.go b/redis_test.go index ae74bed..890d5fb 100644 --- a/redis_test.go +++ b/redis_test.go @@ -5,14 +5,14 @@ package ws import ( "context" "crypto/tls" + "strings" + // Note: AX-6 — internal concurrency primitive; structural for go-ws hub state (RFC mandates concurrent connection map). "sync" "testing" "time" - core "dappco.re/go/core" + core "dappco.re/go" "github.com/redis/go-redis/v9" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const redisAddr = "10.69.69.87:6379" @@ -46,10 +46,21 @@ func cleanupRedis(t *testing.T, client *redis.Client, prefix string) { for iter.Next(ctx) { client.Del(ctx, iter.Val()) } - client.Close() + _ = client.Close() }) } +func redisBridgeListening(bridge *RedisBridge) bool { + if bridge == nil { + return false + } + + bridge.mu.RLock() + defer bridge.mu.RUnlock() + + return bridge.ctx != nil && bridge.pubsub != nil +} + // startTestHub creates a Hub, starts it, and returns cleanup resources. func startTestHub(t *testing.T) (*Hub, context.Context, context.CancelFunc) { t.Helper() @@ -75,17 +86,26 @@ func TestRedisBridge_CreateAndLifecycle(t *testing.T) { Addr: redisAddr, Prefix: prefix, }) - require.NoError(t, err) - require.NotNil(t, bridge) - assert.NotEmpty(t, bridge.SourceID(), "bridge should have a unique source ID") + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if testIsNil(bridge) { + t.Fatalf("expected non-nil value") + } + if testIsEmpty(bridge.SourceID()) { + t.Errorf("expected non-empty value") + } - // Start the bridge. err = bridge.Start(context.Background()) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - // Stop the bridge. err = bridge.Stop() - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + } func TestRedisBridge_NilHub(t *testing.T) { @@ -94,8 +114,13 @@ func TestRedisBridge_NilHub(t *testing.T) { _, err := NewRedisBridge(nil, RedisConfig{ Addr: redisAddr, }) - require.Error(t, err) - assert.Contains(t, err.Error(), "hub must not be nil") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + } func TestRedisBridge_EmptyAddr(t *testing.T) { @@ -104,8 +129,13 @@ func TestRedisBridge_EmptyAddr(t *testing.T) { _, err := NewRedisBridge(hub, RedisConfig{ Addr: "", }) - require.Error(t, err) - assert.Contains(t, err.Error(), "redis address must not be empty") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "redis address must not be empty") { + t.Errorf("expected %v to contain %v", err.Error(), "redis address must not be empty") + } + } func TestRedisBridge_BadAddr(t *testing.T) { @@ -114,8 +144,49 @@ func TestRedisBridge_BadAddr(t *testing.T) { _, err := NewRedisBridge(hub, RedisConfig{ Addr: "127.0.0.1:1", // Nothing listening here. }) - require.Error(t, err) - assert.Contains(t, err.Error(), "redis ping failed") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "redis ping failed") { + t.Errorf("expected %v to contain %v", err.Error(), "redis ping failed") + } + +} + +func TestRedisBridge_InvalidPrefix_Ugly(t *testing.T) { + hub := NewHub() + + _, err := NewRedisBridge(hub, RedisConfig{ + Addr: redisAddr, + Prefix: "bad prefix", + }) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid redis prefix") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid redis prefix") + } + +} + +func TestRedisBridge_NewRedisBridge_SourceIDFailure_Ugly(t *testing.T) { + bridge, err := NewRedisBridge(NewHub(), RedisConfig{}) + if err == nil { + t.Fatalf("expected error") + } + if !testIsNil(bridge) { + t.Fatalf("expected nil bridge when validation fails") + } +} + +func TestRedisBridge_NewRedisBridge_StartFailure_Ugly(t *testing.T) { + bridge, err := NewRedisBridge(NewHub(), RedisConfig{Addr: "", Prefix: "ws"}) + if err == nil { + t.Fatalf("expected error") + } + if !testIsNil(bridge) { + t.Fatalf("expected nil bridge when Redis address is empty") + } } func TestRedisBridge_DefaultPrefix(t *testing.T) { @@ -127,12 +198,19 @@ func TestRedisBridge_DefaultPrefix(t *testing.T) { bridge, err := NewRedisBridge(hub, RedisConfig{ Addr: redisAddr, }) - require.NoError(t, err) - assert.Equal(t, "ws", bridge.prefix) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual("ws", bridge.prefix) { + t.Errorf("expected %v, got %v", "ws", bridge.prefix) + } err = bridge.Start(context.Background()) - require.NoError(t, err) - defer bridge.Stop() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge.Stop) } func TestRedisBridge_TLSConfig(t *testing.T) { @@ -146,148 +224,1014 @@ func TestRedisBridge_TLSConfig(t *testing.T) { DB: 4, TLSConfig: tlsConfig, }) + if !testEqual("redis.example:6380", options.Addr) { + t.Errorf("expected %v, got %v", "redis.example:6380", options.Addr) + } + if !testEqual("secret", options.Password) { + t.Errorf("expected %v, got %v", "secret", options.Password) + } + if !testEqual(4, options.DB) { + t.Errorf("expected %v, got %v", 4, options.DB) + } + if !testSame(tlsConfig, options.TLSConfig) { + t.Errorf("expected same reference") + } + +} + +func TestRedisBridge_newRedisOptions_Good(t *testing.T) { + options := newRedisOptions(RedisConfig{ + Addr: "redis.example:6379", + }) + if !testEqual("redis.example:6379", options.Addr) { + t.Errorf("expected %v, got %v", "redis.example:6379", options.Addr) + } + if !testEqual(redisConnectTimeout, options.DialTimeout) { + t.Errorf("expected %v, got %v", redisConnectTimeout, options.DialTimeout) + } + if !testEqual(redisConnectTimeout, options.ReadTimeout) { + t.Errorf("expected %v, got %v", redisConnectTimeout, options.ReadTimeout) + } + if !testEqual(redisConnectTimeout, options.WriteTimeout) { + t.Errorf("expected %v, got %v", redisConnectTimeout, options.WriteTimeout) + } + if !testEqual(redisConnectTimeout, options.PoolTimeout) { + t.Errorf("expected %v, got %v", redisConnectTimeout, options.PoolTimeout) + } + +} + +func TestRedisBridge_validRedisForwardedMessage(t *testing.T) { + t.Run("accepts messages without a process ID", func(t *testing.T) { + if !(validRedisForwardedMessage(Message{Type: TypeEvent, Data: "hello"})) { + t.Errorf("expected true") + } + + }) + + t.Run("rejects invalid process IDs on forwarded messages", func(t *testing.T) { + if validRedisForwardedMessage(Message{Type: TypeProcessOutput, ProcessID: "bad process", Data: "line"}) { + t.Errorf("expected false") + } + + }) + + t.Run("rejects invalid process IDs even on generic messages", func(t *testing.T) { + if validRedisForwardedMessage(Message{Type: TypeEvent, ProcessID: "bad process", Data: "payload"}) { + t.Errorf("expected false") + } + + }) +} + +func TestRedisBridge_validRedisPrefix_Good(t *testing.T) { + if !(validRedisPrefix("ws")) { + t.Errorf("expected true") + } + if !(validRedisPrefix("my_app-1:prod")) { + t.Errorf("expected true") + } + +} + +func TestRedisBridge_validRedisPrefix_Bad(t *testing.T) { + tests := []string{ + "", + "bad prefix", + strings.Repeat("a", maxChannelNameLen+1), + } + + for _, prefix := range tests { + if validRedisPrefix(prefix) { + t.Errorf("expected false") + } + + } +} + +func TestRedisBridge_validRedisPrefix_Ugly(t *testing.T) { + if validRedisPrefix(" ws ") { + t.Errorf("expected false") + } + +} + +func TestRedisBridge_Start_Bad(t *testing.T) { + bridge := &RedisBridge{} + + err := bridge.Start(context.Background()) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "redis client is not available") { + t.Errorf("expected %v to contain %v", err.Error(), "redis client is not available") + } + +} + +func TestRedisBridge_Start_InvalidPrefix_Bad(t *testing.T) { + bridge := &RedisBridge{ + client: redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}), + prefix: "bad prefix", + } + defer testClose(t, bridge.client.Close) + + err := bridge.Start(context.Background()) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid redis prefix") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid redis prefix") + } + +} + +func TestRedisBridge_Start_ClosedClient_Bad(t *testing.T) { + hub := NewHub() + client := redis.NewClient(&redis.Options{Addr: redisAddr}) + if err := client.Close(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + bridge := &RedisBridge{ + hub: hub, + client: client, + prefix: "ws", + } + + err := bridge.Start(context.Background()) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "redis subscribe failed") { + t.Errorf("expected %v to contain %v", err.Error(), "redis subscribe failed") + } + +} + +// --------------------------------------------------------------------------- +// PublishBroadcast — messages reach local WebSocket clients +// --------------------------------------------------------------------------- + +func TestRedisBridge_PublishBroadcast(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + + // Register a local client. + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.register <- client + time.Sleep(50 * time.Millisecond) + if !testEqual(1, hub.ClientCount()) { + t.Fatalf("expected %v, got %v", 1, hub.ClientCount()) + } + + // Create two bridges on same Redis; bridge1 publishes and bridge2 receives. + bridge1, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge1.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge1.Stop) + + // A second hub and bridge receive the cross-instance message. + hub2, _, _ := startTestHub(t) + client2 := &Client{ + hub: hub2, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub2.register <- client2 + time.Sleep(50 * time.Millisecond) + + bridge2, err := NewRedisBridge(hub2, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge2.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge2.Stop) + + time.Sleep(100 * time.Millisecond) + + // Publish broadcast from bridge1. + err = bridge1.PublishBroadcast(Message{Type: TypeEvent, Data: "cross-broadcast"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // bridge1's local hub should also receive the message. + select { + case msg := <-client.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + if !testEqual("cross-broadcast", received.Data) { + t.Errorf("expected %v, got %v", "cross-broadcast", received.Data) + } + + case <-time.After(3 * time.Second): + t.Fatal("bridge1 client should have received the local broadcast") + } + + // bridge2's hub should receive the message. + select { + case msg := <-client2.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + if !testEqual("cross-broadcast", received.Data) { + t.Errorf("expected %v, got %v", "cross-broadcast", received.Data) + } + + case <-time.After(3 * time.Second): + t.Fatal("bridge2 client should have received the broadcast") + } +} + +// --------------------------------------------------------------------------- +// PublishToChannel — targeted channel delivery +// --------------------------------------------------------------------------- + +func TestRedisBridge_PublishToChannel(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + + // Create a client subscribed to a specific channel. + subClient := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.register <- subClient + time.Sleep(50 * time.Millisecond) + if err := hub.Subscribe(subClient, "process:abc"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Create a client NOT subscribed to that channel. + otherClient := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.register <- otherClient + time.Sleep(50 * time.Millisecond) + + // Second hub + bridge (the publisher). + hub2, _, _ := startTestHub(t) + bridge2, err := NewRedisBridge(hub2, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge2.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge2.Stop) + + // Local hub bridge receives cross-instance channel messages. + bridge1, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge1.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge1.Stop) + + time.Sleep(100 * time.Millisecond) + + // Publish to channel from bridge2. + err = bridge2.PublishToChannel("process:abc", Message{ + Type: TypeProcessOutput, + ProcessID: "abc", + Data: "line of output", + }) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // subClient is subscribed to process:abc, so it should receive the message. + select { + case msg := <-subClient.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeProcessOutput, received.Type) { + t.Errorf("expected %v, got %v", TypeProcessOutput, received.Type) + } + if !testEqual("line of output", received.Data) { + t.Errorf("expected %v, got %v", "line of output", received.Data) + } + + case <-time.After(3 * time.Second): + t.Fatal("subscribed client should have received the channel message") + } + + // otherClient should not receive the channel message. + select { + case msg := <-otherClient.send: + t.Fatalf("unsubscribed client should not receive channel message, got: %s", msg) + case <-time.After(300 * time.Millisecond): + // Good — no message delivered. + } +} + +func TestRedisBridge_PublishToChannel_Bad(t *testing.T) { + bridge := &RedisBridge{prefix: "ws"} + + err := bridge.PublishToChannel("bad channel", Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid channel name") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid channel name") + } + + t.Run("rejects process channels with oversized IDs", func(t *testing.T) { + err := bridge.PublishToChannel("process:"+strings.Repeat("a", maxProcessIDLen+1), Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + + }) + + t.Run("rejects invalid process IDs", func(t *testing.T) { + hub := NewHub() + bridge := &RedisBridge{ + hub: hub, + client: redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}), + ctx: context.Background(), + prefix: "ws", + } + defer testClose(t, bridge.client.Close) + + err := bridge.PublishToChannel("valid-channel", Message{ + Type: TypeProcessOutput, + ProcessID: "bad process", + Data: "payload", + }) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + + }) + +} + +func TestRedisBridge_PublishToChannel_Ugly_NilHub(t *testing.T) { + bridge := &RedisBridge{prefix: "ws"} + + err := bridge.PublishToChannel("valid-channel", Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + +} + +func TestRedisBridge_PublishToChannel_HubMarshalError_Bad(t *testing.T) { + hub := NewHub() + bridge := &RedisBridge{ + hub: hub, + prefix: "ws", + } + + err := bridge.PublishToChannel("valid-channel", Message{Type: TypeEvent, Data: make(chan int)}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "failed to marshal message") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal message") + } + +} + +func TestRedisBridge_PublishToChannel_Ugly(t *testing.T) { + var bridge *RedisBridge + + err := bridge.PublishToChannel("valid-channel", Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "bridge must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "bridge must not be nil") + } + +} + +func TestRedisBridge_PublishBroadcast_Bad(t *testing.T) { + var bridge *RedisBridge + + err := bridge.PublishBroadcast(Message{Type: TypeEvent, Data: "noop"}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "bridge must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "bridge must not be nil") + } + + t.Run("rejects invalid process IDs", func(t *testing.T) { + hub := NewHub() + bridge := &RedisBridge{ + hub: hub, + client: redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}), + ctx: context.Background(), + prefix: "ws", + } + defer testClose(t, bridge.client.Close) + + err := bridge.PublishBroadcast(Message{ + Type: TypeProcessStatus, + ProcessID: "bad process", + Data: "payload", + }) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + + }) + + t.Run("preserves local and redis failures", func(t *testing.T) { + hub := NewHub() + bridge := &RedisBridge{ + hub: hub, + client: redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}), + ctx: context.Background(), + prefix: "ws", + } + defer testClose(t, bridge.client.Close) + + err := bridge.PublishBroadcast(Message{Type: TypeEvent, Data: make(chan int)}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "local:") { + t.Errorf("expected %v to contain %v", err.Error(), "local:") + } + if !testContains(err.Error(), "redis:") { + t.Errorf("expected %v to contain %v", err.Error(), "redis:") + } + if !testContains(err.Error(), "failed to marshal message") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal message") + } + if !testContains(err.Error(), "failed to marshal redis envelope") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal redis envelope") + } + + }) +} + +func TestRedisBridge_PublishBroadcast_Ugly(t *testing.T) { + bridge := &RedisBridge{ + prefix: "ws", + } + + err := bridge.PublishBroadcast(Message{Type: TypeEvent, Data: "noop"}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + +} + +func TestRedisBridge_SourceID_Good(t *testing.T) { + bridge := &RedisBridge{sourceID: "source-123"} + if !testEqual("source-123", bridge.SourceID()) { + t.Errorf("expected %v, got %v", "source-123", bridge.SourceID()) + } + +} + +func TestRedisBridge_SourceID_Bad(t *testing.T) { + var bridge *RedisBridge + if !testIsEmpty(bridge.SourceID()) { + t.Errorf("expected empty value, got %v", bridge.SourceID()) + } + +} + +func TestRedisBridge_SourceID_Ugly(t *testing.T) { + bridge := &RedisBridge{} + if !testIsEmpty(bridge.SourceID()) { + t.Errorf("expected empty value, got %v", bridge.SourceID()) + } + +} + +func TestRedisBridge_Start_Good(t *testing.T) { + t.Run("starts and stops", func(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge.Start(context.TODO()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if testIsNil(bridge.ctx) { + t.Fatalf("expected non-nil value") + } + if testIsNil(bridge.cancel) { + t.Fatalf("expected non-nil value") + } + if testIsNil(bridge.pubsub) { + t.Fatalf("expected non-nil value") + } + if err := bridge.Stop(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + }) + + t.Run("replaces an existing listener when restarted", func(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.register <- client + time.Sleep(50 * time.Millisecond) + + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge.Stop) + + ctx1, cancel1 := context.WithCancel(context.Background()) + if err := bridge.Start(ctx1); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + ctx2, cancel2 := context.WithCancel(context.Background()) + if err := bridge.Start(ctx2); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + cancel1() + + env := redisEnvelope{ + SourceID: "external-source", + Message: Message{ + Type: TypeEvent, + Data: "listener-restart", + }, + } + raw := mustMarshal(env) + if testIsNil(raw) { + t.Fatalf("expected non-nil value") + } + if err := rc.Publish(context.Background(), prefix+":broadcast", raw).Err(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-client.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("listener-restart", received.Data) { + t.Errorf("expected %v, got %v", "listener-restart", received.Data) + } + + case <-time.After(3 * time.Second): + t.Fatal("bridge should keep listening after being restarted with a new context") + } + + cancel2() + }) +} + +func TestRedisBridge_Start_NilReceiver_Bad(t *testing.T) { + var bridge *RedisBridge + + err := bridge.Start(context.Background()) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "bridge must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "bridge must not be nil") + } + +} + +func TestRedisBridge_Start_Ugly(t *testing.T) { + bridge := &RedisBridge{} + + err := bridge.Start(context.Background()) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "redis client is not available") { + t.Errorf("expected %v to contain %v", err.Error(), "redis client is not available") + } + +} + +func TestRedisBridge_Stop_Ugly(t *testing.T) { + if err := (*RedisBridge)(nil).Stop(); err != nil { + t.Errorf("expected no error, got %v", err) + } + +} + +func TestRedisBridge_Stop_ZeroValue_Good(t *testing.T) { + bridge := &RedisBridge{} + if err := bridge.Stop(); err != nil { + t.Errorf("expected no error, got %v", err) + } + +} + +func TestRedisBridge_Stop_Good(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if err := bridge.Start(context.Background()); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if err := bridge.Stop(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + +} + +func TestRedisBridge_MalformedInboundPayload_Ugly(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.register <- client + time.Sleep(50 * time.Millisecond) + + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge.Stop) + + err = rc.Publish(context.Background(), prefix+":broadcast", []byte("not-json")).Err() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-client.send: + t.Fatalf("malformed inbound payload should not be forwarded, got: %s", msg) + case <-time.After(300 * time.Millisecond): + // Good - listener skipped the malformed payload. + } +} + +func TestRedisBridge_listen_NilHubAndClosedChannel_Good(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + pubsub := rc.PSubscribe(context.Background(), prefix+":broadcast", prefix+":channel:*") + receiveCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := pubsub.Receive(receiveCtx) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + bridge := &RedisBridge{ + sourceID: "listener-source", + } + + bridge.wg.Add(1) + done := make(chan struct{}) + go func() { + bridge.listen(context.Background(), pubsub, prefix) + close(done) + }() + + broadcast := mustMarshal(redisEnvelope{ + SourceID: "external-broadcast", + Message: Message{ + Type: TypeEvent, + Data: "broadcast", + }, + }) + if testIsNil(broadcast) { + t.Fatalf("expected non-nil value") + } + if err := rc.Publish(context.Background(), prefix+":broadcast", broadcast).Err(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + channelMsg := mustMarshal(redisEnvelope{ + SourceID: "external-channel", + Message: Message{ + Type: TypeEvent, + Channel: "target", + Data: "channel", + }, + }) + if testIsNil(channelMsg) { + t.Fatalf("expected non-nil value") + } + if err := rc.Publish(context.Background(), prefix+":channel:target", channelMsg).Err(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + time.Sleep(50 * time.Millisecond) + if err := pubsub.Close(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("listener should stop when the pubsub channel closes") + } + + bridge.wg.Wait() +} + +func TestRedisBridge_DecodeRedisEnvelope_SizeLimit(t *testing.T) { + largePayload := strings.Repeat("A", maxRedisEnvelopeBytes+1) + + _, ok := decodeRedisEnvelope(largePayload) + if ok { + t.Errorf("expected false") + } + +} + +func TestRedisBridge_DecodeRedisEnvelope_Good(t *testing.T) { + payload := core.Sprintf(`{"sourceId":"%s","message":{"type":"event","timestamp":"2024-01-01T00:00:00Z"}}`, "source-123") + + env, ok := decodeRedisEnvelope(payload) + if !(ok) { + t.Fatalf("expected true") + } + if !testEqual("source-123", env.SourceID) { + t.Errorf("expected %v, got %v", "source-123", env.SourceID) + } + if !testEqual(TypeEvent, env.Message.Type) { + t.Errorf("expected %v, got %v", TypeEvent, env.Message.Type) + } + +} + +func TestRedisBridge_publish_Good(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge.Stop) + + err = bridge.publish(prefix+":broadcast", Message{Type: TypeEvent, Data: "publish-ok"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + +} + +func TestRedisBridge_publish_Bad(t *testing.T) { + bridge := &RedisBridge{ + client: redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}), + ctx: context.Background(), + } + defer testClose(t, bridge.client.Close) + + err := bridge.publish("ws:broadcast", Message{Type: TypeEvent, Data: make(chan int)}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "failed to marshal redis envelope") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal redis envelope") + } + +} + +func TestRedisBridge_publish_InvalidProcessID_Bad(t *testing.T) { + bridge := &RedisBridge{ + client: redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}), + ctx: context.Background(), + } + defer testClose(t, bridge.client.Close) + + err := bridge.publish("ws:broadcast", Message{ + Type: TypeProcessOutput, + ProcessID: "bad process", + Data: "payload", + }) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } - assert.Equal(t, "redis.example:6380", options.Addr) - assert.Equal(t, "secret", options.Password) - assert.Equal(t, 4, options.DB) - assert.Same(t, tlsConfig, options.TLSConfig) } -// --------------------------------------------------------------------------- -// PublishBroadcast — messages reach local WebSocket clients -// --------------------------------------------------------------------------- +func TestRedisBridge_publish_Ugly(t *testing.T) { + t.Run("nil receiver", func(t *testing.T) { + var bridge *RedisBridge -func TestRedisBridge_PublishBroadcast(t *testing.T) { - rc := skipIfNoRedis(t) - prefix := testPrefix(t) - cleanupRedis(t, rc, prefix) + err := bridge.publish("ws:broadcast", Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "bridge must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "bridge must not be nil") + } - hub, _, _ := startTestHub(t) + }) - // Register a local client. - client := &Client{ - hub: hub, - send: make(chan []byte, 256), - subscriptions: make(map[string]bool), - } - hub.register <- client - time.Sleep(50 * time.Millisecond) - require.Equal(t, 1, hub.ClientCount()) + t.Run("missing context", func(t *testing.T) { + bridge := &RedisBridge{ + client: redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}), + } + defer testClose(t, bridge.client.Close) - // Create two bridges on same Redis — bridge1 publishes, bridge2 receives. - bridge1, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) - err = bridge1.Start(context.Background()) - require.NoError(t, err) - defer bridge1.Stop() + err := bridge.publish("ws:broadcast", Message{Type: TypeEvent, Data: "payload"}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "bridge has not been started") { + t.Errorf("expected %v to contain %v", err.Error(), "bridge has not been started") + } - // A second hub + bridge to receive the cross-instance message. - hub2, _, _ := startTestHub(t) - client2 := &Client{ - hub: hub2, - send: make(chan []byte, 256), - subscriptions: make(map[string]bool), - } - hub2.register <- client2 - time.Sleep(50 * time.Millisecond) + }) - bridge2, err := NewRedisBridge(hub2, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) - err = bridge2.Start(context.Background()) - require.NoError(t, err) - defer bridge2.Stop() + t.Run("missing client", func(t *testing.T) { + bridge := &RedisBridge{ctx: context.Background()} - // Allow subscriptions to propagate. - time.Sleep(100 * time.Millisecond) + err := bridge.publish("ws:broadcast", Message{Type: TypeEvent, Data: "payload"}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "redis client is not available") { + t.Errorf("expected %v to contain %v", err.Error(), "redis client is not available") + } - // Publish broadcast from bridge1. - err = bridge1.PublishBroadcast(Message{Type: TypeEvent, Data: "cross-broadcast"}) - require.NoError(t, err) + }) - // bridge2's hub should receive the message (client2 gets it). - select { - case msg := <-client2.send: - var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeEvent, received.Type) - assert.Equal(t, "cross-broadcast", received.Data) - case <-time.After(3 * time.Second): - t.Fatal("bridge2 client should have received the broadcast") - } -} + t.Run("invalid prefix", func(t *testing.T) { + bridge := &RedisBridge{ + client: redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}), + ctx: context.Background(), + prefix: "bad prefix", + } + defer testClose(t, bridge.client.Close) -// --------------------------------------------------------------------------- -// PublishToChannel — targeted channel delivery -// --------------------------------------------------------------------------- + err := bridge.publish("bad prefix:broadcast", Message{Type: TypeEvent, Data: "payload"}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid redis prefix") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid redis prefix") + } -func TestRedisBridge_PublishToChannel(t *testing.T) { + }) +} + +func TestRedisBridge_SelfEchoSuppressed_Good(t *testing.T) { rc := skipIfNoRedis(t) prefix := testPrefix(t) cleanupRedis(t, rc, prefix) hub, _, _ := startTestHub(t) - - // Create a client subscribed to a specific channel. - subClient := &Client{ + client := &Client{ hub: hub, send: make(chan []byte, 256), subscriptions: make(map[string]bool), } - hub.register <- subClient + hub.register <- client time.Sleep(50 * time.Millisecond) - hub.Subscribe(subClient, "process:abc") - // Create a client NOT subscribed to that channel. - otherClient := &Client{ - hub: hub, - send: make(chan []byte, 256), - subscriptions: make(map[string]bool), + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) } - hub.register <- otherClient - time.Sleep(50 * time.Millisecond) - - // Second hub + bridge (the publisher). - hub2, _, _ := startTestHub(t) - bridge2, err := NewRedisBridge(hub2, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) - err = bridge2.Start(context.Background()) - require.NoError(t, err) - defer bridge2.Stop() - // Local hub bridge (the receiver). - bridge1, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) - err = bridge1.Start(context.Background()) - require.NoError(t, err) - defer bridge1.Stop() + err = bridge.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - time.Sleep(100 * time.Millisecond) + defer testClose(t, bridge.Stop) - // Publish to channel from bridge2. - err = bridge2.PublishToChannel("process:abc", Message{ - Type: TypeProcessOutput, - ProcessID: "abc", - Data: "line of output", - }) - require.NoError(t, err) + err = bridge.PublishBroadcast(Message{Type: TypeEvent, Data: "self-echo"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - // subClient (subscribed to process:abc) should receive the message. select { - case msg := <-subClient.send: + case msg := <-client.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeProcessOutput, received.Type) - assert.Equal(t, "line of output", received.Data) - case <-time.After(3 * time.Second): - t.Fatal("subscribed client should have received the channel message") + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("self-echo", received.Data) { + t.Errorf("expected %v, got %v", "self-echo", received.Data) + } + + case <-time.After(time.Second): + t.Fatal("client should receive the local broadcast") } - // otherClient should NOT receive the message. select { - case msg := <-otherClient.send: - t.Fatalf("unsubscribed client should not receive channel message, got: %s", msg) + case msg := <-client.send: + t.Fatalf("bridge should not echo its own Redis message, got: %s", msg) case <-time.After(300 * time.Millisecond): - // Good — no message delivered. + // Good - the bridge skipped its own source ID. } } @@ -311,10 +1255,16 @@ func TestRedisBridge_CrossBridge(t *testing.T) { time.Sleep(50 * time.Millisecond) bridgeA, err := NewRedisBridge(hubA, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = bridgeA.Start(context.Background()) - require.NoError(t, err) - defer bridgeA.Stop() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridgeA.Stop) // Hub B with a client. hubB, _, _ := startTestHub(t) @@ -327,36 +1277,87 @@ func TestRedisBridge_CrossBridge(t *testing.T) { time.Sleep(50 * time.Millisecond) bridgeB, err := NewRedisBridge(hubB, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = bridgeB.Start(context.Background()) - require.NoError(t, err) - defer bridgeB.Stop() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - // Allow subscriptions to settle. - time.Sleep(200 * time.Millisecond) + defer testClose(t, bridgeB.Stop) + + if !testEventually(func() bool { + return redisBridgeListening(bridgeA) && redisBridgeListening(bridgeB) + }, 3*time.Second, 50*time.Millisecond) { + t.Fatal("bridges did not start listening in time") + } // Publish from A, verify B receives. err = bridgeA.PublishBroadcast(Message{Type: TypeEvent, Data: "from-A"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-clientA.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("from-A", received.Data) { + t.Errorf("expected %v, got %v", "from-A", received.Data) + } + + case <-time.After(3 * time.Second): + t.Fatal("hub A should receive its local broadcast") + } select { case msg := <-clientB.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, "from-A", received.Data) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("from-A", received.Data) { + t.Errorf("expected %v, got %v", "from-A", received.Data) + } + case <-time.After(3 * time.Second): t.Fatal("hub B should receive broadcast from hub A") } // Publish from B, verify A receives. err = bridgeB.PublishBroadcast(Message{Type: TypeEvent, Data: "from-B"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-clientB.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("from-B", received.Data) { + t.Errorf("expected %v, got %v", "from-B", received.Data) + } + + case <-time.After(3 * time.Second): + t.Fatal("hub B should receive its local broadcast") + } select { case msg := <-clientA.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, "from-B", received.Data) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("from-B", received.Data) { + t.Errorf("expected %v, got %v", "from-B", received.Data) + } + case <-time.After(3 * time.Second): t.Fatal("hub A should receive broadcast from hub B") } @@ -381,23 +1382,44 @@ func TestRedisBridge_LoopPrevention(t *testing.T) { time.Sleep(50 * time.Millisecond) bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = bridge.Start(context.Background()) - require.NoError(t, err) - defer bridge.Stop() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge.Stop) time.Sleep(100 * time.Millisecond) - // Publish from this bridge — the same bridge should NOT deliver - // the message back to its own hub. + // Publish from this bridge — the local hub should receive the message once, + // and loop prevention should stop a second echoed copy from Redis. err = bridge.PublishBroadcast(Message{Type: TypeEvent, Data: "echo-test"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-client.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("echo-test", received.Data) { + t.Errorf("expected %v, got %v", "echo-test", received.Data) + } + + case <-time.After(3 * time.Second): + t.Fatal("bridge should deliver the broadcast to its local hub") + } select { case msg := <-client.send: - t.Fatalf("bridge should not echo its own messages, got: %s", msg) + t.Fatalf("bridge should not echo its own Redis message twice, got: %s", msg) case <-time.After(500 * time.Millisecond): - // Good — no echo. } } @@ -421,18 +1443,30 @@ func TestRedisBridge_ConcurrentPublishes(t *testing.T) { time.Sleep(50 * time.Millisecond) bridgeRecv, err := NewRedisBridge(hubRecv, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = bridgeRecv.Start(context.Background()) - require.NoError(t, err) - defer bridgeRecv.Stop() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridgeRecv.Stop) // Sender hub. hubSend, _, _ := startTestHub(t) bridgeSend, err := NewRedisBridge(hubSend, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = bridgeSend.Start(context.Background()) - require.NoError(t, err) - defer bridgeSend.Stop() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridgeSend.Stop) time.Sleep(200 * time.Millisecond) @@ -462,7 +1496,10 @@ func TestRedisBridge_ConcurrentPublishes(t *testing.T) { t.Fatalf("expected %d messages, received %d", numPublishes, received) } } - assert.Equal(t, numPublishes, received) + if !testEqual(numPublishes, received) { + t.Errorf("expected %v, got %v", numPublishes, received) + } + } // --------------------------------------------------------------------------- @@ -477,9 +1514,14 @@ func TestRedisBridge_GracefulShutdown(t *testing.T) { hub, _, _ := startTestHub(t) bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = bridge.Start(context.Background()) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } // Stop should not panic or hang. done := make(chan error, 1) @@ -489,14 +1531,20 @@ func TestRedisBridge_GracefulShutdown(t *testing.T) { select { case err := <-done: - assert.NoError(t, err) + if err := err; err != nil { + t.Errorf("expected no error, got %v", err) + } + case <-time.After(5 * time.Second): t.Fatal("Stop() should not hang") } // Publishing after stop should fail gracefully (context cancelled). err = bridge.PublishBroadcast(Message{Type: TypeEvent, Data: "after-stop"}) - assert.Error(t, err, "publishing after stop should error") + if err := err; err == nil { + t.Errorf("expected error") + } + } func TestRedisBridge_StopWithoutStart(t *testing.T) { @@ -507,12 +1555,14 @@ func TestRedisBridge_StopWithoutStart(t *testing.T) { hub, _, _ := startTestHub(t) bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) - + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } // Stop without Start should not panic. - assert.NotPanics(t, func() { + testNotPanics(t, func() { _ = bridge.Stop() }) + } // --------------------------------------------------------------------------- @@ -527,19 +1577,26 @@ func TestRedisBridge_ContextCancellation(t *testing.T) { hub, _, _ := startTestHub(t) bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } ctx, cancel := context.WithCancel(context.Background()) err = bridge.Start(ctx) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - // Cancel the context — the listener should exit gracefully. + // Cancel the context so the listener exits gracefully. cancel() time.Sleep(200 * time.Millisecond) // Cleanup without hanging. err = bridge.Stop() - assert.NoError(t, err) + if err := err; err != nil { + t.Errorf("expected no error, got %v", err) + } + } // --------------------------------------------------------------------------- @@ -568,35 +1625,58 @@ func TestRedisBridge_ChannelPatternMatching(t *testing.T) { hub.register <- clientB time.Sleep(50 * time.Millisecond) - hub.Subscribe(clientA, "events:user:1") - hub.Subscribe(clientB, "events:user:2") + if err := hub.Subscribe(clientA, "events:user:1"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if err := hub.Subscribe(clientB, "events:user:2"); err != nil { + t.Fatalf("expected no error, got %v", err) + } // Receiver bridge. bridge1, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = bridge1.Start(context.Background()) - require.NoError(t, err) - defer bridge1.Stop() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge1.Stop) // Sender bridge. hub2, _, _ := startTestHub(t) bridge2, err := NewRedisBridge(hub2, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = bridge2.Start(context.Background()) - require.NoError(t, err) - defer bridge2.Stop() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge2.Stop) time.Sleep(200 * time.Millisecond) // Publish to events:user:1 — only clientA should receive. err = bridge2.PublishToChannel("events:user:1", Message{Type: TypeEvent, Data: "for-user-1"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } select { case msg := <-clientA.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, "for-user-1", received.Data) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("for-user-1", received.Data) { + t.Errorf("expected %v, got %v", "for-user-1", received.Data) + } + case <-time.After(3 * time.Second): t.Fatal("clientA should receive the channel message") } @@ -610,6 +1690,109 @@ func TestRedisBridge_ChannelPatternMatching(t *testing.T) { } } +func TestRedisBridge_InvalidInboundChannel_Ugly(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.register <- client + time.Sleep(50 * time.Millisecond) + + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge.Stop) + + env := redisEnvelope{ + SourceID: "external-source", + Message: Message{ + Type: TypeEvent, + Data: "should-be-dropped", + }, + } + raw := mustMarshal(env) + if testIsNil(raw) { + t.Fatalf("expected non-nil value") + } + + err = rc.Publish(context.Background(), prefix+":channel:bad channel", raw).Err() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-client.send: + t.Fatalf("invalid inbound channel should not be forwarded, got: %s", msg) + case <-time.After(300 * time.Millisecond): + // Good - listener dropped the invalid channel name. + } +} + +func TestRedisBridge_listen_InvalidProcessID_Ugly(t *testing.T) { + rc := skipIfNoRedis(t) + prefix := testPrefix(t) + cleanupRedis(t, rc, prefix) + + hub, _, _ := startTestHub(t) + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.register <- client + time.Sleep(50 * time.Millisecond) + + bridge, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = bridge.Start(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, bridge.Stop) + + env := redisEnvelope{ + SourceID: "external-source", + Message: Message{ + Type: TypeProcessOutput, + ProcessID: "bad process", + Data: "should-be-dropped", + }, + } + raw := mustMarshal(env) + if testIsNil(raw) { + t.Fatalf("expected non-nil value") + } + + err = rc.Publish(context.Background(), prefix+":broadcast", raw).Err() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-client.send: + t.Fatalf("invalid process ID should not be forwarded, got: %s", msg) + case <-time.After(300 * time.Millisecond): + // Good - listener dropped the forwarded message before local delivery. + } +} + // --------------------------------------------------------------------------- // Unique source IDs per bridge instance // --------------------------------------------------------------------------- @@ -622,13 +1805,17 @@ func TestRedisBridge_UniqueSourceIDs(t *testing.T) { hub, _, _ := startTestHub(t) bridge1, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } bridge2, err := NewRedisBridge(hub, RedisConfig{Addr: redisAddr, Prefix: prefix}) - require.NoError(t, err) - - assert.NotEqual(t, bridge1.SourceID(), bridge2.SourceID(), - "each bridge instance must have a unique source ID") + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if testEqual(bridge1.SourceID(), bridge2.SourceID()) { + t.Errorf("expected values to differ: %v", bridge2.SourceID()) + } _ = bridge1.Stop() _ = bridge2.Stop() diff --git a/test_stdlib_helpers_test.go b/test_stdlib_helpers_test.go new file mode 100644 index 0000000..3e75233 --- /dev/null +++ b/test_stdlib_helpers_test.go @@ -0,0 +1,151 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ws + +import ( + "errors" + "reflect" + "strings" + "testing" + "time" +) + +func testEqual(want, got any) bool { + return reflect.DeepEqual(want, got) +} + +func testErrorIs(err, target error) bool { + return errors.Is(err, target) +} + +func testIsNil(value any) bool { + if value == nil { + return true + } + + v := reflect.ValueOf(value) + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + return v.IsNil() + default: + return false + } +} + +func testIsEmpty(value any) bool { + if value == nil { + return true + } + + v := reflect.ValueOf(value) + switch v.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Complex64, reflect.Complex128: + return v.Complex() == 0 + case reflect.Interface, reflect.Pointer: + return v.IsNil() + default: + return reflect.DeepEqual(value, reflect.Zero(v.Type()).Interface()) + } +} + +func testContains(container, element any) bool { + if s, ok := container.(string); ok { + needle, ok := element.(string) + return ok && strings.Contains(s, needle) + } + + v := reflect.ValueOf(container) + if !v.IsValid() { + return false + } + + switch v.Kind() { + case reflect.Array, reflect.Slice: + for i := range v.Len() { + if reflect.DeepEqual(v.Index(i).Interface(), element) { + return true + } + } + case reflect.Map: + key := reflect.ValueOf(element) + if !key.IsValid() { + return false + } + if key.Type().AssignableTo(v.Type().Key()) { + return v.MapIndex(key).IsValid() + } + if key.Type().ConvertibleTo(v.Type().Key()) { + return v.MapIndex(key.Convert(v.Type().Key())).IsValid() + } + } + + return false +} + +func testSame(want, got any) bool { + if testIsNil(want) || testIsNil(got) { + return testIsNil(want) && testIsNil(got) + } + + wantValue := reflect.ValueOf(want) + gotValue := reflect.ValueOf(got) + if wantValue.Type() != gotValue.Type() { + return false + } + if wantValue.Kind() != reflect.Pointer { + return false + } + return wantValue.Pointer() == gotValue.Pointer() +} + +func testIsZero(value any) bool { + if value == nil { + return true + } + return reflect.ValueOf(value).IsZero() +} + +func testEventually(condition func() bool, waitFor, tick time.Duration) bool { + deadline := time.Now().Add(waitFor) + for { + if condition() { + return true + } + if time.Now().After(deadline) { + return false + } + + sleepFor := tick + if remaining := time.Until(deadline); remaining < sleepFor { + sleepFor = remaining + } + if sleepFor > 0 { + time.Sleep(sleepFor) + } + } +} + +func testClose(t testing.TB, closeFn func() error) { + t.Helper() + _ = closeFn() +} + +func testNotPanics(t *testing.T, f func()) { + t.Helper() + defer func() { + if recovered := recover(); recovered != nil { + t.Errorf("expected no panic, got %v", recovered) + } + }() + f() +} diff --git a/tests/cli/ws/Taskfile.yaml b/tests/cli/ws/Taskfile.yaml new file mode 100644 index 0000000..a0a4c33 --- /dev/null +++ b/tests/cli/ws/Taskfile.yaml @@ -0,0 +1,54 @@ +version: "3" + +env: + GOWORK: off + GOCACHE: /tmp/go-ws-go-build-cache + +tasks: + build: + dir: ../../.. + cmds: + - go build ./... + + test: + dir: ../../.. + cmds: + - go test -count=1 -race ./... + + vet: + dir: ../../.. + cmds: + - go vet ./... + + fmt: + dir: ../../.. + cmds: + - gofmt -l . + + lint: + dir: ../../.. + cmds: + - golangci-lint run ./... + + test-unit: + dir: ../../.. + cmds: + - go test -count=1 -race ./... -run Unit + + test-integration: + dir: ../../.. + cmds: + - | + if [ -z "${REDIS_ADDR:-}" ]; then + echo "Skipping integration tests: REDIS_ADDR unset (requires Redis on localhost:6379)" + exit 0 + fi + go test -count=1 -race ./... -run Integration -tags integration + + default: + deps: + - fmt + - build + - test + - vet + - lint diff --git a/ws.go b/ws.go index 61b85aa..726e207 100644 --- a/ws.go +++ b/ws.go @@ -40,7 +40,7 @@ // // Clients can subscribe to specific channels to receive targeted messages: // -// // Client sends: {"type": "subscribe", "data": "process:proc-1"} +// // Client sends: {"type": "subscribe", "channel": "process:proc-1"} // // Server broadcasts only to subscribers of "process:proc-1" // // # Integration with Core @@ -62,29 +62,32 @@ import ( "context" "iter" "maps" + "math" + "net" + // AX-6-exception: WebSocket requires HTTP upgrade (RFC 6455) "net/http" + "net/url" "slices" + // Note: AX-6 — origin, host, and channel normalization is structural HTTP/WebSocket boundary validation. + "strings" + // Note: AX-6 — internal concurrency primitive; structural for go-ws hub state (RFC mandates concurrent connection map). "sync" "time" - core "dappco.re/go/core" - coreerr "dappco.re/go/core/log" + core "dappco.re/go" + coreerr "dappco.re/go/log" "github.com/gorilla/websocket" ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true // Allow all origins for local development - }, -} - // Default timing values for heartbeat and pong timeout. const ( - DefaultHeartbeatInterval = 30 * time.Second - DefaultPongTimeout = 60 * time.Second - DefaultWriteTimeout = 10 * time.Second + DefaultHeartbeatInterval = 30 * time.Second + DefaultPongTimeout = 60 * time.Second + DefaultWriteTimeout = 10 * time.Second + DefaultMaxSubscriptionsPerClient = 1024 + defaultMaxMessageBytes = 64 * 1024 + maxChannelNameLen = 256 + maxProcessIDLen = 128 ) // ConnectionState represents the current state of a reconnecting client. @@ -99,7 +102,8 @@ const ( StateConnected ) -// HubConfig holds configuration for the Hub and its managed connections. +// HubConfig configures the hub. +// ws.NewHubWithConfig(ws.HubConfig{HeartbeatInterval: 30 * time.Second}) type HubConfig struct { // HeartbeatInterval is the interval between server-side ping messages. // Defaults to 30 seconds. @@ -130,17 +134,42 @@ type HubConfig struct { // subscribe to a named channel. When nil, all subscriptions are allowed. ChannelAuthoriser ChannelAuthoriser + // MaxSubscriptionsPerClient limits the number of active subscriptions a + // single client may hold. Zero or negative values use the default limit. + MaxSubscriptionsPerClient int + + // AllowedOrigins lists exact Origin header values accepted during the + // WebSocket upgrade. When empty and CheckOrigin is nil, all origins are + // allowed for development compatibility only; configure this in production. + AllowedOrigins []string + + // CheckOrigin optionally overrides the Origin header policy during the + // WebSocket upgrade. When nil, NewHubWithConfig derives one from + // AllowedOrigins. + // + // hub := ws.NewHubWithConfig(ws.HubConfig{ + // AllowedOrigins: []string{"https://app.example"}, + // }) + // AX-6-exception: WebSocket requires the RFC 6455 HTTP/1.1 upgrade boundary. + // gorilla/websocket exposes that boundary as net/http primitives, so go-ws + // keeps http.Request, http.ResponseWriter, and http.Header at the transport edge. + CheckOrigin func(r *http.Request) bool + // OnAuthFailure is called when a connection is rejected by the // Authenticator. Useful for logging or metrics. Optional. OnAuthFailure func(r *http.Request, result AuthResult) } -// DefaultHubConfig returns a HubConfig with sensible defaults. +// DefaultHubConfig returns the package defaults for hub timing and subscription +// limits. +// +// config := ws.DefaultHubConfig() func DefaultHubConfig() HubConfig { return HubConfig{ - HeartbeatInterval: DefaultHeartbeatInterval, - PongTimeout: DefaultPongTimeout, - WriteTimeout: DefaultWriteTimeout, + HeartbeatInterval: DefaultHeartbeatInterval, + PongTimeout: DefaultPongTimeout, + WriteTimeout: DefaultWriteTimeout, + MaxSubscriptionsPerClient: DefaultMaxSubscriptionsPerClient, } } @@ -167,6 +196,7 @@ const ( ) // Message is the standard WebSocket message format. +// msg := ws.Message{Type: ws.TypeEvent, Data: "hello"} type Message struct { Type MessageType `json:"type"` Channel string `json:"channel,omitempty"` @@ -176,6 +206,7 @@ type Message struct { } // Client represents a connected WebSocket client. +// client := &ws.Client{UserID: "user-123"} type Client struct { hub *Hub conn *websocket.Conn @@ -199,25 +230,42 @@ type Client struct { type ChannelAuthoriser func(client *Client, channel string) bool // Hub manages WebSocket connections and message broadcasting. +// hub := ws.NewHub() type Hub struct { - clients map[*Client]bool - broadcast chan []byte - register chan *Client - unregister chan *Client - channels map[string]map[*Client]bool - config HubConfig - done chan struct{} - doneOnce sync.Once - running bool - mu sync.RWMutex -} - -// NewHub creates a new WebSocket hub with default configuration. + clients map[*Client]bool + broadcast chan []byte + register chan *Client + unregister chan *Client + subscribeRequests chan subscriptionRequest + unsubscribeRequests chan subscriptionRequest + channels map[string]map[*Client]bool + config HubConfig + done chan struct{} + doneOnce sync.Once + running bool + mu sync.RWMutex +} + +type subscriptionRequest struct { + client *Client + channel string + reply chan error +} + +// NewHub constructs a hub with DefaultHubConfig. +// +// ws.NewHub(); go hub.Run(ctx) func NewHub() *Hub { - return NewHubWithConfig(DefaultHubConfig()) + config := DefaultHubConfig() + if config.CheckOrigin == nil && len(config.AllowedOrigins) == 0 { + coreerr.Warn("websocket hub allows all origins; set HubConfig.AllowedOrigins in production") + } + return NewHubWithConfig(config) } -// NewHubWithConfig creates a new WebSocket hub with the given configuration. +// NewHubWithConfig constructs a hub using config after applying default values. +// +// ws.NewHubWithConfig(ws.HubConfig{HeartbeatInterval: 30 * time.Second}) func NewHubWithConfig(config HubConfig) *Hub { if config.HeartbeatInterval <= 0 { config.HeartbeatInterval = DefaultHeartbeatInterval @@ -231,20 +279,128 @@ func NewHubWithConfig(config HubConfig) *Hub { if config.WriteTimeout <= 0 { config.WriteTimeout = DefaultWriteTimeout } + if config.MaxSubscriptionsPerClient <= 0 { + config.MaxSubscriptionsPerClient = DefaultMaxSubscriptionsPerClient + } + if config.CheckOrigin == nil { + config.CheckOrigin = allowedOriginsCheck(config.AllowedOrigins) + } return &Hub{ - clients: make(map[*Client]bool), - broadcast: make(chan []byte, 256), - register: make(chan *Client), - unregister: make(chan *Client), - channels: make(map[string]map[*Client]bool), - config: config, - done: make(chan struct{}), + clients: make(map[*Client]bool), + broadcast: make(chan []byte, 256), + register: make(chan *Client), + unregister: make(chan *Client), + subscribeRequests: make(chan subscriptionRequest), + unsubscribeRequests: make(chan subscriptionRequest), + channels: make(map[string]map[*Client]bool), + config: config, + done: make(chan struct{}), + } +} + +func nilHubError(operation string) error { + return coreerr.E(operation, "hub must not be nil", nil) +} + +func logCloseError(operation string, closeFn func() error) { + if closeFn == nil { + return + } + + if err := closeFn(); err != nil { + coreerr.Warn("close failed", "op", operation, "err", err) + } +} + +func stampServerMessage(msg Message) Message { + // Server-emitted messages own the timestamp field. + msg.Timestamp = time.Now() + return msg +} + +func stampServerMessageIfNeeded(msg Message) Message { + if msg.Timestamp.IsZero() { + return stampServerMessage(msg) + } + + return msg +} + +func validateMessageIdentifiers(operation string, msg Message) error { + if msg.ProcessID != "" && !validProcessID(msg.ProcessID) { + return coreerr.E(operation, "invalid process ID", nil) + } + + return nil +} + +func validateChannelTarget(operation string, channel string) error { + if !validChannelName(channel) { + return coreerr.E(operation, "invalid channel name", nil) + } + + if processID, ok := processChannelID(channel); ok && !validProcessID(processID) { + return coreerr.E(operation, "invalid process ID", nil) + } + + return nil +} + +func processChannelID(channel string) (string, bool) { + if !strings.HasPrefix(channel, "process:") { + return "", false + } + + return strings.TrimPrefix(channel, "process:"), true +} + +func validChannelName(channel string) bool { + return validIdentifier(channel, maxChannelNameLen) +} + +func validProcessID(processID string) bool { + if !validIdentifier(processID, maxProcessIDLen) { + return false + } + + // Process IDs are embedded in `process:` channel names, so the + // identifier itself must not contain the separator token. + return !strings.Contains(processID, ":") +} + +func validIdentifier(value string, maxLen int) bool { + if value == "" || len(value) > maxLen { + return false + } + + if strings.TrimSpace(value) != value { + return false + } + + for _, r := range value { + switch { + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9': + case r == '_', r == '-', r == '.', r == ':': + default: + return false + } } + + return true } // Run starts the hub's main loop. It should be called in a goroutine. // The loop exits when the context is cancelled. func (h *Hub) Run(ctx context.Context) { + if h == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + h.mu.Lock() h.running = true h.mu.Unlock() @@ -270,7 +426,7 @@ func (h *Hub) Run(ctx context.Context) { h.mu.Unlock() if h.config.OnDisconnect != nil { for _, client := range disconnected { - safeClientCallback(func() { + go safeClientCallback(func() { h.config.OnDisconnect(client) }) } @@ -285,7 +441,7 @@ func (h *Hub) Run(ctx context.Context) { h.clients[client] = true h.mu.Unlock() if h.config.OnConnect != nil { - safeClientCallback(func() { + go safeClientCallback(func() { h.config.OnConnect(client) }) } @@ -300,21 +456,29 @@ func (h *Hub) Run(ctx context.Context) { h.mu.Unlock() if h.config.OnDisconnect != nil { - safeClientCallback(func() { + go safeClientCallback(func() { h.config.OnDisconnect(client) }) } } else { h.mu.Unlock() } + case request := <-h.subscribeRequests: + err := h.handleSubscribeRequest(request) + if request.reply != nil { + request.reply <- err + } + case request := <-h.unsubscribeRequests: + h.handleUnsubscribeRequest(request) + if request.reply != nil { + request.reply <- nil + } case message := <-h.broadcast: h.mu.RLock() for client := range h.clients { if !trySend(client.send, message) { // Client buffer full or already closed, will be cleaned up. - go func(c *Client) { - h.unregister <- c - }(client) + h.enqueueUnregister(client) } } h.mu.RUnlock() @@ -322,6 +486,41 @@ func (h *Hub) Run(ctx context.Context) { } } +func (h *Hub) handleSubscribeRequest(request subscriptionRequest) error { + if request.client == nil { + return nil + } + + h.mu.Lock() + defer h.mu.Unlock() + + return h.subscribeLocked(request.client, request.channel) +} + +func (h *Hub) handleUnsubscribeRequest(request subscriptionRequest) { + if request.client == nil { + return + } + + h.mu.Lock() + defer h.mu.Unlock() + + h.unsubscribeLocked(request.client, request.channel) +} + +func (h *Hub) enqueueUnregister(client *Client) { + if h == nil || client == nil { + return + } + + go func() { + select { + case h.unregister <- client: + case <-h.done: + } + }() +} + // removeClientLocked removes a client from the hub and all channel // membership maps. The hub lock must be held by the caller. func (h *Hub) removeClientLocked(client *Client) { @@ -345,9 +544,15 @@ func (h *Hub) removeClientLocked(client *Client) { // Subscribe adds a client to a channel. func (h *Hub) Subscribe(client *Client, channel string) error { - if client == nil || channel == "" { + if client == nil { return nil } + if h == nil { + return coreerr.E("Subscribe", "hub must not be nil", nil) + } + if err := validateChannelTarget("Subscribe", channel); err != nil { + return err + } if h != nil && h.config.ChannelAuthoriser != nil && !safeAuthoriserResult(func() bool { return h.config.ChannelAuthoriser(client, channel) @@ -355,9 +560,50 @@ func (h *Hub) Subscribe(client *Client, channel string) error { return coreerr.E("Subscribe", "subscription unauthorised", nil) } + if h.isRunning() { + request := subscriptionRequest{ + client: client, + channel: channel, + reply: make(chan error, 1), + } + + select { + case h.subscribeRequests <- request: + case <-h.done: + return coreerr.E("Subscribe", "hub is not running", nil) + } + + select { + case err := <-request.reply: + return err + case <-h.done: + return coreerr.E("Subscribe", "hub stopped before subscription completed", nil) + } + } + h.mu.Lock() defer h.mu.Unlock() + return h.subscribeLocked(client, channel) +} + +func (h *Hub) subscribeLocked(client *Client, channel string) error { + if client == nil { + return nil + } + + maxSubs := h.config.MaxSubscriptionsPerClient + if maxSubs > 0 { + client.mu.RLock() + currentSubs := len(client.subscriptions) + _, alreadySubscribed := client.subscriptions[channel] + client.mu.RUnlock() + + if !alreadySubscribed && currentSubs >= maxSubs { + return ErrSubscriptionLimitExceeded + } + } + if _, ok := h.channels[channel]; !ok { h.channels[channel] = make(map[*Client]bool) } @@ -378,10 +624,41 @@ func (h *Hub) Unsubscribe(client *Client, channel string) { if client == nil || channel == "" { return } + if h == nil { + return + } + if validateChannelTarget("Unsubscribe", channel) != nil { + return + } + + if h.isRunning() { + request := subscriptionRequest{ + client: client, + channel: channel, + reply: make(chan error, 1), + } + + select { + case h.unsubscribeRequests <- request: + case <-h.done: + return + } + + select { + case <-request.reply: + return + case <-h.done: + return + } + } h.mu.Lock() defer h.mu.Unlock() + h.unsubscribeLocked(client, channel) +} + +func (h *Hub) unsubscribeLocked(client *Client, channel string) { if clients, ok := h.channels[channel]; ok { delete(clients, client) // Clean up empty channels @@ -397,9 +674,37 @@ func (h *Hub) Unsubscribe(client *Client, channel string) { client.mu.Unlock() } -// Broadcast sends a message to all connected clients. +func (h *Hub) isRunning() bool { + if h == nil { + return false + } + + h.mu.RLock() + defer h.mu.RUnlock() + + return h.running +} + +// Broadcast sends msg to every connected client. +// +// hub.Broadcast(ws.Message{Type: ws.TypeEvent, Data: "hello everyone"}) func (h *Hub) Broadcast(msg Message) error { - msg.Timestamp = time.Now() + return h.broadcastMessage(msg, false) +} + +func (h *Hub) broadcastMessage(msg Message, preserveTimestamp bool) error { + if h == nil { + return nilHubError("Broadcast") + } + if err := validateMessageIdentifiers("Broadcast", msg); err != nil { + return err + } + + if preserveTimestamp { + msg = stampServerMessageIfNeeded(msg) + } else { + msg = stampServerMessage(msg) + } r := core.JSONMarshal(msg) if !r.OK { return coreerr.E("Broadcast", "failed to marshal message", nil) @@ -413,9 +718,30 @@ func (h *Hub) Broadcast(msg Message) error { return nil } -// SendToChannel sends a message to all clients subscribed to a channel. +// SendToChannel sends msg to clients subscribed to channel. +// +// hub.SendToChannel("notifications", ws.Message{Type: ws.TypeEvent, Data: "important update"}) func (h *Hub) SendToChannel(channel string, msg Message) error { - msg.Timestamp = time.Now() + return h.sendToChannelMessage(channel, msg, false) +} + +func (h *Hub) sendToChannelMessage(channel string, msg Message, preserveTimestamp bool) error { + if h == nil { + return nilHubError("SendToChannel") + } + + if err := validateChannelTarget("SendToChannel", channel); err != nil { + return err + } + if err := validateMessageIdentifiers("SendToChannel", msg); err != nil { + return err + } + + if preserveTimestamp { + msg = stampServerMessageIfNeeded(msg) + } else { + msg = stampServerMessage(msg) + } msg.Channel = channel r := core.JSONMarshal(msg) if !r.OK { @@ -435,13 +761,76 @@ func (h *Hub) SendToChannel(channel string, msg Message) error { h.mu.RUnlock() for _, client := range targets { - _ = trySend(client.send, data) + if !trySend(client.send, data) { + // Keep the channel membership maps clean if a client can no + // longer accept outbound frames. + h.enqueueUnregister(client) + } } return nil } -// SendProcessOutput sends process output to subscribers of the process channel. +func sortedClientSubscriptions(client *Client) []string { + if client == nil { + return nil + } + + subscriptions := slices.Collect(maps.Keys(client.subscriptions)) + slices.Sort(subscriptions) + return subscriptions +} + +func sortedHubChannels(h *Hub) []string { + if h == nil { + return nil + } + + channels := slices.Collect(maps.Keys(h.channels)) + slices.Sort(channels) + return channels +} + +func sortedHubClients(h *Hub) []*Client { + if h == nil { + return nil + } + + clients := slices.Collect(maps.Keys(h.clients)) + slices.SortStableFunc(clients, func(left *Client, right *Client) int { + switch { + case left == nil && right == nil: + return 0 + case left == nil: + return -1 + case right == nil: + return 1 + } + + if compare := strings.Compare(left.UserID, right.UserID); compare != 0 { + return compare + } + + return strings.Compare(clientSortKey(left), clientSortKey(right)) + }) + return clients +} + +func clientSortKey(client *Client) string { + if client == nil || client.conn == nil || client.conn.RemoteAddr() == nil { + return "" + } + + return client.conn.RemoteAddr().String() +} + +// SendProcessOutput publishes process output to the process channel. +// +// hub.SendProcessOutput("proc-123", "line of output\n") func (h *Hub) SendProcessOutput(processID string, output string) error { + if !validProcessID(processID) { + return coreerr.E("SendProcessOutput", "invalid process ID", nil) + } + return h.SendToChannel("process:"+processID, Message{ Type: TypeProcessOutput, ProcessID: processID, @@ -449,8 +838,14 @@ func (h *Hub) SendProcessOutput(processID string, output string) error { }) } -// SendProcessStatus sends a process status update to subscribers. +// SendProcessStatus publishes a process status update to the process channel. +// +// hub.SendProcessStatus("proc-123", "exited", 0) func (h *Hub) SendProcessStatus(processID string, status string, exitCode int) error { + if !validProcessID(processID) { + return coreerr.E("SendProcessStatus", "invalid process ID", nil) + } + return h.SendToChannel("process:"+processID, Message{ Type: TypeProcessStatus, ProcessID: processID, @@ -461,7 +856,9 @@ func (h *Hub) SendProcessStatus(processID string, status string, exitCode int) e }) } -// SendError sends an error message to all connected clients. +// SendError broadcasts an error message to connected clients. +// +// hub.SendError("server error") func (h *Hub) SendError(errMsg string) error { return h.Broadcast(Message{ Type: TypeError, @@ -469,7 +866,9 @@ func (h *Hub) SendError(errMsg string) error { }) } -// SendEvent sends a generic event to all connected clients. +// SendEvent broadcasts a named event payload to connected clients. +// +// hub.SendEvent("user-joined", map[string]any{"user": "alice"}) func (h *Hub) SendEvent(eventType string, data any) error { return h.Broadcast(Message{ Type: TypeEvent, @@ -480,22 +879,40 @@ func (h *Hub) SendEvent(eventType string, data any) error { }) } -// ClientCount returns the number of connected clients. +// ClientCount returns the number of clients currently registered with the hub. +// +// clientCount := hub.ClientCount() func (h *Hub) ClientCount() int { + if h == nil { + return 0 + } + h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } -// ChannelCount returns the number of active channels. +// ChannelCount returns the number of channels that currently have subscribers. +// +// channelCount := hub.ChannelCount() func (h *Hub) ChannelCount() int { + if h == nil { + return 0 + } + h.mu.RLock() defer h.mu.RUnlock() return len(h.channels) } -// ChannelSubscriberCount returns the number of subscribers for a channel. +// ChannelSubscriberCount returns the number of clients subscribed to channel. +// +// subscriberCount := hub.ChannelSubscriberCount("notifications") func (h *Hub) ChannelSubscriberCount(channel string) int { + if h == nil { + return 0 + } + h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.channels[channel]; ok { @@ -504,37 +921,66 @@ func (h *Hub) ChannelSubscriberCount(channel string) int { return 0 } -// AllClients returns an iterator for all connected clients. +// AllClients returns a deterministic snapshot iterator over registered clients. +// +// for client := range hub.AllClients() { _ = client.UserID } func (h *Hub) AllClients() iter.Seq[*Client] { + if h == nil { + return func(yield func(*Client) bool) {} + } + h.mu.RLock() defer h.mu.RUnlock() - return slices.Values(slices.Collect(maps.Keys(h.clients))) + return slices.Values(sortedHubClients(h)) } -// AllChannels returns an iterator for all active channels. +// AllChannels returns a deterministic snapshot iterator over active channels. +// +// for channel := range hub.AllChannels() { _ = channel } func (h *Hub) AllChannels() iter.Seq[string] { + if h == nil { + return func(yield func(string) bool) {} + } + h.mu.RLock() defer h.mu.RUnlock() - return slices.Values(slices.Collect(maps.Keys(h.channels))) + return slices.Values(sortedHubChannels(h)) } -// HubStats contains hub statistics. +// HubStats contains hub statistics, including the total subscriber count. +// stats := hub.Stats() type HubStats struct { - Clients int `json:"clients"` - Channels int `json:"channels"` + Clients int `json:"clients"` + Channels int `json:"channels"` + Subscribers int `json:"subscribers"` } -// Stats returns current hub statistics. +// Stats returns a snapshot of hub client, channel, and subscriber totals. +// +// stats := hub.Stats() func (h *Hub) Stats() HubStats { + if h == nil { + return HubStats{} + } + h.mu.RLock() defer h.mu.RUnlock() + + subscriberCount := 0 + for _, clients := range h.channels { + subscriberCount += len(clients) + } + return HubStats{ - Clients: len(h.clients), - Channels: len(h.channels), + Clients: len(h.clients), + Channels: len(h.channels), + Subscribers: subscriberCount, } } -// HandleWebSocket is an alias for Handler for clearer API. +// HandleWebSocket handles a single WebSocket upgrade request. +// +// http.HandleFunc("/ws", hub.HandleWebSocket) func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { h.Handler()(w, r) } @@ -549,12 +995,12 @@ func safeAuthenticate(auth Authenticator, r *http.Request) (result AuthResult) { } }() - return auth.Authenticate(r) + return finalizeAuthResult(auth.Authenticate(r)) } func safeClientCallback(call func()) { defer func() { - _ = recover() + recover() }() call() } @@ -569,54 +1015,193 @@ func safeAuthoriserResult(authorise func() bool) (ok bool) { return authorise() } -// Handler returns an HTTP handler for WebSocket connections. -func (h *Hub) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Authenticate if an Authenticator is configured. - var authResult AuthResult - if h.config.Authenticator != nil { - authResult = safeAuthenticate(h.config.Authenticator, r) - if !authResult.Valid { - if h.config.OnAuthFailure != nil { - safeClientCallback(func() { - h.config.OnAuthFailure(r, authResult) - }) - } - http.Error(w, "Unauthorised", http.StatusUnauthorized) - return - } +func safeOriginCheck(checkOrigin func(*http.Request) bool, r *http.Request) (ok bool) { + defer func() { + if recover() != nil { + ok = false } + }() - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } + return checkOrigin(r) +} - client := &Client{ - hub: h, - conn: conn, - send: make(chan []byte, 256), - subscriptions: make(map[string]bool), - } +func allowAllOriginsCheck(*http.Request) bool { + return true +} - // Populate auth fields when authentication succeeded. - if h.config.Authenticator != nil { - client.UserID = authResult.UserID - client.Claims = authResult.Claims +func allowedOriginsCheck(allowedOrigins []string) func(*http.Request) bool { + allowedOrigins = slices.Clone(allowedOrigins) + if len(allowedOrigins) == 0 { + return allowAllOriginsCheck + } + + allowed := make(map[string]struct{}, len(allowedOrigins)) + for _, origin := range allowedOrigins { + allowed[origin] = struct{}{} + } + + return func(r *http.Request) bool { + if r == nil { + return false } - h.mu.RLock() - isRunning := h.running - h.mu.RUnlock() - if !isRunning { - conn.Close() + _, ok := allowed[r.Header.Get("Origin")] + return ok + } +} + +// sameOriginCheck allows requests without an Origin header and otherwise +// requires the Origin scheme and host to match the request target. +func sameOriginCheck(r *http.Request) bool { + if r == nil { + return false + } + + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin == "" { + return true + } + + originURL, err := url.Parse(origin) + if err != nil || originURL.Host == "" { + return false + } + + requestHost := strings.TrimSpace(r.Host) + if requestHost == "" && r.URL != nil { + requestHost = strings.TrimSpace(r.URL.Host) + } + if requestHost == "" { + return false + } + + requestScheme := "http" + if r.TLS != nil { + requestScheme = "https" + } + + if !strings.EqualFold(originURL.Scheme, requestScheme) { + return false + } + + originHost, originPort, ok := splitHostAndPort(originURL.Host, originURL.Scheme) + if !ok { + return false + } + + requestHostName, requestPort, ok := splitHostAndPort(requestHost, requestScheme) + if !ok { + return false + } + + return strings.EqualFold(originHost, requestHostName) && originPort == requestPort +} + +func splitHostAndPort(host string, scheme string) (string, string, bool) { + host = strings.TrimSpace(host) + if host == "" { + return "", "", false + } + + if hostname, port, err := net.SplitHostPort(host); err == nil { + if hostname == "" { + return "", "", false + } + return hostname, port, true + } + + if strings.HasPrefix(host, "[") { + trimmed := strings.TrimSuffix(strings.TrimPrefix(host, "["), "]") + if trimmed == "" { + return "", "", false + } + return trimmed, defaultPortForScheme(scheme), true + } + + if strings.Contains(host, ":") { + return "", "", false + } + + return host, defaultPortForScheme(scheme), true +} + +func defaultPortForScheme(scheme string) string { + switch strings.ToLower(strings.TrimSpace(scheme)) { + case "https", "wss": + return "443" + default: + return "80" + } +} + +// Handler returns an HTTP handler for WebSocket upgrade requests. +// +// http.HandleFunc("/ws", hub.Handler()) +func (h *Hub) Handler() http.HandlerFunc { + if h == nil { + return func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "Hub is not configured", http.StatusServiceUnavailable) + } + } + + return func(w http.ResponseWriter, r *http.Request) { + if !h.isRunning() { + http.Error(w, "Hub is not running", http.StatusServiceUnavailable) + return + } + + checkOrigin := h.config.CheckOrigin + if checkOrigin == nil { + checkOrigin = allowAllOriginsCheck + } + originAllowed := safeOriginCheck(checkOrigin, r) + if !originAllowed { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + // Authenticate only after the origin policy has accepted the request. + var authResult AuthResult + if h.config.Authenticator != nil { + authResult = safeAuthenticate(h.config.Authenticator, r) + if !authResultAccepted(authResult) { + if h.config.OnAuthFailure != nil { + safeClientCallback(func() { + h.config.OnAuthFailure(r, authResult) + }) + } + http.Error(w, "Unauthorised", http.StatusUnauthorized) + return + } + } + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(*http.Request) bool { return originAllowed }, + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { return } + client := &Client{ + hub: h, + conn: conn, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + // Populate auth fields when authentication succeeded. + if h.config.Authenticator != nil { + client.UserID = authResult.UserID + client.Claims = authResult.Claims + } + select { case h.register <- client: case <-h.done: - conn.Close() + logCloseError("Hub.Handler", conn.Close) return } @@ -627,6 +1212,10 @@ func (h *Hub) Handler() http.HandlerFunc { // readPump handles incoming messages from the client. func (c *Client) readPump() { + if c == nil || c.hub == nil || c.conn == nil { + return + } + defer func() { if c.hub != nil { select { @@ -635,16 +1224,17 @@ func (c *Client) readPump() { } } if c.conn != nil { - c.conn.Close() + logCloseError("Client.readPump", c.conn.Close) } }() pongTimeout := c.hub.config.PongTimeout - c.conn.SetReadLimit(65536) - c.conn.SetReadDeadline(time.Now().Add(pongTimeout)) + c.conn.SetReadLimit(defaultMaxMessageBytes) + if err := c.conn.SetReadDeadline(time.Now().Add(pongTimeout)); err != nil { + return + } c.conn.SetPongHandler(func(string) error { - c.conn.SetReadDeadline(time.Now().Add(pongTimeout)) - return nil + return c.conn.SetReadDeadline(time.Now().Add(pongTimeout)) }) for { @@ -660,7 +1250,7 @@ func (c *Client) readPump() { switch msg.Type { case TypeSubscribe: - if channel, ok := msg.Data.(string); ok { + if channel := messageTargetChannel(msg); channel != "" { if err := c.hub.Subscribe(c, channel); err != nil { errMsg := mustMarshal(Message{ Type: TypeError, @@ -668,12 +1258,14 @@ func (c *Client) readPump() { Timestamp: time.Now(), }) if errMsg != nil { - _ = trySend(c.send, errMsg) + if !trySend(c.send, errMsg) { + coreerr.Warn("failed to queue websocket error message", "op", "Client.readPump") + } } } } case TypeUnsubscribe: - if channel, ok := msg.Data.(string); ok { + if channel := messageTargetChannel(msg); channel != "" { c.hub.Unsubscribe(c, channel) } case TypePing: @@ -682,27 +1274,52 @@ func (c *Client) readPump() { continue } - _ = trySend(c.send, pongMessage) + if !trySend(c.send, pongMessage) { + coreerr.Warn("failed to queue websocket pong", "op", "Client.readPump") + } } } } +// messageTargetChannel returns the subscription channel named in a client frame. +// The RFC uses the Channel field, while existing callers in this module have +// historically sent the target in Data, so both shapes are accepted. +func messageTargetChannel(msg Message) string { + if msg.Channel != "" { + return msg.Channel + } + + if channel, ok := msg.Data.(string); ok { + return channel + } + + return "" +} + // writePump sends messages to the client. func (c *Client) writePump() { + if c == nil || c.hub == nil || c.conn == nil { + return + } + heartbeat := c.hub.config.HeartbeatInterval writeTimeout := c.hub.config.WriteTimeout ticker := time.NewTicker(heartbeat) defer func() { ticker.Stop() - c.conn.Close() + logCloseError("Client.writePump", c.conn.Close) }() for { select { case message, ok := <-c.send: - c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if err := c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return + } if !ok { - c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil { + coreerr.Warn("failed to write websocket close message", "op", "Client.writePump", "err", err) + } return } @@ -710,20 +1327,39 @@ func (c *Client) writePump() { if err != nil { return } - w.Write(message) + closed := false + defer func() { + if !closed { + logCloseError("Client.writePump.writer", w.Close) + } + }() + if _, err := w.Write(message); err != nil { + return + } // Batch queued messages n := len(c.send) - for range n { - w.Write([]byte{'\n'}) - w.Write(<-c.send) + for i := 0; i < n; i++ { + next, ok := <-c.send + if !ok { + return + } + if _, err := w.Write([]byte{'\n'}); err != nil { + return + } + if _, err := w.Write(next); err != nil { + return + } } + closed = true if err := w.Close(); err != nil { return } case <-ticker.C: - c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if err := c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return + } if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } @@ -766,22 +1402,37 @@ func (c *Client) closeSend() { }) } -// Subscriptions returns a copy of the client's current subscriptions. +// Subscriptions returns a sorted snapshot of the client's channel subscriptions. +// +// subscriptions := client.Subscriptions() func (c *Client) Subscriptions() []string { + if c == nil { + return nil + } + c.mu.RLock() defer c.mu.RUnlock() - return slices.Collect(maps.Keys(c.subscriptions)) + return sortedClientSubscriptions(c) } -// AllSubscriptions returns an iterator for the client's current subscriptions. +// AllSubscriptions returns a deterministic snapshot iterator over the client's +// channel subscriptions. +// +// for channel := range client.AllSubscriptions() { _ = channel } func (c *Client) AllSubscriptions() iter.Seq[string] { + if c == nil { + return func(yield func(string) bool) {} + } + c.mu.RLock() defer c.mu.RUnlock() - return slices.Values(slices.Collect(maps.Keys(c.subscriptions))) + return slices.Values(sortedClientSubscriptions(c)) } -// Close closes the client connection. +// Close disconnects the client and unregisters it from the hub when attached. +// +// err := client.Close() func (c *Client) Close() error { if c == nil { return nil @@ -794,17 +1445,33 @@ func (c *Client) Close() error { return c.conn.Close() } - select { - case c.hub.unregister <- c: - default: + if c.hub.isRunning() { + c.hub.enqueueUnregister(c) + } else { + var disconnected bool + c.hub.mu.Lock() + if _, ok := c.hub.clients[c]; ok { + c.hub.removeClientLocked(c) + disconnected = true + } + c.hub.mu.Unlock() + + if disconnected && c.hub.config.OnDisconnect != nil { + safeClientCallback(func() { + c.hub.config.OnDisconnect(c) + }) + } } + if c.conn == nil { return nil } return c.conn.Close() } -// ReconnectConfig holds configuration for the reconnecting WebSocket client. +// ReconnectConfig configures a ReconnectingClient. +// +// client := ws.NewReconnectingClient(ws.ReconnectConfig{URL: "ws://localhost:8080/ws"}) type ReconnectConfig struct { // URL is the WebSocket server URL to connect to. URL string @@ -823,8 +1490,16 @@ type ReconnectConfig struct { // MaxRetries is the maximum number of consecutive reconnection attempts. // Zero means unlimited retries. + // + // Deprecated: use MaxReconnectAttempts. Retained for source compatibility. MaxRetries int + // MaxReconnectAttempts is the maximum number of consecutive reconnection attempts. + // Zero means unlimited retries. + // If both MaxReconnectAttempts and MaxRetries are set, MaxReconnectAttempts + // takes precedence. + MaxReconnectAttempts int + // OnConnect is called when the client successfully connects. OnConnect func() @@ -835,8 +1510,17 @@ type ReconnectConfig struct { // after a disconnection. The attempt count is passed in. OnReconnect func(attempt int) + // OnError is called when the client encounters a connection, read, + // or send error. + OnError func(err error) + // OnMessage is called when a message is received from the server. - OnMessage func(msg Message) + // Supported callback shapes are: + // - func([]byte) for raw frame payloads + // - func(Message) for decoded JSON messages + // Raw callbacks receive the frame bytes exactly as read. Message + // callbacks receive each decoded JSON object in the frame. + OnMessage any // Dialer is the WebSocket dialer to use. Defaults to websocket.DefaultDialer. Dialer *websocket.Dialer @@ -845,21 +1529,27 @@ type ReconnectConfig struct { Headers http.Header } -// ReconnectingClient is a WebSocket client that automatically reconnects -// with exponential backoff when the connection drops. +// ReconnectingClient maintains a WebSocket client connection and reconnects +// according to ReconnectConfig. +// +// client := ws.NewReconnectingClient(ws.ReconnectConfig{URL: "ws://localhost:8080/ws"}) type ReconnectingClient struct { - config ReconnectConfig - conn *websocket.Conn - send chan []byte - state ConnectionState - mu sync.RWMutex - writeMu sync.Mutex - done chan struct{} - ctx context.Context - cancel context.CancelFunc -} - -// NewReconnectingClient creates a new reconnecting WebSocket client. + config ReconnectConfig + conn *websocket.Conn + send chan []byte + state ConnectionState + mu sync.RWMutex + writeMu sync.Mutex + done chan struct{} + doneOnce sync.Once + ctx context.Context + cancel context.CancelFunc +} + +// NewReconnectingClient constructs a reconnecting client with validated +// backoff defaults. +// +// ws.NewReconnectingClient(ws.ReconnectConfig{URL: "ws://localhost:8080/ws"}) func NewReconnectingClient(config ReconnectConfig) *ReconnectingClient { if config.InitialBackoff <= 0 { config.InitialBackoff = 1 * time.Second @@ -867,7 +1557,10 @@ func NewReconnectingClient(config ReconnectConfig) *ReconnectingClient { if config.MaxBackoff <= 0 { config.MaxBackoff = 30 * time.Second } - if config.BackoffMultiplier <= 0 { + if config.InitialBackoff > config.MaxBackoff { + config.InitialBackoff = config.MaxBackoff + } + if !(config.BackoffMultiplier >= 1.0) || math.IsInf(config.BackoffMultiplier, 0) { config.BackoffMultiplier = 2.0 } if config.Dialer == nil { @@ -882,48 +1575,113 @@ func NewReconnectingClient(config ReconnectConfig) *ReconnectingClient { } } -// Connect starts the reconnecting client. It blocks until the context is -// cancelled. The client will automatically reconnect on connection loss. +// Connect starts the reconnect loop and blocks until the context is cancelled +// or the client is closed. +// +// err := client.Connect(ctx) func (rc *ReconnectingClient) Connect(ctx context.Context) error { - rc.ctx, rc.cancel = context.WithCancel(ctx) - defer rc.cancel() + if rc == nil { + return coreerr.E("ReconnectingClient.Connect", "client must not be nil", nil) + } + if ctx == nil { + ctx = context.Background() + } + + connectCtx, cancel := context.WithCancel(ctx) + rc.mu.Lock() + rc.ctx = connectCtx + rc.cancel = cancel + rc.mu.Unlock() + defer func() { + cancel() + rc.mu.Lock() + rc.ctx = nil + rc.cancel = nil + rc.mu.Unlock() + }() attempt := 0 wasConnected := false + waitBeforeDial := false for { select { - case <-rc.ctx.Done(): + case <-connectCtx.Done(): + rc.setState(StateDisconnected) + return connectCtx.Err() + case <-rc.done: rc.setState(StateDisconnected) - return rc.ctx.Err() + if err := connectCtx.Err(); err != nil { + return err + } + return nil default: } + if waitBeforeDial { + backoff := rc.calculateBackoff(attempt) + if !waitForReconnectBackoff(connectCtx, rc.done, backoff) { + rc.setState(StateDisconnected) + if err := connectCtx.Err(); err != nil { + return err + } + return nil + } + } + rc.setState(StateConnecting) attempt++ - conn, _, err := rc.config.Dialer.DialContext(rc.ctx, rc.config.URL, rc.config.Headers) + conn, _, err := rc.config.Dialer.DialContext(connectCtx, rc.config.URL, rc.config.Headers) if err != nil { - if rc.config.MaxRetries > 0 && attempt > rc.config.MaxRetries { + maxRetries := rc.maxReconnectAttempts() + if maxRetries > 0 && attempt > maxRetries { rc.setState(StateDisconnected) - return coreerr.E("ReconnectingClient.Connect", core.Sprintf("max retries (%d) exceeded", rc.config.MaxRetries), err) + wrapped := coreerr.E("ReconnectingClient.Connect", core.Sprintf("max retries (%d) exceeded", maxRetries), err) + if rc.config.OnError != nil { + safeReconnectCallback(func() { + rc.config.OnError(wrapped) + }) + } + return wrapped + } + if rc.config.OnError != nil { + safeReconnectCallback(func() { + rc.config.OnError(err) + }) } + rc.setState(StateDisconnected) backoff := rc.calculateBackoff(attempt) - select { - case <-rc.ctx.Done(): + if !waitForReconnectBackoff(connectCtx, rc.done, backoff) { rc.setState(StateDisconnected) - return rc.ctx.Err() - case <-time.After(backoff): - continue + if err := connectCtx.Err(); err != nil { + return err + } + return nil } + continue } - // Connected successfully rc.mu.Lock() rc.conn = conn rc.mu.Unlock() rc.setState(StateConnected) + connDone := make(chan struct{}) + go func(activeConn *websocket.Conn, done <-chan struct{}) { + select { + case <-connectCtx.Done(): + if activeConn != nil { + logCloseError("ReconnectingClient.Connect.context", activeConn.Close) + } + case <-rc.done: + if activeConn != nil { + logCloseError("ReconnectingClient.Connect.done", activeConn.Close) + } + case <-done: + } + }(conn, connDone) + if wasConnected { if rc.config.OnReconnect != nil { safeReconnectCallback(func() { @@ -943,34 +1701,94 @@ func (rc *ReconnectingClient) Connect(ctx context.Context) error { wasConnected = true // Run the read loop — blocks until connection drops - rc.readLoop() + readErr := rc.readLoop() + close(connDone) // Connection lost rc.mu.Lock() rc.conn = nil rc.mu.Unlock() + rc.setState(StateDisconnected) + + if rc.closeRequested() { + if rc.config.OnDisconnect != nil { + safeReconnectCallback(func() { + rc.config.OnDisconnect() + }) + } + if err := connectCtx.Err(); err != nil { + return err + } + return nil + } + + if readErr != nil && connectCtx.Err() == nil && rc.config.OnError != nil { + safeReconnectCallback(func() { + rc.config.OnError(readErr) + }) + } if rc.config.OnDisconnect != nil { safeReconnectCallback(func() { rc.config.OnDisconnect() }) } + + waitBeforeDial = true } } func safeReconnectCallback(call func()) { defer func() { - _ = recover() + recover() }() call() } -// Send sends a message to the server. Returns an error if not connected. -func (rc *ReconnectingClient) Send(msg Message) error { - msg.Timestamp = time.Now() - r := core.JSONMarshal(msg) +func marshalClientMessage(msg Message) []byte { + type clientMessage struct { + Type MessageType `json:"type"` + Channel string `json:"channel,omitempty"` + ProcessID string `json:"processId,omitempty"` + Data any `json:"data,omitempty"` + Timestamp *time.Time `json:"timestamp,omitempty"` + } + + wire := clientMessage{ + Type: msg.Type, + Channel: msg.Channel, + ProcessID: msg.ProcessID, + Data: msg.Data, + } + if !msg.Timestamp.IsZero() { + wire.Timestamp = &msg.Timestamp + } + + r := core.JSONMarshal(wire) if !r.OK { - return coreerr.E("ReconnectingClient.Send", "failed to marshal message", nil) + return nil + } + + return r.Value.([]byte) +} + +// Send writes a message to the active WebSocket connection. +// +// err := client.Send(ws.Message{Type: ws.TypeSubscribe, Channel: "notifications"}) +func (rc *ReconnectingClient) Send(msg Message) error { + if rc == nil { + return coreerr.E("ReconnectingClient.Send", "client must not be nil", nil) + } + + data := marshalClientMessage(msg) + if data == nil { + err := coreerr.E("ReconnectingClient.Send", "failed to marshal message", nil) + if rc.config.OnError != nil { + safeReconnectCallback(func() { + rc.config.OnError(err) + }) + } + return err } rc.mu.RLock() @@ -999,22 +1817,48 @@ func (rc *ReconnectingClient) Send(msg Message) error { } rc.mu.RUnlock() - return conn.WriteMessage(websocket.TextMessage, r.Value.([]byte)) + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + if rc.config.OnError != nil { + safeReconnectCallback(func() { + rc.config.OnError(err) + }) + } + logCloseError("ReconnectingClient.Send", conn.Close) + return err + } + + return nil } -// State returns the current connection state. +// State returns the client's current connection state. +// +// state := client.State() func (rc *ReconnectingClient) State() ConnectionState { + if rc == nil { + return StateDisconnected + } + rc.mu.RLock() defer rc.mu.RUnlock() return rc.state } -// Close gracefully shuts down the reconnecting client. +// Close stops reconnect attempts and closes the active WebSocket connection. +// +// err := client.Close() func (rc *ReconnectingClient) Close() error { + if rc == nil { + return nil + } + if rc.cancel != nil { rc.cancel() } + rc.doneOnce.Do(func() { + close(rc.done) + }) + rc.setState(StateDisconnected) rc.mu.Lock() @@ -1022,11 +1866,24 @@ func (rc *ReconnectingClient) Close() error { rc.conn = nil rc.mu.Unlock() if conn != nil { - return conn.Close() + logCloseError("ReconnectingClient.Close", conn.Close) } return nil } +func (rc *ReconnectingClient) closeRequested() bool { + if rc == nil || rc.done == nil { + return false + } + + select { + case <-rc.done: + return true + default: + return false + } +} + func (rc *ReconnectingClient) setState(state ConnectionState) { rc.mu.Lock() rc.state = state @@ -1034,39 +1891,164 @@ func (rc *ReconnectingClient) setState(state ConnectionState) { } func (rc *ReconnectingClient) calculateBackoff(attempt int) time.Duration { - backoff := rc.config.InitialBackoff - for range attempt - 1 { - backoff = time.Duration(float64(backoff) * rc.config.BackoffMultiplier) - if backoff > rc.config.MaxBackoff { - backoff = rc.config.MaxBackoff - break + if attempt <= 1 { + return rc.clampedInitialBackoff() + } + + backoff := rc.clampedInitialBackoff() + maxBackoff := rc.clampedMaxBackoff() + multiplier := rc.clampedBackoffMultiplier() + for i := 1; i < attempt; i++ { + if backoff >= maxBackoff { + return maxBackoff + } + + next := time.Duration(float64(backoff) * multiplier) + if next <= 0 || next > maxBackoff { + return maxBackoff } + backoff = next + } + + if backoff > maxBackoff { + return maxBackoff } + return backoff } -func (rc *ReconnectingClient) readLoop() { +func (rc *ReconnectingClient) clampedInitialBackoff() time.Duration { + backoff := rc.config.InitialBackoff + if backoff <= 0 { + backoff = 1 * time.Second + } + maxBackoff := rc.clampedMaxBackoff() + if backoff > maxBackoff { + return maxBackoff + } + return backoff +} + +func (rc *ReconnectingClient) clampedMaxBackoff() time.Duration { + maxBackoff := rc.config.MaxBackoff + if maxBackoff <= 0 { + maxBackoff = 30 * time.Second + } + return maxBackoff +} + +func (rc *ReconnectingClient) clampedBackoffMultiplier() float64 { + multiplier := rc.config.BackoffMultiplier + if !(multiplier >= 1.0) || math.IsInf(multiplier, 0) { + multiplier = 2.0 + } + return multiplier +} + +func waitForReconnectBackoff(ctx context.Context, done <-chan struct{}, delay time.Duration) bool { + if delay <= 0 { + return true + } + + timer := time.NewTimer(delay) + defer stopTimer(timer) + + select { + case <-ctx.Done(): + return false + case <-done: + return false + case <-timer.C: + return true + } +} + +func stopTimer(timer *time.Timer) { + if timer == nil { + return + } + + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } +} + +func (rc *ReconnectingClient) maxReconnectAttempts() int { + maxRetries := rc.config.MaxReconnectAttempts + if maxRetries == 0 { + maxRetries = rc.config.MaxRetries + } + if maxRetries < 0 { + return 0 + } + return maxRetries +} + +func (rc *ReconnectingClient) readLoop() error { rc.mu.RLock() conn := rc.conn rc.mu.RUnlock() if conn == nil { - return + return nil } + conn.SetReadLimit(defaultMaxMessageBytes) + for { _, data, err := conn.ReadMessage() if err != nil { - return + return err } if rc.config.OnMessage != nil { + dispatchReconnectMessage(rc.config.OnMessage, data) + } + } +} + +func dispatchReconnectMessage(handler any, data []byte) { + switch fn := handler.(type) { + case nil: + return + case func([]byte): + if fn == nil { + return + } + safeReconnectCallback(func() { + fn(data) + }) + case func(Message): + if fn == nil { + return + } + frames := strings.Split(string(data), "\n") + for _, frame := range frames { + frame = strings.TrimSpace(frame) + if frame == "" { + continue + } + var msg Message - if r := core.JSONUnmarshal(data, &msg); r.OK { - safeReconnectCallback(func() { - rc.config.OnMessage(msg) - }) + if r := core.JSONUnmarshal([]byte(frame), &msg); !r.OK { + continue } + + safeReconnectCallback(func() { + fn(msg) + }) } + case func(string): + if fn == nil { + return + } + safeReconnectCallback(func() { + fn(string(data)) + }) + default: + return } } diff --git a/ws_bench_test.go b/ws_bench_test.go index 999253d..c4c4554 100644 --- a/ws_bench_test.go +++ b/ws_bench_test.go @@ -4,10 +4,11 @@ package ws import ( "net/http/httptest" + // Note: AX-6 — internal concurrency primitive; structural for go-ws hub state (RFC mandates concurrent connection map). "sync" "testing" - core "dappco.re/go/core" + core "dappco.re/go" "github.com/gorilla/websocket" ) @@ -64,7 +65,7 @@ func BenchmarkSendToChannel_50(b *testing.B) { hub.mu.Lock() hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, "bench-channel") + _ = hub.Subscribe(client, "bench-channel") } msg := Message{Type: TypeEvent, Data: "bench-chan"} @@ -140,7 +141,7 @@ func BenchmarkWebSocketEndToEnd(b *testing.B) { if err != nil { b.Fatalf("dial failed: %v", err) } - defer conn.Close() + defer testClose(b, conn.Close) for hub.ClientCount() < 1 { } @@ -177,7 +178,7 @@ func BenchmarkSubscribeUnsubscribe(b *testing.B) { b.ReportAllocs() for b.Loop() { - hub.Subscribe(client, "bench-sub") + _ = hub.Subscribe(client, "bench-sub") hub.Unsubscribe(client, "bench-sub") } } @@ -199,7 +200,7 @@ func BenchmarkSendToChannel_Parallel(b *testing.B) { hub.mu.Lock() hub.clients[clients[i]] = true hub.mu.Unlock() - hub.Subscribe(clients[i], "parallel-chan") + _ = hub.Subscribe(clients[i], "parallel-chan") } msg := Message{Type: TypeEvent, Data: "p-bench"} @@ -236,7 +237,7 @@ func BenchmarkMultiChannelFanout(b *testing.B) { hub.mu.Lock() hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, channels[ch]) + _ = hub.Subscribe(client, channels[ch]) } } @@ -268,7 +269,7 @@ func BenchmarkConcurrentSubscribers(b *testing.B) { send: make(chan []byte, 1), subscriptions: make(map[string]bool), } - hub.Subscribe(client, "conc-sub-bench") + _ = hub.Subscribe(client, "conc-sub-bench") }) } wg.Wait() diff --git a/ws_test.go b/ws_test.go index bd5a1c6..c4a9b0c 100644 --- a/ws_test.go +++ b/ws_test.go @@ -3,20 +3,24 @@ package ws import ( + "bytes" "context" + "crypto/tls" + "math" "net" "net/http" "net/http/httptest" "slices" "strings" + // Note: AX-6 — internal concurrency primitive; structural for go-ws hub state (RFC mandates concurrent connection map). "sync" + "sync/atomic" "testing" "time" - core "dappco.re/go/core" + core "dappco.re/go" + coreerr "dappco.re/go/log" "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // wsURL converts an httptest server URL to a WebSocket URL. @@ -24,19 +28,254 @@ func wsURL(server *httptest.Server) string { return "ws" + core.TrimPrefix(server.URL, "http") } +func originRequest(origin string) *http.Request { + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + if origin != "" { + r.Header.Set("Origin", origin) + } + return r +} + func TestNewHub(t *testing.T) { t.Run("creates hub with initialised maps", func(t *testing.T) { hub := NewHub() + if testIsNil(hub) { + t.Fatalf("expected non-nil value") + } + if testIsNil(hub.clients) { + t.Errorf("expected non-nil value") + } + if testIsNil(hub.broadcast) { + t.Errorf("expected non-nil value") + } + if testIsNil(hub.register) { + t.Errorf("expected non-nil value") + } + if testIsNil(hub.unregister) { + t.Errorf("expected non-nil value") + } + if testIsNil(hub.channels) { + t.Errorf("expected non-nil value") + } + + }) +} + +func TestWs_AllowedOrigins_Good(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + AllowedOrigins: []string{ + "https://app.example", + "https://admin.example", + }, + }) + if testIsNil(hub) { + t.Fatalf("expected non-nil value") + } + if testIsNil(hub.config.CheckOrigin) { + t.Fatalf("expected non-nil value") + } + if !(hub.config.CheckOrigin(originRequest("https://app.example"))) { + t.Errorf("expected true") + } + if !(hub.config.CheckOrigin(originRequest("https://admin.example"))) { + t.Errorf("expected true") + } + +} + +func TestWs_AllowedOrigins_Bad(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + AllowedOrigins: []string{"https://app.example"}, + }) + if testIsNil(hub) { + t.Fatalf("expected non-nil value") + } + if testIsNil(hub.config.CheckOrigin) { + t.Fatalf("expected non-nil value") + } + if hub.config.CheckOrigin(originRequest("https://evil.example")) { + t.Errorf("expected false") + } + if hub.config.CheckOrigin(originRequest("")) { + t.Errorf("expected false") + } + +} + +func TestWs_AllowedOrigins_Ugly(t *testing.T) { + var logs bytes.Buffer + originalLogger := coreerr.Default() + coreerr.SetDefault(coreerr.New(coreerr.Options{ + Level: coreerr.LevelWarn, + Output: &logs, + })) + t.Cleanup(func() { + coreerr.SetDefault(originalLogger) + }) + + hub := NewHub() + if testIsNil(hub) { + t.Fatalf("expected non-nil value") + } + if testIsNil(hub.config.CheckOrigin) { + t.Fatalf("expected non-nil value") + } + if !testIsEmpty(hub.config.AllowedOrigins) { + t.Errorf("expected empty value, got %v", hub.config.AllowedOrigins) + } + if !(hub.config.CheckOrigin(originRequest("https://evil.example"))) { + t.Errorf("expected true") + } + if !testContains(logs.String(), "HubConfig.AllowedOrigins") { + t.Errorf("expected %v to contain %v", logs.String(), "HubConfig.AllowedOrigins") + } + +} + +func TestWs_validIdentifier_Good(t *testing.T) { + tests := []struct { + name string + value string + max int + }{ + {name: "simple", value: "alpha", max: 10}, + {name: "safe token", value: "A-Z_0-9-.:", max: 20}, + {name: "exact max length", value: strings.Repeat("a", 8), max: 8}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !(validIdentifier(tt.value, tt.max)) { + t.Errorf("expected true") + } + + }) + } +} + +func TestWs_validIdentifier_Bad(t *testing.T) { + tests := []struct { + name string + value string + max int + }{ + {name: "empty", value: "", max: 8}, + {name: "whitespace padded", value: " alpha", max: 8}, + {name: "embedded whitespace", value: "al pha", max: 8}, + {name: "too long", value: strings.Repeat("a", 9), max: 8}, + {name: "non-ascii", value: "grüße", max: 16}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if validIdentifier(tt.value, tt.max) { + t.Errorf("expected false") + } + + }) + } +} + +func TestWs_validIdentifier_Ugly(t *testing.T) { + if validIdentifier(strings.Repeat(" ", 4), 8) { + t.Errorf("expected false") + } + if validIdentifier("line\nbreak", 16) { + t.Errorf("expected false") + } + if validIdentifier("\tindent", 16) { + t.Errorf("expected false") + } + +} + +func TestWs_validateChannelTarget(t *testing.T) { + t.Run("accepts regular channels", func(t *testing.T) { + if err := validateChannelTarget("test", "events:user-1"); err != nil { + t.Errorf("expected no error, got %v", err) + } + + }) + + t.Run("accepts process channels with bounded IDs", func(t *testing.T) { + if err := validateChannelTarget("test", "process:proc-123"); err != nil { + t.Errorf("expected no error, got %v", err) + } + + }) + + t.Run("rejects process channels with empty IDs", func(t *testing.T) { + err := validateChannelTarget("test", "process:") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + + }) + + t.Run("rejects process channels with oversized IDs", func(t *testing.T) { + err := validateChannelTarget("test", "process:"+strings.Repeat("a", maxProcessIDLen+1)) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } - require.NotNil(t, hub) - assert.NotNil(t, hub.clients) - assert.NotNil(t, hub.broadcast) - assert.NotNil(t, hub.register) - assert.NotNil(t, hub.unregister) - assert.NotNil(t, hub.channels) }) } +func TestWs_validProcessID_Good(t *testing.T) { + tests := []string{ + "proc-123", + "proc_123", + "proc.123", + strings.Repeat("a", maxProcessIDLen), + } + + for _, processID := range tests { + t.Run(processID, func(t *testing.T) { + if !(validProcessID(processID)) { + t.Errorf("expected true") + } + + }) + } +} + +func TestWs_validProcessID_Bad(t *testing.T) { + tests := []string{ + "", + "bad process", + "proc:123", + "grüße", + } + + for _, processID := range tests { + t.Run(processID, func(t *testing.T) { + if validProcessID(processID) { + t.Errorf("expected false") + } + + }) + } +} + +func TestWs_validProcessID_Ugly(t *testing.T) { + if validProcessID(" proc-123 ") { + t.Errorf("expected false") + } + if validProcessID(strings.Repeat("a", maxProcessIDLen+1)) { + t.Errorf("expected false") + } + if validProcessID("line\nbreak") { + t.Errorf("expected false") + } + +} + func TestHub_Run(t *testing.T) { t.Run("stops on context cancel", func(t *testing.T) { hub := NewHub() @@ -59,6 +298,37 @@ func TestHub_Run(t *testing.T) { }) } +func TestWs_Run_NilClientEvents_Good(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + + go func() { + hub.Run(ctx) + close(done) + }() + + hub.register <- nil + hub.unregister <- nil + + time.Sleep(20 * time.Millisecond) + cancel() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("hub should stop after context cancel") + } +} + +func TestWs_Run_Ugly(t *testing.T) { + testNotPanics(t, func() { + var hub *Hub + hub.Run(context.Background()) + }) + +} + func TestHub_Broadcast(t *testing.T) { t.Run("marshals message with timestamp", func(t *testing.T) { hub := NewHub() @@ -71,7 +341,10 @@ func TestHub_Broadcast(t *testing.T) { } err := hub.Broadcast(msg) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + }) t.Run("returns error when channel full", func(t *testing.T) { @@ -82,8 +355,13 @@ func TestHub_Broadcast(t *testing.T) { } err := hub.Broadcast(Message{Type: TypeEvent}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "broadcast channel full") + if err := err; err == nil { + t.Errorf("expected error") + } + if !testContains(err.Error(), "broadcast channel full") { + t.Errorf("expected %v to contain %v", err.Error(), "broadcast channel full") + } + }) } @@ -92,34 +370,59 @@ func TestHub_Stats(t *testing.T) { hub := NewHub() stats := hub.Stats() + if !testEqual(0, stats.Clients) { + t.Errorf("expected %v, got %v", 0, stats.Clients) + } + if !testEqual(0, stats.Channels) { + t.Errorf("expected %v, got %v", 0, stats.Channels) + } + if !testEqual(0, stats.Subscribers) { + t.Errorf("expected %v, got %v", + + // Manually add clients for testing + 0, stats.Subscribers) + } - assert.Equal(t, 0, stats.Clients) - assert.Equal(t, 0, stats.Channels) }) t.Run("tracks client and channel counts", func(t *testing.T) { hub := NewHub() - // Manually add clients for testing hub.mu.Lock() client1 := &Client{subscriptions: make(map[string]bool)} client2 := &Client{subscriptions: make(map[string]bool)} hub.clients[client1] = true hub.clients[client2] = true - hub.channels["test-channel"] = make(map[*Client]bool) + hub.channels["test-channel"] = map[*Client]bool{ + client1: true, + client2: true, + } + hub.channels["other-channel"] = map[*Client]bool{ + client1: true, + } hub.mu.Unlock() stats := hub.Stats() + if !testEqual(2, stats.Clients) { + t.Errorf("expected %v, got %v", 2, stats.Clients) + } + if !testEqual(2, stats.Channels) { + t.Errorf("expected %v, got %v", 2, stats.Channels) + } + if !testEqual(3, stats.Subscribers) { + t.Errorf("expected %v, got %v", 3, stats.Subscribers) + } - assert.Equal(t, 2, stats.Clients) - assert.Equal(t, 1, stats.Channels) }) } func TestHub_ClientCount(t *testing.T) { t.Run("returns zero for empty hub", func(t *testing.T) { hub := NewHub() - assert.Equal(t, 0, hub.ClientCount()) + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + }) t.Run("counts connected clients", func(t *testing.T) { @@ -129,15 +432,20 @@ func TestHub_ClientCount(t *testing.T) { hub.clients[&Client{}] = true hub.clients[&Client{}] = true hub.mu.Unlock() + if !testEqual(2, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 2, hub.ClientCount()) + } - assert.Equal(t, 2, hub.ClientCount()) }) } func TestHub_ChannelCount(t *testing.T) { t.Run("returns zero for empty hub", func(t *testing.T) { hub := NewHub() - assert.Equal(t, 0, hub.ChannelCount()) + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + }) t.Run("counts active channels", func(t *testing.T) { @@ -147,15 +455,20 @@ func TestHub_ChannelCount(t *testing.T) { hub.channels["channel1"] = make(map[*Client]bool) hub.channels["channel2"] = make(map[*Client]bool) hub.mu.Unlock() + if !testEqual(2, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 2, hub.ChannelCount()) + } - assert.Equal(t, 2, hub.ChannelCount()) }) } func TestHub_ChannelSubscriberCount(t *testing.T) { t.Run("returns zero for non-existent channel", func(t *testing.T) { hub := NewHub() - assert.Equal(t, 0, hub.ChannelSubscriberCount("non-existent")) + if !testEqual(0, hub.ChannelSubscriberCount("non-existent")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("non-existent")) + } + }) t.Run("counts subscribers in channel", func(t *testing.T) { @@ -166,8 +479,10 @@ func TestHub_ChannelSubscriberCount(t *testing.T) { hub.channels["test-channel"][&Client{}] = true hub.channels["test-channel"][&Client{}] = true hub.mu.Unlock() + if !testEqual(2, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 2, hub.ChannelSubscriberCount("test-channel")) + } - assert.Equal(t, 2, hub.ChannelSubscriberCount("test-channel")) }) } @@ -184,10 +499,16 @@ func TestHub_Subscribe(t *testing.T) { hub.mu.Unlock() err := hub.Subscribe(client, "test-channel") - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("test-channel")) + } + if !(client.subscriptions["test-channel"]) { + t.Errorf("expected true") + } - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) - assert.True(t, client.subscriptions["test-channel"]) }) t.Run("creates channel if not exists", func(t *testing.T) { @@ -198,13 +519,51 @@ func TestHub_Subscribe(t *testing.T) { } err := hub.Subscribe(client, "new-channel") - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } hub.mu.RLock() _, exists := hub.channels["new-channel"] hub.mu.RUnlock() + if !(exists) { + t.Errorf("expected true") + } + + }) + + t.Run("rejects invalid channel names", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + err := hub.Subscribe(client, "bad channel") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid channel name") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid channel name") + } + + }) + + t.Run("rejects process channels with oversized IDs", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + err := hub.Subscribe(client, "process:"+strings.Repeat("a", maxProcessIDLen+1)) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } - assert.True(t, exists) }) } @@ -216,12 +575,19 @@ func TestHub_Unsubscribe(t *testing.T) { subscriptions: make(map[string]bool), } - hub.Subscribe(client, "test-channel") - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + _ = hub.Subscribe(client, "test-channel") + if !testEqual(1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("test-channel")) + } hub.Unsubscribe(client, "test-channel") - assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel")) - assert.False(t, client.subscriptions["test-channel"]) + if !testEqual(0, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("test-channel")) + } + if client.subscriptions["test-channel"] { + t.Errorf("expected false") + } + }) t.Run("cleans up empty channels", func(t *testing.T) { @@ -231,14 +597,16 @@ func TestHub_Unsubscribe(t *testing.T) { subscriptions: make(map[string]bool), } - hub.Subscribe(client, "temp-channel") + _ = hub.Subscribe(client, "temp-channel") hub.Unsubscribe(client, "temp-channel") hub.mu.RLock() _, exists := hub.channels["temp-channel"] hub.mu.RUnlock() + if exists { + t.Errorf("expected false") + } - assert.False(t, exists, "empty channel should be removed") }) t.Run("handles non-existent channel gracefully", func(t *testing.T) { @@ -265,20 +633,29 @@ func TestHub_SendToChannel(t *testing.T) { hub.mu.Lock() hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, "test-channel") + _ = hub.Subscribe(client, "test-channel") err := hub.SendToChannel("test-channel", Message{ Type: TypeEvent, Data: "test", }) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } select { case msg := <-client.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeEvent, received.Type) - assert.Equal(t, "test-channel", received.Channel) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + if !testEqual("test-channel", received.Channel) { + t.Errorf("expected %v, got %v", "test-channel", received.Channel) + } + case <-time.After(time.Second): t.Fatal("expected message on client send channel") } @@ -288,7 +665,36 @@ func TestHub_SendToChannel(t *testing.T) { hub := NewHub() err := hub.SendToChannel("non-existent", Message{Type: TypeEvent}) - assert.NoError(t, err, "should not error for non-existent channel") + if err := err; err != nil { + t.Errorf("expected no error, got %v", err) + } + + }) + + t.Run("rejects invalid channel names", func(t *testing.T) { + hub := NewHub() + + err := hub.SendToChannel("bad channel", Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid channel name") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid channel name") + } + + }) + + t.Run("rejects process channels with empty IDs", func(t *testing.T) { + hub := NewHub() + + err := hub.SendToChannel("process:", Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + }) } @@ -304,22 +710,46 @@ func TestHub_SendProcessOutput(t *testing.T) { hub.mu.Lock() hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, "process:proc-1") + _ = hub.Subscribe(client, "process:proc-1") err := hub.SendProcessOutput("proc-1", "hello world") - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } select { case msg := <-client.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeProcessOutput, received.Type) - assert.Equal(t, "proc-1", received.ProcessID) - assert.Equal(t, "hello world", received.Data) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeProcessOutput, received.Type) { + t.Errorf("expected %v, got %v", TypeProcessOutput, received.Type) + } + if !testEqual("proc-1", received.ProcessID) { + t.Errorf("expected %v, got %v", "proc-1", received.ProcessID) + } + if !testEqual("hello world", received.Data) { + t.Errorf("expected %v, got %v", "hello world", received.Data) + } + case <-time.After(time.Second): t.Fatal("expected message on client send channel") } }) + + t.Run("rejects invalid process IDs", func(t *testing.T) { + hub := NewHub() + + err := hub.SendProcessOutput("bad process", "hello world") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + + }) } func TestHub_SendProcessStatus(t *testing.T) { @@ -334,26 +764,54 @@ func TestHub_SendProcessStatus(t *testing.T) { hub.mu.Lock() hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, "process:proc-1") + _ = hub.Subscribe(client, "process:proc-1") err := hub.SendProcessStatus("proc-1", "exited", 0) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } select { case msg := <-client.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeProcessStatus, received.Type) - assert.Equal(t, "proc-1", received.ProcessID) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeProcessStatus, received.Type) { + t.Errorf("expected %v, got %v", TypeProcessStatus, received.Type) + } + if !testEqual("proc-1", received.ProcessID) { + t.Errorf("expected %v, got %v", "proc-1", received.ProcessID) + } data, ok := received.Data.(map[string]any) - require.True(t, ok) - assert.Equal(t, "exited", data["status"]) - assert.Equal(t, float64(0), data["exitCode"]) + if !(ok) { + t.Fatalf("expected true") + } + if !testEqual("exited", data["status"]) { + t.Errorf("expected %v, got %v", "exited", data["status"]) + } + if !testEqual(float64(0), data["exitCode"]) { + t.Errorf("expected %v, got %v", float64(0), data["exitCode"]) + } + case <-time.After(time.Second): t.Fatal("expected message on client send channel") } }) + + t.Run("rejects invalid process IDs", func(t *testing.T) { + hub := NewHub() + + err := hub.SendProcessStatus("bad process", "exited", 1) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + + }) } func TestHub_SendError(t *testing.T) { @@ -373,22 +831,31 @@ func TestHub_SendError(t *testing.T) { time.Sleep(10 * time.Millisecond) err := hub.SendError("something went wrong") - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } select { case msg := <-client.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeError, received.Type) - assert.Equal(t, "something went wrong", received.Data) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeError, received.Type) { + t.Errorf("expected %v, got %v", TypeError, received.Type) + } + if !testEqual("something went wrong", received.Data) { + t.Errorf("expected %v, got %v", "something went wrong", received.Data) + } + case <-time.After(time.Second): t.Fatal("expected error message on client send channel") } }) } -func TestHub_SendEvent(t *testing.T) { - t.Run("broadcasts event message", func(t *testing.T) { +func TestHub_Broadcast_AssignsTimestampAndValidatesProcessID(t *testing.T) { + t.Run("assigns a fresh timestamp", func(t *testing.T) { hub := NewHub() ctx := t.Context() go hub.Run(ctx) @@ -402,43 +869,188 @@ func TestHub_SendEvent(t *testing.T) { hub.register <- client time.Sleep(10 * time.Millisecond) - err := hub.SendEvent("user_joined", map[string]string{"user": "alice"}) - require.NoError(t, err) + before := time.Now() + err := hub.Broadcast(Message{ + Type: TypeEvent, + ProcessID: "proc-1", + Data: "hello", + Timestamp: time.Date(2024, time.January, 2, 3, 4, 5, 0, time.UTC), + }) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } select { case msg := <-client.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeEvent, received.Type) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if received.Timestamp.Before(before) { + t.Errorf("expected false") + } - data, ok := received.Data.(map[string]any) - require.True(t, ok) - assert.Equal(t, "user_joined", data["event"]) case <-time.After(time.Second): - t.Fatal("expected event message on client send channel") + t.Fatal("expected message on client send channel") + } + }) + + t.Run("rejects invalid process IDs", func(t *testing.T) { + hub := NewHub() + + err := hub.Broadcast(Message{ + Type: TypeEvent, + ProcessID: "bad process", + }) + if err := err; err == nil { + t.Fatalf("expected error") } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + }) } -func TestClient_Subscriptions(t *testing.T) { - t.Run("returns copy of subscriptions", func(t *testing.T) { +func TestHub_SendToChannel_AssignsTimestampAndValidatesProcessID(t *testing.T) { + t.Run("assigns a fresh timestamp", func(t *testing.T) { hub := NewHub() client := &Client{ hub: hub, + send: make(chan []byte, 256), subscriptions: make(map[string]bool), } - hub.Subscribe(client, "channel1") - hub.Subscribe(client, "channel2") - + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + if err := hub.Subscribe(client, "events"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + before := time.Now() + err := hub.SendToChannel("events", Message{ + Type: TypeEvent, + ProcessID: "proc-1", + Data: "hello", + Timestamp: time.Date(2024, time.February, 3, 4, 5, 6, 0, time.UTC), + }) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-client.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if received.Timestamp.Before(before) { + t.Errorf("expected false") + } + if !testEqual("events", received.Channel) { + t.Errorf("expected %v, got %v", "events", received.Channel) + } + + case <-time.After(time.Second): + t.Fatal("expected message on client send channel") + } + }) + + t.Run("rejects invalid process IDs", func(t *testing.T) { + hub := NewHub() + + err := hub.SendToChannel("events", Message{ + Type: TypeEvent, + ProcessID: "bad process", + }) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid process ID") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid process ID") + } + + }) +} + +func TestHub_SendEvent(t *testing.T) { + t.Run("broadcasts event message", func(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.register <- client + time.Sleep(10 * time.Millisecond) + + err := hub.SendEvent("user_joined", map[string]string{"user": "alice"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case msg := <-client.send: + var received Message + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + + data, ok := received.Data.(map[string]any) + if !(ok) { + t.Fatalf("expected true") + } + if !testEqual("user_joined", data["event"]) { + t.Errorf("expected %v, got %v", "user_joined", data["event"]) + } + + case <-time.After(time.Second): + t.Fatal("expected event message on client send channel") + } + }) +} + +func TestClient_Subscriptions(t *testing.T) { + t.Run("returns copy of subscriptions", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + _ = hub.Subscribe(client, "channel1") + _ = hub.Subscribe(client, "channel2") + subs := client.Subscriptions() + if gotLen := len(subs); gotLen != 2 { + t.Errorf("expected length %v, got %v", 2, gotLen) + } + if !testContains(subs, "channel1") { + t.Errorf("expected %v to contain %v", subs, "channel1") + } + if !testContains(subs, "channel2") { + t.Errorf("expected %v to contain %v", subs, "channel2") + } - assert.Len(t, subs, 2) - assert.Contains(t, subs, "channel1") - assert.Contains(t, subs, "channel2") }) } +func TestClient_Subscriptions_Ugly(t *testing.T) { + var client *Client + if !testIsNil(client.Subscriptions()) { + t.Errorf("expected nil, got %T", client.Subscriptions()) + } + +} + func TestClient_AllSubscriptions(t *testing.T) { t.Run("returns iterator over subscriptions", func(t *testing.T) { client := &Client{subscriptions: make(map[string]bool)} @@ -446,10 +1058,27 @@ func TestClient_AllSubscriptions(t *testing.T) { client.subscriptions["sub2"] = true subs := slices.Collect(client.AllSubscriptions()) - assert.Len(t, subs, 2) - assert.Contains(t, subs, "sub1") - assert.Contains(t, subs, "sub2") + if gotLen := len(subs); gotLen != 2 { + t.Errorf("expected length %v, got %v", 2, gotLen) + } + if !testContains(subs, "sub1") { + t.Errorf("expected %v to contain %v", subs, "sub1") + } + if !testContains(subs, "sub2") { + t.Errorf("expected %v to contain %v", subs, "sub2") + } + + }) +} + +func TestClient_AllSubscriptions_Ugly(t *testing.T) { + var client *Client + testNotPanics(t, func() { + if !testIsEmpty(slices.Collect(client.AllSubscriptions())) { + t.Errorf("expected empty value, got %v", slices.Collect(client.AllSubscriptions())) + } }) + } func TestHub_AllClients(t *testing.T) { @@ -464,9 +1093,16 @@ func TestHub_AllClients(t *testing.T) { hub.mu.Unlock() clients := slices.Collect(hub.AllClients()) - assert.Len(t, clients, 2) - assert.Contains(t, clients, client1) - assert.Contains(t, clients, client2) + if gotLen := len(clients); gotLen != 2 { + t.Errorf("expected length %v, got %v", 2, gotLen) + } + if !testContains(clients, client1) { + t.Errorf("expected %v to contain %v", clients, client1) + } + if !testContains(clients, client2) { + t.Errorf("expected %v to contain %v", clients, client2) + } + }) } @@ -479,12 +1115,260 @@ func TestHub_AllChannels(t *testing.T) { hub.mu.Unlock() channels := slices.Collect(hub.AllChannels()) - assert.Len(t, channels, 2) - assert.Contains(t, channels, "ch1") - assert.Contains(t, channels, "ch2") + if gotLen := len(channels); gotLen != 2 { + t.Errorf("expected length %v, got %v", 2, gotLen) + } + if !testContains(channels, "ch1") { + t.Errorf("expected %v to contain %v", channels, "ch1") + } + if !testContains(channels, "ch2") { + t.Errorf("expected %v to contain %v", channels, "ch2") + } + }) } +func TestWs_sortedHubClients_Good(t *testing.T) { + hub := NewHub() + clients := []*Client{ + {UserID: "bravo"}, + nil, + {UserID: "alpha"}, + } + + hub.mu.Lock() + for _, client := range clients { + hub.clients[client] = true + } + hub.mu.Unlock() + + ordered := slices.Collect(hub.AllClients()) + if gotLen := len(ordered); gotLen != 3 { + t.Fatalf("expected length %v, got %v", 3, gotLen) + } + if !testIsNil(ordered[0]) { + t.Errorf("expected nil, got %T", ordered[0]) + } + if !testEqual("alpha", ordered[1].UserID) { + t.Errorf("expected %v, got %v", "alpha", ordered[1].UserID) + } + if !testEqual("bravo", ordered[2].UserID) { + t.Errorf("expected %v, got %v", "bravo", ordered[2].UserID) + } + if !testEqual("", clientSortKey(&Client{})) { + t.Errorf("expected %v, got %v", "", clientSortKey(&Client{})) + } + +} + +func TestWs_sortedHubClients_Bad(t *testing.T) { + hub := NewHub() + if !testIsEmpty(sortedHubClients(hub)) { + t.Errorf("expected empty value, got %v", sortedHubClients(hub)) + } + +} + +func TestWs_sortedHubClients_Ugly(t *testing.T) { + if !testIsNil(sortedHubClients(nil)) { + t.Errorf("expected nil, got %T", sortedHubClients(nil)) + } + +} + +func TestWs_sortedHubClients_Good_SameUserID(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + serverA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + time.Sleep(50 * time.Millisecond) + })) + defer serverA.Close() + + serverB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + time.Sleep(50 * time.Millisecond) + })) + defer serverB.Close() + + left, _, err := websocket.DefaultDialer.Dial(wsURL(serverA), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, left.Close) + right, _, err := websocket.DefaultDialer.Dial(wsURL(serverB), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, right.Close) + + hub := NewHub() + leftClient := &Client{UserID: "shared", conn: left} + rightClient := &Client{UserID: "shared", conn: right} + + hub.mu.Lock() + hub.clients[leftClient] = true + hub.clients[rightClient] = true + hub.mu.Unlock() + + ordered := sortedHubClients(hub) + if gotLen := len(ordered); gotLen != 2 { + t.Fatalf("expected length %v, got %v", 2, gotLen) + } + if !testEqual("shared", ordered[0].UserID) { + t.Errorf("expected %v, got %v", "shared", ordered[0].UserID) + } + if !testEqual("shared", ordered[1].UserID) { + t.Errorf("expected %v, got %v", "shared", ordered[1].UserID) + } + if testEqual(clientSortKey(ordered[0]), clientSortKey(ordered[1])) { + t.Errorf("expected values to differ: %v", clientSortKey(ordered[1])) + } + +} + +func TestWs_sortedClientSubscriptions_Good(t *testing.T) { + client := &Client{ + subscriptions: map[string]bool{ + "zeta": true, + "alpha": true, + "mu": true, + }, + } + if !testEqual([]string{"alpha", "mu", "zeta"}, sortedClientSubscriptions(client)) { + t.Errorf("expected %v, got %v", []string{"alpha", "mu", "zeta"}, sortedClientSubscriptions(client)) + } + +} + +func TestWs_sortedClientSubscriptions_Bad(t *testing.T) { + client := &Client{subscriptions: map[string]bool{}} + if !testIsEmpty(sortedClientSubscriptions(client)) { + t.Errorf("expected empty value, got %v", sortedClientSubscriptions(client)) + } + +} + +func TestWs_sortedClientSubscriptions_Ugly(t *testing.T) { + if !testIsNil(sortedClientSubscriptions(nil)) { + t.Errorf("expected nil, got %T", sortedClientSubscriptions(nil)) + } + +} + +func TestWs_sortedHubChannels_Good(t *testing.T) { + hub := NewHub() + hub.channels["zeta"] = map[*Client]bool{} + hub.channels["alpha"] = map[*Client]bool{} + hub.channels["mu"] = map[*Client]bool{} + if !testEqual([]string{"alpha", "mu", "zeta"}, sortedHubChannels(hub)) { + t.Errorf("expected %v, got %v", []string{"alpha", "mu", "zeta"}, sortedHubChannels(hub)) + } + +} + +func TestWs_sortedHubChannels_Bad(t *testing.T) { + hub := NewHub() + if !testIsEmpty(sortedHubChannels(hub)) { + t.Errorf("expected empty value, got %v", sortedHubChannels(hub)) + } + +} + +func TestWs_sortedHubChannels_Ugly(t *testing.T) { + if !testIsNil(sortedHubChannels(nil)) { + t.Errorf("expected nil, got %T", sortedHubChannels(nil)) + } + +} + +func TestWs_clientSortKey_Good(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + time.Sleep(50 * time.Millisecond) + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + client := &Client{conn: conn} + if testIsEmpty(clientSortKey(client)) { + t.Errorf("expected non-empty value") + } + +} + +func TestWs_clientSortKey_Bad(t *testing.T) { + if !testEqual("", clientSortKey(nil)) { + t.Errorf("expected %v, got %v", "", clientSortKey(nil)) + } + +} + +func TestWs_clientSortKey_Ugly(t *testing.T) { + if !testEqual("", clientSortKey(&Client{})) { + t.Errorf("expected %v, got %v", "", clientSortKey(&Client{})) + } + +} + +func TestWs_subscribeLocked_Good(t *testing.T) { + hub := NewHubWithConfig(HubConfig{MaxSubscriptionsPerClient: 1}) + client := &Client{} + if err := hub.subscribeLocked(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !(client.subscriptions["alpha"]) { + t.Errorf("expected true") + } + if !testEqual(1, hub.ChannelSubscriberCount("alpha")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("alpha")) + } + +} + +func TestWs_subscribeLocked_Bad(t *testing.T) { + hub := NewHubWithConfig(HubConfig{MaxSubscriptionsPerClient: 1}) + client := &Client{subscriptions: map[string]bool{"alpha": true}} + if err := hub.subscribeLocked(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(1, hub.ChannelSubscriberCount("alpha")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("alpha")) + } + +} + +func TestWs_subscribeLocked_Ugly(t *testing.T) { + hub := NewHub() + if err := hub.subscribeLocked(nil, "alpha"); err != nil { + t.Errorf("expected no error, got %v", err) + } + +} + func TestMessage_JSON(t *testing.T) { t.Run("marshals correctly", func(t *testing.T) { msg := Message{ @@ -496,23 +1380,40 @@ func TestMessage_JSON(t *testing.T) { } r := core.JSONMarshal(msg) - require.True(t, r.OK) + if !(r.OK) { + t.Fatalf("expected true") + } + data := r.Value.([]byte) + if !testContains(string(data), `"type":"process_output"`) { + t.Errorf("expected %v to contain %v", string(data), `"type":"process_output"`) + } + if !testContains(string(data), `"channel":"process:1"`) { + t.Errorf("expected %v to contain %v", string(data), `"channel":"process:1"`) + } + if !testContains(string(data), `"processId":"1"`) { + t.Errorf("expected %v to contain %v", string(data), `"processId":"1"`) + } + if !testContains(string(data), `"data":"output line"`) { + t.Errorf("expected %v to contain %v", string(data), `"data":"output line"`) + } - assert.Contains(t, string(data), `"type":"process_output"`) - assert.Contains(t, string(data), `"channel":"process:1"`) - assert.Contains(t, string(data), `"processId":"1"`) - assert.Contains(t, string(data), `"data":"output line"`) }) t.Run("unmarshals correctly", func(t *testing.T) { jsonStr := `{"type":"subscribe","data":"channel:test"}` var msg Message - require.True(t, core.JSONUnmarshal([]byte(jsonStr), &msg).OK) + if !(core.JSONUnmarshal([]byte(jsonStr), &msg).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeSubscribe, msg.Type) { + t.Errorf("expected %v, got %v", TypeSubscribe, msg.Type) + } + if !testEqual("channel:test", msg.Data) { + t.Errorf("expected %v, got %v", "channel:test", msg.Data) + } - assert.Equal(t, TypeSubscribe, msg.Type) - assert.Equal(t, "channel:test", msg.Data) }) } @@ -521,6 +1422,11 @@ func TestHub_WebSocketHandler(t *testing.T) { hub := NewHub() ctx := t.Context() go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } server := httptest.NewServer(hub.Handler()) defer server.Close() @@ -528,19 +1434,26 @@ func TestHub_WebSocketHandler(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Give time for registration + err) + } + + defer testClose(t, conn.Close) - // Give time for registration time.Sleep(50 * time.Millisecond) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } - assert.Equal(t, 1, hub.ClientCount()) }) - t.Run("handles subscribe message", func(t *testing.T) { + t.Run("drops registration when the hub is shutting down", func(t *testing.T) { hub := NewHub() - ctx := t.Context() - go hub.Run(ctx) + hub.running = true + close(hub.done) server := httptest.NewServer(hub.Handler()) defer server.Close() @@ -548,84 +1461,407 @@ func TestHub_WebSocketHandler(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() - - // Send subscribe message - subscribeMsg := Message{ - Type: TypeSubscribe, - Data: "test-channel", + if conn != nil { + defer testClose(t, conn.Close) + } + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) } - err = conn.WriteJSON(subscribeMsg) - require.NoError(t, err) - // Give time for subscription - time.Sleep(50 * time.Millisecond) + time.Sleep(20 * time.Millisecond) + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) }) - t.Run("handles unsubscribe message", func(t *testing.T) { + t.Run("allows cross-origin requests with NewHub dev default", func(t *testing.T) { hub := NewHub() ctx := t.Context() go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } server := httptest.NewServer(hub.Handler()) defer server.Close() wsURL := "ws" + core.TrimPrefix(server.URL, "http") - conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + header := http.Header{} + header.Set("Origin", "https://evil.example") - // Subscribe first - err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) - require.NoError(t, err) - time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, header) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if testIsNil(resp) { + t.Fatalf("expected non-nil value") + } + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } - // Unsubscribe - err = conn.WriteJSON(Message{Type: TypeUnsubscribe, Data: "test-channel"}) - require.NoError(t, err) - time.Sleep(50 * time.Millisecond) - assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel")) }) - t.Run("responds to ping with pong", func(t *testing.T) { + t.Run("allows same-host cross-scheme requests with NewHub dev default", func(t *testing.T) { hub := NewHub() ctx := t.Context() go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } server := httptest.NewServer(hub.Handler()) defer server.Close() wsURL := "ws" + core.TrimPrefix(server.URL, "http") - conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() - - // Give time for registration - time.Sleep(50 * time.Millisecond) + header := http.Header{} + header.Set("Origin", "https://"+core.TrimPrefix(server.URL, "http://")) - // Send ping - err = conn.WriteJSON(Message{Type: TypePing}) - require.NoError(t, err) + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, header) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if testIsNil(resp) { + t.Fatalf("expected non-nil value") + } + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + }) + + t.Run("allows custom origin policy", func(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + }) + ctx := t.Context() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + + header := http.Header{} + header.Set("Origin", "https://evil.example") + + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, header) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if testIsNil(resp) { + t.Fatalf("expected non-nil value") + } + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + }) + + t.Run("rejects origin before authenticating", func(t *testing.T) { + var authCalled atomic.Bool + + hub := NewHubWithConfig(HubConfig{ + Authenticator: AuthenticatorFunc(func(r *http.Request) AuthResult { + authCalled.Store(true) + return AuthResult{Valid: true, UserID: "user-1"} + }), + CheckOrigin: func(r *http.Request) bool { + return false + }, + }) + ctx := t.Context() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + + header := http.Header{} + header.Set("Origin", "https://evil.example") + + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, header) + if conn != nil { + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if testIsNil(resp) { + t.Fatalf("expected non-nil value") + } + if !testEqual(http.StatusForbidden, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusForbidden, resp.StatusCode) + } + if authCalled.Load() { + t.Errorf("expected false") + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + + }) + + t.Run("treats panicking origin checks as forbidden", func(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + CheckOrigin: func(r *http.Request) bool { + panic("boom") + }, + }) + ctx := t.Context() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + + header := http.Header{} + header.Set("Origin", "https://evil.example") + + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, header) + if conn != nil { + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if testIsNil(resp) { + t.Fatalf("expected non-nil value") + } + if !testEqual(http.StatusForbidden, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusForbidden, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + + }) + + t.Run("handles subscribe message", func(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Send subscribe message + err) + } + + defer testClose(t, conn.Close) + + subscribeMsg := Message{ + Type: TypeSubscribe, + Data: "test-channel", + } + err = conn.WriteJSON(subscribeMsg) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Give time for subscription + err) + } + + time.Sleep(50 * time.Millisecond) + if !testEqual(1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("test-channel")) + } + + }) + + t.Run("rejects invalid subscribe channel names", func(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "bad channel"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + var response Message + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) + err = conn.ReadJSON(&response) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypeError, response.Type) { + t.Errorf("expected %v, got %v", TypeError, response.Type) + } + if !testContains(response.Data, "invalid channel name") { + t.Errorf("expected %v to contain %v", response.Data, "invalid channel name") + } + + }) + + t.Run("handles unsubscribe message", func(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Subscribe first + err) + } + + defer testClose(t, conn.Close) + + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + time.Sleep(50 * time.Millisecond) + if !testEqual(1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", + + // Unsubscribe + 1, hub.ChannelSubscriberCount("test-channel")) + } + + err = conn.WriteJSON(Message{Type: TypeUnsubscribe, Data: "test-channel"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + time.Sleep(50 * time.Millisecond) + if !testEqual(0, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("test-channel")) + } + + }) + + t.Run("responds to ping with pong", func(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Give time for registration + err) + } + + defer testClose(t, conn.Close) + + time.Sleep(50 * time.Millisecond) + + // Send ping + err = conn.WriteJSON(Message{Type: TypePing}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Read pong response + err) + } - // Read pong response var response Message - conn.SetReadDeadline(time.Now().Add(time.Second)) + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) err = conn.ReadJSON(&response) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypePong, response.Type) { + t.Errorf("expected %v, got %v", TypePong, response.Type) + } - assert.Equal(t, TypePong, response.Type) }) t.Run("broadcasts messages to clients", func(t *testing.T) { hub := NewHub() ctx := t.Context() go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } server := httptest.NewServer(hub.Handler()) defer server.Close() @@ -633,10 +1869,15 @@ func TestHub_WebSocketHandler(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Give time for registration + err) + } + + defer testClose(t, conn.Close) - // Give time for registration time.Sleep(50 * time.Millisecond) // Broadcast a message @@ -644,16 +1885,26 @@ func TestHub_WebSocketHandler(t *testing.T) { Type: TypeEvent, Data: "broadcast test", }) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Read the broadcast + err) + } - // Read the broadcast var response Message - conn.SetReadDeadline(time.Now().Add(time.Second)) + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) err = conn.ReadJSON(&response) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypeEvent, response.Type) { + t.Errorf("expected %v, got %v", TypeEvent, response.Type) + } + if !testEqual("broadcast test", response.Data) { + t.Errorf("expected %v, got %v", "broadcast test", response.Data) + } - assert.Equal(t, TypeEvent, response.Type) - assert.Equal(t, "broadcast test", response.Data) }) t.Run("unregisters client on connection close", func(t *testing.T) { @@ -667,18 +1918,31 @@ func TestHub_WebSocketHandler(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Wait for registration + err) + } - // Wait for registration time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", + + // Close connection + 1, hub.ClientCount( + + // Wait for unregistration + )) + } - // Close connection - conn.Close() + _ = conn.Close() - // Wait for unregistration time.Sleep(50 * time.Millisecond) - assert.Equal(t, 0, hub.ClientCount()) + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + }) t.Run("removes client from channels on disconnect", func(t *testing.T) { @@ -692,20 +1956,35 @@ func TestHub_WebSocketHandler(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Subscribe to channel + err) + } - // Subscribe to channel err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + if !testEqual(1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", + + // Close connection + 1, hub.ChannelSubscriberCount("test-channel")) + } - // Close connection - conn.Close() + _ = conn.Close() time.Sleep(50 * time.Millisecond) + if !testEqual( + + // Channel should be cleaned up + 0, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("test-channel")) + } - // Channel should be cleaned up - assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel")) }) } @@ -732,14 +2011,16 @@ func TestHub_Concurrency(t *testing.T) { hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, "shared-channel") - hub.Subscribe(client, "shared-channel") // Double subscribe should be safe + _ = hub.Subscribe(client, "shared-channel") + _ = hub.Subscribe(client, "shared-channel") // Double subscribe should be safe }(i) } wg.Wait() + if !testEqual(numClients, hub.ChannelSubscriberCount("shared-channel")) { + t.Errorf("expected %v, got %v", numClients, hub.ChannelSubscriberCount("shared-channel")) + } - assert.Equal(t, numClients, hub.ChannelSubscriberCount("shared-channel")) }) t.Run("handles concurrent broadcasts", func(t *testing.T) { @@ -787,9 +2068,11 @@ func TestHub_Concurrency(t *testing.T) { break loop } } + // All or most broadcasts should be received. + if received < numBroadcasts-10 { + t.Errorf("expected %v to be greater than or equal to %v", received, numBroadcasts-10) + } - // All or most broadcasts should be received - assert.GreaterOrEqual(t, received, numBroadcasts-10, "should receive most broadcasts") }) } @@ -806,27 +2089,37 @@ func TestHub_HandleWebSocket(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } + }) } func TestMustMarshal(t *testing.T) { t.Run("marshals valid data", func(t *testing.T) { data := mustMarshal(Message{Type: TypePong}) - assert.Contains(t, string(data), "pong") + if !testContains(string(data), "pong") { + t.Errorf("expected %v to contain %v", string(data), "pong") + } + }) t.Run("handles unmarshalable data without panic", func(t *testing.T) { // Create a channel which cannot be marshaled // This should not panic, even if it returns nil ch := make(chan int) - assert.NotPanics(t, func() { + testNotPanics(t, func() { _ = mustMarshal(ch) }) + }) } @@ -858,21 +2151,36 @@ func TestHub_Run_ShutdownClosesClients(t *testing.T) { hub.register <- client1 hub.register <- client2 time.Sleep(20 * time.Millisecond) + if !testEqual(2, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 2, hub.ClientCount()) + } + + _ = hub.Subscribe(client1, "shutdown-channel") + if !testEqual(1, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 1, hub.ChannelCount()) + } + if !testEqual(1, hub.ChannelSubscriberCount( + + // Cancel context to trigger shutdown + "shutdown-channel")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount( - assert.Equal(t, 2, hub.ClientCount()) - hub.Subscribe(client1, "shutdown-channel") - assert.Equal(t, 1, hub.ChannelCount()) - assert.Equal(t, 1, hub.ChannelSubscriberCount("shutdown-channel")) + // Send channels should be closed + "shutdown-channel")) + } - // Cancel context to trigger shutdown cancel() time.Sleep(50 * time.Millisecond) - // Send channels should be closed _, ok1 := <-client1.send - assert.False(t, ok1, "client1 send channel should be closed") + if ok1 { + t.Errorf("expected false") + } + _, ok2 := <-client2.send - assert.False(t, ok2, "client2 send channel should be closed") + if ok2 { + t.Errorf("expected false") + } select { case <-disconnectCalled: @@ -884,9 +2192,16 @@ func TestHub_Run_ShutdownClosesClients(t *testing.T) { case <-time.After(time.Second): t.Fatal("expected disconnect callback for both clients") } - assert.Equal(t, 0, hub.ClientCount()) - assert.Equal(t, 0, hub.ChannelCount()) - assert.Equal(t, 0, hub.ChannelSubscriberCount("shutdown-channel")) + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + if !testEqual(0, hub.ChannelSubscriberCount("shutdown-channel")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("shutdown-channel")) + } + }) } @@ -905,19 +2220,29 @@ func TestHub_Run_BroadcastToClientWithFullBuffer(t *testing.T) { hub.register <- slowClient time.Sleep(20 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", + + // Fill the client's send buffer + 1, hub.ClientCount()) + } - // Fill the client's send buffer slowClient.send <- []byte("blocking") // Broadcast should trigger the overflow path err := hub.Broadcast(Message{Type: TypeEvent, Data: "overflow"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Wait for the unregister goroutine to fire + err) + } - // Wait for the unregister goroutine to fire time.Sleep(100 * time.Millisecond) + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } - assert.Equal(t, 0, hub.ClientCount(), "slow client should be unregistered") }) } @@ -935,16 +2260,25 @@ func TestHub_Run_BroadcastWithClosedSendChannel(t *testing.T) { hub.register <- client time.Sleep(20 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", + + // Simulate a concurrent close before the hub attempts delivery. + 1, hub.ClientCount()) + } - // Simulate a concurrent close before the hub attempts delivery. client.closeSend() err := hub.Broadcast(Message{Type: TypeEvent, Data: "closed-channel"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } time.Sleep(100 * time.Millisecond) - assert.Equal(t, 0, hub.ClientCount(), "client with closed send channel should be unregistered") + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + }) } @@ -960,14 +2294,17 @@ func TestHub_SendToChannel_ClientBufferFull(t *testing.T) { hub.mu.Lock() hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, "test-channel") + _ = hub.Subscribe(client, "test-channel") // Fill the client buffer client.send <- []byte("blocking") // SendToChannel should not block; it skips the full client err := hub.SendToChannel("test-channel", Message{Type: TypeEvent, Data: "overflow"}) - assert.NoError(t, err) + if err := err; err != nil { + t.Errorf("expected no error, got %v", err) + } + }) } @@ -983,12 +2320,15 @@ func TestHub_SendToChannel_ClosedSendChannel(t *testing.T) { hub.mu.Lock() hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, "test-channel") + _ = hub.Subscribe(client, "test-channel") client.closeSend() err := hub.SendToChannel("test-channel", Message{Type: TypeEvent, Data: "closed-channel"}) - assert.NoError(t, err) + if err := err; err != nil { + t.Errorf("expected no error, got %v", err) + } + }) } @@ -1003,8 +2343,13 @@ func TestHub_Broadcast_MarshalError(t *testing.T) { } err := hub.Broadcast(msg) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to marshal message") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "failed to marshal message") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal message") + } + }) } @@ -1018,8 +2363,13 @@ func TestHub_SendToChannel_MarshalError(t *testing.T) { } err := hub.SendToChannel("any-channel", msg) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to marshal message") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "failed to marshal message") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal message") + } + }) } @@ -1034,32 +2384,252 @@ func TestHub_Handler_UpgradeError(t *testing.T) { // Make a plain HTTP request (not a WebSocket upgrade) resp, err := http.Get(server.URL) - require.NoError(t, err) - defer resp.Body.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // The handler should have returned an error response + err) + } + + defer testClose(t, resp.Body.Close) + if !testEqual(http.StatusBadRequest, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusBadRequest, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } - // The handler should have returned an error response - assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - assert.Equal(t, 0, hub.ClientCount()) }) } -func TestClient_Close(t *testing.T) { - t.Run("unregisters and closes connection", func(t *testing.T) { - hub := NewHub() - ctx := t.Context() - go hub.Run(ctx) +func TestWs_Handler_Bad(t *testing.T) { + var hub *Hub - server := httptest.NewServer(hub.Handler()) - defer server.Close() + handler := hub.Handler() + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + + handler(recorder, req) + if !testEqual(http.StatusServiceUnavailable, recorder.Code) { + t.Errorf("expected %v, got %v", http.StatusServiceUnavailable, recorder.Code) + } + if !testContains(recorder.Body.String(), "Hub is not configured") { + t.Errorf("expected %v to contain %v", recorder.Body.String(), "Hub is not configured") + } + +} + +func TestHub_Handler_AuthSnapshotAndUserID_Good(t *testing.T) { + claims := map[string]any{ + "role": "admin", + } + authCalled := make(chan struct{}, 1) + + hub := NewHubWithConfig(HubConfig{ + Authenticator: AuthenticatorFunc(func(r *http.Request) AuthResult { + select { + case authCalled <- struct{}{}: + default: + } + return AuthResult{ + Valid: true, + UserID: " user-123 ", + Claims: claims, + } + }), + }) + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + conn, resp, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if testIsNil(resp) { + t.Fatalf("expected non-nil value") + } + if !testEqual(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + defer testClose(t, conn.Close) + + select { + case <-authCalled: + case <-time.After(time.Second): + t.Fatal("authenticator should have been called") + } + + claims["role"] = "user" + if !testEventually(func() bool { + return hub.ClientCount() == 1 + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + hub.mu.RLock() + var client *Client + for c := range hub.clients { + client = c + break + } + hub.mu.RUnlock() + if testIsNil(client) { + t.Fatalf("expected non-nil value") + } + if !testEqual("user-123", client.UserID) { + t.Errorf("expected %v, got %v", "user-123", client.UserID) + } + if testIsNil(client.Claims) { + t.Fatalf("expected non-nil value") + } + if !testEqual("admin", client.Claims["role"]) { + t.Errorf("expected %v, got %v", "admin", client.Claims["role"]) + } + +} + +func TestHub_Handler_RejectsEmptyUserID_Bad(t *testing.T) { + authFailure := make(chan AuthResult, 1) + + hub := NewHubWithConfig(HubConfig{ + Authenticator: AuthenticatorFunc(func(r *http.Request) AuthResult { + return AuthResult{ + Valid: true, + UserID: " ", + Claims: map[string]any{"role": "admin"}, + } + }), + OnAuthFailure: func(r *http.Request, result AuthResult) { + select { + case authFailure <- result: + default: + } + }, + }) + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + conn, resp, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if conn != nil { + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if testIsNil(resp) { + t.Fatalf("expected non-nil value") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + + select { + case result := <-authFailure: + if result.Valid { + t.Errorf("expected false") + } + if result.Authenticated { + t.Errorf("expected false") + } + if !(core.Is(result.Error, ErrMissingUserID)) { + t.Errorf("expected true") + } + + case <-time.After(time.Second): + t.Fatal("expected OnAuthFailure callback to run") + } +} + +func TestHub_Handler_AuthenticatorPanic_Ugly(t *testing.T) { + authFailure := make(chan AuthResult, 1) + + hub := NewHubWithConfig(HubConfig{ + Authenticator: AuthenticatorFunc(func(r *http.Request) AuthResult { + panic("boom") + }), + OnAuthFailure: func(r *http.Request, result AuthResult) { + select { + case authFailure <- result: + default: + } + }, + }) + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + conn, resp, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if conn != nil { + _ = conn.Close() + } + if err := err; err == nil { + t.Fatalf("expected error") + } + if testIsNil(resp) { + t.Fatalf("expected non-nil value") + } + if !testEqual(http.StatusUnauthorized, resp.StatusCode) { + t.Errorf("expected %v, got %v", http.StatusUnauthorized, resp.StatusCode) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + + select { + case result := <-authFailure: + if result.Valid { + t.Errorf("expected false") + } + if result.Authenticated { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "authenticator panicked") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "authenticator panicked") + } + + case <-time.After(time.Second): + t.Fatal("expected OnAuthFailure callback to run") + } +} + +func TestClient_Close(t *testing.T) { + t.Run("unregisters and closes connection", func(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", + + // Get the client from the hub + 1, hub.ClientCount()) + } - // Get the client from the hub hub.mu.RLock() var client *Client for c := range hub.clients { @@ -1067,21 +2637,63 @@ func TestClient_Close(t *testing.T) { break } hub.mu.RUnlock() - require.NotNil(t, client) + if testIsNil(client) { + t.Fatalf("expected non-nil value") + + // Close via Client.Close() + } - // Close via Client.Close() err = client.Close() // conn.Close may return an error if already closing, that is acceptable _ = err time.Sleep(50 * time.Millisecond) - assert.Equal(t, 0, hub.ClientCount()) + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", + + // Connection should be closed — writing should fail + 0, hub.ClientCount()) + } - // Connection should be closed — writing should fail _ = conn.Close() // ensure clean up }) } +func TestClient_Close_NilAndDetached_Ugly(t *testing.T) { + t.Run("nil client", func(t *testing.T) { + var client *Client + if err := client.Close(); err != nil { + t.Errorf("expected no error, got %v", err) + } + + }) + + t.Run("detached client with nil conn", func(t *testing.T) { + client := &Client{} + if err := client.Close(); err != nil { + t.Errorf("expected no error, got %v", err) + } + + }) + + t.Run("hub with nil conn", func(t *testing.T) { + hub := NewHub() + client := &Client{hub: hub} + if err := client.Close(); err != nil { + t.Errorf("expected no error, got %v", err) + } + + }) +} + +func TestClient_closeSend_Nil_Ugly(t *testing.T) { + var client *Client + testNotPanics(t, func() { + client.closeSend() + }) + +} + func TestReadPump_MalformedJSON(t *testing.T) { t.Run("ignores malformed JSON messages", func(t *testing.T) { hub := NewHub() @@ -1093,21 +2705,33 @@ func TestReadPump_MalformedJSON(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Send malformed JSON — should be ignored without disconnecting err = conn.WriteMessage(websocket.TextMessage, []byte("this is not json")) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Send a valid subscribe after the bad message — client should still be alive + err) + } - // Send a valid subscribe after the bad message — client should still be alive err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + if !testEqual(1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("test-channel")) + } + }) } @@ -1122,8 +2746,11 @@ func TestReadPump_SubscribeWithNonStringData(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) @@ -1132,15 +2759,122 @@ func TestReadPump_SubscribeWithNonStringData(t *testing.T) { "type": "subscribe", "data": 12345, }) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } time.Sleep(50 * time.Millisecond) + if !testEqual( + + // No channels should have been created + 0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + + }) +} + +func TestClient_readPump_Ugly(t *testing.T) { + t.Run("nil receiver", func(t *testing.T) { + var client *Client + testNotPanics(t, func() { + client.readPump() + }) + + }) + + t.Run("missing hub", func(t *testing.T) { + client := &Client{} + testNotPanics(t, func() { + client.readPump() + }) + + }) +} + +func TestClient_writePump_Ugly(t *testing.T) { + t.Run("nil receiver", func(t *testing.T) { + var client *Client + testNotPanics(t, func() { + client.writePump() + }) + + }) + + t.Run("missing connection", func(t *testing.T) { + client := &Client{ + hub: &Hub{}, + } + testNotPanics(t, func() { + client.writePump() + }) + + }) +} + +func TestReadPump_SubscribeWithChannelField_Good(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + time.Sleep(50 * time.Millisecond) + + err = conn.WriteJSON(Message{ + Type: TypeSubscribe, + Channel: "field-channel", + }) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + time.Sleep(50 * time.Millisecond) + if !testEqual(1, hub.ChannelSubscriberCount("field-channel")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("field-channel")) + } + +} + +func TestWs_messageTargetChannel_Good(t *testing.T) { + t.Run("prefers the channel field", func(t *testing.T) { + if !testEqual("field-channel", messageTargetChannel(Message{Channel: "field-channel", Data: "data-channel"})) { + t.Errorf("expected %v, got %v", "field-channel", messageTargetChannel(Message{Channel: "field-channel", Data: "data-channel"})) + } + + }) + + t.Run("falls back to string data", func(t *testing.T) { + if !testEqual("data-channel", messageTargetChannel(Message{Data: "data-channel"})) { + t.Errorf("expected %v, got %v", "data-channel", messageTargetChannel(Message{Data: "data-channel"})) + } - // No channels should have been created - assert.Equal(t, 0, hub.ChannelCount()) }) } +func TestWs_messageTargetChannel_Bad(t *testing.T) { + if !testIsEmpty(messageTargetChannel(Message{Data: []string{"data-channel"}})) { + t.Errorf("expected empty value, got %v", messageTargetChannel(Message{Data: []string{"data-channel"}})) + } + +} + +func TestWs_messageTargetChannel_Ugly(t *testing.T) { + if !testIsEmpty(messageTargetChannel(Message{})) { + t.Errorf("expected empty value, got %v", messageTargetChannel(Message{})) + } + +} + func TestReadPump_UnsubscribeWithNonStringData(t *testing.T) { t.Run("ignores unsubscribe with non-string data", func(t *testing.T) { hub := NewHub() @@ -1152,28 +2886,44 @@ func TestReadPump_UnsubscribeWithNonStringData(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Subscribe first with valid data err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + if !testEqual(1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", + + // Send unsubscribe with non-string data — should be ignored + 1, hub.ChannelSubscriberCount("test-channel")) + } - // Send unsubscribe with non-string data — should be ignored err = conn.WriteJSON(map[string]any{ "type": "unsubscribe", "data": []string{"test-channel"}, }) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } time.Sleep(50 * time.Millisecond) + if !testEqual( + + // Channel should still have the subscriber + 1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("test-channel")) + } - // Channel should still have the subscriber - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) }) } @@ -1188,21 +2938,65 @@ func TestReadPump_UnknownMessageType(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Send a message with an unknown type err = conn.WriteJSON(Message{Type: "unknown_type", Data: "anything"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // Client should still be connected + err) + } - // Client should still be connected time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } + }) } +func TestReadPump_ReadLimit_Ugly(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + if !testEventually(func() bool { + return hub.ClientCount() == 1 + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + largePayload := strings.Repeat("A", defaultMaxMessageBytes+1) + err = conn.WriteMessage(websocket.TextMessage, []byte(largePayload)) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEventually(func() bool { + return hub.ClientCount() == 0 + }, 2*time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + +} + func TestWritePump_SendsCloseOnChannelClose(t *testing.T) { t.Run("sends close message when send channel is closed", func(t *testing.T) { hub := NewHub() @@ -1214,8 +3008,11 @@ func TestWritePump_SendsCloseOnChannelClose(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) @@ -1232,9 +3029,12 @@ func TestWritePump_SendsCloseOnChannelClose(t *testing.T) { time.Sleep(50 * time.Millisecond) // The client should receive a close message and the connection should end - conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _ = conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) _, _, readErr := conn.ReadMessage() - assert.Error(t, readErr, "reading should fail after close") + if err := readErr; err == nil { + t.Errorf("expected error") + } + }) } @@ -1249,8 +3049,11 @@ func TestWritePump_BatchesMessages(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) @@ -1262,21 +3065,31 @@ func TestWritePump_BatchesMessages(t *testing.T) { break } hub.mu.RUnlock() - require.NotNil(t, client) + if testIsNil(client) { + t.Fatalf("expected non-nil value") - // Queue multiple messages rapidly through the hub so writePump can - // batch them into a single websocket frame when possible. - require.NoError(t, hub.Broadcast(Message{Type: TypeEvent, Data: "batch-1"})) - require.NoError(t, hub.Broadcast(Message{Type: TypeEvent, Data: "batch-2"})) - require.NoError(t, hub.Broadcast(Message{Type: TypeEvent, Data: "batch-3"})) + // Queue multiple messages rapidly through the hub so writePump can + // batch them into a single websocket frame when possible. + } + if err := hub.Broadcast(Message{Type: TypeEvent, Data: "batch-1"}); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if err := hub.Broadcast(Message{Type: TypeEvent, Data: "batch-2"}); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if err := hub.Broadcast(Message{Type: TypeEvent, Data: "batch-3"}); err != nil { + t.Fatalf("expected no error, got %v", err) + } // Read frames until we have observed all three payloads or time out. deadline := time.Now().Add(time.Second) seen := map[string]bool{} for len(seen) < 3 { - conn.SetReadDeadline(deadline) + _ = conn.SetReadDeadline(deadline) _, data, readErr := conn.ReadMessage() - require.NoError(t, readErr) + if err := readErr; err != nil { + t.Fatalf("expected no error, got %v", err) + } content := string(data) for _, token := range []string{"batch-1", "batch-2", "batch-3"} { @@ -1288,6 +3101,180 @@ func TestWritePump_BatchesMessages(t *testing.T) { }) } +func TestWritePump_Heartbeat_Good(t *testing.T) { + pingSeen := make(chan struct{}, 1) + serverErr := make(chan error, 1) + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + serverErr <- err + return + } + + defer testClose(t, conn.Close) + + conn.SetPingHandler(func(string) error { + select { + case pingSeen <- struct{}{}: + default: + } + return nil + }) + + readDone := make(chan struct{}) + go func() { + defer close(readDone) + for { + if _, _, err := conn.ReadMessage(); err != nil { + return + } + } + }() + + <-readDone + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + hub := NewHubWithConfig(HubConfig{ + HeartbeatInterval: 10 * time.Millisecond, + WriteTimeout: time.Second, + }) + client := &Client{ + hub: hub, + conn: conn, + send: make(chan []byte, 1), + subscriptions: make(map[string]bool), + } + + done := make(chan struct{}) + go func() { + client.writePump() + close(done) + }() + + select { + case <-pingSeen: + case err := <-serverErr: + t.Fatalf("expected no server error, got %v", err) + case <-time.After(time.Second): + t.Fatal("expected heartbeat ping") + } + + close(client.send) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("writePump should exit after the send channel is closed") + } +} + +func TestWs_readPump_PongTimeout_Good(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + HeartbeatInterval: 10 * time.Millisecond, + PongTimeout: 30 * time.Millisecond, + WriteTimeout: time.Second, + }) + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + // Ignore server pings so the read deadline expires. + conn.SetPingHandler(func(string) error { + return nil + }) + + done := make(chan struct{}) + go func() { + defer close(done) + for { + if _, _, err := conn.ReadMessage(); err != nil { + return + } + } + }() + if !testEventually(func() bool { + return hub.ClientCount() == 1 + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + if !testEventually(func() bool { + return hub.ClientCount() == 0 + }, 2*time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("client connection should close after pong timeout") + } +} + +func TestWritePump_NextWriterError_Bad(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + time.Sleep(200 * time.Millisecond) + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + hub := NewHubWithConfig(HubConfig{ + HeartbeatInterval: time.Second, + WriteTimeout: time.Second, + }) + client := &Client{ + hub: hub, + conn: conn, + send: make(chan []byte, 1), + subscriptions: make(map[string]bool), + } + client.send <- []byte("payload") + if err := conn.Close(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + done := make(chan struct{}) + go func() { + client.writePump() + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("writePump should exit when NextWriter fails") + } +} + func TestHub_MultipleClientsOnChannel(t *testing.T) { t.Run("delivers channel messages to all subscribers", func(t *testing.T) { hub := NewHub() @@ -1303,34 +3290,59 @@ func TestHub_MultipleClientsOnChannel(t *testing.T) { conns := make([]*websocket.Conn, 3) for i := range conns { conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) conns[i] = conn } time.Sleep(50 * time.Millisecond) - assert.Equal(t, 3, hub.ClientCount()) + if !testEqual(3, hub.ClientCount()) { + t.Errorf("expected %v, got %v", + + // Subscribe all to "shared" + 3, hub.ClientCount()) + } - // Subscribe all to "shared" for _, conn := range conns { err := conn.WriteJSON(Message{Type: TypeSubscribe, Data: "shared"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + } time.Sleep(50 * time.Millisecond) - assert.Equal(t, 3, hub.ChannelSubscriberCount("shared")) + if !testEqual(3, hub.ChannelSubscriberCount("shared")) { + t.Errorf("expected %v, got %v", + + // Send to channel + 3, hub.ChannelSubscriberCount("shared")) + } - // Send to channel err := hub.SendToChannel("shared", Message{Type: TypeEvent, Data: "hello all"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", + + // All three clients should receive the message + err) + } - // All three clients should receive the message - for i, conn := range conns { - conn.SetReadDeadline(time.Now().Add(time.Second)) + for _, conn := range conns { + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) var received Message err := conn.ReadJSON(&received) - require.NoError(t, err, "client %d should receive message", i) - assert.Equal(t, TypeEvent, received.Type) - assert.Equal(t, "hello all", received.Data) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + if !testEqual("hello all", received.Data) { + t.Errorf("expected %v, got %v", "hello all", received.Data) + } + } }) } @@ -1361,7 +3373,7 @@ func TestHub_ConcurrentSubscribeUnsubscribe(t *testing.T) { wg.Add(1) go func(idx int) { defer wg.Done() - hub.Subscribe(clients[idx], "race-channel") + _ = hub.Subscribe(clients[idx], "race-channel") }(i) } wg.Wait() @@ -1374,15 +3386,21 @@ func TestHub_ConcurrentSubscribeUnsubscribe(t *testing.T) { if idx%2 == 0 { hub.Unsubscribe(clients[idx], "race-channel") } else { - hub.Subscribe(clients[idx], "another-channel") + _ = hub.Subscribe(clients[idx], "another-channel") } }(i) } wg.Wait() + if !testEqual( + + // Verify: half should remain on race-channel, half should be on another-channel + numClients/2, hub.ChannelSubscriberCount("race-channel")) { + t.Errorf("expected %v, got %v", numClients/2, hub.ChannelSubscriberCount("race-channel")) + } + if !testEqual(numClients/2, hub.ChannelSubscriberCount("another-channel")) { + t.Errorf("expected %v, got %v", numClients/2, hub.ChannelSubscriberCount("another-channel")) + } - // Verify: half should remain on race-channel, half should be on another-channel - assert.Equal(t, numClients/2, hub.ChannelSubscriberCount("race-channel")) - assert.Equal(t, numClients/2, hub.ChannelSubscriberCount("another-channel")) }) } @@ -1397,21 +3415,30 @@ func TestHub_ProcessOutputEndToEnd(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Subscribe to process channel err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "process:build-42"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(50 * time.Millisecond) // Send lines one at a time with a small delay to avoid batching lines := []string{"Compiling...", "Linking...", "Done."} for _, line := range lines { err = hub.SendProcessOutput("build-42", line) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(10 * time.Millisecond) // Allow writePump to flush each individually } @@ -1419,11 +3446,15 @@ func TestHub_ProcessOutputEndToEnd(t *testing.T) { // with newline separators. ReadMessage gives raw frames. var received []Message for len(received) < 3 { - conn.SetReadDeadline(time.Now().Add(time.Second)) + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) _, data, readErr := conn.ReadMessage() - require.NoError(t, readErr) + if err := readErr; err != nil { + t.Fatalf("expected no error, got %v", + + // A single frame may contain multiple newline-separated JSON objects + err) + } - // A single frame may contain multiple newline-separated JSON objects parts := strings.SplitSeq(core.Trim(string(data)), "\n") for part := range parts { part = core.Trim(part) @@ -1431,16 +3462,28 @@ func TestHub_ProcessOutputEndToEnd(t *testing.T) { continue } var msg Message - require.True(t, core.JSONUnmarshal([]byte(part), &msg).OK) + if !(core.JSONUnmarshal([]byte(part), &msg).OK) { + t.Fatalf("expected true") + } + received = append(received, msg) } } + if gotLen := len(received); gotLen != 3 { + t.Fatalf("expected length %v, got %v", 3, gotLen) + } - require.Len(t, received, 3) for i, expected := range lines { - assert.Equal(t, TypeProcessOutput, received[i].Type) - assert.Equal(t, "build-42", received[i].ProcessID) - assert.Equal(t, expected, received[i].Data) + if !testEqual(TypeProcessOutput, received[i].Type) { + t.Errorf("expected %v, got %v", TypeProcessOutput, received[i].Type) + } + if !testEqual("build-42", received[i].ProcessID) { + t.Errorf("expected %v, got %v", "build-42", received[i].ProcessID) + } + if !testEqual(expected, received[i].Data) { + t.Errorf("expected %v, got %v", expected, received[i].Data) + } + } }) } @@ -1456,36 +3499,57 @@ func TestHub_ProcessStatusEndToEnd(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Subscribe err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "process:job-7"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(50 * time.Millisecond) // Send status err = hub.SendProcessStatus("job-7", "exited", 1) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - conn.SetReadDeadline(time.Now().Add(time.Second)) + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) var received Message err = conn.ReadJSON(&received) - require.NoError(t, err) - assert.Equal(t, TypeProcessStatus, received.Type) - assert.Equal(t, "job-7", received.ProcessID) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypeProcessStatus, received.Type) { + t.Errorf("expected %v, got %v", TypeProcessStatus, received.Type) + } + if !testEqual("job-7", received.ProcessID) { + t.Errorf("expected %v, got %v", "job-7", received.ProcessID) + } data, ok := received.Data.(map[string]any) - require.True(t, ok) - assert.Equal(t, "exited", data["status"]) - assert.Equal(t, float64(1), data["exitCode"]) + if !(ok) { + t.Fatalf("expected true") + } + if !testEqual("exited", data["status"]) { + t.Errorf("expected %v, got %v", "exited", data["status"]) + + // --- Benchmarks --- + } + if !testEqual(float64(1), data["exitCode"]) { + t.Errorf("expected %v, got %v", float64(1), data["exitCode"]) + } + }) } -// --- Benchmarks --- - func BenchmarkBroadcast(b *testing.B) { hub := NewHub() ctx := b.Context() @@ -1525,7 +3589,7 @@ func BenchmarkSendToChannel(b *testing.B) { hub.mu.Lock() hub.clients[client] = true hub.mu.Unlock() - hub.Subscribe(client, "bench-channel") + _ = hub.Subscribe(client, "bench-channel") } msg := Message{Type: TypeEvent, Data: "benchmark"} @@ -1546,18 +3610,30 @@ func TestNewHubWithConfig(t *testing.T) { WriteTimeout: 3 * time.Second, } hub := NewHubWithConfig(config) + if !testEqual(5*time.Second, hub.config.HeartbeatInterval) { + t.Errorf("expected %v, got %v", 5*time.Second, hub.config.HeartbeatInterval) + } + if !testEqual(10*time.Second, hub.config.PongTimeout) { + t.Errorf("expected %v, got %v", 10*time.Second, hub.config.PongTimeout) + } + if !testEqual(3*time.Second, hub.config.WriteTimeout) { + t.Errorf("expected %v, got %v", 3*time.Second, hub.config.WriteTimeout) + } - assert.Equal(t, 5*time.Second, hub.config.HeartbeatInterval) - assert.Equal(t, 10*time.Second, hub.config.PongTimeout) - assert.Equal(t, 3*time.Second, hub.config.WriteTimeout) }) t.Run("applies defaults for zero values", func(t *testing.T) { hub := NewHubWithConfig(HubConfig{}) + if !testEqual(DefaultHeartbeatInterval, hub.config.HeartbeatInterval) { + t.Errorf("expected %v, got %v", DefaultHeartbeatInterval, hub.config.HeartbeatInterval) + } + if !testEqual(DefaultPongTimeout, hub.config.PongTimeout) { + t.Errorf("expected %v, got %v", DefaultPongTimeout, hub.config.PongTimeout) + } + if !testEqual(DefaultWriteTimeout, hub.config.WriteTimeout) { + t.Errorf("expected %v, got %v", DefaultWriteTimeout, hub.config.WriteTimeout) + } - assert.Equal(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval) - assert.Equal(t, DefaultPongTimeout, hub.config.PongTimeout) - assert.Equal(t, DefaultWriteTimeout, hub.config.WriteTimeout) }) t.Run("applies defaults for negative values", func(t *testing.T) { @@ -1566,10 +3642,16 @@ func TestNewHubWithConfig(t *testing.T) { PongTimeout: -1, WriteTimeout: -1, }) + if !testEqual(DefaultHeartbeatInterval, hub.config.HeartbeatInterval) { + t.Errorf("expected %v, got %v", DefaultHeartbeatInterval, hub.config.HeartbeatInterval) + } + if !testEqual(DefaultPongTimeout, hub.config.PongTimeout) { + t.Errorf("expected %v, got %v", DefaultPongTimeout, hub.config.PongTimeout) + } + if !testEqual(DefaultWriteTimeout, hub.config.WriteTimeout) { + t.Errorf("expected %v, got %v", DefaultWriteTimeout, hub.config.WriteTimeout) + } - assert.Equal(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval) - assert.Equal(t, DefaultPongTimeout, hub.config.PongTimeout) - assert.Equal(t, DefaultWriteTimeout, hub.config.WriteTimeout) }) t.Run("expands pong timeout when it does not exceed heartbeat interval", func(t *testing.T) { @@ -1577,22 +3659,41 @@ func TestNewHubWithConfig(t *testing.T) { HeartbeatInterval: 20 * time.Second, PongTimeout: 10 * time.Second, }) + if !testEqual(20*time.Second, hub.config.HeartbeatInterval) { + t.Errorf("expected %v, got %v", 20*time.Second, hub.config.HeartbeatInterval) + } + if !testEqual(40*time.Second, hub.config.PongTimeout) { + t.Errorf("expected %v, got %v", 40*time.Second, hub.config.PongTimeout) + } - assert.Equal(t, 20*time.Second, hub.config.HeartbeatInterval) - assert.Equal(t, 40*time.Second, hub.config.PongTimeout) }) } func TestDefaultHubConfig(t *testing.T) { t.Run("returns sensible defaults", func(t *testing.T) { config := DefaultHubConfig() + if !testEqual(30*time.Second, config.HeartbeatInterval) { + t.Errorf("expected %v, got %v", 30*time.Second, config.HeartbeatInterval) + } + if !testEqual(60*time.Second, config.PongTimeout) { + t.Errorf("expected %v, got %v", 60*time.Second, config.PongTimeout) + } + if !testEqual(10*time.Second, config.WriteTimeout) { + t.Errorf("expected %v, got %v", 10*time.Second, config.WriteTimeout) + } + if !testIsNil(config.OnConnect) { + t.Errorf("expected nil, got %T", config.OnConnect) + } + if !testIsNil(config.OnDisconnect) { + t.Errorf("expected nil, got %T", config.OnDisconnect) + } + if !testIsNil(config.ChannelAuthoriser) { + t.Errorf("expected nil, got %T", config.ChannelAuthoriser) + } + if !testIsEmpty(config.AllowedOrigins) { + t.Errorf("expected empty value, got %v", config.AllowedOrigins) + } - assert.Equal(t, 30*time.Second, config.HeartbeatInterval) - assert.Equal(t, 60*time.Second, config.PongTimeout) - assert.Equal(t, 10*time.Second, config.WriteTimeout) - assert.Nil(t, config.OnConnect) - assert.Nil(t, config.OnDisconnect) - assert.Nil(t, config.ChannelAuthoriser) }) } @@ -1613,12 +3714,18 @@ func TestHub_ConnectionCallbacks(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) select { case c := <-connectCalled: - assert.NotNil(t, c) + if testIsNil(c) { + t.Errorf("expected non-nil value") + } + case <-time.After(time.Second): t.Fatal("OnConnect callback should have been called") } @@ -1640,16 +3747,21 @@ func TestHub_ConnectionCallbacks(t *testing.T) { wsURL := "ws" + core.TrimPrefix(server.URL, "http") conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } time.Sleep(50 * time.Millisecond) // Close the connection to trigger disconnect - conn.Close() + _ = conn.Close() select { case c := <-disconnectCalled: - assert.NotNil(t, c) + if testIsNil(c) { + t.Errorf("expected non-nil value") + } + case <-time.After(time.Second): t.Fatal("OnDisconnect callback should have been called") } @@ -1704,14 +3816,24 @@ func TestHub_ChannelAuthoriser(t *testing.T) { hub.mu.Unlock() err := hub.Subscribe(client, "public:news") - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } err = hub.Subscribe(client, "private:ops") - require.Error(t, err) - assert.Contains(t, err.Error(), "subscription unauthorised") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "subscription unauthorised") { + t.Errorf("expected %v to contain %v", err.Error(), "subscription unauthorised") + } + if !testEqual(1, hub.ChannelSubscriberCount("public:news")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("public:news")) + } + if !testEqual(0, hub.ChannelSubscriberCount("private:ops")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("private:ops")) + } - assert.Equal(t, 1, hub.ChannelSubscriberCount("public:news")) - assert.Equal(t, 0, hub.ChannelSubscriberCount("private:ops")) }) } @@ -1729,16 +3851,85 @@ func TestHub_Subscribe_ReturnsError(t *testing.T) { } err := hub.Subscribe(client, "private:ops") - require.Error(t, err) - assert.Contains(t, err.Error(), "subscription unauthorised") - assert.Empty(t, client.subscriptions) - assert.Equal(t, 0, hub.ChannelCount()) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "subscription unauthorised") { + t.Errorf("expected %v to contain %v", err.Error(), "subscription unauthorised") + } + if !testIsEmpty(client.subscriptions) { + t.Errorf("expected empty value, got %v", client.subscriptions) + } + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + + }) +} + +func TestHub_ChannelAuthoriser_Panic_Ugly(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + ChannelAuthoriser: func(client *Client, channel string) bool { + panic("boom") + }, + }) + + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + err := hub.Subscribe(client, "panic-channel") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "subscription unauthorised") { + t.Errorf("expected %v to contain %v", err.Error(), "subscription unauthorised") + } + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + if !testIsEmpty(client.subscriptions) { + t.Errorf("expected empty value, got %v", client.subscriptions) + } + +} + +func TestHub_MaxSubscriptionsPerClient(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + MaxSubscriptionsPerClient: 1, }) + + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err := hub.Subscribe(client, "beta") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(err, ErrSubscriptionLimitExceeded)) { + t.Errorf("expected true") + } + if !testEqual(1, hub.ChannelSubscriberCount("alpha")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("alpha")) + } + if !testEqual(0, hub.ChannelSubscriberCount("beta")) { + t.Errorf("expected %v, got %v", 0, hub. + + // Use a very short heartbeat to test it actually fires + ChannelSubscriberCount("beta")) + } + } func TestHub_CustomHeartbeat(t *testing.T) { t.Run("uses custom heartbeat interval for server pings", func(t *testing.T) { - // Use a very short heartbeat to test it actually fires + hub := NewHubWithConfig(HubConfig{ HeartbeatInterval: 100 * time.Millisecond, PongTimeout: 500 * time.Millisecond, @@ -1755,8 +3946,11 @@ func TestHub_CustomHeartbeat(t *testing.T) { pingReceived := make(chan struct{}, 1) dialer := websocket.Dialer{} conn, _, err := dialer.Dial(wsURL, nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) conn.SetPingHandler(func(appData string) error { select { @@ -1819,7 +4013,9 @@ func TestReconnectingClient_Connect(t *testing.T) { // Run Connect in background clientCtx, clientCancel := context.WithCancel(context.Background()) defer clientCancel() - go rc.Connect(clientCtx) + go func() { + _ = rc.Connect(clientCtx) + }() // Wait for connect select { @@ -1828,20 +4024,30 @@ func TestReconnectingClient_Connect(t *testing.T) { case <-time.After(time.Second): t.Fatal("OnConnect should have been called") } + if !testEqual(StateConnected, rc.State()) { + t.Errorf("expected %v, got %v", - assert.Equal(t, StateConnected, rc.State()) + // Wait for client to register + StateConnected, rc.State()) + } - // Wait for client to register time.Sleep(50 * time.Millisecond) // Broadcast a message err := hub.Broadcast(Message{Type: TypeEvent, Data: "hello"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } select { case msg := <-msgReceived: - assert.Equal(t, TypeEvent, msg.Type) - assert.Equal(t, "hello", msg.Data) + if !testEqual(TypeEvent, msg.Type) { + t.Errorf("expected %v, got %v", TypeEvent, msg.Type) + } + if !testEqual("hello", msg.Data) { + t.Errorf("expected %v, got %v", "hello", msg.Data) + } + case <-time.After(time.Second): t.Fatal("should have received the broadcast message") } @@ -1851,28 +4057,182 @@ func TestReconnectingClient_Connect(t *testing.T) { }) } -func TestReconnectingClient_Reconnect(t *testing.T) { - t.Run("reconnects after server restart", func(t *testing.T) { - hub := NewHub() - ctx, cancel := context.WithCancel(context.Background()) - go hub.Run(ctx) +func TestReconnectingClient_ContextCancel_WhileConnected(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) - // Use a net.Listener so we control the port - listener, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) + server := httptest.NewServer(hub.Handler()) + defer server.Close() - server := &httptest.Server{ - Listener: listener, - Config: &http.Server{Handler: hub.Handler()}, - } - server.Start() + wsURL := "ws" + core.TrimPrefix(server.URL, "http") - wsURL := "ws" + core.TrimPrefix(server.URL, "http") - addr := listener.Addr().String() + connectCalled := make(chan struct{}, 1) + rc := NewReconnectingClient(ReconnectConfig{ + URL: wsURL, + OnConnect: func() { + select { + case connectCalled <- struct{}{}: + default: + } + }, + }) - reconnectCalled := make(chan int, 5) - disconnectCalled := make(chan struct{}, 5) - connectCalled := make(chan struct{}, 5) + clientCtx, clientCancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + done <- rc.Connect(clientCtx) + }() + + select { + case <-connectCalled: + case <-time.After(time.Second): + t.Fatal("OnConnect should have been called") + } + + clientCancel() + + select { + case err := <-done: + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(context.Canceled, err) { + t.Errorf("expected %v, got %v", context.Canceled, err) + } + + case <-time.After(2 * time.Second): + t.Fatal("Connect should return after context cancel while connected") + } +} + +func TestReconnectingClient_ReadLimit(t *testing.T) { + largePayload := strings.Repeat("A", defaultMaxMessageBytes+1) + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + time.Sleep(50 * time.Millisecond) + if err := conn.WriteMessage(websocket.TextMessage, []byte(largePayload)); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + time.Sleep(50 * time.Millisecond) + })) + defer server.Close() + + clientConn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, clientConn.Close) + + rc := &ReconnectingClient{conn: clientConn} + done := make(chan error, 1) + go func() { + done <- rc.readLoop() + }() + + select { + case readErr := <-done: + if err := readErr; err == nil { + t.Fatalf("expected error") + } + if !testContains(readErr.Error(), "read limit") { + t.Errorf("expected %v to contain %v", readErr.Error(), "read limit") + } + + case <-time.After(2 * time.Second): + t.Fatal("read loop should stop after exceeding the read limit") + } +} + +func TestReconnectingClient_OnMessageRawBytes(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + + rawReceived := make(chan []byte, 1) + + rc := NewReconnectingClient(ReconnectConfig{ + URL: wsURL, + OnMessage: func(msg []byte) { + copied := append([]byte(nil), msg...) + select { + case rawReceived <- copied: + default: + } + }, + }) + + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + go func() { + _ = rc.Connect(clientCtx) + }() + + time.Sleep(50 * time.Millisecond) + + err := hub.Broadcast(Message{Type: TypeEvent, Data: "raw-bytes"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case data := <-rawReceived: + if !testContains(string(data), "raw-bytes") { + t.Errorf("expected %v to contain %v", string(data), "raw-bytes") + } + + var received Message + if !(core.JSONUnmarshal(data, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + + case <-time.After(time.Second): + t.Fatal("raw byte callback should have been invoked") + } +} + +func TestReconnectingClient_Reconnect(t *testing.T) { + t.Run("reconnects after server restart", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + go hub.Run(ctx) + + // Use a net.Listener so we control the port + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + server := &httptest.Server{ + Listener: listener, + Config: &http.Server{Handler: hub.Handler()}, + } + server.Start() + + wsURL := "ws" + core.TrimPrefix(server.URL, "http") + addr := listener.Addr().String() + + reconnectCalled := make(chan int, 5) + disconnectCalled := make(chan struct{}, 5) + connectCalled := make(chan struct{}, 5) rc := NewReconnectingClient(ReconnectConfig{ URL: wsURL, @@ -1900,7 +4260,9 @@ func TestReconnectingClient_Reconnect(t *testing.T) { clientCtx, clientCancel := context.WithCancel(context.Background()) defer clientCancel() - go rc.Connect(clientCtx) + go func() { + _ = rc.Connect(clientCtx) + }() // Wait for initial connection select { @@ -1908,9 +4270,13 @@ func TestReconnectingClient_Reconnect(t *testing.T) { case <-time.After(time.Second): t.Fatal("initial connection should have succeeded") } - assert.Equal(t, StateConnected, rc.State()) + if !testEqual(StateConnected, rc.State()) { + t.Errorf("expected %v, got %v", + + // Shut down the server to simulate disconnect + StateConnected, rc.State()) + } - // Shut down the server to simulate disconnect cancel() server.Close() @@ -1943,24 +4309,104 @@ func TestReconnectingClient_Reconnect(t *testing.T) { // Wait for reconnection select { case attempt := <-reconnectCalled: - assert.Greater(t, attempt, 0) + if attempt <= 0 { + t.Errorf("expected %v to be greater than %v", attempt, 0) + } + case <-time.After(3 * time.Second): t.Fatal("OnReconnect should have been called") } + if !testEqual(StateConnected, rc.State()) { + t.Errorf("expected %v, got %v", StateConnected, rc.State()) + } - assert.Equal(t, StateConnected, rc.State()) clientCancel() }) } +func TestReconnectingClient_ReconnectBackoffAfterDisconnect(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + + var acceptedMu sync.Mutex + acceptedAt := make([]time.Time, 0, 2) + releaseSecond := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + acceptedMu.Lock() + acceptedAt = append(acceptedAt, time.Now()) + connectionCount := len(acceptedAt) + acceptedMu.Unlock() + + if connectionCount == 1 { + time.Sleep(20 * time.Millisecond) + _ = conn.Close() + return + } + + <-releaseSecond + _ = conn.Close() + })) + defer server.Close() + + rc := NewReconnectingClient(ReconnectConfig{ + URL: wsURL(server), + InitialBackoff: 150 * time.Millisecond, + MaxBackoff: 150 * time.Millisecond, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- rc.Connect(ctx) + }() + if !testEventually(func() bool { + acceptedMu.Lock() + defer acceptedMu.Unlock() + return len(acceptedAt) >= 2 + }, 3*time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + acceptedMu.Lock() + firstAccepted := acceptedAt[0] + secondAccepted := acceptedAt[1] + acceptedMu.Unlock() + if secondAccepted.Sub(firstAccepted) < 150*time.Millisecond { + t.Errorf("expected %v to be greater than or equal to %v", secondAccepted.Sub(firstAccepted), 150*time.Millisecond) + } + + close(releaseSecond) + cancel() + + select { + case err := <-done: + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testErrorIs(err, context.Canceled) { + t.Errorf("expected %v to match %v", err, context.Canceled) + } + + case <-time.After(2 * time.Second): + t.Fatal("Connect should return after cancellation") + } +} + func TestReconnectingClient_MaxRetries(t *testing.T) { t.Run("stops after max retries exceeded", func(t *testing.T) { // Use a URL that will never connect rc := NewReconnectingClient(ReconnectConfig{ - URL: "ws://127.0.0.1:1", // Should refuse connection - InitialBackoff: 10 * time.Millisecond, - MaxBackoff: 50 * time.Millisecond, - MaxRetries: 3, + URL: "ws://127.0.0.1:1", // Should refuse connection + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 50 * time.Millisecond, + MaxReconnectAttempts: 3, }) errCh := make(chan error, 1) @@ -1970,13 +4416,20 @@ func TestReconnectingClient_MaxRetries(t *testing.T) { select { case err := <-errCh: - require.Error(t, err) - assert.Contains(t, err.Error(), "max retries (3) exceeded") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "max retries (3) exceeded") { + t.Errorf("expected %v to contain %v", err.Error(), "max retries (3) exceeded") + } + case <-time.After(5 * time.Second): t.Fatal("should have stopped after max retries") } + if !testEqual(StateDisconnected, rc.State()) { + t.Errorf("expected %v, got %v", StateDisconnected, rc.State()) + } - assert.Equal(t, StateDisconnected, rc.State()) }) } @@ -2005,7 +4458,9 @@ func TestReconnectingClient_Send(t *testing.T) { clientCtx, clientCancel := context.WithCancel(context.Background()) defer clientCancel() - go rc.Connect(clientCtx) + go func() { + _ = rc.Connect(clientCtx) + }() <-connected time.Sleep(50 * time.Millisecond) @@ -2015,10 +4470,14 @@ func TestReconnectingClient_Send(t *testing.T) { Type: TypeSubscribe, Data: "test-channel", }) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + if !testEqual(1, hub.ChannelSubscriberCount("test-channel")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("test-channel")) + } clientCancel() }) @@ -2046,7 +4505,9 @@ func TestReconnectingClient_Send(t *testing.T) { clientCtx, clientCancel := context.WithCancel(context.Background()) defer clientCancel() - go rc.Connect(clientCtx) + go func() { + _ = rc.Connect(clientCtx) + }() select { case <-connected: @@ -2073,11 +4534,17 @@ func TestReconnectingClient_Send(t *testing.T) { close(errCh) for err := range errCh { - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + } time.Sleep(100 * time.Millisecond) - assert.GreaterOrEqual(t, hub.ChannelCount(), 1) + if hub.ChannelCount() < 1 { + t.Errorf("expected %v to be greater than or equal to %v", hub.ChannelCount(), 1) + } + }) t.Run("returns error when not connected", func(t *testing.T) { @@ -2086,8 +4553,13 @@ func TestReconnectingClient_Send(t *testing.T) { }) err := rc.Send(Message{Type: TypePing}) - require.Error(t, err) - assert.Contains(t, err.Error(), "not connected") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "not connected") { + t.Errorf("expected %v to contain %v", err.Error(), "not connected") + } + }) t.Run("returns error for unmarshalable message", func(t *testing.T) { @@ -2097,11 +4569,55 @@ func TestReconnectingClient_Send(t *testing.T) { // Force a conn to be set so we get past the nil check // to hit the marshal error first err := rc.Send(Message{Type: TypeEvent, Data: make(chan int)}) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to marshal message") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "failed to marshal message") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal message") + } + }) } +func TestWs_ReconnectingClient_Send_ContextCanceled_Good(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + time.Sleep(50 * time.Millisecond) + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + rc := &ReconnectingClient{ + conn: conn, + ctx: ctx, + config: ReconnectConfig{URL: wsURL(server)}, + } + + err = rc.Send(Message{Type: TypeEvent, Data: "payload"}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testErrorIs(err, context.Canceled) { + t.Errorf("expected %v to match %v", err, context.Canceled) + } + +} + func TestReconnectingClient_Close(t *testing.T) { t.Run("stops reconnection loop", func(t *testing.T) { hub := NewHub() @@ -2135,11 +4651,16 @@ func TestReconnectingClient_Close(t *testing.T) { <-connected err := rc.Close() - assert.NoError(t, err) + if err := err; err != nil { + t.Errorf("expected no error, got %v", + + // Good — Connect returned + err) + } select { case <-done: - // Good — Connect returned + case <-time.After(time.Second): t.Fatal("Connect should have returned after Close") } @@ -2151,7 +4672,10 @@ func TestReconnectingClient_Close(t *testing.T) { }) err := rc.Close() - assert.NoError(t, err) + if err := err; err != nil { + t.Errorf("expected no error, got %v", err) + } + }) } @@ -2163,118 +4687,531 @@ func TestReconnectingClient_ExponentialBackoff(t *testing.T) { MaxBackoff: 1 * time.Second, BackoffMultiplier: 2.0, }) + if !testEqual( + + // attempt 1: 100ms + 100*time.Millisecond, rc.calculateBackoff(1)) { + t.Errorf("expected %v, got %v", 100* + + // attempt 2: 200ms + time.Millisecond, rc.calculateBackoff(1)) + } + if !testEqual(200*time.Millisecond, rc.calculateBackoff( + + // attempt 3: 400ms + 2)) { + t.Errorf("expected %v, got %v", 200*time.Millisecond, rc.calculateBackoff( + + // attempt 4: 800ms + 2)) + } + if !testEqual(400*time.Millisecond, rc.calculateBackoff(3)) { + t.Errorf("expected %v, got %v", + + // attempt 5: capped at 1s + 400*time.Millisecond, rc.calculateBackoff(3)) + } + if !testEqual(800*time.Millisecond, + + // attempt 10: still capped at 1s + rc.calculateBackoff(4)) { + t.Errorf("expected %v, got %v", 800*time.Millisecond, rc.calculateBackoff(4)) + } + if !testEqual(1*time.Second, rc.calculateBackoff(5)) { + t.Errorf("expected %v, got %v", 1*time.Second, rc.calculateBackoff(5)) + } + if !testEqual(1*time.Second, rc.calculateBackoff(10)) { + t.Errorf("expected %v, got %v", 1*time.Second, rc.calculateBackoff(10)) + } - // attempt 1: 100ms - assert.Equal(t, 100*time.Millisecond, rc.calculateBackoff(1)) - // attempt 2: 200ms - assert.Equal(t, 200*time.Millisecond, rc.calculateBackoff(2)) - // attempt 3: 400ms - assert.Equal(t, 400*time.Millisecond, rc.calculateBackoff(3)) - // attempt 4: 800ms - assert.Equal(t, 800*time.Millisecond, rc.calculateBackoff(4)) - // attempt 5: capped at 1s - assert.Equal(t, 1*time.Second, rc.calculateBackoff(5)) - // attempt 10: still capped at 1s - assert.Equal(t, 1*time.Second, rc.calculateBackoff(10)) }) -} -func TestReconnectingClient_Defaults(t *testing.T) { - t.Run("applies defaults for zero config values", func(t *testing.T) { + t.Run("caps an oversized initial backoff", func(t *testing.T) { rc := NewReconnectingClient(ReconnectConfig{ - URL: "ws://localhost:1", + URL: "ws://localhost:1", + InitialBackoff: 5 * time.Second, + MaxBackoff: 1 * time.Second, }) + if !testEqual(1*time.Second, rc.config.InitialBackoff) { + t.Errorf("expected %v, got %v", 1*time.Second, rc.config.InitialBackoff) + } + if !testEqual(1*time.Second, rc.calculateBackoff(1)) { + t.Errorf("expected %v, got %v", 1*time.Second, rc.calculateBackoff(1)) + } - assert.Equal(t, 1*time.Second, rc.config.InitialBackoff) - assert.Equal(t, 30*time.Second, rc.config.MaxBackoff) - assert.Equal(t, 2.0, rc.config.BackoffMultiplier) - assert.NotNil(t, rc.config.Dialer) }) -} -func TestReconnectingClient_ContextCancel(t *testing.T) { - t.Run("returns context error on cancel during backoff", func(t *testing.T) { + t.Run("rejects shrinking multipliers", func(t *testing.T) { rc := NewReconnectingClient(ReconnectConfig{ - URL: "ws://127.0.0.1:1", - InitialBackoff: 10 * time.Second, // Long backoff + URL: "ws://localhost:1", + InitialBackoff: 100 * time.Millisecond, + MaxBackoff: 1 * time.Second, + BackoffMultiplier: 0.5, }) + if !testEqual(2.0, rc.config.BackoffMultiplier) { + t.Errorf("expected %v, got %v", 2.0, rc.config.BackoffMultiplier) + } + if !testEqual(100*time.Millisecond, rc.calculateBackoff(1)) { + t.Errorf("expected %v, got %v", 100*time.Millisecond, rc.calculateBackoff(1)) + } + if !testEqual(200*time.Millisecond, rc.calculateBackoff(2)) { + t.Errorf("expected %v, got %v", 200*time.Millisecond, rc.calculateBackoff(2)) + } - ctx, cancel := context.WithCancel(context.Background()) + }) +} - done := make(chan error, 1) - go func() { - done <- rc.Connect(ctx) - }() +func TestWs_calculateBackoff_Good(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + InitialBackoff: 250 * time.Millisecond, + MaxBackoff: 2 * time.Second, + BackoffMultiplier: 2.0, + }) + if !testEqual(250*time.Millisecond, rc.calculateBackoff(1)) { + t.Errorf("expected %v, got %v", 250*time.Millisecond, rc.calculateBackoff(1)) + } + if !testEqual(500*time.Millisecond, rc.calculateBackoff(2)) { + t.Errorf("expected %v, got %v", 500*time.Millisecond, rc.calculateBackoff(2)) + } + if !testEqual(time.Second, rc.calculateBackoff(3)) { + t.Errorf("expected %v, got %v", time.Second, rc.calculateBackoff(3)) + } - // Allow first dial attempt to fail - time.Sleep(200 * time.Millisecond) +} - // Cancel during backoff - cancel() +func TestWs_calculateBackoff_Bad(t *testing.T) { + rc := &ReconnectingClient{ + config: ReconnectConfig{}, + } + if !testEqual(1*time.Second, rc.calculateBackoff(0)) { + t.Errorf("expected %v, got %v", 1*time.Second, rc.calculateBackoff(0)) + } + if !testEqual(2*time.Second, rc.calculateBackoff(2)) { + t.Errorf("expected %v, got %v", 2*time.Second, rc.calculateBackoff(2)) + } - select { - case err := <-done: - require.Error(t, err) - assert.Equal(t, context.Canceled, err) - case <-time.After(2 * time.Second): - t.Fatal("Connect should have returned after context cancel") + t.Run("returns the ceiling when the initial backoff already matches max", func(t *testing.T) { + rc := &ReconnectingClient{ + config: ReconnectConfig{ + InitialBackoff: 1 * time.Second, + MaxBackoff: 1 * time.Second, + BackoffMultiplier: 2, + }, + } + if !testEqual(1*time.Second, rc.calculateBackoff(2)) { + t.Errorf("expected %v, got %v", 1*time.Second, rc.calculateBackoff(2)) } - }) -} -func TestConnectionState(t *testing.T) { - t.Run("state constants are distinct", func(t *testing.T) { - assert.NotEqual(t, StateDisconnected, StateConnecting) - assert.NotEqual(t, StateConnecting, StateConnected) - assert.NotEqual(t, StateDisconnected, StateConnected) }) } -// --------------------------------------------------------------------------- -// Hub.Run lifecycle — register, broadcast delivery, unregister via channels -// --------------------------------------------------------------------------- +func TestWs_calculateBackoff_Ugly(t *testing.T) { + rc := &ReconnectingClient{ + config: ReconnectConfig{ + InitialBackoff: 5 * time.Second, + MaxBackoff: 1 * time.Second, + }, + } + if !testEqual(1*time.Second, rc.calculateBackoff(1)) { + t.Errorf("expected %v, got %v", 1*time.Second, rc.calculateBackoff(1)) + } -func TestHubRun_RegisterClient_Good(t *testing.T) { - hub := NewHub() - ctx := t.Context() - go hub.Run(ctx) +} - client := &Client{ - hub: hub, - send: make(chan []byte, 256), - subscriptions: make(map[string]bool), +func TestWs_waitForReconnectBackoff_Good(t *testing.T) { + if !(waitForReconnectBackoff(context.Background(), nil, 0)) { + t.Errorf("expected true") } - hub.register <- client - time.Sleep(20 * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if !(waitForReconnectBackoff(ctx, nil, 10*time.Millisecond)) { + t.Errorf("expected true") + } - assert.Equal(t, 1, hub.ClientCount(), "client should be registered via hub loop") } -func TestHubRun_BroadcastDelivery_Good(t *testing.T) { - hub := NewHub() - ctx := t.Context() - go hub.Run(ctx) - - client := &Client{ - hub: hub, - send: make(chan []byte, 256), - subscriptions: make(map[string]bool), +func TestWs_waitForReconnectBackoff_Bad(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if waitForReconnectBackoff(ctx, nil, 10*time.Millisecond) { + t.Errorf("expected false") } - hub.register <- client - time.Sleep(20 * time.Millisecond) - - err := hub.Broadcast(Message{Type: TypeEvent, Data: "lifecycle-test"}) - require.NoError(t, err) +} + +func TestWs_waitForReconnectBackoff_Ugly(t *testing.T) { + done := make(chan struct{}) + close(done) + if waitForReconnectBackoff(context.Background(), done, 10*time.Millisecond) { + t.Errorf("expected false") + } + +} + +func TestWs_stopTimer_Good(t *testing.T) { + timer := time.NewTimer(time.Second) + stopTimer(timer) + + select { + case <-timer.C: + t.Fatal("stopTimer should drain the timer channel before it can fire") + default: + } +} + +func TestWs_stopTimer_Bad(t *testing.T) { + timer := time.NewTimer(10 * time.Millisecond) + <-timer.C + testNotPanics(t, func() { + stopTimer(timer) + }) + + t.Run("drains a fired timer", func(t *testing.T) { + timer := time.NewTimer(10 * time.Millisecond) + time.Sleep(20 * time.Millisecond) + testNotPanics(t, func() { + stopTimer(timer) + }) + + select { + case <-timer.C: + t.Fatal("stopTimer should drain the fired timer channel") + default: + } + }) +} + +func TestWs_stopTimer_Ugly(t *testing.T) { + testNotPanics(t, func() { + stopTimer(nil) + }) + +} + +func TestWs_closeRequested_Good(t *testing.T) { + rc := &ReconnectingClient{done: make(chan struct{})} + close(rc.done) + if !(rc.closeRequested()) { + t.Errorf("expected true") + } + +} + +func TestWs_closeRequested_Bad(t *testing.T) { + rc := &ReconnectingClient{done: make(chan struct{})} + if rc.closeRequested() { + t.Errorf("expected false") + } + +} + +func TestWs_closeRequested_Ugly(t *testing.T) { + var rc *ReconnectingClient + if rc.closeRequested() { + t.Errorf("expected false") + } + +} + +func TestWs_NewReconnectingClient_InfMultiplier_Ugly(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + BackoffMultiplier: math.Inf(1), + }) + if !testEqual(2.0, rc.config.BackoffMultiplier) { + t.Errorf("expected %v, got %v", 2.0, rc.config.BackoffMultiplier) + } + +} + +func TestWs_calculateBackoff_InvalidMultiplier_Ugly(t *testing.T) { + rc := &ReconnectingClient{ + config: ReconnectConfig{ + InitialBackoff: 100 * time.Millisecond, + MaxBackoff: 1 * time.Second, + BackoffMultiplier: math.Inf(1), + }, + } + if !testEqual(200*time.Millisecond, rc.calculateBackoff(2)) { + t.Errorf("expected %v, got %v", 200*time.Millisecond, rc.calculateBackoff(2)) + } + +} + +func TestWs_calculateBackoff_Overflow_Ugly(t *testing.T) { + rc := &ReconnectingClient{ + config: ReconnectConfig{ + InitialBackoff: time.Duration(1 << 62), + MaxBackoff: time.Duration(1<<63 - 1), + BackoffMultiplier: 10, + }, + } + if !testEqual(rc.config.MaxBackoff, rc.calculateBackoff(2)) { + t.Errorf("expected %v, got %v", rc.config.MaxBackoff, rc.calculateBackoff(2)) + } + +} + +func TestWs_Connect_DoneClosed_Good(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + }) + close(rc.done) + + err := rc.Connect(context.Background()) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + +} + +func TestWs_Connect_NilContext_Good(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 20 * time.Millisecond, + MaxReconnectAttempts: 1, + }) + + done := make(chan error, 1) + go func() { + done <- rc.Connect(context.TODO()) + }() + + select { + case err := <-done: + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "max retries (1) exceeded") { + t.Errorf("expected %v to contain %v", err.Error(), "max retries (1) exceeded") + } + + case <-time.After(5 * time.Second): + t.Fatal("Connect should return when the retry limit is reached") + } +} + +func TestReconnectingClient_MaxReconnectAttempts_Precedence_Good(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 20 * time.Millisecond, + MaxRetries: 99, + MaxReconnectAttempts: 1, + }) + + errCh := make(chan error, 1) + go func() { + errCh <- rc.Connect(context.Background()) + }() + + select { + case err := <-errCh: + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "max retries (1) exceeded") { + t.Errorf("expected %v to contain %v", err.Error(), "max retries (1) exceeded") + } + + case <-time.After(5 * time.Second): + t.Fatal("Connect should have stopped after MaxReconnectAttempts") + } +} + +func TestReconnectingClient_MaxReconnectAttempts_ZeroMeansUnlimited_Good(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + MaxReconnectAttempts: 0, + }) + if !testEqual(0, rc.maxReconnectAttempts()) { + t.Errorf("expected %v, got %v", 0, rc.maxReconnectAttempts()) + } + +} + +func TestReconnectingClient_MaxRetries_Compatibility_Good(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + MaxRetries: 3, + }) + if !testEqual(3, rc.maxReconnectAttempts()) { + t.Errorf("expected %v, got %v", 3, rc.maxReconnectAttempts()) + } + +} + +func TestReconnectingClient_MaxReconnectAttempts_Negative_Ugly(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + MaxRetries: -1, + MaxReconnectAttempts: -5, + }) + if !testEqual(0, rc.maxReconnectAttempts()) { + t.Errorf("expected %v, got %v", 0, rc.maxReconnectAttempts()) + } + +} + +func TestDispatchReconnectMessage_StringAndUnsupported_Good(t *testing.T) { + stringCalled := false + dispatchReconnectMessage(func(s string) { + stringCalled = true + if !testContains(s, "payload") { + t.Errorf("expected %v to contain %v", s, "payload") + } + + }, []byte("payload")) + if !(stringCalled) { + t.Errorf("expected true") + } + testNotPanics(t, func() { + dispatchReconnectMessage(123, []byte("ignored")) + }) + +} + +func TestReconnectingClient_Defaults(t *testing.T) { + t.Run("applies defaults for zero config values", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + }) + if !testEqual(1*time.Second, rc.config.InitialBackoff) { + t.Errorf("expected %v, got %v", 1*time.Second, rc.config.InitialBackoff) + } + if !testEqual(30*time.Second, rc.config.MaxBackoff) { + t.Errorf("expected %v, got %v", 30*time.Second, rc.config.MaxBackoff) + } + if !testEqual(2.0, rc.config.BackoffMultiplier) { + t.Errorf("expected %v, got %v", 2.0, rc.config.BackoffMultiplier) + } + if testIsNil(rc.config.Dialer) { + t.Errorf("expected non-nil value") + } + + }) +} + +func TestReconnectingClient_ContextCancel(t *testing.T) { + t.Run("returns context error on cancel during backoff", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + InitialBackoff: 10 * time.Second, // Long backoff + }) + + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan error, 1) + go func() { + done <- rc.Connect(ctx) + }() + + // Allow first dial attempt to fail + time.Sleep(200 * time.Millisecond) + + // Cancel during backoff + cancel() + + select { + case err := <-done: + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(context.Canceled, err) { + t.Errorf("expected %v, got %v", context.Canceled, err) + } + + case <-time.After(2 * time.Second): + t.Fatal("Connect should have returned after context cancel") + } + }) +} + +func TestConnectionState(t *testing.T) { + t.Run("state constants are distinct", func(t *testing.T) { + if testEqual(StateDisconnected, StateConnecting) { + t.Errorf("expected values to differ: %v", StateConnecting) + } + if testEqual(StateConnecting, StateConnected) { + t.Errorf("expected values to differ: %v", StateConnected) + } + if testEqual(StateDisconnected, StateConnected) { + t.Errorf("expected values to differ: %v", StateConnected) + } + + }) +} + +func TestReconnectingClient_State_Ugly(t *testing.T) { + var rc *ReconnectingClient + if !testEqual(StateDisconnected, rc.State()) { + t.Errorf("expected %v, got %v", + + // --------------------------------------------------------------------------- + // Hub.Run lifecycle — register, broadcast delivery, unregister via channels + // --------------------------------------------------------------------------- + StateDisconnected, rc.State()) + } + +} + +func TestHubRun_RegisterClient_Good(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.register <- client + time.Sleep(20 * time.Millisecond) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } + +} + +func TestHubRun_BroadcastDelivery_Good(t *testing.T) { + hub := NewHub() + ctx := t.Context() + go hub.Run(ctx) + + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.register <- client + time.Sleep(20 * time.Millisecond) + + err := hub.Broadcast(Message{Type: TypeEvent, Data: "lifecycle-test"}) + if err := err; err != nil { + t.Fatalf( + + // Hub.Run loop delivers the broadcast to the client's send channel + "expected no error, got %v", err) + } - // Hub.Run loop delivers the broadcast to the client's send channel select { case msg := <-client.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeEvent, received.Type) - assert.Equal(t, "lifecycle-test", received.Data) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + if !testEqual("lifecycle-test", received.Data) { + t.Errorf("expected %v, got %v", "lifecycle-test", received.Data) + } + case <-time.After(time.Second): t.Fatal("broadcast should be delivered via hub loop") } @@ -2293,17 +5230,27 @@ func TestHubRun_UnregisterClient_Good(t *testing.T) { hub.register <- client time.Sleep(20 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf( - // Subscribe so we can verify channel cleanup - hub.Subscribe(client, "lifecycle-chan") - assert.Equal(t, 1, hub.ChannelSubscriberCount("lifecycle-chan")) + // Subscribe so we can verify channel cleanup + "expected %v, got %v", 1, hub.ClientCount()) + } + + _ = hub.Subscribe(client, "lifecycle-chan") + if !testEqual(1, hub.ChannelSubscriberCount("lifecycle-chan")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("lifecycle-chan")) + } hub.unregister <- client time.Sleep(20 * time.Millisecond) + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + if !testEqual(0, hub.ChannelSubscriberCount("lifecycle-chan")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("lifecycle-chan")) + } - assert.Equal(t, 0, hub.ClientCount()) - assert.Equal(t, 0, hub.ChannelSubscriberCount("lifecycle-chan")) } func TestHubRun_UnregisterIgnoresDuplicate_Bad(t *testing.T) { @@ -2350,16 +5297,27 @@ func TestSubscribe_MultipleChannels_Good(t *testing.T) { subscriptions: make(map[string]bool), } - hub.Subscribe(client, "alpha") - hub.Subscribe(client, "beta") - hub.Subscribe(client, "gamma") + _ = hub.Subscribe(client, "alpha") + _ = hub.Subscribe(client, "beta") + _ = hub.Subscribe(client, "gamma") + if !testEqual(3, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 3, hub.ChannelCount()) + } - assert.Equal(t, 3, hub.ChannelCount()) subs := client.Subscriptions() - assert.Len(t, subs, 3) - assert.Contains(t, subs, "alpha") - assert.Contains(t, subs, "beta") - assert.Contains(t, subs, "gamma") + if gotLen := len(subs); gotLen != 3 { + t.Errorf("expected length %v, got %v", 3, gotLen) + } + if !testContains(subs, "alpha") { + t.Errorf("expected %v to contain %v", subs, "alpha") + } + if !testContains(subs, "beta") { + t.Errorf("expected %v to contain %v", subs, "beta") + } + if !testContains(subs, "gamma") { + t.Errorf("expected %v to contain %v", subs, "gamma") + } + } func TestSubscribe_IdempotentDoubleSubscribe_Good(t *testing.T) { @@ -2370,11 +5328,15 @@ func TestSubscribe_IdempotentDoubleSubscribe_Good(t *testing.T) { subscriptions: make(map[string]bool), } - hub.Subscribe(client, "dupl") - hub.Subscribe(client, "dupl") + _ = hub.Subscribe(client, "dupl") + _ = hub.Subscribe(client, "dupl") + if !testEqual( + + // Still only one subscriber entry in the channel map + 1, hub.ChannelSubscriberCount("dupl")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("dupl")) + } - // Still only one subscriber entry in the channel map - assert.Equal(t, 1, hub.ChannelSubscriberCount("dupl")) } func TestUnsubscribe_PartialLeave_Good(t *testing.T) { @@ -2382,18 +5344,27 @@ func TestUnsubscribe_PartialLeave_Good(t *testing.T) { client1 := &Client{hub: hub, send: make(chan []byte, 256), subscriptions: make(map[string]bool)} client2 := &Client{hub: hub, send: make(chan []byte, 256), subscriptions: make(map[string]bool)} - hub.Subscribe(client1, "shared") - hub.Subscribe(client2, "shared") - assert.Equal(t, 2, hub.ChannelSubscriberCount("shared")) + _ = hub.Subscribe(client1, "shared") + _ = hub.Subscribe(client2, "shared") + if !testEqual(2, hub.ChannelSubscriberCount("shared")) { + t.Errorf("expected %v, got %v", 2, hub.ChannelSubscriberCount("shared")) + } hub.Unsubscribe(client1, "shared") - assert.Equal(t, 1, hub.ChannelSubscriberCount("shared")) + if !testEqual(1, hub.ChannelSubscriberCount("shared")) { + t.Errorf( + + // Channel still exists because client2 is subscribed + "expected %v, got %v", 1, hub.ChannelSubscriberCount("shared")) + } - // Channel still exists because client2 is subscribed hub.mu.RLock() _, exists := hub.channels["shared"] hub.mu.RUnlock() - assert.True(t, exists, "channel should persist while subscribers remain") + if !(exists) { + t.Errorf("expected true") + } + } // --------------------------------------------------------------------------- @@ -2409,18 +5380,25 @@ func TestSendToChannel_MultipleSubscribers_Good(t *testing.T) { send: make(chan []byte, 256), subscriptions: make(map[string]bool), } - hub.Subscribe(clients[i], "multi") + _ = hub.Subscribe(clients[i], "multi") } err := hub.SendToChannel("multi", Message{Type: TypeEvent, Data: "fanout"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } for i, c := range clients { select { case msg := <-c.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, "multi", received.Channel) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("multi", received.Channel) { + t.Errorf("expected %v, got %v", "multi", received.Channel) + } + case <-time.After(time.Second): t.Fatalf("client %d should have received the message", i) } @@ -2434,7 +5412,10 @@ func TestSendToChannel_MultipleSubscribers_Good(t *testing.T) { func TestSendProcessOutput_NoSubscribers_Good(t *testing.T) { hub := NewHub() err := hub.SendProcessOutput("orphan-proc", "some output") - assert.NoError(t, err, "sending to a process with no subscribers should not error") + if err := err; err != nil { + t.Errorf("expected no error, got %v", err) + } + } func TestSendProcessStatus_NonZeroExit_Good(t *testing.T) { @@ -2444,29 +5425,44 @@ func TestSendProcessStatus_NonZeroExit_Good(t *testing.T) { send: make(chan []byte, 256), subscriptions: make(map[string]bool), } - hub.Subscribe(client, "process:fail-1") + _ = hub.Subscribe(client, "process:fail-1") err := hub.SendProcessStatus("fail-1", "exited", 137) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } select { case msg := <-client.send: var received Message - require.True(t, core.JSONUnmarshal(msg, &received).OK) - assert.Equal(t, TypeProcessStatus, received.Type) - assert.Equal(t, "fail-1", received.ProcessID) + if !(core.JSONUnmarshal(msg, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeProcessStatus, received.Type) { + t.Errorf("expected %v, got %v", TypeProcessStatus, received.Type) + } + if !testEqual("fail-1", received.ProcessID) { + t.Errorf("expected %v, got %v", "fail-1", received.ProcessID) + } + data := received.Data.(map[string]any) - assert.Equal(t, "exited", data["status"]) - assert.Equal(t, float64(137), data["exitCode"]) + if !testEqual("exited", data["status"]) { + t.Errorf("expected %v, got %v", "exited", data["status"]) + } + if !testEqual(float64(137), data["exitCode"]) { + t.Errorf("expected %v, got %v", float64(137), + + // --------------------------------------------------------------------------- + // readPump — ping with timestamp verification + // --------------------------------------------------------------------------- + data["exitCode"]) + } + case <-time.After(time.Second): t.Fatal("expected process status message") } } -// --------------------------------------------------------------------------- -// readPump — ping with timestamp verification -// --------------------------------------------------------------------------- - func TestReadPump_PingTimestamp_Good(t *testing.T) { hub := NewHub() ctx := t.Context() @@ -2476,24 +5472,36 @@ func TestReadPump_PingTimestamp_Good(t *testing.T) { defer server.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) err = conn.WriteJSON(Message{Type: TypePing}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - conn.SetReadDeadline(time.Now().Add(time.Second)) + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) var pong Message err = conn.ReadJSON(&pong) - require.NoError(t, err) - assert.Equal(t, TypePong, pong.Type) - assert.False(t, pong.Timestamp.IsZero(), "pong should include a timestamp") -} + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypePong, pong.Type) { + t.Errorf("expected %v, got %v", TypePong, pong.Type) -// --------------------------------------------------------------------------- -// writePump — batch sending with multiple messages -// --------------------------------------------------------------------------- + // --------------------------------------------------------------------------- + // writePump — batch sending with multiple messages + // --------------------------------------------------------------------------- + } + if pong.Timestamp.IsZero() { + t.Errorf("expected false") + } + +} func TestWritePump_BatchMultipleMessages_Good(t *testing.T) { hub := NewHub() @@ -2504,8 +5512,11 @@ func TestWritePump_BatchMultipleMessages_Good(t *testing.T) { defer server.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Rapidly send multiple broadcasts so they queue up @@ -2515,14 +5526,17 @@ func TestWritePump_BatchMultipleMessages_Good(t *testing.T) { Type: TypeEvent, Data: core.Sprintf("batch-%d", i), }) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + } time.Sleep(100 * time.Millisecond) // Read all messages — batched with newline separators received := 0 - conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) for received < numMessages { _, raw, err := conn.ReadMessage() if err != nil { @@ -2540,8 +5554,10 @@ func TestWritePump_BatchMultipleMessages_Good(t *testing.T) { } } } + if !testEqual(numMessages, received) { + t.Errorf("expected %v, got %v", numMessages, received) + } - assert.Equal(t, numMessages, received, "all batched messages should be received") } // --------------------------------------------------------------------------- @@ -2557,39 +5573,63 @@ func TestIntegration_UnsubscribeStopsDelivery_Good(t *testing.T) { defer server.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) // Subscribe err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "temp:feed"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(50 * time.Millisecond) // Verify we receive messages on the channel err = hub.SendToChannel("temp:feed", Message{Type: TypeEvent, Data: "before-unsub"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - conn.SetReadDeadline(time.Now().Add(time.Second)) + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) var msg1 Message err = conn.ReadJSON(&msg1) - require.NoError(t, err) - assert.Equal(t, "before-unsub", msg1.Data) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual( + + // Unsubscribe + "before-unsub", msg1.Data) { + t.Errorf("expected %v, got %v", "before-unsub", msg1.Data) + } - // Unsubscribe err = conn.WriteJSON(Message{Type: TypeUnsubscribe, Data: "temp:feed"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(50 * time.Millisecond) // Send another message -- client should NOT receive it err = hub.SendToChannel("temp:feed", Message{Type: TypeEvent, Data: "after-unsub"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf( + + // Try to read -- should timeout (no message delivered) + "expected no error, got %v", err) + } - // Try to read -- should timeout (no message delivered) - conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) var msg2 Message err = conn.ReadJSON(&msg2) - assert.Error(t, err, "should not receive messages after unsubscribing") + if err := err; err == nil { + t.Errorf("expected error") + } + } // --------------------------------------------------------------------------- @@ -2608,32 +5648,49 @@ func TestIntegration_BroadcastReachesAllClients_Good(t *testing.T) { conns := make([]*websocket.Conn, numClients) for i := range numClients { conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) conns[i] = conn } time.Sleep(100 * time.Millisecond) - assert.Equal(t, numClients, hub.ClientCount()) + if !testEqual(numClients, hub.ClientCount()) { + t.Errorf( + + // Broadcast -- no channel subscription needed + "expected %v, got %v", numClients, hub.ClientCount()) + } - // Broadcast -- no channel subscription needed err := hub.Broadcast(Message{Type: TypeError, Data: "global-alert"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - for i, conn := range conns { - conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + for _, conn := range conns { + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) var received Message err := conn.ReadJSON(&received) - require.NoError(t, err, "client %d should receive broadcast", i) - assert.Equal(t, TypeError, received.Type) - assert.Equal(t, "global-alert", received.Data) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypeError, received.Type) { + t.Errorf("expected %v, got %v", TypeError, received.Type) + } + if !testEqual( + + // --------------------------------------------------------------------------- + // Integration — disconnect cleans up all subscriptions + // --------------------------------------------------------------------------- + "global-alert", received.Data) { + t.Errorf("expected %v, got %v", "global-alert", received.Data) + } + } } -// --------------------------------------------------------------------------- -// Integration — disconnect cleans up all subscriptions -// --------------------------------------------------------------------------- - func TestIntegration_DisconnectCleansUpEverything_Good(t *testing.T) { hub := NewHub() ctx := t.Context() @@ -2643,27 +5700,52 @@ func TestIntegration_DisconnectCleansUpEverything_Good(t *testing.T) { defer server.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf( + + // Subscribe to multiple channels + "expected no error, got %v", err) + } - // Subscribe to multiple channels err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "ch-a"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "ch-b"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(50 * time.Millisecond) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } + if !testEqual(1, hub.ChannelSubscriberCount("ch-a")) { + t.Errorf("expected %v, got %v", - assert.Equal(t, 1, hub.ClientCount()) - assert.Equal(t, 1, hub.ChannelSubscriberCount("ch-a")) - assert.Equal(t, 1, hub.ChannelSubscriberCount("ch-b")) + // Disconnect + 1, hub.ChannelSubscriberCount("ch-a")) + } + if !testEqual(1, hub.ChannelSubscriberCount("ch-b")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("ch-b")) + } - // Disconnect - conn.Close() + _ = conn.Close() time.Sleep(100 * time.Millisecond) + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + if !testEqual(0, hub.ChannelSubscriberCount("ch-a")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("ch-a")) + } + if !testEqual(0, hub.ChannelSubscriberCount("ch-b")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("ch-b")) + } + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } - assert.Equal(t, 0, hub.ClientCount()) - assert.Equal(t, 0, hub.ChannelSubscriberCount("ch-a")) - assert.Equal(t, 0, hub.ChannelSubscriberCount("ch-b")) - assert.Equal(t, 0, hub.ChannelCount(), "empty channels should be cleaned up") } func TestIntegration_ChannelAuthoriser_RejectsForbiddenSubscription_Good(t *testing.T) { @@ -2687,30 +5769,50 @@ func TestIntegration_ChannelAuthoriser_RejectsForbiddenSubscription_Good(t *test defer server.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "private:ops"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } - conn.SetReadDeadline(time.Now().Add(time.Second)) + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) var response Message - require.NoError(t, conn.ReadJSON(&response)) - assert.Equal(t, TypeError, response.Type) - assert.Contains(t, response.Data.(string), "subscription unauthorised") - assert.Equal(t, 0, hub.ChannelSubscriberCount("private:ops")) + if err := conn.ReadJSON(&response); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(TypeError, response.Type) { + t.Errorf("expected %v, got %v", TypeError, response.Type) + } + if !testContains(response.Data.(string), "subscription unauthorised") { + t.Errorf("expected %v to contain %v", response.Data.(string), "subscription unauthorised") + } + if !testEqual(0, hub.ChannelSubscriberCount("private:ops")) { + t.Errorf("expected %v, got %v", 0, hub. + + // --------------------------------------------------------------------------- + // Concurrent broadcast + subscribe via hub loop (race test) + // --------------------------------------------------------------------------- + ChannelSubscriberCount("private:ops")) + } err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "public:news"}) - require.NoError(t, err) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ChannelSubscriberCount("public:news")) -} + if !testEqual(1, hub.ChannelSubscriberCount("public:news")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("public:news")) + } -// --------------------------------------------------------------------------- -// Concurrent broadcast + subscribe via hub loop (race test) -// --------------------------------------------------------------------------- +} func TestConcurrentSubscribeAndBroadcast_Good(t *testing.T) { hub := NewHub() @@ -2738,8 +5840,10 @@ func TestConcurrentSubscribeAndBroadcast_Good(t *testing.T) { wg.Wait() time.Sleep(100 * time.Millisecond) + if !testEqual(50, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 50, hub.ClientCount()) + } - assert.Equal(t, 50, hub.ClientCount()) } func TestHub_Handler_RejectsWhenNotRunning(t *testing.T) { @@ -2751,16 +5855,26 @@ func TestHub_Handler_RejectsWhenNotRunning(t *testing.T) { conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) if err != nil { - assert.Error(t, err) - assert.Equal(t, 0, hub.ClientCount()) + if err := err; err == nil { + t.Errorf("expected error") + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + return } - defer conn.Close() - conn.SetReadDeadline(time.Now().Add(time.Second)) + defer testClose(t, conn.Close) + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) _, _, readErr := conn.ReadMessage() - require.Error(t, readErr) - assert.Equal(t, 0, hub.ClientCount()) + if err := readErr; err == nil { + t.Fatalf("expected error") + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + } func TestHub_OnConnect_CallbackPanic_DoesNotCrashHub(t *testing.T) { @@ -2784,14 +5898,1662 @@ func TestHub_OnConnect_CallbackPanic_DoesNotCrashHub(t *testing.T) { defer server.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) - require.NoError(t, err) - defer conn.Close() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, hub.ClientCount()) + if !testEqual(1, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 1, hub.ClientCount()) + } - conn.Close() + _ = conn.Close() time.Sleep(50 * time.Millisecond) + if gotLen := len(ctxErr); gotLen != 1 { + t.Fatalf("expected length %v, got %v", 1, gotLen) + } + +} + +func TestHub_OnConnect_CallbackCanReenterHub(t *testing.T) { + connected := make(chan struct{}, 1) + subscribeErr := make(chan error, 1) + + hub := NewHubWithConfig(HubConfig{ + OnConnect: func(client *Client) { + connected <- struct{}{} + subscribeErr <- client.hub.Subscribe(client, "callback-channel") + }, + }) + ctx := t.Context() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + select { + case <-connected: + case <-time.After(time.Second): + t.Fatal("OnConnect callback did not run") + } + + select { + case err := <-subscribeErr: + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + case <-time.After(time.Second): + t.Fatal("re-entrant subscription from OnConnect timed out") + } + if !testEventually(func() bool { + return hub.ChannelSubscriberCount("callback-channel") == 1 + }, time.Second, 10*time.Millisecond) { + t.Errorf("condition was not met before timeout") + } + +} + +func TestWs_nilHubError_Good(t *testing.T) { + err := nilHubError("Broadcast") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + if !testContains(err.Error(), "Broadcast") { + t.Errorf("expected %v to contain %v", err.Error(), "Broadcast") + } + +} + +func TestWs_nilHubError_Bad(t *testing.T) { + err := nilHubError("") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + +} + +func TestWs_nilHubError_Ugly(t *testing.T) { + err := nilHubError(" \t\n") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + +} + +func TestWs_NewHubWithConfig_Good(t *testing.T) { + hub := NewHubWithConfig(HubConfig{}) + if testIsNil(hub) { + t.Fatalf("expected non-nil value") + } + if !testEqual(DefaultHeartbeatInterval, hub.config.HeartbeatInterval) { + t.Errorf("expected %v, got %v", DefaultHeartbeatInterval, hub.config.HeartbeatInterval) + } + if !testEqual(DefaultPongTimeout, hub.config.PongTimeout) { + t.Errorf("expected %v, got %v", DefaultPongTimeout, hub.config.PongTimeout) + } + if !testEqual(DefaultWriteTimeout, hub.config.WriteTimeout) { + t.Errorf("expected %v, got %v", DefaultWriteTimeout, hub.config.WriteTimeout) + } + if !testEqual(DefaultMaxSubscriptionsPerClient, hub.config.MaxSubscriptionsPerClient) { + t.Errorf("expected %v, got %v", DefaultMaxSubscriptionsPerClient, hub.config.MaxSubscriptionsPerClient) + } + +} + +func TestWs_NewHubWithConfig_Bad(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + HeartbeatInterval: 5 * time.Second, + PongTimeout: 4 * time.Second, + WriteTimeout: -1, + MaxSubscriptionsPerClient: -1, + }) + if testIsNil(hub) { + t.Fatalf("expected non-nil value") + } + if !testEqual(5*time.Second, hub.config.HeartbeatInterval) { + t.Errorf("expected %v, got %v", 5*time.Second, hub.config.HeartbeatInterval) + } + if !testEqual(10*time.Second, hub.config.PongTimeout) { + t.Errorf("expected %v, got %v", 10*time.Second, hub.config.PongTimeout) + } + if !testEqual(DefaultWriteTimeout, hub.config.WriteTimeout) { + t.Errorf("expected %v, got %v", DefaultWriteTimeout, hub.config.WriteTimeout) + } + if !testEqual(DefaultMaxSubscriptionsPerClient, hub.config.MaxSubscriptionsPerClient) { + t.Errorf("expected %v, got %v", DefaultMaxSubscriptionsPerClient, hub.config.MaxSubscriptionsPerClient) + } + +} + +func TestWs_NewHubWithConfig_Ugly(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + HeartbeatInterval: -1, + PongTimeout: time.Nanosecond, + WriteTimeout: 0, + MaxSubscriptionsPerClient: 0, + }) + if testIsNil(hub) { + t.Fatalf("expected non-nil value") + } + if !testEqual(DefaultHeartbeatInterval, hub.config.HeartbeatInterval) { + t.Errorf("expected %v, got %v", DefaultHeartbeatInterval, hub.config.HeartbeatInterval) + } + if !testEqual(DefaultPongTimeout, hub.config.PongTimeout) { + t.Errorf("expected %v, got %v", DefaultPongTimeout, hub.config.PongTimeout) + } + if !testEqual(DefaultWriteTimeout, hub.config.WriteTimeout) { + t.Errorf("expected %v, got %v", DefaultWriteTimeout, hub.config.WriteTimeout) + } + if !testEqual(DefaultMaxSubscriptionsPerClient, hub.config.MaxSubscriptionsPerClient) { + t.Errorf("expected %v, got %v", DefaultMaxSubscriptionsPerClient, hub.config.MaxSubscriptionsPerClient) + } + +} + +func TestWs_Subscribe_Good(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + + err := hub.Subscribe(client, "alpha") + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !(client.subscriptions["alpha"]) { + t.Errorf("expected true") + } + if !testEqual(1, hub.ChannelSubscriberCount("alpha")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("alpha")) + } + +} + +func TestWs_Subscribe_RunningHubClosedDone_Bad(t *testing.T) { + t.Run("nil hub", func(t *testing.T) { + client := &Client{subscriptions: make(map[string]bool)} + + err := (*Hub)(nil).Subscribe(client, "alpha") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + + }) + + t.Run("invalid channel", func(t *testing.T) { + hub := NewHub() + client := &Client{subscriptions: make(map[string]bool)} + + err := hub.Subscribe(client, "bad channel") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "invalid channel name") { + t.Errorf("expected %v to contain %v", err.Error(), "invalid channel name") + } + + }) + + t.Run("channel authoriser rejects", func(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + ChannelAuthoriser: func(client *Client, channel string) bool { + return false + }, + }) + client := &Client{hub: hub, subscriptions: make(map[string]bool)} + + err := hub.Subscribe(client, "alpha") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "subscription unauthorised") { + t.Errorf("expected %v to contain %v", err.Error(), "subscription unauthorised") + } + + }) + + t.Run("subscription limit exceeded", func(t *testing.T) { + hub := NewHubWithConfig(HubConfig{MaxSubscriptionsPerClient: 1}) + client := &Client{hub: hub, subscriptions: make(map[string]bool)} + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err := hub.Subscribe(client, "beta") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !(core.Is(err, ErrSubscriptionLimitExceeded)) { + t.Errorf("expected true") + } + + }) +} + +func TestWs_Subscribe_Ugly(t *testing.T) { + hub := NewHub() + if err := hub.Subscribe(nil, "alpha"); err != nil { + t.Errorf("expected no error, got %v", err) + } + +} + +func TestWs_Subscribe_NilHub_Bad(t *testing.T) { + client := &Client{subscriptions: make(map[string]bool)} + + err := (*Hub)(nil).Subscribe(client, "alpha") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + +} + +func TestWs_Subscribe_NilSubscriptions_Good(t *testing.T) { + hub := NewHub() + client := &Client{hub: hub} + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual([]string{"alpha"}, client.Subscriptions()) { + t.Errorf("expected %v, got %v", []string{"alpha"}, client.Subscriptions()) + } + +} + +func TestWs_Subscribe_HubStoppedBeforeReply_Bad(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + client := &Client{hub: hub, subscriptions: make(map[string]bool)} + client.mu.Lock() + done := make(chan error, 1) + go func() { + done <- hub.Subscribe(client, "alpha") + }() + + time.Sleep(20 * time.Millisecond) + hub.doneOnce.Do(func() { close(hub.done) }) + + select { + case err := <-done: + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub stopped before subscription completed") { + t.Errorf("expected %v to contain %v", err.Error(), "hub stopped before subscription completed") + } + + case <-time.After(time.Second): + t.Fatal("Subscribe should return once the hub shuts down") + } + + client.mu.Unlock() +} + +func TestWs_Unsubscribe_Good(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + hub.Unsubscribe(client, "alpha") + if client.subscriptions["alpha"] { + t.Errorf("expected false") + } + if !testEqual(0, hub.ChannelSubscriberCount("alpha")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("alpha")) + } + +} + +func TestWs_Unsubscribe_RunningHubClosedDone_Bad(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + hub.Unsubscribe(client, "bad channel") + if !(client.subscriptions["alpha"]) { + t.Errorf("expected true") + } + if !testEqual(1, hub.ChannelSubscriberCount("alpha")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("alpha")) + } + +} + +func TestWs_Unsubscribe_Ugly(t *testing.T) { + testNotPanics(t, func() { + var hub *Hub + hub.Unsubscribe(nil, "alpha") + hub.Unsubscribe(&Client{}, "") + }) + +} + +func TestWs_Unsubscribe_NilHub_Ugly(t *testing.T) { + testNotPanics(t, func() { + (*Hub)(nil).Unsubscribe(&Client{subscriptions: make(map[string]bool)}, "alpha") + }) + +} + +func TestWs_Unsubscribe_HubStoppedBeforeReply_Bad(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + if !testEventually(func() bool { + return hub.isRunning() + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + client := &Client{hub: hub, subscriptions: make(map[string]bool)} + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + client.mu.Lock() + done := make(chan struct{}) + go func() { + hub.Unsubscribe(client, "alpha") + close(done) + }() + + time.Sleep(20 * time.Millisecond) + hub.doneOnce.Do(func() { close(hub.done) }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Unsubscribe should return once the hub shuts down") + } + + client.mu.Unlock() +} + +func TestWs_dispatchReconnectMessage_Good(t *testing.T) { + var seen []Message + + dispatchReconnectMessage(func(msg Message) { + seen = append(seen, msg) + }, []byte("{\"type\":\"event\",\"data\":\"alpha\"}\n{\"type\":\"error\",\"data\":\"beta\"}")) + if gotLen := len(seen); gotLen != 2 { + t.Fatalf("expected length %v, got %v", 2, gotLen) + } + if !testEqual(TypeEvent, seen[0].Type) { + t.Errorf("expected %v, got %v", TypeEvent, seen[0].Type) + } + if !testEqual("alpha", seen[0].Data) { + t.Errorf("expected %v, got %v", "alpha", seen[0].Data) + } + if !testEqual(TypeError, seen[1].Type) { + t.Errorf("expected %v, got %v", TypeError, seen[1].Type) + } + if !testEqual("beta", seen[1].Data) { + t.Errorf("expected %v, got %v", "beta", seen[1].Data) + } + +} + +func TestWs_dispatchReconnectMessage_Bad(t *testing.T) { + called := 0 + + dispatchReconnectMessage(func(msg Message) { + called++ + }, []byte("{not-json}\n{\"type\":\"event\",\"data\":\"ok\"}")) + if !testEqual(1, called) { + t.Errorf("expected %v, got %v", 1, called) + } + +} + +func TestWs_dispatchReconnectMessage_Ugly(t *testing.T) { + testNotPanics(t, func() { + dispatchReconnectMessage(nil, []byte("ignored")) + dispatchReconnectMessage(123, []byte("ignored")) + dispatchReconnectMessage(func(msg Message) { + panic("boom") + }, []byte("{\"type\":\"event\"}")) + }) + +} + +func TestReconnectingClient_Send_Good(t *testing.T) { + msgSeen := make(chan []byte, 1) + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + + _, data, err := conn.ReadMessage() + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + msgSeen <- data + })) + defer server.Close() + + rc := NewReconnectingClient(ReconnectConfig{ + URL: wsURL(server), + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- rc.Connect(ctx) + }() + if !testEventually(func() bool { + return rc.State() == StateConnected + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + if err := rc.Send(Message{Type: TypeEvent, Data: "payload"}); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case data := <-msgSeen: + if !testContains(string(data), "\"type\":\"event\"") { + t.Errorf("expected %v to contain %v", string(data), "\"type\":\"event\"") + } + if !testContains(string(data), "\"data\":\"payload\"") { + t.Errorf("expected %v to contain %v", string(data), "\"data\":\"payload\"") + } + + case <-time.After(time.Second): + t.Fatal("server should have received the sent message") + } + if err := rc.Close(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case err := <-done: + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(context.Canceled, err) { + t.Errorf("expected %v, got %v", context.Canceled, err) + } + + case <-time.After(time.Second): + t.Fatal("Connect should stop after Close cancels the context") + } +} + +func TestReconnectingClient_Send_Bad(t *testing.T) { + t.Run("nil receiver", func(t *testing.T) { + var rc *ReconnectingClient + + err := rc.Send(Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "client must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "client must not be nil") + } + + }) + + t.Run("not connected", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{URL: "ws://127.0.0.1:1"}) + + err := rc.Send(Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "not connected") { + t.Errorf("expected %v to contain %v", err.Error(), "not connected") + } + + }) + + t.Run("marshal failure", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + OnError: func(err error) { + if !testContains(err.Error(), "failed to marshal message") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal message") + } + + }, + }) + + err := rc.Send(Message{Type: TypeEvent, Data: make(chan int)}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "failed to marshal message") { + t.Errorf("expected %v to contain %v", err.Error(), "failed to marshal message") + } + + }) + + t.Run("context canceled", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + })) + defer server.Close() + + clientConn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, clientConn.Close) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + rc := &ReconnectingClient{ + conn: clientConn, + ctx: ctx, + state: StateConnected, + config: ReconnectConfig{URL: "ws://127.0.0.1:1"}, + } + + err = rc.Send(Message{Type: TypeEvent, Data: "payload"}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testEqual(context.Canceled, err) { + t.Errorf("expected %v, got %v", context.Canceled, err) + } + + }) + + t.Run("write failure", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + })) + defer server.Close() + + clientConn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + rc := &ReconnectingClient{ + conn: clientConn, + state: StateConnected, + done: make(chan struct{}), + config: ReconnectConfig{URL: wsURL(server)}, + } + if err := clientConn.Close(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err = rc.Send(Message{Type: TypeEvent, Data: "payload"}) + if err := err; err == nil { + t.Fatalf("expected error") + } + + }) +} + +func TestReconnectingClient_Close_Ugly(t *testing.T) { + var rc *ReconnectingClient + if err := rc.Close(); err != nil { + t.Errorf("expected no error, got %v", err) + } + +} + +func TestReconnectingClient_Connect_Ugly(t *testing.T) { + var rc *ReconnectingClient + + err := rc.Connect(context.Background()) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "client must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "client must not be nil") + } + +} + +func TestReconnectingClient_Connect_OnError_Good(t *testing.T) { + errs := make(chan error, 4) + + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 20 * time.Millisecond, + MaxReconnectAttempts: 1, + OnError: func(err error) { + select { + case errs <- err: + default: + } + }, + }) + + done := make(chan error, 1) + go func() { + done <- rc.Connect(context.Background()) + }() + + select { + case err := <-done: + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "max retries (1) exceeded") { + t.Errorf("expected %v to contain %v", err.Error(), "max retries (1) exceeded") + } + + case <-time.After(5 * time.Second): + t.Fatal("Connect should stop after max retries") + } + if !testEventually(func() bool { + return len(errs) >= 2 + }, time.Second, 10*time.Millisecond) { + t.Fatalf("condition was not met before timeout") + } + + first := <-errs + second := <-errs + if err := first; err == nil { + t.Fatalf("expected error") + } + if err := second; err == nil { + t.Fatalf("expected error") + } + if !testContains(second.Error(), "max retries (1) exceeded") { + t.Errorf("expected %v to contain %v", second.Error(), "max retries (1) exceeded") + } + +} + +func TestReconnectingClient_Send_Ugly(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{URL: "ws://127.0.0.1:1"}) + rc.setState(StateConnected) + + err := rc.Send(Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "not connected") { + t.Errorf("expected %v to contain %v", err.Error(), "not connected") + } + +} + +func TestReconnectingClient_readLoop_Ugly(t *testing.T) { + rc := &ReconnectingClient{} + if err := rc.readLoop(); err != nil { + t.Errorf("expected no error, got %v", err) + } + +} + +func TestWs_sameOriginCheck_Good(t *testing.T) { + tests := []struct { + name string + req func() *http.Request + want bool + }{ + { + name: "no origin header is allowed", + req: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + }, + want: true, + }, + { + name: "matches host and scheme", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", "http://example.com") + return r + }, + want: true, + }, + { + name: "matches https on explicit port", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "https://example.com:443/ws", nil) + r.TLS = &tls.ConnectionState{} + r.Header.Set("Origin", "https://example.com") + return r + }, + want: true, + }, + { + name: "uses request URL host when Host is empty", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.org:8080/ws", nil) + r.Host = "" + r.URL.Host = "example.org:8080" + r.Header.Set("Origin", "http://example.org:8080") + return r + }, + want: true, + }, + { + name: "treats whitespace origin as absent", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", " ") + return r + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !testEqual(tt.want, sameOriginCheck(tt.req())) { + t.Errorf("expected %v, got %v", tt.want, sameOriginCheck(tt.req())) + } + + }) + } +} + +func TestWs_sameOriginCheck_Bad(t *testing.T) { + tests := []struct { + name string + req func() *http.Request + }{ + { + name: "scheme mismatch", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", "https://example.com") + return r + }, + }, + { + name: "host mismatch", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", "http://evil.example") + return r + }, + }, + { + name: "port mismatch", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com:8080/ws", nil) + r.Header.Set("Origin", "http://example.com:9090") + return r + }, + }, + { + name: "malformed origin", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", "://broken") + return r + }, + }, + { + name: "invalid origin host", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", "http://example.com:bad") + return r + }, + }, + { + name: "invalid origin port after parse", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", "http://[2001:db8::1]:bad") + return r + }, + }, + { + name: "origin host requires brackets for ipv6", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", "http://2001:db8::1") + return r + }, + }, + { + name: "missing origin host", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header.Set("Origin", "http://") + return r + }, + }, + { + name: "invalid request host", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Host = "example.com:bad" + r.Header.Set("Origin", "http://example.com") + return r + }, + }, + { + name: "request host requires brackets for ipv6", + req: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Host = "2001:db8::1" + r.Header.Set("Origin", "http://example.com") + return r + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if sameOriginCheck(tt.req()) { + t.Errorf("expected false") + } + + }) + } +} + +func TestWs_sameOriginCheck_Ugly(t *testing.T) { + if sameOriginCheck(nil) { + t.Errorf("expected false") + } + + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Host = "" + r.URL.Host = "" + r.Header.Set("Origin", "http://example.com") + if sameOriginCheck(r) { + t.Errorf("expected false") + } + +} + +func TestWs_sameOriginCheck_Ugly_NilURL(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.URL = nil + r.Host = "" + r.Header.Set("Origin", "http://example.com") + if sameOriginCheck(r) { + t.Errorf("expected false") + } + +} + +func TestWs_sameOriginCheck_Ugly_MissingSeam(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Host = "[" + r.Header.Set("Origin", "http://example.com") + if sameOriginCheck(r) { + t.Errorf("expected false") + } +} + +func TestWs_safeOriginCheck_Good(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + + called := false + if !(safeOriginCheck(func(req *http.Request) bool { + called = true + if !testSame(r, req) { + t.Errorf("expected same reference") + } + return true + }, r)) { + t.Errorf("expected true") + } + if !(called) { + t.Errorf("expected true") + } + +} + +func TestWs_safeOriginCheck_Bad(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + if safeOriginCheck(func(*http.Request) bool { + return false + }, r) { + t.Errorf("expected false") + } + +} + +func TestWs_safeOriginCheck_Ugly(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + + var check func(*http.Request) bool + if safeOriginCheck(check, r) { + t.Errorf("expected false") + } + +} + +func TestWs_safeAuthenticate_Good(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + + result := safeAuthenticate(AuthenticatorFunc(func(*http.Request) AuthResult { + return AuthResult{Authenticated: true, UserID: "user-123"} + }), r) + if !(result.Valid) { + t.Errorf("expected true") + } + if !(result.Authenticated) { + t.Errorf("expected true") + } + if !testEqual("user-123", result.UserID) { + t.Errorf("expected %v, got %v", "user-123", result.UserID) + } + +} + +func TestWs_safeAuthenticate_Bad(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + + result := safeAuthenticate(AuthenticatorFunc(func(*http.Request) AuthResult { + return AuthResult{Valid: false, Error: core.NewError("denied")} + }), r) + if result.Valid { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if err := result.Error; err == nil || err.Error() != "denied" { + t.Errorf("expected error %q, got %v", "denied", err) + } + +} + +func TestWs_safeAuthenticate_Ugly(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + + result := safeAuthenticate(AuthenticatorFunc(func(*http.Request) AuthResult { + panic("boom") + }), r) + if result.Valid { + t.Errorf("expected false") + } + if result.Authenticated { + t.Errorf("expected false") + } + if err := result.Error; err == nil { + t.Fatalf("expected error") + } + if !testContains(result.Error.Error(), "authenticator panicked") { + t.Errorf("expected %v to contain %v", result.Error.Error(), "authenticator panicked") + } + +} + +func TestWs_splitHostAndPort_Good(t *testing.T) { + tests := []struct { + name string + host string + scheme string + wantH string + wantP string + }{ + {name: "host and port", host: "example.com:8080", scheme: "http", wantH: "example.com", wantP: "8080"}, + {name: "bare host uses http default port", host: "example.com", scheme: "http", wantH: "example.com", wantP: "80"}, + {name: "ipv6 host uses wss default port", host: "[2001:db8::1]", scheme: "wss", wantH: "2001:db8::1", wantP: "443"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host, port, ok := splitHostAndPort(tt.host, tt.scheme) + if !(ok) { + t.Fatalf("expected true") + } + if !testEqual(tt.wantH, host) { + t.Errorf("expected %v, got %v", tt.wantH, host) + } + if !testEqual(tt.wantP, port) { + t.Errorf("expected %v, got %v", tt.wantP, port) + } + + }) + } +} + +func TestWs_splitHostAndPort_Bad(t *testing.T) { + tests := []struct { + name string + host string + }{ + {name: "empty host", host: ""}, + {name: "bare colon", host: ":"}, + {name: "unbracketed ipv6 with port", host: "2001:db8::1:443"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, ok := splitHostAndPort(tt.host, "http") + if ok { + t.Errorf("expected false") + } + + }) + } +} + +func TestWs_splitHostAndPort_Ugly(t *testing.T) { + host, port, ok := splitHostAndPort(" [::1] ", "https") + if !(ok) { + t.Fatalf("expected true") + } + if !testEqual("::1", host) { + t.Errorf("expected %v, got %v", "::1", host) + } + if !testEqual("443", port) { + t.Errorf("expected %v, got %v", "443", port) + } + + host, port, ok = splitHostAndPort("example.com", " ") + if !(ok) { + t.Fatalf("expected true") + } + if !testEqual("example.com", host) { + t.Errorf("expected %v, got %v", "example.com", host) + } + if !testEqual("80", port) { + t.Errorf("expected %v, got %v", "80", port) + } + +} + +func TestWs_splitHostAndPort_Ugly_EmptyBrackets(t *testing.T) { + _, _, ok := splitHostAndPort("[]", "https") + if ok { + t.Errorf("expected false") + } + +} + +func TestWs_NilHubReceivers_Ugly(t *testing.T) { + var hub *Hub + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + if !testEqual(0, hub.ChannelSubscriberCount("notifications")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("notifications")) + } + if !testIsEmpty(slices.Collect(hub.AllClients())) { + t.Errorf("expected empty value, got %v", slices.Collect(hub.AllClients())) + } + if !testIsEmpty(slices.Collect(hub.AllChannels())) { + t.Errorf("expected empty value, got %v", slices.Collect(hub.AllChannels())) + } + if !testEqual(HubStats{}, hub.Stats()) { + t.Errorf("expected %v, got %v", HubStats{}, hub.Stats()) + } + if hub.isRunning() { + t.Errorf("expected false") + } + +} + +func TestWs_defaultPortForScheme_Good(t *testing.T) { + if !testEqual("443", defaultPortForScheme("https")) { + t.Errorf("expected %v, got %v", "443", defaultPortForScheme("https")) + } + if !testEqual("443", defaultPortForScheme("wss")) { + t.Errorf("expected %v, got %v", "443", defaultPortForScheme("wss")) + } + +} + +func TestWs_defaultPortForScheme_Bad(t *testing.T) { + if !testEqual("80", defaultPortForScheme("http")) { + t.Errorf("expected %v, got %v", "80", defaultPortForScheme("http")) + } + if !testEqual("80", defaultPortForScheme("ws")) { + t.Errorf("expected %v, got %v", "80", defaultPortForScheme("ws")) + } + +} + +func TestWs_defaultPortForScheme_Ugly(t *testing.T) { + if !testEqual("443", defaultPortForScheme(" HTTPS ")) { + t.Errorf("expected %v, got %v", "443", defaultPortForScheme(" HTTPS ")) + } + if !testEqual("80", defaultPortForScheme("")) { + t.Errorf("expected %v, got %v", "80", defaultPortForScheme("")) + } + +} + +func TestWs_ClientClose_Good(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: map[string]bool{"alpha": true}, + send: make(chan []byte, 1), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.channels["alpha"] = map[*Client]bool{client: true} + hub.mu.Unlock() + if err := client.Close(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + if client.subscriptions["alpha"] { + t.Errorf("expected false") + } + +} + +func TestWs_ClientClose_Bad(t *testing.T) { + hub := NewHub() + var called bool + hub.config.OnDisconnect = func(*Client) { + called = true + } + + client := &Client{ + hub: hub, + subscriptions: map[string]bool{"alpha": true}, + } + + hub.mu.Lock() + hub.clients[client] = true + hub.channels["alpha"] = map[*Client]bool{client: true} + hub.mu.Unlock() + if err := client.Close(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !(called) { + t.Errorf("expected true") + } + if !testEqual(0, hub.ClientCount()) { + t.Errorf("expected %v, got %v", 0, hub.ClientCount()) + } + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + +} + +func TestWs_ClientClose_Ugly(t *testing.T) { + var client *Client + if err := client.Close(); err != nil { + t.Errorf("expected no error, got %v", err) + } + + client = &Client{} + if err := client.Close(); err != nil { + t.Errorf("expected no error, got %v", err) + } + +} + +func TestWs_Broadcast_Good(t *testing.T) { + hub := NewHub() + err := hub.Broadcast(Message{Type: TypeEvent, Data: "broadcast"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case raw := <-hub.broadcast: + var received Message + if !(core.JSONUnmarshal(raw, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + if !testEqual("broadcast", received.Data) { + t.Errorf("expected %v, got %v", "broadcast", received.Data) + } + if received.Timestamp.IsZero() { + t.Errorf("expected false") + } + + case <-time.After(time.Second): + t.Fatal("broadcast should be queued") + } +} + +func TestWs_Broadcast_Bad(t *testing.T) { + var hub *Hub + + err := hub.Broadcast(Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + +} + +func TestWs_SendToChannel_Good(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + send: make(chan []byte, 1), + subscriptions: make(map[string]bool), + } + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + err := hub.SendToChannel("alpha", Message{Type: TypeEvent, Data: "payload"}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case raw := <-client.send: + var received Message + if !(core.JSONUnmarshal(raw, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual("alpha", received.Channel) { + t.Errorf("expected %v, got %v", "alpha", received.Channel) + } + if !testEqual(TypeEvent, received.Type) { + t.Errorf("expected %v, got %v", TypeEvent, received.Type) + } + if !testEqual("payload", received.Data) { + t.Errorf("expected %v, got %v", "payload", received.Data) + } + if received.Timestamp.IsZero() { + t.Errorf("expected false") + } + + case <-time.After(time.Second): + t.Fatal("channel message should be queued") + } +} + +func TestWs_sendToChannelMessage_PreserveTimestamp_Good(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + send: make(chan []byte, 1), + subscriptions: make(map[string]bool), + } + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + timestamp := time.Date(2026, time.March, 19, 12, 0, 0, 0, time.UTC) + err := hub.sendToChannelMessage("alpha", Message{ + Type: TypeEvent, + Data: "payload", + Timestamp: timestamp, + }, true) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case raw := <-client.send: + var received Message + if !(core.JSONUnmarshal(raw, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(timestamp, received.Timestamp) { + t.Errorf("expected %v, got %v", timestamp, received.Timestamp) + } + if !testEqual("alpha", received.Channel) { + t.Errorf("expected %v, got %v", "alpha", received.Channel) + } + + case <-time.After(time.Second): + t.Fatal("channel message should be queued") + } +} + +func TestWs_broadcastMessage_PreserveTimestamp_Good(t *testing.T) { + hub := NewHub() + + timestamp := time.Date(2026, time.March, 19, 13, 0, 0, 0, time.UTC) + err := hub.broadcastMessage(Message{ + Type: TypeEvent, + Data: "payload", + Timestamp: timestamp, + }, true) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + select { + case raw := <-hub.broadcast: + var received Message + if !(core.JSONUnmarshal(raw, &received).OK) { + t.Fatalf("expected true") + } + if !testEqual(timestamp, received.Timestamp) { + t.Errorf("expected %v, got %v", timestamp, received.Timestamp) + } + + case <-time.After(time.Second): + t.Fatal("broadcast should be queued") + } +} + +func TestWs_SendToChannel_Bad(t *testing.T) { + var hub *Hub + + err := hub.SendToChannel("alpha", Message{Type: TypeEvent}) + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub must not be nil") { + t.Errorf("expected %v to contain %v", err.Error(), "hub must not be nil") + } + +} + +func TestWs_EnqueueUnregister_Good(t *testing.T) { + hub := &Hub{ + unregister: make(chan *Client, 1), + done: make(chan struct{}), + } + client := &Client{} + + hub.enqueueUnregister(client) + + select { + case got := <-hub.unregister: + if !testSame(client, got) { + t.Errorf("expected same reference") + } + + case <-time.After(time.Second): + t.Fatal("expected client to be queued for unregister") + } +} + +func TestWs_EnqueueUnregister_Ugly(t *testing.T) { + testNotPanics(t, func() { + var hub *Hub + hub.enqueueUnregister(nil) + }) + + // Missing seam: the closed-done branch in enqueueUnregister is + // racey to assert without an injectable send primitive. + t.Skip("missing seam: enqueueUnregister closed-done branch is not directly testable") +} + +func TestWs_HandleSubscribeRequest_Good(t *testing.T) { + hub := NewHub() + client := &Client{hub: hub, subscriptions: make(map[string]bool)} + + err := hub.handleSubscribeRequest(subscriptionRequest{ + client: client, + channel: "alpha", + }) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !(client.subscriptions["alpha"]) { + t.Errorf("expected true") + } + if !testEqual(1, hub.ChannelSubscriberCount("alpha")) { + t.Errorf("expected %v, got %v", 1, hub.ChannelSubscriberCount("alpha")) + } + +} + +func TestWs_HandleSubscribeRequest_Ugly(t *testing.T) { + hub := NewHub() + + err := hub.handleSubscribeRequest(subscriptionRequest{}) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !testEqual(0, hub.ChannelCount()) { + t.Errorf("expected %v, got %v", 0, hub.ChannelCount()) + } + +} + +func TestWs_HandleUnsubscribeRequest_Good(t *testing.T) { + hub := NewHub() + client := &Client{hub: hub, subscriptions: make(map[string]bool)} + if err := hub.Subscribe(client, "alpha"); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + hub.handleUnsubscribeRequest(subscriptionRequest{ + client: client, + channel: "alpha", + }) + if client.subscriptions["alpha"] { + t.Errorf("expected false") + } + if !testEqual(0, hub.ChannelSubscriberCount("alpha")) { + t.Errorf("expected %v, got %v", 0, hub.ChannelSubscriberCount("alpha")) + } + +} + +func TestWs_HandleUnsubscribeRequest_Ugly(t *testing.T) { + hub := NewHub() + testNotPanics(t, func() { + hub.handleUnsubscribeRequest(subscriptionRequest{}) + }) + +} + +func TestWs_Subscribe_Bad(t *testing.T) { + hub := NewHub() + client := &Client{hub: hub, subscriptions: make(map[string]bool)} + hub.running = true + close(hub.done) + + err := hub.Subscribe(client, "alpha") + if err := err; err == nil { + t.Fatalf("expected error") + } + if !testContains(err.Error(), "hub is not running") { + t.Errorf("expected %v to contain %v", err.Error(), "hub is not running") + } + +} + +func TestWs_Unsubscribe_Bad(t *testing.T) { + hub := NewHub() + client := &Client{hub: hub, subscriptions: make(map[string]bool)} + hub.running = true + close(hub.done) + testNotPanics(t, func() { + hub.Unsubscribe(client, "alpha") + }) + +} + +func TestWs_ClientClose_Good_ConnOnly(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + defer testClose(t, conn.Close) + time.Sleep(200 * time.Millisecond) + })) + defer server.Close() + + conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil) + if err := err; err != nil { + t.Fatalf("expected no error, got %v", err) + } + + client := &Client{conn: conn} + if err := client.Close(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if err := conn.WriteMessage(websocket.TextMessage, []byte("after-close")); err == nil { + t.Fatalf("expected error") + } + +} + +func TestWs_marshalClientMessage_Good(t *testing.T) { + timestamp := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) + data := marshalClientMessage(Message{ + Type: TypeProcessStatus, + Channel: "alpha", + ProcessID: "proc-1", + Data: map[string]any{"state": "done"}, + Timestamp: timestamp, + }) + if testIsNil(data) { + t.Fatalf("expected non-nil value") + } + + var wire struct { + Type MessageType `json:"type"` + Channel string `json:"channel"` + ProcessID string `json:"processId"` + Data map[string]any `json:"data"` + Timestamp time.Time `json:"timestamp"` + } + if !(core.JSONUnmarshal(data, &wire).OK) { + t.Fatalf("expected true") + } + if !testEqual(TypeProcessStatus, wire.Type) { + t.Errorf("expected %v, got %v", TypeProcessStatus, wire.Type) + } + if !testEqual("alpha", wire.Channel) { + t.Errorf("expected %v, got %v", "alpha", wire.Channel) + } + if !testEqual("proc-1", wire.ProcessID) { + t.Errorf("expected %v, got %v", "proc-1", wire.ProcessID) + } + if !testEqual("done", wire.Data["state"]) { + t.Errorf("expected %v, got %v", "done", wire.Data["state"]) + } + if !testEqual(timestamp, wire.Timestamp) { + t.Errorf("expected %v, got %v", timestamp, wire.Timestamp) + } + +} + +func TestWs_marshalClientMessage_Bad(t *testing.T) { + data := marshalClientMessage(Message{ + Type: TypeEvent, + Data: make(chan int), + }) + if !testIsNil(data) { + t.Errorf("expected nil, got %T", data) + } + +} + +func TestWs_dispatchReconnectMessage_Good_BlankFrames(t *testing.T) { + seen := make([]Message, 0, 2) + + dispatchReconnectMessage(func(msg Message) { + seen = append(seen, msg) + }, []byte("\n{\"type\":\"event\",\"data\":\"alpha\"}\n\n{\"type\":\"error\",\"data\":\"beta\"}\n")) + if gotLen := len(seen); gotLen != 2 { + t.Fatalf("expected length %v, got %v", 2, gotLen) + } + if !testEqual(TypeEvent, seen[0].Type) { + t.Errorf("expected %v, got %v", TypeEvent, seen[0].Type) + } + if !testEqual("alpha", seen[0].Data) { + t.Errorf("expected %v, got %v", "alpha", seen[0].Data) + } + if !testEqual(TypeError, seen[1].Type) { + t.Errorf("expected %v, got %v", TypeError, seen[1].Type) + } + if !testEqual("beta", seen[1].Data) { + t.Errorf("expected %v, got %v", "beta", seen[1].Data) + } + +} + +func TestWs_dispatchReconnectMessage_Ugly_NilCallbacks(t *testing.T) { + testNotPanics(t, func() { + var raw func([]byte) + var msgFn func(Message) + var stringFn func(string) + dispatchReconnectMessage(raw, []byte("payload")) + dispatchReconnectMessage(msgFn, []byte("{\"type\":\"event\"}")) + dispatchReconnectMessage(stringFn, []byte("payload")) + }) - require.Len(t, ctxErr, 1) }