Skip to content
5 changes: 3 additions & 2 deletions testutils/tokens/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ generator functions in this package.

Prefer to create test specific PDR and use it's ComposeUnitID method!
*/
func PDR() types.PartitionDescriptionRecord {
return testPDR
func PDR() *types.PartitionDescriptionRecord {
copy := testPDR
return &copy
}

/*
Expand Down
4 changes: 2 additions & 2 deletions txsystem/money/unit_data_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ type BillData struct {
Counter uint64 `json:"counter,string"` // The transaction counter of this bill
}

func NewUnitData(unitID types.UnitID, pdr *types.PartitionDescriptionRecord) (types.UnitData, error) {
typeID, err := pdr.ExtractUnitType(unitID)
func NewUnitData(unitID types.UnitID, unitTypeExtractor types.UnitTypeExtractor) (types.UnitData, error) {
typeID, err := unitTypeExtractor(unitID)
if err != nil {
return nil, fmt.Errorf("extracting unit type: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions txsystem/orchestration/unit_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const (
VarUnitType = 1
)

func NewUnitData(unitID types.UnitID, pdr *types.PartitionDescriptionRecord) (types.UnitData, error) {
typeID, err := pdr.ExtractUnitType(unitID)
func NewUnitData(unitID types.UnitID, unitTypeExtractor types.UnitTypeExtractor) (types.UnitData, error) {
typeID, err := unitTypeExtractor(unitID)
if err != nil {
return nil, fmt.Errorf("extracting type ID: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions txsystem/tokens/unit_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func GenerateUnitID(txo *types.TransactionOrder, shardConf *types.PartitionDescr
return nil
}

func NewUnitData(unitID types.UnitID, pdr *types.PartitionDescriptionRecord) (types.UnitData, error) {
typeID, err := pdr.ExtractUnitType(unitID)
func NewUnitData(unitID types.UnitID, unitTypeExtractor types.UnitTypeExtractor) (types.UnitData, error) {
typeID, err := unitTypeExtractor(unitID)
if err != nil {
return nil, fmt.Errorf("extracting type ID: %w", err)
}
Expand Down
7 changes: 5 additions & 2 deletions types/identifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type (

// UnitID is the extended identifier, combining the type and the unit identifiers.
UnitID []byte

// UnitTypeExtractor is a function that extracts unit type from UnitID
UnitTypeExtractor func(UnitID) (uint32, error)
)

func (uid UnitID) Compare(key UnitID) int {
Expand All @@ -45,8 +48,8 @@ func (uid UnitID) Eq(id UnitID) bool {
return bytes.Equal(uid, id)
}

func (uid UnitID) TypeMustBe(typeID uint32, pdr *PartitionDescriptionRecord) error {
tid, err := pdr.ExtractUnitType(uid)
func (uid UnitID) TypeMustBe(typeID uint32, unitTypeExtractor UnitTypeExtractor) error {
tid, err := unitTypeExtractor(uid)
if err != nil {
return fmt.Errorf("extracting unit type from unit ID: %w", err)
}
Expand Down
40 changes: 32 additions & 8 deletions types/partition_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ func (pdr *PartitionDescriptionRecord) Verify(prev *PartitionDescriptionRecord)
if pdr.NetworkID != prev.NetworkID {
return fmt.Errorf("invalid network id, provided %d previous %d", pdr.NetworkID, prev.NetworkID)
}
if pdr.PartitionTypeID != prev.PartitionTypeID {
return fmt.Errorf("invalid partition type id, provided %d previous %d", pdr.PartitionTypeID, prev.PartitionTypeID)
}
if pdr.PartitionID != prev.PartitionID {
return fmt.Errorf("invalid partition id, provided %d previous %d", pdr.PartitionID, prev.PartitionID)
}
Expand All @@ -137,14 +140,33 @@ func (pdr *PartitionDescriptionRecord) Hash(hashAlgorithm crypto.Hash) ([]byte,
return hasher.Sum()
}

func (pdr *PartitionDescriptionRecord) GetVersion() Version {
if pdr == nil || pdr.Version == 0 {
return 1
}
return pdr.Version
}

func (pdr *PartitionDescriptionRecord) GetNetworkID() NetworkID {
return pdr.NetworkID
}

func (pdr *PartitionDescriptionRecord) GetPartitionTypeID() PartitionTypeID {
return pdr.PartitionTypeID
}

func (pdr *PartitionDescriptionRecord) GetPartitionID() PartitionID {
return pdr.PartitionID
}

func (pdr *PartitionDescriptionRecord) GetShardID() ShardID {
return pdr.ShardID
}

func (pdr *PartitionDescriptionRecord) GetPartitionParams() map[string]string {
return pdr.PartitionParams
}

/*
UnitIDValidator returns function which checks that unit ID passed as argument
has correct length and that the unit belongs into the given shard.
Expand Down Expand Up @@ -207,20 +229,13 @@ func (pdr *PartitionDescriptionRecord) ExtractUnitType(id UnitID) (uint32, error
return 0, fmt.Errorf("expected unit ID length %d bytes, got %d bytes", idLen, len(id))
}

// we relay on the fact that valid PDR has "pdr.UnitIDLen >= 64" ie it's safe to read four bytes
// we rely on the fact that valid PDR has "pdr.UnitIDLen >= 64" ie it's safe to read four bytes
idx := len(id) - 1
v := uint32(id[idx]) | (uint32(id[idx-1]) << 8) | (uint32(id[idx-2]) << 16) | (uint32(id[idx-3]) << 24)
mask := uint32(0xFFFFFFFF) >> (32 - pdr.TypeIDLen)
return v & mask, nil
}

func (pdr *PartitionDescriptionRecord) GetVersion() Version {
if pdr == nil || pdr.Version == 0 {
return 1
}
return pdr.Version
}

func (pdr *PartitionDescriptionRecord) MarshalCBOR() ([]byte, error) {
type alias PartitionDescriptionRecord
if pdr.Version == 0 {
Expand All @@ -236,3 +251,12 @@ func (pdr *PartitionDescriptionRecord) UnmarshalCBOR(data []byte) error {
}
return EnsureVersion(pdr, pdr.Version, 1)
}

func (pdr *PartitionDescriptionRecord) FindValidator(nodeID string) *NodeInfo {
for _, validator := range pdr.Validators {
if validator.NodeID == nodeID {
return validator
}
}
return nil
}
123 changes: 114 additions & 9 deletions types/root_trust_base.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package types

import (
"bytes"
"cmp"
"crypto"
"errors"
Expand All @@ -15,7 +16,10 @@ import (

type (
RootTrustBase interface {
GetVersion() Version
GetNetworkID() NetworkID
GetEpoch() uint64
GetEpochStart() uint64
VerifyQuorumSignatures(data []byte, signatures map[string]hex.Bytes) error
VerifySignature(data []byte, sig []byte, nodeID string) (uint64, error)
GetQuorumThreshold() uint64
Expand All @@ -28,7 +32,7 @@ type (
Version Version `json:"version"`
NetworkID NetworkID `json:"networkId"`
Epoch uint64 `json:"epoch"` // current epoch number
EpochStartRound uint64 `json:"epochStartRound"` // root chain round number when the epoch begins
EpochStart uint64 `json:"epochStartRound"` // root chain round number when the epoch begins
RootNodes []*NodeInfo `json:"rootNodes"` // list of all root nodes for the current epoch
QuorumThreshold uint64 `json:"quorumThreshold"` // amount of coins required to reach consensus, currently each node gets equal amount of voting power i.e. +1 for each node
StateHash hex.Bytes `json:"stateHash"` // unicity tree root hash
Expand All @@ -51,18 +55,21 @@ type (
Option func(c *trustBaseConf)

trustBaseConf struct {
quorumThreshold uint64
epoch uint64
epochStart uint64
quorumThreshold uint64
previousTrustBaseHash hex.Bytes
}
)

// NewTrustBaseGenesis creates new unsigned root trust base with default parameters.
func NewTrustBaseGenesis(networkID NetworkID, rootNodes []*NodeInfo, opts ...Option) (*RootTrustBaseV1, error) {
// NewTrustBase creates new unsigned root trust base.
func NewTrustBase(networkID NetworkID, rootNodes []*NodeInfo, opts ...Option) (*RootTrustBaseV1, error) {
if len(rootNodes) == 0 {
return nil, errors.New("nodes list is empty")
}

// init config
c := &trustBaseConf{}
c := &trustBaseConf{epoch: 1}
for _, opt := range opts {
opt(c)
}
Expand Down Expand Up @@ -93,13 +100,13 @@ func NewTrustBaseGenesis(networkID NetworkID, rootNodes []*NodeInfo, opts ...Opt
return &RootTrustBaseV1{
Version: 1,
NetworkID: networkID,
Epoch: 1,
EpochStartRound: 1,
Epoch: c.epoch,
EpochStart: c.epochStart,
RootNodes: rootNodes,
QuorumThreshold: c.quorumThreshold,
StateHash: nil,
ChangeRecordHash: nil,
PreviousEntryHash: nil,
PreviousEntryHash: c.previousTrustBaseHash,
Signatures: make(map[string]hex.Bytes),
}, nil
}
Expand All @@ -111,6 +118,24 @@ func WithQuorumThreshold(threshold uint64) Option {
}
}

func WithEpoch(epoch uint64) Option {
return func(c *trustBaseConf) {
c.epoch = epoch
}
}

func WithEpochStart(epochStart uint64) Option {
return func(c *trustBaseConf) {
c.epochStart = epochStart
}
}

func WithPreviousTrustBaseHash(previousTrustBaseHash hex.Bytes) Option {
return func(c *trustBaseConf) {
c.previousTrustBaseHash = previousTrustBaseHash
}
}

// IsValid validates that all fields are correctly set and public keys are correct.
func (n *NodeInfo) IsValid() error {
if n == nil {
Expand Down Expand Up @@ -171,7 +196,7 @@ func (r *RootTrustBaseV1) Hash(hashAlgo crypto.Hash) ([]byte, error) {
return hasher.Sum()
}

// SigBytes serializes all fields expect for the signatures field.
// SigBytes serializes all fields expect for the Signatures field itself.
func (r RootTrustBaseV1) SigBytes() ([]byte, error) {
r.Signatures = nil
bs, err := r.MarshalCBOR()
Expand Down Expand Up @@ -239,6 +264,14 @@ func (r *RootTrustBaseV1) GetNetworkID() NetworkID {
return r.NetworkID
}

func (r *RootTrustBaseV1) GetEpoch() uint64 {
return r.Epoch
}

func (r *RootTrustBaseV1) GetEpochStart() uint64 {
return r.EpochStart
}

func (r *RootTrustBaseV1) MarshalCBOR() ([]byte, error) {
type alias RootTrustBaseV1
if r.Version == 0 {
Expand All @@ -264,3 +297,75 @@ func (r *RootTrustBaseV1) getRootNode(nodeID string) *NodeInfo {
}
return nil
}

// Verify verifies the trust base, including the signatures.
//
// Common for all trust bases:
// - The current epoch signatures must be valid and reach quorum.
//
// Genesis trust base:
// - Epoch must be equal to 1.
//
// Non-genesis trust base must extend previous trust base:
// - The network identifiers must match.
// - The epoch number must be strictly greater than the previous epoch number.
// - The epoch start round must be strictly greater than the previous epoch start round.
// - The hash of the previous trust must match the previousEntryHash.
// - The previous epoch signatures must be valid and reach quorum.
func (r *RootTrustBaseV1) Verify(prev *RootTrustBaseV1) error {
if err := r.IsValid(prev); err != nil {
return err
}
return r.VerifySignatures(prev)
}

// IsValid verifies the trust base without verifying the signatures.
// Use VerifySignatures to verify the signatures.
func (r *RootTrustBaseV1) IsValid(prev *RootTrustBaseV1) error {
if prev == nil {
if r.Epoch != 1 {
return fmt.Errorf("genesis trust base epoch must be 1, got %d", r.Epoch)
}
return nil
}
if r.NetworkID != prev.NetworkID {
return fmt.Errorf("invalid network id, got %d previous %d", r.NetworkID, prev.NetworkID)
}
if r.Epoch != prev.Epoch+1 {
return fmt.Errorf("invalid epoch, got %d previous %d", r.Epoch, prev.Epoch)
}
if r.EpochStart <= prev.EpochStart {
return fmt.Errorf("invalid epoch start, got %d previous %d", r.EpochStart, prev.EpochStart)
}
prevHash, err := prev.Hash(crypto.SHA256)
if err != nil {
return fmt.Errorf("failed to calculate previous trust base hash: %w", err)
}
if !bytes.Equal(r.PreviousEntryHash, prevHash) {
return errors.New("previous trust base hash does not match")
}
return nil
}

// VerifySignatures verifies that the trust base is signed by the previous
// epoch's validators. For the genesis trust base (epoch 1), the trust base
// must be self-signed by the genesis (epoch 1) validators.
func (r *RootTrustBaseV1) VerifySignatures(prev *RootTrustBaseV1) error {
sigBytes, err := r.SigBytes()
if err != nil {
return fmt.Errorf("failed to get previous epoch sig bytes: %w", err)
}
var tb *RootTrustBaseV1
if r.Epoch == 1 {
tb = r
} else {
if prev == nil {
return errors.New("previous trust base is nil")
}
tb = prev
}
if err := tb.VerifyQuorumSignatures(sigBytes, r.Signatures); err != nil {
return fmt.Errorf("failed to verify signatures: %w", err)
}
return nil
}
Loading