diff --git a/channel.go b/channel.go index 6e032bcf9..7db8aa244 100644 --- a/channel.go +++ b/channel.go @@ -155,7 +155,7 @@ func (c *channel) ensureConnectedNodeStream() (err error) { } // getStream returns the current stream, or nil if no stream is available. -func (c *channel) getStream() grpc.ClientStream { +func (c *channel) getStream() ordering.Gorums_NodeStreamClient { c.streamMut.Lock() defer c.streamMut.Unlock() return c.gorumsStream @@ -280,14 +280,21 @@ func (c *channel) receiver() { } } - resp := newMessage(responseType) - if err := stream.RecvMsg(resp); err != nil { + md, err := stream.Recv() + if err != nil { c.setLastErr(err) c.cancelPendingMsgs(err) c.clearStream() } else { - err := resp.GetStatus().Err() - c.routeResponse(resp.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.id, Value: resp.GetProtoMessage(), Err: err}) + resp, err := fromMetadata(md) + if err != nil { + c.setLastErr(err) + c.cancelPendingMsgs(err) + c.clearStream() + } else { + err := resp.GetStatus().Err() + c.routeResponse(resp.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.id, Value: resp.GetProtoMessage(), Err: err}) + } } select { @@ -354,7 +361,11 @@ func (c *channel) sendMsg(req request) (err error) { } }() - if err = stream.SendMsg(req.msg); err != nil { + md, err := req.msg.toMetadata() + if err != nil { + return err + } + if err = stream.Send(md); err != nil { c.setLastErr(err) c.clearStream() } diff --git a/channel_test.go b/channel_test.go index 30bccdffc..b222f57e6 100644 --- a/channel_test.go +++ b/channel_test.go @@ -11,7 +11,6 @@ import ( "github.com/relab/gorums/internal/testutils/mock" "github.com/relab/gorums/ordering" - "google.golang.org/grpc" "google.golang.org/protobuf/proto" pb "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -114,7 +113,7 @@ func routerExists(node *Node, msgID uint64) bool { return exists } -func getStream(node *Node) grpc.ClientStream { +func getStream(node *Node) ordering.Gorums_NodeStreamClient { return node.channel.getStream() } @@ -324,7 +323,7 @@ func TestChannelEnsureStream(t *testing.T) { } // Helper to verify stream expectations - cmpStream := func(t *testing.T, first, second grpc.ClientStream, wantSame bool) { + cmpStream := func(t *testing.T, first, second ordering.Gorums_NodeStreamClient, wantSame bool) { t.Helper() // If second is nil, skip equality check (covered by UnconnectedNodeHasNoStream action) if second == nil { @@ -342,13 +341,13 @@ func TestChannelEnsureStream(t *testing.T) { tests := []struct { name string setup func(t *testing.T) *Node - action func(node *Node) (first, second grpc.ClientStream) + action func(node *Node) (first, second ordering.Gorums_NodeStreamClient) wantSame bool }{ { name: "UnconnectedNodeHasNoStream", setup: func(t *testing.T) *Node { return testNodeWithoutServer(t) }, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { + action: func(node *Node) (ordering.Gorums_NodeStreamClient, ordering.Gorums_NodeStreamClient) { if err := node.channel.ensureStream(); err == nil { t.Error("ensureStream succeeded unexpectedly") } @@ -361,7 +360,7 @@ func TestChannelEnsureStream(t *testing.T) { { name: "CreatesStreamWhenConnected", setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { + action: func(node *Node) (ordering.Gorums_NodeStreamClient, ordering.Gorums_NodeStreamClient) { if err := node.channel.ensureStream(); err != nil { t.Errorf("ensureStream failed: %v", err) } @@ -371,7 +370,7 @@ func TestChannelEnsureStream(t *testing.T) { { name: "RepeatedCallsReturnSameStream", setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { + action: func(node *Node) (ordering.Gorums_NodeStreamClient, ordering.Gorums_NodeStreamClient) { if err := node.channel.ensureStream(); err != nil { t.Errorf("first ensureStream failed: %v", err) } @@ -386,7 +385,7 @@ func TestChannelEnsureStream(t *testing.T) { { name: "StreamDisconnectionCreatesNewStream", setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { + action: func(node *Node) (ordering.Gorums_NodeStreamClient, ordering.Gorums_NodeStreamClient) { if err := node.channel.ensureStream(); err != nil { t.Errorf("initial ensureStream failed: %v", err) } diff --git a/config_test.go b/config_test.go index b14c97e22..99e67effe 100644 --- a/config_test.go +++ b/config_test.go @@ -7,16 +7,9 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/testutils/mock" - "google.golang.org/grpc/encoding" pb "google.golang.org/protobuf/types/known/wrapperspb" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} - var ( nodes = []string{"127.0.0.1:9081", "127.0.0.1:9082", "127.0.0.1:9083"} nodeMap = map[uint32]testNode{ diff --git a/encoding.go b/encoding.go index 39e8b5e59..635288e53 100644 --- a/encoding.go +++ b/encoding.go @@ -5,21 +5,12 @@ import ( "github.com/relab/gorums/ordering" "google.golang.org/grpc/codes" - "google.golang.org/grpc/encoding" "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" ) -func init() { - encoding.RegisterCodec(NewCodec()) -} - -// ContentSubtype is the subtype used by gorums when sending messages via gRPC. -const ContentSubtype = "gorums" - type gorumsMsgType uint8 const ( @@ -36,12 +27,6 @@ type Message struct { msgType gorumsMsgType } -// newMessage creates a new Message struct for unmarshaling. -// msgType specifies the message type to be unmarshaled. -func newMessage(msgType gorumsMsgType) *Message { - return &Message{metadata: &ordering.Metadata{}, msgType: msgType} -} - // NewRequestMessage creates a new Gorums Message for the given metadata and request message. // // This function should be used by generated code and tests only. @@ -120,85 +105,39 @@ func (m *Message) setError(err error) { m.metadata.SetStatus(errStatus.Proto()) } -// Codec is the gRPC codec used by gorums. -type Codec struct { - marshaler proto.MarshalOptions - unmarshaler proto.UnmarshalOptions -} - -// NewCodec returns a new Codec. -func NewCodec() *Codec { - return &Codec{ - marshaler: proto.MarshalOptions{AllowPartial: true}, - unmarshaler: proto.UnmarshalOptions{AllowPartial: true}, - } -} - -// Name returns the name of the Codec. -func (Codec) Name() string { - return ContentSubtype -} - -func (Codec) String() string { - return ContentSubtype -} - -// Marshal marshals the message m into a byte slice. -func (c Codec) Marshal(m any) (b []byte, err error) { - switch msg := m.(type) { - case *Message: - return c.gorumsMarshal(msg) - case proto.Message: - return c.marshaler.Marshal(msg) - default: - return nil, fmt.Errorf("gorums: cannot marshal message of type '%T'", m) +// toMetadata serializes the application message into the metadata's payload +// field and returns the metadata, ready for sending via the type-safe Send method. +func (m *Message) toMetadata() (*ordering.Metadata, error) { + md := m.metadata + md.SetMsgType(uint32(m.msgType)) + if m.message != nil { + b, err := proto.MarshalOptions{AllowPartial: true}.Marshal(m.message) + if err != nil { + return nil, fmt.Errorf("gorums: failed to marshal payload: %w", err) + } + md.SetPayload(b) } + return md, nil } -// gorumsMarshal marshals a metadata and a data message into a single byte slice. -func (c Codec) gorumsMarshal(msg *Message) (b []byte, err error) { - mdSize := c.marshaler.Size(msg.metadata) - b = protowire.AppendVarint(b, uint64(mdSize)) - b, err = c.marshaler.MarshalAppend(b, msg.metadata) - if err != nil { - return nil, err +// fromMetadata reconstructs a Message from a received Metadata by deserializing +// the payload bytes into the appropriate protobuf message type, determined by +// the method descriptor and message type (request or response) in the metadata. +func fromMetadata(md *ordering.Metadata) (*Message, error) { + msg := &Message{ + metadata: md, + msgType: gorumsMsgType(md.GetMsgType()), } - msgSize := c.marshaler.Size(msg.message) - b = protowire.AppendVarint(b, uint64(msgSize)) - b, err = c.marshaler.MarshalAppend(b, msg.message) - if err != nil { - return nil, err - } - return b, nil -} - -// Unmarshal unmarshals a byte slice into m. -func (c Codec) Unmarshal(b []byte, m any) (err error) { - switch msg := m.(type) { - case *Message: - return c.gorumsUnmarshal(b, msg) - case proto.Message: - return c.unmarshaler.Unmarshal(b, msg) - default: - return fmt.Errorf("gorums: cannot unmarshal message of type '%T'", m) - } -} - -// gorumsUnmarshal extracts metadata and message data from b and places the result in msg. -func (c Codec) gorumsUnmarshal(b []byte, msg *Message) (err error) { - // unmarshal metadata - mdBuf, mdLen := protowire.ConsumeBytes(b) - err = c.unmarshaler.Unmarshal(mdBuf, msg.metadata) - if err != nil { - return fmt.Errorf("gorums: could not unmarshal metadata: %w", err) + method := msg.GetMethod() + if method == "" || method == "nil" { + return msg, nil } // get method descriptor from registry - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(msg.GetMethod())) + desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(method)) if err != nil { - // err is a NotFound error with no method name information; return a more informative error - return fmt.Errorf("gorums: could not find method descriptor for %s", msg.GetMethod()) + return nil, fmt.Errorf("gorums: could not find method descriptor for %s", method) } methodDesc := desc.(protoreflect.MethodDescriptor) @@ -210,18 +149,22 @@ func (c Codec) gorumsUnmarshal(b []byte, msg *Message) (err error) { case responseType: messageName = methodDesc.Output().FullName() default: - return fmt.Errorf("gorums: unknown message type %d", msg.msgType) + return nil, fmt.Errorf("gorums: unknown message type %d", msg.msgType) } - // now get the message type from the types registry + // get the message type from the types registry msgType, err := protoregistry.GlobalTypes.FindMessageByName(messageName) if err != nil { - // err is a NotFound error with no message name information; return a more informative error - return fmt.Errorf("gorums: could not find message type %s", messageName) + return nil, fmt.Errorf("gorums: could not find message type %s", messageName) } msg.message = msgType.New().Interface() - // unmarshal message - msgBuf, _ := protowire.ConsumeBytes(b[mdLen:]) - return c.unmarshaler.Unmarshal(msgBuf, msg.message) + // unmarshal payload into the message + payload := md.GetPayload() + if len(payload) > 0 { + if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(payload, msg.message); err != nil { + return nil, fmt.Errorf("gorums: failed to unmarshal payload: %w", err) + } + } + return msg, nil } diff --git a/mgr.go b/mgr.go index 3f407a149..18b0996bc 100644 --- a/mgr.go +++ b/mgr.go @@ -37,9 +37,6 @@ func NewManager(opts ...ManagerOption) *Manager { if m.opts.logger != nil { m.logger = m.opts.logger } - m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithDefaultCallOptions( - grpc.CallContentSubtype(ContentSubtype), - )) if m.opts.backoff != backoff.DefaultConfig { m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithConnectParams( grpc.ConnectParams{Backoff: m.opts.backoff}, diff --git a/mgr_test.go b/mgr_test.go index a012c2168..c27fa3631 100644 --- a/mgr_test.go +++ b/mgr_test.go @@ -5,16 +5,8 @@ import ( "log" "strings" "testing" - - "google.golang.org/grpc/encoding" ) -func init() { - if encoding.GetCodec(ContentSubtype) == nil { - encoding.RegisterCodec(NewCodec()) - } -} - func TestManagerLogging(t *testing.T) { var ( buf bytes.Buffer @@ -23,8 +15,10 @@ func TestManagerLogging(t *testing.T) { mgr := NewManager(InsecureDialOptions(t), WithLogger(logger)) t.Cleanup(Closer(t, mgr)) - want := "logger: mgr.go:49: ready" - if strings.TrimSpace(buf.String()) != want { - t.Errorf("logger: got %q, want %q", buf.String(), want) + got := strings.TrimSpace(buf.String()) + wantPrefix := "logger: mgr.go:" + wantSuffix := ": ready" + if !strings.HasPrefix(got, wantPrefix) || !strings.HasSuffix(got, wantSuffix) { + t.Errorf("logger: got %q, want %q%q", got, wantPrefix, wantSuffix) } } diff --git a/ordering/ordering.pb.go b/ordering/ordering.pb.go index bd60796da..7a354893b 100644 --- a/ordering/ordering.pb.go +++ b/ordering/ordering.pb.go @@ -29,6 +29,8 @@ type Metadata struct { xxx_hidden_Method string `protobuf:"bytes,2,opt,name=method"` xxx_hidden_Status *status.Status `protobuf:"bytes,3,opt,name=status"` xxx_hidden_Entry *[]*MetadataEntry `protobuf:"bytes,4,rep,name=entry"` + xxx_hidden_Payload []byte `protobuf:"bytes,5,opt,name=payload"` + xxx_hidden_MsgType uint32 `protobuf:"varint,6,opt,name=msg_type,json=msgType"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -88,6 +90,20 @@ func (x *Metadata) GetEntry() []*MetadataEntry { return nil } +func (x *Metadata) GetPayload() []byte { + if x != nil { + return x.xxx_hidden_Payload + } + return nil +} + +func (x *Metadata) GetMsgType() uint32 { + if x != nil { + return x.xxx_hidden_MsgType + } + return 0 +} + func (x *Metadata) SetMessageSeqNo(v uint64) { x.xxx_hidden_MessageSeqNo = v } @@ -104,6 +120,17 @@ func (x *Metadata) SetEntry(v []*MetadataEntry) { x.xxx_hidden_Entry = &v } +func (x *Metadata) SetPayload(v []byte) { + if v == nil { + v = []byte{} + } + x.xxx_hidden_Payload = v +} + +func (x *Metadata) SetMsgType(v uint32) { + x.xxx_hidden_MsgType = v +} + func (x *Metadata) HasStatus() bool { if x == nil { return false @@ -122,6 +149,8 @@ type Metadata_builder struct { Method string Status *status.Status Entry []*MetadataEntry + Payload []byte + MsgType uint32 } func (b0 Metadata_builder) Build() *Metadata { @@ -132,6 +161,8 @@ func (b0 Metadata_builder) Build() *Metadata { x.xxx_hidden_Method = b.Method x.xxx_hidden_Status = b.Status x.xxx_hidden_Entry = &b.Entry + x.xxx_hidden_Payload = b.Payload + x.xxx_hidden_MsgType = b.MsgType return m0 } @@ -211,12 +242,14 @@ var File_ordering_ordering_proto protoreflect.FileDescriptor const file_ordering_ordering_proto_rawDesc = "" + "\n" + - "\x17ordering/ordering.proto\x12\bordering\x1a\x17google/rpc/status.proto\"\xa3\x01\n" + + "\x17ordering/ordering.proto\x12\bordering\x1a\x17google/rpc/status.proto\"\xd8\x01\n" + "\bMetadata\x12$\n" + "\x0emessage_seq_no\x18\x01 \x01(\x04R\fmessageSeqNo\x12\x16\n" + "\x06method\x18\x02 \x01(\tR\x06method\x12*\n" + "\x06status\x18\x03 \x01(\v2\x12.google.rpc.StatusR\x06status\x12-\n" + - "\x05entry\x18\x04 \x03(\v2\x17.ordering.MetadataEntryR\x05entry\"7\n" + + "\x05entry\x18\x04 \x03(\v2\x17.ordering.MetadataEntryR\x05entry\x12\x18\n" + + "\apayload\x18\x05 \x01(\fR\apayload\x12\x19\n" + + "\bmsg_type\x18\x06 \x01(\rR\amsgType\"7\n" + "\rMetadataEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + "\x05value\x18\x02 \x01(\tR\x05value2B\n" + diff --git a/ordering/ordering.proto b/ordering/ordering.proto index 34447afac..bb36bf078 100644 --- a/ordering/ordering.proto +++ b/ordering/ordering.proto @@ -21,6 +21,8 @@ message Metadata { string method = 2; // method name to invoke on the server google.rpc.Status status = 3; // status of an invocation (for responses) repeated MetadataEntry entry = 4; // per message client-generated metadata + bytes payload = 5; // serialized application-specific message + uint32 msg_type = 6; // message type: 1 = request, 2 = response } // MetadataEntry is a key-value pair for Metadata entries. diff --git a/rpc_test.go b/rpc_test.go index da87997ef..d1f4b4e75 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -8,15 +8,9 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/testutils/mock" - "google.golang.org/grpc/encoding" pb "google.golang.org/protobuf/types/known/wrapperspb" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} func TestRPCCallSuccess(t *testing.T) { node := gorums.TestNode(t, gorums.DefaultTestServer) diff --git a/server.go b/server.go index 497028476..752b13753 100644 --- a/server.go +++ b/server.go @@ -49,10 +49,13 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error case <-ctx.Done(): return case msg := <-finished: - err := srv.SendMsg(msg) + md, err := msg.toMetadata() if err != nil { return } + if err := srv.Send(md); err != nil { + return + } } } }() @@ -62,8 +65,11 @@ func (s *orderingServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error defer mut.Unlock() for { - req := newMessage(requestType) - err := srv.RecvMsg(req) + md, err := srv.Recv() + if err != nil { + return err + } + req, err := fromMetadata(md) if err != nil { return err } diff --git a/server_test.go b/server_test.go index 49626797c..546fac873 100644 --- a/server_test.go +++ b/server_test.go @@ -8,18 +8,11 @@ import ( "github.com/relab/gorums" "github.com/relab/gorums/internal/testutils/mock" - "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" pb "google.golang.org/protobuf/types/known/wrapperspb" ) -func init() { - if encoding.GetCodec(gorums.ContentSubtype) == nil { - encoding.RegisterCodec(gorums.NewCodec()) - } -} - func TestServerCallback(t *testing.T) { var message string signal := make(chan struct{})