From b8b4aac56b79249cc1f96d9b9aa3d86d4527da60 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 13 Apr 2025 13:28:13 +0200 Subject: [PATCH 01/27] refactor: remove unused SignAndVerify method from EllipticCurve --- authentication/authentication.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index 0b6647cf0..ba4702a12 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -129,18 +129,6 @@ func (ec *EllipticCurve) VerifySignature(pemEncodedPub string, msg, signature [] return ok, nil } -// VerifySignature sign ecdsa style and verify signature -func (ec *EllipticCurve) SignAndVerify(privKey *ecdsa.PrivateKey, pubKey *ecdsa.PublicKey) ([]byte, bool, error) { - h := sha256.Sum256([]byte("test")) - hash := h[:] - signature, err := ecdsa.SignASN1(rand.Reader, privKey, hash) - if err != nil { - return nil, false, err - } - ok := ecdsa.VerifyASN1(pubKey, hash, signature) - return signature, ok, nil -} - func (ec *EllipticCurve) EncodeMsg(msg any) ([]byte, error) { return []byte(fmt.Sprintf("%v", msg)), nil /*var encodedMsg bytes.Buffer From e1a2a7a193147aad4589dbedf6389843519cce33 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 13 Apr 2025 13:30:26 +0200 Subject: [PATCH 02/27] refactor: remove commented gob-based encoder in EncodeMsg --- authentication/authentication.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index ba4702a12..4a3c69560 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -131,14 +131,6 @@ func (ec *EllipticCurve) VerifySignature(pemEncodedPub string, msg, signature [] func (ec *EllipticCurve) EncodeMsg(msg any) ([]byte, error) { return []byte(fmt.Sprintf("%v", msg)), nil - /*var encodedMsg bytes.Buffer - gob.Register(msg) - enc := gob.NewEncoder(&encodedMsg) - err := enc.Encode(msg) - if err != nil { - return nil, err - } - return encodedMsg.Bytes(), nil*/ } func encodeMsg(msg any) ([]byte, error) { From fa07fb2f78c0d47c1c975840a42737843c1a806d Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 13 Apr 2025 13:47:10 +0200 Subject: [PATCH 03/27] refactor: remove unnecessary error from encodeMsg function --- authentication/authentication.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index 4a3c69560..9019186f5 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -133,15 +133,12 @@ func (ec *EllipticCurve) EncodeMsg(msg any) ([]byte, error) { return []byte(fmt.Sprintf("%v", msg)), nil } -func encodeMsg(msg any) ([]byte, error) { - return []byte(fmt.Sprintf("%v", msg)), nil +func encodeMsg(msg any) []byte { + return fmt.Appendf(nil, "%v", msg) } func Verify(pemEncodedPub string, signature, digest []byte, msg any) (bool, error) { - encodedMsg, err := encodeMsg(msg) - if err != nil { - return false, err - } + encodedMsg := encodeMsg(msg) ec := New(elliptic.P256()) h := sha256.Sum256(encodedMsg) hash := h[:] From 875cb6e8a3505a43b365a3fcce30478ed7feb673 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 13 Apr 2025 14:00:00 +0200 Subject: [PATCH 04/27] refactor: replace method with EncodeMsg function and remove error The EncodeMsg method does not need to return an error, nor does it need to be a method since it doesn't use the EllipticCurve type for anything. --- authentication/authentication.go | 8 ++------ authentication/authentication_test.go | 20 ++++---------------- channel_test.go | 4 ++-- clientserver.go | 11 ++++------- config.go | 15 +++++---------- server.go | 11 ++++------- 6 files changed, 21 insertions(+), 48 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index 9019186f5..71542da13 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -129,16 +129,12 @@ func (ec *EllipticCurve) VerifySignature(pemEncodedPub string, msg, signature [] return ok, nil } -func (ec *EllipticCurve) EncodeMsg(msg any) ([]byte, error) { - return []byte(fmt.Sprintf("%v", msg)), nil -} - -func encodeMsg(msg any) []byte { +func EncodeMsg(msg any) []byte { return fmt.Appendf(nil, "%v", msg) } func Verify(pemEncodedPub string, signature, digest []byte, msg any) (bool, error) { - encodedMsg := encodeMsg(msg) + encodedMsg := EncodeMsg(msg) ec := New(elliptic.P256()) h := sha256.Sum256(encodedMsg) hash := h[:] diff --git a/authentication/authentication_test.go b/authentication/authentication_test.go index 579f30454..2f83767c1 100644 --- a/authentication/authentication_test.go +++ b/authentication/authentication_test.go @@ -31,10 +31,7 @@ func TestSignAndVerify(t *testing.T) { message := "This is a message" - encodedMsg1, err := ec1.EncodeMsg(message) - if err != nil { - t.Error(err) - } + encodedMsg1 := EncodeMsg(message) signature, err := ec1.Sign(encodedMsg1) if err != nil { t.Error(err) @@ -44,10 +41,7 @@ func TestSignAndVerify(t *testing.T) { t.Error(err) } - encodedMsg2, err := ec2.EncodeMsg(message) - if err != nil { - t.Error(err) - } + encodedMsg2 := EncodeMsg(message) ok, err := ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) if err != nil { t.Error(err) @@ -71,10 +65,7 @@ func TestVerifyWithWrongPubKey(t *testing.T) { } message := "This is a message" - encodedMsg1, err := ec1.EncodeMsg(message) - if err != nil { - t.Error(err) - } + encodedMsg1 := EncodeMsg(message) signature, err := ec1.Sign(encodedMsg1) if err != nil { t.Error(err) @@ -86,10 +77,7 @@ func TestVerifyWithWrongPubKey(t *testing.T) { t.Error(err) } - encodedMsg2, err := ec2.EncodeMsg(message) - if err != nil { - t.Error(err) - } + encodedMsg2 := EncodeMsg(message) ok, err := ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) if err != nil { t.Error(err) diff --git a/channel_test.go b/channel_test.go index ae0b19875..cbb1272cc 100644 --- a/channel_test.go +++ b/channel_test.go @@ -223,8 +223,8 @@ func TestAuthentication(t *testing.T) { msg := &Message{Metadata: md, Message: &mock.Request{}} msg1 := &Message{Metadata: md, Message: &mock.Request{}} - chEncodedMsg, _ := config.encodeMsg(msg1) - srvEncodedMsg, _ := srv.srv.encodeMsg(msg1) + chEncodedMsg := config.encodeMsg(msg1) + srvEncodedMsg := srv.srv.encodeMsg(msg1) if !bytes.Equal(chEncodedMsg, srvEncodedMsg) { t.Fatalf("wrong encoding. want: %x, got: %x", chEncodedMsg, srvEncodedMsg) } diff --git a/clientserver.go b/clientserver.go index 0fa377ef5..b22273ce0 100644 --- a/clientserver.go +++ b/clientserver.go @@ -250,7 +250,7 @@ func (srv *ClientServer) Serve(listener net.Listener) error { return srv.grpcServer.Serve(listener) } -func (srv *ClientServer) encodeMsg(req *Message) ([]byte, error) { +func (srv *ClientServer) encodeMsg(req *Message) []byte { // we must not consider the signature field when validating. // also the msgType must be set to requestType. signature := make([]byte, len(req.Metadata.GetAuthMsg().GetSignature())) @@ -258,12 +258,12 @@ func (srv *ClientServer) encodeMsg(req *Message) ([]byte, error) { reqType := req.msgType req.Metadata.GetAuthMsg().SetSignature(nil) req.msgType = 0 - encodedMsg, err := srv.auth.EncodeMsg(*req) + encodedMsg := authentication.EncodeMsg(*req) req.Metadata.GetAuthMsg().SetSignature(make([]byte, len(signature))) // TODO(meling): I think this is incorrect and should be done differently. copy(req.Metadata.GetAuthMsg().GetSignature(), signature) req.msgType = reqType - return encodedMsg, err + return encodedMsg } func (srv *ClientServer) verify(req *Message) error { @@ -289,10 +289,7 @@ func (srv *ClientServer) verify(req *Message) error { return fmt.Errorf("publicKey did not match") } } - encodedMsg, err := srv.encodeMsg(req) - if err != nil { - return err - } + encodedMsg := srv.encodeMsg(req) valid, err := srv.auth.VerifySignature(authMsg.GetPublicKey(), encodedMsg, authMsg.GetSignature()) if err != nil { return err diff --git a/config.go b/config.go index ead550cf2..a9b673e13 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package gorums import ( "fmt" + "github.com/relab/gorums/authentication" "github.com/relab/gorums/ordering" ) @@ -66,10 +67,7 @@ func (c RawConfiguration) getMsgID() uint64 { func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { if c[0].mgr.opts.auth != nil { if len(signOrigin) > 0 && signOrigin[0] { - originMsg, err := c[0].mgr.opts.auth.EncodeMsg(msg.Message) - if err != nil { - panic(err) - } + originMsg := authentication.EncodeMsg(msg.Message) digest := c[0].mgr.opts.auth.Hash(originMsg) originSignature, err := c[0].mgr.opts.auth.Sign(originMsg) if err != nil { @@ -83,10 +81,7 @@ func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { msg.Metadata.GetBroadcastMsg().SetOriginPubKey(pubKey) msg.Metadata.GetBroadcastMsg().SetOriginSignature(originSignature) } - encodedMsg, err := c.encodeMsg(msg) - if err != nil { - panic(err) - } + encodedMsg := c.encodeMsg(msg) signature, err := c[0].mgr.opts.auth.Sign(encodedMsg) if err != nil { panic(err) @@ -95,7 +90,7 @@ func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { } } -func (c RawConfiguration) encodeMsg(msg *Message) ([]byte, error) { +func (c RawConfiguration) encodeMsg(msg *Message) []byte { // we do not want to include the signature field in the signature auth := c[0].mgr.opts.auth pubKey, err := auth.EncodePublic() @@ -107,5 +102,5 @@ func (c RawConfiguration) encodeMsg(msg *Message) ([]byte, error) { Signature: nil, Sender: auth.Addr(), }.Build()) - return auth.EncodeMsg(*msg) + return authentication.EncodeMsg(*msg) } diff --git a/server.go b/server.go index b8a00e873..bdd922679 100644 --- a/server.go +++ b/server.go @@ -38,7 +38,7 @@ func newOrderingServer(opts *serverOptions) *orderingServer { return s } -func (s *orderingServer) encodeMsg(req *Message) ([]byte, error) { +func (s *orderingServer) encodeMsg(req *Message) []byte { // we must not consider the signature field when validating. // also the msgType must be set to requestType. signature := make([]byte, len(req.Metadata.GetAuthMsg().GetSignature())) @@ -46,12 +46,12 @@ func (s *orderingServer) encodeMsg(req *Message) ([]byte, error) { reqType := req.msgType req.Metadata.GetAuthMsg().SetSignature(nil) req.msgType = 0 - encodedMsg, err := s.opts.auth.EncodeMsg(*req) + encodedMsg := authentication.EncodeMsg(*req) req.Metadata.GetAuthMsg().SetSignature(make([]byte, len(signature))) // TODO(meling): I think this is incorrect and should be done differently. copy(req.Metadata.GetAuthMsg().GetSignature(), signature) req.msgType = reqType - return encodedMsg, err + return encodedMsg } func (s *orderingServer) verify(req *Message) error { @@ -78,10 +78,7 @@ func (s *orderingServer) verify(req *Message) error { return fmt.Errorf("publicKey did not match") } } - encodedMsg, err := s.encodeMsg(req) - if err != nil { - return err - } + encodedMsg := s.encodeMsg(req) valid, err := auth.VerifySignature(authMsg.GetPublicKey(), encodedMsg, authMsg.GetSignature()) if err != nil { return err From 8cc1be8c73f74e808ee4c565f52f6ff79e241173 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 13 Apr 2025 14:48:23 +0200 Subject: [PATCH 05/27] refactor: replace encodeMsg with Encode method on gorums.Message type --- channel_test.go | 2 +- clientserver.go | 19 +------------------ encoding.go | 25 +++++++++++++++++++++++++ server.go | 19 +------------------ 4 files changed, 28 insertions(+), 37 deletions(-) diff --git a/channel_test.go b/channel_test.go index cbb1272cc..ad1ad1efb 100644 --- a/channel_test.go +++ b/channel_test.go @@ -224,7 +224,7 @@ func TestAuthentication(t *testing.T) { msg1 := &Message{Metadata: md, Message: &mock.Request{}} chEncodedMsg := config.encodeMsg(msg1) - srvEncodedMsg := srv.srv.encodeMsg(msg1) + srvEncodedMsg := msg1.Encode() if !bytes.Equal(chEncodedMsg, srvEncodedMsg) { t.Fatalf("wrong encoding. want: %x, got: %x", chEncodedMsg, srvEncodedMsg) } diff --git a/clientserver.go b/clientserver.go index b22273ce0..ab6fa3021 100644 --- a/clientserver.go +++ b/clientserver.go @@ -250,22 +250,6 @@ func (srv *ClientServer) Serve(listener net.Listener) error { return srv.grpcServer.Serve(listener) } -func (srv *ClientServer) encodeMsg(req *Message) []byte { - // we must not consider the signature field when validating. - // also the msgType must be set to requestType. - signature := make([]byte, len(req.Metadata.GetAuthMsg().GetSignature())) - copy(signature, req.Metadata.GetAuthMsg().GetSignature()) - reqType := req.msgType - req.Metadata.GetAuthMsg().SetSignature(nil) - req.msgType = 0 - encodedMsg := authentication.EncodeMsg(*req) - req.Metadata.GetAuthMsg().SetSignature(make([]byte, len(signature))) - // TODO(meling): I think this is incorrect and should be done differently. - copy(req.Metadata.GetAuthMsg().GetSignature(), signature) - req.msgType = reqType - return encodedMsg -} - func (srv *ClientServer) verify(req *Message) error { if srv.auth == nil { return nil @@ -289,8 +273,7 @@ func (srv *ClientServer) verify(req *Message) error { return fmt.Errorf("publicKey did not match") } } - encodedMsg := srv.encodeMsg(req) - valid, err := srv.auth.VerifySignature(authMsg.GetPublicKey(), encodedMsg, authMsg.GetSignature()) + valid, err := srv.auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) if err != nil { return err } diff --git a/encoding.go b/encoding.go index ce8fedb31..0e8b4e15c 100644 --- a/encoding.go +++ b/encoding.go @@ -3,6 +3,7 @@ package gorums import ( "fmt" + "github.com/relab/gorums/authentication" "github.com/relab/gorums/ordering" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" @@ -35,6 +36,30 @@ func newMessage(msgType gorumsMsgType) *Message { return &Message{Metadata: &ordering.Metadata{}, msgType: msgType} } +// Encode returns an encoded byte representation of the Message +// ignoring the message type and signature in the Auth message. +func (m *Message) Encode() []byte { + authMsg := m.Metadata.GetAuthMsg() + + // save the original signature and msgType. + origSignature := authMsg.GetSignature() + sigCopy := make([]byte, len(origSignature)) + copy(sigCopy, origSignature) + origMsgType := m.msgType + + // prepare for encoding: remove the signature and set msgType to 0 + authMsg.SetSignature(nil) + m.msgType = 0 + + encoded := authentication.EncodeMsg(*m) + + // restore the original values + authMsg.SetSignature(sigCopy) + m.msgType = origMsgType + + return encoded +} + // Codec is the gRPC codec used by gorums. type Codec struct { marshaler proto.MarshalOptions diff --git a/server.go b/server.go index bdd922679..344dff7d3 100644 --- a/server.go +++ b/server.go @@ -38,22 +38,6 @@ func newOrderingServer(opts *serverOptions) *orderingServer { return s } -func (s *orderingServer) encodeMsg(req *Message) []byte { - // we must not consider the signature field when validating. - // also the msgType must be set to requestType. - signature := make([]byte, len(req.Metadata.GetAuthMsg().GetSignature())) - copy(signature, req.Metadata.GetAuthMsg().GetSignature()) - reqType := req.msgType - req.Metadata.GetAuthMsg().SetSignature(nil) - req.msgType = 0 - encodedMsg := authentication.EncodeMsg(*req) - req.Metadata.GetAuthMsg().SetSignature(make([]byte, len(signature))) - // TODO(meling): I think this is incorrect and should be done differently. - copy(req.Metadata.GetAuthMsg().GetSignature(), signature) - req.msgType = reqType - return encodedMsg -} - func (s *orderingServer) verify(req *Message) error { if s.opts.auth == nil { return nil @@ -78,8 +62,7 @@ func (s *orderingServer) verify(req *Message) error { return fmt.Errorf("publicKey did not match") } } - encodedMsg := s.encodeMsg(req) - valid, err := auth.VerifySignature(authMsg.GetPublicKey(), encodedMsg, authMsg.GetSignature()) + valid, err := auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) if err != nil { return err } From f0f465ea105a1a7bb0218bab2646e1ad2ca10f08 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 13 Apr 2025 16:26:31 +0200 Subject: [PATCH 06/27] refactor: use Hash function in Sign and VerifySignature methods --- authentication/authentication.go | 18 ++++++------------ config.go | 2 +- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index 71542da13..654389bb3 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -102,30 +102,25 @@ func (ec *EllipticCurve) DecodePublic(pemEncodedPub string) (*ecdsa.PublicKey, e } func (ec *EllipticCurve) Sign(msg []byte) ([]byte, error) { - h := sha256.Sum256(msg) - hash := h[:] - signature, err := ecdsa.SignASN1(rand.Reader, ec.privateKey, hash) + signature, err := ecdsa.SignASN1(rand.Reader, ec.privateKey, Hash(msg)) if err != nil { return nil, err } return signature, nil } -func (ec *EllipticCurve) Hash(msg []byte) []byte { - h := sha256.Sum256(msg) - hash := h[:] - return hash +func Hash(msg []byte) []byte { + hash := sha256.Sum256(msg) + return hash[:] } // VerifySignature sign ecdsa style and verify signature func (ec *EllipticCurve) VerifySignature(pemEncodedPub string, msg, signature []byte) (bool, error) { - h := sha256.Sum256(msg) - hash := h[:] pubKey, err := ec.DecodePublic(pemEncodedPub) if err != nil { return false, err } - ok := ecdsa.VerifyASN1(pubKey, hash, signature) + ok := ecdsa.VerifyASN1(pubKey, Hash(msg), signature) return ok, nil } @@ -136,8 +131,7 @@ func EncodeMsg(msg any) []byte { func Verify(pemEncodedPub string, signature, digest []byte, msg any) (bool, error) { encodedMsg := EncodeMsg(msg) ec := New(elliptic.P256()) - h := sha256.Sum256(encodedMsg) - hash := h[:] + hash := Hash(encodedMsg) if !bytes.Equal(hash, digest) { return false, fmt.Errorf("wrong digest") } diff --git a/config.go b/config.go index a9b673e13..d605b45e9 100644 --- a/config.go +++ b/config.go @@ -68,7 +68,7 @@ func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { if c[0].mgr.opts.auth != nil { if len(signOrigin) > 0 && signOrigin[0] { originMsg := authentication.EncodeMsg(msg.Message) - digest := c[0].mgr.opts.auth.Hash(originMsg) + digest := authentication.Hash(originMsg) originSignature, err := c[0].mgr.opts.auth.Sign(originMsg) if err != nil { panic(err) From 3daf7ed5dd3a87457a12f3a9760d3ae824630908 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sun, 13 Apr 2025 16:28:20 +0200 Subject: [PATCH 07/27] chore: simplify var name in sign method --- config.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/config.go b/config.go index d605b45e9..57c59966e 100644 --- a/config.go +++ b/config.go @@ -65,15 +65,16 @@ func (c RawConfiguration) getMsgID() uint64 { } func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { - if c[0].mgr.opts.auth != nil { + auth := c[0].mgr.opts.auth + if auth != nil { if len(signOrigin) > 0 && signOrigin[0] { originMsg := authentication.EncodeMsg(msg.Message) digest := authentication.Hash(originMsg) - originSignature, err := c[0].mgr.opts.auth.Sign(originMsg) + originSignature, err := auth.Sign(originMsg) if err != nil { panic(err) } - pubKey, err := c[0].mgr.opts.auth.EncodePublic() + pubKey, err := auth.EncodePublic() if err != nil { panic(err) } @@ -82,7 +83,7 @@ func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { msg.Metadata.GetBroadcastMsg().SetOriginSignature(originSignature) } encodedMsg := c.encodeMsg(msg) - signature, err := c[0].mgr.opts.auth.Sign(encodedMsg) + signature, err := auth.Sign(encodedMsg) if err != nil { panic(err) } From e0aaae5610846d0082f26df5ea481bcd3c679717 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 19:31:03 +0200 Subject: [PATCH 08/27] refactor: move message validation to Message.isValid method No need to attach the message validation check to the broadcastSrv. --- broadcast.go | 3 +-- encoding.go | 19 +++++++++++++++++++ handler.go | 21 ++------------------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/broadcast.go b/broadcast.go index 791b5cfe6..e2cc1c7f5 100644 --- a/broadcast.go +++ b/broadcast.go @@ -75,11 +75,10 @@ type ( type ( defaultImplementationFunc[T protoreflect.ProtoMessage, V protoreflect.ProtoMessage] func(ServerCtx, T) (V, error) + implementationFunc[T protoreflect.ProtoMessage, V Broadcaster] func(ServerCtx, T, V) clientImplementationFunc[T protoreflect.ProtoMessage, V protoreflect.ProtoMessage] func(context.Context, T, uint64) (V, error) ) -type implementationFunc[T protoreflect.ProtoMessage, V Broadcaster] func(ServerCtx, T, V) - func CancelFunc(ServerCtx, protoreflect.ProtoMessage, Broadcaster) {} const Cancellation string = "cancel" diff --git a/encoding.go b/encoding.go index 0e8b4e15c..d69b4c6ae 100644 --- a/encoding.go +++ b/encoding.go @@ -1,6 +1,7 @@ package gorums import ( + "errors" "fmt" "github.com/relab/gorums/authentication" @@ -36,6 +37,24 @@ func newMessage(msgType gorumsMsgType) *Message { return &Message{Metadata: &ordering.Metadata{}, msgType: msgType} } +// isValid returns an error if the Message is an invalid broadcast message. +func (m *Message) isValid() error { + if m == nil { + return errors.New("message cannot be nil") + } + if m.Metadata == nil { + return errors.New("message metadata cannot be nil") + } + broadcastMsg := m.Metadata.GetBroadcastMsg() + if broadcastMsg == nil { + return errors.New("broadcast message cannot be nil") + } + if broadcastMsg.GetBroadcastID() == 0 { + return errors.New("broadcastID cannot be 0") + } + return nil +} + // Encode returns an encoded byte representation of the Message // ignoring the message type and signature in the Auth message. func (m *Message) Encode() []byte { diff --git a/handler.go b/handler.go index 8264e0b9f..5acacddce 100644 --- a/handler.go +++ b/handler.go @@ -2,7 +2,6 @@ package gorums import ( "context" - "fmt" "github.com/relab/gorums/broadcast" "github.com/relab/gorums/ordering" @@ -35,9 +34,9 @@ func BroadcastHandler[T protoreflect.ProtoMessage, V Broadcaster](impl implement // guard: // - A broadcastID should be non-empty: // - Maybe the request should be unique? Remove duplicates of the same broadcast? <- Most likely no (up to the implementer) - if err := srv.broadcastSrv.validateMessage(in); err != nil { + if err := in.isValid(); err != nil { if srv.broadcastSrv.logger != nil { - srv.broadcastSrv.logger.Debug("broadcast request not valid", "metadata", in.Metadata, "err", err) + srv.broadcastSrv.logger.Debug("invalid broadcast request", "metadata", in.Metadata, "err", err) } return } @@ -118,22 +117,6 @@ func createSendFn(msgID uint64, method string, finished chan<- *Message, ctx Ser } } -func (srv *broadcastServer) validateMessage(in *Message) error { - if in == nil { - return fmt.Errorf("message cannot be empty. got: %v", in) - } - if in.Metadata == nil { - return fmt.Errorf("metadata cannot be empty. got: %v", in.Metadata) - } - if in.Metadata.GetBroadcastMsg() == nil { - return fmt.Errorf("broadcastMsg cannot be empty. got: %v", in.Metadata.GetBroadcastMsg()) - } - if in.Metadata.GetBroadcastMsg().GetBroadcastID() <= 0 { - return fmt.Errorf("broadcastID cannot be empty. got: %v", in.Metadata.GetBroadcastMsg().GetBroadcastID()) - } - return nil -} - func (srv *Server) RegisterBroadcaster(broadcaster func(m BroadcastMetadata, o *BroadcastOrchestrator, e EnqueueBroadcast) Broadcaster) { srv.broadcastSrv.createBroadcaster = broadcaster srv.broadcastSrv.orchestrator = NewBroadcastOrchestrator(srv) From a695bdc2472166b28c859ed3f33141ad72d9aaf1 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 19:58:20 +0200 Subject: [PATCH 09/27] refactor: streamline createRequest function --- handler.go | 43 ++++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/handler.go b/handler.go index 5acacddce..1d91efb2d 100644 --- a/handler.go +++ b/handler.go @@ -64,14 +64,13 @@ func BroadcastHandler[T protoreflect.ProtoMessage, V Broadcaster](impl implement impl(ctx, req, broadcaster) } - msg := broadcast.Content{} - createRequest(&msg, ctx, in, finished, run) + msg := createRequest(ctx, in.Metadata, finished, run) // we are not interested in the server context as this is tied to the previous hop. // instead we want to check whether the client has cancelled the broadcast request // and if so, we return a cancelled context. This enables the implementer to listen // for cancels and do proper actions. - reqCtx, enqueueBroadcast, err := srv.broadcastSrv.manager.Process(&msg) + reqCtx, enqueueBroadcast, err := srv.broadcastSrv.manager.Process(msg) if err != nil { return } @@ -80,33 +79,31 @@ func BroadcastHandler[T protoreflect.ProtoMessage, V Broadcaster](impl implement } } -func createRequest(msg *broadcast.Content, ctx ServerCtx, in *Message, finished chan<- *Message, run func(context.Context, func(*broadcast.Msg) error)) { - msg.BroadcastID = in.Metadata.GetBroadcastMsg().GetBroadcastID() - msg.IsBroadcastClient = in.Metadata.GetBroadcastMsg().GetIsBroadcastClient() - msg.OriginAddr = in.Metadata.GetBroadcastMsg().GetOriginAddr() - msg.OriginMethod = in.Metadata.GetBroadcastMsg().GetOriginMethod() - msg.SenderAddr = in.Metadata.GetBroadcastMsg().GetSenderAddr() +func createRequest(ctx ServerCtx, md *ordering.Metadata, finished chan<- *Message, run func(context.Context, func(*broadcast.Msg) error)) *broadcast.Content { + broadcastMsg := md.GetBroadcastMsg() + msg := &broadcast.Content{ + BroadcastID: broadcastMsg.GetBroadcastID(), + IsBroadcastClient: broadcastMsg.GetIsBroadcastClient(), + OriginAddr: broadcastMsg.GetOriginAddr(), + OriginMethod: broadcastMsg.GetOriginMethod(), + SenderAddr: broadcastMsg.GetSenderAddr(), + OriginDigest: broadcastMsg.GetOriginDigest(), + OriginSignature: broadcastMsg.GetOriginSignature(), + OriginPubKey: broadcastMsg.GetOriginPubKey(), + CurrentMethod: md.GetMethod(), + Ctx: ctx.Context, + Run: run, + } if msg.SenderAddr == "" && msg.IsBroadcastClient { msg.SenderAddr = "client" } - if in.Metadata.GetBroadcastMsg().GetOriginDigest() != nil { - msg.OriginDigest = in.Metadata.GetBroadcastMsg().GetOriginDigest() - } - if in.Metadata.GetBroadcastMsg().GetOriginSignature() != nil { - msg.OriginSignature = in.Metadata.GetBroadcastMsg().GetOriginSignature() - } - if in.Metadata.GetBroadcastMsg().GetOriginPubKey() != "" { - msg.OriginPubKey = in.Metadata.GetBroadcastMsg().GetOriginPubKey() - } - msg.CurrentMethod = in.Metadata.GetMethod() - msg.Ctx = ctx.Context - msg.Run = run if msg.OriginAddr == "" && msg.IsBroadcastClient { - msg.SendFn = createSendFn(in.Metadata.GetMessageID(), in.Metadata.GetMethod(), finished, ctx) + msg.SendFn = createSendFn(md.GetMessageID(), md.GetMethod(), finished, ctx) } - if in.Metadata.GetMethod() == Cancellation { + if md.GetMethod() == Cancellation { msg.IsCancellation = true } + return msg } func createSendFn(msgID uint64, method string, finished chan<- *Message, ctx ServerCtx) func(resp protoreflect.ProtoMessage, err error) error { From 51085a8c2e495c79d5016df36e9120ee2409b95b Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 20:03:45 +0200 Subject: [PATCH 10/27] refactor(modernize): replace interface{} with any --- broadcast.go | 2 +- broadcast/router.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/broadcast.go b/broadcast.go index e2cc1c7f5..342ea864a 100644 --- a/broadcast.go +++ b/broadcast.go @@ -162,7 +162,7 @@ func NewBroadcastOptions() broadcast.BroadcastOptions { } } -type Broadcaster interface{} +type Broadcaster any type BroadcastMetadata struct { BroadcastID uint64 diff --git a/broadcast/router.go b/broadcast/router.go index 420bbfd78..561afd1ea 100644 --- a/broadcast/router.go +++ b/broadcast/router.go @@ -61,7 +61,7 @@ func (r *BroadcastRouter) registerState(state *BroadcastState) { r.state = state } -type msg interface{} +type msg any func (r *BroadcastRouter) Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error { if r.addr == "" { From 46f2b47f3a92b85fd90c352e55e9a7e2122b4011 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 20:23:51 +0200 Subject: [PATCH 11/27] refactor(modernize): use slices.Contains instead of manual loop check --- broadcast/processor.go | 14 +++++--------- broadcastcall.go | 12 ++++-------- tests/broadcast/server.go | 17 +++-------------- 3 files changed, 12 insertions(+), 31 deletions(-) diff --git a/broadcast/processor.go b/broadcast/processor.go index 55f22e5aa..a3663b836 100644 --- a/broadcast/processor.go +++ b/broadcast/processor.go @@ -3,6 +3,7 @@ package broadcast import ( "context" "log/slog" + "slices" "time" "github.com/relab/gorums/logging" @@ -116,7 +117,7 @@ func (p *BroadcastProcessor) handleCancellation(bMsg *Msg, metadata *metadata) b func (p *BroadcastProcessor) handleBroadcast(bMsg *Msg, methods []string, metadata *metadata) bool { // check if msg has already been broadcasted for this method - //if alreadyBroadcasted(p.metadata.Methods, bMsg.Method) { + // if alreadyBroadcasted(p.metadata.Methods, bMsg.Method) { if !bMsg.Msg.options.AllowDuplication && alreadyBroadcasted(methods, bMsg.Method) { return false } @@ -135,7 +136,7 @@ func (p *BroadcastProcessor) handleReply(bMsg *Msg, metadata *metadata) bool { err := p.router.Send(broadcastID, originAddr, originMethod, metadata.OriginDigest, metadata.OriginSignature, metadata.OriginPubKey, replyMsg) p.log("broadcast: sent reply to client", err, logging.Method(originMethod), logging.MsgType(bMsg.MsgType.String()), logging.Stopping(true), logging.IsBroadcastCall(metadata.isBroadcastCall())) }(p.broadcastID, metadata.OriginAddr, metadata.OriginMethod, bMsg.Reply) - // the request is done becuase we have sent a reply to the client + // the request is done because we have sent a reply to the client p.log("broadcast: sending reply to client", nil, logging.Method(metadata.OriginMethod), logging.MsgType(bMsg.MsgType.String()), logging.Stopping(true), logging.IsBroadcastCall(metadata.isBroadcastCall())) return true } @@ -161,7 +162,7 @@ func (p *BroadcastProcessor) handleReply(bMsg *Msg, metadata *metadata) bool { // is done. return metadata.hasReceivedClientRequest() } - // the request is done becuase we have sent a reply to the client + // the request is done because we have sent a reply to the client p.log("broadcast: sending reply to client", err, logging.Method(metadata.OriginMethod), logging.MsgType(bMsg.MsgType.String()), logging.Stopping(true), logging.IsBroadcastCall(metadata.isBroadcastCall())) return true } @@ -432,12 +433,7 @@ func (r *BroadcastProcessor) dispatchOutOfOrderMsgs() { } func alreadyBroadcasted(methods []string, method string) bool { - for _, m := range methods { - if m == method { - return true - } - } - return false + return slices.Contains(methods, method) } func (p *BroadcastProcessor) initialize(msg *Content, metadata *metadata) { diff --git a/broadcastcall.go b/broadcastcall.go index 2de90d516..fd2b38b71 100644 --- a/broadcastcall.go +++ b/broadcastcall.go @@ -2,6 +2,7 @@ package gorums import ( "context" + "slices" "github.com/relab/gorums/ordering" "google.golang.org/protobuf/reflect/protoreflect" @@ -25,18 +26,13 @@ type BroadcastCallData struct { SkipSelf bool } -// checks whether the given address is contained in the given subset +// inSubset returns true if the given address is in the given subset // of server addresses. Will return true if a subset is not given. func (bcd *BroadcastCallData) inSubset(addr string) bool { - if bcd.ServerAddresses == nil || len(bcd.ServerAddresses) <= 0 { + if bcd == nil || len(bcd.ServerAddresses) <= 0 { return true } - for _, srvAddr := range bcd.ServerAddresses { - if addr == srvAddr { - return true - } - } - return false + return slices.Contains(bcd.ServerAddresses, addr) } // BroadcastCall performs a broadcast on the configuration. diff --git a/tests/broadcast/server.go b/tests/broadcast/server.go index 83b6b409e..f8f24830b 100644 --- a/tests/broadcast/server.go +++ b/tests/broadcast/server.go @@ -11,6 +11,7 @@ import ( gorums "github.com/relab/gorums" grpc "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "slices" ) var leader = "127.0.0.1:5000" @@ -245,13 +246,7 @@ func (srv *testServer) PrePrepare(ctx gorums.ServerCtx, req *Request, broadcast time.Sleep(200 * time.Millisecond) } srv.mut.Lock() - added := false - for _, m := range srv.order { - if m == "PrePrepare" { - added = true - break - } - } + added := slices.Contains(srv.order, "PrePrepare") if !added { srv.order = append(srv.order, "PrePrepare") } @@ -269,13 +264,7 @@ func (srv *testServer) Prepare(ctx gorums.ServerCtx, req *Request, broadcast *Br srv.mut.Unlock() return } - added := false - for _, m := range srv.order { - if m == "Prepare" { - added = true - break - } - } + added := slices.Contains(srv.order, "Prepare") if !added { srv.order = append(srv.order, "Prepare") } From 8bae2d6bf2b65864cbdf89902194875dcafb928f Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 20:24:47 +0200 Subject: [PATCH 12/27] refactor(modernize): replace context.WithCancel with t.Context --- broadcast/shard_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/broadcast/shard_test.go b/broadcast/shard_test.go index f2fa05257..6f2b7fcf5 100644 --- a/broadcast/shard_test.go +++ b/broadcast/shard_test.go @@ -40,8 +40,7 @@ func TestShard(t *testing.T) { returnError: false, } shardBuffer := 100 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() shard := &shard{ id: 0, parentCtx: ctx, From 504908337c531c2c48c82367a72ad92887714c2e Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 20:35:29 +0200 Subject: [PATCH 13/27] refactor(modernize): replace b.N loops with b.Loop() in benchmarks --- broadcast/processor_test.go | 6 +++--- tests/broadcast/broadcast_test.go | 28 ++++++++++++++-------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/broadcast/processor_test.go b/broadcast/processor_test.go index d38e7cc1b..08720f012 100644 --- a/broadcast/processor_test.go +++ b/broadcast/processor_test.go @@ -41,7 +41,7 @@ func TestHandleBroadcastOption1(t *testing.T) { snowflake := NewSnowflake(0) broadcastID := snowflake.NewBroadcastID() - var tests = []struct { + tests := []struct { in *Content out error }{ @@ -159,7 +159,7 @@ func TestHandleBroadcastCall1(t *testing.T) { snowflake := NewSnowflake(0) broadcastID := snowflake.NewBroadcastID() - var tests = []struct { + tests := []struct { in *Content out error }{ @@ -290,7 +290,7 @@ func BenchmarkHandleProcessor(b *testing.B) { b.ResetTimer() b.Run("ProcessorHandler", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for b.Loop() { msg := &Content{ BroadcastID: broadcastID, IsBroadcastClient: true, diff --git a/tests/broadcast/broadcast_test.go b/tests/broadcast/broadcast_test.go index 867c79432..848feb072 100644 --- a/tests/broadcast/broadcast_test.go +++ b/tests/broadcast/broadcast_test.go @@ -777,7 +777,7 @@ func BenchmarkQuorumCall(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("QC_AllSuccessful_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.QuorumCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -818,7 +818,7 @@ func BenchmarkQCMulticast(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("QCM_AllSuccessful_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) resp, err := config.QuorumCallWithMulticast(ctx, &Request{Value: int64(i)}) if err != nil { @@ -861,7 +861,7 @@ func BenchmarkQCBroadcastOption(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("QCB_AllSuccessful_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) resp, err := config.QuorumCallWithBroadcast(ctx, &Request{Value: int64(i)}) if err != nil { @@ -911,7 +911,7 @@ func BenchmarkQCBroadcastOptionManyClients(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("QCB_ManyClients_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { var wg sync.WaitGroup for _, client := range clients { wg.Add(1) @@ -960,7 +960,7 @@ func BenchmarkBroadcastCallAllServers(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("BC_AllSuccessful_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1001,7 +1001,7 @@ func BenchmarkBroadcastCallToOneServer(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("BC_OneSrv_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1042,7 +1042,7 @@ func BenchmarkBroadcastCallOneFailedServer(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("BC_OneSrvDown_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1083,7 +1083,7 @@ func BenchmarkBroadcastCallOneDownSrvToOneSrv(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("BC_OneDownToOne_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1127,7 +1127,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } b.Run(fmt.Sprintf("BC_OneClientOneReq_%d", 0), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := clients[0].BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1138,7 +1138,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } }) b.Run(fmt.Sprintf("BC_OneClientAsync_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { var wg sync.WaitGroup for j := range clients { go func(i, j int, c *Configuration) { @@ -1158,7 +1158,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } }) b.Run(fmt.Sprintf("BC_OneClientSync_%d", 2), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { for j := range clients { val := i*100 + j resp, err := clients[0].BroadcastCall(context.Background(), &Request{Value: int64(val)}) @@ -1172,7 +1172,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } }) b.Run(fmt.Sprintf("BC_ManyClientsAsync_%d", 3), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { var wg sync.WaitGroup for _, client := range clients { go func(i int, c *Configuration) { @@ -1191,7 +1191,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } }) b.Run(fmt.Sprintf("BC_ManyClientsSync_%d", 4), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { for _, client := range clients { resp, err := client.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { @@ -1242,7 +1242,7 @@ func BenchmarkBroadcastCallTenClientsCPU(b *testing.B) { b.ResetTimer() b.Run(fmt.Sprintf("BC_%d", 3), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { var wg sync.WaitGroup for _, client := range clients { go func(i int, c *Configuration) { From 8be741c7a9c8fdaf603f69881f53e9096f7f73af Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 20:41:26 +0200 Subject: [PATCH 14/27] refactor(modernize): compute start and end index using max and min --- tests/broadcast/broadcast_test.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/broadcast/broadcast_test.go b/tests/broadcast/broadcast_test.go index 848feb072..3670548b6 100644 --- a/tests/broadcast/broadcast_test.go +++ b/tests/broadcast/broadcast_test.go @@ -480,14 +480,8 @@ func TestBroadcastCallOneServerIsDown(t *testing.T) { } defer srvCleanup() - start := skip - if start < 0 { - start = 0 - } - end := numSrvs - 1 - if end > len(srvAddrs) { - end = len(srvAddrs) - } + start := max(skip, 0) + end := min(numSrvs-1, len(srvAddrs)) config, clientCleanup, err := newClient(srvAddrs[start:end], "127.0.0.1:8080") if err != nil { t.Error(err) From 69876bce16c8131644f56f9f87e18d659d246f17 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 20:44:21 +0200 Subject: [PATCH 15/27] refactor(modernize): replace 3-clause for loops with range loops --- tests/broadcast/broadcast_test.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/broadcast/broadcast_test.go b/tests/broadcast/broadcast_test.go index 3670548b6..7aa8a6025 100644 --- a/tests/broadcast/broadcast_test.go +++ b/tests/broadcast/broadcast_test.go @@ -17,7 +17,7 @@ func createAuthServers(numSrvs int) ([]*testServer, []string, func(), error) { skip := 0 srvs := make([]*testServer, 0, numSrvs) srvAddrs := make([]string, numSrvs) - for i := 0; i < numSrvs; i++ { + for i := range numSrvs { srvAddrs[i] = fmt.Sprintf("127.0.0.1:500%v", i) } for _, addr := range srvAddrs { @@ -54,7 +54,7 @@ func createSrvs(numSrvs int, down ...int) ([]*testServer, []string, func(), erro } srvs := make([]*testServer, 0, numSrvs) srvAddrs := make([]string, numSrvs) - for i := 0; i < numSrvs; i++ { + for i := range numSrvs { srvAddrs[i] = fmt.Sprintf("127.0.0.1:500%v", i) } for i, addr := range srvAddrs { @@ -520,7 +520,7 @@ func TestBroadcastCallForward(t *testing.T) { } defer clientCleanup() - for i := 0; i < 10; i++ { + for i := range 10 { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() resp, err := config.BroadcastCallForward(ctx, &Request{Value: int64(i)}) @@ -546,7 +546,7 @@ func TestBroadcastCallForwardMultiple(t *testing.T) { } defer clientCleanup() - for i := 0; i < 10; i++ { + for i := range 10 { resp, err := config.BroadcastCallForward(context.Background(), &Request{Value: int64(i)}) if err != nil { t.Error(err) @@ -633,7 +633,7 @@ func TestBroadcastCallAsyncReqs(t *testing.T) { numClients := 10 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { + for c := range numClients { config, clientCleanup, err := newClient(srvAddrs, fmt.Sprintf("127.0.0.1:808%v", c), 3) if err != nil { t.Error(err) @@ -656,7 +656,7 @@ func TestBroadcastCallAsyncReqs(t *testing.T) { } var wg sync.WaitGroup - for i := 0; i < 1; i++ { + for i := range 1 { for _, client := range clients { go func(j int, c *Configuration) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -697,7 +697,7 @@ func TestQCBroadcastOptionRace(t *testing.T) { if resp.GetResult() != val { t.Fatalf("resp is wrong, got: %v, want: %v", resp.GetResult(), val) } - for i := 0; i < 100; i++ { + for i := range 100 { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) resp, err := config.QuorumCallWithBroadcast(ctx, &Request{Value: int64(i)}) if err != nil { @@ -731,7 +731,7 @@ func TestQCMulticastRace(t *testing.T) { if resp.GetResult() != val { t.Fatal("resp is wrong") } - for i := 0; i < 100; i++ { + for i := range 100 { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) resp, err := config.QuorumCallWithMulticast(ctx, &Request{Value: int64(i)}) if err != nil { @@ -880,7 +880,7 @@ func BenchmarkQCBroadcastOptionManyClients(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { + for c := range numClients { config, clientCleanup, err := newClient(srvAddrs, fmt.Sprintf("127.0.0.1:808%v", c), 3) if err != nil { b.Error(err) @@ -1100,7 +1100,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { + for c := range numClients { config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), 3) if err != nil { b.Error(err) @@ -1208,7 +1208,7 @@ func BenchmarkBroadcastCallTenClientsCPU(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { + for c := range numClients { config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), 3) if err != nil { b.Error(err) @@ -1269,7 +1269,7 @@ func BenchmarkBroadcastCallTenClientsMEM(b *testing.B) { numClients := 1 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { + for c := range numClients { config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) if err != nil { b.Error(err) @@ -1323,7 +1323,7 @@ func BenchmarkBroadcastCallTenClientsTRACE(b *testing.B) { numClients := 1 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { + for c := range numClients { config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) if err != nil { b.Error(err) @@ -1380,7 +1380,7 @@ func TestBroadcastCallManyRequestsAsync(t *testing.T) { } clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { + for c := range numClients { config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) if err != nil { t.Error(err) @@ -1410,7 +1410,7 @@ func TestBroadcastCallManyRequestsAsync(t *testing.T) { time.Sleep(500 * time.Millisecond) var wg sync.WaitGroup - for r := 0; r < numReqs; r++ { + for range numReqs { for i, client := range clients { wg.Add(1) go func(i int, client *Configuration) { From de05a05bc354e06dbfbdb2c5d44195748e5c0d52 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 21:11:56 +0200 Subject: [PATCH 16/27] refactor: move auth message validation to method on Metadata type --- clientserver.go | 12 +++--------- ordering/gorums_metadata.go | 18 ++++++++++++++++++ server.go | 12 +++--------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/clientserver.go b/clientserver.go index ab6fa3021..3664d5270 100644 --- a/clientserver.go +++ b/clientserver.go @@ -254,16 +254,10 @@ func (srv *ClientServer) verify(req *Message) error { if srv.auth == nil { return nil } - if req.Metadata.GetAuthMsg() == nil { - return fmt.Errorf("missing authMsg") - } - if req.Metadata.GetAuthMsg().GetSignature() == nil { - return fmt.Errorf("missing signature") - } - if req.Metadata.GetAuthMsg().GetPublicKey() == "" { - return fmt.Errorf("missing publicKey") + authMsg, err := req.Metadata.GetValidAuthMsg() + if err != nil { + return err } - authMsg := req.Metadata.GetAuthMsg() if srv.allowList != nil { pemEncodedPub, ok := srv.allowList[authMsg.GetSender()] if !ok { diff --git a/ordering/gorums_metadata.go b/ordering/gorums_metadata.go index cde158e94..90aa4a160 100644 --- a/ordering/gorums_metadata.go +++ b/ordering/gorums_metadata.go @@ -2,6 +2,7 @@ package ordering import ( "context" + "errors" "google.golang.org/grpc/metadata" ) @@ -38,3 +39,20 @@ func (x *Metadata) AppendToIncomingContext(ctx context.Context) context.Context } return metadata.NewIncomingContext(ctx, newMD) } + +func (x *Metadata) GetValidAuthMsg() (*AuthMsg, error) { + if x == nil { + return nil, errors.New("metadata cannot be nil") + } + authMsg := x.GetAuthMsg() + if authMsg == nil { + return nil, errors.New("missing AuthMsg") + } + if authMsg.GetSignature() == nil { + return nil, errors.New("missing signature") + } + if authMsg.GetPublicKey() == "" { + return nil, errors.New("missing publicKey") + } + return authMsg, nil +} diff --git a/server.go b/server.go index 344dff7d3..85c295732 100644 --- a/server.go +++ b/server.go @@ -42,17 +42,11 @@ func (s *orderingServer) verify(req *Message) error { if s.opts.auth == nil { return nil } - if req.Metadata.GetAuthMsg() == nil { - return fmt.Errorf("missing authMsg") - } - if req.Metadata.GetAuthMsg().GetSignature() == nil { - return fmt.Errorf("missing signature") - } - if req.Metadata.GetAuthMsg().GetPublicKey() == "" { - return fmt.Errorf("missing publicKey") + authMsg, err := req.Metadata.GetValidAuthMsg() + if err != nil { + return err } auth := s.opts.auth - authMsg := req.Metadata.GetAuthMsg() if s.opts.allowList != nil { pemEncodedPub, ok := s.opts.allowList[authMsg.GetSender()] if !ok { From 67c89ac49ce615f16837ff6e08d6a9f8b4ac20b9 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 21:35:15 +0200 Subject: [PATCH 17/27] refactor: add allowList type and encapsulate validation logic This adds a new allowList type and a Check method that performs the validation logic common for message verification in both server.go and in clientserver.go. --- clientserver.go | 12 +++--------- server.go | 35 +++++++++++++++++++++++++---------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/clientserver.go b/clientserver.go index 3664d5270..990b9ad82 100644 --- a/clientserver.go +++ b/clientserver.go @@ -36,7 +36,7 @@ type ClientServer struct { handlers map[string]requestHandler logger *slog.Logger auth *authentication.EllipticCurve - allowList map[string]string + allowList allowList ordering.UnimplementedGorumsServer } @@ -258,14 +258,8 @@ func (srv *ClientServer) verify(req *Message) error { if err != nil { return err } - if srv.allowList != nil { - pemEncodedPub, ok := srv.allowList[authMsg.GetSender()] - if !ok { - return fmt.Errorf("not allowed") - } - if pemEncodedPub != authMsg.GetPublicKey() { - return fmt.Errorf("publicKey did not match") - } + if err := srv.allowList.Check(authMsg.GetSender(), authMsg.GetPublicKey()); err != nil { + return err } valid, err := srv.auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) if err != nil { diff --git a/server.go b/server.go index 85c295732..04a1a5d88 100644 --- a/server.go +++ b/server.go @@ -47,14 +47,8 @@ func (s *orderingServer) verify(req *Message) error { return err } auth := s.opts.auth - if s.opts.allowList != nil { - pemEncodedPub, ok := s.opts.allowList[authMsg.GetSender()] - if !ok { - return fmt.Errorf("not allowed") - } - if pemEncodedPub != authMsg.GetPublicKey() { - return fmt.Errorf("publicKey did not match") - } + if err := s.opts.allowList.Check(authMsg.GetSender(), authMsg.GetPublicKey()); err != nil { + return err } valid, err := auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) if err != nil { @@ -156,11 +150,32 @@ type serverOptions struct { // address and use forwarding from the host. If not this option is not given, // the listen address used on the gRPC listener will be used instead. listenAddr string - allowList map[string]string + allowList allowList auth *authentication.EllipticCurve grpcDialOpts []grpc.DialOption } +// allowList is a map of (address, public key) pairs. +type allowList map[string]string + +// Check checks if the address and public key are in the allow list. +// If the address is not in the allow list or the public key does not match, +// an error is returned. If the allow list is nil, no check is performed. +func (al allowList) Check(addr string, publicKey string) error { + if al == nil { + // bypass if no allow list specified + return nil + } + pemEncodedPub, ok := al[addr] + if !ok { + return fmt.Errorf("not allowed: %s", addr) + } + if pemEncodedPub != publicKey { + return fmt.Errorf("public key mismatch") + } + return nil +} + // ServerOption is used to change settings for the GorumsServer type ServerOption func(*serverOptions) @@ -286,7 +301,7 @@ func WithSrvID(machineID uint64) ServerOption { // the server. func WithAllowList(curve elliptic.Curve, allowed ...string) ServerOption { return func(o *serverOptions) { - o.allowList = make(map[string]string) + o.allowList = make(allowList) if len(allowed)%2 != 0 { panic("must provide (address, publicKey) pairs to WithAllowList()") } From ddd62d5f26432c046dc15582e22ed201a24763a2 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 14 Apr 2025 22:31:30 +0200 Subject: [PATCH 18/27] refactor: simplify verification by returning invalid signature error This returns an error also for invalid signatures, avoiding the extra check of the valid bool outside the VerifySignature method just to generate an invalid signature error. Thus, we can simply return the error directly. The returned error can be checked if it was an InvalidSignatureErr if necessary (as is done in tests). This is not used in non-test code yet. --- authentication/authentication.go | 21 ++++++++++++++------- authentication/authentication_test.go | 20 +++++++++----------- channel_test.go | 16 ++++++---------- clientserver.go | 9 +-------- server.go | 9 +-------- 5 files changed, 31 insertions(+), 44 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index 654389bb3..e1bee8000 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "errors" "fmt" "net" ) @@ -93,7 +94,7 @@ func (ec *EllipticCurve) DecodePrivate(pemEncodedPriv string) (*ecdsa.PrivateKey func (ec *EllipticCurve) DecodePublic(pemEncodedPub string) (*ecdsa.PublicKey, error) { blockPub, _ := pem.Decode([]byte(pemEncodedPub)) if blockPub == nil { - return nil, fmt.Errorf("invalid publicKey") + return nil, errors.New("invalid public key") } x509EncodedPub := blockPub.Bytes genericPublicKey, err := x509.ParsePKIXPublicKey(x509EncodedPub) @@ -114,14 +115,20 @@ func Hash(msg []byte) []byte { return hash[:] } -// VerifySignature sign ecdsa style and verify signature -func (ec *EllipticCurve) VerifySignature(pemEncodedPub string, msg, signature []byte) (bool, error) { +var InvalidSignatureErr = errors.New("invalid signature") + +// VerifySignature verifies the signature of the message's hash using the given PEM encoded +// public key. It returns an error if the signature is invalid or if there is an error +// decoding the public key. +func (ec *EllipticCurve) VerifySignature(pemEncodedPub string, msg, signature []byte) error { pubKey, err := ec.DecodePublic(pemEncodedPub) if err != nil { - return false, err + return err } - ok := ecdsa.VerifyASN1(pubKey, Hash(msg), signature) - return ok, nil + if valid := ecdsa.VerifyASN1(pubKey, Hash(msg), signature); !valid { + return InvalidSignatureErr + } + return nil } func EncodeMsg(msg any) []byte { @@ -133,7 +140,7 @@ func Verify(pemEncodedPub string, signature, digest []byte, msg any) (bool, erro ec := New(elliptic.P256()) hash := Hash(encodedMsg) if !bytes.Equal(hash, digest) { - return false, fmt.Errorf("wrong digest") + return false, errors.New("invalid digest for message") } pubKey, err := ec.DecodePublic(pemEncodedPub) if err != nil { diff --git a/authentication/authentication_test.go b/authentication/authentication_test.go index 2f83767c1..4198b47ee 100644 --- a/authentication/authentication_test.go +++ b/authentication/authentication_test.go @@ -42,12 +42,9 @@ func TestSignAndVerify(t *testing.T) { } encodedMsg2 := EncodeMsg(message) - ok, err := ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) + err = ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) if err != nil { - t.Error(err) - } - if !ok { - t.Error("signature not ok!") + t.Errorf("VerifySignature() = %v, want nil", err) } } @@ -78,12 +75,13 @@ func TestVerifyWithWrongPubKey(t *testing.T) { } encodedMsg2 := EncodeMsg(message) - ok, err := ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) - if err != nil { - t.Error(err) - } - if ok { - t.Error("signature should not be ok!") + err = ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) + if err == nil { + t.Errorf("VerifySignature() = nil, want %v", InvalidSignatureErr) + } else { + if !errors.Is(err, InvalidSignatureErr) { + t.Errorf("VerifySignature() = %v, want %v", err, InvalidSignatureErr) + } } } diff --git a/channel_test.go b/channel_test.go index ad1ad1efb..ab8b28eba 100644 --- a/channel_test.go +++ b/channel_test.go @@ -236,19 +236,15 @@ func TestAuthentication(t *testing.T) { if err != nil { t.Fatal(err) } - valid, err := auth.VerifySignature(pemEncodedPub, chEncodedMsg, signature) + err = auth.VerifySignature(pemEncodedPub, chEncodedMsg, signature) if err != nil { - t.Fatal(err) - } - if !valid { - t.Fatal("channel encoded msg not valid") + // channel encoded msg not valid + t.Fatalf("VerifySignature() = %v, want nil", err) } - valid, err = auth.VerifySignature(pemEncodedPub, srvEncodedMsg, signature) + err = auth.VerifySignature(pemEncodedPub, srvEncodedMsg, signature) if err != nil { - t.Fatal(err) - } - if !valid { - t.Fatal("srv encoded msg not valid") + // srv encoded msg not valid + t.Fatalf("VerifySignature() = %v, want nil", err) } config.sign(msg) diff --git a/clientserver.go b/clientserver.go index 990b9ad82..f37d844ea 100644 --- a/clientserver.go +++ b/clientserver.go @@ -261,14 +261,7 @@ func (srv *ClientServer) verify(req *Message) error { if err := srv.allowList.Check(authMsg.GetSender(), authMsg.GetPublicKey()); err != nil { return err } - valid, err := srv.auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) - if err != nil { - return err - } - if !valid { - return fmt.Errorf("invalid signature") - } - return nil + return srv.auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) } func createClient(addr string, dialOpts []grpc.DialOption) (*broadcast.Client, error) { diff --git a/server.go b/server.go index 04a1a5d88..6598cc90e 100644 --- a/server.go +++ b/server.go @@ -50,14 +50,7 @@ func (s *orderingServer) verify(req *Message) error { if err := s.opts.allowList.Check(authMsg.GetSender(), authMsg.GetPublicKey()); err != nil { return err } - valid, err := auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) - if err != nil { - return err - } - if !valid { - return fmt.Errorf("invalid signature") - } - return nil + return auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) } // SendMessage attempts to send a message on a channel. From 03cdce70e90d297804aec7374f0ae800491ab82a Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Tue, 15 Apr 2025 10:48:01 +0200 Subject: [PATCH 19/27] chore: remove unnecessary error handling in Sign method --- authentication/authentication.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index e1bee8000..51dafb1dd 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -103,11 +103,7 @@ func (ec *EllipticCurve) DecodePublic(pemEncodedPub string) (*ecdsa.PublicKey, e } func (ec *EllipticCurve) Sign(msg []byte) ([]byte, error) { - signature, err := ecdsa.SignASN1(rand.Reader, ec.privateKey, Hash(msg)) - if err != nil { - return nil, err - } - return signature, nil + return ecdsa.SignASN1(rand.Reader, ec.privateKey, Hash(msg)) } func Hash(msg []byte) []byte { From dd4c619f5c0e0c3affd0c7bbefb7a48b0e59da82 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Tue, 15 Apr 2025 10:52:44 +0200 Subject: [PATCH 20/27] refactor: simplify the Verify function to return error directly This reuses VerifySignature, but hashes the message twice. However, the Verify function is never used, so we need to decide if it is necessary to keep it. --- authentication/authentication.go | 11 +++-------- broadcast.go | 3 ++- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index 51dafb1dd..146bbd6dd 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -131,17 +131,12 @@ func EncodeMsg(msg any) []byte { return fmt.Appendf(nil, "%v", msg) } -func Verify(pemEncodedPub string, signature, digest []byte, msg any) (bool, error) { +func Verify(pemEncodedPub string, signature, digest []byte, msg any) error { encodedMsg := EncodeMsg(msg) ec := New(elliptic.P256()) hash := Hash(encodedMsg) if !bytes.Equal(hash, digest) { - return false, errors.New("invalid digest for message") + return errors.New("invalid digest for message") } - pubKey, err := ec.DecodePublic(pemEncodedPub) - if err != nil { - return false, err - } - ok := ecdsa.VerifyASN1(pubKey, hash, signature) - return ok, nil + return ec.VerifySignature(pemEncodedPub, encodedMsg, signature) } diff --git a/broadcast.go b/broadcast.go index 342ea864a..65aafa5eb 100644 --- a/broadcast.go +++ b/broadcast.go @@ -207,6 +207,7 @@ func newBroadcastMetadata(md *ordering.Metadata) BroadcastMetadata { } } -func (md BroadcastMetadata) Verify(msg protoreflect.ProtoMessage) (bool, error) { +// TODO(meling): this method is never called +func (md BroadcastMetadata) Verify(msg protoreflect.ProtoMessage) error { return authentication.Verify(md.OriginPubKey, md.OriginSignature, md.OriginDigest, msg) } From 74f4cb119bc23d0b0608eb21b75dea5d055dc78a Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Tue, 15 Apr 2025 16:33:31 +0200 Subject: [PATCH 21/27] chore: WithAllowList could replace EnforceAuthentication This change allows us to remove EnforceAuthentication and we could use WithAllowList instead. --- server.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/server.go b/server.go index 6598cc90e..d9072b64d 100644 --- a/server.go +++ b/server.go @@ -46,11 +46,10 @@ func (s *orderingServer) verify(req *Message) error { if err != nil { return err } - auth := s.opts.auth if err := s.opts.allowList.Check(authMsg.GetSender(), authMsg.GetPublicKey()); err != nil { return err } - return auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) + return s.opts.auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) } // SendMessage attempts to send a message on a channel. @@ -289,14 +288,18 @@ func WithSrvID(machineID uint64) ServerOption { } } -// WithAllowList accepts (address, publicKey) pairs which is used to validate +// WithAllowList accepts (address, public-key) pairs which is used to authenticate // messages. Only nodes added to the allow list are allowed to send msgs to // the server. func WithAllowList(curve elliptic.Curve, allowed ...string) ServerOption { return func(o *serverOptions) { + if len(allowed) == 0 { + o.auth = authentication.New(curve) + return + } o.allowList = make(allowList) if len(allowed)%2 != 0 { - panic("must provide (address, publicKey) pairs to WithAllowList()") + panic("must provide (address, public-key) pairs to WithAllowList()") } for i := range allowed { if i%2 != 0 { From c1c372ac0cda44212f72d869108fd308b5988b2d Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Tue, 15 Apr 2025 16:38:09 +0200 Subject: [PATCH 22/27] refactor: replace New method with NewWithAddr for EllipticCurve creation This simplifies creation of elliptic curve for message auth by avoiding that users have to generating keys, registering keys, and querying the keys. --- authentication/authentication.go | 51 +++++++++++---------------- authentication/authentication_test.go | 36 +++++++++++-------- broadcast.go | 6 ---- channel_test.go | 9 ++--- tests/broadcast/client.go | 9 ++--- tests/broadcast/server.go | 11 +++--- 6 files changed, 59 insertions(+), 63 deletions(-) diff --git a/authentication/authentication.go b/authentication/authentication.go index 146bbd6dd..7ac2e5a91 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -13,47 +13,38 @@ import ( "net" ) -// Elliptic Curve Cryptography (ECC) is a key-based technique for encrypting data. -// ECC focuses on pairs of public and private keys for decryption and encryption of web traffic. -// ECC is frequently discussed in the context of the Rivest–Shamir–Adleman (RSA) cryptographic algorithm. -// -// https://pkg.go.dev/github.com/katzenpost/core/crypto/eddsa +// EllipticCurve represents an elliptic curve instance used for message authentication. type EllipticCurve struct { - addr net.Addr // used to identify self - pubKeyCurve elliptic.Curve // http://golang.org/pkg/crypto/elliptic/#P256 - privateKey *ecdsa.PrivateKey - publicKey *ecdsa.PublicKey + addr net.Addr // used to identify self + curve elliptic.Curve + privateKey *ecdsa.PrivateKey + publicKey *ecdsa.PublicKey } // New EllipticCurve instance func New(curve elliptic.Curve) *EllipticCurve { return &EllipticCurve{ - pubKeyCurve: curve, - privateKey: new(ecdsa.PrivateKey), + curve: curve, + privateKey: new(ecdsa.PrivateKey), } } -// GenerateKeys EllipticCurve public and private keys -func (ec *EllipticCurve) GenerateKeys() error { - privKey, err := ecdsa.GenerateKey(ec.pubKeyCurve, rand.Reader) +// NewWithAddr returns a new EllipticCurve instance with the given address. +// It generates a new public-private key pair for the specified elliptic curve. +func NewWithAddr(curve elliptic.Curve, addr net.Addr) (*EllipticCurve, error) { + if addr == nil { + return nil, errors.New("address cannot be nil") + } + priv, err := ecdsa.GenerateKey(curve, rand.Reader) if err != nil { - return err + return nil, err } - ec.privateKey = privKey - ec.publicKey = &privKey.PublicKey - return nil -} - -// RegisterKeys EllipticCurve public and private keys -func (ec *EllipticCurve) RegisterKeys(addr net.Addr, privKey *ecdsa.PrivateKey, pubKey *ecdsa.PublicKey) { - ec.addr = addr - ec.privateKey = privKey - ec.publicKey = pubKey -} - -// Returns the EllipticCurve public and private keys -func (ec *EllipticCurve) Keys() (*ecdsa.PrivateKey, *ecdsa.PublicKey) { - return ec.privateKey, ec.publicKey + return &EllipticCurve{ + addr: addr, + curve: curve, + privateKey: priv, + publicKey: &priv.PublicKey, + }, nil } // Returns the address diff --git a/authentication/authentication_test.go b/authentication/authentication_test.go index 4198b47ee..39b126ba4 100644 --- a/authentication/authentication_test.go +++ b/authentication/authentication_test.go @@ -3,34 +3,41 @@ package authentication import ( "crypto/elliptic" "errors" + "net" "reflect" "testing" ) func TestAuthentication(t *testing.T) { - ec := New(elliptic.P256()) - _ = ec.GenerateKeys() - err := ec.test() + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5000") + if err != nil { + t.Fatal(err) + } + ec, err := NewWithAddr(elliptic.P256(), addr) + if err != nil { + t.Fatal(err) + } + err = ec.test() if err != nil { t.Error(err) } } func TestSignAndVerify(t *testing.T) { - ec1 := New(elliptic.P256()) - err := ec1.GenerateKeys() + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5000") if err != nil { t.Fatal(err) } - - ec2 := New(elliptic.P256()) - err = ec2.GenerateKeys() + ec1, err := NewWithAddr(elliptic.P256(), addr) + if err != nil { + t.Fatal(err) + } + ec2, err := NewWithAddr(elliptic.P256(), addr) if err != nil { t.Fatal(err) } message := "This is a message" - encodedMsg1 := EncodeMsg(message) signature, err := ec1.Sign(encodedMsg1) if err != nil { @@ -49,14 +56,15 @@ func TestSignAndVerify(t *testing.T) { } func TestVerifyWithWrongPubKey(t *testing.T) { - ec1 := New(elliptic.P256()) - err := ec1.GenerateKeys() + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5000") if err != nil { t.Fatal(err) } - - ec2 := New(elliptic.P256()) - err = ec2.GenerateKeys() + ec1, err := NewWithAddr(elliptic.P256(), addr) + if err != nil { + t.Fatal(err) + } + ec2, err := NewWithAddr(elliptic.P256(), addr) if err != nil { t.Fatal(err) } diff --git a/broadcast.go b/broadcast.go index 65aafa5eb..df3a3eb31 100644 --- a/broadcast.go +++ b/broadcast.go @@ -2,7 +2,6 @@ package gorums import ( "context" - "crypto/elliptic" "hash/fnv" "log/slog" "strings" @@ -19,11 +18,6 @@ import ( // exposing the log entry struct used for structured logging to the user type LogEntry logging.LogEntry -// exposing the ellipticCurve struct for the user -func NewAuth(curve elliptic.Curve) *authentication.EllipticCurve { - return authentication.New(curve) -} - type broadcastServer struct { viewMutex sync.RWMutex id uint32 diff --git a/channel_test.go b/channel_test.go index ab8b28eba..f0a400c77 100644 --- a/channel_test.go +++ b/channel_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/relab/gorums/authentication" "github.com/relab/gorums/ordering" "github.com/relab/gorums/tests/mock" "google.golang.org/grpc" @@ -200,10 +201,10 @@ func TestAuthentication(t *testing.T) { if err != nil { t.Fatal(err) } - auth := NewAuth(elliptic.P256()) - _ = auth.GenerateKeys() - privKey, pubKey := auth.Keys() - auth.RegisterKeys(addr, privKey, pubKey) + auth, err := authentication.NewWithAddr(elliptic.P256(), addr) + if err != nil { + t.Fatal(err) + } mgr := NewRawManager(WithAuthentication(auth)) defer mgr.Close() node.mgr = mgr diff --git a/tests/broadcast/client.go b/tests/broadcast/client.go index f447f6c14..e993baed2 100644 --- a/tests/broadcast/client.go +++ b/tests/broadcast/client.go @@ -6,6 +6,7 @@ import ( net "net" gorums "github.com/relab/gorums" + "github.com/relab/gorums/authentication" grpc "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -174,10 +175,10 @@ func newAuthClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configu return nil, nil, err } } - auth := gorums.NewAuth(elliptic.P256()) - _ = auth.GenerateKeys() - privKey, pubKey := auth.Keys() - auth.RegisterKeys(lis.Addr(), privKey, pubKey) + auth, err := authentication.NewWithAddr(elliptic.P256(), lis.Addr()) + if err != nil { + return nil, nil, err + } mgr := NewManager( gorums.WithAuthentication(auth), gorums.WithGrpcDialOptions( diff --git a/tests/broadcast/server.go b/tests/broadcast/server.go index f8f24830b..61a780d30 100644 --- a/tests/broadcast/server.go +++ b/tests/broadcast/server.go @@ -5,13 +5,14 @@ import ( "crypto/elliptic" "errors" net "net" + "slices" "sync" "time" gorums "github.com/relab/gorums" + "github.com/relab/gorums/authentication" grpc "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "slices" ) var leader = "127.0.0.1:5000" @@ -88,10 +89,10 @@ func newAuthenticatedServer(addr string, srvAddresses []string) *testServer { if addr != leader { srv.processingTime = 100 * time.Millisecond } - auth := gorums.NewAuth(elliptic.P256()) - _ = auth.GenerateKeys() - privKey, pubKey := auth.Keys() - auth.RegisterKeys(address, privKey, pubKey) + auth, err := authentication.NewWithAddr(elliptic.P256(), address) + if err != nil { + panic(err) + } srv.mgr = NewManager( gorums.WithAuthentication(auth), gorums.WithGrpcDialOptions( From 40b7e639e3836f1cf71582bf4d077f85f581e9df Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Fri, 13 Jun 2025 17:34:47 +0200 Subject: [PATCH 23/27] refactor(snowflake): use consts for bit lengths; adds InvalidMachineID This uses const values for bit lengths of timestamp, shard, machine ID, and sequenceNum. Thus, the different max values and bit masks are calculated based on the const bit lengths. This also adds the InvalidMachineID function to allow initializing a snowflake instance to an invalid machine ID that must be set afterwards. --- broadcast/snowflake.go | 53 +++++++++++++++++++++++-------------- broadcast/snowflake_test.go | 4 +-- opts.go | 8 +++--- server.go | 4 +-- 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/broadcast/snowflake.go b/broadcast/snowflake.go index d3665f0e0..b56233725 100644 --- a/broadcast/snowflake.go +++ b/broadcast/snowflake.go @@ -16,14 +16,22 @@ type Snowflake struct { } const ( - MaxMachineID = uint16(1 << 12) - maxShard = uint8(1 << 4) - maxSequenceNum = uint32(1 << 18) - bitMaskTimestamp = uint64((1<<30)-1) << 34 - bitMaskShardID = uint64((1<<4)-1) << 30 - bitMaskMachineID = uint64((1<<12)-1) << 18 - bitMaskSequenceNum = uint64((1 << 18) - 1) - epoch = "2024-01-01T00:00:00" + timestampBits = 30 // seconds since 01.01.2025 + shardIDBits = 4 // 16 different shards + machineIDBits = 12 // 4096 clients + sequenceNumBits = 18 // 262 144 messages + timestampBitsShift = 64 - timestampBits // 34 + + maxShard = uint8(1 << shardIDBits) + maxMachineID = uint16(1 << machineIDBits) + maxSequenceNum = uint32(1 << sequenceNumBits) + + bitMaskTimestamp = uint64((1<= uint64(MaxMachineID) { - id = uint64(rand.Int31n(int32(MaxMachineID))) + if id >= uint64(maxMachineID) { + id = uint64(rand.Int31n(int32(maxMachineID))) } return &Snowflake{ MachineID: id, @@ -43,10 +51,6 @@ func NewSnowflake(id uint64) *Snowflake { } func (s *Snowflake) NewBroadcastID() uint64 { - // timestamp: 30 bit -> seconds since 01.01.2024 - // shardID: 4 bit -> 16 different shards - // machineID: 12 bit -> 4096 clients - // sequenceNum: 18 bit -> 262 144 messages start: s.mut.Lock() timestamp := uint64(time.Since(s.epoch).Seconds()) @@ -63,17 +67,26 @@ start: s.SequenceNum = l s.mut.Unlock() - t := (timestamp << 34) & bitMaskTimestamp - shard := (uint64(rand.Int31n(int32(maxShard))) << 30) & bitMaskShardID - m := uint64(s.MachineID<<18) & bitMaskMachineID + t := (timestamp << timestampBitsShift) & bitMaskTimestamp + shard := (uint64(rand.Int31n(int32(maxShard))) << timestampBits) & bitMaskShardID + m := uint64(s.MachineID<> 34 - shard := (broadcastID & bitMaskShardID) >> 30 - m := (broadcastID & bitMaskMachineID) >> 18 + t := (broadcastID & bitMaskTimestamp) >> timestampBitsShift + shard := (broadcastID & bitMaskShardID) >> timestampBits + m := (broadcastID & bitMaskMachineID) >> sequenceNumBits n := (broadcastID & bitMaskSequenceNum) return uint32(t), uint16(shard), uint16(m), uint32(n) } + +// InvalidMachineID returns an invalid machine ID. +// This can be used to initialize a Snowflake instance to avoid unintentional +// collisions with valid machine IDs. This is necessary because 0 is a valid +// machine ID and should not be used as the default. +// TODO(meling): make the zero value be the invalid machine ID instead. +func InvalidMachineID() uint64 { + return uint64(maxMachineID) + 1 +} diff --git a/broadcast/snowflake_test.go b/broadcast/snowflake_test.go index b39dadc48..776e8cb85 100644 --- a/broadcast/snowflake_test.go +++ b/broadcast/snowflake_test.go @@ -5,8 +5,8 @@ import ( ) func TestBroadcastID(t *testing.T) { - if MaxMachineID != 4096 { - t.Errorf("maxMachineID is hardcoded in test. want: %v, got: %v", 4096, MaxMachineID) + if maxMachineID != 4096 { + t.Errorf("maxMachineID is hardcoded in test. want: %v, got: %v", 4096, maxMachineID) } if maxSequenceNum != 262144 { t.Errorf("maxSequenceNum is hardcoded in test. want: %v, got: %v", 262144, maxSequenceNum) diff --git a/opts.go b/opts.go index 27568dd8b..06afd8adb 100644 --- a/opts.go +++ b/opts.go @@ -26,11 +26,9 @@ type managerOptions struct { func newManagerOptions() managerOptions { return managerOptions{ - backoff: backoff.DefaultConfig, - sendBuffer: 100, - // Provide an illegal machineID to avoid unintentional collisions. - // 0 is a valid MachineID and should not be used as default. - machineID: uint64(broadcast.MaxMachineID) + 1, + backoff: backoff.DefaultConfig, + sendBuffer: 100, + machineID: broadcast.InvalidMachineID(), maxSendRetries: 0, maxConnRetries: -1, // no limit } diff --git a/server.go b/server.go index d9072b64d..6a3f4990a 100644 --- a/server.go +++ b/server.go @@ -343,9 +343,7 @@ func NewServer(opts ...ServerOption) *Server { shardBuffer: 200, sendBuffer: 30, reqTTL: 5 * time.Minute, - // Provide an illegal machineID to avoid unintentional collisions. - // 0 is a valid MachineID and should not be used as default. - machineID: uint64(broadcast.MaxMachineID) + 1, + machineID: broadcast.InvalidMachineID(), } for _, opt := range opts { opt(&serverOpts) From 519c3f0a16dcf12e6c5c3c704b5228cda81e6962 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sat, 14 Jun 2025 01:06:42 +0200 Subject: [PATCH 24/27] refactor: use logging.MachineID helper to init logger --- mgr.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mgr.go b/mgr.go index d4185f5a5..76736d034 100644 --- a/mgr.go +++ b/mgr.go @@ -43,9 +43,7 @@ func NewRawManager(opts ...ManagerOption) *RawManager { } snowflake := broadcast.NewSnowflake(m.opts.machineID) if m.opts.logger != nil { - // a random machineID will be generated if m.opts.machineID is invalid - mID := snowflake.MachineID - m.logger = m.opts.logger.With(slog.Uint64("MachineID", mID)) + m.logger = m.opts.logger.With(logging.MachineID(snowflake.MachineID)) } m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithDefaultCallOptions( grpc.CallContentSubtype(ContentSubtype), From 425421708751c9ce7736a0738895da387939c648 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sat, 14 Jun 2025 01:50:18 +0200 Subject: [PATCH 25/27] refactor: update Snowflake machine ID handling and improve docs --- broadcast/snowflake.go | 18 ++++++------------ broadcast/snowflake_test.go | 29 +++++++++++++++++++++++++++++ opts.go | 12 ++++++------ server.go | 12 +++++++----- 4 files changed, 48 insertions(+), 23 deletions(-) diff --git a/broadcast/snowflake.go b/broadcast/snowflake.go index b56233725..dd8969217 100644 --- a/broadcast/snowflake.go +++ b/broadcast/snowflake.go @@ -29,7 +29,7 @@ const ( bitMaskTimestamp = uint64((1<= uint64(maxMachineID) { - id = uint64(rand.Int31n(int32(maxMachineID))) + if id == 0 || id > uint64(maxMachineID) { + id = uint64(rand.Int31n(int32(maxMachineID))) + 1 // avoid 0 as the machine ID } return &Snowflake{ MachineID: id, @@ -81,12 +84,3 @@ func DecodeBroadcastID(broadcastID uint64) (timestamp uint32, shardID uint16, ma n := (broadcastID & bitMaskSequenceNum) return uint32(t), uint16(shard), uint16(m), uint32(n) } - -// InvalidMachineID returns an invalid machine ID. -// This can be used to initialize a Snowflake instance to avoid unintentional -// collisions with valid machine IDs. This is necessary because 0 is a valid -// machine ID and should not be used as the default. -// TODO(meling): make the zero value be the invalid machine ID instead. -func InvalidMachineID() uint64 { - return uint64(maxMachineID) + 1 -} diff --git a/broadcast/snowflake_test.go b/broadcast/snowflake_test.go index 776e8cb85..199cb6271 100644 --- a/broadcast/snowflake_test.go +++ b/broadcast/snowflake_test.go @@ -1,6 +1,7 @@ package broadcast import ( + "fmt" "testing" ) @@ -46,3 +47,31 @@ func TestBroadcastID(t *testing.T) { } } } + +func TestNewSnowflake(t *testing.T) { + const random = 0 + tests := []struct { + machineID uint64 + wantID uint64 + }{ + {0, random}, // should generate a random machine ID + {4097, random}, // should generate a random machine ID + {1, 1}, // use expected machine ID + {4096, 4096}, // use expected machine ID + {1234, 1234}, // use expected machine ID + } + for _, tt := range tests { + t.Run(fmt.Sprintf("MachineID=%d", tt.machineID), func(t *testing.T) { + snowflake := NewSnowflake(tt.machineID) + if tt.wantID == random { + if snowflake.MachineID == 0 || snowflake.MachineID > 4096 { + t.Errorf("NewSnowflake(%d) = %d, want random machine ID in range [1, 4096]", tt.machineID, snowflake.MachineID) + } + return + } + if snowflake.MachineID != tt.wantID { + t.Errorf("NewSnowflake should use the provided machine ID. got: %v", snowflake.MachineID) + } + }) + } +} diff --git a/opts.go b/opts.go index 06afd8adb..b2c908206 100644 --- a/opts.go +++ b/opts.go @@ -4,7 +4,6 @@ import ( "log/slog" "github.com/relab/gorums/authentication" - "github.com/relab/gorums/broadcast" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/metadata" @@ -28,7 +27,7 @@ func newManagerOptions() managerOptions { return managerOptions{ backoff: backoff.DefaultConfig, sendBuffer: 100, - machineID: broadcast.InvalidMachineID(), + machineID: 0, maxSendRetries: 0, maxConnRetries: -1, // no limit } @@ -102,10 +101,11 @@ func WithAuthentication(auth *authentication.EllipticCurve) ManagerOption { } } -// WithMachineID returns a ManagerOption that allows you to set a unique ID for the client. -// This ID will be embedded in broadcast request sent from the client, making the requests -// trackable by the whole cluster. A random ID will be generated if not set. This can cause -// collisions if there are many clients. MinID = 0 and MaxID = 4095. +// WithMachineID returns a ManagerOption that sets a unique ID for the client. +// The valid range for this ID is 1 to 4095, and it should be unique for each client. +// This ID will be embedded in broadcast requests sent from the client, making client +// requests trackable across the server replicas. A random ID will be generated if not set. +// This can cause collisions if there are many clients. func WithMachineID(id uint64) ManagerOption { return func(o *managerOptions) { o.machineID = id diff --git a/server.go b/server.go index 6a3f4990a..aeb8dc915 100644 --- a/server.go +++ b/server.go @@ -10,7 +10,6 @@ import ( "time" "github.com/relab/gorums/authentication" - "github.com/relab/gorums/broadcast" "github.com/relab/gorums/ordering" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -273,9 +272,12 @@ func WithSLogger(logger *slog.Logger) ServerOption { } } -// WithSrvID sets the MachineID of the broadcast server. This ID is used to -// generate BroadcastIDs. This method should be used if a replica needs to -// initiate a broadcast request. +// WithSrvID returns a ServerOption that sets the MachineID of the broadcast server. +// This method should be used if a replica needs to initiate a broadcast request. +// The valid range for this ID is 1 to 4095, and it should be unique for each replica. +// This ID will be embedded in broadcast requests sent from the replica, making requests +// trackable across the replicas. A random ID will be generated if not set. +// This can cause collisions if there are many replicas. // // An example use case is in Paxos: // The designated leader sends a prepare and receives some promises it has @@ -343,7 +345,7 @@ func NewServer(opts ...ServerOption) *Server { shardBuffer: 200, sendBuffer: 30, reqTTL: 5 * time.Minute, - machineID: broadcast.InvalidMachineID(), + machineID: 0, } for _, opt := range opts { opt(&serverOpts) From bff8a28e55429ee099c86f88bd3d480f4a5bb4a1 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sat, 14 Jun 2025 01:55:33 +0200 Subject: [PATCH 26/27] refactor: use ExtractShardID for shard ID extraction Most uses of DecodeBroadcastID only need the shardID, so there is no need to extract the other parts of the broadcast ID. --- broadcast/manager.go | 17 ++++++----------- broadcast/snowflake.go | 5 +++++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/broadcast/manager.go b/broadcast/manager.go index be28222d2..a8762527d 100644 --- a/broadcast/manager.go +++ b/broadcast/manager.go @@ -40,9 +40,8 @@ func NewBroadcastManager(logger *slog.Logger, createClient func(addr string, dia } func (mgr *manager) Process(msg *Content) (context.Context, func(*Msg) error, error) { - _, shardID, _, _ := DecodeBroadcastID(msg.BroadcastID) - shardID = shardID % NumShards - shard := mgr.state.shards[shardID] + shardID := ExtractShardID(msg.BroadcastID) + shard := mgr.state.getShard(shardID) // we only need a single response receiveChan := make(chan shardResponse, 1) @@ -67,8 +66,7 @@ func (mgr *manager) Broadcast(broadcastID uint64, req protoreflect.ProtoMessage, return enqueueBroadcast(msg) } // slow path: communicate with the shard first - _, shardID, _, _ := DecodeBroadcastID(broadcastID) - shardID = shardID % NumShards + shardID := ExtractShardID(msg.BroadcastID) shard := mgr.state.getShard(shardID) shard.handleBMsg(msg) return nil @@ -88,8 +86,7 @@ func (mgr *manager) SendToClient(broadcastID uint64, resp protoreflect.ProtoMess return enqueueBroadcast(msg) } // slow path: communicate with the shard first - _, shardID, _, _ := DecodeBroadcastID(broadcastID) - shardID = shardID % NumShards + shardID := ExtractShardID(msg.BroadcastID) shard := mgr.state.getShard(shardID) shard.handleBMsg(msg) return nil @@ -107,8 +104,7 @@ func (mgr *manager) Cancel(broadcastID uint64, srvAddrs []string, enqueueBroadca if enqueueBroadcast != nil { return enqueueBroadcast(msg) } - _, shardID, _, _ := DecodeBroadcastID(broadcastID) - shardID = shardID % NumShards + shardID := ExtractShardID(msg.BroadcastID) shard := mgr.state.getShard(shardID) shard.handleBMsg(msg) return nil @@ -128,8 +124,7 @@ func (mgr *manager) Done(broadcastID uint64, enqueueBroadcast func(*Msg) error) _ = enqueueBroadcast(msg) return } - _, shardID, _, _ := DecodeBroadcastID(broadcastID) - shardID = shardID % NumShards + shardID := ExtractShardID(msg.BroadcastID) shard := mgr.state.getShard(shardID) shard.handleBMsg(msg) } diff --git a/broadcast/snowflake.go b/broadcast/snowflake.go index dd8969217..7ec84a70d 100644 --- a/broadcast/snowflake.go +++ b/broadcast/snowflake.go @@ -84,3 +84,8 @@ func DecodeBroadcastID(broadcastID uint64) (timestamp uint32, shardID uint16, ma n := (broadcastID & bitMaskSequenceNum) return uint32(t), uint16(shard), uint16(m), uint32(n) } + +// ExtractShardID extracts the shard ID from the broadcast ID. +func ExtractShardID(broadcastID uint64) uint16 { + return uint16((broadcastID & bitMaskShardID) >> timestampBits) +} From bdb498915093843e437a8ac72d42e6ed635c16c5 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Sat, 14 Jun 2025 02:24:26 +0200 Subject: [PATCH 27/27] refactor: remove unused listenAddr field in newClient The listenAddr field in newClient and newAuthClient was only used as a boolean to determine whether or not to create a listener on an arbitrary port; not the port specified in the listenAddr. This just removes the listenAddr field and uses the arbitrary localhost port. --- tests/broadcast/broadcast_test.go | 76 +++++++++++++++---------------- tests/broadcast/client.go | 50 ++++++++------------ tests/broadcast/server.go | 2 +- 3 files changed, 59 insertions(+), 69 deletions(-) diff --git a/tests/broadcast/broadcast_test.go b/tests/broadcast/broadcast_test.go index 7aa8a6025..744a46605 100644 --- a/tests/broadcast/broadcast_test.go +++ b/tests/broadcast/broadcast_test.go @@ -64,9 +64,9 @@ func createSrvs(numSrvs int, down ...int) ([]*testServer, []string, func(), erro } var srv *testServer if ordering { - srv = newtestServer(addr, srvAddrs, i, true) + srv = newTestServer(addr, srvAddrs, i, true) } else { - srv = newtestServer(addr, srvAddrs, i) + srv = newTestServer(addr, srvAddrs, i) } lis, err := net.Listen("tcp4", srv.addr) if err != nil { @@ -93,7 +93,7 @@ func TestSimpleBroadcastCall(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -123,7 +123,7 @@ func TestSimpleBroadcastTo(t *testing.T) { defer srvCleanup() // only want a response from the leader, hence qsize = 1 - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080", 1) + config, clientCleanup, err := newClient(srvAddrs, 1) if err != nil { t.Error(err) } @@ -152,7 +152,7 @@ func TestSimpleBroadcastCancel(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -181,7 +181,7 @@ func TestBroadcastCancel(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -192,7 +192,7 @@ func TestBroadcastCancel(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) _, _ = config.LongRunningTask(ctx, &Request{Value: val}, true) cancel() - // wait until cancel has reaced the servers before asking for the result + // wait until cancel has reached the servers before asking for the result time.Sleep(1 * time.Second) ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) resp, err := config.GetVal(ctx, &Request{Value: val}) @@ -218,7 +218,7 @@ func TestBroadcastCancelOneSrvDown(t *testing.T) { defer srvCleanup() // only want response from the online servers - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080", numSrvs-1) + config, clientCleanup, err := newClient(srvAddrs, numSrvs-1) if err != nil { t.Error(err) } @@ -250,7 +250,7 @@ func TestBroadcastCancelOneSrvFails(t *testing.T) { defer srvCleanup() // only want response from the online servers - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080", numSrvs-1) + config, clientCleanup, err := newClient(srvAddrs, numSrvs-1) if err != nil { t.Error(err) } @@ -284,7 +284,7 @@ func TestBroadcastCancelOneClientFails(t *testing.T) { defer srvCleanup() // only want response from the online servers - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -299,7 +299,7 @@ func TestBroadcastCancelOneClientFails(t *testing.T) { clientCleanup() // only want response from the online servers - config2, clientCleanup2, err2 := newClient(srvAddrs, "127.0.0.1:8081") + config2, clientCleanup2, err2 := newClient(srvAddrs) defer clientCleanup2() if err2 != nil { t.Error(err2) @@ -326,7 +326,7 @@ func TestBroadcastCallOrderingSendToOneSrv(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[1:2], "127.0.0.1:8080", len(srvAddrs)) + config, clientCleanup, err := newClient(srvAddrs[1:2], len(srvAddrs)) if err != nil { t.Error(err) } @@ -356,7 +356,7 @@ func TestBroadcastCallOrderingSendToAllSrvs(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -386,7 +386,7 @@ func TestBroadcastCallOrderingDoesNotInterfereWithMethodsNotSpecifiedInOrder(t * } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -414,7 +414,7 @@ func TestBroadcastCallRace(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -446,7 +446,7 @@ func TestBroadcastCallClientKnowsOnlyOneServer(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[0:1], "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs[0:1]) if err != nil { t.Error(err) } @@ -482,7 +482,7 @@ func TestBroadcastCallOneServerIsDown(t *testing.T) { start := max(skip, 0) end := min(numSrvs-1, len(srvAddrs)) - config, clientCleanup, err := newClient(srvAddrs[start:end], "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs[start:end]) if err != nil { t.Error(err) } @@ -514,7 +514,7 @@ func TestBroadcastCallForward(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[1:2], "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs[1:2]) if err != nil { t.Error(err) } @@ -540,7 +540,7 @@ func TestBroadcastCallForwardMultiple(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[1:], "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs[1:]) if err != nil { t.Error(err) } @@ -564,13 +564,13 @@ func TestBroadcastCallRaceTwoClients(t *testing.T) { } defer srvCleanup() - client1, clientCleanup1, err := newClient(srvAddrs, "127.0.0.1:8080") + client1, clientCleanup1, err := newClient(srvAddrs) if err != nil { t.Error(err) } defer clientCleanup1() - client2, clientCleanup2, err := newClient(srvAddrs, "127.0.0.1:8081") + client2, clientCleanup2, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -634,7 +634,7 @@ func TestBroadcastCallAsyncReqs(t *testing.T) { numClients := 10 clients := make([]*Configuration, numClients) for c := range numClients { - config, clientCleanup, err := newClient(srvAddrs, fmt.Sprintf("127.0.0.1:808%v", c), 3) + config, clientCleanup, err := newClient(srvAddrs, 3) if err != nil { t.Error(err) } @@ -683,7 +683,7 @@ func TestQCBroadcastOptionRace(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -717,7 +717,7 @@ func TestQCMulticastRace(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -751,7 +751,7 @@ func BenchmarkQuorumCall(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { b.Error(err) } @@ -792,7 +792,7 @@ func BenchmarkQCMulticast(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { b.Error(err) } @@ -835,7 +835,7 @@ func BenchmarkQCBroadcastOption(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { b.Error(err) } @@ -881,7 +881,7 @@ func BenchmarkQCBroadcastOptionManyClients(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) for c := range numClients { - config, clientCleanup, err := newClient(srvAddrs, fmt.Sprintf("127.0.0.1:808%v", c), 3) + config, clientCleanup, err := newClient(srvAddrs, 3) if err != nil { b.Error(err) } @@ -934,7 +934,7 @@ func BenchmarkBroadcastCallAllServers(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { b.Error(err) } @@ -975,7 +975,7 @@ func BenchmarkBroadcastCallToOneServer(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[0:1], "127.0.0.1:8080", 3) + config, clientCleanup, err := newClient(srvAddrs[0:1], 3) if err != nil { b.Error(err) } @@ -1016,7 +1016,7 @@ func BenchmarkBroadcastCallOneFailedServer(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080", 2) + config, clientCleanup, err := newClient(srvAddrs, 2) if err != nil { b.Error(err) } @@ -1057,7 +1057,7 @@ func BenchmarkBroadcastCallOneDownSrvToOneSrv(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[1:2], "127.0.0.1:8080", 2) + config, clientCleanup, err := newClient(srvAddrs[1:2], 2) if err != nil { b.Error(err) } @@ -1101,7 +1101,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) for c := range numClients { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), 3) + config, clientCleanup, err := newClient(srvAddrs[0:1], 3) if err != nil { b.Error(err) } @@ -1209,7 +1209,7 @@ func BenchmarkBroadcastCallTenClientsCPU(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) for c := range numClients { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), 3) + config, clientCleanup, err := newClient(srvAddrs[0:1], 3) if err != nil { b.Error(err) } @@ -1270,7 +1270,7 @@ func BenchmarkBroadcastCallTenClientsMEM(b *testing.B) { numClients := 1 clients := make([]*Configuration, numClients) for c := range numClients { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) + config, clientCleanup, err := newClient(srvAddrs[0:1], numSrvs) if err != nil { b.Error(err) } @@ -1324,7 +1324,7 @@ func BenchmarkBroadcastCallTenClientsTRACE(b *testing.B) { numClients := 1 clients := make([]*Configuration, numClients) for c := range numClients { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) + config, clientCleanup, err := newClient(srvAddrs[0:1], numSrvs) if err != nil { b.Error(err) } @@ -1381,7 +1381,7 @@ func TestBroadcastCallManyRequestsAsync(t *testing.T) { clients := make([]*Configuration, numClients) for c := range numClients { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) + config, clientCleanup, err := newClient(srvAddrs[0:1], numSrvs) if err != nil { t.Error(err) } @@ -1440,7 +1440,7 @@ func TestAuthenticationBroadcastCall(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newAuthClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newAuthClient(srvAddrs) if err != nil { t.Error(err) } diff --git a/tests/broadcast/client.go b/tests/broadcast/client.go index e993baed2..fbd87570a 100644 --- a/tests/broadcast/client.go +++ b/tests/broadcast/client.go @@ -3,11 +3,11 @@ package broadcast import ( "crypto/elliptic" "log/slog" - net "net" + "net" - gorums "github.com/relab/gorums" + "github.com/relab/gorums" "github.com/relab/gorums/authentication" - grpc "google.golang.org/grpc" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -128,7 +128,7 @@ func (qs *testQSpec) OrderQF(in *Request, replies []*Response) (*Response, bool) return nil, false } -func newClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configuration, func(), error) { +func newClient(srvAddrs []string, qsize ...int) (*Configuration, func(), error) { quorumSize := len(srvAddrs) if len(qsize) > 0 { quorumSize = qsize[0] @@ -138,15 +138,13 @@ func newClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configurati grpc.WithTransportCredentials(insecure.NewCredentials()), ), ) - if listenAddr != "" { - lis, err := net.Listen("tcp", "127.0.0.1:") - if err != nil { - return nil, nil, err - } - err = mgr.AddClientServer(lis, lis.Addr()) - if err != nil { - return nil, nil, err - } + lis, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + return nil, nil, err + } + err = mgr.AddClientServer(lis, lis.Addr()) + if err != nil { + return nil, nil, err } config, err := mgr.NewConfiguration( gorums.WithNodeList(srvAddrs), @@ -155,12 +153,10 @@ func newClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configurati if err != nil { return nil, nil, err } - return config, func() { - mgr.Close() - }, nil + return config, mgr.Close, nil } -func newAuthClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configuration, func(), error) { +func newAuthClient(srvAddrs []string, qsize ...int) (*Configuration, func(), error) { quorumSize := len(srvAddrs) if len(qsize) > 0 { quorumSize = qsize[0] @@ -169,11 +165,9 @@ func newAuthClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configu lis net.Listener err error ) - if listenAddr != "" { - lis, err = net.Listen("tcp", "127.0.0.1:") - if err != nil { - return nil, nil, err - } + lis, err = net.Listen("tcp", "127.0.0.1:") + if err != nil { + return nil, nil, err } auth, err := authentication.NewWithAddr(elliptic.P256(), lis.Addr()) if err != nil { @@ -185,11 +179,9 @@ func newAuthClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configu grpc.WithTransportCredentials(insecure.NewCredentials()), ), ) - if listenAddr != "" { - err = mgr.AddClientServer(lis, lis.Addr()) - if err != nil { - return nil, nil, err - } + err = mgr.AddClientServer(lis, lis.Addr()) + if err != nil { + return nil, nil, err } config, err := mgr.NewConfiguration( gorums.WithNodeList(srvAddrs), @@ -198,7 +190,5 @@ func newAuthClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configu if err != nil { return nil, nil, err } - return config, func() { - mgr.Close() - }, nil + return config, mgr.Close, nil } diff --git a/tests/broadcast/server.go b/tests/broadcast/server.go index 61a780d30..b484dc877 100644 --- a/tests/broadcast/server.go +++ b/tests/broadcast/server.go @@ -36,7 +36,7 @@ type testServer struct { order []string } -func newtestServer(addr string, srvAddresses []string, _ int, withOrder ...bool) *testServer { +func newTestServer(addr string, srvAddresses []string, _ int, withOrder ...bool) *testServer { address, err := net.ResolveTCPAddr("tcp", addr) if err != nil { panic(err)