From 6359fdec947653a24fcd4f43af8deec82c957eaa Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Fri, 13 Feb 2026 14:00:10 +0100 Subject: [PATCH] refactor: stream management into internal/stream package This commit moves the core stream management logic from the `gorums` package into a new `internal/stream` package to improve encapsulation and testability. Key changes: - Moved channel.go to `internal/stream` containing `Channel`, `Request`, and `Message` types along with their associated logic. - `Channel` now encapsulates the `grpc.ClientStream` lifecycle, handling connection establishment, reconnection, and message routing. - The `Node` struct delegates underlying communication and metrics (Latency, LastErr) to `Channel`. - Refactored `sender` and `receiver` loops in `Channel` for better robustness and to properly drain the send queue on closure, fixing potential deadlocks. - Updated `channel_test.go` with robust, table-driven tests covering connection states, context cancellation, and concurrent sends. - Added benchmarks for channel performance and stream reconnection. --- client_interceptor.go | 16 +- encoding.go | 62 -- errors.go | 4 +- channel.go => internal/stream/channel.go | 189 ++-- .../stream/channel_test.go | 897 ++++++++++-------- internal/stream/marshaling.go | 67 ++ internal/stream/response.go | 20 + node.go | 18 +- node_test.go | 34 +- responses.go | 23 + rpc.go | 2 +- server.go | 2 +- unicast.go | 4 +- 13 files changed, 730 insertions(+), 608 deletions(-) rename channel.go => internal/stream/channel.go (69%) rename channel_test.go => internal/stream/channel_test.go (52%) create mode 100644 internal/stream/marshaling.go create mode 100644 internal/stream/response.go diff --git a/client_interceptor.go b/client_interceptor.go index 52cceac54..9a0937583 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -201,12 +201,12 @@ func (c *ClientCtx[Req, Resp]) send() { if streamMsg == nil { continue // Skip node: transformAndMarshal already sent ErrSkipNode } - n.channel.enqueue(request{ - ctx: c.Context, - msg: streamMsg, - streaming: c.streaming, - waitSendDone: c.waitSendDone, - responseChan: c.replyChan, + n.channel.Enqueue(stream.Request{ + Ctx: c.Context, + Msg: streamMsg, + Streaming: c.streaming, + WaitSendDone: c.waitSendDone, + ResponseChan: c.replyChan, }) } } @@ -241,7 +241,7 @@ func (c *ClientCtx[Req, Resp]) defaultResponseSeq() ResponseSeq[Resp] { for range c.Size() { select { case r := <-c.replyChan: - res := newNodeResponse[Resp](r) + res := mapToCallResponse[Resp](r) if !yield(res) { return // Consumer stopped iteration } @@ -261,7 +261,7 @@ func (c *ClientCtx[Req, Resp]) streamingResponseSeq() ResponseSeq[Resp] { for { select { case r := <-c.replyChan: - res := newNodeResponse[Resp](r) + res := mapToCallResponse[Resp](r) if !yield(res) { return // Consumer stopped iteration } diff --git a/encoding.go b/encoding.go index f916bdc64..4df19b50d 100644 --- a/encoding.go +++ b/encoding.go @@ -1,14 +1,10 @@ package gorums import ( - "fmt" - "github.com/relab/gorums/internal/stream" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/reflect/protoregistry" ) // Message encapsulates the stream.Message and the actual proto.Message. @@ -82,61 +78,3 @@ func messageWithError(in, out *Message, err error) *Message { } return out } - -// unmarshalRequest unmarshals the request proto message from the message. -// It uses the method name in the message to look up the Input type from the proto registry. -// -// This function should only be used by internal channel operations. -func unmarshalRequest(in *stream.Message) (proto.Message, error) { - // get method descriptor from registry - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(in.GetMethod())) - if err != nil { - return nil, fmt.Errorf("gorums: could not find method descriptor for %s", in.GetMethod()) - } - methodDesc := desc.(protoreflect.MethodDescriptor) - - // get the request message type (Input type) - msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Input().FullName()) - if err != nil { - return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Input().FullName()) - } - req := msgType.New().Interface() - - // unmarshal message from the Message.Payload field - payload := in.GetPayload() - if len(payload) > 0 { - if err := proto.Unmarshal(payload, req); err != nil { - return nil, fmt.Errorf("gorums: could not unmarshal request: %w", err) - } - } - return req, nil -} - -// unmarshalResponse unmarshals the response proto message from the message. -// It uses the method name in the message to look up the Output type from the proto registry. -// -// This function should only be used by internal channel operations. -func unmarshalResponse(out *stream.Message) (proto.Message, error) { - // get method descriptor from registry - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(out.GetMethod())) - if err != nil { - return nil, fmt.Errorf("gorums: could not find method descriptor for %s", out.GetMethod()) - } - methodDesc := desc.(protoreflect.MethodDescriptor) - - // get the response message type (Output type) - msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Output().FullName()) - if err != nil { - return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Output().FullName()) - } - resp := msgType.New().Interface() - - // unmarshal message from the Message.Payload field - payload := out.GetPayload() - if len(payload) > 0 { - if err := proto.Unmarshal(payload, resp); err != nil { - return nil, fmt.Errorf("gorums: could not unmarshal response: %w", err) - } - } - return resp, nil -} diff --git a/errors.go b/errors.go index 0834ad9c5..bb6794a8d 100644 --- a/errors.go +++ b/errors.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "strings" + + "github.com/relab/gorums/internal/stream" ) // ErrIncomplete is the error returned by a quorum call when the call cannot be completed @@ -14,7 +16,7 @@ var ErrIncomplete = errors.New("incomplete call") var ErrSendFailure = errors.New("send failure") // ErrTypeMismatch is returned when a response cannot be cast to the expected type. -var ErrTypeMismatch = errors.New("response type mismatch") +var ErrTypeMismatch = stream.ErrTypeMismatch // ErrSkipNode is returned when a node is skipped by request transformations. // This allows the response iterator to account for all nodes without blocking. diff --git a/channel.go b/internal/stream/channel.go similarity index 69% rename from channel.go rename to internal/stream/channel.go index edabcf004..3c88a4ad6 100644 --- a/channel.go +++ b/internal/stream/channel.go @@ -1,58 +1,33 @@ -package gorums +package stream import ( "context" "sync" "time" - "github.com/relab/gorums/internal/stream" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" ) -// NodeResponse wraps a response value from node ID, and an error if any. -type NodeResponse[T any] struct { - NodeID uint32 - Value T - Err error -} - -// newNodeResponse converts a NodeResponse[msg] to a NodeResponse[Resp]. -// This is necessary because the channel layer's response router returns a -// NodeResponse[msg], while the calltype expects a NodeResponse[Resp]. -func newNodeResponse[Resp msg](r NodeResponse[msg]) NodeResponse[Resp] { - res := NodeResponse[Resp]{ - NodeID: r.NodeID, - Err: r.Err, - } - if r.Err == nil { - if val, ok := r.Value.(Resp); ok { - res.Value = val - } else { - res.Err = ErrTypeMismatch - } - } - return res -} - var ( - nodeClosedErr = status.Error(codes.Unavailable, "node closed") - streamDownErr = status.Error(codes.Unavailable, "stream is down") + ErrNodeClosed = status.Error(codes.Unavailable, "node closed") + ErrStreamDown = status.Error(codes.Unavailable, "stream is down") ) -type request struct { - ctx context.Context - msg *stream.Message - streaming bool - waitSendDone bool - responseChan chan<- NodeResponse[msg] - sendTime time.Time +type Request struct { + Ctx context.Context + Msg *Message + Streaming bool + WaitSendDone bool + ResponseChan chan<- response + SendTime time.Time } -type channel struct { - sendQ chan request +type Channel struct { + sendQ chan Request id uint32 // Connection lifecycle management: node close() cancels the @@ -69,7 +44,7 @@ type channel struct { // Stream lifecycle management for FIFO ordered message delivery // stream is a bidirectional stream for // sending and receiving stream.Message messages. - stream stream.Gorums_NodeStreamClient + stream Gorums_NodeStreamClient streamMut sync.Mutex streamCtx context.Context streamCancel context.CancelFunc @@ -78,33 +53,35 @@ type channel struct { // Response routing; the map holds pending requests waiting for responses. // The request contains the responseChan on which to send the response // to the caller. - responseRouters map[uint64]request + responseRouters map[uint64]Request responseMut sync.Mutex closeOnceFunc func() error } -// newChannel creates a new channel for the given node and starts the sender +// NewChannel creates a new channel for the given node and starts the sender // and receiver goroutines. // // Note that we start both goroutines even though the connection and stream // have not yet been established. This is to prevent deadlock when invoking // a call type. The sender blocks on the sendQ and the receiver waits for // the stream to become available. -func newChannel(parentCtx context.Context, conn *grpc.ClientConn, id uint32, sendBufferSize uint) *channel { +func NewChannel(parentCtx context.Context, conn *grpc.ClientConn, id uint32, sendBufferSize uint) *Channel { ctx, connCancel := context.WithCancel(parentCtx) - c := &channel{ - sendQ: make(chan request, sendBufferSize), + c := &Channel{ + sendQ: make(chan Request, sendBufferSize), id: id, conn: conn, connCtx: ctx, connCancel: connCancel, latency: -1 * time.Second, - responseRouters: make(map[uint64]request), + responseRouters: make(map[uint64]Request), streamReady: make(chan struct{}, 1), } c.closeOnceFunc = sync.OnceValue(func() error { // important to cancel first to stop goroutines c.connCancel() + // unblocks any pending senders/receivers + c.cancelPendingMsgs(ErrNodeClosed) if c.conn != nil { return c.conn.Close() } @@ -115,8 +92,17 @@ func newChannel(parentCtx context.Context, conn *grpc.ClientConn, id uint32, sen return c } -// close closes the channel and the underlying connection exactly once. -func (c *channel) close() error { +// NewChannelWithState creates a new Channel with a specific state for testing. +// This function should only be used in tests. +func NewChannelWithState(latency time.Duration, lastErr error) *Channel { + return &Channel{ + latency: latency, + lastError: lastErr, + } +} + +// Close closes the channel and the underlying connection exactly once. +func (c *Channel) Close() error { return c.closeOnceFunc() } @@ -124,7 +110,7 @@ func (c *channel) close() error { // receiver goroutines, and signals the receiver when the stream is ready. // gRPC automatically handles TCP connection state when creating the stream. // This method is safe for concurrent use. -func (c *channel) ensureStream() error { +func (c *Channel) ensureStream() error { if err := c.ensureConnectedNodeStream(); err != nil { return err } @@ -140,7 +126,7 @@ func (c *channel) ensureStream() error { // ensureConnectedNodeStream ensures there is an active and connected // NodeStream, or creates a new stream if one doesn't already exist. // This method is safe for concurrent use. -func (c *channel) ensureConnectedNodeStream() (err error) { +func (c *Channel) ensureConnectedNodeStream() (err error) { c.streamMut.Lock() defer c.streamMut.Unlock() // if we already have a ready connection and an active stream, do nothing @@ -149,12 +135,12 @@ func (c *channel) ensureConnectedNodeStream() (err error) { return nil } c.streamCtx, c.streamCancel = context.WithCancel(c.connCtx) - c.stream, err = stream.NewGorumsClient(c.conn).NodeStream(c.streamCtx) + c.stream, err = NewGorumsClient(c.conn).NodeStream(c.streamCtx) return err } // getStream returns the current stream, or nil if no stream is available. -func (c *channel) getStream() stream.Gorums_NodeStreamClient { +func (c *Channel) getStream() Gorums_NodeStreamClient { c.streamMut.Lock() defer c.streamMut.Unlock() return c.stream @@ -162,7 +148,7 @@ func (c *channel) getStream() stream.Gorums_NodeStreamClient { // clearStream cancels the current stream context and clears the stream reference. // This triggers reconnection on the next send attempt. -func (c *channel) clearStream() { +func (c *Channel) clearStream() { c.streamMut.Lock() c.streamCancel() c.stream = nil @@ -171,16 +157,16 @@ func (c *channel) clearStream() { // isConnected returns true if the gRPC connection is in Ready state and we have an active stream. // This method is safe for concurrent use. -func (c *channel) isConnected() bool { +func (c *Channel) isConnected() bool { return c.conn.GetState() == connectivity.Ready && c.getStream() != nil } -// enqueue adds the request to the send queue and sets up response routing if needed. +// Enqueue adds the request to the send queue and sets up response routing if needed. // If the node is closed, it responds with an error instead. -func (c *channel) enqueue(req request) { - if req.responseChan != nil { - req.sendTime = time.Now() - msgID := req.msg.GetMessageSeqNo() +func (c *Channel) Enqueue(req Request) { + if req.ResponseChan != nil { + req.SendTime = time.Now() + msgID := req.Msg.GetMessageSeqNo() c.responseMut.Lock() c.responseRouters[msgID] = req c.responseMut.Unlock() @@ -188,9 +174,15 @@ func (c *channel) enqueue(req request) { // either enqueue the request on the sendQ or respond // with error if the node is closed select { + case <-c.connCtx.Done(): + c.routeResponse(req.Msg.GetMessageSeqNo(), response{NodeID: c.id, Err: ErrNodeClosed}) + return + default: + } + select { case <-c.connCtx.Done(): // the node's close() method was called: respond with error instead of enqueueing - c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: nodeClosedErr}) + c.routeResponse(req.Msg.GetMessageSeqNo(), response{NodeID: c.id, Err: ErrNodeClosed}) return case c.sendQ <- req: // enqueued successfully @@ -199,16 +191,16 @@ func (c *channel) enqueue(req request) { // routeResponse routes the response to the appropriate response channel based on msgID. // If no matching request is found, the response is discarded. -func (c *channel) routeResponse(msgID uint64, resp NodeResponse[msg]) { +func (c *Channel) routeResponse(msgID uint64, resp response) { c.responseMut.Lock() defer c.responseMut.Unlock() if req, ok := c.responseRouters[msgID]; ok { if resp.Err == nil { - c.updateLatency(time.Since(req.sendTime)) + c.updateLatency(time.Since(req.SendTime)) } - req.responseChan <- resp + req.ResponseChan <- resp // delete the router if we are only expecting a single reply message - if !req.streaming { + if !req.Streaming { delete(c.responseRouters, msgID) } } @@ -216,33 +208,38 @@ func (c *channel) routeResponse(msgID uint64, resp NodeResponse[msg]) { // cancelPendingMsgs cancels all pending messages by sending an error response to each // associated request. This is called when the stream goes down to notify all waiting calls. -func (c *channel) cancelPendingMsgs(err error) { +func (c *Channel) cancelPendingMsgs(err error) { c.responseMut.Lock() defer c.responseMut.Unlock() for msgID, req := range c.responseRouters { - req.responseChan <- NodeResponse[msg]{NodeID: c.id, Err: err} + req.ResponseChan <- response{NodeID: c.id, Err: err} // delete the router if we are only expecting a single reply message - if !req.streaming { + if !req.Streaming { delete(c.responseRouters, msgID) } } } -// deleteRouter removes the response router for the given msgID. -// This is used for cleaning up after streaming calls are done. -func (c *channel) deleteRouter(msgID uint64) { - c.responseMut.Lock() - defer c.responseMut.Unlock() - delete(c.responseRouters, msgID) -} - // sender goroutine takes requests from the sendQ and sends them on the stream. // If the stream is down, it tries to re-establish it. -func (c *channel) sender() { +func (c *Channel) sender() { + defer func() { + // drain sendQ and error all requests + for { + select { + case req := <-c.sendQ: + c.routeResponse(req.Msg.GetMessageSeqNo(), response{NodeID: c.id, Err: ErrNodeClosed}) + default: + // sendQ is empty + return + } + } + }() + // eager connect; ignored if stream is down (will be retried on send) _ = c.ensureStream() - var req request + var req Request for { select { case <-c.connCtx.Done(): @@ -252,11 +249,11 @@ func (c *channel) sender() { // take next request from sendQ } if err := c.ensureStream(); err != nil { - c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) + c.routeResponse(req.Msg.GetMessageSeqNo(), response{NodeID: c.id, Err: err}) continue } if err := c.sendMsg(req); err != nil { - c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) + c.routeResponse(req.Msg.GetMessageSeqNo(), response{NodeID: c.id, Err: err}) } } } @@ -264,7 +261,7 @@ func (c *channel) sender() { // receiver goroutine receives messages from the stream and routes them to // the appropriate response router. If the stream goes down, it clears the // stream reference and cancels all pending messages with a stream down error. -func (c *channel) receiver() { +func (c *Channel) receiver() { for { stream := c.getStream() if stream == nil { @@ -286,11 +283,11 @@ func (c *channel) receiver() { c.clearStream() } else { err := respMsg.ErrorStatus() - var resp msg + var resp proto.Message if err == nil { - resp, err = unmarshalResponse(respMsg) + resp, err = UnmarshalResponse(respMsg) } - c.routeResponse(respMsg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Value: resp, Err: err}) + c.routeResponse(respMsg.GetMessageSeqNo(), response{NodeID: c.id, Value: resp, Err: err}) } select { @@ -302,7 +299,7 @@ func (c *channel) receiver() { } } -func (c *channel) sendMsg(req request) (err error) { +func (c *Channel) sendMsg(req Request) (err error) { defer func() { // For one-way call types (Unicast/Multicast), the caller can choose between two behaviors: // @@ -318,20 +315,20 @@ func (c *channel) sendMsg(req request) (err error) { // // Note: Two-way call types (RPCCall, QuorumCall) do not use this mechanism, they always // wait for actual server responses, so waitSendDone is false for them. - if req.waitSendDone && err == nil { + if req.WaitSendDone && err == nil { // Send succeeded: unblock the caller and clean up the responseRouter - c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{}) + c.routeResponse(req.Msg.GetMessageSeqNo(), response{}) } }() // don't send if context is already cancelled. - if req.ctx.Err() != nil { - return req.ctx.Err() + if req.Ctx.Err() != nil { + return req.Ctx.Err() } stream := c.getStream() if stream == nil { - return streamDownErr + return ErrStreamDown } done := make(chan struct{}) @@ -345,7 +342,7 @@ func (c *channel) sendMsg(req request) (err error) { select { case <-done: // all is good - case <-req.ctx.Done(): + case <-req.Ctx.Done(): // Both channels could be ready at the same time, so we must check 'done' again. select { case <-done: @@ -357,7 +354,7 @@ func (c *channel) sendMsg(req request) (err error) { } }() - if err = stream.Send(req.msg); err != nil { + if err = stream.Send(req.Msg); err != nil { c.setLastErr(err) c.clearStream() } @@ -366,21 +363,21 @@ func (c *channel) sendMsg(req request) (err error) { return err } -func (c *channel) setLastErr(err error) { +func (c *Channel) setLastErr(err error) { c.mu.Lock() defer c.mu.Unlock() c.lastError = err } -// lastErr returns the last error encountered (if any) when using this channel. -func (c *channel) lastErr() error { +// LastErr returns the last error encountered (if any) when using this channel. +func (c *Channel) LastErr() error { c.mu.Lock() defer c.mu.Unlock() return c.lastError } -// channelLatency returns the latency between the client and the server associated with this channel. -func (c *channel) channelLatency() time.Duration { +// Latency returns the latency between the client and the server associated with this channel. +func (c *Channel) Latency() time.Duration { c.mu.Lock() defer c.mu.Unlock() return c.latency @@ -388,7 +385,7 @@ func (c *channel) channelLatency() time.Duration { // updateLatency updates the latency between the client and the server associated with this channel. // It uses a simple moving average to calculate the latency. -func (c *channel) updateLatency(rtt time.Duration) { +func (c *Channel) updateLatency(rtt time.Duration) { c.mu.Lock() defer c.mu.Unlock() if c.latency < 0 { diff --git a/channel_test.go b/internal/stream/channel_test.go similarity index 52% rename from channel_test.go rename to internal/stream/channel_test.go index ec7aee29d..39f5b4963 100644 --- a/channel_test.go +++ b/internal/stream/channel_test.go @@ -1,146 +1,210 @@ -package gorums +package stream import ( "context" "errors" "fmt" + "net" "strings" "sync" + "sync/atomic" "testing" "time" - "github.com/relab/gorums/internal/stream" "github.com/relab/gorums/internal/testutils/mock" "google.golang.org/grpc" - pb "google.golang.org/protobuf/types/known/wrapperspb" + "google.golang.org/grpc/credentials/insecure" ) const ( defaultTestTimeout = 3 * time.Second - streamConnectTimeout = 500 * time.Millisecond + streamConnectTimeout = 3 * time.Second ) -// waitForConnection polls until the node is connected or timeout expires. -// Returns true if connected, false if timeout expired. -func waitForConnection(node *Node, timeout time.Duration) bool { - deadline := time.Now().Add(timeout) - for time.Now().Before(deadline) { - if node.channel.isConnected() { - return true +// testChannel holds the channel and cleanup function. +type testChannel struct { + *Channel + srv *grpc.Server + lis net.Listener +} + +// echoServer serves as a generic server that echoes back any message. +func echoServer(stream Gorums_NodeStreamServer) error { + for { + in, err := stream.Recv() + if err != nil { + return err + } + // Echo back + if err := stream.Send(in); err != nil { + return err } - time.Sleep(time.Millisecond) } - return node.channel.isConnected() } -type mockSrv struct{} +// delayServer serves a server that delays each message by delay +func delayServer(delay time.Duration) func(stream Gorums_NodeStreamServer) error { + return func(stream Gorums_NodeStreamServer) error { + for { + in, err := stream.Recv() + if err != nil { + return err + } + time.Sleep(delay) + if err := stream.Send(in); err != nil { + return err + } + } + } +} -func (mockSrv) Test(_ ServerCtx, req *pb.StringValue) (*pb.StringValue, error) { - return pb.String(req.GetValue() + "-mocked-"), nil +// A server that drops the stream after first message +func breakStreamServer(stream Gorums_NodeStreamServer) error { + msg, err := stream.Recv() + if err != nil { + return err + } + stream.Send(msg) + return errors.New("stream broken") } -// delayServerFn returns a server function that delays responses by the given duration. -func delayServerFn(delay time.Duration) func(_ int) ServerIface { - return func(_ int) ServerIface { - mockSrv := &mockSrv{} - srv := NewServer() - srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { - // Simulate slow processing - time.Sleep(delay) - req := AsProto[*pb.StringValue](in) - resp, err := mockSrv.Test(ctx, req) - if err != nil { - return nil, err - } - return NewResponseMessage(in, resp), nil - }) - return srv +// holdServer hangs, effectively blocking the stream until context cancellation. +func holdServer(stream Gorums_NodeStreamServer) error { + <-stream.Context().Done() + return nil +} + +// setupChannel creates a channel connected to a server. +func setupChannel(t testing.TB, serverFn func(Gorums_NodeStreamServer) error, opts ...grpc.ServerOption) *testChannel { + t.Helper() + + // Start listener + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + + // Start server + srv := grpc.NewServer(opts...) + if serverFn == nil { + t.Fatal("setupChannel: serverFn must be provided; use echoServer for default behavior") + } + RegisterGorumsServer(srv, &mockServer{handler: serverFn}) + go srv.Serve(lis) + + // Create channel + conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + + c := NewChannel(t.Context(), conn, 1, 10) + tc := &testChannel{ + Channel: c, + srv: srv, + lis: lis, + } + + t.Cleanup(func() { + c.Close() + srv.Stop() + lis.Close() + }) + return tc +} + +type mockServer struct { + UnimplementedGorumsServer + handler func(Gorums_NodeStreamServer) error +} + +func (s *mockServer) NodeStream(srv Gorums_NodeStreamServer) error { + return s.handler(srv) +} + +// setupChannelWithoutServer creates a channel that tries to connect to a non-existent server. +func setupChannelWithoutServer(t testing.TB) *testChannel { + t.Helper() + // Pick a random port (hopefully unused) or loopback without listener + conn, err := grpc.NewClient("127.0.0.1:54321", grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + c := NewChannel(ctx, conn, 1, 10) + t.Cleanup(func() { + cancel() + c.Close() + }) + return &testChannel{ + Channel: c, + } +} + +// waitForConnection polls until the node is connected or timeout expires. +// Returns true if connected, false if timeout expired. +func waitForConnection(c *Channel, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if c.isConnected() { + return true + } + time.Sleep(10 * time.Millisecond) } + return c.isConnected() } -func sendRequest(t testing.TB, node *Node, req request, msgID uint64) NodeResponse[msg] { +func sendRequest(t testing.TB, c *Channel, req Request, msgID uint64) response { t.Helper() - if req.ctx == nil { - req.ctx = t.Context() + if req.Ctx == nil { + req.Ctx = context.Background() } - reqMsg, err := stream.NewMessage(req.ctx, msgID, mock.TestMethod, nil) + reqMsg, err := NewMessage(req.Ctx, msgID, mock.TestMethod, nil) if err != nil { t.Fatalf("NewMessage failed: %v", err) } - req.msg = reqMsg - replyChan := make(chan NodeResponse[msg], 1) - req.responseChan = replyChan - node.channel.enqueue(req) + req.Msg = reqMsg + replyChan := make(chan response, 1) + req.ResponseChan = replyChan + c.Enqueue(req) select { case resp := <-replyChan: return resp case <-time.After(defaultTestTimeout): t.Fatalf("timeout waiting for response to message %d", msgID) - return NodeResponse[msg]{} + return response{} } } type msgResponse struct { msgID uint64 - resp NodeResponse[msg] + resp response } -func sendReq(t testing.TB, results chan<- msgResponse, node *Node, goroutineID, msgsToSend int, req request) { +func sendReq(t testing.TB, results chan<- msgResponse, c *Channel, goroutineID, msgsToSend int, req Request) { for j := range msgsToSend { msgID := uint64(goroutineID*1000 + j) - resp := sendRequest(t, node, req, msgID) + resp := sendRequest(t, c, req, msgID) results <- msgResponse{msgID: msgID, resp: resp} } } -// testNodeWithoutServer creates a node for the given server address and -// adds it to a new manager. This is useful for testing node and channel -// behavior without an active server. The manager is automatically closed -// when the test finishes. -func testNodeWithoutServer(t testing.TB, opts ...ManagerOption) *Node { - t.Helper() - mgrOpts := append([]ManagerOption{InsecureDialOptions(t)}, opts...) - mgr := NewManager(mgrOpts...) - t.Cleanup(Closer(t, mgr)) - // Use a high port number that's unlikely to have anything listening. - // We use a fixed ID for simplicity. - node, err := mgr.newNode("127.0.0.1:59999", 1) - if err != nil { - t.Fatal(err) - } - return node -} - -// Helper functions for accessing channel internals - -func routerExists(node *Node, msgID uint64) bool { - node.channel.responseMut.Lock() - defer node.channel.responseMut.Unlock() - _, exists := node.channel.responseRouters[msgID] - return exists -} - -func getStream(node *Node) grpc.ClientStream { - return node.channel.getStream() -} - func TestChannelCreation(t *testing.T) { - node := testNodeWithoutServer(t) + tc := setupChannelWithoutServer(t) // send message when server is down - resp := sendRequest(t, node, request{waitSendDone: true}, 1) + resp := sendRequest(t, tc.Channel, Request{WaitSendDone: true}, 1) if resp.Err == nil { t.Error("response err: got , want error") } } -// TestChannelShutdown verifies proper cleanup when channel is closed. func TestChannelShutdown(t *testing.T) { - node := TestNode(t, delayServerFn(0)) - // Wait for stream to be established - if !waitForConnection(node, streamConnectTimeout) { - t.Fatal("node should be connected") + tc := setupChannel(t, echoServer) + + if !waitForConnection(tc.Channel, streamConnectTimeout) { + t.Fatal("channel should be connected") } // enqueue several messages to confirm normal operation @@ -148,7 +212,7 @@ func TestChannelShutdown(t *testing.T) { var wg sync.WaitGroup for i := range numMessages { wg.Go(func() { - resp := sendRequest(t, node, request{}, uint64(i)) + resp := sendRequest(t, tc.Channel, Request{}, uint64(i)) if resp.Err != nil { t.Errorf("unexpected error for message %d, got error: %v", i, resp.Err) } @@ -156,76 +220,49 @@ func TestChannelShutdown(t *testing.T) { } wg.Wait() - // shut down the node's channel - if err := node.close(); err != nil { - t.Errorf("error closing node: %v", err) + // shut down the channel + if err := tc.Close(); err != nil { + t.Errorf("error closing channel: %v", err) } - // try to send a message after node closure - resp := sendRequest(t, node, request{}, 999) + // try to send a message after closure + resp := sendRequest(t, tc.Channel, Request{}, 999) if resp.Err == nil { t.Error("expected error when sending to closed channel") } else if !strings.Contains(resp.Err.Error(), "node closed") { t.Errorf("expected 'node closed' error, got: %v", resp.Err) } - if node.channel.isConnected() { + if tc.isConnected() { t.Error("channel should not be connected after close") } } -// TestChannelSendBufferSize verifies that setting the send buffer size option -// works as expected (at least that it doesn't panic and channels are functional). -func TestChannelSendBufferSize(t *testing.T) { - bufferSizes := []uint{0, 1, 10, 100, 1024} - - for _, size := range bufferSizes { - t.Run(string(rune(size)), func(t *testing.T) { - node := TestNode(t, EchoServerFn, WithSendBufferSize(size)) - - ctx := TestContext(t, time.Second) - _, err := RPCCall[*pb.StringValue, *pb.StringValue](node.Context(ctx), pb.String("test"), mock.TestMethod) - if err != nil { - t.Errorf("RPCCall failed with buffer size %d: %v", size, err) - } - }) - } -} - -// TestChannelLatency verifies that the channel latency is updated after requests. func TestChannelLatency(t *testing.T) { const minDelay = 20 * time.Millisecond - node := TestNode(t, delayServerFn(minDelay)) + tc := setupChannel(t, delayServer(minDelay)) // Initial latency should be -1 - if latency := node.Latency(); latency != -1*time.Second { + if latency := tc.Latency(); latency != -1*time.Second { t.Errorf("Initial latency = %v, expected -1s", latency) } - // Send a few requests to update latency; we need a few samples for the - // exponential weighted moving average to stabilize to the real RTT. + // Send a few requests to update latency for i := range 10 { - ctx := TestContext(t, time.Second) - _, err := RPCCall[*pb.StringValue, *pb.StringValue](node.Context(ctx), pb.String("ping"), mock.TestMethod) - if err != nil { - t.Fatalf("RPCCall %d failed: %v", i, err) - } + sendRequest(t, tc.Channel, Request{WaitSendDone: false}, uint64(i)) } - // Latency should be positive and roughly > 20ms - latency := node.Latency() + latency := tc.Latency() if latency <= 0 { t.Errorf("Latency = %v, expected > 0", latency) } if latency < minDelay { t.Errorf("Latency = %v, expected >= %v (server delay)", latency, minDelay) } - t.Logf("Measured latency: %v", latency) } -// TestChannelSendCompletionWaiting verifies the behavior of send completion waiting. func TestChannelSendCompletionWaiting(t *testing.T) { - node := TestNode(t, delayServerFn(0)) + tc := setupChannel(t, echoServer) tests := []struct { name string @@ -237,7 +274,7 @@ func TestChannelSendCompletionWaiting(t *testing.T) { for i, tt := range tests { t.Run(tt.name, func(t *testing.T) { start := time.Now() - resp := sendRequest(t, node, request{waitSendDone: tt.waitSendDone}, uint64(i)) + resp := sendRequest(t, tc.Channel, Request{WaitSendDone: tt.waitSendDone}, uint64(i)) elapsed := time.Since(start) if resp.Err != nil { t.Errorf("unexpected error: %v", resp.Err) @@ -247,66 +284,52 @@ func TestChannelSendCompletionWaiting(t *testing.T) { } } -// TestChannelErrors verifies error detection and handling in various scenarios. func TestChannelErrors(t *testing.T) { tests := []struct { name string - setup func(t *testing.T) *Node + setup func(t *testing.T) *testChannel wantErr string }{ { name: "EnqueueWithoutServer", - setup: func(t *testing.T) *Node { - return testNodeWithoutServer(t) + setup: func(t *testing.T) *testChannel { + return setupChannelWithoutServer(t) }, wantErr: "connection error", }, { name: "EnqueueToClosedChannel", - setup: func(t *testing.T) *Node { - node := testNodeWithoutServer(t) - err := node.close() - if err != nil { - t.Errorf("failed to close node: %v", err) + setup: func(t *testing.T) *testChannel { + tc := setupChannelWithoutServer(t) + if err := tc.Close(); err != nil { + t.Errorf("failed to close channel: %v", err) } - return node - }, - wantErr: "node closed", - }, - { - name: "EnqueueToServerWithClosedNode", - setup: func(t *testing.T) *Node { - node := TestNode(t, delayServerFn(0)) - err := node.close() - if err != nil { - t.Errorf("failed to close node: %v", err) - } - return node + return tc }, wantErr: "node closed", }, { name: "ServerFailureDuringCommunication", - setup: func(t *testing.T) *Node { - var stopServer func(...int) - node := TestNode(t, delayServerFn(0), WithStopFunc(t, &stopServer)) - resp := sendRequest(t, node, request{waitSendDone: true}, 1) + setup: func(t *testing.T) *testChannel { + tc := setupChannel(t, echoServer) + // Send a message to ensure connection is established + resp := sendRequest(t, tc.Channel, Request{WaitSendDone: true}, 1) if resp.Err != nil { - t.Errorf("first message should succeed, got error: %v", resp.Err) + t.Errorf("initial message send should succeed, got error: %v", resp.Err) } - stopServer() - return node + // Stop the server to simulate failure + tc.srv.Stop() + return tc }, wantErr: "connection error", }, } for i, tt := range tests { t.Run(tt.name, func(t *testing.T) { - node := tt.setup(t) + tc := tt.setup(t) time.Sleep(100 * time.Millisecond) - // Send message and verify error - resp := sendRequest(t, node, request{waitSendDone: true}, uint64(i)) + resp := sendRequest(t, tc.Channel, Request{WaitSendDone: true}, uint64(i)) if resp.Err == nil { t.Errorf("expected error containing %q but got nil", tt.wantErr) } else if !strings.Contains(resp.Err.Error(), tt.wantErr) { @@ -319,15 +342,15 @@ func TestChannelErrors(t *testing.T) { // TestChannelEnsureStream verifies that ensureStream correctly manages stream lifecycle. func TestChannelEnsureStream(t *testing.T) { // Helper to prepare a fresh node with no stream - newNodeWithoutStream := func(t *testing.T) *Node { - node := TestNode(t, delayServerFn(0)) + newChannelWithoutStream := func(t *testing.T) *testChannel { + tc := setupChannel(t, echoServer) // ensure sender and receiver goroutines are stopped - node.channel.connCancel() + tc.connCancel() // Extract grpc.ClientConn from existing channel - conn := node.channel.conn + conn := tc.conn // Create new channel with test context without metadata (real implementation captures metadata) - node.channel = newChannel(t.Context(), conn, node.id, 10) - return node + tc.Channel = NewChannel(t.Context(), conn, tc.id, 10) + return tc } // Helper to verify stream expectations @@ -348,18 +371,18 @@ func TestChannelEnsureStream(t *testing.T) { tests := []struct { name string - setup func(t *testing.T) *Node - action func(node *Node) (first, second grpc.ClientStream) + setup func(t *testing.T) *testChannel + action func(tc *testChannel) (first, second grpc.ClientStream) wantSame bool }{ { name: "UnconnectedNodeHasNoStream", - setup: func(t *testing.T) *Node { return testNodeWithoutServer(t) }, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { - if err := node.channel.ensureStream(); err == nil { + setup: func(t *testing.T) *testChannel { return newChannelWithoutStream(t) }, + action: func(tc *testChannel) (grpc.ClientStream, grpc.ClientStream) { + if err := tc.ensureStream(); err == nil { t.Error("ensureStream succeeded unexpectedly") } - if getStream(node) != nil { + if tc.getStream() != nil { t.Error("stream should be nil") } return nil, nil @@ -367,42 +390,42 @@ func TestChannelEnsureStream(t *testing.T) { }, { name: "CreatesStreamWhenConnected", - setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { - if err := node.channel.ensureStream(); err != nil { + setup: newChannelWithoutStream, + action: func(tc *testChannel) (grpc.ClientStream, grpc.ClientStream) { + if err := tc.ensureStream(); err != nil { t.Errorf("ensureStream failed: %v", err) } - return getStream(node), nil + return tc.getStream(), nil }, }, { name: "RepeatedCallsReturnSameStream", - setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { - if err := node.channel.ensureStream(); err != nil { + setup: newChannelWithoutStream, + action: func(tc *testChannel) (grpc.ClientStream, grpc.ClientStream) { + if err := tc.ensureStream(); err != nil { t.Errorf("first ensureStream failed: %v", err) } - first := getStream(node) - if err := node.channel.ensureStream(); err != nil { + first := tc.getStream() + if err := tc.ensureStream(); err != nil { t.Errorf("second ensureStream failed: %v", err) } - return first, getStream(node) + return first, tc.getStream() }, wantSame: true, }, { name: "StreamDisconnectionCreatesNewStream", - setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { - if err := node.channel.ensureStream(); err != nil { + setup: newChannelWithoutStream, + action: func(tc *testChannel) (grpc.ClientStream, grpc.ClientStream) { + if err := tc.ensureStream(); err != nil { t.Errorf("initial ensureStream failed: %v", err) } - first := getStream(node) - node.channel.clearStream() - if err := node.channel.ensureStream(); err != nil { + first := tc.getStream() + tc.clearStream() + if err := tc.ensureStream(); err != nil { t.Errorf("ensureStream after disconnect failed: %v", err) } - return first, getStream(node) + return first, tc.getStream() }, wantSame: false, }, @@ -410,55 +433,72 @@ func TestChannelEnsureStream(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - node := tt.setup(t) - first, second := tt.action(node) + tc := tt.setup(t) + first, second := tt.action(tc) cmpStream(t, first, second, tt.wantSame) }) } } +func TestChannelEnsureStreamAfterBroken(t *testing.T) { + tc := setupChannel(t, echoServer) + + // Ensure we have a stream + if err := tc.ensureStream(); err != nil { + t.Fatalf("ensureStream failed: %v", err) + } + + // Break the stream + tc.clearStream() + + // Ensure we can get it back + if err := tc.ensureStream(); err != nil { + t.Fatalf("ensureStream failed after clear: %v", err) + } +} + // TestChannelConnectionState verifies connection state detection and behavior. func TestChannelConnectionState(t *testing.T) { tests := []struct { name string - setup func(t *testing.T) *Node + setup func(t *testing.T) *testChannel wantConnected bool }{ { name: "WithoutServer", - setup: func(t *testing.T) *Node { return testNodeWithoutServer(t) }, + setup: func(t *testing.T) *testChannel { return setupChannelWithoutServer(t) }, wantConnected: false, }, { name: "WithLiveServer", - setup: func(t *testing.T) *Node { return TestNode(t, delayServerFn(0)) }, + setup: func(t *testing.T) *testChannel { return setupChannel(t, echoServer) }, wantConnected: true, }, { name: "RequiresBothReadyAndStream", - setup: func(t *testing.T) *Node { - node := TestNode(t, delayServerFn(0)) + setup: func(t *testing.T) *testChannel { + tc := setupChannel(t, echoServer) // Wait for stream to be established - if !waitForConnection(node, streamConnectTimeout) { + if !waitForConnection(tc.Channel, streamConnectTimeout) { t.Fatal("node should be connected before clearing stream") } - node.channel.clearStream() - return node + tc.clearStream() + return tc }, wantConnected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - node := tt.setup(t) + tc := tt.setup(t) if tt.wantConnected { // For tests expecting connection, poll until connected or timeout - if !waitForConnection(node, streamConnectTimeout) { + if !waitForConnection(tc.Channel, streamConnectTimeout) { t.Errorf("isConnected() = false, want true") } } else { // For tests expecting no connection, verify immediately - if node.channel.isConnected() { + if tc.isConnected() { t.Errorf("isConnected() = true, want false") } } @@ -466,41 +506,6 @@ func TestChannelConnectionState(t *testing.T) { } } -// TestChannelConcurrentSends tests sending multiple messages concurrently from multiple goroutines. -func TestChannelConcurrentSends(t *testing.T) { - node := TestNode(t, delayServerFn(0)) - - const numMessages = 1000 - const numGoroutines = 10 - msgsPerGoroutine := numMessages / (2 * numGoroutines) - - results := make(chan msgResponse, numMessages) - for goID := range numGoroutines { - go func() { - sendReq(t, results, node, goID, msgsPerGoroutine, request{waitSendDone: true}) - sendReq(t, results, node, goID, msgsPerGoroutine, request{waitSendDone: false}) - }() - } - - var errs []error - for range numMessages { - res := <-results - if res.resp.Err != nil { - errs = append(errs, res.resp.Err) - } - } - - if len(errs) > 0 { - t.Errorf("got %d errors during concurrent sends (first few): %v", len(errs), errs[:min(3, len(errs))]) - } - if !node.channel.isConnected() { - t.Error("node should still be connected after concurrent sends") - } - if node.mgr == nil { - t.Error("manager should not be nil") - } -} - func TestChannelContext(t *testing.T) { // Helper context setup functions cancelledContext := func(ctx context.Context) (context.Context, context.CancelFunc) { @@ -559,8 +564,8 @@ func TestChannelContext(t *testing.T) { ctx, cancel := tt.contextSetup(t.Context()) t.Cleanup(cancel) - node := TestNode(t, delayServerFn(tt.serverDelay)) - resp := sendRequest(t, node, request{ctx: ctx, waitSendDone: tt.waitSendDone}, uint64(i)) + tc := setupChannel(t, delayServer(tt.serverDelay)) + resp := sendRequest(t, tc.Channel, Request{Ctx: ctx, WaitSendDone: tt.waitSendDone}, uint64(i)) if !errors.Is(resp.Err, tt.wantErr) { t.Errorf("expected %v, got: %v", tt.wantErr, resp.Err) } @@ -568,88 +573,77 @@ func TestChannelContext(t *testing.T) { } } -// TestChannelDeadlock reproduces a deadlock bug (issue #235) that occurred -// in channel.go when the stream broke during active communication. -// -// Root Cause: -// The receiver goroutine held a read lock while performing a blocking I/O operation -// that could hang indefinitely when the stream broke. Meanwhile, the sender goroutine -// tried to acquire a write lock to reconnect, creating a deadlock. -// -// This test verifies the fix by: -// 1. Establishing a connection and activating the stream -// 2. Breaking the stream by stopping the server -// 3. Sending multiple messages concurrently to trigger the deadlock condition -// 4. Verifying all goroutines can successfully enqueue without hanging -func TestChannelDeadlock(t *testing.T) { - node := TestNode(t, delayServerFn(0)) - // Wait for the stream to be established - if !waitForConnection(node, streamConnectTimeout) { - t.Fatal("node should be connected") +// TestChannelStreamReadySignaling verifies that the receiver goroutine is properly notified +// when a stream becomes available. +func TestChannelStreamReadySignaling(t *testing.T) { + tc := setupChannel(t, echoServer) + + start := time.Now() + resp := sendRequest(t, tc.Channel, Request{}, 1) + firstLatency := time.Since(start) + + if resp.Err != nil { + t.Fatalf("unexpected error on first request: %v", resp.Err) } - // Send a message to activate the stream - sendRequest(t, node, request{waitSendDone: true}, 1) + start = time.Now() + resp = sendRequest(t, tc.Channel, Request{}, 2) + secondLatency := time.Since(start) - // Break the stream, forcing a reconnection on next send - node.channel.clearStream() - time.Sleep(20 * time.Millisecond) + if resp.Err != nil { + t.Fatalf("unexpected error on second request: %v", resp.Err) + } - // Send multiple messages concurrently when stream is broken with the - // goal to trigger a deadlock between sender and receiver goroutines. - doneChan := make(chan bool, 10) - for id := range 10 { - go func() { - ctx := TestContext(t, 3*time.Second) - reqMsg, _ := stream.NewMessage(ctx, uint64(100+id), mock.TestMethod, nil) - req := request{ctx: ctx, msg: reqMsg} + t.Logf("first request latency: %v", firstLatency) + t.Logf("second request latency: %v", secondLatency) - // try to enqueue - select { - case node.channel.sendQ <- req: - // successfully enqueued - doneChan <- true - case <-ctx.Done(): - // timed out trying to enqueue (deadlock!) - doneChan <- false - } - }() + const maxAcceptableLatency = 100 * time.Millisecond + if firstLatency > maxAcceptableLatency { + t.Errorf("first request took %v, expected < %v", firstLatency, maxAcceptableLatency) } +} - // Wait for all goroutines to complete - timeout := time.After(5 * time.Second) - successful := 0 - for completed := range 10 { - select { - case success := <-doneChan: - if success { - successful++ - } - case <-timeout: - // remaining goroutines are stuck trying to enqueue. - t.Fatalf("DEADLOCK: Only %d/10 goroutines completed (%d successful).", completed, successful) - } +// TestChannelStreamReadyAfterReconnect verifies that the receiver is properly notified +// when a stream is re-established after being cleared (simulating reconnection). +func TestChannelStreamReadyAfterReconnect(t *testing.T) { + tc := setupChannel(t, echoServer) + + // Wait for initial connection + if !waitForConnection(tc.Channel, streamConnectTimeout) { + t.Fatal("channel should be connected") } - // If we reach here, all 10 goroutines completed (but some may have failed to enqueue) - if successful < 10 { - t.Fatalf("DEADLOCK: %d/10 goroutines timed out trying to enqueue (sendQ blocked)", 10-successful) + // Consume initial signal + select { + case <-tc.streamReady: + case <-time.After(time.Second): + t.Fatal("timeout waiting for initial streamReady") + } + + // Force reconnect + tc.clearStream() + // Trigger ensureStream via send (or manual) + tc.ensureStream() + + // Should get a new signal + select { + case <-tc.streamReady: + case <-time.After(time.Second): + t.Fatal("timeout waiting for streamReady after reconnect") } } -// TestChannelRouterLifecycle tests router creation, persistence, and cleanup behavior. func TestChannelRouterLifecycle(t *testing.T) { - node := TestNode(t, delayServerFn(0)) - // Wait for the stream to be established - if !waitForConnection(node, streamConnectTimeout) { - t.Fatal("node should be connected") + tc := setupChannel(t, echoServer) + + if !waitForConnection(tc.Channel, streamConnectTimeout) { + t.Fatal("channel should be connected") } tests := []struct { name string waitSendDone bool streaming bool - afterSend func(t *testing.T, node *Node, msgID uint64) wantRouter bool }{ {name: "WaitSendDone/NonStreamingAutoCleanup", waitSendDone: true, streaming: false, wantRouter: false}, @@ -658,31 +652,44 @@ func TestChannelRouterLifecycle(t *testing.T) { {name: "NoSendWaiting/StreamingKeepsRouterAlive", waitSendDone: false, streaming: true, wantRouter: true}, } for i, tt := range tests { - name := fmt.Sprintf("msgID=%d/%s/streaming=%t", i, tt.name, tt.streaming) + name := fmt.Sprintf("%s/msgID=%d/streaming=%t", tt.name, i, tt.streaming) t.Run(name, func(t *testing.T) { msgID := uint64(i) - resp := sendRequest(t, node, request{waitSendDone: tt.waitSendDone, streaming: tt.streaming}, msgID) + resp := sendRequest(t, tc.Channel, Request{WaitSendDone: tt.waitSendDone, Streaming: tt.streaming}, msgID) if resp.Err != nil { t.Errorf("unexpected error: %v", resp.Err) } - if exists := routerExists(node, msgID); exists != tt.wantRouter { + if exists := routerExists(tc.Channel, msgID); exists != tt.wantRouter { t.Errorf("router exists = %v, want %v", exists, tt.wantRouter) } - node.channel.deleteRouter(msgID) // just for kicks + deleteRouter(tc.Channel, msgID) }) } } -// TestChannelResponseRouting sends multiple messages and verifies that -// responses are correctly routed to their callers. +// Helper functions for testing channel response routing and router lifecycle + +func routerExists(c *Channel, msgID uint64) bool { + c.responseMut.Lock() + defer c.responseMut.Unlock() + _, exists := c.responseRouters[msgID] + return exists +} + +func deleteRouter(c *Channel, msgID uint64) { + c.responseMut.Lock() + defer c.responseMut.Unlock() + delete(c.responseRouters, msgID) +} + func TestChannelResponseRouting(t *testing.T) { - node := TestNode(t, delayServerFn(0)) + tc := setupChannel(t, echoServer) const numMessages = 20 results := make(chan msgResponse, numMessages) for i := range numMessages { - go sendReq(t, results, node, i, 1, request{}) + go sendReq(t, results, tc.Channel, i, 1, Request{WaitSendDone: true}) } // Collect and verify results @@ -698,91 +705,105 @@ func TestChannelResponseRouting(t *testing.T) { received[result.msgID] = true } - // Verify all messages were received if len(received) != numMessages { t.Errorf("got %d unique responses, want %d", len(received), numMessages) } } -// TestChannelStreamReadySignaling verifies that the receiver goroutine is properly notified -// when a stream becomes available. This tests the stream-ready signaling mechanism -// that replaces the old time.Sleep polling approach. -func TestChannelStreamReadySignaling(t *testing.T) { - node := TestNode(t, delayServerFn(0)) +func TestChannelConcurrentSends(t *testing.T) { + tc := setupChannel(t, echoServer) - // The first request triggers stream creation. We measure how quickly - // the receiver starts processing after the stream is ready. - start := time.Now() - resp := sendRequest(t, node, request{}, 1) - firstLatency := time.Since(start) + const numMessages = 1000 + const numGoroutines = 10 + msgsPerGoroutine := numMessages / (2 * numGoroutines) - if resp.Err != nil { - t.Fatalf("unexpected error on first request: %v", resp.Err) + results := make(chan msgResponse, numMessages) + for goID := range numGoroutines { + go func() { + sendReq(t, results, tc.Channel, goID, msgsPerGoroutine, Request{WaitSendDone: true}) + sendReq(t, results, tc.Channel, goID, msgsPerGoroutine, Request{WaitSendDone: false}) + }() } - // Second request should be faster since stream is already established - start = time.Now() - resp = sendRequest(t, node, request{}, 2) - secondLatency := time.Since(start) - - if resp.Err != nil { - t.Fatalf("unexpected error on second request: %v", resp.Err) + var errs []error + for range numMessages { + res := <-results + if res.resp.Err != nil { + errs = append(errs, res.resp.Err) + } } - t.Logf("first request latency: %v", firstLatency) - t.Logf("second request latency: %v", secondLatency) - - // The first request should complete in reasonable time (not waiting for polling timeout). - // With the old 10ms polling, the first request could take 10-20ms just waiting for the stream. - // With proper signaling, it should be much faster (sub-millisecond for the signal itself). - const maxAcceptableLatency = 100 * time.Millisecond // generous bound for CI environments - if firstLatency > maxAcceptableLatency { - t.Errorf("first request took %v, expected < %v (possible stream-ready polling delay)", firstLatency, maxAcceptableLatency) + if len(errs) > 0 { + t.Errorf("got %d errors during concurrent sends (first few): %v", len(errs), errs[:min(3, len(errs))]) + } + if !tc.isConnected() { + t.Error("channel should still be connected after concurrent sends") } } -// TestChannelStreamReadyAfterReconnect verifies that the receiver is properly notified -// when a stream is re-established after being cleared (simulating reconnection). -func TestChannelStreamReadyAfterReconnect(t *testing.T) { - node := TestNode(t, delayServerFn(0)) +// TestChannelDeadlock reproduces a deadlock bug (issue #235) that occurred +// in channel.go when the stream broke during active communication. +// +// Root Cause: +// The receiver goroutine held a read lock while performing a blocking I/O operation +// that could hang indefinitely when the stream broke. Meanwhile, the sender goroutine +// tried to acquire a write lock to reconnect, creating a deadlock. +// +// This test verifies the fix by: +// 1. Establishing a connection and activating the stream +// 2. Breaking the stream by stopping the server +// 3. Sending multiple messages concurrently to trigger the deadlock condition +// 4. Verifying all goroutines can successfully enqueue without hanging +func TestChannelDeadlock(t *testing.T) { + tc := setupChannel(t, breakStreamServer) - // First request to establish the stream - resp := sendRequest(t, node, request{}, 1) - if resp.Err != nil { - t.Fatalf("unexpected error on first request: %v", resp.Err) + if !waitForConnection(tc.Channel, streamConnectTimeout) { + t.Fatal("channel should be connected") } - // Clear the stream to simulate disconnection. - // Note: In production, clearStream is called by the receiver when it detects an error. - // The next request will trigger stream re-creation via ensureStream(). - node.channel.clearStream() + // Send message to activate stream + sendRequest(t, tc.Channel, Request{WaitSendDone: true}, 1) - // The sender's ensureStream() will recreate the stream. - // We may need to retry a few times since there's a race between - // clearStream and the sender's stream check. - var reconnectLatency time.Duration - var lastErr error - start := time.Now() - for i := range 5 { - resp = sendRequest(t, node, request{}, uint64(i+2)) - if resp.Err == nil { - reconnectLatency = time.Since(start) - break - } - lastErr = resp.Err - // Give the sender a chance to reconnect - time.Sleep(5 * time.Millisecond) - } - if resp.Err != nil { - t.Fatalf("failed to reconnect after 5 attempts: %v", lastErr) - } + // Break the stream, forcing a reconnection on next send + tc.clearStream() + time.Sleep(20 * time.Millisecond) - t.Logf("reconnect latency: %v", reconnectLatency) + // Send multiple messages concurrently when stream is broken with the + // goal to trigger a deadlock between sender and receiver goroutines. + doneChan := make(chan bool, 10) + for id := range 10 { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + reqMsg, _ := NewMessage(ctx, uint64(100+id), mock.TestMethod, nil) + req := Request{Ctx: ctx, Msg: reqMsg} - // Even after reconnection, the latency should be reasonable - const maxAcceptableLatency = 200 * time.Millisecond - if reconnectLatency > maxAcceptableLatency { - t.Errorf("reconnect took %v, expected < %v", reconnectLatency, maxAcceptableLatency) + select { + case tc.Channel.sendQ <- req: + doneChan <- true + case <-ctx.Done(): + doneChan <- false + } + }() + } + + // Wait for all goroutines to complete + timeout := time.After(5 * time.Second) + successful := 0 + for completed := range 10 { + select { + case success := <-doneChan: + if success { + successful++ + } + case <-timeout: + // remaining goroutines are stuck trying to enqueue. + t.Fatalf("DEADLOCK: Only %d/10 goroutines completed", completed) + } + } + // If we reach here, all 10 goroutines completed (but some may have failed to enqueue) + if successful < 10 { + t.Fatalf("DEADLOCK: %d/10 goroutines timed out", 10-successful) } } @@ -802,16 +823,16 @@ func BenchmarkChannelStreamReadyFirstRequest(b *testing.B) { } for b.Loop() { - var stopServer func(...int) - node := TestNode(b, delayServerFn(0), WithStopFunc(b, &stopServer)) + tc := setupChannel(b, echoServer) // Use a fresh context for the benchmark request - ctx := TestContext(b, defaultTestTimeout) - reqMsg, _ := stream.NewMessage(ctx, 1, mock.TestMethod, nil) - req := request{ctx: ctx, msg: reqMsg} - replyChan := make(chan NodeResponse[msg], 1) - req.responseChan = replyChan - node.channel.enqueue(req) + ctx, cancel := context.WithTimeout(b.Context(), defaultTestTimeout) + defer cancel() + reqMsg, _ := NewMessage(ctx, 1, mock.TestMethod, nil) + req := Request{Ctx: ctx, Msg: reqMsg} + replyChan := make(chan response, 1) + req.ResponseChan = replyChan + tc.Enqueue(req) select { case resp := <-replyChan: @@ -823,28 +844,8 @@ func BenchmarkChannelStreamReadyFirstRequest(b *testing.B) { } // Close the node before stopping the server to ensure clean shutdown - _ = node.close() - stopServer() - } -} - -// BenchmarkChannelStreamReadySubsequentRequest measures the latency of requests -// after the stream is already established (steady-state performance). -func BenchmarkChannelStreamReadySubsequentRequest(b *testing.B) { - node := TestNode(b, delayServerFn(0)) - - // Warm up: establish the stream - resp := sendRequest(b, node, request{}, 0) - if resp.Err != nil { - b.Fatalf("warmup error: %v", resp.Err) - } - - b.ResetTimer() - for i := range b.N { - resp := sendRequest(b, node, request{}, uint64(i+1)) - if resp.Err != nil { - b.Fatalf("unexpected error: %v", resp.Err) - } + _ = tc.Close() + tc.srv.Stop() } } @@ -853,15 +854,15 @@ func BenchmarkChannelStreamReadySubsequentRequest(b *testing.B) { // Note: This benchmark has inherent variability due to the race between // clearStream and the sender's ensureStream call. func BenchmarkChannelStreamReadyReconnect(b *testing.B) { - node := TestNode(b, delayServerFn(0)) + tc := setupChannel(b, echoServer) // Establish initial stream with a fresh context ctx := context.Background() - reqMsg, _ := stream.NewMessage(ctx, 0, mock.TestMethod, nil) - req := request{ctx: ctx, msg: reqMsg} - replyChan := make(chan NodeResponse[msg], 1) - req.responseChan = replyChan - node.channel.enqueue(req) + reqMsg, _ := NewMessage(ctx, 0, mock.TestMethod, nil) + req := Request{Ctx: ctx, Msg: reqMsg} + replyChan := make(chan response, 1) + req.ResponseChan = replyChan + tc.Enqueue(req) select { case resp := <-replyChan: @@ -874,7 +875,7 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { b.ResetTimer() for i := range b.N { - node.channel.clearStream() + tc.clearStream() // Wait a tiny bit for the receiver to notice the stream is gone // and be ready for the signal. This simulates real-world behavior @@ -883,20 +884,90 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Now send a request which will trigger ensureStream -> newNodeStream -> signal ctx := context.Background() - reqMsg, _ := stream.NewMessage(ctx, uint64(i+1), mock.TestMethod, nil) - req := request{ctx: ctx, msg: reqMsg} - replyChan := make(chan NodeResponse[msg], 1) - req.responseChan = replyChan - node.channel.enqueue(req) + reqMsg, _ := NewMessage(ctx, uint64(i+1), mock.TestMethod, nil) + req := Request{Ctx: ctx, Msg: reqMsg} + replyChan := make(chan response, 1) + req.ResponseChan = replyChan + tc.Enqueue(req) select { case resp := <-replyChan: if resp.Err != nil { // Stream down error is expected sometimes due to race; just continue - b.Logf("request %d error: %v", i, resp.Err) + // b.Logf("request %d error: %v", i, resp.Err) } case <-time.After(500 * time.Millisecond): b.Fatalf("timeout on request %d", i) } } } + +func BenchmarkChannelSend(b *testing.B) { + tc := setupChannel(b, echoServer) + + tests := []struct { + name string + size int // payload size in bytes + }{ + {"100B", 100}, + {"1KB", 1024}, + {"10KB", 10 * 1024}, + {"100KB", 100 * 1024}, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + payload := make([]byte, tt.size) + b.ResetTimer() + for i := range b.N { + // Optimization: reuse chan if we know it's 1-buffered and read. + replyChan := make(chan response, 1) + msg := Message_builder{ + MessageSeqNo: uint64(i), + Method: mock.TestMethod, + Payload: payload, + }.Build() + req := Request{Ctx: context.Background(), Msg: msg, WaitSendDone: true, ResponseChan: replyChan} + tc.Enqueue(req) + <-replyChan + } + }) + } +} + +var msgID atomic.Uint64 + +func BenchmarkChannelSendParallel(b *testing.B) { + tc := setupChannel(b, echoServer) + + tests := []struct { + name string + size int + }{ + {"100B", 100}, + {"1KB", 1024}, + {"10KB", 10 * 1024}, + {"100KB", 100 * 1024}, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + payload := make([]byte, tt.size) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + replyChan := make(chan response, 1) + for pb.Next() { + id := msgID.Add(1) + msg := Message_builder{ + MessageSeqNo: id, + Method: mock.TestMethod, + Payload: payload, + }.Build() + req := Request{Ctx: context.Background(), Msg: msg, WaitSendDone: true, ResponseChan: replyChan} + tc.Enqueue(req) + <-replyChan + } + }) + }) + } +} diff --git a/internal/stream/marshaling.go b/internal/stream/marshaling.go new file mode 100644 index 000000000..2a70f3092 --- /dev/null +++ b/internal/stream/marshaling.go @@ -0,0 +1,67 @@ +package stream + +import ( + "fmt" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" +) + +// UnmarshalRequest unmarshals the request proto message from the message. +// It uses the method name in the message to look up the Input type from the proto registry. +// +// This function should only be used by internal channel operations. +func UnmarshalRequest(in *Message) (proto.Message, error) { + // get method descriptor from registry + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(in.GetMethod())) + if err != nil { + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", in.GetMethod()) + } + methodDesc := desc.(protoreflect.MethodDescriptor) + + // get the request message type (Input type) + msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Input().FullName()) + if err != nil { + return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Input().FullName()) + } + req := msgType.New().Interface() + + // unmarshal message from the Message.Payload field + payload := in.GetPayload() + if len(payload) > 0 { + if err := proto.Unmarshal(payload, req); err != nil { + return nil, fmt.Errorf("gorums: could not unmarshal request: %w", err) + } + } + return req, nil +} + +// UnmarshalResponse unmarshals the response proto message from the message. +// It uses the method name in the message to look up the Output type from the proto registry. +// +// This function should only be used by internal channel operations. +func UnmarshalResponse(out *Message) (proto.Message, error) { + // get method descriptor from registry + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(out.GetMethod())) + if err != nil { + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", out.GetMethod()) + } + methodDesc := desc.(protoreflect.MethodDescriptor) + + // get the response message type (Output type) + msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Output().FullName()) + if err != nil { + return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Output().FullName()) + } + resp := msgType.New().Interface() + + // unmarshal message from the Message.Payload field + payload := out.GetPayload() + if len(payload) > 0 { + if err := proto.Unmarshal(payload, resp); err != nil { + return nil, fmt.Errorf("gorums: could not unmarshal response: %w", err) + } + } + return resp, nil +} diff --git a/internal/stream/response.go b/internal/stream/response.go new file mode 100644 index 000000000..764398e88 --- /dev/null +++ b/internal/stream/response.go @@ -0,0 +1,20 @@ +package stream + +import ( + "errors" + + "google.golang.org/protobuf/proto" +) + +// NodeResponse wraps a response value from node ID, and an error if any. +type NodeResponse[T any] struct { + NodeID uint32 + Value T + Err error +} + +// response is a type alias for NodeResponse[proto.Message] to avoid long type names. +type response = NodeResponse[proto.Message] + +// ErrTypeMismatch is returned when a response cannot be cast to the expected type. +var ErrTypeMismatch = errors.New("response type mismatch") diff --git a/node.go b/node.go index 9a02e2af3..cd3b1dac3 100644 --- a/node.go +++ b/node.go @@ -10,6 +10,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/metadata" + + "github.com/relab/gorums/internal/stream" ) const nilAngleString = "" @@ -29,8 +31,8 @@ func (c NodeContext) Node() *Node { } // enqueue enqueues a request to this node's channel. -func (c NodeContext) enqueue(req request) { - c.node.channel.enqueue(req) +func (c NodeContext) enqueue(req stream.Request) { + c.node.channel.Enqueue(req) } // nextMsgID returns the next message ID from this client's manager. @@ -47,7 +49,7 @@ type Node struct { mgr *Manager // only used for backward compatibility to allow Configuration.Manager() msgIDGen func() uint64 - channel *channel + channel *stream.Channel } // Context creates a new NodeContext from the given parent context @@ -104,14 +106,14 @@ func newNode(addr string, opts nodeOptions) (*Node, error) { ctx := metadata.NewOutgoingContext(context.Background(), md) // Create channel and establish gRPC node stream - n.channel = newChannel(ctx, conn, n.id, opts.SendBufferSize) + n.channel = stream.NewChannel(ctx, conn, n.id, opts.SendBufferSize) return n, nil } // close this node. func (n *Node) close() error { if n.channel != nil { - return n.channel.close() + return n.channel.Close() } return nil } @@ -168,12 +170,12 @@ func (n *Node) FullString() string { // LastErr returns the last error encountered (if any) for this node. func (n *Node) LastErr() error { - return n.channel.lastErr() + return n.channel.LastErr() } // Latency returns the latency between the client and this node. func (n *Node) Latency() time.Duration { - return n.channel.channelLatency() + return n.channel.Latency() } type lessFunc func(n1, n2 *Node) bool @@ -251,7 +253,7 @@ var Port = func(n1, n2 *Node) bool { // LastNodeError sorts nodes by their LastErr() status in increasing order. A // node with LastErr() != nil is larger than a node with LastErr() == nil. var LastNodeError = func(n1, n2 *Node) bool { - if n1.channel.lastErr() != nil && n2.channel.lastErr() == nil { + if n1.channel.LastErr() != nil && n2.channel.LastErr() == nil { return false } return true diff --git a/node_test.go b/node_test.go index 3a2604116..0c365d0c1 100644 --- a/node_test.go +++ b/node_test.go @@ -5,37 +5,39 @@ import ( "fmt" "testing" "time" + + "github.com/relab/gorums/internal/stream" ) func TestNodeSort(t *testing.T) { nodes := []*Node{ { id: 100, - channel: &channel{ - lastError: nil, - latency: time.Second, - }, + channel: stream.NewChannelWithState( + time.Second, + nil, + ), }, { id: 101, - channel: &channel{ - lastError: errors.New("some error"), - latency: 250 * time.Millisecond, - }, + channel: stream.NewChannelWithState( + 250*time.Millisecond, + errors.New("some error"), + ), }, { id: 42, - channel: &channel{ - lastError: nil, - latency: 300 * time.Millisecond, - }, + channel: stream.NewChannelWithState( + 300*time.Millisecond, + nil, + ), }, { id: 99, - channel: &channel{ - lastError: errors.New("some error"), - latency: 500 * time.Millisecond, - }, + channel: stream.NewChannelWithState( + 500*time.Millisecond, + errors.New("some error"), + ), }, } diff --git a/responses.go b/responses.go index a6d7bf87f..dea89e512 100644 --- a/responses.go +++ b/responses.go @@ -5,11 +5,34 @@ import ( "iter" "google.golang.org/protobuf/proto" + + "github.com/relab/gorums/internal/stream" ) // msg is a type alias for proto.Message intended to be used as a type parameter. type msg = proto.Message +// NodeResponse is a type alias for stream.NodeResponse. +type NodeResponse[T any] = stream.NodeResponse[T] + +// mapToCallResponse converts a NodeResponse[msg] to a NodeResponse[Resp]. +// This is necessary because the channel layer's response router returns a +// NodeResponse[msg], while the calltype expects a NodeResponse[Resp]. +func mapToCallResponse[Resp msg](channelResp NodeResponse[msg]) NodeResponse[Resp] { + callResp := NodeResponse[Resp]{ + NodeID: channelResp.NodeID, + Err: channelResp.Err, + } + if channelResp.Err == nil { + if val, ok := channelResp.Value.(Resp); ok { + callResp.Value = val + } else { + callResp.Err = ErrTypeMismatch + } + } + return callResp +} + // ------------------------------------------------------------------------- // Iterator Helpers // ------------------------------------------------------------------------- diff --git a/rpc.go b/rpc.go index 12b804720..aa808ab7e 100644 --- a/rpc.go +++ b/rpc.go @@ -12,7 +12,7 @@ func RPCCall[Req, Resp msg](ctx *NodeContext, req Req, method string) (Resp, err var zero Resp return zero, err } - ctx.enqueue(request{ctx: ctx, msg: reqMsg, responseChan: replyChan}) + ctx.enqueue(stream.Request{Ctx: ctx, Msg: reqMsg, ResponseChan: replyChan}) select { case r := <-replyChan: diff --git a/server.go b/server.go index 0e1c31f92..cba5bdfb0 100644 --- a/server.go +++ b/server.go @@ -77,7 +77,7 @@ func (s *streamServer) NodeStream(srv stream.Gorums_NodeStreamServer) error { srvCtx := newServerCtx(streamIn.AppendToIncomingContext(ctx), &mut, finished) defer srvCtx.Release() - msg, err := unmarshalRequest(streamIn) + msg, err := stream.UnmarshalRequest(streamIn) in := &Message{Msg: msg, Message: streamIn} if err != nil { _ = srvCtx.SendMessage(messageWithError(in, nil, err)) diff --git a/unicast.go b/unicast.go index a58f93493..d724c76f5 100644 --- a/unicast.go +++ b/unicast.go @@ -23,13 +23,13 @@ func Unicast[Req msg](ctx *NodeContext, req Req, method string, opts ...CallOpti waitSendDone := callOpts.mustWaitSendDone() if !waitSendDone { // Fire-and-forget: enqueue and return immediately - ctx.enqueue(request{ctx: ctx, msg: reqMsg}) + ctx.enqueue(stream.Request{Ctx: ctx, Msg: reqMsg}) return nil } // Default: block until send completes replyChan := make(chan NodeResponse[msg], 1) - ctx.enqueue(request{ctx: ctx, msg: reqMsg, waitSendDone: true, responseChan: replyChan}) + ctx.enqueue(stream.Request{Ctx: ctx, Msg: reqMsg, WaitSendDone: true, ResponseChan: replyChan}) // Wait for send confirmation select {