diff --git a/testutils/tokens/id.go b/testutils/tokens/id.go index 83ddf69..339412c 100644 --- a/testutils/tokens/id.go +++ b/testutils/tokens/id.go @@ -25,8 +25,9 @@ generator functions in this package. Prefer to create test specific PDR and use it's ComposeUnitID method! */ -func PDR() types.PartitionDescriptionRecord { - return testPDR +func PDR() *types.PartitionDescriptionRecord { + copy := testPDR + return © } /* diff --git a/txsystem/money/unit_data_types.go b/txsystem/money/unit_data_types.go index da22f33..b6f0034 100644 --- a/txsystem/money/unit_data_types.go +++ b/txsystem/money/unit_data_types.go @@ -20,8 +20,8 @@ type BillData struct { Counter uint64 `json:"counter,string"` // The transaction counter of this bill } -func NewUnitData(unitID types.UnitID, pdr *types.PartitionDescriptionRecord) (types.UnitData, error) { - typeID, err := pdr.ExtractUnitType(unitID) +func NewUnitData(unitID types.UnitID, unitTypeExtractor types.UnitTypeExtractor) (types.UnitData, error) { + typeID, err := unitTypeExtractor(unitID) if err != nil { return nil, fmt.Errorf("extracting unit type: %w", err) } diff --git a/txsystem/orchestration/unit_types.go b/txsystem/orchestration/unit_types.go index ddfad86..ebbb8fa 100644 --- a/txsystem/orchestration/unit_types.go +++ b/txsystem/orchestration/unit_types.go @@ -10,8 +10,8 @@ const ( VarUnitType = 1 ) -func NewUnitData(unitID types.UnitID, pdr *types.PartitionDescriptionRecord) (types.UnitData, error) { - typeID, err := pdr.ExtractUnitType(unitID) +func NewUnitData(unitID types.UnitID, unitTypeExtractor types.UnitTypeExtractor) (types.UnitData, error) { + typeID, err := unitTypeExtractor(unitID) if err != nil { return nil, fmt.Errorf("extracting type ID: %w", err) } diff --git a/txsystem/tokens/unit_types.go b/txsystem/tokens/unit_types.go index c720651..f41dfa1 100644 --- a/txsystem/tokens/unit_types.go +++ b/txsystem/tokens/unit_types.go @@ -51,8 +51,8 @@ func GenerateUnitID(txo *types.TransactionOrder, shardConf *types.PartitionDescr return nil } -func NewUnitData(unitID types.UnitID, pdr *types.PartitionDescriptionRecord) (types.UnitData, error) { - typeID, err := pdr.ExtractUnitType(unitID) +func NewUnitData(unitID types.UnitID, unitTypeExtractor types.UnitTypeExtractor) (types.UnitData, error) { + typeID, err := unitTypeExtractor(unitID) if err != nil { return nil, fmt.Errorf("extracting type ID: %w", err) } diff --git a/types/identifiers.go b/types/identifiers.go index 0281513..ec42668 100644 --- a/types/identifiers.go +++ b/types/identifiers.go @@ -31,6 +31,9 @@ type ( // UnitID is the extended identifier, combining the type and the unit identifiers. UnitID []byte + + // UnitTypeExtractor is a function that extracts unit type from UnitID + UnitTypeExtractor func(UnitID) (uint32, error) ) func (uid UnitID) Compare(key UnitID) int { @@ -45,8 +48,8 @@ func (uid UnitID) Eq(id UnitID) bool { return bytes.Equal(uid, id) } -func (uid UnitID) TypeMustBe(typeID uint32, pdr *PartitionDescriptionRecord) error { - tid, err := pdr.ExtractUnitType(uid) +func (uid UnitID) TypeMustBe(typeID uint32, unitTypeExtractor UnitTypeExtractor) error { + tid, err := unitTypeExtractor(uid) if err != nil { return fmt.Errorf("extracting unit type from unit ID: %w", err) } diff --git a/types/partition_description.go b/types/partition_description.go index 6f80aef..1ef3fa1 100644 --- a/types/partition_description.go +++ b/types/partition_description.go @@ -115,6 +115,9 @@ func (pdr *PartitionDescriptionRecord) Verify(prev *PartitionDescriptionRecord) if pdr.NetworkID != prev.NetworkID { return fmt.Errorf("invalid network id, provided %d previous %d", pdr.NetworkID, prev.NetworkID) } + if pdr.PartitionTypeID != prev.PartitionTypeID { + return fmt.Errorf("invalid partition type id, provided %d previous %d", pdr.PartitionTypeID, prev.PartitionTypeID) + } if pdr.PartitionID != prev.PartitionID { return fmt.Errorf("invalid partition id, provided %d previous %d", pdr.PartitionID, prev.PartitionID) } @@ -137,14 +140,33 @@ func (pdr *PartitionDescriptionRecord) Hash(hashAlgorithm crypto.Hash) ([]byte, return hasher.Sum() } +func (pdr *PartitionDescriptionRecord) GetVersion() Version { + if pdr == nil || pdr.Version == 0 { + return 1 + } + return pdr.Version +} + func (pdr *PartitionDescriptionRecord) GetNetworkID() NetworkID { return pdr.NetworkID } +func (pdr *PartitionDescriptionRecord) GetPartitionTypeID() PartitionTypeID { + return pdr.PartitionTypeID +} + func (pdr *PartitionDescriptionRecord) GetPartitionID() PartitionID { return pdr.PartitionID } +func (pdr *PartitionDescriptionRecord) GetShardID() ShardID { + return pdr.ShardID +} + +func (pdr *PartitionDescriptionRecord) GetPartitionParams() map[string]string { + return pdr.PartitionParams +} + /* UnitIDValidator returns function which checks that unit ID passed as argument has correct length and that the unit belongs into the given shard. @@ -207,20 +229,13 @@ func (pdr *PartitionDescriptionRecord) ExtractUnitType(id UnitID) (uint32, error return 0, fmt.Errorf("expected unit ID length %d bytes, got %d bytes", idLen, len(id)) } - // we relay on the fact that valid PDR has "pdr.UnitIDLen >= 64" ie it's safe to read four bytes + // we rely on the fact that valid PDR has "pdr.UnitIDLen >= 64" ie it's safe to read four bytes idx := len(id) - 1 v := uint32(id[idx]) | (uint32(id[idx-1]) << 8) | (uint32(id[idx-2]) << 16) | (uint32(id[idx-3]) << 24) mask := uint32(0xFFFFFFFF) >> (32 - pdr.TypeIDLen) return v & mask, nil } -func (pdr *PartitionDescriptionRecord) GetVersion() Version { - if pdr == nil || pdr.Version == 0 { - return 1 - } - return pdr.Version -} - func (pdr *PartitionDescriptionRecord) MarshalCBOR() ([]byte, error) { type alias PartitionDescriptionRecord if pdr.Version == 0 { @@ -236,3 +251,12 @@ func (pdr *PartitionDescriptionRecord) UnmarshalCBOR(data []byte) error { } return EnsureVersion(pdr, pdr.Version, 1) } + +func (pdr *PartitionDescriptionRecord) FindValidator(nodeID string) *NodeInfo { + for _, validator := range pdr.Validators { + if validator.NodeID == nodeID { + return validator + } + } + return nil +} diff --git a/types/root_trust_base.go b/types/root_trust_base.go index a308940..3db8c1c 100644 --- a/types/root_trust_base.go +++ b/types/root_trust_base.go @@ -1,6 +1,7 @@ package types import ( + "bytes" "cmp" "crypto" "errors" @@ -15,7 +16,10 @@ import ( type ( RootTrustBase interface { + GetVersion() Version GetNetworkID() NetworkID + GetEpoch() uint64 + GetEpochStart() uint64 VerifyQuorumSignatures(data []byte, signatures map[string]hex.Bytes) error VerifySignature(data []byte, sig []byte, nodeID string) (uint64, error) GetQuorumThreshold() uint64 @@ -28,7 +32,7 @@ type ( Version Version `json:"version"` NetworkID NetworkID `json:"networkId"` Epoch uint64 `json:"epoch"` // current epoch number - EpochStartRound uint64 `json:"epochStartRound"` // root chain round number when the epoch begins + EpochStart uint64 `json:"epochStartRound"` // root chain round number when the epoch begins RootNodes []*NodeInfo `json:"rootNodes"` // list of all root nodes for the current epoch QuorumThreshold uint64 `json:"quorumThreshold"` // amount of coins required to reach consensus, currently each node gets equal amount of voting power i.e. +1 for each node StateHash hex.Bytes `json:"stateHash"` // unicity tree root hash @@ -51,18 +55,21 @@ type ( Option func(c *trustBaseConf) trustBaseConf struct { - quorumThreshold uint64 + epoch uint64 + epochStart uint64 + quorumThreshold uint64 + previousTrustBaseHash hex.Bytes } ) -// NewTrustBaseGenesis creates new unsigned root trust base with default parameters. -func NewTrustBaseGenesis(networkID NetworkID, rootNodes []*NodeInfo, opts ...Option) (*RootTrustBaseV1, error) { +// NewTrustBase creates new unsigned root trust base. +func NewTrustBase(networkID NetworkID, rootNodes []*NodeInfo, opts ...Option) (*RootTrustBaseV1, error) { if len(rootNodes) == 0 { return nil, errors.New("nodes list is empty") } // init config - c := &trustBaseConf{} + c := &trustBaseConf{epoch: 1} for _, opt := range opts { opt(c) } @@ -93,13 +100,13 @@ func NewTrustBaseGenesis(networkID NetworkID, rootNodes []*NodeInfo, opts ...Opt return &RootTrustBaseV1{ Version: 1, NetworkID: networkID, - Epoch: 1, - EpochStartRound: 1, + Epoch: c.epoch, + EpochStart: c.epochStart, RootNodes: rootNodes, QuorumThreshold: c.quorumThreshold, StateHash: nil, ChangeRecordHash: nil, - PreviousEntryHash: nil, + PreviousEntryHash: c.previousTrustBaseHash, Signatures: make(map[string]hex.Bytes), }, nil } @@ -111,6 +118,24 @@ func WithQuorumThreshold(threshold uint64) Option { } } +func WithEpoch(epoch uint64) Option { + return func(c *trustBaseConf) { + c.epoch = epoch + } +} + +func WithEpochStart(epochStart uint64) Option { + return func(c *trustBaseConf) { + c.epochStart = epochStart + } +} + +func WithPreviousTrustBaseHash(previousTrustBaseHash hex.Bytes) Option { + return func(c *trustBaseConf) { + c.previousTrustBaseHash = previousTrustBaseHash + } +} + // IsValid validates that all fields are correctly set and public keys are correct. func (n *NodeInfo) IsValid() error { if n == nil { @@ -171,7 +196,7 @@ func (r *RootTrustBaseV1) Hash(hashAlgo crypto.Hash) ([]byte, error) { return hasher.Sum() } -// SigBytes serializes all fields expect for the signatures field. +// SigBytes serializes all fields expect for the Signatures field itself. func (r RootTrustBaseV1) SigBytes() ([]byte, error) { r.Signatures = nil bs, err := r.MarshalCBOR() @@ -239,6 +264,14 @@ func (r *RootTrustBaseV1) GetNetworkID() NetworkID { return r.NetworkID } +func (r *RootTrustBaseV1) GetEpoch() uint64 { + return r.Epoch +} + +func (r *RootTrustBaseV1) GetEpochStart() uint64 { + return r.EpochStart +} + func (r *RootTrustBaseV1) MarshalCBOR() ([]byte, error) { type alias RootTrustBaseV1 if r.Version == 0 { @@ -264,3 +297,75 @@ func (r *RootTrustBaseV1) getRootNode(nodeID string) *NodeInfo { } return nil } + +// Verify verifies the trust base, including the signatures. +// +// Common for all trust bases: +// - The current epoch signatures must be valid and reach quorum. +// +// Genesis trust base: +// - Epoch must be equal to 1. +// +// Non-genesis trust base must extend previous trust base: +// - The network identifiers must match. +// - The epoch number must be strictly greater than the previous epoch number. +// - The epoch start round must be strictly greater than the previous epoch start round. +// - The hash of the previous trust must match the previousEntryHash. +// - The previous epoch signatures must be valid and reach quorum. +func (r *RootTrustBaseV1) Verify(prev *RootTrustBaseV1) error { + if err := r.IsValid(prev); err != nil { + return err + } + return r.VerifySignatures(prev) +} + +// IsValid verifies the trust base without verifying the signatures. +// Use VerifySignatures to verify the signatures. +func (r *RootTrustBaseV1) IsValid(prev *RootTrustBaseV1) error { + if prev == nil { + if r.Epoch != 1 { + return fmt.Errorf("genesis trust base epoch must be 1, got %d", r.Epoch) + } + return nil + } + if r.NetworkID != prev.NetworkID { + return fmt.Errorf("invalid network id, got %d previous %d", r.NetworkID, prev.NetworkID) + } + if r.Epoch != prev.Epoch+1 { + return fmt.Errorf("invalid epoch, got %d previous %d", r.Epoch, prev.Epoch) + } + if r.EpochStart <= prev.EpochStart { + return fmt.Errorf("invalid epoch start, got %d previous %d", r.EpochStart, prev.EpochStart) + } + prevHash, err := prev.Hash(crypto.SHA256) + if err != nil { + return fmt.Errorf("failed to calculate previous trust base hash: %w", err) + } + if !bytes.Equal(r.PreviousEntryHash, prevHash) { + return errors.New("previous trust base hash does not match") + } + return nil +} + +// VerifySignatures verifies that the trust base is signed by the previous +// epoch's validators. For the genesis trust base (epoch 1), the trust base +// must be self-signed by the genesis (epoch 1) validators. +func (r *RootTrustBaseV1) VerifySignatures(prev *RootTrustBaseV1) error { + sigBytes, err := r.SigBytes() + if err != nil { + return fmt.Errorf("failed to get previous epoch sig bytes: %w", err) + } + var tb *RootTrustBaseV1 + if r.Epoch == 1 { + tb = r + } else { + if prev == nil { + return errors.New("previous trust base is nil") + } + tb = prev + } + if err := tb.VerifyQuorumSignatures(sigBytes, r.Signatures); err != nil { + return fmt.Errorf("failed to verify signatures: %w", err) + } + return nil +} diff --git a/types/root_trust_base_test.go b/types/root_trust_base_test.go index 1f3343e..20dd4d5 100644 --- a/types/root_trust_base_test.go +++ b/types/root_trust_base_test.go @@ -1,13 +1,15 @@ package types import ( + "crypto" "fmt" "strconv" "testing" + "github.com/stretchr/testify/require" + abcrypto "github.com/unicitynetwork/bft-go-base/crypto" "github.com/unicitynetwork/bft-go-base/types/hex" - "github.com/stretchr/testify/require" ) func TestNodeInfo_IsValid(t *testing.T) { @@ -64,16 +66,16 @@ func TestNewTrustBaseGenesis(t *testing.T) { name: "default settings ok", args: args{ nodes: []*NodeInfo{ - &NodeInfo{NodeID: "1", SigKey: keys["1"].publicKey, Stake: 1}, - &NodeInfo{NodeID: "2", SigKey: keys["2"].publicKey, Stake: 1}, - &NodeInfo{NodeID: "3", SigKey: keys["3"].publicKey, Stake: 1}, + {NodeID: "1", SigKey: keys["1"].publicKey, Stake: 1}, + {NodeID: "2", SigKey: keys["2"].publicKey, Stake: 1}, + {NodeID: "3", SigKey: keys["3"].publicKey, Stake: 1}, }, unicityTreeRootHash: []byte{1}, }, verifyFunc: func(t *testing.T, tb *RootTrustBaseV1) { // verify values require.EqualValues(t, 1, tb.Epoch) - require.EqualValues(t, 1, tb.EpochStartRound) + require.EqualValues(t, 0, tb.EpochStart) require.Len(t, tb.RootNodes, 3) require.EqualValues(t, 3, tb.QuorumThreshold) require.EqualValues(t, hex.Bytes(nil), tb.StateHash) @@ -186,7 +188,7 @@ func TestNewTrustBaseGenesis(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tb, err := NewTrustBaseGenesis(NetworkMainNet, tt.args.nodes, tt.args.opts...) + tb, err := NewTrustBase(NetworkMainNet, tt.args.nodes, tt.args.opts...) if tt.wantErrStr != "" { require.ErrorContains(t, err, tt.wantErrStr) require.Nil(t, tb) @@ -203,9 +205,9 @@ func TestNewTrustBaseGenesis(t *testing.T) { func TestSignAndVerify(t *testing.T) { keys := genKeys(1) - tb, err := NewTrustBaseGenesis( + tb, err := NewTrustBase( NetworkMainNet, - []*NodeInfo{&NodeInfo{NodeID: "1", SigKey: keys["1"].publicKey, Stake: 1}}, + []*NodeInfo{{NodeID: "1", SigKey: keys["1"].publicKey, Stake: 1}}, ) require.NoError(t, err) @@ -222,12 +224,12 @@ func TestSignAndVerify(t *testing.T) { func Test_RootTrustBaseV1_CBOR(t *testing.T) { keys := genKeys(3) - tb, err := NewTrustBaseGenesis( + tb, err := NewTrustBase( NetworkMainNet, []*NodeInfo{ - &NodeInfo{NodeID: "1", SigKey: keys["1"].publicKey, Stake: 1}, - &NodeInfo{NodeID: "2", SigKey: keys["2"].publicKey, Stake: 1}, - &NodeInfo{NodeID: "3", SigKey: keys["3"].publicKey, Stake: 1}, + {NodeID: "1", SigKey: keys["1"].publicKey, Stake: 1}, + {NodeID: "2", SigKey: keys["2"].publicKey, Stake: 1}, + {NodeID: "3", SigKey: keys["3"].publicKey, Stake: 1}, }, ) require.NoError(t, err) @@ -287,14 +289,14 @@ func genKeys(count int) map[string]key { return keys } -func NewTrustBase(t *testing.T, verifiers ...abcrypto.Verifier) RootTrustBase { +func NewTrustBaseT(t *testing.T, verifiers ...abcrypto.Verifier) RootTrustBase { var nodes []*NodeInfo for _, v := range verifiers { sigKey, err := v.MarshalPublicKey() require.NoError(t, err) nodes = append(nodes, &NodeInfo{NodeID: "test", SigKey: sigKey, Stake: 1}) } - tb, err := NewTrustBaseGenesis(NetworkMainNet, nodes) + tb, err := NewTrustBase(NetworkMainNet, nodes) require.NoError(t, err) return tb } @@ -306,7 +308,203 @@ func NewTrustBaseFromVerifiers(t *testing.T, verifiers map[string]abcrypto.Verif require.NoError(t, err) nodes = append(nodes, &NodeInfo{NodeID: nodeID, SigKey: sigKey, Stake: 1}) } - tb, err := NewTrustBaseGenesis(NetworkMainNet, nodes) + tb, err := NewTrustBase(NetworkMainNet, nodes) require.NoError(t, err) return tb } + +func TestRootTrustBaseV1_Verify(t *testing.T) { + // epoch 1 = nodes 1-3 + keys := genKeys(3) + nodes := make([]*NodeInfo, 0, len(keys)) + for i := 1; i <= len(keys); i++ { + nodeID := strconv.Itoa(i) + nodes = append(nodes, &NodeInfo{NodeID: nodeID, SigKey: keys[nodeID].publicKey, Stake: 1}) + } + + // create trust base for epoch 1 + tb1Signed, err := NewTrustBase(NetworkLocal, nodes, + WithEpoch(1), + WithEpochStart(5), + ) + require.NoError(t, err) + require.EqualValues(t, 3, tb1Signed.QuorumThreshold) + + // sign trust base for epoch 1 + for i := 1; i <= 3; i++ { + nodeID := strconv.Itoa(i) + require.NoError(t, tb1Signed.Sign(nodeID, keys[nodeID].signer)) + } + + // calculate signed trust base hash + tb0Hash, err := tb1Signed.Hash(crypto.SHA256) + require.NoError(t, err) + + // create a valid trust base for epoch 2, signed by previous validators + keys1 := genKeys(3) + nodes1 := make([]*NodeInfo, 0, len(keys1)) + for i := 1 + 10; i <= len(keys)+10; i++ { + nodeID := strconv.Itoa(i) + nodes1 = append(nodes1, &NodeInfo{NodeID: nodeID, SigKey: keys[nodeID].publicKey, Stake: 1}) + } + tb2Signed, err := NewTrustBase(NetworkLocal, nodes1, + WithEpoch(2), + WithEpochStart(50), + WithPreviousTrustBaseHash(tb0Hash), + ) + require.NoError(t, err) + require.EqualValues(t, 3, tb1Signed.QuorumThreshold) + + // sign tb2 with previous epoch keys + for i := 1; i <= 3; i++ { + nodeID := strconv.Itoa(i) + require.NoError(t, tb2Signed.Sign(nodeID, keys[nodeID].signer)) + } + + tests := []struct { + name string + prev *RootTrustBaseV1 + curr *RootTrustBaseV1 + wantErr string + }{ + { + name: "genesis trust base with epoch zero", + prev: nil, + curr: tb1Signed, + }, + { + name: "genesis trust base without signatures", + prev: nil, + curr: func() *RootTrustBaseV1 { + // create unsigned trust base for epoch 1 + tb1Unsigned, err := NewTrustBase(NetworkLocal, nodes, + WithEpoch(1), + WithEpochStart(5), + ) + require.NoError(t, err) + return tb1Unsigned + }(), + wantErr: "failed to verify signatures: quorum not reached, signed_votes=0 quorum_threshold=3", + }, + { + name: "genesis trust base with zero epoch", + prev: nil, + curr: func() *RootTrustBaseV1 { + g := *tb1Signed + g.Epoch = 0 + return &g + }(), + wantErr: "genesis trust base epoch must be 1, got 0", + }, + { + name: "extend", + prev: tb1Signed, + curr: tb2Signed, + }, + { + name: "extend with different network id", + prev: tb1Signed, + curr: func() *RootTrustBaseV1 { + b := *tb2Signed + b.NetworkID = b.NetworkID + 1 + return &b + }(), + wantErr: "invalid network id, got 4 previous 3", + }, + { + name: "extend with same epoch", + prev: tb1Signed, + curr: func() *RootTrustBaseV1 { + b := *tb2Signed + b.Epoch = 1 + return &b + }(), + wantErr: "invalid epoch, got 1 previous 1", + }, + { + name: "extend with epoch not incremented by 1", + prev: tb1Signed, + curr: func() *RootTrustBaseV1 { + b := *tb2Signed + b.Epoch = 3 + return &b + }(), + wantErr: "invalid epoch, got 3 previous 1", + }, + { + name: "extend with same epoch start", + prev: tb1Signed, + curr: func() *RootTrustBaseV1 { + b := *tb2Signed + b.EpochStart = 5 + return &b + }(), + wantErr: "invalid epoch start, got 5 previous 5", + }, + { + name: "extend with smaller epoch start", + prev: tb1Signed, + curr: func() *RootTrustBaseV1 { + b := *tb2Signed + b.EpochStart = 4 + return &b + }(), + wantErr: "invalid epoch start, got 4 previous 5", + }, + { + name: "extend with invalid previous hash", + prev: tb1Signed, + curr: func() *RootTrustBaseV1 { + b := *tb2Signed + b.PreviousEntryHash = []byte{1, 2, 3} + return &b + }(), + wantErr: "previous trust base hash does not match", + }, + { + name: "extend without previous epoch signatures", + prev: tb1Signed, + curr: func() *RootTrustBaseV1 { + tb, err := NewTrustBase(NetworkLocal, nodes, + WithEpoch(2), + WithEpochStart(50), + WithPreviousTrustBaseHash(tb0Hash), + ) + require.NoError(t, err) + return tb + }(), + wantErr: "failed to verify signatures: quorum not reached, signed_votes=0 quorum_threshold=3", + }, + { + name: "extend with not enough previous epoch signatures", + prev: tb1Signed, + curr: func() *RootTrustBaseV1 { + tb, err := NewTrustBase(NetworkLocal, nodes, + WithEpoch(2), + WithEpochStart(50), + WithPreviousTrustBaseHash(tb0Hash), + ) + require.NoError(t, err) + + // sign with 2 of 3 required previous epoch keys + for i := 1; i <= 2; i++ { + nodeID := strconv.Itoa(i) + require.NoError(t, tb.Sign(nodeID, keys[nodeID].signer)) + } + return tb + }(), + wantErr: "failed to verify signatures: quorum not reached, signed_votes=2 quorum_threshold=3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.curr.Verify(tt.prev) + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/types/tx_proof.go b/types/tx_proof.go index 96b125a..fc7ea7a 100644 --- a/types/tx_proof.go +++ b/types/tx_proof.go @@ -107,7 +107,8 @@ func VerifyTxInclusion(txRecordProof *TxRecordProof, tb RootTrustBase, hashAlgor return fmt.Errorf("failed to get transaction order: %w", err) } - if err := uc.Verify(tb, hashAlgorithm, txo.PartitionID, nil); err != nil { + // TODO: actual shardID extracted + if err := uc.Verify(tb, hashAlgorithm, txo.PartitionID, ShardID{}, nil); err != nil { return fmt.Errorf("invalid unicity certificate: %w", err) } // h ← plain_tree_output(C, H(P)) diff --git a/types/tx_proof_test.go b/types/tx_proof_test.go index 5ed5edc..6857068 100644 --- a/types/tx_proof_test.go +++ b/types/tx_proof_test.go @@ -5,9 +5,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + abcrypto "github.com/unicitynetwork/bft-go-base/crypto" testsig "github.com/unicitynetwork/bft-go-base/testutils/sig" - "github.com/stretchr/testify/require" ) func TestNewTxProof(t *testing.T) { @@ -45,21 +46,21 @@ func TestVerifyInc(t *testing.T) { block := createBlock(t, "test", signer, createTx(t)) proof, err := NewTxRecordProof(block, 0, crypto.SHA256) require.NoError(t, err) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) require.NoError(t, VerifyTxInclusion(proof, tb, crypto.SHA256)) }) t.Run("Test tx record proof is nil", func(t *testing.T) { _, verifier := testsig.CreateSignerAndVerifier(t) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) require.EqualError(t, VerifyTxInclusion(nil, tb, crypto.SHA256), "transaction record proof is nil") }) t.Run("Test tx record is nil", func(t *testing.T) { _, verifier := testsig.CreateSignerAndVerifier(t) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) proof := &TxRecordProof{TxProof: &TxProof{Version: 1}} require.EqualError(t, VerifyTxInclusion(proof, tb, crypto.SHA256), "transaction record is nil") @@ -67,7 +68,7 @@ func TestVerifyInc(t *testing.T) { t.Run("Test tx order is nil", func(t *testing.T) { _, verifier := testsig.CreateSignerAndVerifier(t) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) txr := &TransactionRecord{Version: 1, ServerMetadata: &ServerMetadata{SuccessIndicator: TxStatusSuccessful}} proof := &TxRecordProof{TxRecord: txr, TxProof: &TxProof{Version: 1}} @@ -79,7 +80,7 @@ func TestVerifyInc(t *testing.T) { block := createBlock(t, "test", signer, createTx(t)) proof, err := NewTxRecordProof(block, 0, crypto.SHA256) require.NoError(t, err) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) uc, err := proof.TxProof.GetUC() require.NoError(t, err) uc.UnicityTreeCertificate.Partition = 1 @@ -94,7 +95,7 @@ func TestVerifyInc(t *testing.T) { block := createBlock(t, "test", signer, createTx(t)) proof, err := NewTxRecordProof(block, 0, crypto.SHA256) require.NoError(t, err) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) proof.TxProof.BlockHeaderHash = make([]byte, 32) require.EqualError(t, VerifyTxInclusion(proof, tb, crypto.SHA256), "proof block hash does not match to block hash in unicity certificate") @@ -107,7 +108,7 @@ func TestVerifyTxProof(t *testing.T) { block := createBlock(t, "test", signer, createTx(t)) proof, err := NewTxRecordProof(block, 0, crypto.SHA256) require.NoError(t, err) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) require.NoError(t, VerifyTxInclusion(proof, tb, crypto.SHA256)) }) @@ -119,7 +120,7 @@ func TestVerifyTxProof(t *testing.T) { block := createBlock(t, "test", signer, txr) proof, err := NewTxRecordProof(block, 0, crypto.SHA256) require.NoError(t, err) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) require.EqualError(t, VerifyTxProof(proof, tb, crypto.SHA256), "transaction failed") }) @@ -131,7 +132,7 @@ func TestVerifyTxProof(t *testing.T) { block := createBlock(t, "test", signer, txr) proof, err := NewTxRecordProof(block, 0, crypto.SHA256) require.NoError(t, err) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) require.EqualError(t, VerifyTxProof(proof, tb, crypto.SHA256), "transaction failed") }) diff --git a/types/tx_record.go b/types/tx_record.go index 042b1e3..aa983b1 100644 --- a/types/tx_record.go +++ b/types/tx_record.go @@ -207,7 +207,7 @@ func (t *TxRecordProof) Verify(getTrustBase func(epoch uint64) (RootTrustBase, e return errors.New("invalid UC: missing UnicitySeal") } trustBase, err := getTrustBase(uc.UnicitySeal.Epoch) - if err != nil { + if err != nil || trustBase == nil { return fmt.Errorf("acquiring trust base: %w", err) } return VerifyTxProof(t, trustBase, crypto.SHA256) diff --git a/types/unicity_certificate.go b/types/unicity_certificate.go index 4c2e1f8..588182c 100644 --- a/types/unicity_certificate.go +++ b/types/unicity_certificate.go @@ -51,7 +51,8 @@ func (x *UnicityCertificate) IsValid(partitionID PartitionID, shardConfHash []by return nil } -func (x *UnicityCertificate) Verify(tb RootTrustBase, algorithm crypto.Hash, partitionID PartitionID, shardConfHash []byte) error { +// TODO: verify shardID also +func (x *UnicityCertificate) Verify(tb RootTrustBase, algorithm crypto.Hash, partitionID PartitionID, shardID ShardID, shardConfHash []byte) error { if err := x.IsValid(partitionID, shardConfHash); err != nil { return fmt.Errorf("invalid unicity certificate: %w", err) } @@ -145,6 +146,20 @@ func (x *UnicityCertificate) GetShardID() ShardID { return x.ShardTreeCertificate.Shard } +func (x *UnicityCertificate) GetShardEpoch() uint64 { + if x != nil && x.InputRecord != nil { + return x.InputRecord.Epoch + } + return 0 +} + +func (x *UnicityCertificate) GetRootEpoch() uint64 { + if x != nil && x.UnicitySeal != nil { + return x.UnicitySeal.Epoch + } + return 0 +} + // CheckNonEquivocatingCertificates checks if provided certificates are equivocating // NB! order is important, also it is assumed that validity of both UCs is checked before // The algorithm is based on Yellowpaper: "Algorithm 6 Checking two UC-s for equivocation" diff --git a/types/unicity_certificate_test.go b/types/unicity_certificate_test.go index 8e25c03..2732a5b 100644 --- a/types/unicity_certificate_test.go +++ b/types/unicity_certificate_test.go @@ -75,7 +75,7 @@ func TestUnicityCertificate_IsValid(t *testing.T) { require.EqualValues(t, 0, uc.GetRoundNumber()) require.EqualValues(t, 0, uc.GetRootRoundNumber()) require.Nil(t, uc.GetStateHash()) - require.ErrorIs(t, uc.Verify(nil, crypto.SHA256, 0, nil), ErrUnicityCertificateIsNil) + require.ErrorIs(t, uc.Verify(nil, crypto.SHA256, 0, ShardID{}, nil), ErrUnicityCertificateIsNil) }) t.Run("invalid input record", func(t *testing.T) { @@ -149,7 +149,7 @@ func TestUnicityCertificate_Verify(t *testing.T) { trHash1 := bytes.Repeat([]byte{11}, 32) signer, verifier := testsig.CreateSignerAndVerifier(t) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) // must use const timestamp to have deterministic UC hash const curTimestamp uint64 = 1731504540 @@ -209,25 +209,25 @@ func TestUnicityCertificate_Verify(t *testing.T) { } } - require.NoError(t, validUC(t, sid0, &ir0, trHash0, shardConf0Hash).Verify(tb, crypto.SHA256, shardConf0.PartitionID, shardConf0Hash)) - require.NoError(t, validUC(t, sid1, &ir1, trHash1, shardConf1Hash).Verify(tb, crypto.SHA256, shardConf0.PartitionID, shardConf1Hash)) + require.NoError(t, validUC(t, sid0, &ir0, trHash0, shardConf0Hash).Verify(tb, crypto.SHA256, shardConf0.PartitionID, shardConf0.ShardID, shardConf0Hash)) + require.NoError(t, validUC(t, sid1, &ir1, trHash1, shardConf1Hash).Verify(tb, crypto.SHA256, shardConf0.PartitionID, shardConf0.ShardID, shardConf1Hash)) t.Run("IsValid", func(t *testing.T) { // check that IsValid is called uc := UnicityCertificate{Version: 1} - require.EqualError(t, uc.Verify(nil, crypto.SHA256, 0, nil), + require.EqualError(t, uc.Verify(nil, crypto.SHA256, 0, ShardID{}, nil), "invalid unicity certificate: invalid input record: input record is nil") }) t.Run("tb is nil", func(t *testing.T) { uc := validUC(t, sid0, &ir0, trHash0, shardConf0Hash) - require.EqualError(t, uc.Verify(nil, crypto.SHA256, shardConf0.PartitionID, shardConf0Hash), "verifying unicity seal: root node info is missing") + require.EqualError(t, uc.Verify(nil, crypto.SHA256, shardConf0.PartitionID, shardConf0.ShardID, shardConf0Hash), "verifying unicity seal: root node info is missing") }) t.Run("invalid root hash", func(t *testing.T) { uc := validUC(t, sid0, &ir0, trHash0, shardConf0Hash) uc.UnicitySeal.Hash = []byte{1, 2, 3} - require.EqualError(t, uc.Verify(tb, crypto.SHA256, shardConf0.PartitionID, shardConf0Hash), + require.EqualError(t, uc.Verify(tb, crypto.SHA256, shardConf0.PartitionID, shardConf0.ShardID, shardConf0Hash), "unicity seal hash 010203 does not match with the root hash of the unicity tree F06B596575FAE5F211C9738A657C55A13D06F7E22CE40F02A4682FDA7C1FD44F") }) } diff --git a/types/unicity_seal_test.go b/types/unicity_seal_test.go index abe65e4..1fb45d6 100644 --- a/types/unicity_seal_test.go +++ b/types/unicity_seal_test.go @@ -80,7 +80,7 @@ func TestUnicitySeal_IsValid(t *testing.T) { func TestUnicitySeal_Verify(t *testing.T) { signer, verifier := testsig.CreateSignerAndVerifier(t) - trustBase := NewTrustBase(t, verifier) + trustBase := NewTrustBaseT(t, verifier) randomHash := test.RandomBytes(32) // createUS returns UnicitySeal which is not signed but otherwise valid(ish) @@ -197,7 +197,7 @@ func TestUnicitySeal_cbor(t *testing.T) { err := seal.Sign("test", signer) require.NoError(t, err) - tb := NewTrustBase(t, verifier) + tb := NewTrustBaseT(t, verifier) err = seal.Verify(tb) require.NoError(t, err) diff --git a/types/unit_proof.go b/types/unit_proof.go index bc516ac..11cd315 100644 --- a/types/unit_proof.go +++ b/types/unit_proof.go @@ -62,7 +62,7 @@ type ( } UnicityCertificateValidator interface { - Validate(uc *UnicityCertificate, shardConfHash []byte) error + Validate(uc *UnicityCertificate, shardConf *PartitionDescriptionRecord, trustBase RootTrustBase) error } ) @@ -93,7 +93,7 @@ func (u *UnitStateProof) getUCv1() (*UnicityCertificate, error) { return uc, nil } -func (u *UnitStateProof) Verify(algorithm crypto.Hash, unitState *UnitState, ucv UnicityCertificateValidator, shardConfHash []byte) error { +func (u *UnitStateProof) Verify(algorithm crypto.Hash, unitState *UnitState, ucv UnicityCertificateValidator, shardConf *PartitionDescriptionRecord, trustBase RootTrustBase) error { if err := u.IsValid(); err != nil { return fmt.Errorf("invalid unit state proof: %w", err) } @@ -105,7 +105,7 @@ func (u *UnitStateProof) Verify(algorithm crypto.Hash, unitState *UnitState, ucv if err != nil { return fmt.Errorf("failed to get unicity certificate: %w", err) } - if err := ucv.Validate(uc, shardConfHash); err != nil { + if err := ucv.Validate(uc, shardConf, trustBase); err != nil { return fmt.Errorf("invalid unicity certificate: %w", err) } diff --git a/types/unit_proof_test.go b/types/unit_proof_test.go index 50005c8..9535405 100644 --- a/types/unit_proof_test.go +++ b/types/unit_proof_test.go @@ -12,11 +12,11 @@ import ( type alwaysValid struct{} type alwaysInvalid struct{} -func (a *alwaysValid) Validate(*UnicityCertificate, []byte) error { +func (a *alwaysValid) Validate(uc *UnicityCertificate, shardConf *PartitionDescriptionRecord, trustBase RootTrustBase) error { return nil } -func (a alwaysInvalid) Validate(*UnicityCertificate, []byte) error { +func (a *alwaysInvalid) Validate(uc *UnicityCertificate, shardConf *PartitionDescriptionRecord, trustBase RootTrustBase) error { return errors.New("invalid uc") } @@ -27,13 +27,13 @@ func TestVerifyUnitStateProof(t *testing.T) { t.Run("unit state proof is nil", func(t *testing.T) { data := &UnitState{} var usp *UnitStateProof - require.EqualError(t, usp.Verify(crypto.SHA256, data, &alwaysValid{}, nil), "invalid unit state proof: unit state proof is nil") + require.EqualError(t, usp.Verify(crypto.SHA256, data, &alwaysValid{}, nil, nil), "invalid unit state proof: unit state proof is nil") }) t.Run("unit ID missing", func(t *testing.T) { data := &UnitState{} usp := &UnitStateProof{} - require.EqualError(t, usp.Verify(crypto.SHA256, data, &alwaysValid{}, nil), "invalid unit state proof: unit ID is unassigned") + require.EqualError(t, usp.Verify(crypto.SHA256, data, &alwaysValid{}, nil, nil), "invalid unit state proof: unit ID is unassigned") }) t.Run("unit tree cert missing", func(t *testing.T) { @@ -41,7 +41,7 @@ func TestVerifyUnitStateProof(t *testing.T) { UnitID: []byte{0}, } data := &UnitState{} - require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil), "invalid unit state proof: unit tree cert is nil") + require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil, nil), "invalid unit state proof: unit tree cert is nil") }) t.Run("state tree cert missing", func(t *testing.T) { @@ -50,7 +50,7 @@ func TestVerifyUnitStateProof(t *testing.T) { UnitTreeCert: &UnitTreeCert{}, } data := &UnitState{} - require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil), "invalid unit state proof: state tree cert is nil") + require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil, nil), "invalid unit state proof: state tree cert is nil") }) t.Run("unicity certificate missing", func(t *testing.T) { @@ -60,7 +60,7 @@ func TestVerifyUnitStateProof(t *testing.T) { StateTreeCert: &StateTreeCert{}, } data := &UnitState{} - require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil), "invalid unit state proof: unicity certificate is nil") + require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil, nil), "invalid unit state proof: unicity certificate is nil") }) t.Run("invalid unicity certificate", func(t *testing.T) { @@ -71,7 +71,7 @@ func TestVerifyUnitStateProof(t *testing.T) { UnicityCertificate: emptyUC, } data := &UnitState{} - require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysInvalid{}, nil), "invalid unicity certificate: invalid uc") + require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysInvalid{}, nil, nil), "invalid unicity certificate: invalid uc") }) t.Run("missing unit data", func(t *testing.T) { @@ -81,7 +81,7 @@ func TestVerifyUnitStateProof(t *testing.T) { StateTreeCert: &StateTreeCert{}, UnicityCertificate: emptyUC, } - require.EqualError(t, proof.Verify(crypto.SHA256, nil, &alwaysValid{}, nil), "unit state is nil") + require.EqualError(t, proof.Verify(crypto.SHA256, nil, &alwaysValid{}, nil, nil), "unit state is nil") }) t.Run("unit data hash invalid", func(t *testing.T) { @@ -92,7 +92,7 @@ func TestVerifyUnitStateProof(t *testing.T) { UnicityCertificate: emptyUC, } data := &UnitState{} - require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil), "unit state hash does not match unit state hash in unit tree cert") + require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil, nil), "unit state hash does not match unit state hash in unit tree cert") }) t.Run("invalid summary value", func(t *testing.T) { @@ -109,7 +109,7 @@ func TestVerifyUnitStateProof(t *testing.T) { uc.InputRecord = &InputRecord{SummaryValue: []byte{1}} proof.UnicityCertificate, err = uc.MarshalCBOR() require.NoError(t, err) - require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil), "invalid summary value: expected 01, got 0000000000000000") + require.EqualError(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil, nil), "invalid summary value: expected 01, got 0000000000000000") }) t.Run("invalid state root hash", func(t *testing.T) { @@ -126,7 +126,7 @@ func TestVerifyUnitStateProof(t *testing.T) { uc.InputRecord = &InputRecord{SummaryValue: []byte{0, 0, 0, 0, 0, 0, 0, 0}} proof.UnicityCertificate, err = uc.MarshalCBOR() require.NoError(t, err) - require.ErrorContains(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil), "invalid state root hash") + require.ErrorContains(t, proof.Verify(crypto.SHA256, data, &alwaysValid{}, nil, nil), "invalid state root hash") }) t.Run("verify - ok", func(t *testing.T) { @@ -146,7 +146,7 @@ func TestVerifyUnitStateProof(t *testing.T) { uc.InputRecord.Hash = hash proof.UnicityCertificate, err = uc.MarshalCBOR() require.NoError(t, err) - require.NoError(t, proof.Verify(crypto.SHA256, unitState, &alwaysValid{}, nil), "unexpected error") + require.NoError(t, proof.Verify(crypto.SHA256, unitState, &alwaysValid{}, nil, nil), "unexpected error") }) }