diff --git a/authentication/authentication.go b/authentication/authentication.go index 0b6647cf0..7ac2e5a91 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -8,51 +8,43 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "errors" "fmt" "net" ) -// Elliptic Curve Cryptography (ECC) is a key-based technique for encrypting data. -// ECC focuses on pairs of public and private keys for decryption and encryption of web traffic. -// ECC is frequently discussed in the context of the Rivest–Shamir–Adleman (RSA) cryptographic algorithm. -// -// https://pkg.go.dev/github.com/katzenpost/core/crypto/eddsa +// EllipticCurve represents an elliptic curve instance used for message authentication. type EllipticCurve struct { - addr net.Addr // used to identify self - pubKeyCurve elliptic.Curve // http://golang.org/pkg/crypto/elliptic/#P256 - privateKey *ecdsa.PrivateKey - publicKey *ecdsa.PublicKey + addr net.Addr // used to identify self + curve elliptic.Curve + privateKey *ecdsa.PrivateKey + publicKey *ecdsa.PublicKey } // New EllipticCurve instance func New(curve elliptic.Curve) *EllipticCurve { return &EllipticCurve{ - pubKeyCurve: curve, - privateKey: new(ecdsa.PrivateKey), + curve: curve, + privateKey: new(ecdsa.PrivateKey), } } -// GenerateKeys EllipticCurve public and private keys -func (ec *EllipticCurve) GenerateKeys() error { - privKey, err := ecdsa.GenerateKey(ec.pubKeyCurve, rand.Reader) +// NewWithAddr returns a new EllipticCurve instance with the given address. +// It generates a new public-private key pair for the specified elliptic curve. +func NewWithAddr(curve elliptic.Curve, addr net.Addr) (*EllipticCurve, error) { + if addr == nil { + return nil, errors.New("address cannot be nil") + } + priv, err := ecdsa.GenerateKey(curve, rand.Reader) if err != nil { - return err + return nil, err } - ec.privateKey = privKey - ec.publicKey = &privKey.PublicKey - return nil -} - -// RegisterKeys EllipticCurve public and private keys -func (ec *EllipticCurve) RegisterKeys(addr net.Addr, privKey *ecdsa.PrivateKey, pubKey *ecdsa.PublicKey) { - ec.addr = addr - ec.privateKey = privKey - ec.publicKey = pubKey -} - -// Returns the EllipticCurve public and private keys -func (ec *EllipticCurve) Keys() (*ecdsa.PrivateKey, *ecdsa.PublicKey) { - return ec.privateKey, ec.publicKey + return &EllipticCurve{ + addr: addr, + curve: curve, + privateKey: priv, + publicKey: &priv.PublicKey, + }, nil } // Returns the address @@ -93,7 +85,7 @@ func (ec *EllipticCurve) DecodePrivate(pemEncodedPriv string) (*ecdsa.PrivateKey func (ec *EllipticCurve) DecodePublic(pemEncodedPub string) (*ecdsa.PublicKey, error) { blockPub, _ := pem.Decode([]byte(pemEncodedPub)) if blockPub == nil { - return nil, fmt.Errorf("invalid publicKey") + return nil, errors.New("invalid public key") } x509EncodedPub := blockPub.Bytes genericPublicKey, err := x509.ParsePKIXPublicKey(x509EncodedPub) @@ -102,76 +94,40 @@ func (ec *EllipticCurve) DecodePublic(pemEncodedPub string) (*ecdsa.PublicKey, e } func (ec *EllipticCurve) Sign(msg []byte) ([]byte, error) { - h := sha256.Sum256(msg) - hash := h[:] - signature, err := ecdsa.SignASN1(rand.Reader, ec.privateKey, hash) - if err != nil { - return nil, err - } - return signature, nil + return ecdsa.SignASN1(rand.Reader, ec.privateKey, Hash(msg)) } -func (ec *EllipticCurve) Hash(msg []byte) []byte { - h := sha256.Sum256(msg) - hash := h[:] - return hash +func Hash(msg []byte) []byte { + hash := sha256.Sum256(msg) + return hash[:] } -// VerifySignature sign ecdsa style and verify signature -func (ec *EllipticCurve) VerifySignature(pemEncodedPub string, msg, signature []byte) (bool, error) { - h := sha256.Sum256(msg) - hash := h[:] - pubKey, err := ec.DecodePublic(pemEncodedPub) - if err != nil { - return false, err - } - ok := ecdsa.VerifyASN1(pubKey, hash, signature) - return ok, nil -} +var InvalidSignatureErr = errors.New("invalid signature") -// VerifySignature sign ecdsa style and verify signature -func (ec *EllipticCurve) SignAndVerify(privKey *ecdsa.PrivateKey, pubKey *ecdsa.PublicKey) ([]byte, bool, error) { - h := sha256.Sum256([]byte("test")) - hash := h[:] - signature, err := ecdsa.SignASN1(rand.Reader, privKey, hash) +// VerifySignature verifies the signature of the message's hash using the given PEM encoded +// public key. It returns an error if the signature is invalid or if there is an error +// decoding the public key. +func (ec *EllipticCurve) VerifySignature(pemEncodedPub string, msg, signature []byte) error { + pubKey, err := ec.DecodePublic(pemEncodedPub) if err != nil { - return nil, false, err + return err } - ok := ecdsa.VerifyASN1(pubKey, hash, signature) - return signature, ok, nil -} - -func (ec *EllipticCurve) EncodeMsg(msg any) ([]byte, error) { - return []byte(fmt.Sprintf("%v", msg)), nil - /*var encodedMsg bytes.Buffer - gob.Register(msg) - enc := gob.NewEncoder(&encodedMsg) - err := enc.Encode(msg) - if err != nil { - return nil, err + if valid := ecdsa.VerifyASN1(pubKey, Hash(msg), signature); !valid { + return InvalidSignatureErr } - return encodedMsg.Bytes(), nil*/ + return nil } -func encodeMsg(msg any) ([]byte, error) { - return []byte(fmt.Sprintf("%v", msg)), nil +func EncodeMsg(msg any) []byte { + return fmt.Appendf(nil, "%v", msg) } -func Verify(pemEncodedPub string, signature, digest []byte, msg any) (bool, error) { - encodedMsg, err := encodeMsg(msg) - if err != nil { - return false, err - } +func Verify(pemEncodedPub string, signature, digest []byte, msg any) error { + encodedMsg := EncodeMsg(msg) ec := New(elliptic.P256()) - h := sha256.Sum256(encodedMsg) - hash := h[:] + hash := Hash(encodedMsg) if !bytes.Equal(hash, digest) { - return false, fmt.Errorf("wrong digest") - } - pubKey, err := ec.DecodePublic(pemEncodedPub) - if err != nil { - return false, err + return errors.New("invalid digest for message") } - ok := ecdsa.VerifyASN1(pubKey, hash, signature) - return ok, nil + return ec.VerifySignature(pemEncodedPub, encodedMsg, signature) } diff --git a/authentication/authentication_test.go b/authentication/authentication_test.go index 579f30454..39b126ba4 100644 --- a/authentication/authentication_test.go +++ b/authentication/authentication_test.go @@ -3,38 +3,42 @@ package authentication import ( "crypto/elliptic" "errors" + "net" "reflect" "testing" ) func TestAuthentication(t *testing.T) { - ec := New(elliptic.P256()) - _ = ec.GenerateKeys() - err := ec.test() + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5000") + if err != nil { + t.Fatal(err) + } + ec, err := NewWithAddr(elliptic.P256(), addr) + if err != nil { + t.Fatal(err) + } + err = ec.test() if err != nil { t.Error(err) } } func TestSignAndVerify(t *testing.T) { - ec1 := New(elliptic.P256()) - err := ec1.GenerateKeys() + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5000") if err != nil { t.Fatal(err) } - - ec2 := New(elliptic.P256()) - err = ec2.GenerateKeys() + ec1, err := NewWithAddr(elliptic.P256(), addr) if err != nil { t.Fatal(err) } - - message := "This is a message" - - encodedMsg1, err := ec1.EncodeMsg(message) + ec2, err := NewWithAddr(elliptic.P256(), addr) if err != nil { - t.Error(err) + t.Fatal(err) } + + message := "This is a message" + encodedMsg1 := EncodeMsg(message) signature, err := ec1.Sign(encodedMsg1) if err != nil { t.Error(err) @@ -44,37 +48,29 @@ func TestSignAndVerify(t *testing.T) { t.Error(err) } - encodedMsg2, err := ec2.EncodeMsg(message) - if err != nil { - t.Error(err) - } - ok, err := ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) + encodedMsg2 := EncodeMsg(message) + err = ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) if err != nil { - t.Error(err) - } - if !ok { - t.Error("signature not ok!") + t.Errorf("VerifySignature() = %v, want nil", err) } } func TestVerifyWithWrongPubKey(t *testing.T) { - ec1 := New(elliptic.P256()) - err := ec1.GenerateKeys() + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5000") if err != nil { t.Fatal(err) } - - ec2 := New(elliptic.P256()) - err = ec2.GenerateKeys() + ec1, err := NewWithAddr(elliptic.P256(), addr) if err != nil { t.Fatal(err) } - - message := "This is a message" - encodedMsg1, err := ec1.EncodeMsg(message) + ec2, err := NewWithAddr(elliptic.P256(), addr) if err != nil { - t.Error(err) + t.Fatal(err) } + + message := "This is a message" + encodedMsg1 := EncodeMsg(message) signature, err := ec1.Sign(encodedMsg1) if err != nil { t.Error(err) @@ -86,16 +82,14 @@ func TestVerifyWithWrongPubKey(t *testing.T) { t.Error(err) } - encodedMsg2, err := ec2.EncodeMsg(message) - if err != nil { - t.Error(err) - } - ok, err := ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) - if err != nil { - t.Error(err) - } - if ok { - t.Error("signature should not be ok!") + encodedMsg2 := EncodeMsg(message) + err = ec2.VerifySignature(pemEncodedPub, encodedMsg2, signature) + if err == nil { + t.Errorf("VerifySignature() = nil, want %v", InvalidSignatureErr) + } else { + if !errors.Is(err, InvalidSignatureErr) { + t.Errorf("VerifySignature() = %v, want %v", err, InvalidSignatureErr) + } } } diff --git a/broadcast.go b/broadcast.go index 791b5cfe6..df3a3eb31 100644 --- a/broadcast.go +++ b/broadcast.go @@ -2,7 +2,6 @@ package gorums import ( "context" - "crypto/elliptic" "hash/fnv" "log/slog" "strings" @@ -19,11 +18,6 @@ import ( // exposing the log entry struct used for structured logging to the user type LogEntry logging.LogEntry -// exposing the ellipticCurve struct for the user -func NewAuth(curve elliptic.Curve) *authentication.EllipticCurve { - return authentication.New(curve) -} - type broadcastServer struct { viewMutex sync.RWMutex id uint32 @@ -75,11 +69,10 @@ type ( type ( defaultImplementationFunc[T protoreflect.ProtoMessage, V protoreflect.ProtoMessage] func(ServerCtx, T) (V, error) + implementationFunc[T protoreflect.ProtoMessage, V Broadcaster] func(ServerCtx, T, V) clientImplementationFunc[T protoreflect.ProtoMessage, V protoreflect.ProtoMessage] func(context.Context, T, uint64) (V, error) ) -type implementationFunc[T protoreflect.ProtoMessage, V Broadcaster] func(ServerCtx, T, V) - func CancelFunc(ServerCtx, protoreflect.ProtoMessage, Broadcaster) {} const Cancellation string = "cancel" @@ -163,7 +156,7 @@ func NewBroadcastOptions() broadcast.BroadcastOptions { } } -type Broadcaster interface{} +type Broadcaster any type BroadcastMetadata struct { BroadcastID uint64 @@ -208,6 +201,7 @@ func newBroadcastMetadata(md *ordering.Metadata) BroadcastMetadata { } } -func (md BroadcastMetadata) Verify(msg protoreflect.ProtoMessage) (bool, error) { +// TODO(meling): this method is never called +func (md BroadcastMetadata) Verify(msg protoreflect.ProtoMessage) error { return authentication.Verify(md.OriginPubKey, md.OriginSignature, md.OriginDigest, msg) } diff --git a/broadcast/manager.go b/broadcast/manager.go index be28222d2..a8762527d 100644 --- a/broadcast/manager.go +++ b/broadcast/manager.go @@ -40,9 +40,8 @@ func NewBroadcastManager(logger *slog.Logger, createClient func(addr string, dia } func (mgr *manager) Process(msg *Content) (context.Context, func(*Msg) error, error) { - _, shardID, _, _ := DecodeBroadcastID(msg.BroadcastID) - shardID = shardID % NumShards - shard := mgr.state.shards[shardID] + shardID := ExtractShardID(msg.BroadcastID) + shard := mgr.state.getShard(shardID) // we only need a single response receiveChan := make(chan shardResponse, 1) @@ -67,8 +66,7 @@ func (mgr *manager) Broadcast(broadcastID uint64, req protoreflect.ProtoMessage, return enqueueBroadcast(msg) } // slow path: communicate with the shard first - _, shardID, _, _ := DecodeBroadcastID(broadcastID) - shardID = shardID % NumShards + shardID := ExtractShardID(msg.BroadcastID) shard := mgr.state.getShard(shardID) shard.handleBMsg(msg) return nil @@ -88,8 +86,7 @@ func (mgr *manager) SendToClient(broadcastID uint64, resp protoreflect.ProtoMess return enqueueBroadcast(msg) } // slow path: communicate with the shard first - _, shardID, _, _ := DecodeBroadcastID(broadcastID) - shardID = shardID % NumShards + shardID := ExtractShardID(msg.BroadcastID) shard := mgr.state.getShard(shardID) shard.handleBMsg(msg) return nil @@ -107,8 +104,7 @@ func (mgr *manager) Cancel(broadcastID uint64, srvAddrs []string, enqueueBroadca if enqueueBroadcast != nil { return enqueueBroadcast(msg) } - _, shardID, _, _ := DecodeBroadcastID(broadcastID) - shardID = shardID % NumShards + shardID := ExtractShardID(msg.BroadcastID) shard := mgr.state.getShard(shardID) shard.handleBMsg(msg) return nil @@ -128,8 +124,7 @@ func (mgr *manager) Done(broadcastID uint64, enqueueBroadcast func(*Msg) error) _ = enqueueBroadcast(msg) return } - _, shardID, _, _ := DecodeBroadcastID(broadcastID) - shardID = shardID % NumShards + shardID := ExtractShardID(msg.BroadcastID) shard := mgr.state.getShard(shardID) shard.handleBMsg(msg) } diff --git a/broadcast/processor.go b/broadcast/processor.go index 55f22e5aa..a3663b836 100644 --- a/broadcast/processor.go +++ b/broadcast/processor.go @@ -3,6 +3,7 @@ package broadcast import ( "context" "log/slog" + "slices" "time" "github.com/relab/gorums/logging" @@ -116,7 +117,7 @@ func (p *BroadcastProcessor) handleCancellation(bMsg *Msg, metadata *metadata) b func (p *BroadcastProcessor) handleBroadcast(bMsg *Msg, methods []string, metadata *metadata) bool { // check if msg has already been broadcasted for this method - //if alreadyBroadcasted(p.metadata.Methods, bMsg.Method) { + // if alreadyBroadcasted(p.metadata.Methods, bMsg.Method) { if !bMsg.Msg.options.AllowDuplication && alreadyBroadcasted(methods, bMsg.Method) { return false } @@ -135,7 +136,7 @@ func (p *BroadcastProcessor) handleReply(bMsg *Msg, metadata *metadata) bool { err := p.router.Send(broadcastID, originAddr, originMethod, metadata.OriginDigest, metadata.OriginSignature, metadata.OriginPubKey, replyMsg) p.log("broadcast: sent reply to client", err, logging.Method(originMethod), logging.MsgType(bMsg.MsgType.String()), logging.Stopping(true), logging.IsBroadcastCall(metadata.isBroadcastCall())) }(p.broadcastID, metadata.OriginAddr, metadata.OriginMethod, bMsg.Reply) - // the request is done becuase we have sent a reply to the client + // the request is done because we have sent a reply to the client p.log("broadcast: sending reply to client", nil, logging.Method(metadata.OriginMethod), logging.MsgType(bMsg.MsgType.String()), logging.Stopping(true), logging.IsBroadcastCall(metadata.isBroadcastCall())) return true } @@ -161,7 +162,7 @@ func (p *BroadcastProcessor) handleReply(bMsg *Msg, metadata *metadata) bool { // is done. return metadata.hasReceivedClientRequest() } - // the request is done becuase we have sent a reply to the client + // the request is done because we have sent a reply to the client p.log("broadcast: sending reply to client", err, logging.Method(metadata.OriginMethod), logging.MsgType(bMsg.MsgType.String()), logging.Stopping(true), logging.IsBroadcastCall(metadata.isBroadcastCall())) return true } @@ -432,12 +433,7 @@ func (r *BroadcastProcessor) dispatchOutOfOrderMsgs() { } func alreadyBroadcasted(methods []string, method string) bool { - for _, m := range methods { - if m == method { - return true - } - } - return false + return slices.Contains(methods, method) } func (p *BroadcastProcessor) initialize(msg *Content, metadata *metadata) { diff --git a/broadcast/processor_test.go b/broadcast/processor_test.go index d38e7cc1b..08720f012 100644 --- a/broadcast/processor_test.go +++ b/broadcast/processor_test.go @@ -41,7 +41,7 @@ func TestHandleBroadcastOption1(t *testing.T) { snowflake := NewSnowflake(0) broadcastID := snowflake.NewBroadcastID() - var tests = []struct { + tests := []struct { in *Content out error }{ @@ -159,7 +159,7 @@ func TestHandleBroadcastCall1(t *testing.T) { snowflake := NewSnowflake(0) broadcastID := snowflake.NewBroadcastID() - var tests = []struct { + tests := []struct { in *Content out error }{ @@ -290,7 +290,7 @@ func BenchmarkHandleProcessor(b *testing.B) { b.ResetTimer() b.Run("ProcessorHandler", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for b.Loop() { msg := &Content{ BroadcastID: broadcastID, IsBroadcastClient: true, diff --git a/broadcast/router.go b/broadcast/router.go index 420bbfd78..561afd1ea 100644 --- a/broadcast/router.go +++ b/broadcast/router.go @@ -61,7 +61,7 @@ func (r *BroadcastRouter) registerState(state *BroadcastState) { r.state = state } -type msg interface{} +type msg any func (r *BroadcastRouter) Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error { if r.addr == "" { diff --git a/broadcast/shard_test.go b/broadcast/shard_test.go index f2fa05257..6f2b7fcf5 100644 --- a/broadcast/shard_test.go +++ b/broadcast/shard_test.go @@ -40,8 +40,7 @@ func TestShard(t *testing.T) { returnError: false, } shardBuffer := 100 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() shard := &shard{ id: 0, parentCtx: ctx, diff --git a/broadcast/snowflake.go b/broadcast/snowflake.go index d3665f0e0..7ec84a70d 100644 --- a/broadcast/snowflake.go +++ b/broadcast/snowflake.go @@ -16,14 +16,22 @@ type Snowflake struct { } const ( - MaxMachineID = uint16(1 << 12) - maxShard = uint8(1 << 4) - maxSequenceNum = uint32(1 << 18) - bitMaskTimestamp = uint64((1<<30)-1) << 34 - bitMaskShardID = uint64((1<<4)-1) << 30 - bitMaskMachineID = uint64((1<<12)-1) << 18 - bitMaskSequenceNum = uint64((1 << 18) - 1) - epoch = "2024-01-01T00:00:00" + timestampBits = 30 // seconds since 01.01.2025 + shardIDBits = 4 // 16 different shards + machineIDBits = 12 // 4096 clients + sequenceNumBits = 18 // 262 144 messages + timestampBitsShift = 64 - timestampBits // 34 + + maxShard = uint8(1 << shardIDBits) + maxMachineID = uint16(1 << machineIDBits) + maxSequenceNum = uint32(1 << sequenceNumBits) + + bitMaskTimestamp = uint64((1<= uint64(MaxMachineID) { - id = uint64(rand.Int31n(int32(MaxMachineID))) + if id == 0 || id > uint64(maxMachineID) { + id = uint64(rand.Int31n(int32(maxMachineID))) + 1 // avoid 0 as the machine ID } return &Snowflake{ MachineID: id, @@ -43,10 +54,6 @@ func NewSnowflake(id uint64) *Snowflake { } func (s *Snowflake) NewBroadcastID() uint64 { - // timestamp: 30 bit -> seconds since 01.01.2024 - // shardID: 4 bit -> 16 different shards - // machineID: 12 bit -> 4096 clients - // sequenceNum: 18 bit -> 262 144 messages start: s.mut.Lock() timestamp := uint64(time.Since(s.epoch).Seconds()) @@ -63,17 +70,22 @@ start: s.SequenceNum = l s.mut.Unlock() - t := (timestamp << 34) & bitMaskTimestamp - shard := (uint64(rand.Int31n(int32(maxShard))) << 30) & bitMaskShardID - m := uint64(s.MachineID<<18) & bitMaskMachineID + t := (timestamp << timestampBitsShift) & bitMaskTimestamp + shard := (uint64(rand.Int31n(int32(maxShard))) << timestampBits) & bitMaskShardID + m := uint64(s.MachineID<> 34 - shard := (broadcastID & bitMaskShardID) >> 30 - m := (broadcastID & bitMaskMachineID) >> 18 + t := (broadcastID & bitMaskTimestamp) >> timestampBitsShift + shard := (broadcastID & bitMaskShardID) >> timestampBits + m := (broadcastID & bitMaskMachineID) >> sequenceNumBits n := (broadcastID & bitMaskSequenceNum) return uint32(t), uint16(shard), uint16(m), uint32(n) } + +// ExtractShardID extracts the shard ID from the broadcast ID. +func ExtractShardID(broadcastID uint64) uint16 { + return uint16((broadcastID & bitMaskShardID) >> timestampBits) +} diff --git a/broadcast/snowflake_test.go b/broadcast/snowflake_test.go index b39dadc48..199cb6271 100644 --- a/broadcast/snowflake_test.go +++ b/broadcast/snowflake_test.go @@ -1,12 +1,13 @@ package broadcast import ( + "fmt" "testing" ) func TestBroadcastID(t *testing.T) { - if MaxMachineID != 4096 { - t.Errorf("maxMachineID is hardcoded in test. want: %v, got: %v", 4096, MaxMachineID) + if maxMachineID != 4096 { + t.Errorf("maxMachineID is hardcoded in test. want: %v, got: %v", 4096, maxMachineID) } if maxSequenceNum != 262144 { t.Errorf("maxSequenceNum is hardcoded in test. want: %v, got: %v", 262144, maxSequenceNum) @@ -46,3 +47,31 @@ func TestBroadcastID(t *testing.T) { } } } + +func TestNewSnowflake(t *testing.T) { + const random = 0 + tests := []struct { + machineID uint64 + wantID uint64 + }{ + {0, random}, // should generate a random machine ID + {4097, random}, // should generate a random machine ID + {1, 1}, // use expected machine ID + {4096, 4096}, // use expected machine ID + {1234, 1234}, // use expected machine ID + } + for _, tt := range tests { + t.Run(fmt.Sprintf("MachineID=%d", tt.machineID), func(t *testing.T) { + snowflake := NewSnowflake(tt.machineID) + if tt.wantID == random { + if snowflake.MachineID == 0 || snowflake.MachineID > 4096 { + t.Errorf("NewSnowflake(%d) = %d, want random machine ID in range [1, 4096]", tt.machineID, snowflake.MachineID) + } + return + } + if snowflake.MachineID != tt.wantID { + t.Errorf("NewSnowflake should use the provided machine ID. got: %v", snowflake.MachineID) + } + }) + } +} diff --git a/broadcastcall.go b/broadcastcall.go index 2de90d516..fd2b38b71 100644 --- a/broadcastcall.go +++ b/broadcastcall.go @@ -2,6 +2,7 @@ package gorums import ( "context" + "slices" "github.com/relab/gorums/ordering" "google.golang.org/protobuf/reflect/protoreflect" @@ -25,18 +26,13 @@ type BroadcastCallData struct { SkipSelf bool } -// checks whether the given address is contained in the given subset +// inSubset returns true if the given address is in the given subset // of server addresses. Will return true if a subset is not given. func (bcd *BroadcastCallData) inSubset(addr string) bool { - if bcd.ServerAddresses == nil || len(bcd.ServerAddresses) <= 0 { + if bcd == nil || len(bcd.ServerAddresses) <= 0 { return true } - for _, srvAddr := range bcd.ServerAddresses { - if addr == srvAddr { - return true - } - } - return false + return slices.Contains(bcd.ServerAddresses, addr) } // BroadcastCall performs a broadcast on the configuration. diff --git a/channel_test.go b/channel_test.go index ae0b19875..f0a400c77 100644 --- a/channel_test.go +++ b/channel_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/relab/gorums/authentication" "github.com/relab/gorums/ordering" "github.com/relab/gorums/tests/mock" "google.golang.org/grpc" @@ -200,10 +201,10 @@ func TestAuthentication(t *testing.T) { if err != nil { t.Fatal(err) } - auth := NewAuth(elliptic.P256()) - _ = auth.GenerateKeys() - privKey, pubKey := auth.Keys() - auth.RegisterKeys(addr, privKey, pubKey) + auth, err := authentication.NewWithAddr(elliptic.P256(), addr) + if err != nil { + t.Fatal(err) + } mgr := NewRawManager(WithAuthentication(auth)) defer mgr.Close() node.mgr = mgr @@ -223,8 +224,8 @@ func TestAuthentication(t *testing.T) { msg := &Message{Metadata: md, Message: &mock.Request{}} msg1 := &Message{Metadata: md, Message: &mock.Request{}} - chEncodedMsg, _ := config.encodeMsg(msg1) - srvEncodedMsg, _ := srv.srv.encodeMsg(msg1) + chEncodedMsg := config.encodeMsg(msg1) + srvEncodedMsg := msg1.Encode() if !bytes.Equal(chEncodedMsg, srvEncodedMsg) { t.Fatalf("wrong encoding. want: %x, got: %x", chEncodedMsg, srvEncodedMsg) } @@ -236,19 +237,15 @@ func TestAuthentication(t *testing.T) { if err != nil { t.Fatal(err) } - valid, err := auth.VerifySignature(pemEncodedPub, chEncodedMsg, signature) + err = auth.VerifySignature(pemEncodedPub, chEncodedMsg, signature) if err != nil { - t.Fatal(err) + // channel encoded msg not valid + t.Fatalf("VerifySignature() = %v, want nil", err) } - if !valid { - t.Fatal("channel encoded msg not valid") - } - valid, err = auth.VerifySignature(pemEncodedPub, srvEncodedMsg, signature) + err = auth.VerifySignature(pemEncodedPub, srvEncodedMsg, signature) if err != nil { - t.Fatal(err) - } - if !valid { - t.Fatal("srv encoded msg not valid") + // srv encoded msg not valid + t.Fatalf("VerifySignature() = %v, want nil", err) } config.sign(msg) diff --git a/clientserver.go b/clientserver.go index 0fa377ef5..f37d844ea 100644 --- a/clientserver.go +++ b/clientserver.go @@ -36,7 +36,7 @@ type ClientServer struct { handlers map[string]requestHandler logger *slog.Logger auth *authentication.EllipticCurve - allowList map[string]string + allowList allowList ordering.UnimplementedGorumsServer } @@ -250,57 +250,18 @@ func (srv *ClientServer) Serve(listener net.Listener) error { return srv.grpcServer.Serve(listener) } -func (srv *ClientServer) encodeMsg(req *Message) ([]byte, error) { - // we must not consider the signature field when validating. - // also the msgType must be set to requestType. - signature := make([]byte, len(req.Metadata.GetAuthMsg().GetSignature())) - copy(signature, req.Metadata.GetAuthMsg().GetSignature()) - reqType := req.msgType - req.Metadata.GetAuthMsg().SetSignature(nil) - req.msgType = 0 - encodedMsg, err := srv.auth.EncodeMsg(*req) - req.Metadata.GetAuthMsg().SetSignature(make([]byte, len(signature))) - // TODO(meling): I think this is incorrect and should be done differently. - copy(req.Metadata.GetAuthMsg().GetSignature(), signature) - req.msgType = reqType - return encodedMsg, err -} - func (srv *ClientServer) verify(req *Message) error { if srv.auth == nil { return nil } - if req.Metadata.GetAuthMsg() == nil { - return fmt.Errorf("missing authMsg") - } - if req.Metadata.GetAuthMsg().GetSignature() == nil { - return fmt.Errorf("missing signature") - } - if req.Metadata.GetAuthMsg().GetPublicKey() == "" { - return fmt.Errorf("missing publicKey") - } - authMsg := req.Metadata.GetAuthMsg() - if srv.allowList != nil { - pemEncodedPub, ok := srv.allowList[authMsg.GetSender()] - if !ok { - return fmt.Errorf("not allowed") - } - if pemEncodedPub != authMsg.GetPublicKey() { - return fmt.Errorf("publicKey did not match") - } - } - encodedMsg, err := srv.encodeMsg(req) + authMsg, err := req.Metadata.GetValidAuthMsg() if err != nil { return err } - valid, err := srv.auth.VerifySignature(authMsg.GetPublicKey(), encodedMsg, authMsg.GetSignature()) - if err != nil { + if err := srv.allowList.Check(authMsg.GetSender(), authMsg.GetPublicKey()); err != nil { return err } - if !valid { - return fmt.Errorf("invalid signature") - } - return nil + return srv.auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) } func createClient(addr string, dialOpts []grpc.DialOption) (*broadcast.Client, error) { diff --git a/config.go b/config.go index ead550cf2..57c59966e 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package gorums import ( "fmt" + "github.com/relab/gorums/authentication" "github.com/relab/gorums/ordering" ) @@ -64,18 +65,16 @@ func (c RawConfiguration) getMsgID() uint64 { } func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { - if c[0].mgr.opts.auth != nil { + auth := c[0].mgr.opts.auth + if auth != nil { if len(signOrigin) > 0 && signOrigin[0] { - originMsg, err := c[0].mgr.opts.auth.EncodeMsg(msg.Message) - if err != nil { - panic(err) - } - digest := c[0].mgr.opts.auth.Hash(originMsg) - originSignature, err := c[0].mgr.opts.auth.Sign(originMsg) + originMsg := authentication.EncodeMsg(msg.Message) + digest := authentication.Hash(originMsg) + originSignature, err := auth.Sign(originMsg) if err != nil { panic(err) } - pubKey, err := c[0].mgr.opts.auth.EncodePublic() + pubKey, err := auth.EncodePublic() if err != nil { panic(err) } @@ -83,11 +82,8 @@ func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { msg.Metadata.GetBroadcastMsg().SetOriginPubKey(pubKey) msg.Metadata.GetBroadcastMsg().SetOriginSignature(originSignature) } - encodedMsg, err := c.encodeMsg(msg) - if err != nil { - panic(err) - } - signature, err := c[0].mgr.opts.auth.Sign(encodedMsg) + encodedMsg := c.encodeMsg(msg) + signature, err := auth.Sign(encodedMsg) if err != nil { panic(err) } @@ -95,7 +91,7 @@ func (c RawConfiguration) sign(msg *Message, signOrigin ...bool) { } } -func (c RawConfiguration) encodeMsg(msg *Message) ([]byte, error) { +func (c RawConfiguration) encodeMsg(msg *Message) []byte { // we do not want to include the signature field in the signature auth := c[0].mgr.opts.auth pubKey, err := auth.EncodePublic() @@ -107,5 +103,5 @@ func (c RawConfiguration) encodeMsg(msg *Message) ([]byte, error) { Signature: nil, Sender: auth.Addr(), }.Build()) - return auth.EncodeMsg(*msg) + return authentication.EncodeMsg(*msg) } diff --git a/encoding.go b/encoding.go index ce8fedb31..d69b4c6ae 100644 --- a/encoding.go +++ b/encoding.go @@ -1,8 +1,10 @@ package gorums import ( + "errors" "fmt" + "github.com/relab/gorums/authentication" "github.com/relab/gorums/ordering" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" @@ -35,6 +37,48 @@ func newMessage(msgType gorumsMsgType) *Message { return &Message{Metadata: &ordering.Metadata{}, msgType: msgType} } +// isValid returns an error if the Message is an invalid broadcast message. +func (m *Message) isValid() error { + if m == nil { + return errors.New("message cannot be nil") + } + if m.Metadata == nil { + return errors.New("message metadata cannot be nil") + } + broadcastMsg := m.Metadata.GetBroadcastMsg() + if broadcastMsg == nil { + return errors.New("broadcast message cannot be nil") + } + if broadcastMsg.GetBroadcastID() == 0 { + return errors.New("broadcastID cannot be 0") + } + return nil +} + +// Encode returns an encoded byte representation of the Message +// ignoring the message type and signature in the Auth message. +func (m *Message) Encode() []byte { + authMsg := m.Metadata.GetAuthMsg() + + // save the original signature and msgType. + origSignature := authMsg.GetSignature() + sigCopy := make([]byte, len(origSignature)) + copy(sigCopy, origSignature) + origMsgType := m.msgType + + // prepare for encoding: remove the signature and set msgType to 0 + authMsg.SetSignature(nil) + m.msgType = 0 + + encoded := authentication.EncodeMsg(*m) + + // restore the original values + authMsg.SetSignature(sigCopy) + m.msgType = origMsgType + + return encoded +} + // Codec is the gRPC codec used by gorums. type Codec struct { marshaler proto.MarshalOptions diff --git a/handler.go b/handler.go index 8264e0b9f..1d91efb2d 100644 --- a/handler.go +++ b/handler.go @@ -2,7 +2,6 @@ package gorums import ( "context" - "fmt" "github.com/relab/gorums/broadcast" "github.com/relab/gorums/ordering" @@ -35,9 +34,9 @@ func BroadcastHandler[T protoreflect.ProtoMessage, V Broadcaster](impl implement // guard: // - A broadcastID should be non-empty: // - Maybe the request should be unique? Remove duplicates of the same broadcast? <- Most likely no (up to the implementer) - if err := srv.broadcastSrv.validateMessage(in); err != nil { + if err := in.isValid(); err != nil { if srv.broadcastSrv.logger != nil { - srv.broadcastSrv.logger.Debug("broadcast request not valid", "metadata", in.Metadata, "err", err) + srv.broadcastSrv.logger.Debug("invalid broadcast request", "metadata", in.Metadata, "err", err) } return } @@ -65,14 +64,13 @@ func BroadcastHandler[T protoreflect.ProtoMessage, V Broadcaster](impl implement impl(ctx, req, broadcaster) } - msg := broadcast.Content{} - createRequest(&msg, ctx, in, finished, run) + msg := createRequest(ctx, in.Metadata, finished, run) // we are not interested in the server context as this is tied to the previous hop. // instead we want to check whether the client has cancelled the broadcast request // and if so, we return a cancelled context. This enables the implementer to listen // for cancels and do proper actions. - reqCtx, enqueueBroadcast, err := srv.broadcastSrv.manager.Process(&msg) + reqCtx, enqueueBroadcast, err := srv.broadcastSrv.manager.Process(msg) if err != nil { return } @@ -81,33 +79,31 @@ func BroadcastHandler[T protoreflect.ProtoMessage, V Broadcaster](impl implement } } -func createRequest(msg *broadcast.Content, ctx ServerCtx, in *Message, finished chan<- *Message, run func(context.Context, func(*broadcast.Msg) error)) { - msg.BroadcastID = in.Metadata.GetBroadcastMsg().GetBroadcastID() - msg.IsBroadcastClient = in.Metadata.GetBroadcastMsg().GetIsBroadcastClient() - msg.OriginAddr = in.Metadata.GetBroadcastMsg().GetOriginAddr() - msg.OriginMethod = in.Metadata.GetBroadcastMsg().GetOriginMethod() - msg.SenderAddr = in.Metadata.GetBroadcastMsg().GetSenderAddr() +func createRequest(ctx ServerCtx, md *ordering.Metadata, finished chan<- *Message, run func(context.Context, func(*broadcast.Msg) error)) *broadcast.Content { + broadcastMsg := md.GetBroadcastMsg() + msg := &broadcast.Content{ + BroadcastID: broadcastMsg.GetBroadcastID(), + IsBroadcastClient: broadcastMsg.GetIsBroadcastClient(), + OriginAddr: broadcastMsg.GetOriginAddr(), + OriginMethod: broadcastMsg.GetOriginMethod(), + SenderAddr: broadcastMsg.GetSenderAddr(), + OriginDigest: broadcastMsg.GetOriginDigest(), + OriginSignature: broadcastMsg.GetOriginSignature(), + OriginPubKey: broadcastMsg.GetOriginPubKey(), + CurrentMethod: md.GetMethod(), + Ctx: ctx.Context, + Run: run, + } if msg.SenderAddr == "" && msg.IsBroadcastClient { msg.SenderAddr = "client" } - if in.Metadata.GetBroadcastMsg().GetOriginDigest() != nil { - msg.OriginDigest = in.Metadata.GetBroadcastMsg().GetOriginDigest() - } - if in.Metadata.GetBroadcastMsg().GetOriginSignature() != nil { - msg.OriginSignature = in.Metadata.GetBroadcastMsg().GetOriginSignature() - } - if in.Metadata.GetBroadcastMsg().GetOriginPubKey() != "" { - msg.OriginPubKey = in.Metadata.GetBroadcastMsg().GetOriginPubKey() - } - msg.CurrentMethod = in.Metadata.GetMethod() - msg.Ctx = ctx.Context - msg.Run = run if msg.OriginAddr == "" && msg.IsBroadcastClient { - msg.SendFn = createSendFn(in.Metadata.GetMessageID(), in.Metadata.GetMethod(), finished, ctx) + msg.SendFn = createSendFn(md.GetMessageID(), md.GetMethod(), finished, ctx) } - if in.Metadata.GetMethod() == Cancellation { + if md.GetMethod() == Cancellation { msg.IsCancellation = true } + return msg } func createSendFn(msgID uint64, method string, finished chan<- *Message, ctx ServerCtx) func(resp protoreflect.ProtoMessage, err error) error { @@ -118,22 +114,6 @@ func createSendFn(msgID uint64, method string, finished chan<- *Message, ctx Ser } } -func (srv *broadcastServer) validateMessage(in *Message) error { - if in == nil { - return fmt.Errorf("message cannot be empty. got: %v", in) - } - if in.Metadata == nil { - return fmt.Errorf("metadata cannot be empty. got: %v", in.Metadata) - } - if in.Metadata.GetBroadcastMsg() == nil { - return fmt.Errorf("broadcastMsg cannot be empty. got: %v", in.Metadata.GetBroadcastMsg()) - } - if in.Metadata.GetBroadcastMsg().GetBroadcastID() <= 0 { - return fmt.Errorf("broadcastID cannot be empty. got: %v", in.Metadata.GetBroadcastMsg().GetBroadcastID()) - } - return nil -} - func (srv *Server) RegisterBroadcaster(broadcaster func(m BroadcastMetadata, o *BroadcastOrchestrator, e EnqueueBroadcast) Broadcaster) { srv.broadcastSrv.createBroadcaster = broadcaster srv.broadcastSrv.orchestrator = NewBroadcastOrchestrator(srv) diff --git a/mgr.go b/mgr.go index d4185f5a5..76736d034 100644 --- a/mgr.go +++ b/mgr.go @@ -43,9 +43,7 @@ func NewRawManager(opts ...ManagerOption) *RawManager { } snowflake := broadcast.NewSnowflake(m.opts.machineID) if m.opts.logger != nil { - // a random machineID will be generated if m.opts.machineID is invalid - mID := snowflake.MachineID - m.logger = m.opts.logger.With(slog.Uint64("MachineID", mID)) + m.logger = m.opts.logger.With(logging.MachineID(snowflake.MachineID)) } m.opts.grpcDialOpts = append(m.opts.grpcDialOpts, grpc.WithDefaultCallOptions( grpc.CallContentSubtype(ContentSubtype), diff --git a/opts.go b/opts.go index 27568dd8b..b2c908206 100644 --- a/opts.go +++ b/opts.go @@ -4,7 +4,6 @@ import ( "log/slog" "github.com/relab/gorums/authentication" - "github.com/relab/gorums/broadcast" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/metadata" @@ -26,11 +25,9 @@ type managerOptions struct { func newManagerOptions() managerOptions { return managerOptions{ - backoff: backoff.DefaultConfig, - sendBuffer: 100, - // Provide an illegal machineID to avoid unintentional collisions. - // 0 is a valid MachineID and should not be used as default. - machineID: uint64(broadcast.MaxMachineID) + 1, + backoff: backoff.DefaultConfig, + sendBuffer: 100, + machineID: 0, maxSendRetries: 0, maxConnRetries: -1, // no limit } @@ -104,10 +101,11 @@ func WithAuthentication(auth *authentication.EllipticCurve) ManagerOption { } } -// WithMachineID returns a ManagerOption that allows you to set a unique ID for the client. -// This ID will be embedded in broadcast request sent from the client, making the requests -// trackable by the whole cluster. A random ID will be generated if not set. This can cause -// collisions if there are many clients. MinID = 0 and MaxID = 4095. +// WithMachineID returns a ManagerOption that sets a unique ID for the client. +// The valid range for this ID is 1 to 4095, and it should be unique for each client. +// This ID will be embedded in broadcast requests sent from the client, making client +// requests trackable across the server replicas. A random ID will be generated if not set. +// This can cause collisions if there are many clients. func WithMachineID(id uint64) ManagerOption { return func(o *managerOptions) { o.machineID = id diff --git a/ordering/gorums_metadata.go b/ordering/gorums_metadata.go index cde158e94..90aa4a160 100644 --- a/ordering/gorums_metadata.go +++ b/ordering/gorums_metadata.go @@ -2,6 +2,7 @@ package ordering import ( "context" + "errors" "google.golang.org/grpc/metadata" ) @@ -38,3 +39,20 @@ func (x *Metadata) AppendToIncomingContext(ctx context.Context) context.Context } return metadata.NewIncomingContext(ctx, newMD) } + +func (x *Metadata) GetValidAuthMsg() (*AuthMsg, error) { + if x == nil { + return nil, errors.New("metadata cannot be nil") + } + authMsg := x.GetAuthMsg() + if authMsg == nil { + return nil, errors.New("missing AuthMsg") + } + if authMsg.GetSignature() == nil { + return nil, errors.New("missing signature") + } + if authMsg.GetPublicKey() == "" { + return nil, errors.New("missing publicKey") + } + return authMsg, nil +} diff --git a/server.go b/server.go index b8a00e873..aeb8dc915 100644 --- a/server.go +++ b/server.go @@ -10,7 +10,6 @@ import ( "time" "github.com/relab/gorums/authentication" - "github.com/relab/gorums/broadcast" "github.com/relab/gorums/ordering" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -38,58 +37,18 @@ func newOrderingServer(opts *serverOptions) *orderingServer { return s } -func (s *orderingServer) encodeMsg(req *Message) ([]byte, error) { - // we must not consider the signature field when validating. - // also the msgType must be set to requestType. - signature := make([]byte, len(req.Metadata.GetAuthMsg().GetSignature())) - copy(signature, req.Metadata.GetAuthMsg().GetSignature()) - reqType := req.msgType - req.Metadata.GetAuthMsg().SetSignature(nil) - req.msgType = 0 - encodedMsg, err := s.opts.auth.EncodeMsg(*req) - req.Metadata.GetAuthMsg().SetSignature(make([]byte, len(signature))) - // TODO(meling): I think this is incorrect and should be done differently. - copy(req.Metadata.GetAuthMsg().GetSignature(), signature) - req.msgType = reqType - return encodedMsg, err -} - func (s *orderingServer) verify(req *Message) error { if s.opts.auth == nil { return nil } - if req.Metadata.GetAuthMsg() == nil { - return fmt.Errorf("missing authMsg") - } - if req.Metadata.GetAuthMsg().GetSignature() == nil { - return fmt.Errorf("missing signature") - } - if req.Metadata.GetAuthMsg().GetPublicKey() == "" { - return fmt.Errorf("missing publicKey") - } - auth := s.opts.auth - authMsg := req.Metadata.GetAuthMsg() - if s.opts.allowList != nil { - pemEncodedPub, ok := s.opts.allowList[authMsg.GetSender()] - if !ok { - return fmt.Errorf("not allowed") - } - if pemEncodedPub != authMsg.GetPublicKey() { - return fmt.Errorf("publicKey did not match") - } - } - encodedMsg, err := s.encodeMsg(req) + authMsg, err := req.Metadata.GetValidAuthMsg() if err != nil { return err } - valid, err := auth.VerifySignature(authMsg.GetPublicKey(), encodedMsg, authMsg.GetSignature()) - if err != nil { + if err := s.opts.allowList.Check(authMsg.GetSender(), authMsg.GetPublicKey()); err != nil { return err } - if !valid { - return fmt.Errorf("invalid signature") - } - return nil + return s.opts.auth.VerifySignature(authMsg.GetPublicKey(), req.Encode(), authMsg.GetSignature()) } // SendMessage attempts to send a message on a channel. @@ -182,11 +141,32 @@ type serverOptions struct { // address and use forwarding from the host. If not this option is not given, // the listen address used on the gRPC listener will be used instead. listenAddr string - allowList map[string]string + allowList allowList auth *authentication.EllipticCurve grpcDialOpts []grpc.DialOption } +// allowList is a map of (address, public key) pairs. +type allowList map[string]string + +// Check checks if the address and public key are in the allow list. +// If the address is not in the allow list or the public key does not match, +// an error is returned. If the allow list is nil, no check is performed. +func (al allowList) Check(addr string, publicKey string) error { + if al == nil { + // bypass if no allow list specified + return nil + } + pemEncodedPub, ok := al[addr] + if !ok { + return fmt.Errorf("not allowed: %s", addr) + } + if pemEncodedPub != publicKey { + return fmt.Errorf("public key mismatch") + } + return nil +} + // ServerOption is used to change settings for the GorumsServer type ServerOption func(*serverOptions) @@ -292,9 +272,12 @@ func WithSLogger(logger *slog.Logger) ServerOption { } } -// WithSrvID sets the MachineID of the broadcast server. This ID is used to -// generate BroadcastIDs. This method should be used if a replica needs to -// initiate a broadcast request. +// WithSrvID returns a ServerOption that sets the MachineID of the broadcast server. +// This method should be used if a replica needs to initiate a broadcast request. +// The valid range for this ID is 1 to 4095, and it should be unique for each replica. +// This ID will be embedded in broadcast requests sent from the replica, making requests +// trackable across the replicas. A random ID will be generated if not set. +// This can cause collisions if there are many replicas. // // An example use case is in Paxos: // The designated leader sends a prepare and receives some promises it has @@ -307,14 +290,18 @@ func WithSrvID(machineID uint64) ServerOption { } } -// WithAllowList accepts (address, publicKey) pairs which is used to validate +// WithAllowList accepts (address, public-key) pairs which is used to authenticate // messages. Only nodes added to the allow list are allowed to send msgs to // the server. func WithAllowList(curve elliptic.Curve, allowed ...string) ServerOption { return func(o *serverOptions) { - o.allowList = make(map[string]string) + if len(allowed) == 0 { + o.auth = authentication.New(curve) + return + } + o.allowList = make(allowList) if len(allowed)%2 != 0 { - panic("must provide (address, publicKey) pairs to WithAllowList()") + panic("must provide (address, public-key) pairs to WithAllowList()") } for i := range allowed { if i%2 != 0 { @@ -358,9 +345,7 @@ func NewServer(opts ...ServerOption) *Server { shardBuffer: 200, sendBuffer: 30, reqTTL: 5 * time.Minute, - // Provide an illegal machineID to avoid unintentional collisions. - // 0 is a valid MachineID and should not be used as default. - machineID: uint64(broadcast.MaxMachineID) + 1, + machineID: 0, } for _, opt := range opts { opt(&serverOpts) diff --git a/tests/broadcast/broadcast_test.go b/tests/broadcast/broadcast_test.go index 867c79432..744a46605 100644 --- a/tests/broadcast/broadcast_test.go +++ b/tests/broadcast/broadcast_test.go @@ -17,7 +17,7 @@ func createAuthServers(numSrvs int) ([]*testServer, []string, func(), error) { skip := 0 srvs := make([]*testServer, 0, numSrvs) srvAddrs := make([]string, numSrvs) - for i := 0; i < numSrvs; i++ { + for i := range numSrvs { srvAddrs[i] = fmt.Sprintf("127.0.0.1:500%v", i) } for _, addr := range srvAddrs { @@ -54,7 +54,7 @@ func createSrvs(numSrvs int, down ...int) ([]*testServer, []string, func(), erro } srvs := make([]*testServer, 0, numSrvs) srvAddrs := make([]string, numSrvs) - for i := 0; i < numSrvs; i++ { + for i := range numSrvs { srvAddrs[i] = fmt.Sprintf("127.0.0.1:500%v", i) } for i, addr := range srvAddrs { @@ -64,9 +64,9 @@ func createSrvs(numSrvs int, down ...int) ([]*testServer, []string, func(), erro } var srv *testServer if ordering { - srv = newtestServer(addr, srvAddrs, i, true) + srv = newTestServer(addr, srvAddrs, i, true) } else { - srv = newtestServer(addr, srvAddrs, i) + srv = newTestServer(addr, srvAddrs, i) } lis, err := net.Listen("tcp4", srv.addr) if err != nil { @@ -93,7 +93,7 @@ func TestSimpleBroadcastCall(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -123,7 +123,7 @@ func TestSimpleBroadcastTo(t *testing.T) { defer srvCleanup() // only want a response from the leader, hence qsize = 1 - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080", 1) + config, clientCleanup, err := newClient(srvAddrs, 1) if err != nil { t.Error(err) } @@ -152,7 +152,7 @@ func TestSimpleBroadcastCancel(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -181,7 +181,7 @@ func TestBroadcastCancel(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -192,7 +192,7 @@ func TestBroadcastCancel(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) _, _ = config.LongRunningTask(ctx, &Request{Value: val}, true) cancel() - // wait until cancel has reaced the servers before asking for the result + // wait until cancel has reached the servers before asking for the result time.Sleep(1 * time.Second) ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) resp, err := config.GetVal(ctx, &Request{Value: val}) @@ -218,7 +218,7 @@ func TestBroadcastCancelOneSrvDown(t *testing.T) { defer srvCleanup() // only want response from the online servers - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080", numSrvs-1) + config, clientCleanup, err := newClient(srvAddrs, numSrvs-1) if err != nil { t.Error(err) } @@ -250,7 +250,7 @@ func TestBroadcastCancelOneSrvFails(t *testing.T) { defer srvCleanup() // only want response from the online servers - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080", numSrvs-1) + config, clientCleanup, err := newClient(srvAddrs, numSrvs-1) if err != nil { t.Error(err) } @@ -284,7 +284,7 @@ func TestBroadcastCancelOneClientFails(t *testing.T) { defer srvCleanup() // only want response from the online servers - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -299,7 +299,7 @@ func TestBroadcastCancelOneClientFails(t *testing.T) { clientCleanup() // only want response from the online servers - config2, clientCleanup2, err2 := newClient(srvAddrs, "127.0.0.1:8081") + config2, clientCleanup2, err2 := newClient(srvAddrs) defer clientCleanup2() if err2 != nil { t.Error(err2) @@ -326,7 +326,7 @@ func TestBroadcastCallOrderingSendToOneSrv(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[1:2], "127.0.0.1:8080", len(srvAddrs)) + config, clientCleanup, err := newClient(srvAddrs[1:2], len(srvAddrs)) if err != nil { t.Error(err) } @@ -356,7 +356,7 @@ func TestBroadcastCallOrderingSendToAllSrvs(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -386,7 +386,7 @@ func TestBroadcastCallOrderingDoesNotInterfereWithMethodsNotSpecifiedInOrder(t * } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -414,7 +414,7 @@ func TestBroadcastCallRace(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -446,7 +446,7 @@ func TestBroadcastCallClientKnowsOnlyOneServer(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[0:1], "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs[0:1]) if err != nil { t.Error(err) } @@ -480,15 +480,9 @@ func TestBroadcastCallOneServerIsDown(t *testing.T) { } defer srvCleanup() - start := skip - if start < 0 { - start = 0 - } - end := numSrvs - 1 - if end > len(srvAddrs) { - end = len(srvAddrs) - } - config, clientCleanup, err := newClient(srvAddrs[start:end], "127.0.0.1:8080") + start := max(skip, 0) + end := min(numSrvs-1, len(srvAddrs)) + config, clientCleanup, err := newClient(srvAddrs[start:end]) if err != nil { t.Error(err) } @@ -520,13 +514,13 @@ func TestBroadcastCallForward(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[1:2], "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs[1:2]) if err != nil { t.Error(err) } defer clientCleanup() - for i := 0; i < 10; i++ { + for i := range 10 { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() resp, err := config.BroadcastCallForward(ctx, &Request{Value: int64(i)}) @@ -546,13 +540,13 @@ func TestBroadcastCallForwardMultiple(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[1:], "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs[1:]) if err != nil { t.Error(err) } defer clientCleanup() - for i := 0; i < 10; i++ { + for i := range 10 { resp, err := config.BroadcastCallForward(context.Background(), &Request{Value: int64(i)}) if err != nil { t.Error(err) @@ -570,13 +564,13 @@ func TestBroadcastCallRaceTwoClients(t *testing.T) { } defer srvCleanup() - client1, clientCleanup1, err := newClient(srvAddrs, "127.0.0.1:8080") + client1, clientCleanup1, err := newClient(srvAddrs) if err != nil { t.Error(err) } defer clientCleanup1() - client2, clientCleanup2, err := newClient(srvAddrs, "127.0.0.1:8081") + client2, clientCleanup2, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -639,8 +633,8 @@ func TestBroadcastCallAsyncReqs(t *testing.T) { numClients := 10 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { - config, clientCleanup, err := newClient(srvAddrs, fmt.Sprintf("127.0.0.1:808%v", c), 3) + for c := range numClients { + config, clientCleanup, err := newClient(srvAddrs, 3) if err != nil { t.Error(err) } @@ -662,7 +656,7 @@ func TestBroadcastCallAsyncReqs(t *testing.T) { } var wg sync.WaitGroup - for i := 0; i < 1; i++ { + for i := range 1 { for _, client := range clients { go func(j int, c *Configuration) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -689,7 +683,7 @@ func TestQCBroadcastOptionRace(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -703,7 +697,7 @@ func TestQCBroadcastOptionRace(t *testing.T) { if resp.GetResult() != val { t.Fatalf("resp is wrong, got: %v, want: %v", resp.GetResult(), val) } - for i := 0; i < 100; i++ { + for i := range 100 { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) resp, err := config.QuorumCallWithBroadcast(ctx, &Request{Value: int64(i)}) if err != nil { @@ -723,7 +717,7 @@ func TestQCMulticastRace(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { t.Error(err) } @@ -737,7 +731,7 @@ func TestQCMulticastRace(t *testing.T) { if resp.GetResult() != val { t.Fatal("resp is wrong") } - for i := 0; i < 100; i++ { + for i := range 100 { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) resp, err := config.QuorumCallWithMulticast(ctx, &Request{Value: int64(i)}) if err != nil { @@ -757,7 +751,7 @@ func BenchmarkQuorumCall(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { b.Error(err) } @@ -777,7 +771,7 @@ func BenchmarkQuorumCall(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("QC_AllSuccessful_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.QuorumCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -798,7 +792,7 @@ func BenchmarkQCMulticast(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { b.Error(err) } @@ -818,7 +812,7 @@ func BenchmarkQCMulticast(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("QCM_AllSuccessful_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) resp, err := config.QuorumCallWithMulticast(ctx, &Request{Value: int64(i)}) if err != nil { @@ -841,7 +835,7 @@ func BenchmarkQCBroadcastOption(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { b.Error(err) } @@ -861,7 +855,7 @@ func BenchmarkQCBroadcastOption(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("QCB_AllSuccessful_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) resp, err := config.QuorumCallWithBroadcast(ctx, &Request{Value: int64(i)}) if err != nil { @@ -886,8 +880,8 @@ func BenchmarkQCBroadcastOptionManyClients(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { - config, clientCleanup, err := newClient(srvAddrs, fmt.Sprintf("127.0.0.1:808%v", c), 3) + for c := range numClients { + config, clientCleanup, err := newClient(srvAddrs, 3) if err != nil { b.Error(err) } @@ -911,7 +905,7 @@ func BenchmarkQCBroadcastOptionManyClients(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("QCB_ManyClients_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { var wg sync.WaitGroup for _, client := range clients { wg.Add(1) @@ -940,7 +934,7 @@ func BenchmarkBroadcastCallAllServers(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newClient(srvAddrs) if err != nil { b.Error(err) } @@ -960,7 +954,7 @@ func BenchmarkBroadcastCallAllServers(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("BC_AllSuccessful_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -981,7 +975,7 @@ func BenchmarkBroadcastCallToOneServer(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[0:1], "127.0.0.1:8080", 3) + config, clientCleanup, err := newClient(srvAddrs[0:1], 3) if err != nil { b.Error(err) } @@ -1001,7 +995,7 @@ func BenchmarkBroadcastCallToOneServer(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("BC_OneSrv_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1022,7 +1016,7 @@ func BenchmarkBroadcastCallOneFailedServer(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs, "127.0.0.1:8080", 2) + config, clientCleanup, err := newClient(srvAddrs, 2) if err != nil { b.Error(err) } @@ -1042,7 +1036,7 @@ func BenchmarkBroadcastCallOneFailedServer(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("BC_OneSrvDown_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1063,7 +1057,7 @@ func BenchmarkBroadcastCallOneDownSrvToOneSrv(b *testing.B) { } defer srvCleanup() - config, clientCleanup, err := newClient(srvAddrs[1:2], "127.0.0.1:8080", 2) + config, clientCleanup, err := newClient(srvAddrs[1:2], 2) if err != nil { b.Error(err) } @@ -1083,7 +1077,7 @@ func BenchmarkBroadcastCallOneDownSrvToOneSrv(b *testing.B) { _ = pprof.StartCPUProfile(cpuProfile) b.Run(fmt.Sprintf("BC_OneDownToOne_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := config.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1106,8 +1100,8 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), 3) + for c := range numClients { + config, clientCleanup, err := newClient(srvAddrs[0:1], 3) if err != nil { b.Error(err) } @@ -1127,7 +1121,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } b.Run(fmt.Sprintf("BC_OneClientOneReq_%d", 0), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { resp, err := clients[0].BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { b.Error(err) @@ -1138,7 +1132,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } }) b.Run(fmt.Sprintf("BC_OneClientAsync_%d", 1), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { var wg sync.WaitGroup for j := range clients { go func(i, j int, c *Configuration) { @@ -1158,7 +1152,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } }) b.Run(fmt.Sprintf("BC_OneClientSync_%d", 2), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { for j := range clients { val := i*100 + j resp, err := clients[0].BroadcastCall(context.Background(), &Request{Value: int64(val)}) @@ -1172,7 +1166,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } }) b.Run(fmt.Sprintf("BC_ManyClientsAsync_%d", 3), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { var wg sync.WaitGroup for _, client := range clients { go func(i int, c *Configuration) { @@ -1191,7 +1185,7 @@ func BenchmarkBroadcastCallManyClients(b *testing.B) { } }) b.Run(fmt.Sprintf("BC_ManyClientsSync_%d", 4), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { for _, client := range clients { resp, err := client.BroadcastCall(context.Background(), &Request{Value: int64(i)}) if err != nil { @@ -1214,8 +1208,8 @@ func BenchmarkBroadcastCallTenClientsCPU(b *testing.B) { numClients := 10 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), 3) + for c := range numClients { + config, clientCleanup, err := newClient(srvAddrs[0:1], 3) if err != nil { b.Error(err) } @@ -1242,7 +1236,7 @@ func BenchmarkBroadcastCallTenClientsCPU(b *testing.B) { b.ResetTimer() b.Run(fmt.Sprintf("BC_%d", 3), func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := 0; b.Loop(); i++ { var wg sync.WaitGroup for _, client := range clients { go func(i int, c *Configuration) { @@ -1275,8 +1269,8 @@ func BenchmarkBroadcastCallTenClientsMEM(b *testing.B) { numClients := 1 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) + for c := range numClients { + config, clientCleanup, err := newClient(srvAddrs[0:1], numSrvs) if err != nil { b.Error(err) } @@ -1329,8 +1323,8 @@ func BenchmarkBroadcastCallTenClientsTRACE(b *testing.B) { numClients := 1 clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) + for c := range numClients { + config, clientCleanup, err := newClient(srvAddrs[0:1], numSrvs) if err != nil { b.Error(err) } @@ -1386,8 +1380,8 @@ func TestBroadcastCallManyRequestsAsync(t *testing.T) { } clients := make([]*Configuration, numClients) - for c := 0; c < numClients; c++ { - config, clientCleanup, err := newClient(srvAddrs[0:1], fmt.Sprintf("127.0.0.1:%v", 8080+c), numSrvs) + for c := range numClients { + config, clientCleanup, err := newClient(srvAddrs[0:1], numSrvs) if err != nil { t.Error(err) } @@ -1416,7 +1410,7 @@ func TestBroadcastCallManyRequestsAsync(t *testing.T) { time.Sleep(500 * time.Millisecond) var wg sync.WaitGroup - for r := 0; r < numReqs; r++ { + for range numReqs { for i, client := range clients { wg.Add(1) go func(i int, client *Configuration) { @@ -1446,7 +1440,7 @@ func TestAuthenticationBroadcastCall(t *testing.T) { } defer srvCleanup() - config, clientCleanup, err := newAuthClient(srvAddrs, "127.0.0.1:8080") + config, clientCleanup, err := newAuthClient(srvAddrs) if err != nil { t.Error(err) } diff --git a/tests/broadcast/client.go b/tests/broadcast/client.go index f447f6c14..fbd87570a 100644 --- a/tests/broadcast/client.go +++ b/tests/broadcast/client.go @@ -3,10 +3,11 @@ package broadcast import ( "crypto/elliptic" "log/slog" - net "net" + "net" - gorums "github.com/relab/gorums" - grpc "google.golang.org/grpc" + "github.com/relab/gorums" + "github.com/relab/gorums/authentication" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -127,7 +128,7 @@ func (qs *testQSpec) OrderQF(in *Request, replies []*Response) (*Response, bool) return nil, false } -func newClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configuration, func(), error) { +func newClient(srvAddrs []string, qsize ...int) (*Configuration, func(), error) { quorumSize := len(srvAddrs) if len(qsize) > 0 { quorumSize = qsize[0] @@ -137,15 +138,13 @@ func newClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configurati grpc.WithTransportCredentials(insecure.NewCredentials()), ), ) - if listenAddr != "" { - lis, err := net.Listen("tcp", "127.0.0.1:") - if err != nil { - return nil, nil, err - } - err = mgr.AddClientServer(lis, lis.Addr()) - if err != nil { - return nil, nil, err - } + lis, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + return nil, nil, err + } + err = mgr.AddClientServer(lis, lis.Addr()) + if err != nil { + return nil, nil, err } config, err := mgr.NewConfiguration( gorums.WithNodeList(srvAddrs), @@ -154,12 +153,10 @@ func newClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configurati if err != nil { return nil, nil, err } - return config, func() { - mgr.Close() - }, nil + return config, mgr.Close, nil } -func newAuthClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configuration, func(), error) { +func newAuthClient(srvAddrs []string, qsize ...int) (*Configuration, func(), error) { quorumSize := len(srvAddrs) if len(qsize) > 0 { quorumSize = qsize[0] @@ -168,27 +165,23 @@ func newAuthClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configu lis net.Listener err error ) - if listenAddr != "" { - lis, err = net.Listen("tcp", "127.0.0.1:") - if err != nil { - return nil, nil, err - } + lis, err = net.Listen("tcp", "127.0.0.1:") + if err != nil { + return nil, nil, err + } + auth, err := authentication.NewWithAddr(elliptic.P256(), lis.Addr()) + if err != nil { + return nil, nil, err } - auth := gorums.NewAuth(elliptic.P256()) - _ = auth.GenerateKeys() - privKey, pubKey := auth.Keys() - auth.RegisterKeys(lis.Addr(), privKey, pubKey) mgr := NewManager( gorums.WithAuthentication(auth), gorums.WithGrpcDialOptions( grpc.WithTransportCredentials(insecure.NewCredentials()), ), ) - if listenAddr != "" { - err = mgr.AddClientServer(lis, lis.Addr()) - if err != nil { - return nil, nil, err - } + err = mgr.AddClientServer(lis, lis.Addr()) + if err != nil { + return nil, nil, err } config, err := mgr.NewConfiguration( gorums.WithNodeList(srvAddrs), @@ -197,7 +190,5 @@ func newAuthClient(srvAddrs []string, listenAddr string, qsize ...int) (*Configu if err != nil { return nil, nil, err } - return config, func() { - mgr.Close() - }, nil + return config, mgr.Close, nil } diff --git a/tests/broadcast/server.go b/tests/broadcast/server.go index 83b6b409e..b484dc877 100644 --- a/tests/broadcast/server.go +++ b/tests/broadcast/server.go @@ -5,10 +5,12 @@ import ( "crypto/elliptic" "errors" net "net" + "slices" "sync" "time" gorums "github.com/relab/gorums" + "github.com/relab/gorums/authentication" grpc "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -34,7 +36,7 @@ type testServer struct { order []string } -func newtestServer(addr string, srvAddresses []string, _ int, withOrder ...bool) *testServer { +func newTestServer(addr string, srvAddresses []string, _ int, withOrder ...bool) *testServer { address, err := net.ResolveTCPAddr("tcp", addr) if err != nil { panic(err) @@ -87,10 +89,10 @@ func newAuthenticatedServer(addr string, srvAddresses []string) *testServer { if addr != leader { srv.processingTime = 100 * time.Millisecond } - auth := gorums.NewAuth(elliptic.P256()) - _ = auth.GenerateKeys() - privKey, pubKey := auth.Keys() - auth.RegisterKeys(address, privKey, pubKey) + auth, err := authentication.NewWithAddr(elliptic.P256(), address) + if err != nil { + panic(err) + } srv.mgr = NewManager( gorums.WithAuthentication(auth), gorums.WithGrpcDialOptions( @@ -245,13 +247,7 @@ func (srv *testServer) PrePrepare(ctx gorums.ServerCtx, req *Request, broadcast time.Sleep(200 * time.Millisecond) } srv.mut.Lock() - added := false - for _, m := range srv.order { - if m == "PrePrepare" { - added = true - break - } - } + added := slices.Contains(srv.order, "PrePrepare") if !added { srv.order = append(srv.order, "PrePrepare") } @@ -269,13 +265,7 @@ func (srv *testServer) Prepare(ctx gorums.ServerCtx, req *Request, broadcast *Br srv.mut.Unlock() return } - added := false - for _, m := range srv.order { - if m == "Prepare" { - added = true - break - } - } + added := slices.Contains(srv.order, "Prepare") if !added { srv.order = append(srv.order, "Prepare") }