diff --git a/.vscode/gorums.txt b/.vscode/gorums.txt index 2896c781f..e22ddbb89 100644 --- a/.vscode/gorums.txt +++ b/.vscode/gorums.txt @@ -69,6 +69,7 @@ pprof proto protobuf protoc +protocmp protodesc protogen protoimpl diff --git a/Makefile b/Makefile index 22b543127..8481cb28c 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ static_files := $(shell find $(dev_path) -name "*.go" -not -name "zorums*" -no proto_path := $(dev_path):third_party:. plugin_deps := gorums.pb.go $(static_file) -runtime_deps := ordering/ordering.pb.go ordering/ordering_grpc.pb.go +runtime_deps := internal/stream/stream.pb.go internal/stream/stream_grpc.pb.go benchmark_deps := benchmark/benchmark.pb.go benchmark/benchmark_gorums.pb.go .PHONY: all dev tools bootstrapgorums installgorums benchmark test compiletests genproto benchtest bench diff --git a/benchmark/benchmark_gorums.pb.go b/benchmark/benchmark_gorums.pb.go index e199144ac..954949cc8 100644 --- a/benchmark/benchmark_gorums.pb.go +++ b/benchmark/benchmark_gorums.pb.go @@ -168,32 +168,50 @@ func RegisterBenchmarkServer(srv *gorums.Server, impl BenchmarkServer) { srv.RegisterHandler("benchmark.Benchmark.StartServerBenchmark", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*StartRequest](in) resp, err := impl.StartServerBenchmark(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.StopServerBenchmark", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*StopRequest](in) resp, err := impl.StopServerBenchmark(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.StartBenchmark", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*StartRequest](in) resp, err := impl.StartBenchmark(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.StopBenchmark", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*StopRequest](in) resp, err := impl.StopBenchmark(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.QuorumCall", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Echo](in) resp, err := impl.QuorumCall(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.SlowServer", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Echo](in) resp, err := impl.SlowServer(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("benchmark.Benchmark.Multicast", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*TimedMsg](in) diff --git a/channel.go b/channel.go index 7b61ca378..edabcf004 100644 --- a/channel.go +++ b/channel.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/relab/gorums/ordering" + "github.com/relab/gorums/internal/stream" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" @@ -44,9 +44,9 @@ var ( type request struct { ctx context.Context - msg *Message - waitSendDone bool + msg *stream.Message streaming bool + waitSendDone bool responseChan chan<- NodeResponse[msg] sendTime time.Time } @@ -68,8 +68,8 @@ type channel struct { // Stream lifecycle management for FIFO ordered message delivery // stream is a bidirectional stream for - // sending and receiving ordering.Metadata messages. - stream ordering.Gorums_NodeStreamClient + // sending and receiving stream.Message messages. + stream stream.Gorums_NodeStreamClient streamMut sync.Mutex streamCtx context.Context streamCancel context.CancelFunc @@ -149,12 +149,12 @@ func (c *channel) ensureConnectedNodeStream() (err error) { return nil } c.streamCtx, c.streamCancel = context.WithCancel(c.connCtx) - c.stream, err = ordering.NewGorumsClient(c.conn).NodeStream(c.streamCtx) + c.stream, err = stream.NewGorumsClient(c.conn).NodeStream(c.streamCtx) return err } // getStream returns the current stream, or nil if no stream is available. -func (c *channel) getStream() grpc.ClientStream { +func (c *channel) getStream() stream.Gorums_NodeStreamClient { c.streamMut.Lock() defer c.streamMut.Unlock() return c.stream @@ -180,7 +180,7 @@ func (c *channel) isConnected() bool { func (c *channel) enqueue(req request) { if req.responseChan != nil { req.sendTime = time.Now() - msgID := req.msg.GetMessageID() + msgID := req.msg.GetMessageSeqNo() c.responseMut.Lock() c.responseRouters[msgID] = req c.responseMut.Unlock() @@ -190,7 +190,7 @@ func (c *channel) enqueue(req request) { select { case <-c.connCtx.Done(): // the node's close() method was called: respond with error instead of enqueueing - c.routeResponse(req.msg.GetMessageID(), NodeResponse[msg]{NodeID: c.id, Err: nodeClosedErr}) + c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: nodeClosedErr}) return case c.sendQ <- req: // enqueued successfully @@ -252,11 +252,11 @@ func (c *channel) sender() { // take next request from sendQ } if err := c.ensureStream(); err != nil { - c.routeResponse(req.msg.GetMessageID(), NodeResponse[msg]{NodeID: c.id, Err: err}) + c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) continue } if err := c.sendMsg(req); err != nil { - c.routeResponse(req.msg.GetMessageID(), NodeResponse[msg]{NodeID: c.id, Err: err}) + c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Err: err}) } } } @@ -279,14 +279,18 @@ func (c *channel) receiver() { } } - resp := newMessage(responseType) - if err := stream.RecvMsg(resp); err != nil { - c.setLastErr(err) - c.cancelPendingMsgs(err) + respMsg, e := stream.Recv() + if e != nil { + c.setLastErr(e) + c.cancelPendingMsgs(e) c.clearStream() } else { - err := resp.GetStatus().Err() - c.routeResponse(resp.GetMessageID(), NodeResponse[msg]{NodeID: c.id, Value: resp.GetProtoMessage(), Err: err}) + err := respMsg.ErrorStatus() + var resp msg + if err == nil { + resp, err = unmarshalResponse(respMsg) + } + c.routeResponse(respMsg.GetMessageSeqNo(), NodeResponse[msg]{NodeID: c.id, Value: resp, Err: err}) } select { @@ -316,7 +320,7 @@ func (c *channel) sendMsg(req request) (err error) { // wait for actual server responses, so waitSendDone is false for them. if req.waitSendDone && err == nil { // Send succeeded: unblock the caller and clean up the responseRouter - c.routeResponse(req.msg.GetMessageID(), NodeResponse[msg]{}) + c.routeResponse(req.msg.GetMessageSeqNo(), NodeResponse[msg]{}) } }() @@ -353,7 +357,7 @@ func (c *channel) sendMsg(req request) (err error) { } }() - if err = stream.SendMsg(req.msg); err != nil { + if err = stream.Send(req.msg); err != nil { c.setLastErr(err) c.clearStream() } diff --git a/channel_test.go b/channel_test.go index 1da69fe60..ec7aee29d 100644 --- a/channel_test.go +++ b/channel_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/relab/gorums/internal/stream" "github.com/relab/gorums/internal/testutils/mock" "google.golang.org/grpc" pb "google.golang.org/protobuf/types/known/wrapperspb" @@ -48,7 +49,10 @@ func delayServerFn(delay time.Duration) func(_ int) ServerIface { time.Sleep(delay) req := AsProto[*pb.StringValue](in) resp, err := mockSrv.Test(ctx, req) - return NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return NewResponseMessage(in, resp), nil }) return srv } @@ -59,7 +63,11 @@ func sendRequest(t testing.TB, node *Node, req request, msgID uint64) NodeRespon if req.ctx == nil { req.ctx = t.Context() } - req.msg = NewRequest(req.ctx, msgID, mock.TestMethod, nil) + reqMsg, err := stream.NewMessage(req.ctx, msgID, mock.TestMethod, nil) + if err != nil { + t.Fatalf("NewMessage failed: %v", err) + } + req.msg = reqMsg replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -593,7 +601,8 @@ func TestChannelDeadlock(t *testing.T) { for id := range 10 { go func() { ctx := TestContext(t, 3*time.Second) - req := request{ctx: ctx, msg: NewRequest(ctx, uint64(100+id), mock.TestMethod, nil)} + reqMsg, _ := stream.NewMessage(ctx, uint64(100+id), mock.TestMethod, nil) + req := request{ctx: ctx, msg: reqMsg} // try to enqueue select { @@ -798,7 +807,8 @@ func BenchmarkChannelStreamReadyFirstRequest(b *testing.B) { // Use a fresh context for the benchmark request ctx := TestContext(b, defaultTestTimeout) - req := request{ctx: ctx, msg: NewRequest(ctx, 1, mock.TestMethod, nil)} + reqMsg, _ := stream.NewMessage(ctx, 1, mock.TestMethod, nil) + req := request{ctx: ctx, msg: reqMsg} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -847,7 +857,8 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Establish initial stream with a fresh context ctx := context.Background() - req := request{ctx: ctx, msg: NewRequest(ctx, 0, mock.TestMethod, nil)} + reqMsg, _ := stream.NewMessage(ctx, 0, mock.TestMethod, nil) + req := request{ctx: ctx, msg: reqMsg} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -872,7 +883,8 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Now send a request which will trigger ensureStream -> newNodeStream -> signal ctx := context.Background() - req := request{ctx: ctx, msg: NewRequest(ctx, uint64(i+1), mock.TestMethod, nil)} + reqMsg, _ := stream.NewMessage(ctx, uint64(i+1), mock.TestMethod, nil) + req := request{ctx: ctx, msg: reqMsg} replyChan := make(chan NodeResponse[msg], 1) req.responseChan = replyChan node.channel.enqueue(req) diff --git a/client_interceptor.go b/client_interceptor.go index 938d46487..52cceac54 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -5,7 +5,7 @@ import ( "slices" "sync" - "github.com/relab/gorums/ordering" + "github.com/relab/gorums/internal/stream" "google.golang.org/protobuf/proto" ) @@ -42,7 +42,7 @@ type ClientCtx[Req, Resp msg] struct { config Configuration request Req method string - md *ordering.Metadata + msgID uint64 replyChan chan NodeResponse[msg] // reqTransforms holds request transformation functions registered by interceptors. @@ -52,11 +52,6 @@ type ClientCtx[Req, Resp msg] struct { // Interceptors can wrap this iterator to modify responses. responseSeq ResponseSeq[Resp] - // expectedReplies is the number of responses expected from nodes. - // It is set when messages are sent and may be lower than config size - // if some nodes are skipped by request transformations. - expectedReplies int - // streaming indicates whether this is a streaming call (for correctable streams). streaming bool @@ -87,11 +82,10 @@ func newClientCtxBuilder[Req, Resp msg]( ) *clientCtxBuilder[Req, Resp] { return &clientCtxBuilder[Req, Resp]{ c: &ClientCtx[Req, Resp]{ - Context: ctx, - config: ctx.Configuration(), - request: req, - method: method, - expectedReplies: ctx.Configuration().Size(), + Context: ctx, + config: ctx.Configuration(), + request: req, + method: method, }, chanMultiplier: 1, } @@ -117,8 +111,8 @@ func (b *clientCtxBuilder[Req, Resp]) WithWaitSendDone(waitSendDone bool) *clien // Build finalizes the ClientCtx configuration and returns the constructed instance. // It creates the metadata and reply channel, and sets up the appropriate response iterator. func (b *clientCtxBuilder[Req, Resp]) Build() *ClientCtx[Req, Resp] { - // Create metadata and reply channel at build time - b.c.md = ordering.NewGorumsMetadata(b.c.Context, b.c.config.nextMsgID(), b.c.method) + // Assign a unique message ID and create the reply channel at build time. + b.c.msgID = b.c.config.nextMsgID() b.c.replyChan = make(chan NodeResponse[msg], b.c.config.Size()*b.chanMultiplier) if b.c.streaming { @@ -170,23 +164,6 @@ func (c *ClientCtx[Req, Resp]) Size() int { return c.config.Size() } -// applyTransforms returns the transformed request as a proto.Message, or nil if the result is -// invalid or the node should be skipped. It applies the registered transformation functions to -// the given request for the specified node. Transformation functions are applied in the order -// they were registered. -func (c *ClientCtx[Req, Resp]) applyTransforms(req Req, node *Node) proto.Message { - result := req - for _, transform := range c.reqTransforms { - result = transform(result, node) - } - if protoMsg, ok := any(result).(proto.Message); ok { - if protoMsg.ProtoReflect().IsValid() { - return protoMsg - } - } - return nil -} - // applyInterceptors chains the given interceptors, wrapping the response sequence. // Each interceptor receives the current response sequence and returns a new one. // Interceptors are applied in order, with each wrapping the previous result. @@ -200,28 +177,59 @@ func (c *ClientCtx[Req, Resp]) applyInterceptors(interceptors []any) { } // send dispatches requests to all nodes, applying any registered transformations. -// It updates expectedReplies based on how many nodes actually receive requests -// (nodes may be skipped if a transformation returns nil). +// It ensures that exactly one response (success or error) is sent per node on replyChan. func (c *ClientCtx[Req, Resp]) send() { - var expected int + var sharedMsg *stream.Message + if len(c.reqTransforms) == 0 { + // Fast path: marshal once when no per-node transforms are registered. + var err error + sharedMsg, err = stream.NewMessage(c.Context, c.msgID, c.method, c.request) + if err != nil { + // Marshaling fails identically for all nodes; report and return. + for _, n := range c.config { + c.replyChan <- NodeResponse[msg]{NodeID: n.ID(), Err: err} + } + return + } + } for _, n := range c.config { - msg := c.applyTransforms(c.request, n) - if msg == nil { - continue // Skip node if transformation returns nil + // transform only if there are registered transforms; otherwise reuse the shared message + streamMsg := sharedMsg + if streamMsg == nil { + streamMsg = c.transformAndMarshal(n) + } + if streamMsg == nil { + continue // Skip node: transformAndMarshal already sent ErrSkipNode } - expected++ - // Clone metadata for each request to avoid race conditions during - // concurrent marshaling when gorumsMarshal calls SetMessageData. - md := proto.CloneOf(c.md) n.channel.enqueue(request{ ctx: c.Context, - msg: newRequestMessage(md, msg), + msg: streamMsg, streaming: c.streaming, waitSendDone: c.waitSendDone, responseChan: c.replyChan, }) } - c.expectedReplies = expected +} + +// transformAndMarshal applies transformations to the request for the given node, +// then marshals it into a stream.Message. Returns nil if transformation fails +// or marshaling fails (in which case the error is sent on replyChan). +func (c *ClientCtx[Req, Resp]) transformAndMarshal(n *Node) *stream.Message { + result := c.request + for _, transform := range c.reqTransforms { + result = transform(result, n) + } + // Check if the result is valid + if protoReq, ok := any(result).(proto.Message); !ok || protoReq == nil || !protoReq.ProtoReflect().IsValid() { + c.replyChan <- NodeResponse[msg]{NodeID: n.ID(), Err: ErrSkipNode} + return nil + } + streamMsg, err := stream.NewMessage(c.Context, c.msgID, c.method, result) + if err != nil { + c.replyChan <- NodeResponse[msg]{NodeID: n.ID(), Err: err} + return nil + } + return streamMsg } // defaultResponseSeq returns an iterator that yields at most c.expectedReplies responses @@ -230,7 +238,7 @@ func (c *ClientCtx[Req, Resp]) defaultResponseSeq() ResponseSeq[Resp] { return func(yield func(NodeResponse[Resp]) bool) { // Trigger sending on first iteration c.sendOnce.Do(c.send) - for range c.expectedReplies { + for range c.Size() { select { case r := <-c.replyChan: res := newNodeResponse[Resp](r) @@ -273,7 +281,7 @@ func (c *ClientCtx[Req, Resp]) streamingResponseSeq() ResponseSeq[Resp] { // // The fn receives the original request and a node, and returns the transformed // request to send to that node. If the function returns an invalid message or nil, -// the request to that node is skipped. +// an ErrSkipNode error is sent for that node, indicating it was skipped. func MapRequest[Req, Resp msg](fn func(Req, *Node) Req) QuorumInterceptor[Req, Resp] { return func(ctx *ClientCtx[Req, Resp], next ResponseSeq[Resp]) ResponseSeq[Resp] { if fn != nil { diff --git a/client_interceptor_test.go b/client_interceptor_test.go index 9755da869..9bc7513a7 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -1,6 +1,7 @@ package gorums_test import ( + "fmt" "testing" "time" @@ -179,3 +180,73 @@ func TestCustomInterceptorWithMapRequest(t *testing.T) { t.Errorf("Expected 3 responses counted, got %d", count) } } + +// BenchmarkQuorumCallMapRequest compares the performance of quorum calls +// with and without per-node MapRequest interceptors. +// Without MapRequest, the request is marshaled once and shared across all nodes. +// With MapRequest, each node gets a cloned message with individually marshaled payload. +func BenchmarkQuorumCallMapRequest(b *testing.B) { + for _, numNodes := range []int{3, 7, 13} { + config := gorums.TestConfiguration(b, numNodes, gorums.EchoServerFn) + cfgCtx := config.Context(b.Context()) + + // Baseline: no interceptors — single marshal, shared message + b.Run(fmt.Sprintf("NoTransform/%d", numNodes), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + cfgCtx, + pb.String("benchmark payload"), + mock.TestMethod, + ) + if _, err := responses.Majority(); err != nil { + b.Fatal(err) + } + } + }) + + // With MapRequest identity transform — per-node clone + marshal + b.Run(fmt.Sprintf("MapRequestIdentity/%d", numNodes), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + cfgCtx, + pb.String("benchmark payload"), + mock.TestMethod, + gorums.Interceptors( + gorums.MapRequest[*pb.StringValue, *pb.StringValue]( + func(req *pb.StringValue, _ *gorums.Node) *pb.StringValue { + return req + }, + ), + ), + ) + if _, err := responses.Majority(); err != nil { + b.Fatal(err) + } + } + }) + + // With MapRequest that modifies the request per node + b.Run(fmt.Sprintf("MapRequestPerNode/%d", numNodes), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + cfgCtx, + pb.String("benchmark payload"), + mock.TestMethod, + gorums.Interceptors( + gorums.MapRequest[*pb.StringValue, *pb.StringValue]( + func(req *pb.StringValue, n *gorums.Node) *pb.StringValue { + return pb.String(fmt.Sprintf("%s-node-%d", req.GetValue(), n.ID())) + }, + ), + ), + ) + if _, err := responses.Majority(); err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/cmd/protoc-gen-gorums/dev/generated_code_test.go b/cmd/protoc-gen-gorums/dev/generated_code_test.go index d0d52e5b3..5c53cd7a3 100644 --- a/cmd/protoc-gen-gorums/dev/generated_code_test.go +++ b/cmd/protoc-gen-gorums/dev/generated_code_test.go @@ -25,7 +25,7 @@ func quorumCallServer(_ int) gorums.ServerIface { req := gorums.AsProto[*dev.Request](in) resp := &dev.Response{} resp.SetResult(int64(len(req.GetValue()))) - return gorums.NewResponseMessage(in.GetMetadata(), resp), nil + return gorums.NewResponseMessage(in, resp), nil }) return srv } diff --git a/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go index 987597727..5bee3e2df 100644 --- a/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go @@ -8,7 +8,6 @@ package dev import ( gorums "github.com/relab/gorums" - proto "google.golang.org/protobuf/proto" emptypb "google.golang.org/protobuf/types/known/emptypb" ) @@ -40,22 +39,34 @@ func RegisterZorumsServiceServer(srv *gorums.Server, impl ZorumsServiceServer) { srv.RegisterHandler("dev.ZorumsService.GRPCCall", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.GRPCCall(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("dev.ZorumsService.QuorumCall", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.QuorumCall(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("dev.ZorumsService.QuorumCallEmpty", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*emptypb.Empty](in) resp, err := impl.QuorumCallEmpty(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("dev.ZorumsService.QuorumCallEmpty2", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.QuorumCallEmpty2(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("dev.ZorumsService.Multicast", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) @@ -80,27 +91,24 @@ func RegisterZorumsServiceServer(srv *gorums.Server, impl ZorumsServiceServer) { srv.RegisterHandler("dev.ZorumsService.QuorumCallStream", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) err := impl.QuorumCallStream(ctx, req, func(resp *Response) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := proto.CloneOf(in.GetMetadata()) - return ctx.SendMessage(gorums.NewResponseMessage(md, resp)) + out := gorums.NewResponseMessage(in, resp) + return ctx.SendMessage(out) }) return nil, err }) srv.RegisterHandler("dev.ZorumsService.QuorumCallStreamWithEmpty", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) err := impl.QuorumCallStreamWithEmpty(ctx, req, func(resp *emptypb.Empty) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := proto.CloneOf(in.GetMetadata()) - return ctx.SendMessage(gorums.NewResponseMessage(md, resp)) + out := gorums.NewResponseMessage(in, resp) + return ctx.SendMessage(out) }) return nil, err }) srv.RegisterHandler("dev.ZorumsService.QuorumCallStreamWithEmpty2", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*emptypb.Empty](in) err := impl.QuorumCallStreamWithEmpty2(ctx, req, func(resp *Response) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := proto.CloneOf(in.GetMetadata()) - return ctx.SendMessage(gorums.NewResponseMessage(md, resp)) + out := gorums.NewResponseMessage(in, resp) + return ctx.SendMessage(out) }) return nil, err }) diff --git a/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go b/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go index becfa5d8d..d195bfc8e 100644 --- a/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go +++ b/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go @@ -29,7 +29,7 @@ var importMap = map[string]protogen.GoImportPath{ "backoff": protogen.GoImportPath("google.golang.org/grpc/backoff"), "proto": protogen.GoImportPath("google.golang.org/protobuf/proto"), "gorums": protogen.GoImportPath("github.com/relab/gorums"), - "ordering": protogen.GoImportPath("github.com/relab/gorums/ordering"), + "stream": protogen.GoImportPath("github.com/relab/gorums/internal/stream"), "protoreflect": protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect"), } diff --git a/cmd/protoc-gen-gorums/gengorums/template_server.go b/cmd/protoc-gen-gorums/gengorums/template_server.go index 17aee7596..39f30d6cd 100644 --- a/cmd/protoc-gen-gorums/gengorums/template_server.go +++ b/cmd/protoc-gen-gorums/gengorums/template_server.go @@ -39,14 +39,16 @@ func Register{{$service}}Server(srv *{{use "gorums.Server" $genFile}}, impl {{$s return nil, nil {{- else if isStreamingServer .}} err := impl.{{.GoName}}(ctx, req, func(resp *{{out $genFile .}}) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := {{use "proto.CloneOf" $genFile}}(in.GetMetadata()) - return ctx.SendMessage({{$newMessage}}(md, resp)) + out := {{$newMessage}}(in, resp) + return ctx.SendMessage(out) }) return nil, err {{- else }} resp, err := impl.{{.GoName}}(ctx, req) - return {{$newMessage}}(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return {{$newMessage}}(in, resp), nil {{- end}} }) {{- end}} diff --git a/config_test.go b/config_test.go index eb40df16b..c43b37556 100644 --- a/config_test.go +++ b/config_test.go @@ -5,15 +5,8 @@ import ( "testing" "github.com/relab/gorums" - "google.golang.org/grpc/encoding" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} - var ( nodes = []string{"127.0.0.1:9081", "127.0.0.1:9082", "127.0.0.1:9083"} nodeMap = map[uint32]testNode{ diff --git a/encoding.go b/encoding.go index 17805d750..f916bdc64 100644 --- a/encoding.go +++ b/encoding.go @@ -1,234 +1,142 @@ package gorums import ( - "context" "fmt" - "github.com/relab/gorums/ordering" + "github.com/relab/gorums/internal/stream" "google.golang.org/grpc/codes" - "google.golang.org/grpc/encoding" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" ) -func init() { - encoding.RegisterCodec(NewCodec()) -} - -// ContentSubtype is the subtype used by gorums when sending messages via gRPC. -const ContentSubtype = "gorums" - -type gorumsMsgType uint8 - -const ( - requestType gorumsMsgType = iota + 1 - responseType -) - -// Message encapsulates a protobuf message and metadata. -// -// This struct should be used by generated code only. +// Message encapsulates the stream.Message and the actual proto.Message. type Message struct { - metadata *ordering.Metadata - message proto.Message - msgType gorumsMsgType + Msg proto.Message + *stream.Message } -// newMessage creates a new Message struct for unmarshaling. -// msgType specifies the message type to be unmarshaled. -func newMessage(msgType gorumsMsgType) *Message { - return &Message{metadata: &ordering.Metadata{}, msgType: msgType} -} +// MetadataEntry is a type alias for stream.MetadataEntry. +type MetadataEntry = stream.MetadataEntry -// newRequestMessage creates a new Gorums Message for the given metadata and request message. -// -// This function should be used by generated code and tests only. -func newRequestMessage(md *ordering.Metadata, req proto.Message) *Message { - return &Message{metadata: md, message: req, msgType: requestType} -} +// MetadataEntry_builder is a type alias for stream.MetadataEntry_builder. +type MetadataEntry_builder = stream.MetadataEntry_builder -// NewRequest creates a new Gorums Message for the given context, message ID, method, and request. -// This is a convenience function that combines NewGorumsMetadata and NewRequestMessage. +// NewResponseMessage creates a new response message based on the provided proto +// message. The response message includes the message ID and method from the request +// message to facilitate routing the response back to the caller on the client side. +// The payload, error status, and metadata entries are left empty; the error status +// of the response message can be set using [messageWithError], and the payload will +// be marshaled by [ServerCtx.SendMessage]. This function is safe for concurrent use. // -// This function should be used by generated code and tests only. -func NewRequest(ctx context.Context, msgID uint64, method string, req proto.Message) *Message { - md := ordering.NewGorumsMetadata(ctx, msgID, method) - return &Message{metadata: md, message: req, msgType: requestType} -} - -// NewResponseMessage creates a new Gorums Message for the given metadata and response message. -// -// This function should be used by generated code only. -func NewResponseMessage(md *ordering.Metadata, resp proto.Message) *Message { - return &Message{metadata: md, message: resp, msgType: responseType} +// This function should only be used in generated code. +func NewResponseMessage(in *Message, resp proto.Message) *Message { + if in == nil { + return nil + } + // Create a new stream.Message to avoid race conditions when the sender + // goroutine marshals while the handler creates the next response. + // This can happen for stream-based quorum calls where the handler can + // call SendMessage multiple times before returning. + msgBuilder := stream.Message_builder{ + MessageSeqNo: in.GetMessageSeqNo(), // needed in channel.routeResponse to lookup the response channel + Method: in.GetMethod(), // needed in unmarshalResponse to look up the response type in the proto registry + // Payload is left empty; SendMessage will marshal resp into the payload when sending the message + // Status is left empty; it can be set by messageWithError if needed + } + return &Message{ + Msg: resp, + Message: msgBuilder.Build(), + } } -// AsProto returns msg's underlying protobuf message of the specified type T. -// If msg is nil or the contained message is not of type T, the zero value of T is returned. +// AsProto returns the message's already-decoded proto message as type T. +// If the message is nil, or the underlying message cannot be asserted to T, +// the zero value of T is returned. func AsProto[T proto.Message](msg *Message) T { var zero T - if msg == nil || msg.message == nil { + if msg == nil || msg.Msg == nil { return zero } - if req, ok := msg.message.(T); ok { + if req, ok := msg.Msg.(T); ok { return req } return zero } -// GetProtoMessage returns the protobuf message contained in the Message. -func (m *Message) GetProtoMessage() proto.Message { - if m == nil { - return nil - } - return m.message -} - -// GetMetadata returns the metadata of the message. -func (m *Message) GetMetadata() *ordering.Metadata { - if m == nil { - return nil - } - return m.metadata -} - -// GetMethod returns the method name from the message metadata. -func (m *Message) GetMethod() string { - if m == nil { - return "nil" - } - return m.metadata.GetMethod() -} - -// GetMessageID returns the message ID from the message metadata. -func (m *Message) GetMessageID() uint64 { - if m == nil { - return 0 - } - return m.metadata.GetMessageSeqNo() -} - -func (m *Message) GetStatus() *status.Status { - if m == nil { - return status.New(codes.Unknown, "nil message") +// messageWithError ensures a response message exists and sets the error status. +// If out is nil, a new response message is created based on the in request message; +// otherwise, out is modified in place. This is used by the server to send error +// responses back to the client. +func messageWithError(in, out *Message, err error) *Message { + if out == nil { + out = NewResponseMessage(in, nil) } - return status.FromProto(m.metadata.GetStatus()) -} - -// setError sets the error status in the message metadata in preparation for sending -// the response to the client. The provided error may include several wrapped errors. -// If err is nil, the status is set to OK. -// This method should be called just prior to sending the response to the client. -func (m *Message) setError(err error) { - errStatus, ok := status.FromError(err) - if !ok { - errStatus = status.New(codes.Unknown, err.Error()) + if err != nil { + errStatus, ok := status.FromError(err) + if !ok { + errStatus = status.New(codes.Unknown, err.Error()) + } + out.SetStatus(errStatus.Proto()) } - m.metadata.SetStatus(errStatus.Proto()) -} - -// Codec is the gRPC codec used by gorums. -type Codec struct { - marshaler proto.MarshalOptions - unmarshaler proto.UnmarshalOptions + return out } -// NewCodec returns a new Codec. -func NewCodec() *Codec { - return &Codec{ - marshaler: proto.MarshalOptions{AllowPartial: true}, - unmarshaler: proto.UnmarshalOptions{AllowPartial: true}, +// unmarshalRequest unmarshals the request proto message from the message. +// It uses the method name in the message to look up the Input type from the proto registry. +// +// This function should only be used by internal channel operations. +func unmarshalRequest(in *stream.Message) (proto.Message, error) { + // get method descriptor from registry + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(in.GetMethod())) + if err != nil { + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", in.GetMethod()) } -} - -// Name returns the name of the Codec. -func (Codec) Name() string { - return ContentSubtype -} - -func (Codec) String() string { - return ContentSubtype -} + methodDesc := desc.(protoreflect.MethodDescriptor) -// Marshal marshals the message m into a byte slice. -func (c Codec) Marshal(m any) (b []byte, err error) { - switch msg := m.(type) { - case *Message: - return c.gorumsMarshal(msg) - case proto.Message: - return c.marshaler.Marshal(msg) - default: - return nil, fmt.Errorf("gorums: cannot marshal message of type '%T'", m) + // get the request message type (Input type) + msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Input().FullName()) + if err != nil { + return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Input().FullName()) } -} + req := msgType.New().Interface() -// gorumsMarshal marshals a Message by serializing the application message -// into metadata.message_data, then marshaling the metadata. -func (c Codec) gorumsMarshal(msg *Message) (b []byte, err error) { - // serialize the application message into metadata.message_data - if msg.message != nil { - msgData, err := c.marshaler.Marshal(msg.message) - if err != nil { - return nil, fmt.Errorf("gorums: could not marshal message: %w", err) + // unmarshal message from the Message.Payload field + payload := in.GetPayload() + if len(payload) > 0 { + if err := proto.Unmarshal(payload, req); err != nil { + return nil, fmt.Errorf("gorums: could not unmarshal request: %w", err) } - msg.metadata.SetMessageData(msgData) - } - return c.marshaler.Marshal(msg.metadata) -} - -// Unmarshal unmarshals a byte slice into m. -func (c Codec) Unmarshal(b []byte, m any) (err error) { - switch msg := m.(type) { - case *Message: - return c.gorumsUnmarshal(b, msg) - case proto.Message: - return c.unmarshaler.Unmarshal(b, msg) - default: - return fmt.Errorf("gorums: cannot unmarshal message of type '%T'", m) } + return req, nil } -// gorumsUnmarshal extracts metadata and message data from b and places the result in msg. -func (c Codec) gorumsUnmarshal(b []byte, msg *Message) (err error) { - // unmarshal metadata - err = c.unmarshaler.Unmarshal(b, msg.metadata) - if err != nil { - return fmt.Errorf("gorums: could not unmarshal metadata: %w", err) - } - +// unmarshalResponse unmarshals the response proto message from the message. +// It uses the method name in the message to look up the Output type from the proto registry. +// +// This function should only be used by internal channel operations. +func unmarshalResponse(out *stream.Message) (proto.Message, error) { // get method descriptor from registry - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(msg.GetMethod())) + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(out.GetMethod())) if err != nil { - // err is a NotFound error with no method name information; return a more informative error - return fmt.Errorf("gorums: could not find method descriptor for %s", msg.GetMethod()) + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", out.GetMethod()) } methodDesc := desc.(protoreflect.MethodDescriptor) - // get message name depending on whether we are creating a request or response message - var messageName protoreflect.FullName - switch msg.msgType { - case requestType: - messageName = methodDesc.Input().FullName() - case responseType: - messageName = methodDesc.Output().FullName() - default: - return fmt.Errorf("gorums: unknown message type %d", msg.msgType) - } - - // now get the message type from the types registry - msgType, err := protoregistry.GlobalTypes.FindMessageByName(messageName) + // get the response message type (Output type) + msgType, err := protoregistry.GlobalTypes.FindMessageByName(methodDesc.Output().FullName()) if err != nil { - // err is a NotFound error with no message name information; return a more informative error - return fmt.Errorf("gorums: could not find message type %s", messageName) + return nil, fmt.Errorf("gorums: could not find message type %s", methodDesc.Output().FullName()) } - msg.message = msgType.New().Interface() + resp := msgType.New().Interface() - // unmarshal message from metadata.message_data - msgData := msg.metadata.GetMessageData() - if len(msgData) > 0 { - return c.unmarshaler.Unmarshal(msgData, msg.message) + // unmarshal message from the Message.Payload field + payload := out.GetPayload() + if len(payload) > 0 { + if err := proto.Unmarshal(payload, resp); err != nil { + return nil, fmt.Errorf("gorums: could not unmarshal response: %w", err) + } } - return nil + return resp, nil } diff --git a/encoding_test.go b/encoding_test.go index 950fab259..4149e0bfa 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -3,10 +3,98 @@ package gorums_test import ( "testing" + "github.com/google/go-cmp/cmp" "github.com/relab/gorums" + "github.com/relab/gorums/internal/stream" "github.com/relab/gorums/internal/tests/config" + "github.com/relab/gorums/internal/testutils/mock" + "google.golang.org/protobuf/testing/protocmp" ) +func TestNewResponseMessage(t *testing.T) { + req := config.Request_builder{Num: 99}.Build() + resp := config.Response_builder{Name: "test", Num: 42}.Build() + + streamIn := stream.Message_builder{ + MessageSeqNo: 100, + Method: mock.GetValueMethod, + Payload: []byte("request payload"), + Entry: []*stream.MetadataEntry{stream.MetadataEntry_builder{Key: "key1", Value: "val1"}.Build()}, + }.Build() + streamOut := stream.Message_builder{MessageSeqNo: 100, Method: mock.GetValueMethod}.Build() + + tests := []struct { + name string + in *gorums.Message + resp *config.Response + want *gorums.Message + }{ + { + name: "NilIn/NilResp/NilOut", + in: nil, + resp: nil, + want: nil, + }, + { + name: "NilIn/Resp/NilOut", + in: nil, + resp: resp, + want: nil, + }, + { + name: "NilReq/NilResp/StreamIn/StreamOut", + in: &gorums.Message{Msg: nil, Message: streamIn}, + resp: nil, + want: &gorums.Message{Msg: (*config.Response)(nil), Message: streamOut}, + }, + { + name: "NilReq/Resp/StreamIn/StreamOut", + in: &gorums.Message{Msg: nil, Message: streamIn}, + resp: resp, + want: &gorums.Message{Msg: resp, Message: streamOut}, + }, + { + name: "Req/NilResp/StreamIn/StreamOut", + in: &gorums.Message{Msg: req, Message: streamIn}, + resp: nil, + want: &gorums.Message{Msg: (*config.Response)(nil), Message: streamOut}, + }, + { + name: "Req/Resp/StreamIn/StreamOut", + in: &gorums.Message{Msg: req, Message: streamIn}, + resp: resp, + want: &gorums.Message{Msg: resp, Message: streamOut}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := gorums.NewResponseMessage(tt.in, tt.resp) + if tt.want == nil { + if got != nil { + t.Errorf("NewResponseMessage returned %v, want nil", got) + } + return + } + if got == nil { + t.Fatalf("NewResponseMessage returned nil, want non-nil") + } + // Compare Msg field - handle nil specially + if (tt.want.Msg == nil) != (got.Msg == nil) { + t.Errorf("Msg field: want nil=%v, got nil=%v", tt.want.Msg == nil, got.Msg == nil) + } else if tt.want.Msg != nil && got.Msg != nil { + if diff := cmp.Diff(tt.want.Msg, got.Msg, protocmp.Transform()); diff != "" { + t.Errorf("Msg field mismatch (-want, +got):\n%s", diff) + } + } + // Compare the stream.Message field + if diff := cmp.Diff(tt.want.Message, got.Message, protocmp.Transform()); diff != "" { + t.Errorf("Message field mismatch (-want, +got):\n%s", diff) + } + }) + } +} + func TestAsProto(t *testing.T) { t.Parallel() @@ -18,7 +106,7 @@ func TestAsProto(t *testing.T) { }{ { name: "Success", - msg: gorums.NewRequest(t.Context(), 0, "", config.Request_builder{Num: 42}.Build()), + msg: gorums.NewResponseMessage(&gorums.Message{}, config.Response_builder{Name: "test", Num: 42}.Build()), wantNil: false, wantNum: 42, }, @@ -29,7 +117,7 @@ func TestAsProto(t *testing.T) { }, { name: "WrongType", - msg: gorums.NewResponseMessage(nil, config.Response_builder{Name: "test", Num: 99}.Build()), + msg: gorums.NewResponseMessage(&gorums.Message{}, config.Request_builder{Num: 99}.Build()), wantNil: true, }, } @@ -37,18 +125,18 @@ func TestAsProto(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - req := gorums.AsProto[*config.Request](tc.msg) + req := gorums.AsProto[*config.Response](tc.msg) if tc.wantNil { if req != nil { - t.Errorf("AsProto returned %v, want nil", req) + t.Errorf("AsProto(%v) returned %v, want nil", tc.msg, req) } return } if req == nil { - t.Errorf("AsProto returned nil, want *config.Request") + t.Errorf("AsProto(%v) returned nil, want *config.Response", tc.msg) } if got := req.GetNum(); got != tc.wantNum { - t.Errorf("Num = %d, want %d", got, tc.wantNum) + t.Errorf("Num() = %d, want %d", got, tc.wantNum) } }) } diff --git a/errors.go b/errors.go index b22018584..0834ad9c5 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,10 @@ var ErrSendFailure = errors.New("send failure") // ErrTypeMismatch is returned when a response cannot be cast to the expected type. var ErrTypeMismatch = errors.New("response type mismatch") +// ErrSkipNode is returned when a node is skipped by request transformations. +// This allows the response iterator to account for all nodes without blocking. +var ErrSkipNode = errors.New("skip node") + // QuorumCallError reports on a failed quorum call. // It provides detailed information about which nodes failed. type QuorumCallError struct { diff --git a/examples/interceptors/server_interceptors.go b/examples/interceptors/server_interceptors.go index b5a402438..2a001e9e6 100644 --- a/examples/interceptors/server_interceptors.go +++ b/examples/interceptors/server_interceptors.go @@ -6,30 +6,34 @@ import ( "time" "github.com/relab/gorums" - "github.com/relab/gorums/ordering" "google.golang.org/grpc/peer" + "google.golang.org/protobuf/proto" ) func LoggingInterceptor(addr string) gorums.Interceptor { - return func(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { - log.Printf("[%s]: LoggingInterceptor(incoming): Method=%s, Message=%s", addr, msg.GetMethod(), msg.GetProtoMessage()) + return func(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { + req := gorums.AsProto[proto.Message](in) + log.Printf("[%s]: LoggingInterceptor(incoming): Method=%s, Message=%s", addr, in.GetMethod(), req) start := time.Now() - resp, err := next(ctx, msg) + out, err := next(ctx, in) duration := time.Since(start) - log.Printf("[%s]: LoggingInterceptor(outgoing): Method=%s, Duration=%s, Err=%v, Message=%v, Type=%T", addr, msg.GetMethod(), duration, err, resp.GetProtoMessage(), resp.GetProtoMessage()) - return resp, err + resp := gorums.AsProto[proto.Message](out) + log.Printf("[%s]: LoggingInterceptor(outgoing): Method=%s, Duration=%s, Err=%v, Message=%v, Type=%T", addr, in.GetMethod(), duration, err, resp, resp) + return out, err } } -func LoggingSimpleInterceptor(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { - log.Printf("LoggingSimpleInterceptor(incoming): Method=%s, Message=%v)", msg.GetMethod(), msg.GetProtoMessage()) - resp, err := next(ctx, msg) - log.Printf("LoggingSimpleInterceptor(outgoing): Method=%s, Err=%v, Message=%v", msg.GetMethod(), err, resp.GetProtoMessage()) - return resp, err +func LoggingSimpleInterceptor(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { + req := gorums.AsProto[proto.Message](in) + log.Printf("LoggingSimpleInterceptor(incoming): Method=%s, Message=%v)", in.GetMethod(), req) + out, err := next(ctx, in) + resp := gorums.AsProto[proto.Message](out) + log.Printf("LoggingSimpleInterceptor(outgoing): Method=%s, Err=%v, Message=%v", in.GetMethod(), err, resp) + return out, err } -func DelayedInterceptor(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { +func DelayedInterceptor(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { // delay based on sending node address delay := 0 * time.Millisecond peer, ok := peer.FromContext(ctx) @@ -45,35 +49,35 @@ func DelayedInterceptor(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.H time.Sleep(delay) // Call the next handler in the chain - resp, err := next(ctx, msg) + out, err := next(ctx, in) log.Printf("DelayedInterceptor: Finished processing message after %s", delay) - return resp, err + return out, err } /** NoFooAllowedInterceptor rejects requests for messages with key "foo". */ -func NoFooAllowedInterceptor[T interface{ GetKey() string }](ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { - if req, ok := msg.GetProtoMessage().(T); ok { +func NoFooAllowedInterceptor[T interface{ GetKey() string }](ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { + if req, ok := gorums.AsProto[proto.Message](in).(T); ok { log.Printf("NoFooAllowedInterceptor: Received request for key '%s'", req.GetKey()) if req.GetKey() == "foo" { log.Printf("NoFooAllowedInterceptor: Rejecting request for key 'foo'") return nil, fmt.Errorf("requests for key 'foo' are not allowed") } } - return next(ctx, msg) + return next(ctx, in) } -func MetadataInterceptor(ctx gorums.ServerCtx, msg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { +func MetadataInterceptor(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { log.Printf("MetadataInterceptor: Adding custom metadata to message(customKey=customValue)") // Add a custom metadata field - entry := ordering.MetadataEntry_builder{ + entry := gorums.MetadataEntry_builder{ Key: "customKey", Value: "customValue", }.Build() - msg.GetMetadata().SetEntry([]*ordering.MetadataEntry{ + in.SetEntry([]*gorums.MetadataEntry{ entry, }) // Call the next handler in the chain - resp, err := next(ctx, msg) + out, err := next(ctx, in) log.Printf("MetadataInterceptor: Finished processing message with custom metadata") - return resp, err + return out, err } diff --git a/examples/storage/proto/storage_gorums.pb.go b/examples/storage/proto/storage_gorums.pb.go index b189b1f4d..26dd9556d 100644 --- a/examples/storage/proto/storage_gorums.pb.go +++ b/examples/storage/proto/storage_gorums.pb.go @@ -137,22 +137,34 @@ func RegisterStorageServer(srv *gorums.Server, impl StorageServer) { srv.RegisterHandler("proto.Storage.ReadRPC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*ReadRequest](in) resp, err := impl.ReadRPC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("proto.Storage.WriteRPC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*WriteRequest](in) resp, err := impl.WriteRPC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("proto.Storage.ReadQC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*ReadRequest](in) resp, err := impl.ReadQC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("proto.Storage.WriteQC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*WriteRequest](in) resp, err := impl.WriteQC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("proto.Storage.WriteMulticast", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*WriteRequest](in) diff --git a/internal/stream/gorums_message.go b/internal/stream/gorums_message.go new file mode 100644 index 000000000..b39269ad6 --- /dev/null +++ b/internal/stream/gorums_message.go @@ -0,0 +1,59 @@ +package stream + +import ( + "context" + + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +// NewMessage creates a new [Message] proto message for the given method and message ID. +// If a non-nil proto message is provided, it is marshaled and included in the message payload. +// This function also extracts any client-specific metadata from the context and appends +// it to the message, allowing client-specific metadata to be passed to the server. +// +// This method is intended for Gorums internal use. +// This function is used on the client-side to create outgoing request messages. +func NewMessage(ctx context.Context, msgID uint64, method string, msg proto.Message) (*Message, error) { + // Marshal the message to bytes (nil message returns nil bytes and no error) + payload, err := proto.Marshal(msg) + if err != nil { + return nil, err + } + msgBuilder := Message_builder{ + MessageSeqNo: msgID, + Method: method, + Payload: payload, + } + md, _ := metadata.FromOutgoingContext(ctx) + for k, vv := range md { + for _, v := range vv { + entry := MetadataEntry_builder{Key: k, Value: v}.Build() + msgBuilder.Entry = append(msgBuilder.Entry, entry) + } + } + return msgBuilder.Build(), nil +} + +// AppendToIncomingContext appends client-specific metadata from the [Message] proto message +// to the incoming gRPC context, allowing server implementations to extract and use said +// metadata directly from the server method's context. +// +// This method is intended for Gorums internal use. +func (x *Message) AppendToIncomingContext(ctx context.Context) context.Context { + existingMD, _ := metadata.FromIncomingContext(ctx) + newMD := existingMD.Copy() // copy to avoid mutating the original + for _, entry := range x.GetEntry() { + newMD.Append(entry.GetKey(), entry.GetValue()) + } + return metadata.NewIncomingContext(ctx, newMD) +} + +func (x *Message) ErrorStatus() error { + s := x.GetStatus() + if s == nil { + return nil + } + return status.ErrorProto(s) +} diff --git a/ordering/ordering.pb.go b/internal/stream/stream.pb.go similarity index 61% rename from ordering/ordering.pb.go rename to internal/stream/stream.pb.go index 362656ca6..f7671a6b5 100644 --- a/ordering/ordering.pb.go +++ b/internal/stream/stream.pb.go @@ -2,9 +2,9 @@ // versions: // protoc-gen-go v1.36.11 // protoc v6.33.4 -// source: ordering/ordering.proto +// source: internal/stream/stream.proto -package ordering +package stream import ( status "google.golang.org/genproto/googleapis/rpc/status" @@ -21,34 +21,34 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// Metadata is sent together with application-specific message types, +// Message is sent together with application-specific message types, // and contains information necessary for Gorums to handle the messages. -type Metadata struct { +type Message struct { state protoimpl.MessageState `protogen:"opaque.v1"` xxx_hidden_MessageSeqNo uint64 `protobuf:"varint,1,opt,name=message_seq_no,json=messageSeqNo"` xxx_hidden_Method string `protobuf:"bytes,2,opt,name=method"` xxx_hidden_Status *status.Status `protobuf:"bytes,3,opt,name=status"` xxx_hidden_Entry *[]*MetadataEntry `protobuf:"bytes,4,rep,name=entry"` - xxx_hidden_MessageData []byte `protobuf:"bytes,5,opt,name=message_data,json=messageData"` + xxx_hidden_Payload []byte `protobuf:"bytes,5,opt,name=payload"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Metadata) Reset() { - *x = Metadata{} - mi := &file_ordering_ordering_proto_msgTypes[0] +func (x *Message) Reset() { + *x = Message{} + mi := &file_internal_stream_stream_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Metadata) String() string { +func (x *Message) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Metadata) ProtoMessage() {} +func (*Message) ProtoMessage() {} -func (x *Metadata) ProtoReflect() protoreflect.Message { - mi := &file_ordering_ordering_proto_msgTypes[0] +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_internal_stream_stream_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -59,28 +59,28 @@ func (x *Metadata) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -func (x *Metadata) GetMessageSeqNo() uint64 { +func (x *Message) GetMessageSeqNo() uint64 { if x != nil { return x.xxx_hidden_MessageSeqNo } return 0 } -func (x *Metadata) GetMethod() string { +func (x *Message) GetMethod() string { if x != nil { return x.xxx_hidden_Method } return "" } -func (x *Metadata) GetStatus() *status.Status { +func (x *Message) GetStatus() *status.Status { if x != nil { return x.xxx_hidden_Status } return nil } -func (x *Metadata) GetEntry() []*MetadataEntry { +func (x *Message) GetEntry() []*MetadataEntry { if x != nil { if x.xxx_hidden_Entry != nil { return *x.xxx_hidden_Entry @@ -89,66 +89,66 @@ func (x *Metadata) GetEntry() []*MetadataEntry { return nil } -func (x *Metadata) GetMessageData() []byte { +func (x *Message) GetPayload() []byte { if x != nil { - return x.xxx_hidden_MessageData + return x.xxx_hidden_Payload } return nil } -func (x *Metadata) SetMessageSeqNo(v uint64) { +func (x *Message) SetMessageSeqNo(v uint64) { x.xxx_hidden_MessageSeqNo = v } -func (x *Metadata) SetMethod(v string) { +func (x *Message) SetMethod(v string) { x.xxx_hidden_Method = v } -func (x *Metadata) SetStatus(v *status.Status) { +func (x *Message) SetStatus(v *status.Status) { x.xxx_hidden_Status = v } -func (x *Metadata) SetEntry(v []*MetadataEntry) { +func (x *Message) SetEntry(v []*MetadataEntry) { x.xxx_hidden_Entry = &v } -func (x *Metadata) SetMessageData(v []byte) { +func (x *Message) SetPayload(v []byte) { if v == nil { v = []byte{} } - x.xxx_hidden_MessageData = v + x.xxx_hidden_Payload = v } -func (x *Metadata) HasStatus() bool { +func (x *Message) HasStatus() bool { if x == nil { return false } return x.xxx_hidden_Status != nil } -func (x *Metadata) ClearStatus() { +func (x *Message) ClearStatus() { x.xxx_hidden_Status = nil } -type Metadata_builder struct { +type Message_builder struct { _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. MessageSeqNo uint64 Method string Status *status.Status Entry []*MetadataEntry - MessageData []byte + Payload []byte } -func (b0 Metadata_builder) Build() *Metadata { - m0 := &Metadata{} +func (b0 Message_builder) Build() *Message { + m0 := &Message{} b, x := &b0, m0 _, _ = b, x x.xxx_hidden_MessageSeqNo = b.MessageSeqNo x.xxx_hidden_Method = b.Method x.xxx_hidden_Status = b.Status x.xxx_hidden_Entry = &b.Entry - x.xxx_hidden_MessageData = b.MessageData + x.xxx_hidden_Payload = b.Payload return m0 } @@ -163,7 +163,7 @@ type MetadataEntry struct { func (x *MetadataEntry) Reset() { *x = MetadataEntry{} - mi := &file_ordering_ordering_proto_msgTypes[1] + mi := &file_internal_stream_stream_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -175,7 +175,7 @@ func (x *MetadataEntry) String() string { func (*MetadataEntry) ProtoMessage() {} func (x *MetadataEntry) ProtoReflect() protoreflect.Message { - mi := &file_ordering_ordering_proto_msgTypes[1] + mi := &file_internal_stream_stream_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -224,35 +224,35 @@ func (b0 MetadataEntry_builder) Build() *MetadataEntry { return m0 } -var File_ordering_ordering_proto protoreflect.FileDescriptor +var File_internal_stream_stream_proto protoreflect.FileDescriptor -const file_ordering_ordering_proto_rawDesc = "" + +const file_internal_stream_stream_proto_rawDesc = "" + "\n" + - "\x17ordering/ordering.proto\x12\bordering\x1a\x17google/rpc/status.proto\"\xc6\x01\n" + - "\bMetadata\x12$\n" + + "\x1cinternal/stream/stream.proto\x12\x06stream\x1a\x17google/rpc/status.proto\"\xba\x01\n" + + "\aMessage\x12$\n" + "\x0emessage_seq_no\x18\x01 \x01(\x04R\fmessageSeqNo\x12\x16\n" + "\x06method\x18\x02 \x01(\tR\x06method\x12*\n" + - "\x06status\x18\x03 \x01(\v2\x12.google.rpc.StatusR\x06status\x12-\n" + - "\x05entry\x18\x04 \x03(\v2\x17.ordering.MetadataEntryR\x05entry\x12!\n" + - "\fmessage_data\x18\x05 \x01(\fR\vmessageData\"7\n" + + "\x06status\x18\x03 \x01(\v2\x12.google.rpc.StatusR\x06status\x12+\n" + + "\x05entry\x18\x04 \x03(\v2\x15.stream.MetadataEntryR\x05entry\x12\x18\n" + + "\apayload\x18\x05 \x01(\fR\apayload\"7\n" + "\rMetadataEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\tR\x05value2B\n" + - "\x06Gorums\x128\n" + + "\x05value\x18\x02 \x01(\tR\x05value2<\n" + + "\x06Gorums\x122\n" + "\n" + - "NodeStream\x12\x12.ordering.Metadata\x1a\x12.ordering.Metadata(\x010\x01B'Z github.com/relab/gorums/ordering\x92\x03\x02\b\x02b\beditionsp\xe8\a" + "NodeStream\x12\x0f.stream.Message\x1a\x0f.stream.Message(\x010\x01B.Z'github.com/relab/gorums/internal/stream\x92\x03\x02\b\x02b\beditionsp\xe8\a" -var file_ordering_ordering_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_ordering_ordering_proto_goTypes = []any{ - (*Metadata)(nil), // 0: ordering.Metadata - (*MetadataEntry)(nil), // 1: ordering.MetadataEntry +var file_internal_stream_stream_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_internal_stream_stream_proto_goTypes = []any{ + (*Message)(nil), // 0: stream.Message + (*MetadataEntry)(nil), // 1: stream.MetadataEntry (*status.Status)(nil), // 2: google.rpc.Status } -var file_ordering_ordering_proto_depIdxs = []int32{ - 2, // 0: ordering.Metadata.status:type_name -> google.rpc.Status - 1, // 1: ordering.Metadata.entry:type_name -> ordering.MetadataEntry - 0, // 2: ordering.Gorums.NodeStream:input_type -> ordering.Metadata - 0, // 3: ordering.Gorums.NodeStream:output_type -> ordering.Metadata +var file_internal_stream_stream_proto_depIdxs = []int32{ + 2, // 0: stream.Message.status:type_name -> google.rpc.Status + 1, // 1: stream.Message.entry:type_name -> stream.MetadataEntry + 0, // 2: stream.Gorums.NodeStream:input_type -> stream.Message + 0, // 3: stream.Gorums.NodeStream:output_type -> stream.Message 3, // [3:4] is the sub-list for method output_type 2, // [2:3] is the sub-list for method input_type 2, // [2:2] is the sub-list for extension type_name @@ -260,26 +260,26 @@ var file_ordering_ordering_proto_depIdxs = []int32{ 0, // [0:2] is the sub-list for field type_name } -func init() { file_ordering_ordering_proto_init() } -func file_ordering_ordering_proto_init() { - if File_ordering_ordering_proto != nil { +func init() { file_internal_stream_stream_proto_init() } +func file_internal_stream_stream_proto_init() { + if File_internal_stream_stream_proto != nil { return } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_ordering_ordering_proto_rawDesc), len(file_ordering_ordering_proto_rawDesc)), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_internal_stream_stream_proto_rawDesc), len(file_internal_stream_stream_proto_rawDesc)), NumEnums: 0, NumMessages: 2, NumExtensions: 0, NumServices: 1, }, - GoTypes: file_ordering_ordering_proto_goTypes, - DependencyIndexes: file_ordering_ordering_proto_depIdxs, - MessageInfos: file_ordering_ordering_proto_msgTypes, + GoTypes: file_internal_stream_stream_proto_goTypes, + DependencyIndexes: file_internal_stream_stream_proto_depIdxs, + MessageInfos: file_internal_stream_stream_proto_msgTypes, }.Build() - File_ordering_ordering_proto = out.File - file_ordering_ordering_proto_goTypes = nil - file_ordering_ordering_proto_depIdxs = nil + File_internal_stream_stream_proto = out.File + file_internal_stream_stream_proto_goTypes = nil + file_internal_stream_stream_proto_depIdxs = nil } diff --git a/ordering/ordering.proto b/internal/stream/stream.proto similarity index 76% rename from ordering/ordering.proto rename to internal/stream/stream.proto index 21a20d1a8..10143bad2 100644 --- a/ordering/ordering.proto +++ b/internal/stream/stream.proto @@ -1,7 +1,7 @@ edition = "2023"; -package ordering; -option go_package = "github.com/relab/gorums/ordering"; +package stream; +option go_package = "github.com/relab/gorums/internal/stream"; option features.field_presence = IMPLICIT; import "google/rpc/status.proto"; @@ -11,17 +11,17 @@ service Gorums { // NodeStream is a stream that connects a client to a Node. // The messages that are sent on the stream contain both Metadata // and an application-specific message. - rpc NodeStream(stream Metadata) returns (stream Metadata); + rpc NodeStream(stream Message) returns (stream Message); } -// Metadata is sent together with application-specific message types, +// Message is sent together with application-specific message types, // and contains information necessary for Gorums to handle the messages. -message Metadata { +message Message { uint64 message_seq_no = 1; // sequence number for this message string method = 2; // method name to invoke on the server google.rpc.Status status = 3; // status of an invocation (for responses) repeated MetadataEntry entry = 4; // per message client-generated metadata - bytes message_data = 5; // serialized application message + bytes payload = 5; // serialized application message } // MetadataEntry is a key-value pair for Metadata entries. diff --git a/ordering/ordering_grpc.pb.go b/internal/stream/stream_grpc.pb.go similarity index 86% rename from ordering/ordering_grpc.pb.go rename to internal/stream/stream_grpc.pb.go index a1b79e9f0..b692df2ee 100644 --- a/ordering/ordering_grpc.pb.go +++ b/internal/stream/stream_grpc.pb.go @@ -2,9 +2,9 @@ // versions: // - protoc-gen-go-grpc v1.6.0 // - protoc v6.33.4 -// source: ordering/ordering.proto +// source: internal/stream/stream.proto -package ordering +package stream import ( context "context" @@ -19,7 +19,7 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - Gorums_NodeStream_FullMethodName = "/ordering.Gorums/NodeStream" + Gorums_NodeStream_FullMethodName = "/stream.Gorums/NodeStream" ) // GorumsClient is the client API for Gorums service. @@ -31,7 +31,7 @@ type GorumsClient interface { // NodeStream is a stream that connects a client to a Node. // The messages that are sent on the stream contain both Metadata // and an application-specific message. - NodeStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Metadata, Metadata], error) + NodeStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Message, Message], error) } type gorumsClient struct { @@ -42,18 +42,18 @@ func NewGorumsClient(cc grpc.ClientConnInterface) GorumsClient { return &gorumsClient{cc} } -func (c *gorumsClient) NodeStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Metadata, Metadata], error) { +func (c *gorumsClient) NodeStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[Message, Message], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &Gorums_ServiceDesc.Streams[0], Gorums_NodeStream_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[Metadata, Metadata]{ClientStream: stream} + x := &grpc.GenericClientStream[Message, Message]{ClientStream: stream} return x, nil } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Gorums_NodeStreamClient = grpc.BidiStreamingClient[Metadata, Metadata] +type Gorums_NodeStreamClient = grpc.BidiStreamingClient[Message, Message] // GorumsServer is the server API for Gorums service. // All implementations must embed UnimplementedGorumsServer @@ -64,7 +64,7 @@ type GorumsServer interface { // NodeStream is a stream that connects a client to a Node. // The messages that are sent on the stream contain both Metadata // and an application-specific message. - NodeStream(grpc.BidiStreamingServer[Metadata, Metadata]) error + NodeStream(grpc.BidiStreamingServer[Message, Message]) error mustEmbedUnimplementedGorumsServer() } @@ -75,7 +75,7 @@ type GorumsServer interface { // pointer dereference when methods are called. type UnimplementedGorumsServer struct{} -func (UnimplementedGorumsServer) NodeStream(grpc.BidiStreamingServer[Metadata, Metadata]) error { +func (UnimplementedGorumsServer) NodeStream(grpc.BidiStreamingServer[Message, Message]) error { return status.Error(codes.Unimplemented, "method NodeStream not implemented") } func (UnimplementedGorumsServer) mustEmbedUnimplementedGorumsServer() {} @@ -100,17 +100,17 @@ func RegisterGorumsServer(s grpc.ServiceRegistrar, srv GorumsServer) { } func _Gorums_NodeStream_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(GorumsServer).NodeStream(&grpc.GenericServerStream[Metadata, Metadata]{ServerStream: stream}) + return srv.(GorumsServer).NodeStream(&grpc.GenericServerStream[Message, Message]{ServerStream: stream}) } // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type Gorums_NodeStreamServer = grpc.BidiStreamingServer[Metadata, Metadata] +type Gorums_NodeStreamServer = grpc.BidiStreamingServer[Message, Message] // Gorums_ServiceDesc is the grpc.ServiceDesc for Gorums service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var Gorums_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "ordering.Gorums", + ServiceName: "stream.Gorums", HandlerType: (*GorumsServer)(nil), Methods: []grpc.MethodDesc{}, Streams: []grpc.StreamDesc{ @@ -121,5 +121,5 @@ var Gorums_ServiceDesc = grpc.ServiceDesc{ ClientStreams: true, }, }, - Metadata: "ordering/ordering.proto", + Metadata: "internal/stream/stream.proto", } diff --git a/internal/tests/config/config_gorums.pb.go b/internal/tests/config/config_gorums.pb.go index 7d2f7de2c..dfac39e9a 100644 --- a/internal/tests/config/config_gorums.pb.go +++ b/internal/tests/config/config_gorums.pb.go @@ -101,6 +101,9 @@ func RegisterConfigTestServer(srv *gorums.Server, impl ConfigTestServer) { srv.RegisterHandler("config.ConfigTest.Config", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.Config(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/internal/tests/correctable/correctable_gorums.pb.go b/internal/tests/correctable/correctable_gorums.pb.go index d20429209..a967da930 100644 --- a/internal/tests/correctable/correctable_gorums.pb.go +++ b/internal/tests/correctable/correctable_gorums.pb.go @@ -8,7 +8,6 @@ package correctable import ( gorums "github.com/relab/gorums" - proto "google.golang.org/protobuf/proto" ) const ( @@ -118,14 +117,16 @@ func RegisterCorrectableTestServer(srv *gorums.Server, impl CorrectableTestServe srv.RegisterHandler("correctable.CorrectableTest.Correctable", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.Correctable(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("correctable.CorrectableTest.CorrectableStream", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) err := impl.CorrectableStream(ctx, req, func(resp *Response) error { - // create a copy of the metadata, to avoid a data race between NewResponseMessage and SendMsg - md := proto.CloneOf(in.GetMetadata()) - return ctx.SendMessage(gorums.NewResponseMessage(md, resp)) + out := gorums.NewResponseMessage(in, resp) + return ctx.SendMessage(out) }) return nil, err }) diff --git a/internal/tests/metadata/metadata_gorums.pb.go b/internal/tests/metadata/metadata_gorums.pb.go index dadb7fa44..8dd8bda81 100644 --- a/internal/tests/metadata/metadata_gorums.pb.go +++ b/internal/tests/metadata/metadata_gorums.pb.go @@ -93,11 +93,17 @@ func RegisterMetadataTestServer(srv *gorums.Server, impl MetadataTestServer) { srv.RegisterHandler("metadata.MetadataTest.IDFromMD", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*emptypb.Empty](in) resp, err := impl.IDFromMD(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("metadata.MetadataTest.WhatIP", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*emptypb.Empty](in) resp, err := impl.WhatIP(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/internal/tests/oneway/oneway_test.go b/internal/tests/oneway/oneway_test.go index 0d91c9725..915aa7c97 100644 --- a/internal/tests/oneway/oneway_test.go +++ b/internal/tests/oneway/oneway_test.go @@ -37,7 +37,7 @@ func (s *onewaySrv) Multicast(_ gorums.ServerCtx, r *oneway.Request) { } // setupWithNodeMap sets up servers and configuration with sequential node IDs -// (0, 1, 2, ...) matching the server array indices. This is needed for tests like +// (1, 2, 3, ...) matching the server array indices. This is needed for tests like // TestMulticastPerNode that verify per-node message transformations based on node ID. func setupWithNodeMap(t testing.TB, cfgSize int) (cfg oneway.Configuration, srvs []*onewaySrv) { t.Helper() @@ -45,7 +45,6 @@ func setupWithNodeMap(t testing.TB, cfgSize int) (cfg oneway.Configuration, srvs for i := range cfgSize { srvs[i] = &onewaySrv{received: make(chan *oneway.Request, numCalls)} } - cfg = gorums.TestConfiguration(t, cfgSize, func(i int) gorums.ServerIface { srv := gorums.NewServer() oneway.RegisterOnewayTestServer(srv, srvs[i]) @@ -70,14 +69,12 @@ func TestOnewayCalls(t *testing.T) { {name: "MulticastSendWaiting__", calls: numCalls, servers: 9, sendWait: true}, {name: "MulticastNoSendWaiting", calls: numCalls, servers: 9, sendWait: false}, } - for _, test := range tests { t.Run(fmt.Sprintf("%s/Servers=%d", test.name, test.servers), func(t *testing.T) { config, srvs := setupWithNodeMap(t, test.servers) for i := range srvs { srvs[i].wg.Add(test.calls) } - for c := 1; c <= test.calls; c++ { in := oneway.Request_builder{Num: uint64(c)}.Build() if config.Size() == 1 { @@ -110,12 +107,18 @@ func TestOnewayCalls(t *testing.T) { for i := range srvs { srvs[i].wg.Wait() close(srvs[i].received) - expected := uint64(1) + received := make([]uint64, 0, test.calls) for r := range srvs[i].received { - if expected != r.GetNum() { - t.Errorf("%s(%d) = %d, expected %d", test.name, expected, r.GetNum(), expected) + received = append(received, r.GetNum()) + } + // Sort received messages to avoid test flakiness + // due to message reordering in multicast tests + slices.Sort(received) + for j, got := range received { + want := uint64(j + 1) + if want != got { + t.Errorf("%s: received[%d] = %d, expected %d", test.name, j, got, want) } - expected++ } } }) @@ -163,7 +166,6 @@ func TestMulticastPerNode(t *testing.T) { {name: "MulticastPerNodeSendWaitingIgnoreNodes", calls: numCalls, servers: 3, sendWait: true, ignoreNodes: []uint32{0, 1}}, {name: "MulticastPerNodeSendWaitingIgnoreNodes", calls: numCalls, servers: 3, sendWait: true, ignoreNodes: []uint32{0, 1, 2}}, } - for _, test := range tests { t.Run(fmt.Sprintf("%s/Servers=%d/IgnoredNodes=%v", test.name, test.servers, test.ignoreNodes), func(t *testing.T) { config, srvs := setupWithNodeMap(t, test.servers) @@ -206,12 +208,18 @@ func TestMulticastPerNode(t *testing.T) { } srvs[i].wg.Wait() close(srvs[i].received) - expected := add(uint64(1), nodeIDs[i]) + received := make([]uint64, 0, test.calls) for r := range srvs[i].received { - if expected != r.GetNum() { - t.Errorf("%s -> %d, expected %d, nodeID=%d", test.name, r.GetNum(), expected, nodeIDs[i]) + received = append(received, r.GetNum()) + } + // Sort received messages to avoid test flakiness + // due to message reordering in multicast tests + slices.Sort(received) + for j, got := range received { + want := add(uint64(j+1), nodeIDs[i]) + if want != got { + t.Errorf("%s: received[%d] = %d, expected %d, nodeID=%d", test.name, j, got, want, nodeIDs[i]) } - expected++ } } }) diff --git a/internal/tests/ordering/order_gorums.pb.go b/internal/tests/ordering/order_gorums.pb.go index 5e72a50fc..54d0d4c87 100644 --- a/internal/tests/ordering/order_gorums.pb.go +++ b/internal/tests/ordering/order_gorums.pb.go @@ -107,11 +107,17 @@ func RegisterGorumsTestServer(srv *gorums.Server, impl GorumsTestServer) { srv.RegisterHandler("ordering.GorumsTest.QuorumCall", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.QuorumCall(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) srv.RegisterHandler("ordering.GorumsTest.UnaryRPC", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.UnaryRPC(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/internal/tests/tls/tls_gorums.pb.go b/internal/tests/tls/tls_gorums.pb.go index a037a9120..e9b0efad8 100644 --- a/internal/tests/tls/tls_gorums.pb.go +++ b/internal/tests/tls/tls_gorums.pb.go @@ -86,6 +86,9 @@ func RegisterTLSServer(srv *gorums.Server, impl TLSServer) { srv.RegisterHandler("tls.TLS.TestTLS", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Request](in) resp, err := impl.TestTLS(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/internal/tests/unresponsive/unresponsive_gorums.pb.go b/internal/tests/unresponsive/unresponsive_gorums.pb.go index 9fb964a8d..f1d6b1492 100644 --- a/internal/tests/unresponsive/unresponsive_gorums.pb.go +++ b/internal/tests/unresponsive/unresponsive_gorums.pb.go @@ -86,6 +86,9 @@ func RegisterUnresponsiveServer(srv *gorums.Server, impl UnresponsiveServer) { srv.RegisterHandler("unresponsive.Unresponsive.TestUnresponsive", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*Empty](in) resp, err := impl.TestUnresponsive(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) } diff --git a/mgr.go b/mgr.go index 3f407a149..18b0996bc 100644 --- a/mgr.go +++ b/mgr.go @@ -37,9 +37,6 @@ func NewManager(opts ...ManagerOption) *Manager { if m.opts.logger != nil { m.logger = m.opts.logger } - m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithDefaultCallOptions( - grpc.CallContentSubtype(ContentSubtype), - )) if m.opts.backoff != backoff.DefaultConfig { m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithConnectParams( grpc.ConnectParams{Backoff: m.opts.backoff}, diff --git a/mgr_test.go b/mgr_test.go deleted file mode 100644 index a012c2168..000000000 --- a/mgr_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorums - -import ( - "bytes" - "log" - "strings" - "testing" - - "google.golang.org/grpc/encoding" -) - -func init() { - if encoding.GetCodec(ContentSubtype) == nil { - encoding.RegisterCodec(NewCodec()) - } -} - -func TestManagerLogging(t *testing.T) { - var ( - buf bytes.Buffer - logger = log.New(&buf, "logger: ", log.Lshortfile) - ) - mgr := NewManager(InsecureDialOptions(t), WithLogger(logger)) - t.Cleanup(Closer(t, mgr)) - - want := "logger: mgr.go:49: ready" - if strings.TrimSpace(buf.String()) != want { - t.Errorf("logger: got %q, want %q", buf.String(), want) - } -} diff --git a/multicast.go b/multicast.go index a620aeecf..7ff2cd540 100644 --- a/multicast.go +++ b/multicast.go @@ -1,6 +1,8 @@ package gorums import ( + "errors" + "google.golang.org/protobuf/types/known/emptypb" ) @@ -31,10 +33,10 @@ func Multicast[Req msg](ctx *ConfigContext, req Req, method string, opts ...Call // If waiting for send completion, drain the reply channel and return the first error. if waitSendDone { var errs []nodeError - for range clientCtx.expectedReplies { + for range clientCtx.Size() { select { case r := <-clientCtx.replyChan: - if r.Err != nil { + if r.Err != nil && !errors.Is(r.Err, ErrSkipNode) { errs = append(errs, nodeError{cause: r.Err, nodeID: r.NodeID}) } case <-ctx.Done(): diff --git a/ordering/gorums_metadata.go b/ordering/gorums_metadata.go deleted file mode 100644 index 0a4716cb3..000000000 --- a/ordering/gorums_metadata.go +++ /dev/null @@ -1,40 +0,0 @@ -package ordering - -import ( - "context" - - "google.golang.org/grpc/metadata" -) - -// NewGorumsMetadata creates a new Gorums metadata object for the given method -// and client message ID. It also appends any client-specific metadata from the -// context to the Gorums metadata object. -// -// This is used to pass client-specific metadata to the server via Gorums. -// This method should be used by generated code only. -func NewGorumsMetadata(ctx context.Context, msgID uint64, method string) *Metadata { - gorumsMetadata := Metadata_builder{MessageSeqNo: msgID, Method: method} - md, _ := metadata.FromOutgoingContext(ctx) - for k, vv := range md { - for _, v := range vv { - entry := MetadataEntry_builder{Key: k, Value: v}.Build() - gorumsMetadata.Entry = append(gorumsMetadata.Entry, entry) - } - } - return gorumsMetadata.Build() -} - -// AppendToIncomingContext appends client-specific metadata from the -// Gorums metadata object to the incoming context. -// -// This is used to pass client-specific metadata from the Gorums runtime -// to the server implementation. -// This method should be used by generated code only. -func (x *Metadata) AppendToIncomingContext(ctx context.Context) context.Context { - existingMD, _ := metadata.FromIncomingContext(ctx) - newMD := existingMD.Copy() // copy to avoid mutating the original - for _, entry := range x.GetEntry() { - newMD.Append(entry.GetKey(), entry.GetValue()) - } - return metadata.NewIncomingContext(ctx, newMD) -} diff --git a/quorumcall.go b/quorumcall.go index afbc91ac8..80d168ee0 100644 --- a/quorumcall.go +++ b/quorumcall.go @@ -15,7 +15,7 @@ package gorums // or iterator method (like Seq) is called, applying any registered request transformations. // This lazy sending is necessary to allow interceptors to register transformations prior to dispatch. // -// This function should be used by generated code only. +// This function should only be used by generated code. func QuorumCall[Req, Resp msg]( ctx *ConfigContext, req Req, @@ -31,7 +31,7 @@ func QuorumCall[Req, Resp msg]( // In streaming mode, the response iterator continues indefinitely until the context // is canceled, allowing the server to send multiple responses over time. // -// This function should be used by generated code only. +// This function should only be used by generated code. func QuorumCallStream[Req, Resp msg]( ctx *ConfigContext, req Req, diff --git a/responses.go b/responses.go index 856e4b7ac..a6d7bf87f 100644 --- a/responses.go +++ b/responses.go @@ -1,6 +1,7 @@ package gorums import ( + "errors" "iter" "google.golang.org/protobuf/proto" @@ -192,7 +193,7 @@ func (r *Responses[Resp]) Threshold(threshold int) (resp Resp, err error) { errs []nodeError ) for result := range r.ResponseSeq { - if result.Err != nil { + if result.Err != nil && !errors.Is(result.Err, ErrSkipNode) { errs = append(errs, nodeError{nodeID: result.NodeID, cause: result.Err}) continue } diff --git a/responses_test.go b/responses_test.go index 9d4c81788..ad69333db 100644 --- a/responses_test.go +++ b/responses_test.go @@ -24,10 +24,9 @@ func makeClientCtx[Req, Resp msg](t *testing.T, numNodes int, responses []NodeRe } c := &ClientCtx[Req, Resp]{ - Context: t.Context(), - config: config, - replyChan: resultChan, - expectedReplies: numNodes, + Context: t.Context(), + config: config, + replyChan: resultChan, } // Mark sendOnce as done since test responses are already in the channel c.sendOnce.Do(func() {}) diff --git a/rpc.go b/rpc.go index e16f1c55a..12b804720 100644 --- a/rpc.go +++ b/rpc.go @@ -1,11 +1,18 @@ package gorums +import "github.com/relab/gorums/internal/stream" + // RPCCall executes a remote procedure call on the node. // // This method should be used by generated code only. func RPCCall[Req, Resp msg](ctx *NodeContext, req Req, method string) (Resp, error) { replyChan := make(chan NodeResponse[msg], 1) - ctx.enqueue(request{ctx: ctx, msg: NewRequest(ctx, ctx.nextMsgID(), method, req), responseChan: replyChan}) + reqMsg, err := stream.NewMessage(ctx, ctx.nextMsgID(), method, req) + if err != nil { + var zero Resp + return zero, err + } + ctx.enqueue(request{ctx: ctx, msg: reqMsg, responseChan: replyChan}) select { case r := <-replyChan: diff --git a/rpc_test.go b/rpc_test.go index 576662ea3..0be0feb2d 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -9,16 +9,9 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/testutils/mock" - "google.golang.org/grpc/encoding" pb "google.golang.org/protobuf/types/known/wrapperspb" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} - func TestRPCCallSuccess(t *testing.T) { node := gorums.TestNode(t, gorums.DefaultTestServer) diff --git a/server.go b/server.go index 497028476..0e1c31f92 100644 --- a/server.go +++ b/server.go @@ -5,8 +5,9 @@ import ( "net" "sync" - "github.com/relab/gorums/ordering" + "github.com/relab/gorums/internal/stream" "google.golang.org/grpc" + "google.golang.org/protobuf/proto" ) type ( @@ -15,18 +16,18 @@ type ( // a Handler representing the next element in the chain (either another // Interceptor or the actual server method). It returns a Message and an error. Interceptor func(ServerCtx, *Message, Handler) (*Message, error) - // Handler is a function that processes a request message and returns a response message. + // Handler is a function that processes a request and returns a response. Handler func(ServerCtx, *Message) (*Message, error) ) -type orderingServer struct { +type streamServer struct { handlers map[string]Handler opts *serverOptions - ordering.UnimplementedGorumsServer + stream.UnimplementedGorumsServer } -func newOrderingServer(opts *serverOptions) *orderingServer { - return &orderingServer{ +func newStreamServer(opts *serverOptions) *streamServer { + return &streamServer{ handlers: make(map[string]Handler), opts: opts, } @@ -34,9 +35,9 @@ func newOrderingServer(opts *serverOptions) *orderingServer { // NodeStream handles a connection to a single client. The stream is aborted if there // is any error with sending or receiving. -func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error { +func (s *streamServer) NodeStream(srv stream.Gorums_NodeStreamServer) error { var mut sync.Mutex // used to achieve mutex between request handlers - finished := make(chan *Message, s.opts.buffer) + finished := make(chan *stream.Message, s.opts.buffer) ctx := srv.Context() if s.opts.connectCallback != nil { @@ -48,9 +49,8 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error select { case <-ctx.Done(): return - case msg := <-finished: - err := srv.SendMsg(msg) - if err != nil { + case streamOut := <-finished: + if err := srv.Send(streamOut); err != nil { return } } @@ -62,12 +62,11 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error defer mut.Unlock() for { - req := newMessage(requestType) - err := srv.RecvMsg(req) + streamIn, err := srv.Recv() if err != nil { return err } - if handler, ok := s.handlers[req.GetMethod()]; ok { + if handler, ok := s.handlers[streamIn.GetMethod()]; ok { // We start the handler in a new goroutine in order to allow multiple handlers to run concurrently. // However, to preserve request ordering, the handler must unlock the shared mutex when it has either // finished, or when it is safe to start processing the next request. @@ -75,26 +74,26 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error // This func() is the default interceptor; it is the first and last handler in the chain. // It is responsible for releasing the mutex when the handler chain is done. go func() { - metadata := req.GetMetadata() - srvCtx := newServerCtx(metadata.AppendToIncomingContext(ctx), &mut, finished) + srvCtx := newServerCtx(streamIn.AppendToIncomingContext(ctx), &mut, finished) defer srvCtx.Release() - message, err := handler(srvCtx, req) - // If there is no message and no error, we do not send anything back to the client. + msg, err := unmarshalRequest(streamIn) + in := &Message{Msg: msg, Message: streamIn} + if err != nil { + _ = srvCtx.SendMessage(messageWithError(in, nil, err)) + return + } + + out, err := handler(srvCtx, in) + // If there is no response and no error, we do not send anything back to the client. // This corresponds to a unidirectional message from client to server, where clients // are not expected to receive a response. - if message == nil && err == nil { + if out == nil && err == nil { return } - // If there was an error in the interceptor chain or the method handler, they may return a nil message. - // Thus, we need to create a response message to send the error back to the client. - if message == nil { - message = NewResponseMessage(metadata, nil) - } - message.setError(err) - _ = srvCtx.SendMessage(message) // to the client + _ = srvCtx.SendMessage(messageWithError(in, out, err)) // We ignore the error from SendMessage here; it means that the stream is closed. - // The for-loop above will exit on the next RecvMsg call. + // The for-loop above will exit on the next Recv call. }() // Wait until the handler releases the mutex. mut.Lock() @@ -161,8 +160,8 @@ func chainInterceptors(final Handler, interceptors ...Interceptor) Handler { for i := len(interceptors) - 1; i >= 0; i-- { curr := interceptors[i] next := handler - handler = func(ctx ServerCtx, msg *Message) (*Message, error) { - return curr(ctx, msg, next) + handler = func(ctx ServerCtx, in *Message) (*Message, error) { + return curr(ctx, in, next) } } return handler @@ -170,7 +169,7 @@ func chainInterceptors(final Handler, interceptors ...Interceptor) Handler { // Server serves all ordering based RPCs using registered handlers. type Server struct { - srv *orderingServer + srv *streamServer grpcServer *grpc.Server interceptors []Interceptor } @@ -182,11 +181,11 @@ func NewServer(opts ...ServerOption) *Server { opt(&serverOpts) } s := &Server{ - srv: newOrderingServer(&serverOpts), + srv: newStreamServer(&serverOpts), grpcServer: grpc.NewServer(serverOpts.grpcOpts...), interceptors: serverOpts.interceptors, } - ordering.RegisterGorumsServer(s.grpcServer, s.srv) + stream.RegisterGorumsServer(s.grpcServer, s.srv) return s } @@ -219,11 +218,11 @@ type ServerCtx struct { context.Context once *sync.Once // must be a pointer to avoid passing ctx by value mut *sync.Mutex - c chan<- *Message + c chan<- *stream.Message } -// newServerCtx creates a new ServerCtx with the given context, mutex and message channel. -func newServerCtx(ctx context.Context, mut *sync.Mutex, c chan<- *Message) ServerCtx { +// newServerCtx creates a new ServerCtx with the given context, mutex and metadata channel. +func newServerCtx(ctx context.Context, mut *sync.Mutex, c chan<- *stream.Message) ServerCtx { return ServerCtx{ Context: ctx, once: new(sync.Once), @@ -242,10 +241,19 @@ func (ctx *ServerCtx) Release() { // SendMessage attempts to send the given message to the client. // This may fail if the stream was closed or the stream context got canceled. // -// This function should be used by generated code only. -func (ctx *ServerCtx) SendMessage(msg *Message) error { +// This function should only be used by generated code. +func (ctx *ServerCtx) SendMessage(out *Message) error { + // If Msg is set, marshal it to payload before sending. + if out.Msg != nil && len(out.GetPayload()) == 0 { + payload, err := proto.Marshal(out.Msg) + if err == nil { + out.SetPayload(payload) + } + // Return an error to the client if marshaling failed on the server side; don't close the stream. + out = messageWithError(nil, out, err) + } select { - case ctx.c <- msg: + case ctx.c <- out.Message: case <-ctx.Done(): return ctx.Err() } diff --git a/server_test.go b/server_test.go index c70abae30..ac449d352 100644 --- a/server_test.go +++ b/server_test.go @@ -8,17 +8,11 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/testutils/mock" - "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" pb "google.golang.org/protobuf/types/known/wrapperspb" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} - func TestServerCallback(t *testing.T) { var message string signal := make(chan struct{}) @@ -45,17 +39,41 @@ func TestServerCallback(t *testing.T) { } } -func appendStringInterceptor(in, out string) gorums.Interceptor { - return func(ctx gorums.ServerCtx, inMsg *gorums.Message, next gorums.Handler) (*gorums.Message, error) { - req := gorums.AsProto[*pb.StringValue](inMsg) +func appendStringInterceptor(inStr, outStr string) gorums.Interceptor { + return func(ctx gorums.ServerCtx, in *gorums.Message, next gorums.Handler) (*gorums.Message, error) { + req := gorums.AsProto[*pb.StringValue](in) // update the underlying request gorums.Message's message field (pb.StringValue in this case) - req.Value += in + req.Value += inStr + + // TODO(meling): I did not like this change; I need to think more about this; + // can we make interceptors modify the request and responses in a cleaner way? + // Maybe we can add a method to stream.Message to modify/marshal the payload? + // Or maybe interceptors can operate on a InterceptorMessage wrapper that can + // hold the proto message that can be manipulated directly, and only on the way + // out of the interceptor chain, the wrapper is marshaled into the stream.Message.Payload. + + // re-marshal the modified request into the in message + reqPayload, err := proto.Marshal(req) + if err != nil { + return nil, err + } + in.SetPayload(reqPayload) + // call the next handler - outMsg, err := next(ctx, inMsg) - resp := gorums.AsProto[*pb.StringValue](outMsg) + out, err := next(ctx, in) + if err != nil { + return nil, err + } + resp := gorums.AsProto[*pb.StringValue](out) // update the underlying response gorums.Message's message field (pb.StringValue in this case) - resp.Value += out - return outMsg, err + resp.Value += outStr + // re-marshal the modified response into the out message + respPayload, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + out.SetPayload(respPayload) + return out, err } } @@ -77,7 +95,10 @@ func TestServerInterceptorsChain(t *testing.T) { s.RegisterHandler(mock.TestMethod, func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*pb.StringValue](in) resp, err := interceptorSrv.Test(ctx, req) - return gorums.NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil }) return s } @@ -104,7 +125,7 @@ func TestTCPReconnection(t *testing.T) { srv := gorums.NewServer() srv.RegisterHandler(mock.TestMethod, func(_ gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*pb.StringValue](in) - return gorums.NewResponseMessage(in.GetMetadata(), req), nil + return gorums.NewResponseMessage(in, req), nil }) lis, err := net.Listen("tcp", "127.0.0.1:0") @@ -160,7 +181,7 @@ func TestTCPReconnection(t *testing.T) { srv2 := gorums.NewServer() srv2.RegisterHandler(mock.TestMethod, func(_ gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { req := gorums.AsProto[*pb.StringValue](in) - return gorums.NewResponseMessage(in.GetMetadata(), req), nil + return gorums.NewResponseMessage(in, req), nil }) go func() { _ = srv2.Serve(lis2) diff --git a/testing_shared.go b/testing_shared.go index 5ac73831c..ccd737201 100644 --- a/testing_shared.go +++ b/testing_shared.go @@ -53,7 +53,7 @@ func TestQuorumCallError(_ testing.TB, nodeErrors map[uint32]error) QuorumCallEr // // Optional TestOptions can be provided to customize the manager, server, or configuration. // -// By default, nodes are assigned sequential IDs (0, 1, 2, ...) matching the server +// By default, nodes are assigned sequential IDs (1, 2, 3, ...) matching the server // creation order. This can be overridden by providing a NodeListOption. // // This is the recommended way to set up tests that need both servers and a configuration. @@ -247,12 +247,18 @@ func defaultTestServer(i int, opts ...ServerOption) ServerIface { srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { req := AsProto[*pb.StringValue](in) resp, err := ts.Test(ctx, req) - return NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return NewResponseMessage(in, resp), nil }) srv.RegisterHandler(mock.GetValueMethod, func(ctx ServerCtx, in *Message) (*Message, error) { req := AsProto[*pb.Int32Value](in) resp, err := ts.GetValue(ctx, req) - return NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return NewResponseMessage(in, resp), nil }) return srv } @@ -274,7 +280,10 @@ func EchoServerFn(_ int) ServerIface { srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { req := AsProto[*pb.StringValue](in) resp, err := echoSrv{}.Test(ctx, req) - return NewResponseMessage(in.GetMetadata(), resp), err + if err != nil { + return nil, err + } + return NewResponseMessage(in, resp), nil }) return srv @@ -296,8 +305,8 @@ func StreamServerFn(_ int) ServerIface { // Send 3 responses for i := 1; i <= 3; i++ { resp := pb.String(fmt.Sprintf("echo: %s-%d", val, i)) - msg := NewResponseMessage(in.GetMetadata(), resp) - if err := ctx.SendMessage(msg); err != nil { + out := NewResponseMessage(in, resp) + if err := ctx.SendMessage(out); err != nil { return nil, err } time.Sleep(10 * time.Millisecond) @@ -316,8 +325,8 @@ func StreamBenchmarkServerFn(_ int) ServerIface { // Send 3 responses for i := 1; i <= 3; i++ { resp := pb.String(fmt.Sprintf("echo: %s-%d", val, i)) - msg := NewResponseMessage(in.GetMetadata(), resp) - if err := ctx.SendMessage(msg); err != nil { + out := NewResponseMessage(in, resp) + if err := ctx.SendMessage(out); err != nil { return nil, err } } diff --git a/unicast.go b/unicast.go index 3aaad386a..a58f93493 100644 --- a/unicast.go +++ b/unicast.go @@ -1,5 +1,7 @@ package gorums +import "github.com/relab/gorums/internal/stream" + // Unicast is a one-way call; no replies are returned to the client. // // By default, this method blocks until the message has been sent to the node. @@ -13,18 +15,21 @@ package gorums // This method should be used by generated code only. func Unicast[Req msg](ctx *NodeContext, req Req, method string, opts ...CallOption) error { callOpts := getCallOptions(E_Unicast, opts...) - message := NewRequest(ctx, ctx.nextMsgID(), method, req) + reqMsg, err := stream.NewMessage(ctx, ctx.nextMsgID(), method, req) + if err != nil { + return err + } waitSendDone := callOpts.mustWaitSendDone() if !waitSendDone { // Fire-and-forget: enqueue and return immediately - ctx.enqueue(request{ctx: ctx, msg: message}) + ctx.enqueue(request{ctx: ctx, msg: reqMsg}) return nil } // Default: block until send completes replyChan := make(chan NodeResponse[msg], 1) - ctx.enqueue(request{ctx: ctx, msg: message, waitSendDone: true, responseChan: replyChan}) + ctx.enqueue(request{ctx: ctx, msg: reqMsg, waitSendDone: true, responseChan: replyChan}) // Wait for send confirmation select {