Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (c *channel) ensureConnectedNodeStream() (err error) {
}

// getStream returns the current stream, or nil if no stream is available.
func (c *channel) getStream() grpc.ClientStream {
func (c *channel) getStream() ordering.Gorums_NodeStreamClient {
c.streamMut.Lock()
defer c.streamMut.Unlock()
return c.gorumsStream
Expand Down Expand Up @@ -280,14 +280,21 @@ func (c *channel) receiver() {
}
}

resp := newMessage(responseType)
if err := stream.RecvMsg(resp); err != nil {
md, err := stream.Recv()
if err != nil {
c.setLastErr(err)
c.cancelPendingMsgs(err)
c.clearStream()
} else {
err := resp.GetStatus().Err()
c.routeResponse(resp.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.id, Value: resp.GetProtoMessage(), Err: err})
resp, err := fromMetadata(md)
if err != nil {
c.setLastErr(err)
c.cancelPendingMsgs(err)
c.clearStream()
} else {
err := resp.GetStatus().Err()
c.routeResponse(resp.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.id, Value: resp.GetProtoMessage(), Err: err})
}
}

select {
Expand Down Expand Up @@ -354,7 +361,11 @@ func (c *channel) sendMsg(req request) (err error) {
}
}()

if err = stream.SendMsg(req.msg); err != nil {
md, err := req.msg.toMetadata()
if err != nil {
return err
}
if err = stream.Send(md); err != nil {
c.setLastErr(err)
c.clearStream()
}
Expand Down
15 changes: 7 additions & 8 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/relab/gorums/internal/testutils/mock"
"github.com/relab/gorums/ordering"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
pb "google.golang.org/protobuf/types/known/wrapperspb"
)
Expand Down Expand Up @@ -114,7 +113,7 @@ func routerExists(node *Node, msgID uint64) bool {
return exists
}

func getStream(node *Node) grpc.ClientStream {
func getStream(node *Node) ordering.Gorums_NodeStreamClient {
return node.channel.getStream()
}

Expand Down Expand Up @@ -324,7 +323,7 @@ func TestChannelEnsureStream(t *testing.T) {
}

// Helper to verify stream expectations
cmpStream := func(t *testing.T, first, second grpc.ClientStream, wantSame bool) {
cmpStream := func(t *testing.T, first, second ordering.Gorums_NodeStreamClient, wantSame bool) {
t.Helper()
// If second is nil, skip equality check (covered by UnconnectedNodeHasNoStream action)
if second == nil {
Expand All @@ -342,13 +341,13 @@ func TestChannelEnsureStream(t *testing.T) {
tests := []struct {
name string
setup func(t *testing.T) *Node
action func(node *Node) (first, second grpc.ClientStream)
action func(node *Node) (first, second ordering.Gorums_NodeStreamClient)
wantSame bool
}{
{
name: "UnconnectedNodeHasNoStream",
setup: func(t *testing.T) *Node { return testNodeWithoutServer(t) },
action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) {
action: func(node *Node) (ordering.Gorums_NodeStreamClient, ordering.Gorums_NodeStreamClient) {
if err := node.channel.ensureStream(); err == nil {
t.Error("ensureStream succeeded unexpectedly")
}
Expand All @@ -361,7 +360,7 @@ func TestChannelEnsureStream(t *testing.T) {
{
name: "CreatesStreamWhenConnected",
setup: newNodeWithoutStream,
action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) {
action: func(node *Node) (ordering.Gorums_NodeStreamClient, ordering.Gorums_NodeStreamClient) {
if err := node.channel.ensureStream(); err != nil {
t.Errorf("ensureStream failed: %v", err)
}
Expand All @@ -371,7 +370,7 @@ func TestChannelEnsureStream(t *testing.T) {
{
name: "RepeatedCallsReturnSameStream",
setup: newNodeWithoutStream,
action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) {
action: func(node *Node) (ordering.Gorums_NodeStreamClient, ordering.Gorums_NodeStreamClient) {
if err := node.channel.ensureStream(); err != nil {
t.Errorf("first ensureStream failed: %v", err)
}
Expand All @@ -386,7 +385,7 @@ func TestChannelEnsureStream(t *testing.T) {
{
name: "StreamDisconnectionCreatesNewStream",
setup: newNodeWithoutStream,
action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) {
action: func(node *Node) (ordering.Gorums_NodeStreamClient, ordering.Gorums_NodeStreamClient) {
if err := node.channel.ensureStream(); err != nil {
t.Errorf("initial ensureStream failed: %v", err)
}
Expand Down
7 changes: 0 additions & 7 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,9 @@ import (

"github.com/relab/gorums"
"github.com/relab/gorums/internal/testutils/mock"
"google.golang.org/grpc/encoding"
pb "google.golang.org/protobuf/types/known/wrapperspb"
)

func init() {
if encoding.GetCodec(gorums.ContentSubtype) == nil {
encoding.RegisterCodec(gorums.NewCodec())
}
}

var (
nodes = []string{"127.0.0.1:9081", "127.0.0.1:9082", "127.0.0.1:9083"}
nodeMap = map[uint32]testNode{
Expand Down
127 changes: 35 additions & 92 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,12 @@ import (

"github.com/relab/gorums/ordering"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)

func init() {
encoding.RegisterCodec(NewCodec())
}

// ContentSubtype is the subtype used by gorums when sending messages via gRPC.
const ContentSubtype = "gorums"

type gorumsMsgType uint8

const (
Expand All @@ -36,12 +27,6 @@ type Message struct {
msgType gorumsMsgType
}

// newMessage creates a new Message struct for unmarshaling.
// msgType specifies the message type to be unmarshaled.
func newMessage(msgType gorumsMsgType) *Message {
return &Message{metadata: &ordering.Metadata{}, msgType: msgType}
}

// NewRequestMessage creates a new Gorums Message for the given metadata and request message.
//
// This function should be used by generated code and tests only.
Expand Down Expand Up @@ -120,85 +105,39 @@ func (m *Message) setError(err error) {
m.metadata.SetStatus(errStatus.Proto())
}

// Codec is the gRPC codec used by gorums.
type Codec struct {
marshaler proto.MarshalOptions
unmarshaler proto.UnmarshalOptions
}

// NewCodec returns a new Codec.
func NewCodec() *Codec {
return &Codec{
marshaler: proto.MarshalOptions{AllowPartial: true},
unmarshaler: proto.UnmarshalOptions{AllowPartial: true},
}
}

// Name returns the name of the Codec.
func (Codec) Name() string {
return ContentSubtype
}

func (Codec) String() string {
return ContentSubtype
}

// Marshal marshals the message m into a byte slice.
func (c Codec) Marshal(m any) (b []byte, err error) {
switch msg := m.(type) {
case *Message:
return c.gorumsMarshal(msg)
case proto.Message:
return c.marshaler.Marshal(msg)
default:
return nil, fmt.Errorf("gorums: cannot marshal message of type '%T'", m)
// toMetadata serializes the application message into the metadata's payload
// field and returns the metadata, ready for sending via the type-safe Send method.
func (m *Message) toMetadata() (*ordering.Metadata, error) {
md := m.metadata
md.SetMsgType(uint32(m.msgType))
if m.message != nil {
b, err := proto.MarshalOptions{AllowPartial: true}.Marshal(m.message)
if err != nil {
return nil, fmt.Errorf("gorums: failed to marshal payload: %w", err)
}
md.SetPayload(b)
}
return md, nil
}

// gorumsMarshal marshals a metadata and a data message into a single byte slice.
func (c Codec) gorumsMarshal(msg *Message) (b []byte, err error) {
mdSize := c.marshaler.Size(msg.metadata)
b = protowire.AppendVarint(b, uint64(mdSize))
b, err = c.marshaler.MarshalAppend(b, msg.metadata)
if err != nil {
return nil, err
// fromMetadata reconstructs a Message from a received Metadata by deserializing
// the payload bytes into the appropriate protobuf message type, determined by
// the method descriptor and message type (request or response) in the metadata.
func fromMetadata(md *ordering.Metadata) (*Message, error) {
msg := &Message{
metadata: md,
msgType: gorumsMsgType(md.GetMsgType()),
}

msgSize := c.marshaler.Size(msg.message)
b = protowire.AppendVarint(b, uint64(msgSize))
b, err = c.marshaler.MarshalAppend(b, msg.message)
if err != nil {
return nil, err
}
return b, nil
}

// Unmarshal unmarshals a byte slice into m.
func (c Codec) Unmarshal(b []byte, m any) (err error) {
switch msg := m.(type) {
case *Message:
return c.gorumsUnmarshal(b, msg)
case proto.Message:
return c.unmarshaler.Unmarshal(b, msg)
default:
return fmt.Errorf("gorums: cannot unmarshal message of type '%T'", m)
}
}

// gorumsUnmarshal extracts metadata and message data from b and places the result in msg.
func (c Codec) gorumsUnmarshal(b []byte, msg *Message) (err error) {
// unmarshal metadata
mdBuf, mdLen := protowire.ConsumeBytes(b)
err = c.unmarshaler.Unmarshal(mdBuf, msg.metadata)
if err != nil {
return fmt.Errorf("gorums: could not unmarshal metadata: %w", err)
method := msg.GetMethod()
if method == "" || method == "nil" {
return msg, nil
}

// get method descriptor from registry
desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(msg.GetMethod()))
desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(method))
if err != nil {
// err is a NotFound error with no method name information; return a more informative error
return fmt.Errorf("gorums: could not find method descriptor for %s", msg.GetMethod())
return nil, fmt.Errorf("gorums: could not find method descriptor for %s", method)
}
methodDesc := desc.(protoreflect.MethodDescriptor)

Expand All @@ -210,18 +149,22 @@ func (c Codec) gorumsUnmarshal(b []byte, msg *Message) (err error) {
case responseType:
messageName = methodDesc.Output().FullName()
default:
return fmt.Errorf("gorums: unknown message type %d", msg.msgType)
return nil, fmt.Errorf("gorums: unknown message type %d", msg.msgType)
}

// now get the message type from the types registry
// get the message type from the types registry
msgType, err := protoregistry.GlobalTypes.FindMessageByName(messageName)
if err != nil {
// err is a NotFound error with no message name information; return a more informative error
return fmt.Errorf("gorums: could not find message type %s", messageName)
return nil, fmt.Errorf("gorums: could not find message type %s", messageName)
}
msg.message = msgType.New().Interface()

// unmarshal message
msgBuf, _ := protowire.ConsumeBytes(b[mdLen:])
return c.unmarshaler.Unmarshal(msgBuf, msg.message)
// unmarshal payload into the message
payload := md.GetPayload()
if len(payload) > 0 {
if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(payload, msg.message); err != nil {
return nil, fmt.Errorf("gorums: failed to unmarshal payload: %w", err)
}
}
return msg, nil
}
3 changes: 0 additions & 3 deletions mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ func NewManager(opts ...ManagerOption) *Manager {
if m.opts.logger != nil {
m.logger = m.opts.logger
}
m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithDefaultCallOptions(
grpc.CallContentSubtype(ContentSubtype),
))
if m.opts.backoff != backoff.DefaultConfig {
m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithConnectParams(
grpc.ConnectParams{Backoff: m.opts.backoff},
Expand Down
16 changes: 5 additions & 11 deletions mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,8 @@ import (
"log"
"strings"
"testing"

"google.golang.org/grpc/encoding"
)

func init() {
if encoding.GetCodec(ContentSubtype) == nil {
encoding.RegisterCodec(NewCodec())
}
}

func TestManagerLogging(t *testing.T) {
var (
buf bytes.Buffer
Expand All @@ -23,8 +15,10 @@ func TestManagerLogging(t *testing.T) {
mgr := NewManager(InsecureDialOptions(t), WithLogger(logger))
t.Cleanup(Closer(t, mgr))

want := "logger: mgr.go:49: ready"
if strings.TrimSpace(buf.String()) != want {
t.Errorf("logger: got %q, want %q", buf.String(), want)
got := strings.TrimSpace(buf.String())
wantPrefix := "logger: mgr.go:"
wantSuffix := ": ready"
if !strings.HasPrefix(got, wantPrefix) || !strings.HasSuffix(got, wantSuffix) {
t.Errorf("logger: got %q, want %q<line>%q", got, wantPrefix, wantSuffix)
}
}
Loading
Loading