diff --git a/txsystem/fc/unit_data_types.go b/txsystem/fc/unit_data_types.go index ac5e408..bd41bb5 100644 --- a/txsystem/fc/unit_data_types.go +++ b/txsystem/fc/unit_data_types.go @@ -24,6 +24,7 @@ type FeeCreditRecord struct { func NewFeeCreditRecord(balance uint64, ownerPredicate []byte, minLifetime uint64) *FeeCreditRecord { return &FeeCreditRecord{ + Version: 1, Balance: balance, OwnerPredicate: ownerPredicate, MinLifetime: minLifetime, @@ -67,18 +68,16 @@ func (b *FeeCreditRecord) GetVersion() types.Version { func (b *FeeCreditRecord) MarshalCBOR() ([]byte, error) { type alias FeeCreditRecord - if b.Version == 0 { - b.Version = b.GetVersion() + cp := *b + if cp.Version == 0 { + cp.Version = 1 } - return types.Cbor.Marshal((*alias)(b)) + return types.Cbor.Marshal((*alias)(&cp)) } func (b *FeeCreditRecord) UnmarshalCBOR(data []byte) error { type alias FeeCreditRecord - if err := types.Cbor.Unmarshal(data, (*alias)(b)); err != nil { - return err - } - return types.EnsureVersion(b, b.Version, 1) + return types.UnmarshalVersioned(1, data, (*alias)(b), b) } func (b *FeeCreditRecord) IsExpired(currentRoundNumber uint64) bool { diff --git a/txsystem/money/unit_data_types.go b/txsystem/money/unit_data_types.go index b6f0034..a0bfdb4 100644 --- a/txsystem/money/unit_data_types.go +++ b/txsystem/money/unit_data_types.go @@ -38,6 +38,7 @@ func NewUnitData(unitID types.UnitID, unitTypeExtractor types.UnitTypeExtractor) func NewBillData(value uint64, ownerPredicate []byte) *BillData { return &BillData{ + Version: 1, Value: value, OwnerPredicate: ownerPredicate, } @@ -72,16 +73,14 @@ func (b *BillData) GetVersion() types.Version { func (b *BillData) MarshalCBOR() ([]byte, error) { type alias BillData - if b.Version == 0 { - b.Version = b.GetVersion() + cp := *b + if cp.Version == 0 { + cp.Version = 1 } - return types.Cbor.Marshal((*alias)(b)) + return types.Cbor.Marshal((*alias)(&cp)) } func (b *BillData) UnmarshalCBOR(data []byte) error { type alias BillData - if err := types.Cbor.Unmarshal(data, (*alias)(b)); err != nil { - return err - } - return types.EnsureVersion(b, b.Version, 1) + return types.UnmarshalVersioned(1, data, (*alias)(b), b) } diff --git a/txsystem/orchestration/unit_data_types.go b/txsystem/orchestration/unit_data_types.go index 07ffc0e..43a16ed 100644 --- a/txsystem/orchestration/unit_data_types.go +++ b/txsystem/orchestration/unit_data_types.go @@ -39,18 +39,20 @@ func (b *VarData) GetVersion() types.Version { return 1 } +func NewVarData(epochNumber uint64) *VarData { + return &VarData{Version: 1, EpochNumber: epochNumber} +} + func (b *VarData) MarshalCBOR() ([]byte, error) { type alias VarData - if b.Version == 0 { - b.Version = b.GetVersion() + cp := *b + if cp.Version == 0 { + cp.Version = 1 } - return types.Cbor.Marshal((*alias)(b)) + return types.Cbor.Marshal((*alias)(&cp)) } func (b *VarData) UnmarshalCBOR(data []byte) error { type alias VarData - if err := types.Cbor.Unmarshal(data, (*alias)(b)); err != nil { - return err - } - return types.EnsureVersion(b, b.Version, 1) + return types.UnmarshalVersioned(1, data, (*alias)(b), b) } diff --git a/txsystem/tokens/unit_data_types.go b/txsystem/tokens/unit_data_types.go index 6ba231c..18f7c00 100644 --- a/txsystem/tokens/unit_data_types.go +++ b/txsystem/tokens/unit_data_types.go @@ -64,6 +64,7 @@ type FungibleTokenData struct { func NewFungibleTokenTypeData(attr *DefineFungibleTokenAttributes) types.UnitData { return &FungibleTokenTypeData{ + Version: 1, Symbol: attr.Symbol, Name: attr.Name, Icon: attr.Icon, @@ -77,6 +78,7 @@ func NewFungibleTokenTypeData(attr *DefineFungibleTokenAttributes) types.UnitDat func NewNonFungibleTokenTypeData(attr *DefineNonFungibleTokenAttributes) types.UnitData { return &NonFungibleTokenTypeData{ + Version: 1, Symbol: attr.Symbol, Name: attr.Name, Icon: attr.Icon, @@ -90,6 +92,7 @@ func NewNonFungibleTokenTypeData(attr *DefineNonFungibleTokenAttributes) types.U func NewNonFungibleTokenData(typeID types.UnitID, attr *MintNonFungibleTokenAttributes) types.UnitData { return &NonFungibleTokenData{ + Version: 1, TypeID: typeID, Name: attr.Name, URI: attr.URI, @@ -101,6 +104,7 @@ func NewNonFungibleTokenData(typeID types.UnitID, attr *MintNonFungibleTokenAttr func NewFungibleTokenData(typeID types.UnitID, value uint64, ownerPredicate []byte, minLifetime uint64) types.UnitData { return &FungibleTokenData{ + Version: 1, TypeID: typeID, Value: value, OwnerPredicate: ownerPredicate, @@ -141,18 +145,16 @@ func (n *NonFungibleTokenTypeData) GetVersion() types.Version { func (n *NonFungibleTokenTypeData) MarshalCBOR() ([]byte, error) { type alias NonFungibleTokenTypeData - if n.Version == 0 { - n.Version = n.GetVersion() + cp := *n + if cp.Version == 0 { + cp.Version = 1 } - return types.Cbor.Marshal((*alias)(n)) + return types.Cbor.Marshal((*alias)(&cp)) } func (n *NonFungibleTokenTypeData) UnmarshalCBOR(data []byte) error { type alias NonFungibleTokenTypeData - if err := types.Cbor.Unmarshal(data, (*alias)(n)); err != nil { - return err - } - return types.EnsureVersion(n, n.Version, 1) + return types.UnmarshalVersioned(1, data, (*alias)(n), n) } func (n *NonFungibleTokenTypeData) Owner() []byte { @@ -191,18 +193,16 @@ func (n *NonFungibleTokenData) GetVersion() types.Version { func (n *NonFungibleTokenData) MarshalCBOR() ([]byte, error) { type alias NonFungibleTokenData - if n.Version == 0 { - n.Version = n.GetVersion() + cp := *n + if cp.Version == 0 { + cp.Version = 1 } - return types.Cbor.Marshal((*alias)(n)) + return types.Cbor.Marshal((*alias)(&cp)) } func (n *NonFungibleTokenData) UnmarshalCBOR(data []byte) error { type alias NonFungibleTokenData - if err := types.Cbor.Unmarshal(data, (*alias)(n)); err != nil { - return err - } - return types.EnsureVersion(n, n.Version, 1) + return types.UnmarshalVersioned(1, data, (*alias)(n), n) } func (n *NonFungibleTokenData) GetCounter() uint64 { @@ -250,18 +250,16 @@ func (f *FungibleTokenTypeData) GetVersion() types.Version { func (b *FungibleTokenTypeData) MarshalCBOR() ([]byte, error) { type alias FungibleTokenTypeData - if b.Version == 0 { - b.Version = b.GetVersion() + cp := *b + if cp.Version == 0 { + cp.Version = 1 } - return types.Cbor.Marshal((*alias)(b)) + return types.Cbor.Marshal((*alias)(&cp)) } func (b *FungibleTokenTypeData) UnmarshalCBOR(data []byte) error { type alias FungibleTokenTypeData - if err := types.Cbor.Unmarshal(data, (*alias)(b)); err != nil { - return err - } - return types.EnsureVersion(b, b.Version, 1) + return types.UnmarshalVersioned(1, data, (*alias)(b), b) } func (f *FungibleTokenData) Write(hasher abhash.Hasher) { @@ -302,16 +300,14 @@ func (f *FungibleTokenData) GetVersion() types.Version { func (f *FungibleTokenData) MarshalCBOR() ([]byte, error) { type alias FungibleTokenData - if f.Version == 0 { - f.Version = f.GetVersion() + cp := *f + if cp.Version == 0 { + cp.Version = 1 } - return types.Cbor.Marshal((*alias)(f)) + return types.Cbor.Marshal((*alias)(&cp)) } func (f *FungibleTokenData) UnmarshalCBOR(data []byte) error { type alias FungibleTokenData - if err := types.Cbor.Unmarshal(data, (*alias)(f)); err != nil { - return err - } - return types.EnsureVersion(f, f.Version, 1) + return types.UnmarshalVersioned(1, data, (*alias)(f), f) } diff --git a/types/block.go b/types/block.go index f715298..331f750 100644 --- a/types/block.go +++ b/types/block.go @@ -207,20 +207,25 @@ func (h *Header) GetVersion() Version { return 1 } +func NewHeader() *Header { + return &Header{Version: 1} +} + func (h *Header) MarshalCBOR() ([]byte, error) { type alias Header - if h.Version == 0 { - h.Version = h.GetVersion() + cp := *h + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(BlockTag, (*alias)(h)) + return Cbor.MarshalTaggedValue(BlockTag, (*alias)(&cp)) } func (h *Header) UnmarshalCBOR(data []byte) error { type alias Header - if err := Cbor.UnmarshalTaggedValue(BlockTag, data, (*alias)(h)); err != nil { + if err := UnmarshalTaggedVersioned(BlockTag, 1, data, (*alias)(h), h); err != nil { return fmt.Errorf("failed to unmarshal block header: %w", err) } - return EnsureVersion(h, h.Version, 1) + return nil } func (h *Header) Hash(algorithm crypto.Hash) ([]byte, error) { diff --git a/types/block_test.go b/types/block_test.go index a5ead98..450d996 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -516,11 +516,14 @@ func TestBlock_CBOR(t *testing.T) { }) t.Run("block with unicity certificate", func(t *testing.T) { uc := &UnicityCertificate{ + Version: 1, InputRecord: &InputRecord{ - Version: 1, // if version is not set here, the test fails (despite the fact it's a pointer) + Version: 1, Hash: []byte{1, 1, 1}, PreviousHash: []byte{1, 1, 1}, - }} + }, + ShardTreeCertificate: NewShardTreeCertificate(), + } ucBytes, err := (uc).MarshalCBOR() require.NoError(t, err) b := Block{ diff --git a/types/input_record.go b/types/input_record.go index 3c32ee9..91cfeb1 100644 --- a/types/input_record.go +++ b/types/input_record.go @@ -137,18 +137,20 @@ func (x *InputRecord) GetVersion() Version { return 1 } +func NewInputRecord() *InputRecord { + return &InputRecord{Version: 1} +} + func (x *InputRecord) MarshalCBOR() ([]byte, error) { type alias InputRecord - if x.Version == 0 { - x.Version = x.GetVersion() + cp := *x + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(InputRecordTag, (*alias)(x)) + return Cbor.MarshalTaggedValue(InputRecordTag, (*alias)(&cp)) } func (x *InputRecord) UnmarshalCBOR(data []byte) error { type alias InputRecord - if err := Cbor.UnmarshalTaggedValue(InputRecordTag, data, (*alias)(x)); err != nil { - return err - } - return EnsureVersion(x, x.Version, 1) + return UnmarshalTaggedVersioned(InputRecordTag, 1, data, (*alias)(x), x) } diff --git a/types/input_record_test.go b/types/input_record_test.go index b811de1..b6307bf 100644 --- a/types/input_record_test.go +++ b/types/input_record_test.go @@ -126,8 +126,8 @@ func TestInputRecord_AddToHasher(t *testing.T) { ir.AddToHasher(abhasher) hash := hasher.Sum(nil) - expectedHash := []byte{0x65, 0x33, 0x67, 0xb2, 0xf2, 0xff, 0x9d, 0xa5, 0x2, 0x86, 0x2, 0x65, 0x46, 0xf6, 0x62, - 0x77, 0x89, 0x83, 0x10, 0x63, 0x60, 0x6b, 0x23, 0x60, 0xf2, 0x16, 0x61, 0x5a, 0x60, 0x16, 0x1, 0xbf} + expectedHash := []byte{0xc0, 0x8e, 0x82, 0x36, 0x7b, 0xe0, 0x7b, 0xd, 0xf7, 0xfd, 0xa4, 0x72, 0xde, 0xf0, 0xed, + 0xb5, 0xd9, 0xc5, 0x63, 0x48, 0xc1, 0x65, 0xfb, 0xfc, 0x12, 0x8e, 0x6, 0xd6, 0x67, 0xad, 0xbb, 0xe8} require.Equal(t, expectedHash, hash) } diff --git a/types/partition_description.go b/types/partition_description.go index d06f3c8..833efb1 100644 --- a/types/partition_description.go +++ b/types/partition_description.go @@ -236,20 +236,25 @@ func (pdr *PartitionDescriptionRecord) ExtractUnitType(id UnitID) (uint32, error return v & mask, nil } +func NewPartitionDescriptionRecord() *PartitionDescriptionRecord { + return &PartitionDescriptionRecord{Version: 1} +} + func (pdr *PartitionDescriptionRecord) MarshalCBOR() ([]byte, error) { type alias PartitionDescriptionRecord - if pdr.Version == 0 { - pdr.Version = pdr.GetVersion() + cp := *pdr + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(PartitionDescriptionRecordTag, (*alias)(pdr)) + return Cbor.MarshalTaggedValue(PartitionDescriptionRecordTag, (*alias)(&cp)) } func (pdr *PartitionDescriptionRecord) UnmarshalCBOR(data []byte) error { type alias PartitionDescriptionRecord - if err := Cbor.UnmarshalTaggedValue(PartitionDescriptionRecordTag, data, (*alias)(pdr)); err != nil { + if err := UnmarshalTaggedVersioned(PartitionDescriptionRecordTag, 1, data, (*alias)(pdr), pdr); err != nil { return fmt.Errorf("failed to unmarshal partition description record: %w", err) } - return EnsureVersion(pdr, pdr.Version, 1) + return nil } func (pdr *PartitionDescriptionRecord) FindValidator(nodeID string) *NodeInfo { diff --git a/types/root_trust_base.go b/types/root_trust_base.go index 3db8c1c..ecb6b32 100644 --- a/types/root_trust_base.go +++ b/types/root_trust_base.go @@ -272,20 +272,25 @@ func (r *RootTrustBaseV1) GetEpochStart() uint64 { return r.EpochStart } +func NewRootTrustBaseV1() *RootTrustBaseV1 { + return &RootTrustBaseV1{Version: 1} +} + func (r *RootTrustBaseV1) MarshalCBOR() ([]byte, error) { type alias RootTrustBaseV1 - if r.Version == 0 { - r.Version = r.GetVersion() + cp := *r + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(RootTrustBaseTag, (*alias)(r)) + return Cbor.MarshalTaggedValue(UnicityTrustBaseTag, (*alias)(&cp)) } func (r *RootTrustBaseV1) UnmarshalCBOR(data []byte) error { type alias RootTrustBaseV1 - if err := Cbor.UnmarshalTaggedValue(RootTrustBaseTag, data, (*alias)(r)); err != nil { + if err := UnmarshalTaggedVersioned(UnicityTrustBaseTag, 1, data, (*alias)(r), r); err != nil { return fmt.Errorf("failed to unmarshal root trust base: %w", err) } - return EnsureVersion(r, r.Version, 1) + return nil } func (r *RootTrustBaseV1) getRootNode(nodeID string) *NodeInfo { diff --git a/types/shard_certificate.go b/types/shard_certificate.go index 2f2ba18..3dfd603 100644 --- a/types/shard_certificate.go +++ b/types/shard_certificate.go @@ -10,11 +10,39 @@ import ( type ShardTreeCertificate struct { _ struct{} `cbor:",toarray"` + Version Version Shard ShardID SiblingHashes [][]byte } +func NewShardTreeCertificate() ShardTreeCertificate { + return ShardTreeCertificate{Version: 1} +} + +func (cert ShardTreeCertificate) GetVersion() Version { + if cert.Version > 0 { + return cert.Version + } + return 1 +} + +func (cert ShardTreeCertificate) MarshalCBOR() ([]byte, error) { + type alias ShardTreeCertificate + if cert.Version == 0 { + cert.Version = 1 + } + return Cbor.MarshalTaggedValue(ShardTreeCertificateTag, alias(cert)) +} + +func (cert *ShardTreeCertificate) UnmarshalCBOR(data []byte) error { + type alias ShardTreeCertificate + return UnmarshalTaggedVersioned(ShardTreeCertificateTag, 1, data, (*alias)(cert), cert) +} + func (cert ShardTreeCertificate) IsValid() error { + if cert.Version != 1 { + return ErrInvalidVersion(cert) + } if cnt := uint(len(cert.SiblingHashes)); cnt != cert.Shard.Length() { return fmt.Errorf("shard ID is %d bits but got %d sibling hashes", cert.Shard.Length(), cnt) } @@ -190,6 +218,7 @@ func (tree ShardTree) Certificate(shardID ShardID) (ShardTreeCertificate, error) return ShardTreeCertificate{}, fmt.Errorf("shard %q is not in the tree", shardID) } return ShardTreeCertificate{ + Version: 1, Shard: shardID, SiblingHashes: tree.siblingHashes(shardID), }, nil diff --git a/types/tx_order.go b/types/tx_order.go index 4f0fb71..53b5573 100644 --- a/types/tx_order.go +++ b/types/tx_order.go @@ -231,20 +231,22 @@ func (t *TransactionOrder) GetVersion() Version { return t.Version } +func NewTransactionOrder() *TransactionOrder { + return &TransactionOrder{Version: 1} +} + func (t *TransactionOrder) MarshalCBOR() ([]byte, error) { type alias TransactionOrder - if t.Version == 0 { - t.Version = t.GetVersion() + cp := *t + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(TransactionOrderTag, (*alias)(t)) + return Cbor.MarshalTaggedValue(TransactionOrderTag, (*alias)(&cp)) } func (t *TransactionOrder) UnmarshalCBOR(data []byte) error { type alias TransactionOrder - if err := Cbor.UnmarshalTaggedValue(TransactionOrderTag, data, (*alias)(t)); err != nil { - return err - } - return EnsureVersion(t, t.Version, 1) + return UnmarshalTaggedVersioned(TransactionOrderTag, 1, data, (*alias)(t), t) } func (t *TransactionOrder) AddStateUnlockCommitProof(unlockProof []byte) { diff --git a/types/tx_proof.go b/types/tx_proof.go index fc7ea7a..e1ec2f0 100644 --- a/types/tx_proof.go +++ b/types/tx_proof.go @@ -162,18 +162,20 @@ func (p *TxProof) GetVersion() Version { return 1 } +func NewTxProof() *TxProof { + return &TxProof{Version: 1} +} + func (p *TxProof) MarshalCBOR() ([]byte, error) { type alias TxProof - if p.Version == 0 { - p.Version = p.GetVersion() + cp := *p + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(TxProofTag, (*alias)(p)) + return Cbor.MarshalTaggedValue(TxProofTag, (*alias)(&cp)) } func (p *TxProof) UnmarshalCBOR(data []byte) error { type alias TxProof - if err := Cbor.UnmarshalTaggedValue(TxProofTag, data, (*alias)(p)); err != nil { - return err - } - return EnsureVersion(p, p.Version, 1) + return UnmarshalTaggedVersioned(TxProofTag, 1, data, (*alias)(p), p) } diff --git a/types/tx_record.go b/types/tx_record.go index aa983b1..29afb9f 100644 --- a/types/tx_record.go +++ b/types/tx_record.go @@ -130,20 +130,22 @@ func (t *TransactionRecord) GetVersion() Version { return t.Version } +func NewTransactionRecord() *TransactionRecord { + return &TransactionRecord{Version: 1} +} + func (t *TransactionRecord) MarshalCBOR() ([]byte, error) { type alias TransactionRecord - if t.Version == 0 { - t.Version = t.GetVersion() + cp := *t + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(TransactionRecordTag, (*alias)(t)) + return Cbor.MarshalTaggedValue(TransactionRecordTag, (*alias)(&cp)) } func (t *TransactionRecord) UnmarshalCBOR(data []byte) error { type alias TransactionRecord - if err := Cbor.UnmarshalTaggedValue(TransactionRecordTag, data, (*alias)(t)); err != nil { - return err - } - return EnsureVersion(t, t.Version, 1) + return UnmarshalTaggedVersioned(TransactionRecordTag, 1, data, (*alias)(t), t) } func (sm *ServerMetadata) GetActualFee() uint64 { diff --git a/types/unicity_certificate.go b/types/unicity_certificate.go index f3ff10f..c114888 100644 --- a/types/unicity_certificate.go +++ b/types/unicity_certificate.go @@ -265,18 +265,20 @@ func (x *UnicityCertificate) GetVersion() Version { return 1 } +func NewUnicityCertificate() *UnicityCertificate { + return &UnicityCertificate{Version: 1} +} + func (x *UnicityCertificate) MarshalCBOR() ([]byte, error) { type alias UnicityCertificate - if x.Version == 0 { - x.Version = x.GetVersion() + cp := *x + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(UnicityCertificateTag, (*alias)(x)) + return Cbor.MarshalTaggedValue(UnicityCertificateTag, (*alias)(&cp)) } func (x *UnicityCertificate) UnmarshalCBOR(data []byte) error { type alias UnicityCertificate - if err := Cbor.UnmarshalTaggedValue(UnicityCertificateTag, data, (*alias)(x)); err != nil { - return err - } - return EnsureVersion(x, x.Version, 1) + return UnmarshalTaggedVersioned(UnicityCertificateTag, 1, data, (*alias)(x), x) } diff --git a/types/unicity_certificate_test.go b/types/unicity_certificate_test.go index 2732a5b..a946353 100644 --- a/types/unicity_certificate_test.go +++ b/types/unicity_certificate_test.go @@ -228,7 +228,7 @@ func TestUnicityCertificate_Verify(t *testing.T) { uc := validUC(t, sid0, &ir0, trHash0, shardConf0Hash) uc.UnicitySeal.Hash = []byte{1, 2, 3} require.EqualError(t, uc.Verify(tb, crypto.SHA256, shardConf0.PartitionID, shardConf0.ShardID, shardConf0Hash), - "unicity seal hash 010203 does not match with the root hash of the unicity tree F06B596575FAE5F211C9738A657C55A13D06F7E22CE40F02A4682FDA7C1FD44F") + "unicity seal hash 010203 does not match with the root hash of the unicity tree BD0AC73BEA6F2A497A6AECB5DD013BB3700916DB075371F9C8C20A86C4B38D59") }) } @@ -836,7 +836,7 @@ func Test_UnicityCertificate_Cbor(t *testing.T) { Version: 1, InputRecord: &InputRecord{Version: 1}, TRHash: []byte{1, 2, 3, 4, 5}, - ShardTreeCertificate: ShardTreeCertificate{Shard: ShardID{}}, + ShardTreeCertificate: ShardTreeCertificate{Version: 1, Shard: ShardID{}}, UnicityTreeCertificate: &UnicityTreeCertificate{Version: 1}, UnicitySeal: &UnicitySeal{ Version: 1, @@ -858,7 +858,7 @@ func Test_UnicityCertificate_Cbor(t *testing.T) { //uc := &UnicityCertificate{InputRecord: &InputRecord{}, TRHash: []byte{1}, UnicityTreeCertificate: &UnicityTreeCertificate{}, UnicitySeal: &UnicitySeal{}} //_ucData, _ := uc.MarshalCBOR() //fmt.Printf("ucData: 0x%X\n", _ucData) - ucData, err := hex.Decode([]byte("0xD903EF8701D903F08A010000F6F6F600F600F64101F6824180F6D903F6830100F6D903E9880100000000F6F6F6")) + ucData, err := hex.Decode([]byte("0xD998598701D9985A8A010000F6F6F600F600F64101F6D9985B83014180F6D9985C830100F6D9985D880100000000F6F6F6")) require.NoError(t, err) uc1 := &UnicityCertificate{} diff --git a/types/unicity_seal.go b/types/unicity_seal.go index be5ac04..55f9da7 100644 --- a/types/unicity_seal.go +++ b/types/unicity_seal.go @@ -128,12 +128,17 @@ func (x *UnicitySeal) AddToHasher(hasher abhash.Hasher) { hasher.Write(x) } +func NewUnicitySeal() *UnicitySeal { + return &UnicitySeal{Version: 1} +} + func (x *UnicitySeal) MarshalCBOR() ([]byte, error) { type alias UnicitySeal - if x.Version == 0 { - x.Version = x.GetVersion() + cp := *x + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(UnicitySealTag, (*alias)(x)) + return Cbor.MarshalTaggedValue(UnicitySealTag, (*alias)(&cp)) } func (x *UnicitySeal) UnmarshalCBOR(b []byte) (err error) { diff --git a/types/unicity_seal_test.go b/types/unicity_seal_test.go index 1fb45d6..622a66d 100644 --- a/types/unicity_seal_test.go +++ b/types/unicity_seal_test.go @@ -268,7 +268,7 @@ func TestUnicitySeal_UnmarshalCBOR(t *testing.T) { require.NoError(t, err) seal := &UnicitySeal{} err = seal.UnmarshalCBOR(data) - require.EqualError(t, err, "unmarshaling UnicitySeal: expected tag 1001, got 1000") + require.EqualError(t, err, "unmarshaling UnicitySeal: expected tag 39005, got 1000") }) t.Run("Invalid encoding", func(t *testing.T) { diff --git a/types/unicity_tree_certificate.go b/types/unicity_tree_certificate.go index 4c01ce3..4796d60 100644 --- a/types/unicity_tree_certificate.go +++ b/types/unicity_tree_certificate.go @@ -92,20 +92,22 @@ func (utc *UnicityTreeCertificate) GetVersion() Version { return 1 } +func NewUnicityTreeCertificate() *UnicityTreeCertificate { + return &UnicityTreeCertificate{Version: 1} +} + func (utc *UnicityTreeCertificate) MarshalCBOR() ([]byte, error) { type alias UnicityTreeCertificate - if utc.Version == 0 { - utc.Version = utc.GetVersion() + cp := *utc + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(UnicityTreeCertificateTag, (*alias)(utc)) + return Cbor.MarshalTaggedValue(UnicityTreeCertificateTag, (*alias)(&cp)) } func (utc *UnicityTreeCertificate) UnmarshalCBOR(data []byte) error { type alias UnicityTreeCertificate - if err := Cbor.UnmarshalTaggedValue(UnicityTreeCertificateTag, data, (*alias)(utc)); err != nil { - return err - } - return EnsureVersion(utc, utc.Version, 1) + return UnmarshalTaggedVersioned(UnicityTreeCertificateTag, 1, data, (*alias)(utc), utc) } func (p *PathItem) ToIMTPathItem() *imt.PathItem { diff --git a/types/unit_proof.go b/types/unit_proof.go index 11cd315..d8fc3f7 100644 --- a/types/unit_proof.go +++ b/types/unit_proof.go @@ -235,18 +235,20 @@ func (u *UnitStateProof) GetVersion() Version { return 1 } +func NewUnitStateProof() *UnitStateProof { + return &UnitStateProof{Version: 1} +} + func (u *UnitStateProof) MarshalCBOR() ([]byte, error) { type alias UnitStateProof - if u.Version == 0 { - u.Version = u.GetVersion() + cp := *u + if cp.Version == 0 { + cp.Version = 1 } - return Cbor.MarshalTaggedValue(UnitStateProofTag, (*alias)(u)) + return Cbor.MarshalTaggedValue(UnitStateProofTag, (*alias)(&cp)) } func (u *UnitStateProof) UnmarshalCBOR(data []byte) error { type alias UnitStateProof - if err := Cbor.UnmarshalTaggedValue(UnitStateProofTag, data, (*alias)(u)); err != nil { - return err - } - return EnsureVersion(u, u.Version, 1) + return UnmarshalTaggedVersioned(UnitStateProofTag, 1, data, (*alias)(u), u) } diff --git a/types/versions.go b/types/versions.go index e7562a0..3ce963f 100644 --- a/types/versions.go +++ b/types/versions.go @@ -16,25 +16,21 @@ type Versioned interface { } const ( - _ = iota + CborTag(1000) - UnicitySealTag - RootGenesisTag - GenesisRootRecordTag - ConsensusParamsTag - GenesisPartitionRecordTag - PartitionNodeTag - UnicityCertificateTag - InputRecordTag - TxProofTag - UnitStateProofTag - PartitionDescriptionRecordTag - BlockTag - RootTrustBaseTag - UnicityTreeCertificateTag - TransactionRecordTag - TransactionOrderTag - RootPartitionBlockDataTag - RootPartitionRoundInfoTag + // https://github.com/unicitynetwork/unicity-ids/blob/main/cbor-tags.json + UnicityTrustBaseTag CborTag = 39000 + UnicityCertificateTag CborTag = 39001 + InputRecordTag CborTag = 39002 + ShardTreeCertificateTag CborTag = 39003 + UnicityTreeCertificateTag CborTag = 39004 + UnicitySealTag CborTag = 39005 + RootPartitionBlockDataTag CborTag = 39006 + RootPartitionRoundInfoTag CborTag = 39007 + PartitionDescriptionRecordTag CborTag = 39008 + BlockTag CborTag = 39009 + TransactionRecordTag CborTag = 39010 + TransactionOrderTag CborTag = 39011 + TxProofTag CborTag = 39012 + UnitStateProofTag CborTag = 39013 ) func ErrInvalidVersion(s Versioned) error { @@ -49,6 +45,32 @@ func EnsureVersion(data Versioned, actual, expected Version) error { return nil } +// UnmarshalTaggedVersioned decodes tagged CBOR into aliasPtr and verifies that +// v.GetVersion() equals expectedVersion. It centralizes the decode-then-check +// pattern used by every Versioned type in this package. +func UnmarshalTaggedVersioned[A any](tag CborTag, expectedVersion Version, data []byte, aliasPtr *A, v Versioned) error { + if err := Cbor.UnmarshalTaggedValue(tag, data, aliasPtr); err != nil { + return err + } + if got := v.GetVersion(); got != expectedVersion { + return fmt.Errorf("invalid version (type %T), expected %d, got %d", v, expectedVersion, got) + } + return nil +} + +// UnmarshalVersioned decodes untagged CBOR into aliasPtr and verifies that +// v.GetVersion() equals expectedVersion. Used by unit data types that are +// CBOR-embedded inside a larger structure and carry no outer tag. +func UnmarshalVersioned[A any](expectedVersion Version, data []byte, aliasPtr *A, v Versioned) error { + if err := Cbor.Unmarshal(data, aliasPtr); err != nil { + return err + } + if got := v.GetVersion(); got != expectedVersion { + return fmt.Errorf("invalid version (type %T), expected %d, got %d", v, expectedVersion, got) + } + return nil +} + func parseTaggedCBOR(b []byte, objID CborTag) (Version, []any, error) { tag, arr, err := Cbor.UnmarshalTagged(b) if err != nil {