From 8791c99b79c5ffd4721b0c88549e6f8b99731441 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 8 Feb 2026 18:35:36 +0100 Subject: [PATCH 01/19] refactor(channel): use Metadata directly instead of Message wrapper Replace the Message wrapper with ordering.Metadata in the request struct and channel operations. This simplifies the wire protocol by using the generated protobuf types directly. Changes: - Replace `msg *Message` with `md *ordering.Metadata` in request struct - Update getStream() to return the typed Gorums_NodeStreamClient - Use stream.Send()/Recv() with Metadata instead of SendMsg()/RecvMsg() - Replace GetMessageID() with GetMessageSeqNo() for message routing - Add UnmarshalResponse() for deserializing response messages - Remove unused newRequestMessage() function from encoding.go --- channel.go | 32 ++++++++++++++------------ channel_test.go | 18 ++++++++++----- client_interceptor.go | 10 +++++++-- encoding.go | 52 +++++++++++++++++++++++++++++++++++++------ rpc.go | 7 +++++- unicast.go | 9 +++++--- 6 files changed, 96 insertions(+), 32 deletions(-) diff --git a/channel.go b/channel.go index 7b61ca37..ea677b30 100644 --- a/channel.go +++ b/channel.go @@ -44,7 +44,7 @@ var ( type request struct { ctx context.Context - msg *Message + md *ordering.Metadata waitSendDone bool streaming bool responseChan chan<- NodeResponse[msg] @@ -154,7 +154,7 @@ func (c *channel) ensureConnectedNodeStream() (err error) { } // getStream returns the current stream, or nil if no stream is available. -func (c *channel) getStream() grpc.ClientStream { +func (c *channel) getStream() ordering.Gorums_NodeStreamClient { c.streamMut.Lock() defer c.streamMut.Unlock() return c.stream @@ -180,7 +180,7 @@ func (c *channel) isConnected() bool { func (c *channel) enqueue(req request) { if req.responseChan != nil { req.sendTime = time.Now() - msgID := req.msg.GetMessageID() + msgID := req.md.GetMessageSeqNo() c.responseMut.Lock() c.responseRouters[msgID] = req c.responseMut.Unlock() @@ -190,7 +190,7 @@ func (c *channel) enqueue(req request) { select { case <-c.connCtx.Done(): // the node's close() method was called: respond with error instead of enqueueing - c.routeResponse(req.msg.GetMessageID(), NodeResponse[msg]{NodeID: c.id, Err: nodeClosedErr}) + c.routeResponse(req.md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: nodeClosedErr}) return case c.sendQ <- req: // enqueued successfully @@ -252,11 +252,11 @@ func (c *channel) sender() { // take next request from sendQ } if err := c.ensureStream(); err != nil { - c.routeResponse(req.msg.GetMessageID(), NodeResponse[msg]{NodeID: c.id, Err: err}) + c.routeResponse(req.md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) continue } if err := c.sendMsg(req); err != nil { - c.routeResponse(req.msg.GetMessageID(), NodeResponse[msg]{NodeID: c.id, Err: err}) + c.routeResponse(req.md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) } } } @@ -279,14 +279,18 @@ func (c *channel) receiver() { } } - resp := newMessage(responseType) - if err := stream.RecvMsg(resp); err != nil { - c.setLastErr(err) - c.cancelPendingMsgs(err) + md, e := stream.Recv() + if e != nil { + c.setLastErr(e) + c.cancelPendingMsgs(e) c.clearStream() } else { - err := resp.GetStatus().Err() - c.routeResponse(resp.GetMessageID(), NodeResponse[msg]{NodeID: c.id, Value: resp.GetProtoMessage(), Err: err}) + err := status.FromProto(md.GetStatus()).Err() + var resp msg + if err == nil { + resp, err = UnmarshalResponse(md) + } + c.routeResponse(md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Value: resp, Err: err}) } select { @@ -316,7 +320,7 @@ func (c *channel) sendMsg(req request) (err error) { // wait for actual server responses, so waitSendDone is false for them. if req.waitSendDone && err == nil { // Send succeeded: unblock the caller and clean up the responseRouter - c.routeResponse(req.msg.GetMessageID(), NodeResponse[msg]{}) + c.routeResponse(req.md.GetMessageSeqNo(), NodeResponse[msg]{}) } }() @@ -353,7 +357,7 @@ func (c *channel) sendMsg(req request) (err error) { } }() - if err = stream.SendMsg(req.msg); err != nil { + if err = stream.Send(req.md); err != nil { c.setLastErr(err) c.clearStream() } diff --git a/channel_test.go b/channel_test.go index 1da69fe6..ea2823a4 100644 --- a/channel_test.go +++ b/channel_test.go @@ -59,7 +59,11 @@ func sendRequest(t testing.TB, node *Node, req request, msgID uint64) NodeRespon if req.ctx == nil { req.ctx = t.Context() } - req.msg = NewRequest(req.ctx, msgID, mock.TestMethod, nil) + md, err := MarshalMetadata(req.ctx, msgID, mock.TestMethod, nil) + if err != nil { + t.Fatalf("MarshalMetadata failed: %v", err) + } + req.md = md replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -593,7 +597,8 @@ func TestChannelDeadlock(t *testing.T) { for id := range 10 { go func() { ctx := TestContext(t, 3*time.Second) - req := request{ctx: ctx, msg: NewRequest(ctx, uint64(100+id), mock.TestMethod, nil)} + md, _ := MarshalMetadata(ctx, uint64(100+id), mock.TestMethod, nil) + req := request{ctx: ctx, md: md} // try to enqueue select { @@ -798,7 +803,8 @@ func BenchmarkChannelStreamReadyFirstRequest(b *testing.B) { // Use a fresh context for the benchmark request ctx := TestContext(b, defaultTestTimeout) - req := request{ctx: ctx, msg: NewRequest(ctx, 1, mock.TestMethod, nil)} + md, _ := MarshalMetadata(ctx, 1, mock.TestMethod, nil) + req := request{ctx: ctx, md: md} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -847,7 +853,8 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Establish initial stream with a fresh context ctx := context.Background() - req := request{ctx: ctx, msg: NewRequest(ctx, 0, mock.TestMethod, nil)} + md, _ := MarshalMetadata(ctx, 0, mock.TestMethod, nil) + req := request{ctx: ctx, md: md} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -872,7 +879,8 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Now send a request which will trigger ensureStream -> newNodeStream -> signal ctx := context.Background() - req := request{ctx: ctx, msg: NewRequest(ctx, uint64(i+1), mock.TestMethod, nil)} + md, _ := MarshalMetadata(ctx, uint64(i+1), mock.TestMethod, nil) + req := request{ctx: ctx, md: md} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) diff --git a/client_interceptor.go b/client_interceptor.go index 938d4648..938ac18c 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -211,11 +211,17 @@ func (c *ClientCtx[Req, Resp]) send() { } expected++ // Clone metadata for each request to avoid race conditions during - // concurrent marshaling when gorumsMarshal calls SetMessageData. + // concurrent marshaling when SetMessageData is called. md := proto.CloneOf(c.md) + // Marshal the proto message into the metadata's message_data field + msgData, err := proto.Marshal(msg) + if err != nil { + continue // Skip node if marshaling fails + } + md.SetMessageData(msgData) n.channel.enqueue(request{ ctx: c.Context, - msg: newRequestMessage(md, msg), + md: md, streaming: c.streaming, waitSendDone: c.waitSendDone, responseChan: c.replyChan, diff --git a/encoding.go b/encoding.go index 17805d75..7cbfc9ce 100644 --- a/encoding.go +++ b/encoding.go @@ -42,13 +42,6 @@ func newMessage(msgType gorumsMsgType) *Message { return &Message{metadata: &ordering.Metadata{}, msgType: msgType} } -// newRequestMessage creates a new Gorums Message for the given metadata and request message. -// -// This function should be used by generated code and tests only. -func newRequestMessage(md *ordering.Metadata, req proto.Message) *Message { - return &Message{metadata: md, message: req, msgType: requestType} -} - // NewRequest creates a new Gorums Message for the given context, message ID, method, and request. // This is a convenience function that combines NewGorumsMetadata and NewRequestMessage. // @@ -58,6 +51,22 @@ func NewRequest(ctx context.Context, msgID uint64, method string, req proto.Mess return &Message{metadata: md, message: req, msgType: requestType} } +// MarshalMetadata creates metadata with the serialized proto message for type-safe Send. +// It marshals the proto message into the metadata's message_data field. +// +// This function should be used by generated code and internal channel operations only. +func MarshalMetadata(ctx context.Context, msgID uint64, method string, msg proto.Message) (*ordering.Metadata, error) { + md := ordering.NewGorumsMetadata(ctx, msgID, method) + if msg != nil { + msgData, err := proto.Marshal(msg) + if err != nil { + return nil, err + } + md.SetMessageData(msgData) + } + return md, nil +} + // NewResponseMessage creates a new Gorums Message for the given metadata and response message. // // This function should be used by generated code only. @@ -232,3 +241,32 @@ func (c Codec) gorumsUnmarshal(b []byte, msg *Message) (err error) { } return nil } + +// UnmarshalResponse extracts and unmarshals the response proto message from metadata. +// It uses the method name in metadata to look up the Output type from the proto registry. +// +// This function should be used by internal channel operations only. +func UnmarshalResponse(md *ordering.Metadata) (proto.Message, error) { + // get method descriptor from registry + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(md.GetMethod())) + if err != nil { + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", md.GetMethod()) + } + methodDesc := desc.(protoreflect.MethodDescriptor) + + // get the response message type (Output type) + msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Output().FullName()) + if err != nil { + return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Output().FullName()) + } + resp := msgType.New().Interface() + + // unmarshal message from metadata.message_data + msgData := md.GetMessageData() + if len(msgData) > 0 { + if err := proto.Unmarshal(msgData, resp); err != nil { + return nil, fmt.Errorf("gorums: could not unmarshal response: %w", err) + } + } + return resp, nil +} diff --git a/rpc.go b/rpc.go index e16f1c55..69cef168 100644 --- a/rpc.go +++ b/rpc.go @@ -5,7 +5,12 @@ package gorums // This method should be used by generated code only. func RPCCall[Req, Resp msg](ctx *NodeContext, req Req, method string) (Resp, error) { replyChan := make(chan NodeResponse[msg], 1) - ctx.enqueue(request{ctx: ctx, msg: NewRequest(ctx, ctx.nextMsgID(), method, req), responseChan: replyChan}) + md, err := MarshalMetadata(ctx, ctx.nextMsgID(), method, req) + if err != nil { + var zero Resp + return zero, err + } + ctx.enqueue(request{ctx: ctx, md: md, responseChan: replyChan}) select { case r := <-replyChan: diff --git a/unicast.go b/unicast.go index 3aaad386..261f4dc2 100644 --- a/unicast.go +++ b/unicast.go @@ -13,18 +13,21 @@ package gorums // This method should be used by generated code only. func Unicast[Req msg](ctx *NodeContext, req Req, method string, opts ...CallOption) error { callOpts := getCallOptions(E_Unicast, opts...) - message := NewRequest(ctx, ctx.nextMsgID(), method, req) + md, err := MarshalMetadata(ctx, ctx.nextMsgID(), method, req) + if err != nil { + return err + } waitSendDone := callOpts.mustWaitSendDone() if !waitSendDone { // Fire-and-forget: enqueue and return immediately - ctx.enqueue(request{ctx: ctx, msg: message}) + ctx.enqueue(request{ctx: ctx, md: md}) return nil } // Default: block until send completes replyChan := make(chan NodeResponse[msg], 1) - ctx.enqueue(request{ctx: ctx, msg: message, waitSendDone: true, responseChan: replyChan}) + ctx.enqueue(request{ctx: ctx, md: md, waitSendDone: true, responseChan: replyChan}) // Wait for send confirmation select { From e5f7c3f76e51bdbb3e60a37566394b4c0662c23c Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 8 Feb 2026 19:01:21 +0100 Subject: [PATCH 02/19] refactor(server): use type-safe Send/Recv instead of SendMsg/RecvMsg Replace legacy SendMsg/RecvMsg with type-safe Send/Recv methods in server.go, continuing the work from issue #252. Changes: - Change finished channel from chan *Message to chan *ordering.Metadata - Use srv.Recv() to receive *ordering.Metadata directly - Use srv.Send(md) to send metadata directly - Add UnmarshalRequest() to convert metadata to *Message for handlers - Add MarshalResponseMetadata() to convert *Message to metadata for sending - Update ServerCtx to use *ordering.Metadata channel internally - Clone metadata in MarshalResponseMetadata to avoid race conditions The Handler interface remains unchanged - conversion between *Message and *ordering.Metadata happens at the server boundaries, so no changes are needed to generated code. --- encoding.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++------ server.go | 26 +++++++++++++++---------- 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/encoding.go b/encoding.go index 7cbfc9ce..ff9ba882 100644 --- a/encoding.go +++ b/encoding.go @@ -36,12 +36,6 @@ type Message struct { msgType gorumsMsgType } -// newMessage creates a new Message struct for unmarshaling. -// msgType specifies the message type to be unmarshaled. -func newMessage(msgType gorumsMsgType) *Message { - return &Message{metadata: &ordering.Metadata{}, msgType: msgType} -} - // NewRequest creates a new Gorums Message for the given context, message ID, method, and request. // This is a convenience function that combines NewGorumsMetadata and NewRequestMessage. // @@ -270,3 +264,53 @@ func UnmarshalResponse(md *ordering.Metadata) (proto.Message, error) { } return resp, nil } + +// UnmarshalRequest extracts and unmarshals the request proto message from metadata. +// It uses the method name in metadata to look up the Input type from the proto registry. +// Returns a *Message suitable for passing to handlers. +// +// This function should be used by server-side operations only. +func UnmarshalRequest(md *ordering.Metadata) (*Message, error) { + // get method descriptor from registry + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(md.GetMethod())) + if err != nil { + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", md.GetMethod()) + } + methodDesc := desc.(protoreflect.MethodDescriptor) + + // get the request message type (Input type) + msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Input().FullName()) + if err != nil { + return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Input().FullName()) + } + req := msgType.New().Interface() + + // unmarshal message from metadata.message_data + msgData := md.GetMessageData() + if len(msgData) > 0 { + if err := proto.Unmarshal(msgData, req); err != nil { + return nil, fmt.Errorf("gorums: could not unmarshal request: %w", err) + } + } + return &Message{metadata: md, message: req, msgType: requestType}, nil +} + +// MarshalResponseMetadata marshals a response message into metadata for type-safe Send. +// It clones the metadata to avoid race conditions with concurrent send operations. +// +// This function should be used by server-side operations only. +func MarshalResponseMetadata(msg *Message) (*ordering.Metadata, error) { + if msg == nil { + return nil, nil + } + // Clone metadata to avoid race with concurrent send operations + md := proto.CloneOf(msg.metadata) + if msg.message != nil { + msgData, err := proto.Marshal(msg.message) + if err != nil { + return nil, err + } + md.SetMessageData(msgData) + } + return md, nil +} diff --git a/server.go b/server.go index 49702847..c7379700 100644 --- a/server.go +++ b/server.go @@ -36,7 +36,7 @@ func newOrderingServer(opts *serverOptions) *orderingServer { // is any error with sending or receiving. func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error { var mut sync.Mutex // used to achieve mutex between request handlers - finished := make(chan *Message, s.opts.buffer) + finished := make(chan *ordering.Metadata, s.opts.buffer) ctx := srv.Context() if s.opts.connectCallback != nil { @@ -48,9 +48,8 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error select { case <-ctx.Done(): return - case msg := <-finished: - err := srv.SendMsg(msg) - if err != nil { + case md := <-finished: + if err := srv.Send(md); err != nil { return } } @@ -62,8 +61,11 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error defer mut.Unlock() for { - req := newMessage(requestType) - err := srv.RecvMsg(req) + md, err := srv.Recv() + if err != nil { + return err + } + req, err := UnmarshalRequest(md) if err != nil { return err } @@ -219,11 +221,11 @@ type ServerCtx struct { context.Context once *sync.Once // must be a pointer to avoid passing ctx by value mut *sync.Mutex - c chan<- *Message + c chan<- *ordering.Metadata } -// newServerCtx creates a new ServerCtx with the given context, mutex and message channel. -func newServerCtx(ctx context.Context, mut *sync.Mutex, c chan<- *Message) ServerCtx { +// newServerCtx creates a new ServerCtx with the given context, mutex and metadata channel. +func newServerCtx(ctx context.Context, mut *sync.Mutex, c chan<- *ordering.Metadata) ServerCtx { return ServerCtx{ Context: ctx, once: new(sync.Once), @@ -244,8 +246,12 @@ func (ctx *ServerCtx) Release() { // // This function should be used by generated code only. func (ctx *ServerCtx) SendMessage(msg *Message) error { + md, err := MarshalResponseMetadata(msg) + if err != nil { + return err + } select { - case ctx.c <- msg: + case ctx.c <- md: case <-ctx.Done(): return ctx.Err() } From 62cb5bb4930006f15f324c85d344622f84898f82 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 8 Feb 2026 20:11:34 +0100 Subject: [PATCH 03/19] refactor(server): simplify response with responseWithError helper Add responseWithError() helper function that creates a response message if needed and sets the error status in a single call. This simplifies the server's handler goroutine by: - Properly handling UnmarshalRequest errors (send error response to client) - Consolidating the nil message check and setError call into one line - Reducing the handler goroutine from 26 lines to 21 lines The helper is also useful for interceptors that need to return errors. --- encoding.go | 11 +++++++++++ server.go | 25 ++++++++++--------------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/encoding.go b/encoding.go index ff9ba882..3eddb0f5 100644 --- a/encoding.go +++ b/encoding.go @@ -132,6 +132,17 @@ func (m *Message) setError(err error) { m.metadata.SetStatus(errStatus.Proto()) } +// responseWithError ensures a response message exists and sets the error status. +// If msg is nil, a new response message is created using the provided metadata. +// This is used by the server to send error responses back to the client. +func responseWithError(msg *Message, md *ordering.Metadata, err error) *Message { + if msg == nil { + msg = NewResponseMessage(md, nil) + } + msg.setError(err) + return msg +} + // Codec is the gRPC codec used by gorums. type Codec struct { marshaler proto.MarshalOptions diff --git a/server.go b/server.go index c7379700..40ca4b21 100644 --- a/server.go +++ b/server.go @@ -65,11 +65,7 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error if err != nil { return err } - req, err := UnmarshalRequest(md) - if err != nil { - return err - } - if handler, ok := s.handlers[req.GetMethod()]; ok { + if handler, ok := s.handlers[md.GetMethod()]; ok { // We start the handler in a new goroutine in order to allow multiple handlers to run concurrently. // However, to preserve request ordering, the handler must unlock the shared mutex when it has either // finished, or when it is safe to start processing the next request. @@ -77,10 +73,15 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error // This func() is the default interceptor; it is the first and last handler in the chain. // It is responsible for releasing the mutex when the handler chain is done. go func() { - metadata := req.GetMetadata() - srvCtx := newServerCtx(metadata.AppendToIncomingContext(ctx), &mut, finished) + srvCtx := newServerCtx(md.AppendToIncomingContext(ctx), &mut, finished) defer srvCtx.Release() + req, err := UnmarshalRequest(md) + if err != nil { + _ = srvCtx.SendMessage(responseWithError(nil, md, err)) + return + } + message, err := handler(srvCtx, req) // If there is no message and no error, we do not send anything back to the client. // This corresponds to a unidirectional message from client to server, where clients @@ -88,15 +89,9 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error if message == nil && err == nil { return } - // If there was an error in the interceptor chain or the method handler, they may return a nil message. - // Thus, we need to create a response message to send the error back to the client. - if message == nil { - message = NewResponseMessage(metadata, nil) - } - message.setError(err) - _ = srvCtx.SendMessage(message) // to the client + _ = srvCtx.SendMessage(responseWithError(message, md, err)) // We ignore the error from SendMessage here; it means that the stream is closed. - // The for-loop above will exit on the next RecvMsg call. + // The for-loop above will exit on the next Recv call. }() // Wait until the handler releases the mutex. mut.Lock() From a9ab22e89b12e09641a69eebcbe92802520b722e Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 8 Feb 2026 20:29:57 +0100 Subject: [PATCH 04/19] refactor: remove custom gRPC codec The custom Codec is no longer needed now that we use type-safe Send/Recv methods with *ordering.Metadata. gRPC's default proto codec handles serialization directly. We instead provide custom functions. Removed: - Codec struct and methods (Marshal, Unmarshal, gorumsMarshal, gorumsUnmarshal) - NewCodec() constructor - ContentSubtype constant - init() function that registered the codec - grpc.CallContentSubtype dial option - Test init() functions that registered the codec - mgr_test.go (only contained codec registration) This simplifies the codebase by removing an abstraction layer that was only needed for the legacy SendMsg/RecvMsg approach. --- config_test.go | 7 -- encoding.go | 198 +++++++++++-------------------------------------- mgr.go | 3 - mgr_test.go | 30 -------- rpc_test.go | 7 -- server_test.go | 7 -- 6 files changed, 43 insertions(+), 209 deletions(-) delete mode 100644 mgr_test.go diff --git a/config_test.go b/config_test.go index eb40df16..c43b3755 100644 --- a/config_test.go +++ b/config_test.go @@ -5,15 +5,8 @@ import ( "testing" "github.com/relab/gorums" - "google.golang.org/grpc/encoding" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} - var ( nodes = []string{"127.0.0.1:9081", "127.0.0.1:9082", "127.0.0.1:9083"} nodeMap = map[uint32]testNode{ diff --git a/encoding.go b/encoding.go index 3eddb0f5..dbd969f2 100644 --- a/encoding.go +++ b/encoding.go @@ -6,20 +6,12 @@ import ( "github.com/relab/gorums/ordering" "google.golang.org/grpc/codes" - "google.golang.org/grpc/encoding" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" ) -func init() { - encoding.RegisterCodec(NewCodec()) -} - -// ContentSubtype is the subtype used by gorums when sending messages via gRPC. -const ContentSubtype = "gorums" - type gorumsMsgType uint8 const ( @@ -45,22 +37,6 @@ func NewRequest(ctx context.Context, msgID uint64, method string, req proto.Mess return &Message{metadata: md, message: req, msgType: requestType} } -// MarshalMetadata creates metadata with the serialized proto message for type-safe Send. -// It marshals the proto message into the metadata's message_data field. -// -// This function should be used by generated code and internal channel operations only. -func MarshalMetadata(ctx context.Context, msgID uint64, method string, msg proto.Message) (*ordering.Metadata, error) { - md := ordering.NewGorumsMetadata(ctx, msgID, method) - if msg != nil { - msgData, err := proto.Marshal(msg) - if err != nil { - return nil, err - } - md.SetMessageData(msgData) - } - return md, nil -} - // NewResponseMessage creates a new Gorums Message for the given metadata and response message. // // This function should be used by generated code only. @@ -143,108 +119,70 @@ func responseWithError(msg *Message, md *ordering.Metadata, err error) *Message return msg } -// Codec is the gRPC codec used by gorums. -type Codec struct { - marshaler proto.MarshalOptions - unmarshaler proto.UnmarshalOptions -} - -// NewCodec returns a new Codec. -func NewCodec() *Codec { - return &Codec{ - marshaler: proto.MarshalOptions{AllowPartial: true}, - unmarshaler: proto.UnmarshalOptions{AllowPartial: true}, +// MarshalMetadata creates metadata with the serialized proto message for type-safe Send. +// It marshals the proto message into the metadata's message_data field. +// +// This function should be used by client-side operations only. +func MarshalMetadata(ctx context.Context, msgID uint64, method string, msg proto.Message) (*ordering.Metadata, error) { + md := ordering.NewGorumsMetadata(ctx, msgID, method) + if msg != nil { + msgData, err := proto.Marshal(msg) + if err != nil { + return nil, err + } + md.SetMessageData(msgData) } + return md, nil } -// Name returns the name of the Codec. -func (Codec) Name() string { - return ContentSubtype -} - -func (Codec) String() string { - return ContentSubtype -} - -// Marshal marshals the message m into a byte slice. -func (c Codec) Marshal(m any) (b []byte, err error) { - switch msg := m.(type) { - case *Message: - return c.gorumsMarshal(msg) - case proto.Message: - return c.marshaler.Marshal(msg) - default: - return nil, fmt.Errorf("gorums: cannot marshal message of type '%T'", m) +// MarshalResponseMetadata marshals a response message into metadata for type-safe Send. +// It clones the metadata to avoid race conditions with concurrent send operations. +// +// This function should be used by server-side operations only. +func MarshalResponseMetadata(msg *Message) (*ordering.Metadata, error) { + if msg == nil { + return nil, nil } -} - -// gorumsMarshal marshals a Message by serializing the application message -// into metadata.message_data, then marshaling the metadata. -func (c Codec) gorumsMarshal(msg *Message) (b []byte, err error) { - // serialize the application message into metadata.message_data + // Clone metadata to avoid race with concurrent send operations + md := proto.CloneOf(msg.metadata) if msg.message != nil { - msgData, err := c.marshaler.Marshal(msg.message) + msgData, err := proto.Marshal(msg.message) if err != nil { - return nil, fmt.Errorf("gorums: could not marshal message: %w", err) + return nil, err } - msg.metadata.SetMessageData(msgData) - } - return c.marshaler.Marshal(msg.metadata) -} - -// Unmarshal unmarshals a byte slice into m. -func (c Codec) Unmarshal(b []byte, m any) (err error) { - switch msg := m.(type) { - case *Message: - return c.gorumsUnmarshal(b, msg) - case proto.Message: - return c.unmarshaler.Unmarshal(b, msg) - default: - return fmt.Errorf("gorums: cannot unmarshal message of type '%T'", m) + md.SetMessageData(msgData) } + return md, nil } -// gorumsUnmarshal extracts metadata and message data from b and places the result in msg. -func (c Codec) gorumsUnmarshal(b []byte, msg *Message) (err error) { - // unmarshal metadata - err = c.unmarshaler.Unmarshal(b, msg.metadata) - if err != nil { - return fmt.Errorf("gorums: could not unmarshal metadata: %w", err) - } - +// UnmarshalRequest extracts and unmarshals the request proto message from metadata. +// It uses the method name in metadata to look up the Input type from the proto registry. +// Returns a *Message suitable for passing to handlers. +// +// This function should be used by server-side operations only. +func UnmarshalRequest(md *ordering.Metadata) (*Message, error) { // get method descriptor from registry - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(msg.GetMethod())) + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(md.GetMethod())) if err != nil { - // err is a NotFound error with no method name information; return a more informative error - return fmt.Errorf("gorums: could not find method descriptor for %s", msg.GetMethod()) + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", md.GetMethod()) } methodDesc := desc.(protoreflect.MethodDescriptor) - // get message name depending on whether we are creating a request or response message - var messageName protoreflect.FullName - switch msg.msgType { - case requestType: - messageName = methodDesc.Input().FullName() - case responseType: - messageName = methodDesc.Output().FullName() - default: - return fmt.Errorf("gorums: unknown message type %d", msg.msgType) - } - - // now get the message type from the types registry - msgType, err := protoregistry.GlobalTypes.FindMessageByName(messageName) + // get the request message type (Input type) + msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Input().FullName()) if err != nil { - // err is a NotFound error with no message name information; return a more informative error - return fmt.Errorf("gorums: could not find message type %s", messageName) + return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Input().FullName()) } - msg.message = msgType.New().Interface() + req := msgType.New().Interface() // unmarshal message from metadata.message_data - msgData := msg.metadata.GetMessageData() + msgData := md.GetMessageData() if len(msgData) > 0 { - return c.unmarshaler.Unmarshal(msgData, msg.message) + if err := proto.Unmarshal(msgData, req); err != nil { + return nil, fmt.Errorf("gorums: could not unmarshal request: %w", err) + } } - return nil + return &Message{metadata: md, message: req, msgType: requestType}, nil } // UnmarshalResponse extracts and unmarshals the response proto message from metadata. @@ -275,53 +213,3 @@ func UnmarshalResponse(md *ordering.Metadata) (proto.Message, error) { } return resp, nil } - -// UnmarshalRequest extracts and unmarshals the request proto message from metadata. -// It uses the method name in metadata to look up the Input type from the proto registry. -// Returns a *Message suitable for passing to handlers. -// -// This function should be used by server-side operations only. -func UnmarshalRequest(md *ordering.Metadata) (*Message, error) { - // get method descriptor from registry - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(md.GetMethod())) - if err != nil { - return nil, fmt.Errorf("gorums: could not find method descriptor for %s", md.GetMethod()) - } - methodDesc := desc.(protoreflect.MethodDescriptor) - - // get the request message type (Input type) - msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Input().FullName()) - if err != nil { - return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Input().FullName()) - } - req := msgType.New().Interface() - - // unmarshal message from metadata.message_data - msgData := md.GetMessageData() - if len(msgData) > 0 { - if err := proto.Unmarshal(msgData, req); err != nil { - return nil, fmt.Errorf("gorums: could not unmarshal request: %w", err) - } - } - return &Message{metadata: md, message: req, msgType: requestType}, nil -} - -// MarshalResponseMetadata marshals a response message into metadata for type-safe Send. -// It clones the metadata to avoid race conditions with concurrent send operations. -// -// This function should be used by server-side operations only. -func MarshalResponseMetadata(msg *Message) (*ordering.Metadata, error) { - if msg == nil { - return nil, nil - } - // Clone metadata to avoid race with concurrent send operations - md := proto.CloneOf(msg.metadata) - if msg.message != nil { - msgData, err := proto.Marshal(msg.message) - if err != nil { - return nil, err - } - md.SetMessageData(msgData) - } - return md, nil -} diff --git a/mgr.go b/mgr.go index 3f407a14..18b0996b 100644 --- a/mgr.go +++ b/mgr.go @@ -37,9 +37,6 @@ func NewManager(opts ...ManagerOption) *Manager { if m.opts.logger != nil { m.logger = m.opts.logger } - m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithDefaultCallOptions( - grpc.CallContentSubtype(ContentSubtype), - )) if m.opts.backoff != backoff.DefaultConfig { m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithConnectParams( grpc.ConnectParams{Backoff: m.opts.backoff}, diff --git a/mgr_test.go b/mgr_test.go deleted file mode 100644 index a012c216..00000000 --- a/mgr_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorums - -import ( - "bytes" - "log" - "strings" - "testing" - - "google.golang.org/grpc/encoding" -) - -func init() { - if encoding.GetCodec(ContentSubtype) == nil { - encoding.RegisterCodec(NewCodec()) - } -} - -func TestManagerLogging(t *testing.T) { - var ( - buf bytes.Buffer - logger = log.New(&buf, "logger: ", log.Lshortfile) - ) - mgr := NewManager(InsecureDialOptions(t), WithLogger(logger)) - t.Cleanup(Closer(t, mgr)) - - want := "logger: mgr.go:49: ready" - if strings.TrimSpace(buf.String()) != want { - t.Errorf("logger: got %q, want %q", buf.String(), want) - } -} diff --git a/rpc_test.go b/rpc_test.go index 576662ea..0be0feb2 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -9,16 +9,9 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/testutils/mock" - "google.golang.org/grpc/encoding" pb "google.golang.org/protobuf/types/known/wrapperspb" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} - func TestRPCCallSuccess(t *testing.T) { node := gorums.TestNode(t, gorums.DefaultTestServer) diff --git a/server_test.go b/server_test.go index c70abae3..5d8064ad 100644 --- a/server_test.go +++ b/server_test.go @@ -8,17 +8,10 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/testutils/mock" - "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" pb "google.golang.org/protobuf/types/known/wrapperspb" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} - func TestServerCallback(t *testing.T) { var message string signal := make(chan struct{}) From 2e97ff835b1736809e607ca40baf3778827c4cff Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 8 Feb 2026 20:56:17 +0100 Subject: [PATCH 05/19] chore: unexport marshal and unmarshal functions --- channel.go | 2 +- channel_test.go | 12 ++++++------ encoding.go | 16 ++++++++-------- rpc.go | 2 +- server.go | 4 ++-- unicast.go | 2 +- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/channel.go b/channel.go index ea677b30..8576d20a 100644 --- a/channel.go +++ b/channel.go @@ -288,7 +288,7 @@ func (c *channel) receiver() { err := status.FromProto(md.GetStatus()).Err() var resp msg if err == nil { - resp, err = UnmarshalResponse(md) + resp, err = unmarshalResponse(md) } c.routeResponse(md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Value: resp, Err: err}) } diff --git a/channel_test.go b/channel_test.go index ea2823a4..50ae3283 100644 --- a/channel_test.go +++ b/channel_test.go @@ -59,9 +59,9 @@ func sendRequest(t testing.TB, node *Node, req request, msgID uint64) NodeRespon if req.ctx == nil { req.ctx = t.Context() } - md, err := MarshalMetadata(req.ctx, msgID, mock.TestMethod, nil) + md, err := marshalRequest(req.ctx, msgID, mock.TestMethod, nil) if err != nil { - t.Fatalf("MarshalMetadata failed: %v", err) + t.Fatalf("marshalRequest failed: %v", err) } req.md = md replyChan := make(chan NodeResponse[msg], 1) @@ -597,7 +597,7 @@ func TestChannelDeadlock(t *testing.T) { for id := range 10 { go func() { ctx := TestContext(t, 3*time.Second) - md, _ := MarshalMetadata(ctx, uint64(100+id), mock.TestMethod, nil) + md, _ := marshalRequest(ctx, uint64(100+id), mock.TestMethod, nil) req := request{ctx: ctx, md: md} // try to enqueue @@ -803,7 +803,7 @@ func BenchmarkChannelStreamReadyFirstRequest(b *testing.B) { // Use a fresh context for the benchmark request ctx := TestContext(b, defaultTestTimeout) - md, _ := MarshalMetadata(ctx, 1, mock.TestMethod, nil) + md, _ := marshalRequest(ctx, 1, mock.TestMethod, nil) req := request{ctx: ctx, md: md} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan @@ -853,7 +853,7 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Establish initial stream with a fresh context ctx := context.Background() - md, _ := MarshalMetadata(ctx, 0, mock.TestMethod, nil) + md, _ := marshalRequest(ctx, 0, mock.TestMethod, nil) req := request{ctx: ctx, md: md} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan @@ -879,7 +879,7 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Now send a request which will trigger ensureStream -> newNodeStream -> signal ctx := context.Background() - md, _ := MarshalMetadata(ctx, uint64(i+1), mock.TestMethod, nil) + md, _ := marshalRequest(ctx, uint64(i+1), mock.TestMethod, nil) req := request{ctx: ctx, md: md} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan diff --git a/encoding.go b/encoding.go index dbd969f2..5c2c1617 100644 --- a/encoding.go +++ b/encoding.go @@ -119,11 +119,11 @@ func responseWithError(msg *Message, md *ordering.Metadata, err error) *Message return msg } -// MarshalMetadata creates metadata with the serialized proto message for type-safe Send. +// marshalRequest marshals the request proto message into metadata for type-safe Send. // It marshals the proto message into the metadata's message_data field. // // This function should be used by client-side operations only. -func MarshalMetadata(ctx context.Context, msgID uint64, method string, msg proto.Message) (*ordering.Metadata, error) { +func marshalRequest(ctx context.Context, msgID uint64, method string, msg proto.Message) (*ordering.Metadata, error) { md := ordering.NewGorumsMetadata(ctx, msgID, method) if msg != nil { msgData, err := proto.Marshal(msg) @@ -135,11 +135,11 @@ func MarshalMetadata(ctx context.Context, msgID uint64, method string, msg proto return md, nil } -// MarshalResponseMetadata marshals a response message into metadata for type-safe Send. +// marshalResponse marshals the response message into metadata for type-safe Send. // It clones the metadata to avoid race conditions with concurrent send operations. // // This function should be used by server-side operations only. -func MarshalResponseMetadata(msg *Message) (*ordering.Metadata, error) { +func marshalResponse(msg *Message) (*ordering.Metadata, error) { if msg == nil { return nil, nil } @@ -155,12 +155,12 @@ func MarshalResponseMetadata(msg *Message) (*ordering.Metadata, error) { return md, nil } -// UnmarshalRequest extracts and unmarshals the request proto message from metadata. +// unmarshalRequest unmarshals the request proto message from metadata. // It uses the method name in metadata to look up the Input type from the proto registry. // Returns a *Message suitable for passing to handlers. // // This function should be used by server-side operations only. -func UnmarshalRequest(md *ordering.Metadata) (*Message, error) { +func unmarshalRequest(md *ordering.Metadata) (*Message, error) { // get method descriptor from registry desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(md.GetMethod())) if err != nil { @@ -185,11 +185,11 @@ func UnmarshalRequest(md *ordering.Metadata) (*Message, error) { return &Message{metadata: md, message: req, msgType: requestType}, nil } -// UnmarshalResponse extracts and unmarshals the response proto message from metadata. +// unmarshalResponse unmarshals the response proto message from metadata. // It uses the method name in metadata to look up the Output type from the proto registry. // // This function should be used by internal channel operations only. -func UnmarshalResponse(md *ordering.Metadata) (proto.Message, error) { +func unmarshalResponse(md *ordering.Metadata) (proto.Message, error) { // get method descriptor from registry desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(md.GetMethod())) if err != nil { diff --git a/rpc.go b/rpc.go index 69cef168..a6da3cf3 100644 --- a/rpc.go +++ b/rpc.go @@ -5,7 +5,7 @@ package gorums // This method should be used by generated code only. func RPCCall[Req, Resp msg](ctx *NodeContext, req Req, method string) (Resp, error) { replyChan := make(chan NodeResponse[msg], 1) - md, err := MarshalMetadata(ctx, ctx.nextMsgID(), method, req) + md, err := marshalRequest(ctx, ctx.nextMsgID(), method, req) if err != nil { var zero Resp return zero, err diff --git a/server.go b/server.go index 40ca4b21..979a49b0 100644 --- a/server.go +++ b/server.go @@ -76,7 +76,7 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error srvCtx := newServerCtx(md.AppendToIncomingContext(ctx), &mut, finished) defer srvCtx.Release() - req, err := UnmarshalRequest(md) + req, err := unmarshalRequest(md) if err != nil { _ = srvCtx.SendMessage(responseWithError(nil, md, err)) return @@ -241,7 +241,7 @@ func (ctx *ServerCtx) Release() { // // This function should be used by generated code only. func (ctx *ServerCtx) SendMessage(msg *Message) error { - md, err := MarshalResponseMetadata(msg) + md, err := marshalResponse(msg) if err != nil { return err } diff --git a/unicast.go b/unicast.go index 261f4dc2..d0beb7e5 100644 --- a/unicast.go +++ b/unicast.go @@ -13,7 +13,7 @@ package gorums // This method should be used by generated code only. func Unicast[Req msg](ctx *NodeContext, req Req, method string, opts ...CallOption) error { callOpts := getCallOptions(E_Unicast, opts...) - md, err := MarshalMetadata(ctx, ctx.nextMsgID(), method, req) + md, err := marshalRequest(ctx, ctx.nextMsgID(), method, req) if err != nil { return err } From a62bb77a8a46ddfb03743641c8dd14eefae54c3e Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 8 Feb 2026 21:06:44 +0100 Subject: [PATCH 06/19] refactor(encoding): remove msgType from Message struct After removing the custom codec we no longer need the msgType field in the gorums.Message struct. --- encoding.go | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/encoding.go b/encoding.go index 5c2c1617..a4ec44f2 100644 --- a/encoding.go +++ b/encoding.go @@ -12,20 +12,12 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" ) -type gorumsMsgType uint8 - -const ( - requestType gorumsMsgType = iota + 1 - responseType -) - // Message encapsulates a protobuf message and metadata. // // This struct should be used by generated code only. type Message struct { metadata *ordering.Metadata message proto.Message - msgType gorumsMsgType } // NewRequest creates a new Gorums Message for the given context, message ID, method, and request. @@ -34,14 +26,14 @@ type Message struct { // This function should be used by generated code and tests only. func NewRequest(ctx context.Context, msgID uint64, method string, req proto.Message) *Message { md := ordering.NewGorumsMetadata(ctx, msgID, method) - return &Message{metadata: md, message: req, msgType: requestType} + return &Message{metadata: md, message: req} } // NewResponseMessage creates a new Gorums Message for the given metadata and response message. // // This function should be used by generated code only. func NewResponseMessage(md *ordering.Metadata, resp proto.Message) *Message { - return &Message{metadata: md, message: resp, msgType: responseType} + return &Message{metadata: md, message: resp} } // AsProto returns msg's underlying protobuf message of the specified type T. @@ -182,7 +174,7 @@ func unmarshalRequest(md *ordering.Metadata) (*Message, error) { return nil, fmt.Errorf("gorums: could not unmarshal request: %w", err) } } - return &Message{metadata: md, message: req, msgType: requestType}, nil + return &Message{metadata: md, message: req}, nil } // unmarshalResponse unmarshals the response proto message from metadata. From 0a8e891e7db50357bec142741383d27d8fa9947a Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 8 Feb 2026 22:12:15 +0100 Subject: [PATCH 07/19] refactor: remove unused NewRequest We had only one remaining use of NewRequest in a test; this removes NewRequest and replaces its use in TestAsProto with NewResponseMessage. --- encoding.go | 9 --------- encoding_test.go | 8 ++++---- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/encoding.go b/encoding.go index a4ec44f2..0cce4bda 100644 --- a/encoding.go +++ b/encoding.go @@ -20,15 +20,6 @@ type Message struct { message proto.Message } -// NewRequest creates a new Gorums Message for the given context, message ID, method, and request. -// This is a convenience function that combines NewGorumsMetadata and NewRequestMessage. -// -// This function should be used by generated code and tests only. -func NewRequest(ctx context.Context, msgID uint64, method string, req proto.Message) *Message { - md := ordering.NewGorumsMetadata(ctx, msgID, method) - return &Message{metadata: md, message: req} -} - // NewResponseMessage creates a new Gorums Message for the given metadata and response message. // // This function should be used by generated code only. diff --git a/encoding_test.go b/encoding_test.go index 950fab25..b17bdced 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -18,7 +18,7 @@ func TestAsProto(t *testing.T) { }{ { name: "Success", - msg: gorums.NewRequest(t.Context(), 0, "", config.Request_builder{Num: 42}.Build()), + msg: gorums.NewResponseMessage(nil, config.Response_builder{Name: "test", Num: 42}.Build()), wantNil: false, wantNum: 42, }, @@ -29,7 +29,7 @@ func TestAsProto(t *testing.T) { }, { name: "WrongType", - msg: gorums.NewResponseMessage(nil, config.Response_builder{Name: "test", Num: 99}.Build()), + msg: gorums.NewResponseMessage(nil, config.Request_builder{Num: 99}.Build()), wantNil: true, }, } @@ -37,7 +37,7 @@ func TestAsProto(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - req := gorums.AsProto[*config.Request](tc.msg) + req := gorums.AsProto[*config.Response](tc.msg) if tc.wantNil { if req != nil { t.Errorf("AsProto returned %v, want nil", req) @@ -45,7 +45,7 @@ func TestAsProto(t *testing.T) { return } if req == nil { - t.Errorf("AsProto returned nil, want *config.Request") + t.Errorf("AsProto returned nil, want *config.Response") } if got := req.GetNum(); got != tc.wantNum { t.Errorf("Num = %d, want %d", got, tc.wantNum) From c8bc6e93947b70176b711971c0461d91eaec6bee Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 8 Feb 2026 22:14:25 +0100 Subject: [PATCH 08/19] refactor(metadata): consolidate metadata creation into NewMetadata This replaces prior uses of NewGorumsMetadata. --- client_interceptor.go | 4 ++-- encoding.go | 12 ++---------- ordering/gorums_metadata.go | 37 +++++++++++++++++++++++-------------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/client_interceptor.go b/client_interceptor.go index 938ac18c..faffbee3 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -117,8 +117,8 @@ func (b *clientCtxBuilder[Req, Resp]) WithWaitSendDone(waitSendDone bool) *clien // Build finalizes the ClientCtx configuration and returns the constructed instance. // It creates the metadata and reply channel, and sets up the appropriate response iterator. func (b *clientCtxBuilder[Req, Resp]) Build() *ClientCtx[Req, Resp] { - // Create metadata and reply channel at build time - b.c.md = ordering.NewGorumsMetadata(b.c.Context, b.c.config.nextMsgID(), b.c.method) + // Create metadata (without message) and reply channel at build time + b.c.md, _ = ordering.NewMetadata(b.c.Context, b.c.config.nextMsgID(), b.c.method, nil) b.c.replyChan = make(chan NodeResponse[msg], b.c.config.Size()*b.chanMultiplier) if b.c.streaming { diff --git a/encoding.go b/encoding.go index 0cce4bda..c7d8289d 100644 --- a/encoding.go +++ b/encoding.go @@ -106,16 +106,8 @@ func responseWithError(msg *Message, md *ordering.Metadata, err error) *Message // It marshals the proto message into the metadata's message_data field. // // This function should be used by client-side operations only. -func marshalRequest(ctx context.Context, msgID uint64, method string, msg proto.Message) (*ordering.Metadata, error) { - md := ordering.NewGorumsMetadata(ctx, msgID, method) - if msg != nil { - msgData, err := proto.Marshal(msg) - if err != nil { - return nil, err - } - md.SetMessageData(msgData) - } - return md, nil +func marshalRequest(ctx context.Context, msgID uint64, method string, req proto.Message) (*ordering.Metadata, error) { + return ordering.NewMetadata(ctx, msgID, method, req) } // marshalResponse marshals the response message into metadata for type-safe Send. diff --git a/ordering/gorums_metadata.go b/ordering/gorums_metadata.go index 0a4716cb..1ad0656b 100644 --- a/ordering/gorums_metadata.go +++ b/ordering/gorums_metadata.go @@ -4,32 +4,41 @@ import ( "context" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" ) -// NewGorumsMetadata creates a new Gorums metadata object for the given method -// and client message ID. It also appends any client-specific metadata from the -// context to the Gorums metadata object. +// NewMetadata creates a new [Metadata] proto message for the given method and message ID. +// If a non-nil proto message is provided, it is marshaled and included in the metadata. +// This function also extracts any client-specific metadata from the context and appends +// it to the metadata, allowing client-specific metadata to be passed to the server. // -// This is used to pass client-specific metadata to the server via Gorums. -// This method should be used by generated code only. -func NewGorumsMetadata(ctx context.Context, msgID uint64, method string) *Metadata { - gorumsMetadata := Metadata_builder{MessageSeqNo: msgID, Method: method} +// This method is intended for Gorums internal use. +func NewMetadata(ctx context.Context, msgID uint64, method string, msg proto.Message) (*Metadata, error) { + // Marshal the message to bytes (nil message returns nil bytes and no error) + msgBytes, err := proto.Marshal(msg) + if err != nil { + return nil, err + } + mdBuilder := Metadata_builder{ + MessageSeqNo: msgID, + Method: method, + MessageData: msgBytes, + } md, _ := metadata.FromOutgoingContext(ctx) for k, vv := range md { for _, v := range vv { entry := MetadataEntry_builder{Key: k, Value: v}.Build() - gorumsMetadata.Entry = append(gorumsMetadata.Entry, entry) + mdBuilder.Entry = append(mdBuilder.Entry, entry) } } - return gorumsMetadata.Build() + return mdBuilder.Build(), nil } -// AppendToIncomingContext appends client-specific metadata from the -// Gorums metadata object to the incoming context. +// AppendToIncomingContext appends client-specific metadata from the [Metadata] proto message +// to the incoming gRPC context, allowing server implementations to extract and use said +// metadata directly from the server method's context. // -// This is used to pass client-specific metadata from the Gorums runtime -// to the server implementation. -// This method should be used by generated code only. +// This method is intended for Gorums internal use. func (x *Metadata) AppendToIncomingContext(ctx context.Context) context.Context { existingMD, _ := metadata.FromIncomingContext(ctx) newMD := existingMD.Copy() // copy to avoid mutating the original From f7a191faf302570ffed2d29207d313a357eba52f Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 9 Feb 2026 23:39:34 +0100 Subject: [PATCH 09/19] refactor: rename ordering to stream and unify message envelope This commit refactors the internal messaging layer by replacing the `ordering` package with `stream` and unifying the message envelope structure. Key changes: - Rename the `ordering` package and proto files to `stream`. - Update `stream.proto` to define a single `Message` type containing both metadata and payload (as `bytes`), replacing the previous split between `Metadata` and application message. - Update `server.go` to: - Use the new `stream` package. - Implement explicit payload marshaling within `SendMessage`. - Send error messages to the client on marshaling failures instead of closing the stream. - Update `encoding.go` to support packing/unpacking payloads into `stream.Message`. - Update `client_interceptor.go` to work with the new `stream.Message` structure. - Update `cmd/protoc-gen-gorums` templates to generate code compatible with the new stream package and message format. --- .vscode/gorums.txt | 1 + Makefile | 2 +- channel.go | 34 +-- channel_test.go | 28 +-- client_interceptor.go | 82 +++++--- client_interceptor_test.go | 71 +++++++ .../dev/generated_code_test.go | 2 +- .../gengorums/gorums_func_map.go | 2 +- .../gengorums/template_server.go | 10 +- encoding.go | 195 +++++++----------- encoding_test.go | 97 ++++++++- examples/interceptors/server_interceptors.go | 48 +++-- quorumcall.go | 4 +- rpc.go | 6 +- server.go | 71 ++++--- server_test.go | 50 ++++- .../gorums_message.go | 25 +-- .../ordering.pb.go => stream/stream.pb.go | 124 +++++------ .../ordering.proto => stream/stream.proto | 12 +- .../stream_grpc.pb.go | 26 +-- testing_shared.go | 23 ++- unicast.go | 8 +- 22 files changed, 552 insertions(+), 369 deletions(-) rename ordering/gorums_metadata.go => stream/gorums_message.go (62%) rename ordering/ordering.pb.go => stream/stream.pb.go (61%) rename ordering/ordering.proto => stream/stream.proto (76%) rename ordering/ordering_grpc.pb.go => stream/stream_grpc.pb.go (86%) diff --git a/.vscode/gorums.txt b/.vscode/gorums.txt index 2896c781..e22ddbb8 100644 --- a/.vscode/gorums.txt +++ b/.vscode/gorums.txt @@ -69,6 +69,7 @@ pprof proto protobuf protoc +protocmp protodesc protogen protoimpl diff --git a/Makefile b/Makefile index 22b54312..9a003cd6 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ static_files := $(shell find $(dev_path) -name "*.go" -not -name "zorums*" -no proto_path := $(dev_path):third_party:. plugin_deps := gorums.pb.go $(static_file) -runtime_deps := ordering/ordering.pb.go ordering/ordering_grpc.pb.go +runtime_deps := stream/stream.pb.go stream/stream_grpc.pb.go benchmark_deps := benchmark/benchmark.pb.go benchmark/benchmark_gorums.pb.go .PHONY: all dev tools bootstrapgorums installgorums benchmark test compiletests genproto benchtest bench diff --git a/channel.go b/channel.go index 8576d20a..633ece43 100644 --- a/channel.go +++ b/channel.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/relab/gorums/ordering" + "github.com/relab/gorums/stream" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" @@ -44,9 +44,9 @@ var ( type request struct { ctx context.Context - md *ordering.Metadata - waitSendDone bool + msg *stream.Message streaming bool + waitSendDone bool responseChan chan<- NodeResponse[msg] sendTime time.Time } @@ -68,8 +68,8 @@ type channel struct { // Stream lifecycle management for FIFO ordered message delivery // stream is a bidirectional stream for - // sending and receiving ordering.Metadata messages. - stream ordering.Gorums_NodeStreamClient + // sending and receiving stream.Message messages. + stream stream.Gorums_NodeStreamClient streamMut sync.Mutex streamCtx context.Context streamCancel context.CancelFunc @@ -149,12 +149,12 @@ func (c *channel) ensureConnectedNodeStream() (err error) { return nil } c.streamCtx, c.streamCancel = context.WithCancel(c.connCtx) - c.stream, err = ordering.NewGorumsClient(c.conn).NodeStream(c.streamCtx) + c.stream, err = stream.NewGorumsClient(c.conn).NodeStream(c.streamCtx) return err } // getStream returns the current stream, or nil if no stream is available. -func (c *channel) getStream() ordering.Gorums_NodeStreamClient { +func (c *channel) getStream() stream.Gorums_NodeStreamClient { c.streamMut.Lock() defer c.streamMut.Unlock() return c.stream @@ -180,7 +180,7 @@ func (c *channel) isConnected() bool { func (c *channel) enqueue(req request) { if req.responseChan != nil { req.sendTime = time.Now() - msgID := req.md.GetMessageSeqNo() + msgID := req.msg.GetMessageSeqNo() c.responseMut.Lock() c.responseRouters[msgID] = req c.responseMut.Unlock() @@ -190,7 +190,7 @@ func (c *channel) enqueue(req request) { select { case <-c.connCtx.Done(): // the node's close() method was called: respond with error instead of enqueueing - c.routeResponse(req.md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: nodeClosedErr}) + c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: nodeClosedErr}) return case c.sendQ <- req: // enqueued successfully @@ -252,11 +252,11 @@ func (c *channel) sender() { // take next request from sendQ } if err := c.ensureStream(); err != nil { - c.routeResponse(req.md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) + c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) continue } if err := c.sendMsg(req); err != nil { - c.routeResponse(req.md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) + c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) } } } @@ -279,18 +279,18 @@ func (c *channel) receiver() { } } - md, e := stream.Recv() + respMsg, e := stream.Recv() if e != nil { c.setLastErr(e) c.cancelPendingMsgs(e) c.clearStream() } else { - err := status.FromProto(md.GetStatus()).Err() + err := status.FromProto(respMsg.GetStatus()).Err() var resp msg if err == nil { - resp, err = unmarshalResponse(md) + resp, err = unmarshalResponse(respMsg) } - c.routeResponse(md.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Value: resp, Err: err}) + c.routeResponse(respMsg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Value: resp, Err: err}) } select { @@ -320,7 +320,7 @@ func (c *channel) sendMsg(req request) (err error) { // wait for actual server responses, so waitSendDone is false for them. if req.waitSendDone && err == nil { // Send succeeded: unblock the caller and clean up the responseRouter - c.routeResponse(req.md.GetMessageSeqNo(), NodeResponse[msg]{}) + c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{}) } }() @@ -357,7 +357,7 @@ func (c *channel) sendMsg(req request) (err error) { } }() - if err = stream.Send(req.md); err != nil { + if err = stream.Send(req.msg); err != nil { c.setLastErr(err) c.clearStream() } diff --git a/channel_test.go b/channel_test.go index 50ae3283..a73e898a 100644 --- a/channel_test.go +++ b/channel_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/relab/gorums/internal/testutils/mock" + "github.com/relab/gorums/stream" "google.golang.org/grpc" pb "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -48,7 +49,10 @@ func delayServerFn(delay time.Duration) func(_ int) ServerIface { time.Sleep(delay) req := AsProto[*pb.StringValue](in) resp, err := mockSrv.Test(ctx, req) - return NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return NewResponseMessage(in, resp), nil }) return srv } @@ -59,11 +63,11 @@ func sendRequest(t testing.TB, node *Node, req request, msgID uint64) NodeRespon if req.ctx == nil { req.ctx = t.Context() } - md, err := marshalRequest(req.ctx, msgID, mock.TestMethod, nil) + reqMsg, err := stream.NewMessage(req.ctx, msgID, mock.TestMethod, nil) if err != nil { - t.Fatalf("marshalRequest failed: %v", err) + t.Fatalf("NewMessage failed: %v", err) } - req.md = md + req.msg = reqMsg replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -597,8 +601,8 @@ func TestChannelDeadlock(t *testing.T) { for id := range 10 { go func() { ctx := TestContext(t, 3*time.Second) - md, _ := marshalRequest(ctx, uint64(100+id), mock.TestMethod, nil) - req := request{ctx: ctx, md: md} + reqMsg, _ := stream.NewMessage(ctx, uint64(100+id), mock.TestMethod, nil) + req := request{ctx: ctx, msg: reqMsg} // try to enqueue select { @@ -803,8 +807,8 @@ func BenchmarkChannelStreamReadyFirstRequest(b *testing.B) { // Use a fresh context for the benchmark request ctx := TestContext(b, defaultTestTimeout) - md, _ := marshalRequest(ctx, 1, mock.TestMethod, nil) - req := request{ctx: ctx, md: md} + reqMsg, _ := stream.NewMessage(ctx, 1, mock.TestMethod, nil) + req := request{ctx: ctx, msg: reqMsg} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -853,8 +857,8 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Establish initial stream with a fresh context ctx := context.Background() - md, _ := marshalRequest(ctx, 0, mock.TestMethod, nil) - req := request{ctx: ctx, md: md} + reqMsg, _ := stream.NewMessage(ctx, 0, mock.TestMethod, nil) + req := request{ctx: ctx, msg: reqMsg} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -879,8 +883,8 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Now send a request which will trigger ensureStream -> newNodeStream -> signal ctx := context.Background() - md, _ := marshalRequest(ctx, uint64(i+1), mock.TestMethod, nil) - req := request{ctx: ctx, md: md} + reqMsg, _ := stream.NewMessage(ctx, uint64(i+1), mock.TestMethod, nil) + req := request{ctx: ctx, msg: reqMsg} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) diff --git a/client_interceptor.go b/client_interceptor.go index faffbee3..949151c8 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -1,11 +1,12 @@ package gorums import ( + "cmp" "context" "slices" "sync" - "github.com/relab/gorums/ordering" + "github.com/relab/gorums/stream" "google.golang.org/protobuf/proto" ) @@ -42,7 +43,7 @@ type ClientCtx[Req, Resp msg] struct { config Configuration request Req method string - md *ordering.Metadata + msgID uint64 replyChan chan NodeResponse[msg] // reqTransforms holds request transformation functions registered by interceptors. @@ -117,8 +118,10 @@ func (b *clientCtxBuilder[Req, Resp]) WithWaitSendDone(waitSendDone bool) *clien // Build finalizes the ClientCtx configuration and returns the constructed instance. // It creates the metadata and reply channel, and sets up the appropriate response iterator. func (b *clientCtxBuilder[Req, Resp]) Build() *ClientCtx[Req, Resp] { - // Create metadata (without message) and reply channel at build time - b.c.md, _ = ordering.NewMetadata(b.c.Context, b.c.config.nextMsgID(), b.c.method, nil) + // Assign a unique message ID and create the reply channel at build time. + // The stream.Message is created lazily in applyTransforms, where the + // request payload is marshaled together with the metadata. + b.c.msgID = b.c.config.nextMsgID() b.c.replyChan = make(chan NodeResponse[msg], b.c.config.Size()*b.chanMultiplier) if b.c.streaming { @@ -170,23 +173,6 @@ func (c *ClientCtx[Req, Resp]) Size() int { return c.config.Size() } -// applyTransforms returns the transformed request as a proto.Message, or nil if the result is -// invalid or the node should be skipped. It applies the registered transformation functions to -// the given request for the specified node. Transformation functions are applied in the order -// they were registered. -func (c *ClientCtx[Req, Resp]) applyTransforms(req Req, node *Node) proto.Message { - result := req - for _, transform := range c.reqTransforms { - result = transform(result, node) - } - if protoMsg, ok := any(result).(proto.Message); ok { - if protoMsg.ProtoReflect().IsValid() { - return protoMsg - } - } - return nil -} - // applyInterceptors chains the given interceptors, wrapping the response sequence. // Each interceptor receives the current response sequence and returns a new one. // Interceptors are applied in order, with each wrapping the previous result. @@ -204,24 +190,32 @@ func (c *ClientCtx[Req, Resp]) applyInterceptors(interceptors []any) { // (nodes may be skipped if a transformation returns nil). func (c *ClientCtx[Req, Resp]) send() { var expected int + + // Fast path: marshal once when no per-node transforms are registered. + var sharedMsg *stream.Message + if len(c.reqTransforms) == 0 { + var err error + sharedMsg, err = stream.NewMessage(c.Context, c.msgID, c.method, c.request) + if err != nil { + // Marshaling fails identically for all nodes; report and return. + for _, n := range c.config { + c.replyChan <- NodeResponse[msg]{NodeID: n.ID(), Err: err} + expected++ + } + c.expectedReplies = expected + return + } + } for _, n := range c.config { - msg := c.applyTransforms(c.request, n) - if msg == nil { - continue // Skip node if transformation returns nil + // transform only if there are registered transforms; otherwise reuse the shared message + streamMsg := cmp.Or(sharedMsg, c.transformAndMarshal(n)) + if streamMsg == nil { + continue // Skip node } expected++ - // Clone metadata for each request to avoid race conditions during - // concurrent marshaling when SetMessageData is called. - md := proto.CloneOf(c.md) - // Marshal the proto message into the metadata's message_data field - msgData, err := proto.Marshal(msg) - if err != nil { - continue // Skip node if marshaling fails - } - md.SetMessageData(msgData) n.channel.enqueue(request{ ctx: c.Context, - md: md, + msg: streamMsg, streaming: c.streaming, waitSendDone: c.waitSendDone, responseChan: c.replyChan, @@ -230,6 +224,26 @@ func (c *ClientCtx[Req, Resp]) send() { c.expectedReplies = expected } +// transformAndMarshal applies transformations to the request for the given node, +// then marshals it into a stream.Message. Returns nil if transformation fails +// or marshaling fails (in which case the error is sent on replyChan). +func (c *ClientCtx[Req, Resp]) transformAndMarshal(n *Node) *stream.Message { + result := c.request + for _, transform := range c.reqTransforms { + result = transform(result, n) + } + // Check if the result is valid + if protoReq, ok := any(result).(proto.Message); !ok || !protoReq.ProtoReflect().IsValid() { + return nil + } + streamMsg, err := stream.NewMessage(c.Context, c.msgID, c.method, result) + if err != nil { + c.replyChan <- NodeResponse[msg]{NodeID: n.ID(), Err: err} + return nil + } + return streamMsg +} + // defaultResponseSeq returns an iterator that yields at most c.expectedReplies responses // from nodes until the context is canceled or all expected responses are received. func (c *ClientCtx[Req, Resp]) defaultResponseSeq() ResponseSeq[Resp] { diff --git a/client_interceptor_test.go b/client_interceptor_test.go index 9755da86..9bc7513a 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -1,6 +1,7 @@ package gorums_test import ( + "fmt" "testing" "time" @@ -179,3 +180,73 @@ func TestCustomInterceptorWithMapRequest(t *testing.T) { t.Errorf("Expected 3 responses counted, got %d", count) } } + +// BenchmarkQuorumCallMapRequest compares the performance of quorum calls +// with and without per-node MapRequest interceptors. +// Without MapRequest, the request is marshaled once and shared across all nodes. +// With MapRequest, each node gets a cloned message with individually marshaled payload. +func BenchmarkQuorumCallMapRequest(b *testing.B) { + for _, numNodes := range []int{3, 7, 13} { + config := gorums.TestConfiguration(b, numNodes, gorums.EchoServerFn) + cfgCtx := config.Context(b.Context()) + + // Baseline: no interceptors — single marshal, shared message + b.Run(fmt.Sprintf("NoTransform/%d", numNodes), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + cfgCtx, + pb.String("benchmark payload"), + mock.TestMethod, + ) + if _, err := responses.Majority(); err != nil { + b.Fatal(err) + } + } + }) + + // With MapRequest identity transform — per-node clone + marshal + b.Run(fmt.Sprintf("MapRequestIdentity/%d", numNodes), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + cfgCtx, + pb.String("benchmark payload"), + mock.TestMethod, + gorums.Interceptors( + gorums.MapRequest[*pb.StringValue, *pb.StringValue]( + func(req *pb.StringValue, _ *gorums.Node) *pb.StringValue { + return req + }, + ), + ), + ) + if _, err := responses.Majority(); err != nil { + b.Fatal(err) + } + } + }) + + // With MapRequest that modifies the request per node + b.Run(fmt.Sprintf("MapRequestPerNode/%d", numNodes), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + cfgCtx, + pb.String("benchmark payload"), + mock.TestMethod, + gorums.Interceptors( + gorums.MapRequest[*pb.StringValue, *pb.StringValue]( + func(req *pb.StringValue, n *gorums.Node) *pb.StringValue { + return pb.String(fmt.Sprintf("%s-node-%d", req.GetValue(), n.ID())) + }, + ), + ), + ) + if _, err := responses.Majority(); err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/cmd/protoc-gen-gorums/dev/generated_code_test.go b/cmd/protoc-gen-gorums/dev/generated_code_test.go index d0d52e5b..5c53cd7a 100644 --- a/cmd/protoc-gen-gorums/dev/generated_code_test.go +++ b/cmd/protoc-gen-gorums/dev/generated_code_test.go @@ -25,7 +25,7 @@ func quorumCallServer(_ int) gorums.ServerIface { req := gorums.AsProto[*dev.Request](in) resp := &dev.Response{} resp.SetResult(int64(len(req.GetValue()))) - return gorums.NewResponseMessage(in.GetMetadata(), resp), nil + return gorums.NewResponseMessage(in, resp), nil }) return srv } diff --git a/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go b/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go index becfa5d8..b755aaf7 100644 --- a/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go +++ b/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go @@ -29,7 +29,7 @@ var importMap = map[string]protogen.GoImportPath{ "backoff": protogen.GoImportPath("google.golang.org/grpc/backoff"), "proto": protogen.GoImportPath("google.golang.org/protobuf/proto"), "gorums": protogen.GoImportPath("github.com/relab/gorums"), - "ordering": protogen.GoImportPath("github.com/relab/gorums/ordering"), + "stream": protogen.GoImportPath("github.com/relab/gorums/stream"), "protoreflect": protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect"), } diff --git a/cmd/protoc-gen-gorums/gengorums/template_server.go b/cmd/protoc-gen-gorums/gengorums/template_server.go index 17aee759..39f30d6c 100644 --- a/cmd/protoc-gen-gorums/gengorums/template_server.go +++ b/cmd/protoc-gen-gorums/gengorums/template_server.go @@ -39,14 +39,16 @@ func Register{{$service}}Server(srv *{{use "gorums.Server" $genFile}}, impl {{$s return nil, nil {{- else if isStreamingServer .}} err := impl.{{.GoName}}(ctx, req, func(resp *{{out $genFile .}}) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := {{use "proto.CloneOf" $genFile}}(in.GetMetadata()) - return ctx.SendMessage({{$newMessage}}(md, resp)) + out := {{$newMessage}}(in, resp) + return ctx.SendMessage(out) }) return nil, err {{- else }} resp, err := impl.{{.GoName}}(ctx, req) - return {{$newMessage}}(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return {{$newMessage}}(in, resp), nil {{- end}} }) {{- end}} diff --git a/encoding.go b/encoding.go index c7d8289d..377d4bb3 100644 --- a/encoding.go +++ b/encoding.go @@ -1,10 +1,10 @@ package gorums import ( - "context" + "cmp" "fmt" - "github.com/relab/gorums/ordering" + "github.com/relab/gorums/stream" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -12,134 +12,85 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" ) -// Message encapsulates a protobuf message and metadata. -// -// This struct should be used by generated code only. +// Message encapsulates the stream.Message and the actual proto.Message. type Message struct { - metadata *ordering.Metadata - message proto.Message + Msg proto.Message + *stream.Message } -// NewResponseMessage creates a new Gorums Message for the given metadata and response message. -// -// This function should be used by generated code only. -func NewResponseMessage(md *ordering.Metadata, resp proto.Message) *Message { - return &Message{metadata: md, message: resp} -} - -// AsProto returns msg's underlying protobuf message of the specified type T. -// If msg is nil or the contained message is not of type T, the zero value of T is returned. -func AsProto[T proto.Message](msg *Message) T { - var zero T - if msg == nil || msg.message == nil { - return zero - } - if req, ok := msg.message.(T); ok { - return req - } - return zero -} +// MetadataEntry is a type alias for stream.MetadataEntry. +type MetadataEntry = stream.MetadataEntry -// GetProtoMessage returns the protobuf message contained in the Message. -func (m *Message) GetProtoMessage() proto.Message { - if m == nil { - return nil - } - return m.message -} +// MetadataEntry_builder is a type alias for stream.MetadataEntry_builder. +type MetadataEntry_builder = stream.MetadataEntry_builder -// GetMetadata returns the metadata of the message. -func (m *Message) GetMetadata() *ordering.Metadata { - if m == nil { +// NewResponseMessage creates a new response message based on the provided proto +// message. The response message includes the message ID and method from the request +// message to facilitate routing the response back to the caller on the client side. +// The payload, error status, and metadata entries are left empty; the error status +// of the response message can be set using [messageWithError], and the payload will +// be marshaled by [ServerCtx.SendMessage]. This function is safe for concurrent use. +// +// This function should only be used in generated code. +func NewResponseMessage(in *Message, resp proto.Message) *Message { + if in == nil { return nil } - return m.metadata -} - -// GetMethod returns the method name from the message metadata. -func (m *Message) GetMethod() string { - if m == nil { - return "nil" + // Create a new stream.Message to avoid race conditions when the sender + // goroutine marshals while the handler creates the next response. + // This can happen for stream-based quorum calls where the handler can + // call SendMessage multiple times before returning. + msgBuilder := stream.Message_builder{ + MessageSeqNo: in.GetMessageSeqNo(), // needed in channel.routeResponse to lookup the response channel + Method: in.GetMethod(), // needed in unmarshalResponse to look up the response type in the proto registry + // Payload is left empty; SendMessage will marshal resp into the payload when sending the message + // Status is left empty; it can be set by messageWithError if needed } - return m.metadata.GetMethod() -} - -// GetMessageID returns the message ID from the message metadata. -func (m *Message) GetMessageID() uint64 { - if m == nil { - return 0 + return &Message{ + Msg: resp, + Message: msgBuilder.Build(), } - return m.metadata.GetMessageSeqNo() } -func (m *Message) GetStatus() *status.Status { - if m == nil { - return status.New(codes.Unknown, "nil message") +// AsProto extracts the payload from the message. +// If msg is nil or invalid, the zero value of T is returned. +// +// This function should only be used in generated code. +func AsProto[T proto.Message](msg *Message) T { + var zero T + if msg == nil || msg.Msg == nil { + return zero } - return status.FromProto(m.metadata.GetStatus()) -} - -// setError sets the error status in the message metadata in preparation for sending -// the response to the client. The provided error may include several wrapped errors. -// If err is nil, the status is set to OK. -// This method should be called just prior to sending the response to the client. -func (m *Message) setError(err error) { - errStatus, ok := status.FromError(err) - if !ok { - errStatus = status.New(codes.Unknown, err.Error()) + if req, ok := msg.Msg.(T); ok { + return req } - m.metadata.SetStatus(errStatus.Proto()) + return zero } -// responseWithError ensures a response message exists and sets the error status. -// If msg is nil, a new response message is created using the provided metadata. +// messageWithError ensures a response message exists and sets the error status. +// If out is nil, the in message (request) is reused to return the error status. // This is used by the server to send error responses back to the client. -func responseWithError(msg *Message, md *ordering.Metadata, err error) *Message { - if msg == nil { - msg = NewResponseMessage(md, nil) - } - msg.setError(err) - return msg -} - -// marshalRequest marshals the request proto message into metadata for type-safe Send. -// It marshals the proto message into the metadata's message_data field. -// -// This function should be used by client-side operations only. -func marshalRequest(ctx context.Context, msgID uint64, method string, req proto.Message) (*ordering.Metadata, error) { - return ordering.NewMetadata(ctx, msgID, method, req) -} - -// marshalResponse marshals the response message into metadata for type-safe Send. -// It clones the metadata to avoid race conditions with concurrent send operations. -// -// This function should be used by server-side operations only. -func marshalResponse(msg *Message) (*ordering.Metadata, error) { - if msg == nil { - return nil, nil - } - // Clone metadata to avoid race with concurrent send operations - md := proto.CloneOf(msg.metadata) - if msg.message != nil { - msgData, err := proto.Marshal(msg.message) - if err != nil { - return nil, err +func messageWithError(in, out *Message, err error) *Message { + msg := cmp.Or(out, in) + if err != nil { + errStatus, ok := status.FromError(err) + if !ok { + errStatus = status.New(codes.Unknown, err.Error()) } - md.SetMessageData(msgData) + msg.SetStatus(errStatus.Proto()) } - return md, nil + return msg } -// unmarshalRequest unmarshals the request proto message from metadata. -// It uses the method name in metadata to look up the Input type from the proto registry. -// Returns a *Message suitable for passing to handlers. +// unmarshalRequest unmarshals the request proto message from the message. +// It uses the method name in the message to look up the Input type from the proto registry. // -// This function should be used by server-side operations only. -func unmarshalRequest(md *ordering.Metadata) (*Message, error) { +// This function should only be used by internal channel operations. +func unmarshalRequest(in *stream.Message) (proto.Message, error) { // get method descriptor from registry - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(md.GetMethod())) + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(in.GetMethod())) if err != nil { - return nil, fmt.Errorf("gorums: could not find method descriptor for %s", md.GetMethod()) + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", in.GetMethod()) } methodDesc := desc.(protoreflect.MethodDescriptor) @@ -150,25 +101,25 @@ func unmarshalRequest(md *ordering.Metadata) (*Message, error) { } req := msgType.New().Interface() - // unmarshal message from metadata.message_data - msgData := md.GetMessageData() - if len(msgData) > 0 { - if err := proto.Unmarshal(msgData, req); err != nil { + // unmarshal message from the Message.Payload field + payload := in.GetPayload() + if len(payload) > 0 { + if err := proto.Unmarshal(payload, req); err != nil { return nil, fmt.Errorf("gorums: could not unmarshal request: %w", err) } } - return &Message{metadata: md, message: req}, nil + return req, nil } -// unmarshalResponse unmarshals the response proto message from metadata. -// It uses the method name in metadata to look up the Output type from the proto registry. +// unmarshalResponse unmarshals the response proto message from the message. +// It uses the method name in the message to look up the Output type from the proto registry. // -// This function should be used by internal channel operations only. -func unmarshalResponse(md *ordering.Metadata) (proto.Message, error) { +// This function should only be used by internal channel operations. +func unmarshalResponse(out *stream.Message) (proto.Message, error) { // get method descriptor from registry - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(md.GetMethod())) + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(out.GetMethod())) if err != nil { - return nil, fmt.Errorf("gorums: could not find method descriptor for %s", md.GetMethod()) + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", out.GetMethod()) } methodDesc := desc.(protoreflect.MethodDescriptor) @@ -179,10 +130,10 @@ func unmarshalResponse(md *ordering.Metadata) (proto.Message, error) { } resp := msgType.New().Interface() - // unmarshal message from metadata.message_data - msgData := md.GetMessageData() - if len(msgData) > 0 { - if err := proto.Unmarshal(msgData, resp); err != nil { + // unmarshal message from the Message.Payload field + payload := out.GetPayload() + if len(payload) > 0 { + if err := proto.Unmarshal(payload, resp); err != nil { return nil, fmt.Errorf("gorums: could not unmarshal response: %w", err) } } diff --git a/encoding_test.go b/encoding_test.go index b17bdced..b3e76585 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -3,10 +3,97 @@ package gorums_test import ( "testing" + "github.com/google/go-cmp/cmp" "github.com/relab/gorums" "github.com/relab/gorums/internal/tests/config" + "github.com/relab/gorums/stream" + "google.golang.org/protobuf/testing/protocmp" ) +func TestNewResponseMessage(t *testing.T) { + req := config.Request_builder{Num: 99}.Build() + resp := config.Response_builder{Name: "test", Num: 42}.Build() + + streamIn := stream.Message_builder{ + MessageSeqNo: 100, + Method: "/pkg.Svc/Call", + Payload: []byte("request payload"), + Entry: []*stream.MetadataEntry{stream.MetadataEntry_builder{Key: "key1", Value: "val1"}.Build()}, + }.Build() + streamOut := stream.Message_builder{MessageSeqNo: 100, Method: "/pkg.Svc/Call"}.Build() + + tests := []struct { + name string + in *gorums.Message + resp *config.Response + want *gorums.Message + }{ + { + name: "NilIn/NilResp/NilOut", + in: nil, + resp: nil, + want: nil, + }, + { + name: "NilIn/Resp/NilOut", + in: nil, + resp: resp, + want: nil, + }, + { + name: "NilReq/NilResp/StreamIn/StreamOut", + in: &gorums.Message{Msg: nil, Message: streamIn}, + resp: nil, + want: &gorums.Message{Msg: (*config.Response)(nil), Message: streamOut}, + }, + { + name: "NilReq/Resp/StreamIn/StreamOut", + in: &gorums.Message{Msg: nil, Message: streamIn}, + resp: resp, + want: &gorums.Message{Msg: resp, Message: streamOut}, + }, + { + name: "Req/NilResp/StreamIn/StreamOut", + in: &gorums.Message{Msg: req, Message: streamIn}, + resp: nil, + want: &gorums.Message{Msg: (*config.Response)(nil), Message: streamOut}, + }, + { + name: "Req/Resp/StreamIn/StreamOut", + in: &gorums.Message{Msg: req, Message: streamIn}, + resp: resp, + want: &gorums.Message{Msg: resp, Message: streamOut}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := gorums.NewResponseMessage(tt.in, tt.resp) + if tt.want == nil { + if got != nil { + t.Errorf("NewResponseMessage returned %v, want nil", got) + } + return + } + if got == nil { + t.Fatalf("NewResponseMessage returned nil, want non-nil") + } + // Compare Msg field - handle nil specially + if (tt.want.Msg == nil) != (got.Msg == nil) { + t.Errorf("Msg field: want nil=%v, got nil=%v", tt.want.Msg == nil, got.Msg == nil) + } else if tt.want.Msg != nil && got.Msg != nil { + if diff := cmp.Diff(tt.want.Msg, got.Msg, protocmp.Transform()); diff != "" { + t.Errorf("Msg field mismatch (-want, +got):\n%s", diff) + } + } + // Compare the stream.Message field + if diff := cmp.Diff(tt.want.Message, got.Message, protocmp.Transform()); diff != "" { + t.Errorf("Message field mismatch (-want, +got):\n%s", diff) + } + }) + } +} + func TestAsProto(t *testing.T) { t.Parallel() @@ -18,7 +105,7 @@ func TestAsProto(t *testing.T) { }{ { name: "Success", - msg: gorums.NewResponseMessage(nil, config.Response_builder{Name: "test", Num: 42}.Build()), + msg: gorums.NewResponseMessage(&gorums.Message{}, config.Response_builder{Name: "test", Num: 42}.Build()), wantNil: false, wantNum: 42, }, @@ -29,7 +116,7 @@ func TestAsProto(t *testing.T) { }, { name: "WrongType", - msg: gorums.NewResponseMessage(nil, config.Request_builder{Num: 99}.Build()), + msg: gorums.NewResponseMessage(&gorums.Message{}, config.Request_builder{Num: 99}.Build()), wantNil: true, }, } @@ -40,15 +127,15 @@ func TestAsProto(t *testing.T) { req := gorums.AsProto[*config.Response](tc.msg) if tc.wantNil { if req != nil { - t.Errorf("AsProto returned %v, want nil", req) + t.Errorf("AsProto(%v) returned %v, want nil", tc.msg, req) } return } if req == nil { - t.Errorf("AsProto returned nil, want *config.Response") + t.Errorf("AsProto(%v) returned nil, want *config.Response", tc.msg) } if got := req.GetNum(); got != tc.wantNum { - t.Errorf("Num = %d, want %d", got, tc.wantNum) + t.Errorf("Num() = %d, want %d", got, tc.wantNum) } }) } diff --git a/examples/interceptors/server_interceptors.go b/examples/interceptors/server_interceptors.go index b5a40243..2a001e9e 100644 --- a/examples/interceptors/server_interceptors.go +++ b/examples/interceptors/server_interceptors.go @@ -6,30 +6,34 @@ import ( "time" "github.com/relab/gorums" - "github.com/relab/gorums/ordering" "google.golang.org/grpc/peer" + "google.golang.org/protobuf/proto" ) func LoggingInterceptor(addr string) gorums.Interceptor { - return func(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { - log.Printf("[%s]: LoggingInterceptor(incoming): Method=%s, Message=%s", addr, msg.GetMethod(), msg.GetProtoMessage()) + return func(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { + req := gorums.AsProto[proto.Message](in) + log.Printf("[%s]: LoggingInterceptor(incoming): Method=%s, Message=%s", addr, in.GetMethod(), req) start := time.Now() - resp, err := next(ctx, msg) + out, err := next(ctx, in) duration := time.Since(start) - log.Printf("[%s]: LoggingInterceptor(outgoing): Method=%s, Duration=%s, Err=%v, Message=%v, Type=%T", addr, msg.GetMethod(), duration, err, resp.GetProtoMessage(), resp.GetProtoMessage()) - return resp, err + resp := gorums.AsProto[proto.Message](out) + log.Printf("[%s]: LoggingInterceptor(outgoing): Method=%s, Duration=%s, Err=%v, Message=%v, Type=%T", addr, in.GetMethod(), duration, err, resp, resp) + return out, err } } -func LoggingSimpleInterceptor(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { - log.Printf("LoggingSimpleInterceptor(incoming): Method=%s, Message=%v)", msg.GetMethod(), msg.GetProtoMessage()) - resp, err := next(ctx, msg) - log.Printf("LoggingSimpleInterceptor(outgoing): Method=%s, Err=%v, Message=%v", msg.GetMethod(), err, resp.GetProtoMessage()) - return resp, err +func LoggingSimpleInterceptor(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { + req := gorums.AsProto[proto.Message](in) + log.Printf("LoggingSimpleInterceptor(incoming): Method=%s, Message=%v)", in.GetMethod(), req) + out, err := next(ctx, in) + resp := gorums.AsProto[proto.Message](out) + log.Printf("LoggingSimpleInterceptor(outgoing): Method=%s, Err=%v, Message=%v", in.GetMethod(), err, resp) + return out, err } -func DelayedInterceptor(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { +func DelayedInterceptor(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { // delay based on sending node address delay := 0 * time.Millisecond peer, ok := peer.FromContext(ctx) @@ -45,35 +49,35 @@ func DelayedInterceptor(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.H time.Sleep(delay) // Call the next handler in the chain - resp, err := next(ctx, msg) + out, err := next(ctx, in) log.Printf("DelayedInterceptor: Finished processing message after %s", delay) - return resp, err + return out, err } /** NoFooAllowedInterceptor rejects requests for messages with key "foo". */ -func NoFooAllowedInterceptor[T interface{ GetKey() string }](ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { - if req, ok := msg.GetProtoMessage().(T); ok { +func NoFooAllowedInterceptor[T interface{ GetKey() string }](ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { + if req, ok := gorums.AsProto[proto.Message](in).(T); ok { log.Printf("NoFooAllowedInterceptor: Received request for key '%s'", req.GetKey()) if req.GetKey() == "foo" { log.Printf("NoFooAllowedInterceptor: Rejecting request for key 'foo'") return nil, fmt.Errorf("requests for key 'foo' are not allowed") } } - return next(ctx, msg) + return next(ctx, in) } -func MetadataInterceptor(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { +func MetadataInterceptor(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { log.Printf("MetadataInterceptor: Adding custom metadata to message(customKey=customValue)") // Add a custom metadata field - entry := ordering.MetadataEntry_builder{ + entry := gorums.MetadataEntry_builder{ Key: "customKey", Value: "customValue", }.Build() - msg.GetMetadata().SetEntry([]*ordering.MetadataEntry{ + in.SetEntry([]*gorums.MetadataEntry{ entry, }) // Call the next handler in the chain - resp, err := next(ctx, msg) + out, err := next(ctx, in) log.Printf("MetadataInterceptor: Finished processing message with custom metadata") - return resp, err + return out, err } diff --git a/quorumcall.go b/quorumcall.go index afbc91ac..80d168ee 100644 --- a/quorumcall.go +++ b/quorumcall.go @@ -15,7 +15,7 @@ package gorums // or iterator method (like Seq) is called, applying any registered request transformations. // This lazy sending is necessary to allow interceptors to register transformations prior to dispatch. // -// This function should be used by generated code only. +// This function should only be used by generated code. func QuorumCall[Req, Resp msg]( ctx *ConfigContext, req Req, @@ -31,7 +31,7 @@ func QuorumCall[Req, Resp msg]( // In streaming mode, the response iterator continues indefinitely until the context // is canceled, allowing the server to send multiple responses over time. // -// This function should be used by generated code only. +// This function should only be used by generated code. func QuorumCallStream[Req, Resp msg]( ctx *ConfigContext, req Req, diff --git a/rpc.go b/rpc.go index a6da3cf3..7b2441c0 100644 --- a/rpc.go +++ b/rpc.go @@ -1,16 +1,18 @@ package gorums +import "github.com/relab/gorums/stream" + // RPCCall executes a remote procedure call on the node. // // This method should be used by generated code only. func RPCCall[Req, Resp msg](ctx *NodeContext, req Req, method string) (Resp, error) { replyChan := make(chan NodeResponse[msg], 1) - md, err := marshalRequest(ctx, ctx.nextMsgID(), method, req) + reqMsg, err := stream.NewMessage(ctx, ctx.nextMsgID(), method, req) if err != nil { var zero Resp return zero, err } - ctx.enqueue(request{ctx: ctx, md: md, responseChan: replyChan}) + ctx.enqueue(request{ctx: ctx, msg: reqMsg, responseChan: replyChan}) select { case r := <-replyChan: diff --git a/server.go b/server.go index 979a49b0..7c733170 100644 --- a/server.go +++ b/server.go @@ -5,8 +5,9 @@ import ( "net" "sync" - "github.com/relab/gorums/ordering" + "github.com/relab/gorums/stream" "google.golang.org/grpc" + "google.golang.org/protobuf/proto" ) type ( @@ -15,18 +16,18 @@ type ( // a Handler representing the next element in the chain (either another // Interceptor or the actual server method). It returns a Message and an error. Interceptor func(ServerCtx, *Message, Handler) (*Message, error) - // Handler is a function that processes a request message and returns a response message. + // Handler is a function that processes a request and returns a response. Handler func(ServerCtx, *Message) (*Message, error) ) -type orderingServer struct { +type streamServer struct { handlers map[string]Handler opts *serverOptions - ordering.UnimplementedGorumsServer + stream.UnimplementedGorumsServer } -func newOrderingServer(opts *serverOptions) *orderingServer { - return &orderingServer{ +func newStreamServer(opts *serverOptions) *streamServer { + return &streamServer{ handlers: make(map[string]Handler), opts: opts, } @@ -34,9 +35,9 @@ func newOrderingServer(opts *serverOptions) *orderingServer { // NodeStream handles a connection to a single client. The stream is aborted if there // is any error with sending or receiving. -func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error { +func (s *streamServer) NodeStream(srv stream.Gorums_NodeStreamServer) error { var mut sync.Mutex // used to achieve mutex between request handlers - finished := make(chan *ordering.Metadata, s.opts.buffer) + finished := make(chan *stream.Message, s.opts.buffer) ctx := srv.Context() if s.opts.connectCallback != nil { @@ -48,8 +49,8 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error select { case <-ctx.Done(): return - case md := <-finished: - if err := srv.Send(md); err != nil { + case streamOut := <-finished: + if err := srv.Send(streamOut); err != nil { return } } @@ -61,11 +62,11 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error defer mut.Unlock() for { - md, err := srv.Recv() + streamIn, err := srv.Recv() if err != nil { return err } - if handler, ok := s.handlers[md.GetMethod()]; ok { + if handler, ok := s.handlers[streamIn.GetMethod()]; ok { // We start the handler in a new goroutine in order to allow multiple handlers to run concurrently. // However, to preserve request ordering, the handler must unlock the shared mutex when it has either // finished, or when it is safe to start processing the next request. @@ -73,23 +74,24 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error // This func() is the default interceptor; it is the first and last handler in the chain. // It is responsible for releasing the mutex when the handler chain is done. go func() { - srvCtx := newServerCtx(md.AppendToIncomingContext(ctx), &mut, finished) + srvCtx := newServerCtx(streamIn.AppendToIncomingContext(ctx), &mut, finished) defer srvCtx.Release() - req, err := unmarshalRequest(md) + msg, err := unmarshalRequest(streamIn) + in := &Message{Msg: msg, Message: streamIn} if err != nil { - _ = srvCtx.SendMessage(responseWithError(nil, md, err)) + _ = srvCtx.SendMessage(messageWithError(in, nil, err)) return } - message, err := handler(srvCtx, req) - // If there is no message and no error, we do not send anything back to the client. + out, err := handler(srvCtx, in) + // If there is no response and no error, we do not send anything back to the client. // This corresponds to a unidirectional message from client to server, where clients // are not expected to receive a response. - if message == nil && err == nil { + if out == nil && err == nil { return } - _ = srvCtx.SendMessage(responseWithError(message, md, err)) + _ = srvCtx.SendMessage(messageWithError(in, out, err)) // We ignore the error from SendMessage here; it means that the stream is closed. // The for-loop above will exit on the next Recv call. }() @@ -158,8 +160,8 @@ func chainInterceptors(final Handler, interceptors ...Interceptor) Handler { for i := len(interceptors) - 1; i >= 0; i-- { curr := interceptors[i] next := handler - handler = func(ctx ServerCtx, msg *Message) (*Message, error) { - return curr(ctx, msg, next) + handler = func(ctx ServerCtx, in *Message) (*Message, error) { + return curr(ctx, in, next) } } return handler @@ -167,7 +169,7 @@ func chainInterceptors(final Handler, interceptors ...Interceptor) Handler { // Server serves all ordering based RPCs using registered handlers. type Server struct { - srv *orderingServer + srv *streamServer grpcServer *grpc.Server interceptors []Interceptor } @@ -179,11 +181,11 @@ func NewServer(opts ...ServerOption) *Server { opt(&serverOpts) } s := &Server{ - srv: newOrderingServer(&serverOpts), + srv: newStreamServer(&serverOpts), grpcServer: grpc.NewServer(serverOpts.grpcOpts...), interceptors: serverOpts.interceptors, } - ordering.RegisterGorumsServer(s.grpcServer, s.srv) + stream.RegisterGorumsServer(s.grpcServer, s.srv) return s } @@ -216,11 +218,11 @@ type ServerCtx struct { context.Context once *sync.Once // must be a pointer to avoid passing ctx by value mut *sync.Mutex - c chan<- *ordering.Metadata + c chan<- *stream.Message } // newServerCtx creates a new ServerCtx with the given context, mutex and metadata channel. -func newServerCtx(ctx context.Context, mut *sync.Mutex, c chan<- *ordering.Metadata) ServerCtx { +func newServerCtx(ctx context.Context, mut *sync.Mutex, c chan<- *stream.Message) ServerCtx { return ServerCtx{ Context: ctx, once: new(sync.Once), @@ -239,14 +241,19 @@ func (ctx *ServerCtx) Release() { // SendMessage attempts to send the given message to the client. // This may fail if the stream was closed or the stream context got canceled. // -// This function should be used by generated code only. -func (ctx *ServerCtx) SendMessage(msg *Message) error { - md, err := marshalResponse(msg) - if err != nil { - return err +// This function should only be used by generated code. +func (ctx *ServerCtx) SendMessage(out *Message) error { + // If Msg is set, marshal it to payload before sending. + if out.Msg != nil && len(out.GetPayload()) == 0 { + payload, err := proto.Marshal(out.Msg) + if err == nil { + out.SetPayload(payload) + } + // Return an error to the client if marshaling failed on the server side; don't close the stream. + out = messageWithError(nil, out, err) } select { - case ctx.c <- md: + case ctx.c <- out.Message: case <-ctx.Done(): return ctx.Err() } diff --git a/server_test.go b/server_test.go index 5d8064ad..ac449d35 100644 --- a/server_test.go +++ b/server_test.go @@ -9,6 +9,7 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/testutils/mock" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" pb "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -38,17 +39,41 @@ func TestServerCallback(t *testing.T) { } } -func appendStringInterceptor(in, out string) gorums.Interceptor { - return func(ctx gorums.ServerCtx, inMsg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { - req := gorums.AsProto[*pb.StringValue](inMsg) +func appendStringInterceptor(inStr, outStr string) gorums.Interceptor { + return func(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { + req := gorums.AsProto[*pb.StringValue](in) // update the underlying request gorums.Message's message field (pb.StringValue in this case) - req.Value += in + req.Value += inStr + + // TODO(meling): I did not like this change; I need to think more about this; + // can we make interceptors modify the request and responses in a cleaner way? + // Maybe we can add a method to stream.Message to modify/marshal the payload? + // Or maybe interceptors can operate on a InterceptorMessage wrapper that can + // hold the proto message that can be manipulated directly, and only on the way + // out of the interceptor chain, the wrapper is marshaled into the stream.Message.Payload. + + // re-marshal the modified request into the in message + reqPayload, err := proto.Marshal(req) + if err != nil { + return nil, err + } + in.SetPayload(reqPayload) + // call the next handler - outMsg, err := next(ctx, inMsg) - resp := gorums.AsProto[*pb.StringValue](outMsg) + out, err := next(ctx, in) + if err != nil { + return nil, err + } + resp := gorums.AsProto[*pb.StringValue](out) // update the underlying response gorums.Message's message field (pb.StringValue in this case) - resp.Value += out - return outMsg, err + resp.Value += outStr + // re-marshal the modified response into the out message + respPayload, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + out.SetPayload(respPayload) + return out, err } } @@ -70,7 +95,10 @@ func TestServerInterceptorsChain(t *testing.T) { s.RegisterHandler(mock.TestMethod, func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*pb.StringValue](in) resp, err := interceptorSrv.Test(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) return s } @@ -97,7 +125,7 @@ func TestTCPReconnection(t *testing.T) { srv := gorums.NewServer() srv.RegisterHandler(mock.TestMethod, func(_ gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*pb.StringValue](in) - return gorums.NewResponseMessage(in.GetMetadata(), req), nil + return gorums.NewResponseMessage(in, req), nil }) lis, err := net.Listen("tcp", "127.0.0.1:0") @@ -153,7 +181,7 @@ func TestTCPReconnection(t *testing.T) { srv2 := gorums.NewServer() srv2.RegisterHandler(mock.TestMethod, func(_ gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*pb.StringValue](in) - return gorums.NewResponseMessage(in.GetMetadata(), req), nil + return gorums.NewResponseMessage(in, req), nil }) go func() { _ = srv2.Serve(lis2) diff --git a/ordering/gorums_metadata.go b/stream/gorums_message.go similarity index 62% rename from ordering/gorums_metadata.go rename to stream/gorums_message.go index 1ad0656b..c1d2b8c7 100644 --- a/ordering/gorums_metadata.go +++ b/stream/gorums_message.go @@ -1,4 +1,4 @@ -package ordering +package stream import ( "context" @@ -7,39 +7,40 @@ import ( "google.golang.org/protobuf/proto" ) -// NewMetadata creates a new [Metadata] proto message for the given method and message ID. -// If a non-nil proto message is provided, it is marshaled and included in the metadata. +// NewMessage creates a new [Message] proto message for the given method and message ID. +// If a non-nil proto message is provided, it is marshaled and included in the message payload. // This function also extracts any client-specific metadata from the context and appends -// it to the metadata, allowing client-specific metadata to be passed to the server. +// it to the message, allowing client-specific metadata to be passed to the server. // // This method is intended for Gorums internal use. -func NewMetadata(ctx context.Context, msgID uint64, method string, msg proto.Message) (*Metadata, error) { +// This function is used on the client-side to create outgoing request messages. +func NewMessage(ctx context.Context, msgID uint64, method string, msg proto.Message) (*Message, error) { // Marshal the message to bytes (nil message returns nil bytes and no error) - msgBytes, err := proto.Marshal(msg) + payload, err := proto.Marshal(msg) if err != nil { return nil, err } - mdBuilder := Metadata_builder{ + msgBuilder := Message_builder{ MessageSeqNo: msgID, Method: method, - MessageData: msgBytes, + Payload: payload, } md, _ := metadata.FromOutgoingContext(ctx) for k, vv := range md { for _, v := range vv { entry := MetadataEntry_builder{Key: k, Value: v}.Build() - mdBuilder.Entry = append(mdBuilder.Entry, entry) + msgBuilder.Entry = append(msgBuilder.Entry, entry) } } - return mdBuilder.Build(), nil + return msgBuilder.Build(), nil } -// AppendToIncomingContext appends client-specific metadata from the [Metadata] proto message +// AppendToIncomingContext appends client-specific metadata from the [Message] proto message // to the incoming gRPC context, allowing server implementations to extract and use said // metadata directly from the server method's context. // // This method is intended for Gorums internal use. -func (x *Metadata) AppendToIncomingContext(ctx context.Context) context.Context { +func (x *Message) AppendToIncomingContext(ctx context.Context) context.Context { existingMD, _ := metadata.FromIncomingContext(ctx) newMD := existingMD.Copy() // copy to avoid mutating the original for _, entry := range x.GetEntry() { diff --git a/ordering/ordering.pb.go b/stream/stream.pb.go similarity index 61% rename from ordering/ordering.pb.go rename to stream/stream.pb.go index 362656ca..213e2419 100644 --- a/ordering/ordering.pb.go +++ b/stream/stream.pb.go @@ -2,9 +2,9 @@ // versions: // protoc-gen-go v1.36.11 // protoc v6.33.4 -// source: ordering/ordering.proto +// source: stream/stream.proto -package ordering +package stream import ( status "google.golang.org/genproto/googleapis/rpc/status" @@ -21,34 +21,34 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// Metadata is sent together with application-specific message types, +// Message is sent together with application-specific message types, // and contains information necessary for Gorums to handle the messages. -type Metadata struct { +type Message struct { state protoimpl.MessageState `protogen:"opaque.v1"` xxx_hidden_MessageSeqNo uint64 `protobuf:"varint,1,opt,name=message_seq_no,json=messageSeqNo"` xxx_hidden_Method string `protobuf:"bytes,2,opt,name=method"` xxx_hidden_Status *status.Status `protobuf:"bytes,3,opt,name=status"` xxx_hidden_Entry *[]*MetadataEntry `protobuf:"bytes,4,rep,name=entry"` - xxx_hidden_MessageData []byte `protobuf:"bytes,5,opt,name=message_data,json=messageData"` + xxx_hidden_Payload []byte `protobuf:"bytes,5,opt,name=payload"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Metadata) Reset() { - *x = Metadata{} - mi := &file_ordering_ordering_proto_msgTypes[0] +func (x *Message) Reset() { + *x = Message{} + mi := &file_stream_stream_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Metadata) String() string { +func (x *Message) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Metadata) ProtoMessage() {} +func (*Message) ProtoMessage() {} -func (x *Metadata) ProtoReflect() protoreflect.Message { - mi := &file_ordering_ordering_proto_msgTypes[0] +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_stream_stream_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -59,28 +59,28 @@ func (x *Metadata) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -func (x *Metadata) GetMessageSeqNo() uint64 { +func (x *Message) GetMessageSeqNo() uint64 { if x != nil { return x.xxx_hidden_MessageSeqNo } return 0 } -func (x *Metadata) GetMethod() string { +func (x *Message) GetMethod() string { if x != nil { return x.xxx_hidden_Method } return "" } -func (x *Metadata) GetStatus() *status.Status { +func (x *Message) GetStatus() *status.Status { if x != nil { return x.xxx_hidden_Status } return nil } -func (x *Metadata) GetEntry() []*MetadataEntry { +func (x *Message) GetEntry() []*MetadataEntry { if x != nil { if x.xxx_hidden_Entry != nil { return *x.xxx_hidden_Entry @@ -89,66 +89,66 @@ func (x *Metadata) GetEntry() []*MetadataEntry { return nil } -func (x *Metadata) GetMessageData() []byte { +func (x *Message) GetPayload() []byte { if x != nil { - return x.xxx_hidden_MessageData + return x.xxx_hidden_Payload } return nil } -func (x *Metadata) SetMessageSeqNo(v uint64) { +func (x *Message) SetMessageSeqNo(v uint64) { x.xxx_hidden_MessageSeqNo = v } -func (x *Metadata) SetMethod(v string) { +func (x *Message) SetMethod(v string) { x.xxx_hidden_Method = v } -func (x *Metadata) SetStatus(v *status.Status) { +func (x *Message) SetStatus(v *status.Status) { x.xxx_hidden_Status = v } -func (x *Metadata) SetEntry(v []*MetadataEntry) { +func (x *Message) SetEntry(v []*MetadataEntry) { x.xxx_hidden_Entry = &v } -func (x *Metadata) SetMessageData(v []byte) { +func (x *Message) SetPayload(v []byte) { if v == nil { v = []byte{} } - x.xxx_hidden_MessageData = v + x.xxx_hidden_Payload = v } -func (x *Metadata) HasStatus() bool { +func (x *Message) HasStatus() bool { if x == nil { return false } return x.xxx_hidden_Status != nil } -func (x *Metadata) ClearStatus() { +func (x *Message) ClearStatus() { x.xxx_hidden_Status = nil } -type Metadata_builder struct { +type Message_builder struct { _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. MessageSeqNo uint64 Method string Status *status.Status Entry []*MetadataEntry - MessageData []byte + Payload []byte } -func (b0 Metadata_builder) Build() *Metadata { - m0 := &Metadata{} +func (b0 Message_builder) Build() *Message { + m0 := &Message{} b, x := &b0, m0 _, _ = b, x x.xxx_hidden_MessageSeqNo = b.MessageSeqNo x.xxx_hidden_Method = b.Method x.xxx_hidden_Status = b.Status x.xxx_hidden_Entry = &b.Entry - x.xxx_hidden_MessageData = b.MessageData + x.xxx_hidden_Payload = b.Payload return m0 } @@ -163,7 +163,7 @@ type MetadataEntry struct { func (x *MetadataEntry) Reset() { *x = MetadataEntry{} - mi := &file_ordering_ordering_proto_msgTypes[1] + mi := &file_stream_stream_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -175,7 +175,7 @@ func (x *MetadataEntry) String() string { func (*MetadataEntry) ProtoMessage() {} func (x *MetadataEntry) ProtoReflect() protoreflect.Message { - mi := &file_ordering_ordering_proto_msgTypes[1] + mi := &file_stream_stream_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -224,35 +224,35 @@ func (b0 MetadataEntry_builder) Build() *MetadataEntry { return m0 } -var File_ordering_ordering_proto protoreflect.FileDescriptor +var File_stream_stream_proto protoreflect.FileDescriptor -const file_ordering_ordering_proto_rawDesc = "" + +const file_stream_stream_proto_rawDesc = "" + "\n" + - "\x17ordering/ordering.proto\x12\bordering\x1a\x17google/rpc/status.proto\"\xc6\x01\n" + - "\bMetadata\x12$\n" + + "\x13stream/stream.proto\x12\x06stream\x1a\x17google/rpc/status.proto\"\xba\x01\n" + + "\aMessage\x12$\n" + "\x0emessage_seq_no\x18\x01 \x01(\x04R\fmessageSeqNo\x12\x16\n" + "\x06method\x18\x02 \x01(\tR\x06method\x12*\n" + - "\x06status\x18\x03 \x01(\v2\x12.google.rpc.StatusR\x06status\x12-\n" + - "\x05entry\x18\x04 \x03(\v2\x17.ordering.MetadataEntryR\x05entry\x12!\n" + - "\fmessage_data\x18\x05 \x01(\fR\vmessageData\"7\n" + + "\x06status\x18\x03 \x01(\v2\x12.google.rpc.StatusR\x06status\x12+\n" + + "\x05entry\x18\x04 \x03(\v2\x15.stream.MetadataEntryR\x05entry\x12\x18\n" + + "\apayload\x18\x05 \x01(\fR\apayload\"7\n" + "\rMetadataEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\tR\x05value2B\n" + - "\x06Gorums\x128\n" + + "\x05value\x18\x02 \x01(\tR\x05value2<\n" + + "\x06Gorums\x122\n" + "\n" + - "NodeStream\x12\x12.ordering.Metadata\x1a\x12.ordering.Metadata(\x010\x01B'Z github.com/relab/gorums/ordering\x92\x03\x02\b\x02b\beditionsp\xe8\a" + "NodeStream\x12\x0f.stream.Message\x1a\x0f.stream.Message(\x010\x01B%Z\x1egithub.com/relab/gorums/stream\x92\x03\x02\b\x02b\beditionsp\xe8\a" -var file_ordering_ordering_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_ordering_ordering_proto_goTypes = []any{ - (*Metadata)(nil), // 0: ordering.Metadata - (*MetadataEntry)(nil), // 1: ordering.MetadataEntry +var file_stream_stream_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_stream_stream_proto_goTypes = []any{ + (*Message)(nil), // 0: stream.Message + (*MetadataEntry)(nil), // 1: stream.MetadataEntry (*status.Status)(nil), // 2: google.rpc.Status } -var file_ordering_ordering_proto_depIdxs = []int32{ - 2, // 0: ordering.Metadata.status:type_name -> google.rpc.Status - 1, // 1: ordering.Metadata.entry:type_name -> ordering.MetadataEntry - 0, // 2: ordering.Gorums.NodeStream:input_type -> ordering.Metadata - 0, // 3: ordering.Gorums.NodeStream:output_type -> ordering.Metadata +var file_stream_stream_proto_depIdxs = []int32{ + 2, // 0: stream.Message.status:type_name -> google.rpc.Status + 1, // 1: stream.Message.entry:type_name -> stream.MetadataEntry + 0, // 2: stream.Gorums.NodeStream:input_type -> stream.Message + 0, // 3: stream.Gorums.NodeStream:output_type -> stream.Message 3, // [3:4] is the sub-list for method output_type 2, // [2:3] is the sub-list for method input_type 2, // [2:2] is the sub-list for extension type_name @@ -260,26 +260,26 @@ var file_ordering_ordering_proto_depIdxs = []int32{ 0, // [0:2] is the sub-list for field type_name } -func init() { file_ordering_ordering_proto_init() } -func file_ordering_ordering_proto_init() { - if File_ordering_ordering_proto != nil { +func init() { file_stream_stream_proto_init() } +func file_stream_stream_proto_init() { + if File_stream_stream_proto != nil { return } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_ordering_ordering_proto_rawDesc), len(file_ordering_ordering_proto_rawDesc)), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_stream_stream_proto_rawDesc), len(file_stream_stream_proto_rawDesc)), NumEnums: 0, NumMessages: 2, NumExtensions: 0, NumServices: 1, }, - GoTypes: file_ordering_ordering_proto_goTypes, - DependencyIndexes: file_ordering_ordering_proto_depIdxs, - MessageInfos: file_ordering_ordering_proto_msgTypes, + GoTypes: file_stream_stream_proto_goTypes, + DependencyIndexes: file_stream_stream_proto_depIdxs, + MessageInfos: file_stream_stream_proto_msgTypes, }.Build() - File_ordering_ordering_proto = out.File - file_ordering_ordering_proto_goTypes = nil - file_ordering_ordering_proto_depIdxs = nil + File_stream_stream_proto = out.File + file_stream_stream_proto_goTypes = nil + file_stream_stream_proto_depIdxs = nil } diff --git a/ordering/ordering.proto b/stream/stream.proto similarity index 76% rename from ordering/ordering.proto rename to stream/stream.proto index 21a20d1a..38046888 100644 --- a/ordering/ordering.proto +++ b/stream/stream.proto @@ -1,7 +1,7 @@ edition = "2023"; -package ordering; -option go_package = "github.com/relab/gorums/ordering"; +package stream; +option go_package = "github.com/relab/gorums/stream"; option features.field_presence = IMPLICIT; import "google/rpc/status.proto"; @@ -11,17 +11,17 @@ service Gorums { // NodeStream is a stream that connects a client to a Node. // The messages that are sent on the stream contain both Metadata // and an application-specific message. - rpc NodeStream(stream Metadata) returns (stream Metadata); + rpc NodeStream(stream Message) returns (stream Message); } -// Metadata is sent together with application-specific message types, +// Message is sent together with application-specific message types, // and contains information necessary for Gorums to handle the messages. -message Metadata { +message Message { uint64 message_seq_no = 1; // sequence number for this message string method = 2; // method name to invoke on the server google.rpc.Status status = 3; // status of an invocation (for responses) repeated MetadataEntry entry = 4; // per message client-generated metadata - bytes message_data = 5; // serialized application message + bytes payload = 5; // serialized application message } // MetadataEntry is a key-value pair for Metadata entries. diff --git a/ordering/ordering_grpc.pb.go b/stream/stream_grpc.pb.go similarity index 86% rename from ordering/ordering_grpc.pb.go rename to stream/stream_grpc.pb.go index a1b79e9f..7b66258a 100644 --- a/ordering/ordering_grpc.pb.go +++ b/stream/stream_grpc.pb.go @@ -2,9 +2,9 @@ // versions: // - protoc-gen-go-grpc v1.6.0 // - protoc v6.33.4 -// source: ordering/ordering.proto +// source: stream/stream.proto -package ordering +package stream import ( context "context" @@ -19,7 +19,7 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - Gorums_NodeStream_FullMethodName = "/ordering.Gorums/NodeStream" + Gorums_NodeStream_FullMethodName = "/stream.Gorums/NodeStream" ) // GorumsClient is the client API for Gorums service. @@ -31,7 +31,7 @@ type GorumsClient interface { // NodeStream is a stream that connects a client to a Node. // The messages that are sent on the stream contain both Metadata // and an application-specific message. - NodeStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Metadata, Metadata], error) + NodeStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Message, Message], error) } type gorumsClient struct { @@ -42,18 +42,18 @@ func NewGorumsClient(cc grpc.ClientConnInterface) GorumsClient { return &gorumsClient{cc} } -func (c *gorumsClient) NodeStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Metadata, Metadata], error) { +func (c *gorumsClient) NodeStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Message, Message], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &Gorums_ServiceDesc.Streams[0], Gorums_NodeStream_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[Metadata, Metadata]{ClientStream: stream} + x := &grpc.GenericClientStream[Message, Message]{ClientStream: stream} return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Gorums_NodeStreamClient = grpc.BidiStreamingClient[Metadata, Metadata] +type Gorums_NodeStreamClient = grpc.BidiStreamingClient[Message, Message] // GorumsServer is the server API for Gorums service. // All implementations must embed UnimplementedGorumsServer @@ -64,7 +64,7 @@ type GorumsServer interface { // NodeStream is a stream that connects a client to a Node. // The messages that are sent on the stream contain both Metadata // and an application-specific message. - NodeStream(grpc.BidiStreamingServer[Metadata, Metadata]) error + NodeStream(grpc.BidiStreamingServer[Message, Message]) error mustEmbedUnimplementedGorumsServer() } @@ -75,7 +75,7 @@ type GorumsServer interface { // pointer dereference when methods are called. type UnimplementedGorumsServer struct{} -func (UnimplementedGorumsServer) NodeStream(grpc.BidiStreamingServer[Metadata, Metadata]) error { +func (UnimplementedGorumsServer) NodeStream(grpc.BidiStreamingServer[Message, Message]) error { return status.Error(codes.Unimplemented, "method NodeStream not implemented") } func (UnimplementedGorumsServer) mustEmbedUnimplementedGorumsServer() {} @@ -100,17 +100,17 @@ func RegisterGorumsServer(s grpc.ServiceRegistrar, srv GorumsServer) { } func _Gorums_NodeStream_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(GorumsServer).NodeStream(&grpc.GenericServerStream[Metadata, Metadata]{ServerStream: stream}) + return srv.(GorumsServer).NodeStream(&grpc.GenericServerStream[Message, Message]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Gorums_NodeStreamServer = grpc.BidiStreamingServer[Metadata, Metadata] +type Gorums_NodeStreamServer = grpc.BidiStreamingServer[Message, Message] // Gorums_ServiceDesc is the grpc.ServiceDesc for Gorums service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var Gorums_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "ordering.Gorums", + ServiceName: "stream.Gorums", HandlerType: (*GorumsServer)(nil), Methods: []grpc.MethodDesc{}, Streams: []grpc.StreamDesc{ @@ -121,5 +121,5 @@ var Gorums_ServiceDesc = grpc.ServiceDesc{ ClientStreams: true, }, }, - Metadata: "ordering/ordering.proto", + Metadata: "stream/stream.proto", } diff --git a/testing_shared.go b/testing_shared.go index 5ac73831..a8634e01 100644 --- a/testing_shared.go +++ b/testing_shared.go @@ -247,12 +247,18 @@ func defaultTestServer(i int, opts ...ServerOption) ServerIface { srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { req := AsProto[*pb.StringValue](in) resp, err := ts.Test(ctx, req) - return NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return NewResponseMessage(in, resp), nil }) srv.RegisterHandler(mock.GetValueMethod, func(ctx ServerCtx, in *Message) (*Message, error) { req := AsProto[*pb.Int32Value](in) resp, err := ts.GetValue(ctx, req) - return NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return NewResponseMessage(in, resp), nil }) return srv } @@ -274,7 +280,10 @@ func EchoServerFn(_ int) ServerIface { srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { req := AsProto[*pb.StringValue](in) resp, err := echoSrv{}.Test(ctx, req) - return NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return NewResponseMessage(in, resp), nil }) return srv @@ -296,8 +305,8 @@ func StreamServerFn(_ int) ServerIface { // Send 3 responses for i := 1; i <= 3; i++ { resp := pb.String(fmt.Sprintf("echo: %s-%d", val, i)) - msg := NewResponseMessage(in.GetMetadata(), resp) - if err := ctx.SendMessage(msg); err != nil { + out := NewResponseMessage(in, resp) + if err := ctx.SendMessage(out); err != nil { return nil, err } time.Sleep(10 * time.Millisecond) @@ -316,8 +325,8 @@ func StreamBenchmarkServerFn(_ int) ServerIface { // Send 3 responses for i := 1; i <= 3; i++ { resp := pb.String(fmt.Sprintf("echo: %s-%d", val, i)) - msg := NewResponseMessage(in.GetMetadata(), resp) - if err := ctx.SendMessage(msg); err != nil { + out := NewResponseMessage(in, resp) + if err := ctx.SendMessage(out); err != nil { return nil, err } } diff --git a/unicast.go b/unicast.go index d0beb7e5..a1990c9e 100644 --- a/unicast.go +++ b/unicast.go @@ -1,5 +1,7 @@ package gorums +import "github.com/relab/gorums/stream" + // Unicast is a one-way call; no replies are returned to the client. // // By default, this method blocks until the message has been sent to the node. @@ -13,7 +15,7 @@ package gorums // This method should be used by generated code only. func Unicast[Req msg](ctx *NodeContext, req Req, method string, opts ...CallOption) error { callOpts := getCallOptions(E_Unicast, opts...) - md, err := marshalRequest(ctx, ctx.nextMsgID(), method, req) + reqMsg, err := stream.NewMessage(ctx, ctx.nextMsgID(), method, req) if err != nil { return err } @@ -21,13 +23,13 @@ func Unicast[Req msg](ctx *NodeContext, req Req, method string, opts ...CallOpti waitSendDone := callOpts.mustWaitSendDone() if !waitSendDone { // Fire-and-forget: enqueue and return immediately - ctx.enqueue(request{ctx: ctx, md: md}) + ctx.enqueue(request{ctx: ctx, msg: reqMsg}) return nil } // Default: block until send completes replyChan := make(chan NodeResponse[msg], 1) - ctx.enqueue(request{ctx: ctx, md: md, waitSendDone: true, responseChan: replyChan}) + ctx.enqueue(request{ctx: ctx, msg: reqMsg, waitSendDone: true, responseChan: replyChan}) // Wait for send confirmation select { From a8ce766adc967ab751d17124cb8e791b3d388e31 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 16:01:32 +0100 Subject: [PATCH 10/19] chore: regenerate proto files after refactor to get tests to pass --- benchmark/benchmark_gorums.pb.go | 30 ++++++++++++---- .../dev/zorums_server_gorums.pb.go | 36 +++++++++++-------- examples/storage/proto/storage_gorums.pb.go | 20 ++++++++--- internal/tests/config/config_gorums.pb.go | 5 ++- .../correctable/correctable_gorums.pb.go | 11 +++--- internal/tests/metadata/metadata_gorums.pb.go | 10 ++++-- internal/tests/ordering/order_gorums.pb.go | 10 ++++-- internal/tests/tls/tls_gorums.pb.go | 5 ++- .../unresponsive/unresponsive_gorums.pb.go | 5 ++- 9 files changed, 96 insertions(+), 36 deletions(-) diff --git a/benchmark/benchmark_gorums.pb.go b/benchmark/benchmark_gorums.pb.go index e199144a..954949cc 100644 --- a/benchmark/benchmark_gorums.pb.go +++ b/benchmark/benchmark_gorums.pb.go @@ -168,32 +168,50 @@ func RegisterBenchmarkServer(srv *gorums.Server, impl BenchmarkServer) { srv.RegisterHandler("benchmark.Benchmark.StartServerBenchmark", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*StartRequest](in) resp, err := impl.StartServerBenchmark(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.StopServerBenchmark", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*StopRequest](in) resp, err := impl.StopServerBenchmark(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.StartBenchmark", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*StartRequest](in) resp, err := impl.StartBenchmark(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.StopBenchmark", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*StopRequest](in) resp, err := impl.StopBenchmark(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.QuorumCall", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Echo](in) resp, err := impl.QuorumCall(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.SlowServer", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Echo](in) resp, err := impl.SlowServer(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.Multicast", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*TimedMsg](in) diff --git a/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go index 98759772..5bee3e2d 100644 --- a/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go @@ -8,7 +8,6 @@ package dev import ( gorums "github.com/relab/gorums" - proto "google.golang.org/protobuf/proto" emptypb "google.golang.org/protobuf/types/known/emptypb" ) @@ -40,22 +39,34 @@ func RegisterZorumsServiceServer(srv *gorums.Server, impl ZorumsServiceServer) { srv.RegisterHandler("dev.ZorumsService.GRPCCall", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.GRPCCall(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("dev.ZorumsService.QuorumCall", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.QuorumCall(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("dev.ZorumsService.QuorumCallEmpty", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*emptypb.Empty](in) resp, err := impl.QuorumCallEmpty(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("dev.ZorumsService.QuorumCallEmpty2", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.QuorumCallEmpty2(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("dev.ZorumsService.Multicast", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) @@ -80,27 +91,24 @@ func RegisterZorumsServiceServer(srv *gorums.Server, impl ZorumsServiceServer) { srv.RegisterHandler("dev.ZorumsService.QuorumCallStream", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) err := impl.QuorumCallStream(ctx, req, func(resp *Response) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := proto.CloneOf(in.GetMetadata()) - return ctx.SendMessage(gorums.NewResponseMessage(md, resp)) + out := gorums.NewResponseMessage(in, resp) + return ctx.SendMessage(out) }) return nil, err }) srv.RegisterHandler("dev.ZorumsService.QuorumCallStreamWithEmpty", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) err := impl.QuorumCallStreamWithEmpty(ctx, req, func(resp *emptypb.Empty) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := proto.CloneOf(in.GetMetadata()) - return ctx.SendMessage(gorums.NewResponseMessage(md, resp)) + out := gorums.NewResponseMessage(in, resp) + return ctx.SendMessage(out) }) return nil, err }) srv.RegisterHandler("dev.ZorumsService.QuorumCallStreamWithEmpty2", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*emptypb.Empty](in) err := impl.QuorumCallStreamWithEmpty2(ctx, req, func(resp *Response) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := proto.CloneOf(in.GetMetadata()) - return ctx.SendMessage(gorums.NewResponseMessage(md, resp)) + out := gorums.NewResponseMessage(in, resp) + return ctx.SendMessage(out) }) return nil, err }) diff --git a/examples/storage/proto/storage_gorums.pb.go b/examples/storage/proto/storage_gorums.pb.go index b189b1f4..26dd9556 100644 --- a/examples/storage/proto/storage_gorums.pb.go +++ b/examples/storage/proto/storage_gorums.pb.go @@ -137,22 +137,34 @@ func RegisterStorageServer(srv *gorums.Server, impl StorageServer) { srv.RegisterHandler("proto.Storage.ReadRPC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*ReadRequest](in) resp, err := impl.ReadRPC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("proto.Storage.WriteRPC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*WriteRequest](in) resp, err := impl.WriteRPC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("proto.Storage.ReadQC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*ReadRequest](in) resp, err := impl.ReadQC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("proto.Storage.WriteQC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*WriteRequest](in) resp, err := impl.WriteQC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("proto.Storage.WriteMulticast", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*WriteRequest](in) diff --git a/internal/tests/config/config_gorums.pb.go b/internal/tests/config/config_gorums.pb.go index 7d2f7de2..dfac39e9 100644 --- a/internal/tests/config/config_gorums.pb.go +++ b/internal/tests/config/config_gorums.pb.go @@ -101,6 +101,9 @@ func RegisterConfigTestServer(srv *gorums.Server, impl ConfigTestServer) { srv.RegisterHandler("config.ConfigTest.Config", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.Config(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/internal/tests/correctable/correctable_gorums.pb.go b/internal/tests/correctable/correctable_gorums.pb.go index d2042920..a967da93 100644 --- a/internal/tests/correctable/correctable_gorums.pb.go +++ b/internal/tests/correctable/correctable_gorums.pb.go @@ -8,7 +8,6 @@ package correctable import ( gorums "github.com/relab/gorums" - proto "google.golang.org/protobuf/proto" ) const ( @@ -118,14 +117,16 @@ func RegisterCorrectableTestServer(srv *gorums.Server, impl CorrectableTestServe srv.RegisterHandler("correctable.CorrectableTest.Correctable", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.Correctable(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("correctable.CorrectableTest.CorrectableStream", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) err := impl.CorrectableStream(ctx, req, func(resp *Response) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := proto.CloneOf(in.GetMetadata()) - return ctx.SendMessage(gorums.NewResponseMessage(md, resp)) + out := gorums.NewResponseMessage(in, resp) + return ctx.SendMessage(out) }) return nil, err }) diff --git a/internal/tests/metadata/metadata_gorums.pb.go b/internal/tests/metadata/metadata_gorums.pb.go index dadb7fa4..8dd8bda8 100644 --- a/internal/tests/metadata/metadata_gorums.pb.go +++ b/internal/tests/metadata/metadata_gorums.pb.go @@ -93,11 +93,17 @@ func RegisterMetadataTestServer(srv *gorums.Server, impl MetadataTestServer) { srv.RegisterHandler("metadata.MetadataTest.IDFromMD", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*emptypb.Empty](in) resp, err := impl.IDFromMD(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("metadata.MetadataTest.WhatIP", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*emptypb.Empty](in) resp, err := impl.WhatIP(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/internal/tests/ordering/order_gorums.pb.go b/internal/tests/ordering/order_gorums.pb.go index 5e72a50f..54d0d4c8 100644 --- a/internal/tests/ordering/order_gorums.pb.go +++ b/internal/tests/ordering/order_gorums.pb.go @@ -107,11 +107,17 @@ func RegisterGorumsTestServer(srv *gorums.Server, impl GorumsTestServer) { srv.RegisterHandler("ordering.GorumsTest.QuorumCall", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.QuorumCall(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("ordering.GorumsTest.UnaryRPC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.UnaryRPC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/internal/tests/tls/tls_gorums.pb.go b/internal/tests/tls/tls_gorums.pb.go index a037a912..e9b0efad 100644 --- a/internal/tests/tls/tls_gorums.pb.go +++ b/internal/tests/tls/tls_gorums.pb.go @@ -86,6 +86,9 @@ func RegisterTLSServer(srv *gorums.Server, impl TLSServer) { srv.RegisterHandler("tls.TLS.TestTLS", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.TestTLS(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/internal/tests/unresponsive/unresponsive_gorums.pb.go b/internal/tests/unresponsive/unresponsive_gorums.pb.go index 9fb964a8..f1d6b149 100644 --- a/internal/tests/unresponsive/unresponsive_gorums.pb.go +++ b/internal/tests/unresponsive/unresponsive_gorums.pb.go @@ -86,6 +86,9 @@ func RegisterUnresponsiveServer(srv *gorums.Server, impl UnresponsiveServer) { srv.RegisterHandler("unresponsive.Unresponsive.TestUnresponsive", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Empty](in) resp, err := impl.TestUnresponsive(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } From 8f6d17bc47593bd01677e3ce688e4c9f3e6048e6 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 16:21:40 +0100 Subject: [PATCH 11/19] refactor: move stream package to internal/stream This is to avoid directly exposing the internal message types for the Gorums runtime (NodeStream). --- Makefile | 2 +- channel.go | 2 +- channel_test.go | 2 +- client_interceptor.go | 2 +- .../gengorums/gorums_func_map.go | 2 +- encoding.go | 2 +- encoding_test.go | 2 +- {stream => internal/stream}/gorums_message.go | 0 {stream => internal/stream}/stream.pb.go | 44 +++++++++---------- {stream => internal/stream}/stream.proto | 2 +- {stream => internal/stream}/stream_grpc.pb.go | 4 +- rpc.go | 2 +- server.go | 2 +- unicast.go | 2 +- 14 files changed, 35 insertions(+), 35 deletions(-) rename {stream => internal/stream}/gorums_message.go (100%) rename {stream => internal/stream}/stream.pb.go (82%) rename {stream => internal/stream}/stream.proto (94%) rename {stream => internal/stream}/stream_grpc.pb.go (98%) diff --git a/Makefile b/Makefile index 9a003cd6..8481cb28 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ static_files := $(shell find $(dev_path) -name "*.go" -not -name "zorums*" -no proto_path := $(dev_path):third_party:. plugin_deps := gorums.pb.go $(static_file) -runtime_deps := stream/stream.pb.go stream/stream_grpc.pb.go +runtime_deps := internal/stream/stream.pb.go internal/stream/stream_grpc.pb.go benchmark_deps := benchmark/benchmark.pb.go benchmark/benchmark_gorums.pb.go .PHONY: all dev tools bootstrapgorums installgorums benchmark test compiletests genproto benchtest bench diff --git a/channel.go b/channel.go index 633ece43..32e89d19 100644 --- a/channel.go +++ b/channel.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/relab/gorums/stream" + "github.com/relab/gorums/internal/stream" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" diff --git a/channel_test.go b/channel_test.go index a73e898a..ec7aee29 100644 --- a/channel_test.go +++ b/channel_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" + "github.com/relab/gorums/internal/stream" "github.com/relab/gorums/internal/testutils/mock" - "github.com/relab/gorums/stream" "google.golang.org/grpc" pb "google.golang.org/protobuf/types/known/wrapperspb" ) diff --git a/client_interceptor.go b/client_interceptor.go index 949151c8..aa5cde21 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -6,7 +6,7 @@ import ( "slices" "sync" - "github.com/relab/gorums/stream" + "github.com/relab/gorums/internal/stream" "google.golang.org/protobuf/proto" ) diff --git a/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go b/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go index b755aaf7..d195bfc8 100644 --- a/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go +++ b/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go @@ -29,7 +29,7 @@ var importMap = map[string]protogen.GoImportPath{ "backoff": protogen.GoImportPath("google.golang.org/grpc/backoff"), "proto": protogen.GoImportPath("google.golang.org/protobuf/proto"), "gorums": protogen.GoImportPath("github.com/relab/gorums"), - "stream": protogen.GoImportPath("github.com/relab/gorums/stream"), + "stream": protogen.GoImportPath("github.com/relab/gorums/internal/stream"), "protoreflect": protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect"), } diff --git a/encoding.go b/encoding.go index 377d4bb3..0a5caecc 100644 --- a/encoding.go +++ b/encoding.go @@ -4,7 +4,7 @@ import ( "cmp" "fmt" - "github.com/relab/gorums/stream" + "github.com/relab/gorums/internal/stream" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" diff --git a/encoding_test.go b/encoding_test.go index b3e76585..f6158681 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -5,8 +5,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/relab/gorums" + "github.com/relab/gorums/internal/stream" "github.com/relab/gorums/internal/tests/config" - "github.com/relab/gorums/stream" "google.golang.org/protobuf/testing/protocmp" ) diff --git a/stream/gorums_message.go b/internal/stream/gorums_message.go similarity index 100% rename from stream/gorums_message.go rename to internal/stream/gorums_message.go diff --git a/stream/stream.pb.go b/internal/stream/stream.pb.go similarity index 82% rename from stream/stream.pb.go rename to internal/stream/stream.pb.go index 213e2419..f7671a6b 100644 --- a/stream/stream.pb.go +++ b/internal/stream/stream.pb.go @@ -2,7 +2,7 @@ // versions: // protoc-gen-go v1.36.11 // protoc v6.33.4 -// source: stream/stream.proto +// source: internal/stream/stream.proto package stream @@ -36,7 +36,7 @@ type Message struct { func (x *Message) Reset() { *x = Message{} - mi := &file_stream_stream_proto_msgTypes[0] + mi := &file_internal_stream_stream_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -48,7 +48,7 @@ func (x *Message) String() string { func (*Message) ProtoMessage() {} func (x *Message) ProtoReflect() protoreflect.Message { - mi := &file_stream_stream_proto_msgTypes[0] + mi := &file_internal_stream_stream_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -163,7 +163,7 @@ type MetadataEntry struct { func (x *MetadataEntry) Reset() { *x = MetadataEntry{} - mi := &file_stream_stream_proto_msgTypes[1] + mi := &file_internal_stream_stream_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -175,7 +175,7 @@ func (x *MetadataEntry) String() string { func (*MetadataEntry) ProtoMessage() {} func (x *MetadataEntry) ProtoReflect() protoreflect.Message { - mi := &file_stream_stream_proto_msgTypes[1] + mi := &file_internal_stream_stream_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -224,11 +224,11 @@ func (b0 MetadataEntry_builder) Build() *MetadataEntry { return m0 } -var File_stream_stream_proto protoreflect.FileDescriptor +var File_internal_stream_stream_proto protoreflect.FileDescriptor -const file_stream_stream_proto_rawDesc = "" + +const file_internal_stream_stream_proto_rawDesc = "" + "\n" + - "\x13stream/stream.proto\x12\x06stream\x1a\x17google/rpc/status.proto\"\xba\x01\n" + + "\x1cinternal/stream/stream.proto\x12\x06stream\x1a\x17google/rpc/status.proto\"\xba\x01\n" + "\aMessage\x12$\n" + "\x0emessage_seq_no\x18\x01 \x01(\x04R\fmessageSeqNo\x12\x16\n" + "\x06method\x18\x02 \x01(\tR\x06method\x12*\n" + @@ -240,15 +240,15 @@ const file_stream_stream_proto_rawDesc = "" + "\x05value\x18\x02 \x01(\tR\x05value2<\n" + "\x06Gorums\x122\n" + "\n" + - "NodeStream\x12\x0f.stream.Message\x1a\x0f.stream.Message(\x010\x01B%Z\x1egithub.com/relab/gorums/stream\x92\x03\x02\b\x02b\beditionsp\xe8\a" + "NodeStream\x12\x0f.stream.Message\x1a\x0f.stream.Message(\x010\x01B.Z'github.com/relab/gorums/internal/stream\x92\x03\x02\b\x02b\beditionsp\xe8\a" -var file_stream_stream_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_stream_stream_proto_goTypes = []any{ +var file_internal_stream_stream_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_internal_stream_stream_proto_goTypes = []any{ (*Message)(nil), // 0: stream.Message (*MetadataEntry)(nil), // 1: stream.MetadataEntry (*status.Status)(nil), // 2: google.rpc.Status } -var file_stream_stream_proto_depIdxs = []int32{ +var file_internal_stream_stream_proto_depIdxs = []int32{ 2, // 0: stream.Message.status:type_name -> google.rpc.Status 1, // 1: stream.Message.entry:type_name -> stream.MetadataEntry 0, // 2: stream.Gorums.NodeStream:input_type -> stream.Message @@ -260,26 +260,26 @@ var file_stream_stream_proto_depIdxs = []int32{ 0, // [0:2] is the sub-list for field type_name } -func init() { file_stream_stream_proto_init() } -func file_stream_stream_proto_init() { - if File_stream_stream_proto != nil { +func init() { file_internal_stream_stream_proto_init() } +func file_internal_stream_stream_proto_init() { + if File_internal_stream_stream_proto != nil { return } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_stream_stream_proto_rawDesc), len(file_stream_stream_proto_rawDesc)), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_internal_stream_stream_proto_rawDesc), len(file_internal_stream_stream_proto_rawDesc)), NumEnums: 0, NumMessages: 2, NumExtensions: 0, NumServices: 1, }, - GoTypes: file_stream_stream_proto_goTypes, - DependencyIndexes: file_stream_stream_proto_depIdxs, - MessageInfos: file_stream_stream_proto_msgTypes, + GoTypes: file_internal_stream_stream_proto_goTypes, + DependencyIndexes: file_internal_stream_stream_proto_depIdxs, + MessageInfos: file_internal_stream_stream_proto_msgTypes, }.Build() - File_stream_stream_proto = out.File - file_stream_stream_proto_goTypes = nil - file_stream_stream_proto_depIdxs = nil + File_internal_stream_stream_proto = out.File + file_internal_stream_stream_proto_goTypes = nil + file_internal_stream_stream_proto_depIdxs = nil } diff --git a/stream/stream.proto b/internal/stream/stream.proto similarity index 94% rename from stream/stream.proto rename to internal/stream/stream.proto index 38046888..10143bad 100644 --- a/stream/stream.proto +++ b/internal/stream/stream.proto @@ -1,7 +1,7 @@ edition = "2023"; package stream; -option go_package = "github.com/relab/gorums/stream"; +option go_package = "github.com/relab/gorums/internal/stream"; option features.field_presence = IMPLICIT; import "google/rpc/status.proto"; diff --git a/stream/stream_grpc.pb.go b/internal/stream/stream_grpc.pb.go similarity index 98% rename from stream/stream_grpc.pb.go rename to internal/stream/stream_grpc.pb.go index 7b66258a..b692df2e 100644 --- a/stream/stream_grpc.pb.go +++ b/internal/stream/stream_grpc.pb.go @@ -2,7 +2,7 @@ // versions: // - protoc-gen-go-grpc v1.6.0 // - protoc v6.33.4 -// source: stream/stream.proto +// source: internal/stream/stream.proto package stream @@ -121,5 +121,5 @@ var Gorums_ServiceDesc = grpc.ServiceDesc{ ClientStreams: true, }, }, - Metadata: "stream/stream.proto", + Metadata: "internal/stream/stream.proto", } diff --git a/rpc.go b/rpc.go index 7b2441c0..12b80472 100644 --- a/rpc.go +++ b/rpc.go @@ -1,6 +1,6 @@ package gorums -import "github.com/relab/gorums/stream" +import "github.com/relab/gorums/internal/stream" // RPCCall executes a remote procedure call on the node. // diff --git a/server.go b/server.go index 7c733170..0e1c31f9 100644 --- a/server.go +++ b/server.go @@ -5,7 +5,7 @@ import ( "net" "sync" - "github.com/relab/gorums/stream" + "github.com/relab/gorums/internal/stream" "google.golang.org/grpc" "google.golang.org/protobuf/proto" ) diff --git a/unicast.go b/unicast.go index a1990c9e..a58f9349 100644 --- a/unicast.go +++ b/unicast.go @@ -1,6 +1,6 @@ package gorums -import "github.com/relab/gorums/stream" +import "github.com/relab/gorums/internal/stream" // Unicast is a one-way call; no replies are returned to the client. // From 21388655c022d34033ca44dba628c9e80fb7cd8f Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 19:26:45 +0100 Subject: [PATCH 12/19] refactor: client interceptor to ensure consistent response handling - Remove expectedReplies field and always send one response per node (success, error, or skip) - Added ErrSkipNode for nodes skipped by request transformations - Update quorum and multicast logic to treat skips as non-errors - Simplify response iterator to use config size instead of expected count This fixes the issue where marshal failures or skips could cause the response iterator to stop early or block, ensuring exactly one response is sent per node in the configuration. --- client_interceptor.go | 32 ++++++++++---------------------- errors.go | 4 ++++ multicast.go | 6 ++++-- responses.go | 3 ++- responses_test.go | 7 +++---- 5 files changed, 23 insertions(+), 29 deletions(-) diff --git a/client_interceptor.go b/client_interceptor.go index aa5cde21..9cd32d72 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -53,11 +53,6 @@ type ClientCtx[Req, Resp msg] struct { // Interceptors can wrap this iterator to modify responses. responseSeq ResponseSeq[Resp] - // expectedReplies is the number of responses expected from nodes. - // It is set when messages are sent and may be lower than config size - // if some nodes are skipped by request transformations. - expectedReplies int - // streaming indicates whether this is a streaming call (for correctable streams). streaming bool @@ -88,11 +83,10 @@ func newClientCtxBuilder[Req, Resp msg]( ) *clientCtxBuilder[Req, Resp] { return &clientCtxBuilder[Req, Resp]{ c: &ClientCtx[Req, Resp]{ - Context: ctx, - config: ctx.Configuration(), - request: req, - method: method, - expectedReplies: ctx.Configuration().Size(), + Context: ctx, + config: ctx.Configuration(), + request: req, + method: method, }, chanMultiplier: 1, } @@ -186,11 +180,8 @@ func (c *ClientCtx[Req, Resp]) applyInterceptors(interceptors []any) { } // send dispatches requests to all nodes, applying any registered transformations. -// It updates expectedReplies based on how many nodes actually receive requests -// (nodes may be skipped if a transformation returns nil). +// It ensures that exactly one response (success or error) is sent per node on replyChan. func (c *ClientCtx[Req, Resp]) send() { - var expected int - // Fast path: marshal once when no per-node transforms are registered. var sharedMsg *stream.Message if len(c.reqTransforms) == 0 { @@ -200,9 +191,7 @@ func (c *ClientCtx[Req, Resp]) send() { // Marshaling fails identically for all nodes; report and return. for _, n := range c.config { c.replyChan <- NodeResponse[msg]{NodeID: n.ID(), Err: err} - expected++ } - c.expectedReplies = expected return } } @@ -210,9 +199,8 @@ func (c *ClientCtx[Req, Resp]) send() { // transform only if there are registered transforms; otherwise reuse the shared message streamMsg := cmp.Or(sharedMsg, c.transformAndMarshal(n)) if streamMsg == nil { - continue // Skip node + continue // Skip node: transformAndMarshal already sent ErrSkipNode } - expected++ n.channel.enqueue(request{ ctx: c.Context, msg: streamMsg, @@ -221,7 +209,6 @@ func (c *ClientCtx[Req, Resp]) send() { responseChan: c.replyChan, }) } - c.expectedReplies = expected } // transformAndMarshal applies transformations to the request for the given node, @@ -233,7 +220,8 @@ func (c *ClientCtx[Req, Resp]) transformAndMarshal(n *Node) *stream.Message { result = transform(result, n) } // Check if the result is valid - if protoReq, ok := any(result).(proto.Message); !ok || !protoReq.ProtoReflect().IsValid() { + if protoReq, ok := any(result).(proto.Message); !ok || protoReq == nil || !protoReq.ProtoReflect().IsValid() { + c.replyChan <- NodeResponse[msg]{NodeID: n.ID(), Err: ErrSkipNode} return nil } streamMsg, err := stream.NewMessage(c.Context, c.msgID, c.method, result) @@ -250,7 +238,7 @@ func (c *ClientCtx[Req, Resp]) defaultResponseSeq() ResponseSeq[Resp] { return func(yield func(NodeResponse[Resp]) bool) { // Trigger sending on first iteration c.sendOnce.Do(c.send) - for range c.expectedReplies { + for range c.Size() { select { case r := <-c.replyChan: res := newNodeResponse[Resp](r) @@ -293,7 +281,7 @@ func (c *ClientCtx[Req, Resp]) streamingResponseSeq() ResponseSeq[Resp] { // // The fn receives the original request and a node, and returns the transformed // request to send to that node. If the function returns an invalid message or nil, -// the request to that node is skipped. +// an ErrSkipNode error is sent for that node, indicating it was skipped. func MapRequest[Req, Resp msg](fn func(Req, *Node) Req) QuorumInterceptor[Req, Resp] { return func(ctx *ClientCtx[Req, Resp], next ResponseSeq[Resp]) ResponseSeq[Resp] { if fn != nil { diff --git a/errors.go b/errors.go index b2201858..0834ad9c 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,10 @@ var ErrSendFailure = errors.New("send failure") // ErrTypeMismatch is returned when a response cannot be cast to the expected type. var ErrTypeMismatch = errors.New("response type mismatch") +// ErrSkipNode is returned when a node is skipped by request transformations. +// This allows the response iterator to account for all nodes without blocking. +var ErrSkipNode = errors.New("skip node") + // QuorumCallError reports on a failed quorum call. // It provides detailed information about which nodes failed. type QuorumCallError struct { diff --git a/multicast.go b/multicast.go index a620aeec..7ff2cd54 100644 --- a/multicast.go +++ b/multicast.go @@ -1,6 +1,8 @@ package gorums import ( + "errors" + "google.golang.org/protobuf/types/known/emptypb" ) @@ -31,10 +33,10 @@ func Multicast[Req msg](ctx *ConfigContext, req Req, method string, opts ...Call // If waiting for send completion, drain the reply channel and return the first error. if waitSendDone { var errs []nodeError - for range clientCtx.expectedReplies { + for range clientCtx.Size() { select { case r := <-clientCtx.replyChan: - if r.Err != nil { + if r.Err != nil && !errors.Is(r.Err, ErrSkipNode) { errs = append(errs, nodeError{cause: r.Err, nodeID: r.NodeID}) } case <-ctx.Done(): diff --git a/responses.go b/responses.go index 856e4b7a..a6d7bf87 100644 --- a/responses.go +++ b/responses.go @@ -1,6 +1,7 @@ package gorums import ( + "errors" "iter" "google.golang.org/protobuf/proto" @@ -192,7 +193,7 @@ func (r *Responses[Resp]) Threshold(threshold int) (resp Resp, err error) { errs []nodeError ) for result := range r.ResponseSeq { - if result.Err != nil { + if result.Err != nil && !errors.Is(result.Err, ErrSkipNode) { errs = append(errs, nodeError{nodeID: result.NodeID, cause: result.Err}) continue } diff --git a/responses_test.go b/responses_test.go index 9d4c8178..ad69333d 100644 --- a/responses_test.go +++ b/responses_test.go @@ -24,10 +24,9 @@ func makeClientCtx[Req, Resp msg](t *testing.T, numNodes int, responses []NodeRe } c := &ClientCtx[Req, Resp]{ - Context: t.Context(), - config: config, - replyChan: resultChan, - expectedReplies: numNodes, + Context: t.Context(), + config: config, + replyChan: resultChan, } // Mark sendOnce as done since test responses are already in the channel c.sendOnce.Do(func() {}) From d1976b05d7e9568378c2f76af10854249c36f5a5 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 20:34:07 +0100 Subject: [PATCH 13/19] fix: multicast tests to sort responses to avoid flakiness --- internal/tests/oneway/oneway_test.go | 34 +++++++++++++++++----------- testing_shared.go | 2 +- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/internal/tests/oneway/oneway_test.go b/internal/tests/oneway/oneway_test.go index 0d91c972..915aa7c9 100644 --- a/internal/tests/oneway/oneway_test.go +++ b/internal/tests/oneway/oneway_test.go @@ -37,7 +37,7 @@ func (s *onewaySrv) Multicast(_ gorums.ServerCtx, r *oneway.Request) { } // setupWithNodeMap sets up servers and configuration with sequential node IDs -// (0, 1, 2, ...) matching the server array indices. This is needed for tests like +// (1, 2, 3, ...) matching the server array indices. This is needed for tests like // TestMulticastPerNode that verify per-node message transformations based on node ID. func setupWithNodeMap(t testing.TB, cfgSize int) (cfg oneway.Configuration, srvs []*onewaySrv) { t.Helper() @@ -45,7 +45,6 @@ func setupWithNodeMap(t testing.TB, cfgSize int) (cfg oneway.Configuration, srvs for i := range cfgSize { srvs[i] = &onewaySrv{received: make(chan *oneway.Request, numCalls)} } - cfg = gorums.TestConfiguration(t, cfgSize, func(i int) gorums.ServerIface { srv := gorums.NewServer() oneway.RegisterOnewayTestServer(srv, srvs[i]) @@ -70,14 +69,12 @@ func TestOnewayCalls(t *testing.T) { {name: "MulticastSendWaiting__", calls: numCalls, servers: 9, sendWait: true}, {name: "MulticastNoSendWaiting", calls: numCalls, servers: 9, sendWait: false}, } - for _, test := range tests { t.Run(fmt.Sprintf("%s/Servers=%d", test.name, test.servers), func(t *testing.T) { config, srvs := setupWithNodeMap(t, test.servers) for i := range srvs { srvs[i].wg.Add(test.calls) } - for c := 1; c <= test.calls; c++ { in := oneway.Request_builder{Num: uint64(c)}.Build() if config.Size() == 1 { @@ -110,12 +107,18 @@ func TestOnewayCalls(t *testing.T) { for i := range srvs { srvs[i].wg.Wait() close(srvs[i].received) - expected := uint64(1) + received := make([]uint64, 0, test.calls) for r := range srvs[i].received { - if expected != r.GetNum() { - t.Errorf("%s(%d) = %d, expected %d", test.name, expected, r.GetNum(), expected) + received = append(received, r.GetNum()) + } + // Sort received messages to avoid test flakiness + // due to message reordering in multicast tests + slices.Sort(received) + for j, got := range received { + want := uint64(j + 1) + if want != got { + t.Errorf("%s: received[%d] = %d, expected %d", test.name, j, got, want) } - expected++ } } }) @@ -163,7 +166,6 @@ func TestMulticastPerNode(t *testing.T) { {name: "MulticastPerNodeSendWaitingIgnoreNodes", calls: numCalls, servers: 3, sendWait: true, ignoreNodes: []uint32{0, 1}}, {name: "MulticastPerNodeSendWaitingIgnoreNodes", calls: numCalls, servers: 3, sendWait: true, ignoreNodes: []uint32{0, 1, 2}}, } - for _, test := range tests { t.Run(fmt.Sprintf("%s/Servers=%d/IgnoredNodes=%v", test.name, test.servers, test.ignoreNodes), func(t *testing.T) { config, srvs := setupWithNodeMap(t, test.servers) @@ -206,12 +208,18 @@ func TestMulticastPerNode(t *testing.T) { } srvs[i].wg.Wait() close(srvs[i].received) - expected := add(uint64(1), nodeIDs[i]) + received := make([]uint64, 0, test.calls) for r := range srvs[i].received { - if expected != r.GetNum() { - t.Errorf("%s -> %d, expected %d, nodeID=%d", test.name, r.GetNum(), expected, nodeIDs[i]) + received = append(received, r.GetNum()) + } + // Sort received messages to avoid test flakiness + // due to message reordering in multicast tests + slices.Sort(received) + for j, got := range received { + want := add(uint64(j+1), nodeIDs[i]) + if want != got { + t.Errorf("%s: received[%d] = %d, expected %d, nodeID=%d", test.name, j, got, want, nodeIDs[i]) } - expected++ } } }) diff --git a/testing_shared.go b/testing_shared.go index a8634e01..ccd73720 100644 --- a/testing_shared.go +++ b/testing_shared.go @@ -53,7 +53,7 @@ func TestQuorumCallError(_ testing.TB, nodeErrors map[uint32]error) QuorumCallEr // // Optional TestOptions can be provided to customize the manager, server, or configuration. // -// By default, nodes are assigned sequential IDs (0, 1, 2, ...) matching the server +// By default, nodes are assigned sequential IDs (1, 2, 3, ...) matching the server // creation order. This can be overridden by providing a NodeListOption. // // This is the recommended way to set up tests that need both servers and a configuration. From 202d6064b62635ebb2cb22fbb971e7ea406d9c8a Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 20:51:59 +0100 Subject: [PATCH 14/19] refactor: simplify messageWithError logic for response message creation Don't reuse the in message as the response message as explained in a review comment; this can cause unnecessary bandwidth usage to send back the request payload to the client for no good reason. Review comment from Copilot: messageWithError reuses the incoming request message (in) when out is nil. That means error responses will echo the original request payload back to the client (wasted bandwidth, and surprising semantics). Prefer constructing a fresh response stream.Message (seqno+method+status, empty payload) or at least clear Payload/Entry when building an error-only response. --- encoding.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/encoding.go b/encoding.go index 0a5caecc..5ddd4712 100644 --- a/encoding.go +++ b/encoding.go @@ -1,7 +1,6 @@ package gorums import ( - "cmp" "fmt" "github.com/relab/gorums/internal/stream" @@ -68,18 +67,21 @@ func AsProto[T proto.Message](msg *Message) T { } // messageWithError ensures a response message exists and sets the error status. -// If out is nil, the in message (request) is reused to return the error status. -// This is used by the server to send error responses back to the client. +// If out is nil, a new response message is created based on the in request message; +// otherwise, out is modified in place. This is used by the server to send error +// responses back to the client. func messageWithError(in, out *Message, err error) *Message { - msg := cmp.Or(out, in) + if out == nil { + out = NewResponseMessage(in, nil) + } if err != nil { errStatus, ok := status.FromError(err) if !ok { errStatus = status.New(codes.Unknown, err.Error()) } - msg.SetStatus(errStatus.Proto()) + out.SetStatus(errStatus.Proto()) } - return msg + return out } // unmarshalRequest unmarshals the request proto message from the message. From de50d158191260d79d35f08621cd2339681a5654 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 21:13:45 +0100 Subject: [PATCH 15/19] doc: corrected the AsProto documentation --- encoding.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/encoding.go b/encoding.go index 5ddd4712..f916bdc6 100644 --- a/encoding.go +++ b/encoding.go @@ -51,10 +51,9 @@ func NewResponseMessage(in *Message, resp proto.Message) *Message { } } -// AsProto extracts the payload from the message. -// If msg is nil or invalid, the zero value of T is returned. -// -// This function should only be used in generated code. +// AsProto returns the message's already-decoded proto message as type T. +// If the message is nil, or the underlying message cannot be asserted to T, +// the zero value of T is returned. func AsProto[T proto.Message](msg *Message) T { var zero T if msg == nil || msg.Msg == nil { From 4d9de3c150ae2658f9e658371c1f77e32eccec83 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 21:17:30 +0100 Subject: [PATCH 16/19] chore: fix TestNewResponse to use mock.GetValueMethod string This resolves a code review complaint about incorrect gRPC-style path. --- encoding_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/encoding_test.go b/encoding_test.go index f6158681..4149e0bf 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -7,6 +7,7 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/stream" "github.com/relab/gorums/internal/tests/config" + "github.com/relab/gorums/internal/testutils/mock" "google.golang.org/protobuf/testing/protocmp" ) @@ -16,11 +17,11 @@ func TestNewResponseMessage(t *testing.T) { streamIn := stream.Message_builder{ MessageSeqNo: 100, - Method: "/pkg.Svc/Call", + Method: mock.GetValueMethod, Payload: []byte("request payload"), Entry: []*stream.MetadataEntry{stream.MetadataEntry_builder{Key: "key1", Value: "val1"}.Build()}, }.Build() - streamOut := stream.Message_builder{MessageSeqNo: 100, Method: "/pkg.Svc/Call"}.Build() + streamOut := stream.Message_builder{MessageSeqNo: 100, Method: mock.GetValueMethod}.Build() tests := []struct { name string From a237be59fd7fdbfbb84e87c200f59cb5a8d51516 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 21:25:24 +0100 Subject: [PATCH 17/19] chore: removed obsolete comment in clientCtxBuilder.Build() --- client_interceptor.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/client_interceptor.go b/client_interceptor.go index 9cd32d72..72860187 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -113,8 +113,6 @@ func (b *clientCtxBuilder[Req, Resp]) WithWaitSendDone(waitSendDone bool) *clien // It creates the metadata and reply channel, and sets up the appropriate response iterator. func (b *clientCtxBuilder[Req, Resp]) Build() *ClientCtx[Req, Resp] { // Assign a unique message ID and create the reply channel at build time. - // The stream.Message is created lazily in applyTransforms, where the - // request payload is marshaled together with the metadata. b.c.msgID = b.c.config.nextMsgID() b.c.replyChan = make(chan NodeResponse[msg], b.c.config.Size()*b.chanMultiplier) From 23f3e463cdd626b5755f70e8442f353e051a8166 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 21:47:59 +0100 Subject: [PATCH 18/19] refactor: message handling in ClientCtx.send method This avoids the cmp.Or function here since it will always call the transformAndMarshal() method even if there is nothing to transform. --- client_interceptor.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client_interceptor.go b/client_interceptor.go index 72860187..52cceac5 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -1,7 +1,6 @@ package gorums import ( - "cmp" "context" "slices" "sync" @@ -180,9 +179,9 @@ func (c *ClientCtx[Req, Resp]) applyInterceptors(interceptors []any) { // send dispatches requests to all nodes, applying any registered transformations. // It ensures that exactly one response (success or error) is sent per node on replyChan. func (c *ClientCtx[Req, Resp]) send() { - // Fast path: marshal once when no per-node transforms are registered. var sharedMsg *stream.Message if len(c.reqTransforms) == 0 { + // Fast path: marshal once when no per-node transforms are registered. var err error sharedMsg, err = stream.NewMessage(c.Context, c.msgID, c.method, c.request) if err != nil { @@ -195,7 +194,10 @@ func (c *ClientCtx[Req, Resp]) send() { } for _, n := range c.config { // transform only if there are registered transforms; otherwise reuse the shared message - streamMsg := cmp.Or(sharedMsg, c.transformAndMarshal(n)) + streamMsg := sharedMsg + if streamMsg == nil { + streamMsg = c.transformAndMarshal(n) + } if streamMsg == nil { continue // Skip node: transformAndMarshal already sent ErrSkipNode } From 3139e1261a07217acd587df941294cf6caaf95d2 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 11 Feb 2026 22:15:45 +0100 Subject: [PATCH 19/19] refactor: add ErrorStatus method in response handling --- channel.go | 2 +- internal/stream/gorums_message.go | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/channel.go b/channel.go index 32e89d19..edabcf00 100644 --- a/channel.go +++ b/channel.go @@ -285,7 +285,7 @@ func (c *channel) receiver() { c.cancelPendingMsgs(e) c.clearStream() } else { - err := status.FromProto(respMsg.GetStatus()).Err() + err := respMsg.ErrorStatus() var resp msg if err == nil { resp, err = unmarshalResponse(respMsg) diff --git a/internal/stream/gorums_message.go b/internal/stream/gorums_message.go index c1d2b8c7..b39269ad 100644 --- a/internal/stream/gorums_message.go +++ b/internal/stream/gorums_message.go @@ -4,6 +4,7 @@ import ( "context" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) @@ -48,3 +49,11 @@ func (x *Message) AppendToIncomingContext(ctx context.Context) context.Context { } return metadata.NewIncomingContext(ctx, newMD) } + +func (x *Message) ErrorStatus() error { + s := x.GetStatus() + if s == nil { + return nil + } + return status.ErrorProto(s) +}