diff --git a/README.md b/README.md index b5e25cf..bf42658 100644 --- a/README.md +++ b/README.md @@ -366,7 +366,6 @@ for d in 255..=0: h = H(0x01 || d || h || siblings[j]) assert j == 0 and h == UC.IR.h ``` -``` #### `get_block_height` Retrieve the current blockchain height. diff --git a/cmd/performance-test/main.go b/cmd/performance-test/main.go index e3f2c0d..1fbdf95 100644 --- a/cmd/performance-test/main.go +++ b/cmd/performance-test/main.go @@ -5,12 +5,10 @@ import ( "bytes" "context" "crypto/rand" - "encoding/hex" "encoding/json" "fmt" "log" "math" - "math/big" "os" "path/filepath" "runtime" @@ -198,32 +196,19 @@ func selectShardIndex(requestID api.StateID, shardClients []*ShardClient) int { if len(imprint) == 0 { return 0 } - return int(imprint[len(imprint)-1]) % shardCount + keyBytes := requestID.DataBytes() + if len(keyBytes) == 0 { + return 0 + } + // Fallback only: keep distribution aligned with LSB-first key layout. + return int(keyBytes[0]) % shardCount } func matchesShardMask(requestIDHex string, shardMask int) (bool, error) { if shardMask <= 0 { return false, nil } - - bytes, err := hex.DecodeString(requestIDHex) - if err != nil { - return false, fmt.Errorf("failed to decode request ID: %w", err) - } - - requestBig := new(big.Int).SetBytes(bytes) - maskBig := new(big.Int).SetInt64(int64(shardMask)) - - msbPos := maskBig.BitLen() - 1 - if msbPos < 0 { - return false, fmt.Errorf("invalid shard mask: %d", shardMask) - } - - compareMask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(msbPos)), big.NewInt(1)) - expected := new(big.Int).And(maskBig, compareMask) - requestLowBits := new(big.Int).And(requestBig, compareMask) - - return requestLowBits.Cmp(expected) == 0, nil + return api.MatchesShardPrefixFromHex(requestIDHex, shardMask) } // matchesAnyShardTarget checks if a request ID matches any of the configured shard targets @@ -409,7 +394,7 @@ func commitmentWorker(ctx context.Context, shardClients []*ShardClient, metrics if proofQueue != nil { metrics.recordSubmissionTimestamp(requestIDStr) select { - case proofQueue <- proofJob{shardIdx: shardIdx, requestID: requestIDStr}: + case proofQueue <- proofJob{shardIdx: shardIdx, request: req}: default: // Queue full, skip proof verification for this one } @@ -435,10 +420,21 @@ func proofVerificationWorker(ctx context.Context, shardClients []*ShardClient, m case <-ctx.Done(): return case job := <-proofQueue: - go func(reqID string, shardIdx int) { + go func(job proofJob) { + if job.request == nil { + metrics.recordError("Missing original request for proof verification") + atomic.AddInt64(&metrics.proofVerifyFailed, 1) + if sm := metrics.shard(job.shardIdx); sm != nil { + sm.proofVerifyFailed.Add(1) + } + return + } + + reqID := normalizeRequestID(job.request.StateID.String()) + shardIdx := job.shardIdx time.Sleep(proofInitialDelay) startTime := time.Now() - normalizedID := normalizeRequestID(reqID) + normalizedID := reqID client := shardClients[shardIdx].proofClient // Use separate proof client pool for attempt := 0; attempt < proofMaxRetries; attempt++ { @@ -534,16 +530,12 @@ func proofVerificationWorker(ctx context.Context, shardClients []*ShardClient, m } metrics.addProofLatency(totalLatency) - // TODO: Wire api.InclusionProofV2.Verify(*CertificationRequest) - // verification here. For now we only check that the response carries a - // non-empty inclusion cert; perf tests care about throughput, not - // cryptographic verification correctness. - if len(proofResp.InclusionProof.CertificateBytes) == 0 { + if err := proofResp.InclusionProof.Verify(job.request); err != nil { if attempt < proofMaxRetries-1 { time.Sleep(proofRetryDelay) continue } - metrics.recordError(fmt.Sprintf("Empty certificate bytes for request %s", reqID)) + metrics.recordError(fmt.Sprintf("Proof verification failed for request %s: %v", reqID, err)) atomic.AddInt64(&metrics.proofVerifyFailed, 1) if sm := metrics.shard(shardIdx); sm != nil { sm.proofVerifyFailed.Add(1) @@ -563,7 +555,7 @@ func proofVerificationWorker(ctx context.Context, shardClients []*ShardClient, m if sm := metrics.shard(shardIdx); sm != nil { sm.proofFailed.Add(1) } - }(job.requestID, job.shardIdx) + }(job) } } } diff --git a/cmd/performance-test/types.go b/cmd/performance-test/types.go index ff6c8ca..55b4246 100644 --- a/cmd/performance-test/types.go +++ b/cmd/performance-test/types.go @@ -17,6 +17,7 @@ import ( "sync/atomic" "time" + "github.com/unicitynetwork/aggregator-go/pkg/api" "golang.org/x/net/http2" ) @@ -109,8 +110,8 @@ type ShardClient struct { } type proofJob struct { - shardIdx int - requestID string + shardIdx int + request *api.CertificationRequest } // RequestRateCounters tracks per-second client-side request activity. diff --git a/internal/bft/client_stub_test.go b/internal/bft/client_stub_test.go index b6a176c..a584d85 100644 --- a/internal/bft/client_stub_test.go +++ b/internal/bft/client_stub_test.go @@ -47,7 +47,6 @@ func TestBFTClientStub_CertificationRequest_PopulatesSyntheticUC(t *testing.T) { api.HexBytes("0123"), nil, nil, - nil, ) err = client.CertificationRequest(t.Context(), block) diff --git a/internal/ha/block_syncer_test.go b/internal/ha/block_syncer_test.go index eb87fc8..80f22f0 100644 --- a/internal/ha/block_syncer_test.go +++ b/internal/ha/block_syncer_test.go @@ -122,7 +122,7 @@ func createBlock(t *testing.T, storage *mongodb.Storage, blockNum int64) api.Hex rootHash := api.HexBytes(tmpSMT.GetRootHashRaw()) // persist block - block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", rootHash, nil, nil, nil) + block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", rootHash, nil, nil) block.Finalized = true // Mark as finalized so GetLatestNumber finds it err = storage.BlockStorage().Store(ctx, block) require.NoError(t, err) diff --git a/internal/models/block.go b/internal/models/block.go index 7fa706b..72495d5 100644 --- a/internal/models/block.go +++ b/internal/models/block.go @@ -1,7 +1,6 @@ package models import ( - "encoding/json" "fmt" "time" @@ -12,18 +11,19 @@ import ( // Block represents a blockchain block type Block struct { - Index *api.BigInt `json:"index"` - ChainID string `json:"chainId"` - ShardID api.ShardID `json:"shardId"` - Version string `json:"version"` - ForkID string `json:"forkId"` - RootHash api.HexBytes `json:"rootHash"` - PreviousBlockHash api.HexBytes `json:"previousBlockHash"` - NoDeletionProofHash api.HexBytes `json:"noDeletionProofHash"` - CreatedAt *api.Timestamp `json:"createdAt"` - UnicityCertificate api.HexBytes `json:"unicityCertificate"` - ParentMerkleTreePath *api.MerkleTreePath `json:"parentMerkleTreePath,omitempty"` // child mode only - Finalized bool `json:"finalized"` // true when all data is persisted + Index *api.BigInt `json:"index"` + ChainID string `json:"chainId"` + ShardID api.ShardID `json:"shardId"` + Version string `json:"version"` + ForkID string `json:"forkId"` + RootHash api.HexBytes `json:"rootHash"` + PreviousBlockHash api.HexBytes `json:"previousBlockHash"` + NoDeletionProofHash api.HexBytes `json:"noDeletionProofHash"` + CreatedAt *api.Timestamp `json:"createdAt"` + UnicityCertificate api.HexBytes `json:"unicityCertificate"` + ParentFragment *api.ParentInclusionFragment `json:"parentFragment,omitempty"` // child mode only + ParentBlockNumber uint64 `json:"parentBlockNumber,omitempty"` // child mode only + Finalized bool `json:"finalized"` // true when all data is persisted } // BlockBSON represents the BSON version of Block for MongoDB storage @@ -38,23 +38,29 @@ type BlockBSON struct { NoDeletionProofHash string `bson:"noDeletionProofHash,omitempty"` CreatedAt time.Time `bson:"createdAt"` UnicityCertificate string `bson:"unicityCertificate"` - MerkleTreePath string `bson:"merkleTreePath,omitempty"` // child mode only + ParentFragment *ParentFragmentBSON `bson:"parentFragment,omitempty"` // child mode only + ParentBlockNumber uint64 `bson:"parentBlockNumber,omitempty"` Finalized bool `bson:"finalized"` } +// ParentFragmentBSON is the BSON representation of ParentInclusionFragment. +type ParentFragmentBSON struct { + CertificateBytes []byte `bson:"certificateBytes"` + ShardLeafValue []byte `bson:"shardLeafValue"` +} + // ToBSON converts Block to BlockBSON for MongoDB storage func (b *Block) ToBSON() (*BlockBSON, error) { indexDecimal, err := primitive.ParseDecimal128(b.Index.String()) if err != nil { return nil, fmt.Errorf("error converting block index to decimal-128: %w", err) } - var merkleTreePath string - if b.ParentMerkleTreePath != nil { - merkleTreePathJson, err := json.Marshal(b.ParentMerkleTreePath) - if err != nil { - return nil, fmt.Errorf("failed to marshal parent merkle tree path: %w", err) + var parentFragment *ParentFragmentBSON + if b.ParentFragment != nil { + parentFragment = &ParentFragmentBSON{ + CertificateBytes: append([]byte(nil), b.ParentFragment.CertificateBytes...), + ShardLeafValue: append([]byte(nil), b.ParentFragment.ShardLeafValue...), } - merkleTreePath = api.NewHexBytes(merkleTreePathJson).String() } return &BlockBSON{ Index: indexDecimal, @@ -67,7 +73,8 @@ func (b *Block) ToBSON() (*BlockBSON, error) { NoDeletionProofHash: b.NoDeletionProofHash.String(), CreatedAt: b.CreatedAt.Time, UnicityCertificate: b.UnicityCertificate.String(), - MerkleTreePath: merkleTreePath, + ParentFragment: parentFragment, + ParentBlockNumber: b.ParentBlockNumber, Finalized: b.Finalized, }, nil } @@ -94,15 +101,11 @@ func (bb *BlockBSON) FromBSON() (*Block, error) { return nil, fmt.Errorf("failed to parse unicityCertificate: %w", err) } - var parentMerkleTreePath *api.MerkleTreePath - if bb.MerkleTreePath != "" { - hexBytes, err := api.NewHexBytesFromString(bb.MerkleTreePath) - if err != nil { - return nil, fmt.Errorf("failed to parse parentMerkleTreePath: %w", err) - } - parentMerkleTreePath = &api.MerkleTreePath{} - if err := json.Unmarshal(hexBytes, parentMerkleTreePath); err != nil { - return nil, fmt.Errorf("failed to parse parentMerkleTreePath: %w", err) + var parentFragment *api.ParentInclusionFragment + if bb.ParentFragment != nil { + parentFragment = &api.ParentInclusionFragment{ + CertificateBytes: append([]byte(nil), bb.ParentFragment.CertificateBytes...), + ShardLeafValue: append([]byte(nil), bb.ParentFragment.ShardLeafValue...), } } @@ -112,33 +115,41 @@ func (bb *BlockBSON) FromBSON() (*Block, error) { } return &Block{ - Index: index, - ChainID: bb.ChainID, - ShardID: bb.ShardID, - Version: bb.Version, - ForkID: bb.ForkID, - RootHash: rootHash, - PreviousBlockHash: previousBlockHash, - NoDeletionProofHash: noDeletionProofHash, - CreatedAt: api.NewTimestamp(bb.CreatedAt), - UnicityCertificate: unicityCertificate, - ParentMerkleTreePath: parentMerkleTreePath, - Finalized: bb.Finalized, + Index: index, + ChainID: bb.ChainID, + ShardID: bb.ShardID, + Version: bb.Version, + ForkID: bb.ForkID, + RootHash: rootHash, + PreviousBlockHash: previousBlockHash, + NoDeletionProofHash: noDeletionProofHash, + CreatedAt: api.NewTimestamp(bb.CreatedAt), + UnicityCertificate: unicityCertificate, + ParentFragment: parentFragment, + ParentBlockNumber: bb.ParentBlockNumber, + Finalized: bb.Finalized, }, nil } // NewBlock creates a new block -func NewBlock(index *api.BigInt, chainID string, shardID api.ShardID, version, forkID string, rootHash, previousBlockHash, uc api.HexBytes, parentMerkleTreePath *api.MerkleTreePath) *Block { +func NewBlock(index *api.BigInt, chainID string, shardID api.ShardID, version, forkID string, rootHash, previousBlockHash, uc api.HexBytes) *Block { return &Block{ - Index: index, - ChainID: chainID, - ShardID: shardID, - Version: version, - ForkID: forkID, - RootHash: rootHash, - PreviousBlockHash: previousBlockHash, - CreatedAt: api.Now(), - UnicityCertificate: uc, - ParentMerkleTreePath: parentMerkleTreePath, + Index: index, + ChainID: chainID, + ShardID: shardID, + Version: version, + ForkID: forkID, + RootHash: rootHash, + PreviousBlockHash: previousBlockHash, + CreatedAt: api.Now(), + UnicityCertificate: uc, } } + +// NewChildBlock creates a block for child mode with required parent proof metadata. +func NewChildBlock(index *api.BigInt, chainID string, shardID api.ShardID, version, forkID string, rootHash, previousBlockHash, uc api.HexBytes, parentFragment *api.ParentInclusionFragment, parentBlockNumber uint64) *Block { + block := NewBlock(index, chainID, shardID, version, forkID, rootHash, previousBlockHash, uc) + block.ParentFragment = parentFragment + block.ParentBlockNumber = parentBlockNumber + return block +} diff --git a/internal/models/block_test.go b/internal/models/block_test.go index 777962b..83431c5 100644 --- a/internal/models/block_test.go +++ b/internal/models/block_test.go @@ -35,8 +35,10 @@ func createTestBlock() *Block { NoDeletionProofHash: randomHash, CreatedAt: api.Now(), UnicityCertificate: randomHash, - ParentMerkleTreePath: &api.MerkleTreePath{ - Root: randomHash.String(), + ParentFragment: &api.ParentInclusionFragment{ + CertificateBytes: randomHash, + ShardLeafValue: randomHash, }, + ParentBlockNumber: 7, } } diff --git a/internal/round/batch_processor.go b/internal/round/batch_processor.go index e40f449..269caa8 100644 --- a/internal/round/batch_processor.go +++ b/internal/round/batch_processor.go @@ -127,7 +127,6 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn rootHash, parentHash, nil, - nil, ) rm.roundMutex.RLock() if rm.currentRound != nil && !rm.currentRound.StartTime.IsZero() { @@ -175,7 +174,7 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn err error ) for { - proof, parentUC, err = rm.pollForParentProof(ctx, rootHash.String()) + proof, parentUC, err = rm.pollForParentProof(ctx, rootHash) if err == nil { break } @@ -199,7 +198,11 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn } rm.roundMutex.Unlock() - block := models.NewBlock( + if proof.ParentFragment == nil { + return fmt.Errorf("parent shard proof missing native parent fragment") + } + + block := models.NewChildBlock( blockNumber, rm.config.Chain.ID, request.ShardID, @@ -208,7 +211,8 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn rootHash, parentHash, proof.UnicityCertificate, - proof.MerkleTreePath, + proof.ParentFragment, + proof.BlockNumber, ) if err := rm.FinalizeBlockWithRetry(ctx, block); err != nil { return fmt.Errorf("failed to finalize block after retries: %w", err) @@ -250,7 +254,7 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn } } -func (rm *RoundManager) pollForParentProof(ctx context.Context, rootHash string) (*api.RootShardInclusionProof, *types.UnicityCertificate, error) { +func (rm *RoundManager) pollForParentProof(ctx context.Context, rootHash api.HexBytes) (*api.RootShardInclusionProof, *types.UnicityCertificate, error) { pollingCtx, cancel := context.WithTimeout(ctx, rm.config.Sharding.Child.ParentPollTimeout) defer cancel() @@ -266,27 +270,27 @@ func (rm *RoundManager) pollForParentProof(ctx context.Context, rootHash string) } metrics.ParentProofErrorsTotal.Inc() rm.logger.WithContext(ctx).Warn("Timed out waiting for parent shard inclusion proof", - "rootHash", rootHash, + "rootHash", rootHash.String(), "pollDuration", time.Since(pollStart)) - return nil, nil, fmt.Errorf("%w: %s", ErrParentProofPollTimeout, rootHash) + return nil, nil, fmt.Errorf("%w: %s", ErrParentProofPollTimeout, rootHash.String()) case <-ticker.C: request := &api.GetShardProofRequest{ShardID: rm.config.Sharding.Child.ShardID} proof, err := rm.rootClient.GetShardProof(pollingCtx, request) if err != nil { metrics.ParentProofErrorsTotal.Inc() rm.logger.WithContext(ctx).Warn("Failed to fetch parent shard inclusion proof, retrying", - "rootHash", rootHash, + "rootHash", rootHash.String(), "error", err.Error()) continue } - if proof == nil || !proof.IsValid(rootHash) { + if proof == nil || !proof.IsValid(rm.config.Sharding.Child.ShardID, rm.config.Sharding.ShardIDLength, rootHash) { continue } parentUC, err := decodeUnicityCertificate(proof.UnicityCertificate) if err != nil { rm.logger.WithContext(ctx).Warn("Failed to decode parent shard proof UC, retrying", - "rootHash", rootHash, + "rootHash", rootHash.String(), "error", err.Error()) continue } @@ -294,7 +298,7 @@ func (rm *RoundManager) pollForParentProof(ctx context.Context, rootHash string) lastParentRound := rm.lastAcceptedParentUC() if parentUC.GetRoundNumber() <= lastParentRound { rm.logger.WithContext(ctx).Debug("Ignoring stale parent shard proof", - "rootHash", rootHash, + "rootHash", rootHash.String(), "proofParentRound", parentUC.GetRoundNumber(), "lastAcceptedParentRound", lastParentRound) continue @@ -384,6 +388,10 @@ func (rm *RoundManager) FinalizeBlockWithRetry(ctx context.Context, block *model // FinalizeBlock creates and persists a new block with the given data func (rm *RoundManager) FinalizeBlock(ctx context.Context, block *models.Block) error { + if err := rm.validateBlockForMode(block); err != nil { + return err + } + rm.logger.WithContext(ctx).Info("FinalizeBlock called", "blockNumber", block.Index.String(), "rootHash", block.RootHash.String(), @@ -584,6 +592,21 @@ func (rm *RoundManager) FinalizeBlock(ctx context.Context, block *models.Block) return nil } +func (rm *RoundManager) validateBlockForMode(block *models.Block) error { + if block == nil { + return errors.New("block is nil") + } + if rm.config.Sharding.Mode.IsChild() { + if block.ParentFragment == nil { + return errors.New("child-mode block missing parent fragment") + } + if block.ParentBlockNumber == 0 { + return errors.New("child-mode block missing parent block number") + } + } + return nil +} + // convertLeavesToNodes converts SMT leaves to storage models func (rm *RoundManager) convertLeavesToNodes(leaves []*smt.Leaf) ([]*models.SmtNode, error) { if len(leaves) == 0 { diff --git a/internal/round/finalize_duplicate_test.go b/internal/round/finalize_duplicate_test.go index 270bb14..ac54932 100644 --- a/internal/round/finalize_duplicate_test.go +++ b/internal/round/finalize_duplicate_test.go @@ -122,7 +122,6 @@ func (s *FinalizeDuplicateTestSuite) Test1_DuplicateRecovery() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // FinalizeBlock should succeed despite duplicates @@ -181,7 +180,6 @@ func (s *FinalizeDuplicateTestSuite) Test2_NoDuplicates() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // Should succeed on first try (no duplicates) @@ -247,7 +245,6 @@ func (s *FinalizeDuplicateTestSuite) Test3_AllDuplicates() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // Should succeed even when all records are duplicates @@ -305,7 +302,6 @@ func (s *FinalizeDuplicateTestSuite) Test4_DuplicateBlock() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // Pre-store the block (simulating previous attempt that stored block but failed on MarkProcessed) @@ -387,7 +383,6 @@ func (s *FinalizeDuplicateTestSuite) Test5_DuplicateBlockAlreadyFinalized() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // Pre-store the block as FINALIZED (simulating previous successful attempt except MarkProcessed) @@ -487,7 +482,6 @@ func (s *FinalizeDuplicateTestSuite) Test6_BlockRecordsMatchPendingCommitmentsOn rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) err = rm.FinalizeBlock(ctx, block) diff --git a/internal/round/parent_round_manager.go b/internal/round/parent_round_manager.go index 14b05ad..dc87da8 100644 --- a/internal/round/parent_round_manager.go +++ b/internal/round/parent_round_manager.go @@ -345,7 +345,6 @@ func (prm *ParentRoundManager) processRound(ctx context.Context, round *ParentRo parentRootHash, previousBlockHash, nil, - nil, ) round.Block = block diff --git a/internal/round/precollection_test.go b/internal/round/precollection_test.go index 09aab78..9b7019a 100644 --- a/internal/round/precollection_test.go +++ b/internal/round/precollection_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/unicitynetwork/bft-go-base/types" + "github.com/unicitynetwork/bft-go-base/types/hex" "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/events" @@ -175,7 +176,6 @@ func (c *staleThenFreshRootAggregatorClient) GetShardProof(ctx context.Context, } c.proofPolls++ - root := c.submittedRoot.String() parentRound := c.staleParentRound rootRound := c.staleRootRound @@ -186,16 +186,18 @@ func (c *staleThenFreshRootAggregatorClient) GetShardProof(ctx context.Context, default: } - uc, err := testProofUC(parentRound, rootRound) + uc, err := testProofUC(parentRound, rootRound, c.submittedRoot) if err != nil { return nil, err } return &api.RootShardInclusionProof{ - UnicityCertificate: uc, - MerkleTreePath: &api.MerkleTreePath{ - Steps: []api.MerkleTreeStep{{Data: &root}}, + ParentFragment: &api.ParentInclusionFragment{ + CertificateBytes: api.NewHexBytes(make([]byte, api.BitmapSize)), + ShardLeafValue: api.NewHexBytes(c.submittedRoot), }, + BlockNumber: parentRound, + UnicityCertificate: uc, }, nil } @@ -209,10 +211,11 @@ func (c *staleThenFreshRootAggregatorClient) ProofPolls() int { return c.proofPolls } -func testProofUC(parentRound, rootRound uint64) (api.HexBytes, error) { +func testProofUC(parentRound, rootRound uint64, rootHash api.HexBytes) (api.HexBytes, error) { uc := types.UnicityCertificate{ InputRecord: &types.InputRecord{ RoundNumber: parentRound, + Hash: hex.Bytes(rootHash), }, UnicitySeal: &types.UnicitySeal{ RootChainRoundNumber: rootRound, @@ -624,7 +627,8 @@ func TestChildPrecollector_DeactivateDuringInFlightRound(t *testing.T) { MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -697,7 +701,8 @@ func TestChildRound_ParentProofTimeoutIsRetriable(t *testing.T) { MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 100 * time.Millisecond, @@ -833,7 +838,8 @@ func TestPipelinedChildModeFlow(t *testing.T) { MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -903,7 +909,8 @@ func TestChildPreCollection_CommitmentAfterProofBeforeRoundEnd_ShouldBeInNextRou MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -972,7 +979,8 @@ func TestChildMode_RequiresFreshParentProof(t *testing.T) { MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -988,7 +996,7 @@ func TestChildMode_RequiresFreshParentProof(t *testing.T) { rootHash, err := api.NewHexBytesFromString(smtInstance.GetRootHash()) require.NoError(t, err) - initialUC, err := testProofUC(1, 10) + initialUC, err := testProofUC(1, 10, rootHash) require.NoError(t, err) initialBlock := models.NewBlock( api.NewBigInt(big.NewInt(1)), @@ -999,7 +1007,6 @@ func TestChildMode_RequiresFreshParentProof(t *testing.T) { rootHash, api.HexBytes{}, initialUC, - nil, ) initialBlock.Finalized = true require.NoError(t, storage.BlockStorage().Store(ctx, initialBlock)) diff --git a/internal/round/recovery_test.go b/internal/round/recovery_test.go index 72d43cb..3254987 100644 --- a/internal/round/recovery_test.go +++ b/internal/round/recovery_test.go @@ -149,7 +149,7 @@ func (s *RecoveryTestSuite) createTestData(blockNum int64, commitmentCount int, rootHashBytes := smtTree.GetRootHash() // Create block (unfinalized) - block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", api.HexBytes(rootHashBytes), nil, nil, nil) + block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", api.HexBytes(rootHashBytes), nil, nil) block.Finalized = false return commitments, block, requestIDs diff --git a/internal/round/round_manager_test.go b/internal/round/round_manager_test.go index f261048..6166359 100644 --- a/internal/round/round_manager_test.go +++ b/internal/round/round_manager_test.go @@ -30,7 +30,8 @@ func TestParentShardIntegration_GoodCase(t *testing.T) { BatchLimit: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -79,7 +80,8 @@ func TestParentShardIntegration_RoundProcessingError(t *testing.T) { BatchLimit: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, diff --git a/internal/round/smt_persistence_integration_test.go b/internal/round/smt_persistence_integration_test.go index 74b3408..e7cd120 100644 --- a/internal/round/smt_persistence_integration_test.go +++ b/internal/round/smt_persistence_integration_test.go @@ -229,7 +229,6 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { rootHashBytes, api.HexBytes{}, nil, - nil, ) err = rm.FinalizeBlock(ctx, block) @@ -296,18 +295,17 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { // Create a block with the expected raw root hash block := &models.Block{ - Index: api.NewBigInt(big.NewInt(1)), - ChainID: "test-chain", - ShardID: 0, - Version: "1.0.0", - ForkID: "test-fork", - RootHash: api.HexBytes(expectedRootHashRaw), - PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000000"), - NoDeletionProofHash: api.HexBytes(""), - CreatedAt: api.NewTimestamp(time.Now()), - UnicityCertificate: api.HexBytes("certificate_data"), - ParentMerkleTreePath: nil, - Finalized: true, // Mark as finalized for test + Index: api.NewBigInt(big.NewInt(1)), + ChainID: "test-chain", + ShardID: 0, + Version: "1.0.0", + ForkID: "test-fork", + RootHash: api.HexBytes(expectedRootHashRaw), + PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000000"), + NoDeletionProofHash: api.HexBytes(""), + CreatedAt: api.NewTimestamp(time.Now()), + UnicityCertificate: api.HexBytes("certificate_data"), + Finalized: true, // Mark as finalized for test } // Store the block @@ -345,18 +343,17 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { t.Run("FailedVerification", func(t *testing.T) { // Create a block with a different root hash to simulate mismatch wrongBlock := &models.Block{ - Index: api.NewBigInt(big.NewInt(2)), - ChainID: "test-chain", - ShardID: 0, - Version: "1.0.0", - ForkID: "test-fork", - RootHash: api.HexBytes("wrong_root_hash_value"), // Intentionally wrong hash - PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000001"), - NoDeletionProofHash: api.HexBytes(""), - CreatedAt: api.NewTimestamp(time.Now()), - UnicityCertificate: api.HexBytes("certificate_data"), - ParentMerkleTreePath: nil, - Finalized: true, // Mark as finalized for test + Index: api.NewBigInt(big.NewInt(2)), + ChainID: "test-chain", + ShardID: 0, + Version: "1.0.0", + ForkID: "test-fork", + RootHash: api.HexBytes("wrong_root_hash_value"), // Intentionally wrong hash + PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000001"), + NoDeletionProofHash: api.HexBytes(""), + CreatedAt: api.NewTimestamp(time.Now()), + UnicityCertificate: api.HexBytes("certificate_data"), + Finalized: true, // Mark as finalized for test } // Store the wrong block (this will become the "latest" block) diff --git a/internal/service/parent_service.go b/internal/service/parent_service.go index d274051..1c4f382 100644 --- a/internal/service/parent_service.go +++ b/internal/service/parent_service.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math/big" "github.com/unicitynetwork/bft-go-base/types" @@ -163,6 +162,9 @@ func (pas *ParentAggregatorService) GetShardProof(ctx context.Context, req *api. pas.logger.WithContext(ctx).Debug("Serving shard proof while node is follower") } } + if !pas.parentRoundManager.IsReady() { + return nil, errors.New("parent round manager is not ready to serve shard proofs") + } if err := pas.validateShardID(req.ShardID); err != nil { pas.logger.WithContext(ctx).Warn("Invalid shard ID", "shardId", req.ShardID, "error", err.Error()) @@ -171,29 +173,24 @@ func (pas *ParentAggregatorService) GetShardProof(ctx context.Context, req *api. pas.logger.WithContext(ctx).Debug("Shard proof requested", "shardId", req.ShardID) - shardPath := new(big.Int).SetInt64(int64(req.ShardID)) - merkleTreePath, err := pas.parentRoundManager.GetSMT().GetPath(shardPath) + parentFragment, rootHashRaw, err := pas.parentRoundManager.GetSMT().GetShardInclusionFragmentWithRoot(req.ShardID) if err != nil { - return nil, fmt.Errorf("failed to get merkle tree path: %w", err) + return nil, fmt.Errorf("failed to get shard inclusion fragment: %w", err) } - - var proofPath *api.MerkleTreePath - if len(merkleTreePath.Steps) > 0 && merkleTreePath.Steps[0].Data != nil { - proofPath = merkleTreePath - pas.logger.WithContext(ctx).Info("Generated shard proof from current state", - "shardId", req.ShardID, - "rootHash", merkleTreePath.Root) - } else { - proofPath = nil + if parentFragment == nil { pas.logger.WithContext(ctx).Info("Shard has not submitted root yet, returning nil proof", "shardId", req.ShardID) } var unicityCertificate api.HexBytes - if proofPath != nil { + var blockNumber uint64 + if parentFragment != nil { + // GetShardInclusionFragmentWithRoot snapshots fragment+root under one RLock. + // Finalization persists block data before committing the SMT snapshot, + // so a visible root is expected to have a corresponding finalized block. // Blocks are stored under the raw 32-byte SMT root matching // UC.IR.h, so look up the block by the raw root hash. - rootHash := api.HexBytes(pas.parentRoundManager.GetSMT().GetRootHashRaw()) + rootHash := api.HexBytes(rootHashRaw) block, err := pas.storage.BlockStorage().GetLatestByRootHash(ctx, rootHash) if err != nil { return nil, fmt.Errorf("failed to get latest block by root hash: %w", err) @@ -202,11 +199,13 @@ func (pas *ParentAggregatorService) GetShardProof(ctx context.Context, req *api. return nil, fmt.Errorf("no block found with root hash %s", rootHash.String()) } unicityCertificate = block.UnicityCertificate + blockNumber = block.Index.Uint64() } return &api.GetShardProofResponse{ - MerkleTreePath: proofPath, + ParentFragment: parentFragment, UnicityCertificate: unicityCertificate, + BlockNumber: blockNumber, }, nil } diff --git a/internal/service/parent_service_test.go b/internal/service/parent_service_test.go index 54790f9..32f844c 100644 --- a/internal/service/parent_service_test.go +++ b/internal/service/parent_service_test.go @@ -136,7 +136,7 @@ func (suite *ParentServiceTestSuite) waitForShardToExist(ctx context.Context, sh suite.T().Fatalf("Timeout waiting for shard %d to be processed", shardID) case <-ticker.C: response, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shardID}) - if err == nil && response.MerkleTreePath != nil { + if err == nil && response.ParentFragment != nil { return // Shard exists and has a proof! } // Continue polling @@ -283,7 +283,8 @@ func (suite *ParentServiceTestSuite) TestSubmitShardRoot_UpdateQueued() { suite.T().Log("✓ Shard update was queued and processed correctly") } -// GetShardProof - Success - Full E2E test with child SMT, proof joining, and verification +// GetShardProof - Success - Full E2E test with child SMT root submission and +// native parent fragment verification. func (suite *ParentServiceTestSuite) TestGetShardProof_Success() { ctx := context.Background() @@ -298,16 +299,9 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_Success() { err := childSMT.AddLeaf(testLeafPath, testLeafValue) suite.Require().NoError(err, "Should add leaf to child SMT") - // 2. Extract child root hash (what child aggregator would submit to parent) - childRootHash := childSMT.GetRootHash() - suite.Require().NotEmpty(childRootHash, "Child SMT should have root hash") - suite.T().Logf("Child SMT root hash: %x", childRootHash) - - // 3. Submit child root to parent (strip algorithm prefix - first 2 bytes) - // This is required for JoinPaths to work: parent stores raw 32-byte hashes - childRootRaw := childRootHash[2:] // Remove algorithm identifier (2 bytes) - suite.Require().True(len(childRootRaw) == 32, "Root hash should be 32 bytes after stripping prefix") - suite.T().Logf("Sending %d bytes to parent (WITHOUT algorithm prefix)", len(childRootRaw)) + // 2. Extract child root hash (what child aggregator submits to the parent). + childRootRaw := childSMT.GetRootHashRaw() + suite.Require().Len(childRootRaw, 32, "Child SMT raw root hash must be 32 bytes") submitReq := &api.SubmitShardRootRequest{ ShardID: shard0ID, @@ -319,42 +313,27 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_Success() { // 4. Wait for round to process suite.waitForShardToExist(ctx, shard0ID) - // 5. Get child proof from child SMT - childProof, err := childSMT.GetPath(testLeafPath) - suite.Require().NoError(err, "Should get child proof") - suite.Require().NotNil(childProof, "Child proof should not be nil") - suite.T().Logf("Child proof has %d steps", len(childProof.Steps)) - - // 6. Request parent proof from parent aggregator + // 5. Request parent proof from parent aggregator proofReq := &api.GetShardProofRequest{ ShardID: shard0ID, } parentResponse, err := suite.service.GetShardProof(ctx, proofReq) suite.Require().NoError(err, "Should get parent proof successfully") suite.Require().NotNil(parentResponse, "Parent response should not be nil") - suite.Require().NotNil(parentResponse.MerkleTreePath, "Parent proof should not be nil") - suite.T().Logf("Parent proof has %d steps", len(parentResponse.MerkleTreePath.Steps)) - - // 7. Join child and parent proofs - joinedProof, err := smt.JoinPaths(childProof, parentResponse.MerkleTreePath) - suite.Require().NoError(err, "Should join proofs successfully") - suite.Require().NotNil(joinedProof, "Joined proof should not be nil") - suite.T().Logf("Joined proof has %d steps", len(joinedProof.Steps)) - - // 8. Verify the joined proof - result, err := joinedProof.Verify(testLeafPath) - suite.Require().NoError(err, "Proof verification should not error") - suite.Require().NotNil(result, "Verification result should not be nil") - - // Both PathValid and PathIncluded should be true - suite.Assert().True(result.PathValid, "Joined proof path must be valid") - suite.Assert().True(result.PathIncluded, "Joined proof should show path is included") - suite.Assert().True(result.Result, "Overall verification result should be true") - - suite.T().Log("✓ End-to-end test: child SMT → parent submission → proof joining → verification SUCCESS") + suite.Require().NotNil(parentResponse.ParentFragment, "Parent fragment should not be nil") + + // 6. Verify the native parent fragment against the returned parent UC root. + suite.Assert().Equal(api.NewHexBytes(childRootRaw), parentResponse.ParentFragment.ShardLeafValue) + suite.Assert().True((&api.RootShardInclusionProof{ + ParentFragment: parentResponse.ParentFragment, + UnicityCertificate: parentResponse.UnicityCertificate, + BlockNumber: parentResponse.BlockNumber, + }).IsValid(shard0ID, suite.cfg.Sharding.ShardIDLength, api.NewHexBytes(childRootRaw))) + + suite.T().Log("✓ End-to-end test: child SMT → parent submission → native parent fragment verification SUCCESS") } -// GetShardProof - Non-existent Shard (returns nil MerkleTreePath) +// GetShardProof - Non-existent Shard (returns nil fragment) func (suite *ParentServiceTestSuite) TestGetShardProof_NonExistentShard() { ctx := context.Background() @@ -379,9 +358,9 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_NonExistentShard() { response, err := suite.service.GetShardProof(ctx, proofReq) suite.Require().NoError(err, "Should not return error for non-existent shard") suite.Require().NotNil(response, "Response should not be nil") - suite.Assert().Nil(response.MerkleTreePath, "MerkleTreePath should be nil for non-existent shard") + suite.Assert().Nil(response.ParentFragment, "Parent fragment should be nil for non-existent shard") - suite.T().Log("✓ GetShardProof returns nil MerkleTreePath for non-existent shard") + suite.T().Log("✓ GetShardProof returns nil parent fragment for non-existent shard") } // GetShardProof - Empty Tree (no shards submitted yet) @@ -397,9 +376,9 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_EmptyTree() { response, err := suite.service.GetShardProof(ctx, proofReq) suite.Require().NoError(err, "Should not return error for empty tree") suite.Require().NotNil(response, "Response should not be nil") - suite.Assert().Nil(response.MerkleTreePath, "MerkleTreePath should be nil when no shards submitted") + suite.Assert().Nil(response.ParentFragment, "Parent fragment should be nil when no shards submitted") - suite.T().Log("✓ GetShardProof returns nil MerkleTreePath for empty tree") + suite.T().Log("✓ GetShardProof returns nil parent fragment for empty tree") } // GetShardProof - Multiple Shards (verify each has correct proof) @@ -437,21 +416,33 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_MultipleShards() { // Get proofs for all 3 shards proof0, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard0ID}) suite.Require().NoError(err, "Should get proof for shard 0") - suite.Assert().NotNil(proof0.MerkleTreePath, "Proof 0 should not be nil") + suite.Assert().NotNil(proof0.ParentFragment, "Fragment 0 should not be nil") proof1, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard1ID}) suite.Require().NoError(err, "Should get proof for shard 1") - suite.Assert().NotNil(proof1.MerkleTreePath, "Proof 1 should not be nil") + suite.Assert().NotNil(proof1.ParentFragment, "Fragment 1 should not be nil") proof2, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard2ID}) suite.Require().NoError(err, "Should get proof for shard 2") - suite.Assert().NotNil(proof2.MerkleTreePath, "Proof 2 should not be nil") - - // All proofs should have the same root (same parent SMT) - suite.Assert().Equal(proof0.MerkleTreePath.Root, proof1.MerkleTreePath.Root, "All proofs should have same root") - suite.Assert().Equal(proof0.MerkleTreePath.Root, proof2.MerkleTreePath.Root, "All proofs should have same root") - - suite.T().Log("✓ GetShardProof returns valid proofs for multiple shards with same root") + suite.Assert().NotNil(proof2.ParentFragment, "Fragment 2 should not be nil") + + suite.Assert().True((&api.RootShardInclusionProof{ + ParentFragment: proof0.ParentFragment, + UnicityCertificate: proof0.UnicityCertificate, + BlockNumber: proof0.BlockNumber, + }).IsValid(shard0ID, suite.cfg.Sharding.ShardIDLength, makeTestHash(0xAA))) + suite.Assert().True((&api.RootShardInclusionProof{ + ParentFragment: proof1.ParentFragment, + UnicityCertificate: proof1.UnicityCertificate, + BlockNumber: proof1.BlockNumber, + }).IsValid(shard1ID, suite.cfg.Sharding.ShardIDLength, makeTestHash(0xBB))) + suite.Assert().True((&api.RootShardInclusionProof{ + ParentFragment: proof2.ParentFragment, + UnicityCertificate: proof2.UnicityCertificate, + BlockNumber: proof2.BlockNumber, + }).IsValid(shard2ID, suite.cfg.Sharding.ShardIDLength, makeTestHash(0xCC))) + + suite.T().Log("✓ GetShardProof returns valid parent fragments for multiple shards") } // TestParentServiceSuite runs the test suite diff --git a/internal/service/service.go b/internal/service/service.go index 594a6d3..56f75df 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -14,7 +14,6 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/signing" signingV1 "github.com/unicitynetwork/aggregator-go/internal/signing/v1" - "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" "github.com/unicitynetwork/aggregator-go/internal/trustbase" "github.com/unicitynetwork/aggregator-go/pkg/api" @@ -126,17 +125,18 @@ func modelToAPIAggregatorRecordV1(modelRecord *models.AggregatorRecord) *api.Agg func modelToAPIBlock(modelBlock *models.Block) *api.Block { return &api.Block{ - Index: modelBlock.Index, - ChainID: modelBlock.ChainID, - ShardID: modelBlock.ShardID, - Version: modelBlock.Version, - ForkID: modelBlock.ForkID, - RootHash: modelBlock.RootHash, - PreviousBlockHash: modelBlock.PreviousBlockHash, - NoDeletionProofHash: modelBlock.NoDeletionProofHash, - CreatedAt: modelBlock.CreatedAt, - UnicityCertificate: modelBlock.UnicityCertificate, - ParentMerkleTreePath: modelBlock.ParentMerkleTreePath, + Index: modelBlock.Index, + ChainID: modelBlock.ChainID, + ShardID: modelBlock.ShardID, + Version: modelBlock.Version, + ForkID: modelBlock.ForkID, + RootHash: modelBlock.RootHash, + PreviousBlockHash: modelBlock.PreviousBlockHash, + NoDeletionProofHash: modelBlock.NoDeletionProofHash, + CreatedAt: modelBlock.CreatedAt, + UnicityCertificate: modelBlock.UnicityCertificate, + ParentFragment: modelBlock.ParentFragment, + ParentBlockNumber: modelBlock.ParentBlockNumber, } } @@ -326,6 +326,10 @@ func (as *AggregatorService) GetInclusionProofV1(ctx context.Context, req *api.G unlock := as.roundManager.FinalizationReadLock() defer unlock() + if as.config.Sharding.Mode == config.ShardingModeChild { + return nil, fmt.Errorf("legacy inclusion proof v1 is not supported in child mode") + } + // verify that the request ID matches the shard ID of this aggregator if err := as.commitmentValidator.ValidateShardID(req.RequestID); err != nil { return nil, fmt.Errorf("request ID validation failed: %w", err) @@ -360,14 +364,6 @@ func (as *AggregatorService) GetInclusionProofV1(ctx context.Context, req *api.G return nil, fmt.Errorf("no block found with root hash %s", rootHash.String()) } - // Join parent and child SMT paths if sharding mode is enabled - if as.config.Sharding.Mode == config.ShardingModeChild { - merkleTreePath, err = smt.JoinPaths(merkleTreePath, block.ParentMerkleTreePath) - if err != nil { - return nil, fmt.Errorf("failed to join parent and child aggregator paths: %w", err) - } - } - // Check if commitment exists in aggregator records (finalized) record, err := as.storage.AggregatorRecordStorage().GetByStateID(ctx, req.RequestID) if err != nil { @@ -404,8 +400,9 @@ func (as *AggregatorService) GetInclusionProofV1(ctx context.Context, req *api.G } // GetInclusionProofV2 retrieves a v2 inclusion proof for the given stateId. -// In standalone mode it returns an InclusionCert bound to the current SMT -// root via UC.IR.h. Child mode is not implemented yet. +// Both standalone and child mode serve proofs against the current certified +// SMT root. In child mode, the local child cert is composed with the stored +// parent fragment and bound to the parent UC.IR.h. func (as *AggregatorService) GetInclusionProofV2(ctx context.Context, req *api.GetInclusionProofRequestV2) (*api.GetInclusionProofResponseV2, error) { unlock := as.roundManager.FinalizationReadLock() defer unlock() @@ -432,10 +429,6 @@ func (as *AggregatorService) GetInclusionProofV2(ctx context.Context, req *api.G return nil, fmt.Errorf("unexpected SMT key length: got %d bits, want %d", keyLen, api.StateTreeKeyLengthBits) } - if as.config.Sharding.Mode == config.ShardingModeChild { - return nil, fmt.Errorf("inclusion proof v2 not yet migrated for child mode") - } - // Bind the UC via the block whose stored rootHash matches the current // raw 32-byte SMT root (which also lives in UC.IR.h). rootHashRaw := api.HexBytes(smtInstance.GetRootHashRaw()) @@ -446,6 +439,10 @@ func (as *AggregatorService) GetInclusionProofV2(ctx context.Context, req *api.G if block == nil { return nil, fmt.Errorf("no block found with root hash %s", rootHashRaw.String()) } + responseBlockNumber, err := proofBundleBlockNumber(as.config.Sharding.Mode, block) + if err != nil { + return nil, err + } record, err := as.storage.AggregatorRecordStorage().GetByStateID(ctx, req.StateID) if err != nil { @@ -455,7 +452,7 @@ func (as *AggregatorService) GetInclusionProofV2(ctx context.Context, req *api.G // Non-inclusion is not implemented yet. Return an empty v2 proof // payload so verifiers short-circuit with ErrExclusionNotImpl. return &api.GetInclusionProofResponseV2{ - BlockNumber: block.Index.Uint64(), + BlockNumber: responseBlockNumber, InclusionProof: &api.InclusionProofV2{ CertificationData: nil, CertificateBytes: nil, @@ -467,28 +464,59 @@ func (as *AggregatorService) GetInclusionProofV2(ctx context.Context, req *api.G return nil, fmt.Errorf("invalid aggregator record version got %d expected 2", record.Version) } - cert, err := smtInstance.GetInclusionCert(key) + childCert, err := smtInstance.GetInclusionCert(key) if err != nil { return nil, fmt.Errorf("failed to build inclusion cert for state ID %s: %w", req.StateID, err) } + + cert := childCert + if as.config.Sharding.Mode == config.ShardingModeChild { + if block.ParentFragment == nil { + return nil, fmt.Errorf("current child block %s is missing parent fragment", block.Index.String()) + } + childRoot := smtInstance.GetRootHashRaw() + cert, err = api.ComposeInclusionCert(block.ParentFragment, childCert, childRoot) + if err != nil { + return nil, fmt.Errorf("failed to compose child and parent inclusion certs: %w", err) + } + } + certBytes, err := cert.MarshalBinary() if err != nil { return nil, fmt.Errorf("failed to marshal inclusion cert: %w", err) } + proof := &api.InclusionProofV2{ + CertificationData: record.CertificationData.ToAPI(), + CertificateBytes: certBytes, + UnicityCertificate: types.RawCBOR(block.UnicityCertificate), + } + if err := proof.Verify(&api.CertificationRequest{ + StateID: req.StateID, + CertificationData: *record.CertificationData.ToAPI(), + }); err != nil { + return nil, fmt.Errorf("generated inclusion proof failed self-verification: %w", err) + } + return &api.GetInclusionProofResponseV2{ - // BlockNumber must describe the returned proof bundle - // (CertificateBytes + UnicityCertificate), which is bound to the - // current SMT root looked up above. - BlockNumber: block.Index.Uint64(), - InclusionProof: &api.InclusionProofV2{ - CertificationData: record.CertificationData.ToAPI(), - CertificateBytes: certBytes, - UnicityCertificate: types.RawCBOR(block.UnicityCertificate), - }, + BlockNumber: responseBlockNumber, + InclusionProof: proof, }, nil } +func proofBundleBlockNumber(mode config.ShardingMode, block *models.Block) (uint64, error) { + if block == nil { + return 0, fmt.Errorf("missing block for proof bundle") + } + if mode != config.ShardingModeChild { + return block.Index.Uint64(), nil + } + if block.ParentBlockNumber == 0 { + return 0, fmt.Errorf("current child block %s is missing parent block number", block.Index.String()) + } + return block.ParentBlockNumber, nil +} + // GetNoDeletionProof retrieves the global no-deletion proof func (as *AggregatorService) GetNoDeletionProof(ctx context.Context) (*api.GetNoDeletionProofResponse, error) { // TODO: Implement no-deletion proof generation diff --git a/internal/service/service_test.go b/internal/service/service_test.go index ac86848..a6c7c7d 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "math/big" "net" "net/http" "net/url" @@ -22,17 +23,21 @@ import ( "github.com/stretchr/testify/suite" "github.com/testcontainers/testcontainers-go/modules/mongodb" redisContainer "github.com/testcontainers/testcontainers-go/modules/redis" + bfttypes "github.com/unicitynetwork/bft-go-base/types" + bfthex "github.com/unicitynetwork/bft-go-base/types/hex" "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/gateway" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/sharding" "github.com/unicitynetwork/aggregator-go/internal/signing" "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage" + "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" "github.com/unicitynetwork/aggregator-go/pkg/api" "github.com/unicitynetwork/aggregator-go/pkg/jsonrpc" ) @@ -294,8 +299,8 @@ func TestGetInclusionProofShardMismatch(t *testing.T) { tree := smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, shardingCfg.Child.ShardID) service := newAggregatorServiceForTest(t, shardingCfg, tree) - // Raw 32-byte v2 stateId whose low bits don't match shard 4 (=0b100). - invalidShardID := api.RequireNewImprintV2(strings.Repeat("00", api.StateTreeKeyLengthBytes-1) + "01") + // Raw 32-byte v2 stateId whose shard-prefix bits don't match shard 4 (=0b100). + invalidShardID := api.RequireNewImprintV2("01" + strings.Repeat("00", api.StateTreeKeyLengthBytes-1)) _, err := service.GetInclusionProofV2(context.Background(), &api.GetInclusionProofRequestV2{StateID: invalidShardID}) require.Error(t, err) assert.Contains(t, err.Error(), "state ID validation failed") @@ -325,7 +330,7 @@ func TestGetInclusionProofSMTUnavailable(t *testing.T) { } service := newAggregatorServiceForTest(t, shardingCfg, nil) - // Raw 32-byte v2 stateId; all-zero low bits match shard 4 (expected = 0b00). + // Raw 32-byte v2 stateId; all-zero shard-prefix bits match shard 4 (expected = 0b00). validID := api.RequireNewImprintV2(strings.Repeat("00", api.StateTreeKeyLengthBytes)) _, err := service.GetInclusionProofV2(t.Context(), &api.GetInclusionProofRequestV2{StateID: validID}) require.Error(t, err) @@ -350,6 +355,127 @@ func TestInclusionProofInvalidPathLength(t *testing.T) { assert.Contains(t, err.Error(), "invalid state ID length") } +func TestGetInclusionProofV2Child_ComposesParentFragment(t *testing.T) { + shardingCfg := config.ShardingConfig{ + Mode: config.ShardingModeChild, + ShardIDLength: 2, + Child: config.ChildConfig{ + ShardID: 4, // shard prefix bits 00 + }, + } + + stateID := api.RequireNewImprintV2(strings.Repeat("00", api.StateTreeKeyLengthBytes)) + sourceStateHash := api.RequireNewImprintV2("10" + strings.Repeat("00", api.StateTreeKeyLengthBytes-1)) + transactionHash := api.RequireNewImprintV2("20" + strings.Repeat("00", api.StateTreeKeyLengthBytes-1)) + + childTree := smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, shardingCfg.Child.ShardID) + path, err := stateID.GetPath() + require.NoError(t, err) + require.NoError(t, childTree.AddLeaf(path, transactionHash.DataBytes())) + childRoot := childTree.GetRootHashRaw() + + parentTree := smt.NewParentSparseMerkleTree(api.SHA256, shardingCfg.ShardIDLength) + require.NoError(t, smt.NewThreadSafeSMT(parentTree).AddPreHashedLeaf(big.NewInt(int64(shardingCfg.Child.ShardID)), childRoot)) + parentRoot := parentTree.GetRootHashRaw() + parentFragment, err := parentTree.GetShardInclusionFragment(shardingCfg.Child.ShardID) + require.NoError(t, err) + require.NotNil(t, parentFragment) + + parentUC := testChildProofUC(t, 9, parentRoot) + block := models.NewBlock( + api.NewBigIntFromUint64(3), + "test-chain", + shardingCfg.Child.ShardID, + "v", + "f", + api.HexBytes(childRoot), + nil, + parentUC, + ) + block.ParentFragment = parentFragment + block.ParentBlockNumber = 9 + block.Finalized = true + + record := &models.AggregatorRecord{ + Version: 2, + StateID: stateID, + CertificationData: models.CertificationData{ + OwnerPredicate: api.Predicate{Engine: 1, Code: []byte{0x01}, Params: []byte{0x02}}, + SourceStateHash: sourceStateHash, + TransactionHash: transactionHash, + Witness: []byte{0x01, 0x02}, + }, + BlockNumber: api.NewBigIntFromUint64(1), + LeafIndex: api.NewBigIntFromUint64(0), + CreatedAt: api.Now(), + FinalizedAt: api.Now(), + } + + service := newAggregatorServiceForTest(t, shardingCfg, childTree) + service.storage = &testStorage{ + blockStorage: &testBlockStorage{latestByRoot: map[string]*models.Block{block.RootHash.String(): block}}, + recordStorage: &testAggregatorRecordStorage{ + byStateID: map[string]*models.AggregatorRecord{stateID.String(): record}, + }, + } + + resp, err := service.GetInclusionProofV2(context.Background(), &api.GetInclusionProofRequestV2{StateID: stateID}) + require.NoError(t, err) + require.Equal(t, uint64(9), resp.BlockNumber) + require.NotNil(t, resp.InclusionProof) + require.Equal(t, parentUC, api.HexBytes(resp.InclusionProof.UnicityCertificate)) + + req := &api.CertificationRequest{ + StateID: stateID, + CertificationData: api.CertificationData{ + OwnerPredicate: record.CertificationData.OwnerPredicate, + SourceStateHash: record.CertificationData.SourceStateHash, + TransactionHash: record.CertificationData.TransactionHash, + Witness: record.CertificationData.Witness, + }, + } + validateInclusionProof(t, resp.InclusionProof, req) +} + +func TestGetInclusionProofV2Child_NonInclusionUsesParentBundleMetadata(t *testing.T) { + shardingCfg := config.ShardingConfig{ + Mode: config.ShardingModeChild, + ShardIDLength: 2, + Child: config.ChildConfig{ + ShardID: 4, + }, + } + + childTree := smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, shardingCfg.Child.ShardID) + parentUC := testChildProofUC(t, 12, childTree.GetRootHashRaw()) + block := models.NewBlock( + api.NewBigIntFromUint64(4), + "test-chain", + shardingCfg.Child.ShardID, + "v", + "f", + api.HexBytes(childTree.GetRootHashRaw()), + nil, + parentUC, + ) + block.ParentBlockNumber = 12 + block.Finalized = true + + service := newAggregatorServiceForTest(t, shardingCfg, childTree) + service.storage = &testStorage{ + blockStorage: &testBlockStorage{latestByRoot: map[string]*models.Block{block.RootHash.String(): block}}, + recordStorage: &testAggregatorRecordStorage{byStateID: map[string]*models.AggregatorRecord{}}, + } + + stateID := api.RequireNewImprintV2(strings.Repeat("00", api.StateTreeKeyLengthBytes)) + resp, err := service.GetInclusionProofV2(context.Background(), &api.GetInclusionProofRequestV2{StateID: stateID}) + require.NoError(t, err) + require.Equal(t, uint64(12), resp.BlockNumber) + require.Nil(t, resp.InclusionProof.CertificationData) + require.Empty(t, resp.InclusionProof.CertificateBytes) + require.Equal(t, parentUC, api.HexBytes(resp.InclusionProof.UnicityCertificate)) +} + // TestInclusionProof tests the complete inclusion proof workflow func (suite *AggregatorTestSuite) TestInclusionProof() { // 1) Send commitments @@ -494,3 +620,88 @@ func newAggregatorServiceForTest(t *testing.T, shardingCfg config.ShardingConfig certificationRequestValidator: signing.NewCertificationRequestValidator(shardingCfg), } } + +type testStorage struct { + blockStorage interfaces.BlockStorage + recordStorage interfaces.AggregatorRecordStorage +} + +func (s *testStorage) AggregatorRecordStorage() interfaces.AggregatorRecordStorage { + return s.recordStorage +} +func (s *testStorage) BlockStorage() interfaces.BlockStorage { return s.blockStorage } +func (s *testStorage) SmtStorage() interfaces.SmtStorage { return nil } +func (s *testStorage) BlockRecordsStorage() interfaces.BlockRecordsStorage { return nil } +func (s *testStorage) LeadershipStorage() interfaces.LeadershipStorage { return nil } +func (s *testStorage) TrustBaseStorage() interfaces.TrustBaseStorage { return nil } +func (s *testStorage) Initialize(context.Context) error { return nil } +func (s *testStorage) Ping(context.Context) error { return nil } +func (s *testStorage) Close(context.Context) error { return nil } +func (s *testStorage) WithTransaction(ctx context.Context, fn func(context.Context) error) error { + return fn(ctx) +} + +type testBlockStorage struct { + latestByRoot map[string]*models.Block +} + +func (s *testBlockStorage) Store(context.Context, *models.Block) error { return nil } +func (s *testBlockStorage) GetByNumber(context.Context, *api.BigInt) (*models.Block, error) { + return nil, nil +} +func (s *testBlockStorage) GetLatest(context.Context) (*models.Block, error) { return nil, nil } +func (s *testBlockStorage) GetLatestNumber(context.Context) (*api.BigInt, error) { return nil, nil } +func (s *testBlockStorage) Count(context.Context) (int64, error) { return 0, nil } +func (s *testBlockStorage) GetRange(context.Context, *api.BigInt, *api.BigInt) ([]*models.Block, error) { + return nil, nil +} +func (s *testBlockStorage) SetFinalized(context.Context, *api.BigInt, bool) error { return nil } +func (s *testBlockStorage) GetUnfinalized(context.Context) ([]*models.Block, error) { + return nil, nil +} +func (s *testBlockStorage) GetLatestByRootHash(ctx context.Context, rootHash api.HexBytes) (*models.Block, error) { + if s == nil { + return nil, nil + } + return s.latestByRoot[rootHash.String()], nil +} + +type testAggregatorRecordStorage struct { + byStateID map[string]*models.AggregatorRecord +} + +func (s *testAggregatorRecordStorage) Store(context.Context, *models.AggregatorRecord) error { + return nil +} +func (s *testAggregatorRecordStorage) StoreBatch(context.Context, []*models.AggregatorRecord) error { + return nil +} +func (s *testAggregatorRecordStorage) GetByBlockNumber(context.Context, *api.BigInt) ([]*models.AggregatorRecord, error) { + return nil, nil +} +func (s *testAggregatorRecordStorage) Count(context.Context) (int64, error) { return 0, nil } +func (s *testAggregatorRecordStorage) GetExistingRequestIDs(context.Context, []string) (map[string]bool, error) { + return nil, nil +} +func (s *testAggregatorRecordStorage) GetByStateID(ctx context.Context, stateID api.StateID) (*models.AggregatorRecord, error) { + if s == nil { + return nil, nil + } + return s.byStateID[stateID.String()], nil +} + +func testChildProofUC(t *testing.T, roundNumber uint64, rootHash []byte) api.HexBytes { + t.Helper() + uc := bfttypes.UnicityCertificate{ + InputRecord: &bfttypes.InputRecord{ + RoundNumber: roundNumber, + Hash: bfthex.Bytes(rootHash), + }, + UnicitySeal: &bfttypes.UnicitySeal{ + RootChainRoundNumber: roundNumber, + }, + } + ucBytes, err := bfttypes.Cbor.Marshal(uc) + require.NoError(t, err) + return api.NewHexBytes(ucBytes) +} diff --git a/internal/sharding/root_aggregator_client_stub.go b/internal/sharding/root_aggregator_client_stub.go index e1f5e13..b304f9d 100644 --- a/internal/sharding/root_aggregator_client_stub.go +++ b/internal/sharding/root_aggregator_client_stub.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/unicitynetwork/bft-go-base/types" + "github.com/unicitynetwork/bft-go-base/types/hex" "github.com/unicitynetwork/aggregator-go/pkg/api" ) @@ -46,16 +47,18 @@ func (m *RootAggregatorClientStub) GetShardProof(ctx context.Context, request *a if m.submissions[request.ShardID] != nil { m.returnedProofCount++ - submittedRootHash := m.submittedRootHash.String() - ucBytes, err := stubProofUC(uint64(m.returnedProofCount), uint64(m.returnedProofCount)) + ucBytes, err := stubProofUC(uint64(m.returnedProofCount), uint64(m.returnedProofCount), m.submittedRootHash) if err != nil { return nil, err } + fragment := &api.ParentInclusionFragment{ + CertificateBytes: api.NewHexBytes(make([]byte, api.BitmapSize)), + ShardLeafValue: api.NewHexBytes(m.submittedRootHash), + } return &api.RootShardInclusionProof{ + ParentFragment: fragment, + BlockNumber: uint64(m.returnedProofCount), UnicityCertificate: ucBytes, - MerkleTreePath: &api.MerkleTreePath{ - Steps: []api.MerkleTreeStep{{Data: &submittedRootHash}}, - }, }, nil } return nil, nil @@ -91,10 +94,11 @@ func (m *RootAggregatorClientStub) SetSubmissionError(err error) { m.submissionError = err } -func stubProofUC(parentRound, rootRound uint64) (api.HexBytes, error) { +func stubProofUC(parentRound, rootRound uint64, rootHash api.HexBytes) (api.HexBytes, error) { uc := types.UnicityCertificate{ InputRecord: &types.InputRecord{ RoundNumber: parentRound, + Hash: hex.Bytes(rootHash), }, UnicitySeal: &types.UnicitySeal{ RootChainRoundNumber: rootRound, diff --git a/internal/sharding/root_aggregator_client_test.go b/internal/sharding/root_aggregator_client_test.go index 7a7245f..fba9e86 100644 --- a/internal/sharding/root_aggregator_client_test.go +++ b/internal/sharding/root_aggregator_client_test.go @@ -59,7 +59,11 @@ func TestRootAggregatorClient_GetShardProof(t *testing.T) { require.Equal(t, float64(4), params["shardId"]) proof := &api.RootShardInclusionProof{ - MerkleTreePath: &api.MerkleTreePath{Root: "0x1234"}, + ParentFragment: &api.ParentInclusionFragment{ + CertificateBytes: api.NewHexBytes(make([]byte, api.BitmapSize)), + ShardLeafValue: api.HexBytes("0x1234"), + }, + BlockNumber: 9, UnicityCertificate: api.HexBytes("0xabcdef"), } @@ -80,7 +84,8 @@ func TestRootAggregatorClient_GetShardProof(t *testing.T) { require.NoError(t, err) require.NotNil(t, proof) - require.Equal(t, "0x1234", proof.MerkleTreePath.Root) + require.Equal(t, api.HexBytes("0x1234"), proof.ParentFragment.ShardLeafValue) + require.Equal(t, uint64(9), proof.BlockNumber) require.Equal(t, api.HexBytes("0xabcdef"), proof.UnicityCertificate) } diff --git a/internal/signing/certification_request_validator.go b/internal/signing/certification_request_validator.go index 936c00b..7f38e6b 100644 --- a/internal/signing/certification_request_validator.go +++ b/internal/signing/certification_request_validator.go @@ -1,10 +1,8 @@ package signing import ( - "encoding/hex" "errors" "fmt" - "math/big" "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/models" @@ -185,7 +183,7 @@ func (v *CertificationRequestValidator) ValidateShardID(stateID api.StateID) err if !v.shardConfig.Mode.IsChild() { return nil } - ok, err := verifyShardID(stateID.String(), v.shardConfig.Child.ShardID) + ok, err := api.MatchesShardPrefixFromHex(stateID.String(), v.shardConfig.Child.ShardID) if err != nil { return fmt.Errorf("error verifying shard id: %w", err) } @@ -194,34 +192,3 @@ func (v *CertificationRequestValidator) ValidateShardID(stateID api.StateID) err } return nil } - -// verifyShardID Checks if commitmentID's least significant bits match the shard bitmask. -func verifyShardID(commitmentID string, shardBitmask int) (bool, error) { - // convert to big.Ints - bytes, err := hex.DecodeString(commitmentID) - if err != nil { - return false, fmt.Errorf("failed to decode certification state ID: %w", err) - } - commitmentIdBigInt := new(big.Int).SetBytes(bytes) - shardBitmaskBigInt := new(big.Int).SetInt64(int64(shardBitmask)) - - // find position of MSB e.g. - // 0b111 -> BitLen=3 -> 3-1=2 - msbPos := shardBitmaskBigInt.BitLen() - 1 - - // build a mask covering bits below MSB e.g. - // 1<<2=0b100; 0b100-1=0b11; compareMask=0b11 - compareMask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(msbPos)), big.NewInt(1)) - - // remove MSB from shardBitmask to get expected value e.g. - // 0b111 & 0b11 = 0b11 - expected := new(big.Int).And(shardBitmaskBigInt, compareMask) - - // extract low bits from certification request e.g. - // commitment=0b11111111 & 0b11 = 0b11 - commitmentLowBits := new(big.Int).And(commitmentIdBigInt, compareMask) - - // return true if the certification request low bits match bitmask bits e.g. - // 0b11 == 0b11 - return commitmentLowBits.Cmp(expected) == 0, nil -} diff --git a/internal/signing/certification_request_validator_test.go b/internal/signing/certification_request_validator_test.go index 50db8c2..9ff5c43 100644 --- a/internal/signing/certification_request_validator_test.go +++ b/internal/signing/certification_request_validator_test.go @@ -149,6 +149,13 @@ func TestValidator_StateIDMismatch(t *testing.T) { } func TestValidator_ShardID(t *testing.T) { + makeShardTestID := func(firstByte, lastByte byte) string { + key := make([]byte, api.StateTreeKeyLengthBytes) + key[0] = firstByte + key[len(key)-1] = lastByte + return hex.EncodeToString(key) + } + tests := []struct { commitmentID string shardBitmask int @@ -158,25 +165,29 @@ func TestValidator_ShardID(t *testing.T) { // shard1=bitmask 0b10 // shard2=bitmask 0b11 - // certification request ending with 0b00000000 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b10, true}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b11, false}, + // certification request with key bit 0 = 0 belongs to shard1 + {makeShardTestID(0x00, 0x00), 0b10, true}, + {makeShardTestID(0x00, 0x00), 0b11, false}, + + // certification request with key bit 0 = 1 belongs to shard2 + {makeShardTestID(0x01, 0x00), 0b10, false}, + {makeShardTestID(0x01, 0x00), 0b11, true}, - // certification request ending with 0b00000001 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b10, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b11, true}, + // certification request with first byte 0b00000010 still belongs to shard1 + {makeShardTestID(0x02, 0x00), 0b10, true}, + {makeShardTestID(0x02, 0x00), 0b11, false}, - // certification request ending with 0b00000010 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b10, true}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b11, false}, + // certification request with first byte 0b00000011 belongs to shard2 + {makeShardTestID(0x03, 0x00), 0b10, false}, + {makeShardTestID(0x03, 0x00), 0b11, true}, - // certification request ending with 0b00000011 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b10, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b11, true}, + // certification request with first byte 0b11111111 belongs to shard2 + {makeShardTestID(0xFF, 0x00), 0b10, false}, + {makeShardTestID(0xFF, 0x00), 0b11, true}, - // certification request ending with 0b11111111 belongs to shard2 - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b10, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b11, true}, + // the last byte no longer affects shard routing under LSB-first byte order + {makeShardTestID(0x00, 0xFF), 0b10, true}, + {makeShardTestID(0x00, 0xFF), 0b11, false}, // === END TWO SHARD CONFIG === @@ -186,40 +197,40 @@ func TestValidator_ShardID(t *testing.T) { // shard3=0b101 // shard4=0b111 - // certification request ending with 0b00000000 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b100, true}, - - // certification request ending with 0b00000010 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b100, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b110, true}, - - // certification request ending with 0b00000001 belongs to shard3 - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b101, true}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b100, false}, - - // certification request ending with 0b00000011 belongs to shard4 - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b111, true}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b100, false}, - - // certification request ending with 0b11111111 belongs to shard4 - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b111, true}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b101, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b110, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b100, false}, + // key bits 1:0 = 00 belong to shard1 + {makeShardTestID(0x00, 0x00), 0b111, false}, + {makeShardTestID(0x00, 0x00), 0b101, false}, + {makeShardTestID(0x00, 0x00), 0b110, false}, + {makeShardTestID(0x00, 0x00), 0b100, true}, + + // key bits 1:0 = 10 belong to shard2 + {makeShardTestID(0x02, 0x00), 0b111, false}, + {makeShardTestID(0x02, 0x00), 0b100, false}, + {makeShardTestID(0x02, 0x00), 0b101, false}, + {makeShardTestID(0x02, 0x00), 0b110, true}, + + // key bits 1:0 = 01 belong to shard3 + {makeShardTestID(0x01, 0x00), 0b111, false}, + {makeShardTestID(0x01, 0x00), 0b101, true}, + {makeShardTestID(0x01, 0x00), 0b110, false}, + {makeShardTestID(0x01, 0x00), 0b100, false}, + + // key bits 1:0 = 11 belong to shard4 + {makeShardTestID(0x03, 0x00), 0b111, true}, + {makeShardTestID(0x03, 0x00), 0b101, false}, + {makeShardTestID(0x03, 0x00), 0b110, false}, + {makeShardTestID(0x03, 0x00), 0b100, false}, + + // key bits 1:0 = 11 still belong to shard4 when the whole first byte is set + {makeShardTestID(0xFF, 0x00), 0b111, true}, + {makeShardTestID(0xFF, 0x00), 0b101, false}, + {makeShardTestID(0xFF, 0x00), 0b110, false}, + {makeShardTestID(0xFF, 0x00), 0b100, false}, // === END FOUR SHARD CONFIG === } for _, tc := range tests { - match, err := verifyShardID(tc.commitmentID, tc.shardBitmask) + match, err := api.MatchesShardPrefixFromHex(tc.commitmentID, tc.shardBitmask) require.NoError(t, err) if match != tc.match { t.Errorf("commitmentID=%s shardBitmask=%b expected %v got %v", tc.commitmentID, tc.shardBitmask, tc.match, match) @@ -227,6 +238,17 @@ func TestValidator_ShardID(t *testing.T) { } } +func TestValidator_ShardIDRejectsPrefixedStateID(t *testing.T) { + rawKey := make([]byte, api.StateTreeKeyLengthBytes) + rawKey[0] = 0x01 + prefixed := append([]byte{0x00, 0x00}, rawKey...) + + match, err := api.MatchesShardPrefixFromHex(hex.EncodeToString(prefixed), 0b11) + require.Error(t, err) + require.False(t, match) + require.Contains(t, err.Error(), "must be exactly 32 bytes") +} + func TestValidator_InvalidSignatureFormat(t *testing.T) { validator := newDefaultCertificationRequestValidator() diff --git a/internal/signing/v1/commitment_validator.go b/internal/signing/v1/commitment_validator.go index 7d1606c..ce9293d 100644 --- a/internal/signing/v1/commitment_validator.go +++ b/internal/signing/v1/commitment_validator.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "errors" "fmt" - "math/big" "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/models/v1" @@ -212,7 +211,16 @@ func (v *CommitmentValidator) ValidateShardID(requestID api.RequestID) error { if !v.shardConfig.Mode.IsChild() { return nil } - ok, err := verifyShardID(requestID.String(), v.shardConfig.Child.ShardID) + keyHex := requestID.String() + keyBytes, err := hex.DecodeString(keyHex) + if err != nil { + return fmt.Errorf("error decoding request ID: %w", err) + } + // V1 request IDs may carry a 2-byte algorithm prefix; strip it. + if len(keyBytes) == api.StateTreeKeyLengthBytes+2 { + keyBytes = keyBytes[2:] + } + ok, err := api.MatchesShardPrefix(keyBytes, v.shardConfig.Child.ShardID) if err != nil { return fmt.Errorf("error verifying shard id: %w", err) } @@ -222,37 +230,6 @@ func (v *CommitmentValidator) ValidateShardID(requestID api.RequestID) error { return nil } -// verifyShardID Checks if commitmentID's least significant bits match the shard bitmask. -func verifyShardID(commitmentID string, shardBitmask int) (bool, error) { - // convert to big.Ints - bytes, err := hex.DecodeString(commitmentID) - if err != nil { - return false, fmt.Errorf("failed to decode certification state ID: %w", err) - } - commitmentIdBigInt := new(big.Int).SetBytes(bytes) - shardBitmaskBigInt := new(big.Int).SetInt64(int64(shardBitmask)) - - // find position of MSB e.g. - // 0b111 -> BitLen=3 -> 3-1=2 - msbPos := shardBitmaskBigInt.BitLen() - 1 - - // build a mask covering bits below MSB e.g. - // 1<<2=0b100; 0b100-1=0b11; compareMask=0b11 - compareMask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(msbPos)), big.NewInt(1)) - - // remove MSB from shardBitmask to get expected value e.g. - // 0b111 & 0b11 = 0b11 - expected := new(big.Int).And(shardBitmaskBigInt, compareMask) - - // extract low bits from certification request e.g. - // commitment=0b11111111 & 0b11 = 0b11 - commitmentLowBits := new(big.Int).And(commitmentIdBigInt, compareMask) - - // return true if the certification request low bits match bitmask bits e.g. - // 0b11 == 0b11 - return commitmentLowBits.Cmp(expected) == 0, nil -} - // CreateDataHashImprint creates a DataHash imprint in the Unicity format: // 2 bytes algorithm (big-endian) + actual hash bytes // For SHA256: algorithm = 0, so prefix is [0x00, 0x00] diff --git a/internal/signing/v1/commitment_validator_test.go b/internal/signing/v1/commitment_validator_test.go index 98e2076..a35872a 100644 --- a/internal/signing/v1/commitment_validator_test.go +++ b/internal/signing/v1/commitment_validator_test.go @@ -161,6 +161,13 @@ func TestValidator_RequestIDMismatch(t *testing.T) { } func TestValidator_ShardID(t *testing.T) { + makeShardTestID := func(firstByte, lastByte byte) string { + key := make([]byte, api.StateTreeKeyLengthBytes) + key[0] = firstByte + key[len(key)-1] = lastByte + return hex.EncodeToString(key) + } + tests := []struct { commitmentID string shardBitmask int @@ -170,25 +177,29 @@ func TestValidator_ShardID(t *testing.T) { // shard1=bitmask 0b10 // shard2=bitmask 0b11 - // commitment ending with 0b00000000 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b10, true}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b11, false}, + // request key bit 0 = 0 belongs to shard1 + {makeShardTestID(0x00, 0x00), 0b10, true}, + {makeShardTestID(0x00, 0x00), 0b11, false}, + + // request key bit 0 = 1 belongs to shard2 + {makeShardTestID(0x01, 0x00), 0b10, false}, + {makeShardTestID(0x01, 0x00), 0b11, true}, - // commitment ending with 0b00000001 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b10, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b11, true}, + // request first byte 0b00000010 still belongs to shard1 + {makeShardTestID(0x02, 0x00), 0b10, true}, + {makeShardTestID(0x02, 0x00), 0b11, false}, - // commitment ending with 0b00000010 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b10, true}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b11, false}, + // request first byte 0b00000011 belongs to shard2 + {makeShardTestID(0x03, 0x00), 0b10, false}, + {makeShardTestID(0x03, 0x00), 0b11, true}, - // commitment ending with 0b00000011 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b10, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b11, true}, + // request first byte 0b11111111 belongs to shard2 + {makeShardTestID(0xFF, 0x00), 0b10, false}, + {makeShardTestID(0xFF, 0x00), 0b11, true}, - // commitment ending with 0b11111111 belongs to shard2 - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b10, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b11, true}, + // the final byte no longer affects shard routing + {makeShardTestID(0x00, 0xFF), 0b10, true}, + {makeShardTestID(0x00, 0xFF), 0b11, false}, // === END TWO SHARD CONFIG === @@ -198,40 +209,40 @@ func TestValidator_ShardID(t *testing.T) { // shard3=0b101 // shard4=0b111 - // commitment ending with 0b00000000 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b100, true}, - - // commitment ending with 0b00000010 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b100, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b110, true}, - - // commitment ending with 0b00000001 belongs to shard3 - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b101, true}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b100, false}, - - // commitment ending with 0b00000011 belongs to shard4 - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b111, true}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b100, false}, - - // commitment ending with 0b11111111 belongs to shard4 - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b111, true}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b101, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b110, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b100, false}, + // request key bits 1:0 = 00 belong to shard1 + {makeShardTestID(0x00, 0x00), 0b111, false}, + {makeShardTestID(0x00, 0x00), 0b101, false}, + {makeShardTestID(0x00, 0x00), 0b110, false}, + {makeShardTestID(0x00, 0x00), 0b100, true}, + + // request key bits 1:0 = 10 belong to shard2 + {makeShardTestID(0x02, 0x00), 0b111, false}, + {makeShardTestID(0x02, 0x00), 0b100, false}, + {makeShardTestID(0x02, 0x00), 0b101, false}, + {makeShardTestID(0x02, 0x00), 0b110, true}, + + // request key bits 1:0 = 01 belong to shard3 + {makeShardTestID(0x01, 0x00), 0b111, false}, + {makeShardTestID(0x01, 0x00), 0b101, true}, + {makeShardTestID(0x01, 0x00), 0b110, false}, + {makeShardTestID(0x01, 0x00), 0b100, false}, + + // request key bits 1:0 = 11 belong to shard4 + {makeShardTestID(0x03, 0x00), 0b111, true}, + {makeShardTestID(0x03, 0x00), 0b101, false}, + {makeShardTestID(0x03, 0x00), 0b110, false}, + {makeShardTestID(0x03, 0x00), 0b100, false}, + + // request key bits 1:0 = 11 still belong to shard4 when the whole first byte is set + {makeShardTestID(0xFF, 0x00), 0b111, true}, + {makeShardTestID(0xFF, 0x00), 0b101, false}, + {makeShardTestID(0xFF, 0x00), 0b110, false}, + {makeShardTestID(0xFF, 0x00), 0b100, false}, // === END FOUR SHARD CONFIG === } for _, tc := range tests { - match, err := verifyShardID(tc.commitmentID, tc.shardBitmask) + match, err := api.MatchesShardPrefixFromHex(tc.commitmentID, tc.shardBitmask) require.NoError(t, err) if match != tc.match { t.Errorf("commitmentID=%s shardBitmask=%b expected %v got %v", tc.commitmentID, tc.shardBitmask, tc.match, match) @@ -239,6 +250,23 @@ func TestValidator_ShardID(t *testing.T) { } } +func TestValidator_ValidateShardID_AcceptsPrefixedRequestID(t *testing.T) { + key := make([]byte, api.StateTreeKeyLengthBytes) + key[0] = 0x01 // bit 0 set -> shard 0b11 + + validator := &CommitmentValidator{ + shardConfig: config.ShardingConfig{ + Mode: config.ShardingModeChild, + Child: config.ChildConfig{ + ShardID: 0b11, + }, + }, + } + + requestID := api.RequireNewImprintV2("0000" + hex.EncodeToString(key)) + require.NoError(t, validator.ValidateShardID(requestID)) +} + func TestValidator_InvalidSignatureFormat(t *testing.T) { validator := newDefaultCommitmentValidator() diff --git a/internal/smt/inclusion_cert_test.go b/internal/smt/inclusion_cert_test.go index a3ca888..124c994 100644 --- a/internal/smt/inclusion_cert_test.go +++ b/internal/smt/inclusion_cert_test.go @@ -1,8 +1,10 @@ package smt import ( + "bytes" "crypto/rand" "encoding/hex" + "math/big" "testing" "github.com/stretchr/testify/require" @@ -267,6 +269,32 @@ func TestGetInclusionCert_WrongKeyLength(t *testing.T) { } } +// TestGetShardInclusionFragment_SkipsNilSiblingAfterNonNilSibling exercises +// the parent-tree traversal branch where a shallower sibling exists but a +// deeper sibling subtree is empty (nil hash). That depth must be skipped as a +// unary passthrough while keeping previously recorded siblings valid. +func TestGetShardInclusionFragment_SkipsNilSiblingAfterNonNilSibling(t *testing.T) { + parent := NewParentSparseMerkleTree(api.SHA256, 2) + + // Update two shards on opposite root sides so root-level sibling is present. + leaf4 := bytes.Repeat([]byte{0xA4}, api.SiblingSize) + leaf5 := bytes.Repeat([]byte{0xB5}, api.SiblingSize) + require.NoError(t, parent.AddLeaf(big.NewInt(0b100), leaf4)) + require.NoError(t, parent.AddLeaf(big.NewInt(0b101), leaf5)) + + fragment, err := parent.GetShardInclusionFragment(0b100) + require.NoError(t, err) + require.NotNil(t, fragment) + + var cert api.InclusionCert + require.NoError(t, cert.UnmarshalBinary(fragment.CertificateBytes)) + require.Equal(t, 1, bitmapPopcountForTest(&cert.Bitmap), "one deeper depth should be skipped as unary passthrough") + require.Len(t, cert.Siblings, 1, "only the non-empty shallower sibling should be present") + + root := parent.GetRootHashRaw() + require.NoError(t, fragment.Verify(0b100, 2, leaf4, root, api.SHA256)) +} + // TestGetRootHashRaw_MatchesHex confirms that GetRootHashRaw produces // the 32-byte hash portion of GetRootHashHex. func TestGetRootHashRaw_MatchesHex(t *testing.T) { diff --git a/internal/smt/smt.go b/internal/smt/smt.go index 284990a..3d7ce0e 100644 --- a/internal/smt/smt.go +++ b/internal/smt/smt.go @@ -560,35 +560,87 @@ func (smt *SparseMerkleTree) GetInclusionCert(key []byte) (*api.InclusionCert, e _ = smt.root.calculateHash(hasher) var cert api.InclusionCert - if err := smt.generateInclusionCert(hasher, key, smt.root, &cert); err != nil { + if _, err := smt.generateInclusionCertWithLeafValue(hasher, key, smt.root, &cert); err != nil { return nil, err } return &cert, nil } +// GetShardInclusionFragment builds the native parent proof fragment used by a +// child aggregator to later compose a full v2 proof locally. The fragment uses +// the same bitmap+sibling wire shape as InclusionCert but only contains the +// shallow parent-tree depths. The returned shard leaf value must equal the +// child tree root for composition to proceed. +func (smt *SparseMerkleTree) GetShardInclusionFragment(shardID api.ShardID) (*api.ParentInclusionFragment, error) { + if !smt.parentMode { + return nil, fmt.Errorf("smt: shard inclusion fragment only valid for parent trees") + } + + path := big.NewInt(int64(shardID)) + if path.BitLen()-1 != smt.keyLength { + return nil, ErrKeyLength + } + + key, err := api.PathToFixedBytes(path, smt.keyLength) + if err != nil { + return nil, fmt.Errorf("failed to derive shard key bytes: %w", err) + } + + hasher := api.NewDataHasher(smt.algorithm) + _ = smt.root.calculateHash(hasher) + + var cert api.InclusionCert + leafValue, err := smt.generateInclusionCertWithLeafValue(hasher, key, smt.root, &cert) + if err != nil { + return nil, err + } + if len(leafValue) == 0 { + return nil, nil + } + + certBytes, err := cert.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("failed to encode parent fragment cert: %w", err) + } + return &api.ParentInclusionFragment{ + CertificateBytes: api.NewHexBytes(certBytes), + ShardLeafValue: api.NewHexBytes(leafValue), + }, nil +} + // generateInclusionCert recursively walks from the current node toward the // leaf matching key, appending siblings and setting bitmap bits at every // 2-child branching node along the path. Unary passthrough nodes contribute // nothing to either the bitmap or the sibling list. func (smt *SparseMerkleTree) generateInclusionCert(hasher *api.DataHasher, key []byte, current branch, cert *api.InclusionCert) error { + _, err := smt.generateInclusionCertWithLeafValue(hasher, key, current, cert) + return err +} + +func (smt *SparseMerkleTree) generateInclusionCertWithLeafValue(hasher *api.DataHasher, key []byte, current branch, cert *api.InclusionCert) ([]byte, error) { if current == nil { - return fmt.Errorf("smt: inclusion cert traversal reached nil subtree") + return nil, fmt.Errorf("smt: inclusion cert traversal reached nil subtree") } if current.isLeaf() { leaf := current.(*LeafBranch) - if leaf.Key == nil || !bytes.Equal(leaf.Key, key) { - return fmt.Errorf("smt: leaf not found for key %x", key) + // Parent-mode populate() creates placeholder leaves with nil Key and nil + // Value. For those placeholders, key equality is validated only after the + // first real AddLeaf replaces them with a keyed leaf. + if leaf.Key != nil && !bytes.Equal(leaf.Key, key) { + return nil, fmt.Errorf("smt: leaf not found for key %x", key) } - return nil + if leaf.Value == nil { + return nil, nil + } + return append([]byte(nil), leaf.Value...), nil } node := current.(*NodeBranch) if node.Left != nil && node.Right != nil { depth := int(node.Depth) if depth < 0 || depth >= api.StateTreeKeyLengthBits { - return fmt.Errorf("smt: invalid branch depth %d", depth) + return nil, fmt.Errorf("smt: invalid branch depth %d", depth) } - cert.Bitmap[depth/8] |= 1 << (uint(depth) % 8) var sibling, child branch if keyBit(key, depth) == 0 { @@ -599,27 +651,40 @@ func (smt *SparseMerkleTree) generateInclusionCert(hasher *api.DataHasher, key [ child = node.Right } + childHash := child.calculateHash(hasher) + if childHash == nil { + return nil, nil + } + sibHash := sibling.calculateHash(hasher) + if sibHash == nil { + // The sibling subtree is empty (all-placeholder / no submitted leaf), + // so this depth is a unary passthrough: no bitmap bit and no sibling. + // Any already recorded shallower siblings remain valid. + return smt.generateInclusionCertWithLeafValue(hasher, key, child, cert) + } + + cert.Bitmap[depth/8] |= 1 << (uint(depth) % 8) if len(sibHash) != api.SiblingSize { - return fmt.Errorf("smt: sibling hash unexpected length: got %d, want %d", len(sibHash), api.SiblingSize) + return nil, fmt.Errorf("smt: sibling hash unexpected length: got %d, want %d", len(sibHash), api.SiblingSize) } var sib [api.SiblingSize]byte copy(sib[:], sibHash) cert.Siblings = append(cert.Siblings, sib) - return smt.generateInclusionCert(hasher, key, child, cert) + return smt.generateInclusionCertWithLeafValue(hasher, key, child, cert) } // Unary passthrough: typically only at the root when the tree holds // ≤1 leaves. The v2 rule says a single-child node's hash equals the // child's hash, so no bitmap bit and no sibling are added. if node.Left != nil { - return smt.generateInclusionCert(hasher, key, node.Left, cert) + return smt.generateInclusionCertWithLeafValue(hasher, key, node.Left, cert) } if node.Right != nil { - return smt.generateInclusionCert(hasher, key, node.Right, cert) + return smt.generateInclusionCertWithLeafValue(hasher, key, node.Right, cert) } - return fmt.Errorf("smt: reached empty subtree in inclusion cert traversal") + return nil, fmt.Errorf("smt: reached empty subtree in inclusion cert traversal") } // keyBit returns bit d of the raw key under LSB-first byte layout. @@ -676,7 +741,7 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, le if branch.isLeaf() && branch.getPath().Cmp(remainingPath) == 0 { leafBranch := branch.(*LeafBranch) if leafBranch.isChild { - return newChildLeafBranchWithKey(leafBranch.Path, leafBranch.Key, value), nil + return newChildLeafBranchWithKey(leafBranch.Path, leafKey, value), nil } else if bytes.Equal(leafBranch.Value, value) { return nil, ErrDuplicateLeaf } else { @@ -889,40 +954,3 @@ func NewLeaf(path *big.Int, value []byte) *Leaf { Value: append([]byte(nil), value...), } } - -// JoinPaths joins the hash proofs from a child and parent in sharded setting -func JoinPaths(child, parent *api.MerkleTreePath) (*api.MerkleTreePath, error) { - if child == nil { - return nil, errors.New("nil child path") - } - if parent == nil { - return nil, errors.New("nil parent path") - } - - // Root hashes are hex-encoded imprints, the first 4 characters are hash function identifiers - if len(child.Root) < 4 { - return nil, errors.New("invalid child root hash format") - } - if len(parent.Root) < 4 { - return nil, errors.New("invalid parent root hash format") - } - if child.Root[:4] != parent.Root[:4] { - return nil, errors.New("can't join paths: child hash algorithm does not match parent") - } - - if len(parent.Steps) == 0 { - return nil, errors.New("empty parent hash steps") - } - if parent.Steps[0].Data == nil || *parent.Steps[0].Data != child.Root[4:] { - return nil, errors.New("can't join paths: child root hash does not match parent input hash") - } - - steps := make([]api.MerkleTreeStep, len(child.Steps)+len(parent.Steps)-1) - copy(steps, child.Steps) - copy(steps[len(child.Steps):], parent.Steps[1:]) - - return &api.MerkleTreePath{ - Root: parent.Root, - Steps: steps, - }, nil -} diff --git a/internal/smt/smt_test.go b/internal/smt/smt_test.go index c1d3be4..ab05d15 100644 --- a/internal/smt/smt_test.go +++ b/internal/smt/smt_test.go @@ -1381,155 +1381,6 @@ func TestSMTAddingLeafAboveNode(t *testing.T) { require.Error(t, smt2.AddLeaves(leaves2), "SMT should not allow adding leaves above existing nodes, even in a batch") } -func TestJoinPaths(t *testing.T) { - // "Two Leaves, Sharded" example from the spec - t.Run("TwoLeaves", func(t *testing.T) { - left := NewChildSparseMerkleTree(api.SHA256, 2, 0b10) - left.AddLeaf(big.NewInt(0b100), []byte{0x61}) - - right := NewChildSparseMerkleTree(api.SHA256, 2, 0b11) - right.AddLeaf(big.NewInt(0b111), []byte{0x62}) - - parent := NewParentSparseMerkleTree(api.SHA256, 1) - parent.AddLeaf(big.NewInt(0b10), left.GetRootHash()[2:]) - parent.AddLeaf(big.NewInt(0b11), right.GetRootHash()[2:]) - - leftChild, _ := left.GetPath(big.NewInt(0b100)) - leftParent, _ := parent.GetPath(big.NewInt(0b10)) - leftPath, err := JoinPaths(leftChild, leftParent) - assert.Nil(t, err) - assert.NotNil(t, leftPath) - leftRes, err := leftPath.Verify(big.NewInt(0b100)) - assert.Nil(t, err) - assert.NotNil(t, leftRes) - assert.True(t, leftRes.PathValid) - assert.True(t, leftRes.PathIncluded) - - rightChild, _ := right.GetPath(big.NewInt(0b111)) - rightParent, _ := parent.GetPath(big.NewInt(0b11)) - rightPath, err := JoinPaths(rightChild, rightParent) - assert.Nil(t, err) - assert.NotNil(t, rightPath) - rightRes, err := rightPath.Verify(big.NewInt(0b111)) - assert.Nil(t, err) - assert.NotNil(t, rightRes) - assert.True(t, rightRes.PathValid) - assert.True(t, rightRes.PathIncluded) - }) - - // "Four Leaves, Sharded" example from the spec - t.Run("FourLeaves", func(t *testing.T) { - left := NewChildSparseMerkleTree(api.SHA256, 4, 0b110) - left.AddLeaf(big.NewInt(0b10010), []byte{0x61}) - left.AddLeaf(big.NewInt(0b11010), []byte{0x62}) - - right := NewChildSparseMerkleTree(api.SHA256, 4, 0b101) - right.AddLeaf(big.NewInt(0b10101), []byte{0x63}) - right.AddLeaf(big.NewInt(0b11101), []byte{0x64}) - - parent := NewParentSparseMerkleTree(api.SHA256, 2) - parent.AddLeaf(big.NewInt(0b110), left.GetRootHash()[2:]) - parent.AddLeaf(big.NewInt(0b101), right.GetRootHash()[2:]) - - child1, _ := left.GetPath(big.NewInt(0b10010)) - parent1, _ := parent.GetPath(big.NewInt(0b110)) - path1, err := JoinPaths(child1, parent1) - assert.Nil(t, err) - assert.NotNil(t, path1) - res1, err := path1.Verify(big.NewInt(0b10010)) - assert.Nil(t, err) - assert.NotNil(t, res1) - assert.True(t, res1.PathValid) - assert.True(t, res1.PathIncluded) - - child2, _ := left.GetPath(big.NewInt(0b11010)) - parent2, _ := parent.GetPath(big.NewInt(0b110)) - path2, err := JoinPaths(child2, parent2) - assert.Nil(t, err) - assert.NotNil(t, path2) - res2, err := path2.Verify(big.NewInt(0b11010)) - assert.Nil(t, err) - assert.NotNil(t, res2) - assert.True(t, res2.PathValid) - assert.True(t, res2.PathIncluded) - - child3, _ := right.GetPath(big.NewInt(0b10101)) - parent3, _ := parent.GetPath(big.NewInt(0b101)) - path3, err := JoinPaths(child3, parent3) - assert.Nil(t, err) - assert.NotNil(t, path3) - res3, err := path3.Verify(big.NewInt(0b10101)) - assert.Nil(t, err) - assert.NotNil(t, res3) - assert.True(t, res3.PathValid) - assert.True(t, res3.PathIncluded) - - child4, _ := right.GetPath(big.NewInt(0b11101)) - parent4, _ := parent.GetPath(big.NewInt(0b101)) - path4, err := JoinPaths(child4, parent4) - assert.Nil(t, err) - assert.NotNil(t, path4) - res4, err := path4.Verify(big.NewInt(0b11101)) - assert.Nil(t, err) - assert.NotNil(t, res4) - assert.True(t, res4.PathValid) - assert.True(t, res4.PathIncluded) - }) - - t.Run("NilInputDoesNotPanic", func(t *testing.T) { - dummy := &api.MerkleTreePath{Root: "0000"} - - joinedPath, err := JoinPaths(nil, dummy) - assert.ErrorContains(t, err, "nil child path") - assert.Nil(t, joinedPath) - - joinedPath, err = JoinPaths(dummy, nil) - assert.ErrorContains(t, err, "nil parent path") - assert.Nil(t, joinedPath) - }) - - t.Run("NilRootDoesNotPanic", func(t *testing.T) { - dummyNil := &api.MerkleTreePath{} - dummyShort := &api.MerkleTreePath{Root: ""} - dummyOK := &api.MerkleTreePath{Root: "0000"} - - joinedPath, err := JoinPaths(dummyNil, dummyOK) - assert.ErrorContains(t, err, "invalid child root hash format") - assert.Nil(t, joinedPath) - - joinedPath, err = JoinPaths(dummyShort, dummyOK) - assert.ErrorContains(t, err, "invalid child root hash format") - assert.Nil(t, joinedPath) - - joinedPath, err = JoinPaths(dummyOK, dummyNil) - assert.ErrorContains(t, err, "invalid parent root hash format") - assert.Nil(t, joinedPath) - - joinedPath, err = JoinPaths(dummyOK, dummyShort) - assert.ErrorContains(t, err, "invalid parent root hash format") - assert.Nil(t, joinedPath) - }) - - t.Run("HashFunctionMismatch", func(t *testing.T) { - child := &api.MerkleTreePath{Root: "0000"} - parent := &api.MerkleTreePath{Root: "0001"} - - joinedPath, err := JoinPaths(child, parent) - assert.ErrorContains(t, err, "child hash algorithm does not match parent") - assert.Nil(t, joinedPath) - }) - - t.Run("HashValueMismatch", func(t *testing.T) { - smt := NewSparseMerkleTree(api.SHA256, 1) - smt.AddLeaf(big.NewInt(0b10), []byte{0}) - path, _ := smt.GetPath(big.NewInt(0b10)) - - joinedPath, err := JoinPaths(path, path) - assert.ErrorContains(t, err, "child root hash does not match parent input hash") - assert.Nil(t, joinedPath) - }) -} - // TestParentSMTSnapshotUpdateLeaf tests that parent SMT snapshots can update pre-populated leaves func TestParentSMTSnapshotUpdateLeaf(t *testing.T) { // Create parent SMT with ShardIDLength=1 (creates pre-populated leaves at paths 2 and 3) diff --git a/internal/smt/thread_safe_smt.go b/internal/smt/thread_safe_smt.go index c12fc77..5ead109 100644 --- a/internal/smt/thread_safe_smt.go +++ b/internal/smt/thread_safe_smt.go @@ -108,6 +108,32 @@ func (ts *ThreadSafeSMT) GetInclusionCert(key []byte) (*api.InclusionCert, error return ts.smt.GetInclusionCert(key) } +// GetShardInclusionFragment builds the native parent proof fragment for the +// given shard ID. This is only valid on parent-mode SMT instances. +func (ts *ThreadSafeSMT) GetShardInclusionFragment(shardID api.ShardID) (*api.ParentInclusionFragment, error) { + ts.rwMux.RLock() + defer ts.rwMux.RUnlock() + return ts.smt.GetShardInclusionFragment(shardID) +} + +// GetShardInclusionFragmentWithRoot atomically reads the parent fragment and +// the raw SMT root from the same in-memory snapshot. This avoids serving a +// fragment from one root and looking up a block for a newer root in a later +// read section. +func (ts *ThreadSafeSMT) GetShardInclusionFragmentWithRoot(shardID api.ShardID) (*api.ParentInclusionFragment, []byte, error) { + ts.rwMux.RLock() + defer ts.rwMux.RUnlock() + + parentFragment, err := ts.smt.GetShardInclusionFragment(shardID) + if err != nil { + return nil, nil, err + } + if parentFragment == nil { + return nil, nil, nil + } + return parentFragment, ts.smt.GetRootHashRaw(), nil +} + // GetKeyLength exposes the configured SMT key length. func (ts *ThreadSafeSMT) GetKeyLength() int { ts.rwMux.RLock() diff --git a/internal/smt/thread_safe_smt_test.go b/internal/smt/thread_safe_smt_test.go index 22a5b76..cb60e77 100644 --- a/internal/smt/thread_safe_smt_test.go +++ b/internal/smt/thread_safe_smt_test.go @@ -24,6 +24,23 @@ func TestThreadSafeSMT_AddPreHashedLeaf_StoresChildRoot(t *testing.T) { require.True(t, leaf.isChild) } +func TestThreadSafeSMT_GetShardInclusionFragment_ReturnsNativeParentFragment(t *testing.T) { + tree := NewThreadSafeSMT(NewParentSparseMerkleTree(api.SHA256, 2)) + + path := big.NewInt(4) // shard ID with sentinel bit for a 2-bit parent tree + hash := bytes.Repeat([]byte{0xcd}, 32) + + require.NoError(t, tree.AddPreHashedLeaf(path, hash)) + + fragment, err := tree.GetShardInclusionFragment(4) + require.NoError(t, err) + require.NotNil(t, fragment) + require.Equal(t, hash, []byte(fragment.ShardLeafValue)) + require.GreaterOrEqual(t, len(fragment.CertificateBytes), api.BitmapSize) + require.Equal(t, 0, len(fragment.CertificateBytes)%api.SiblingSize) + require.NoError(t, fragment.Verify(4, 2, hash, tree.GetRootHashRaw(), api.SHA256)) +} + func TestThreadSafeSMT_PrimesHashesOnConstruction(t *testing.T) { tree := NewThreadSafeSMT(NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) requireBranchHashesPrimed(t, tree.smt.root) diff --git a/pkg/api/inclusion_cert.go b/pkg/api/inclusion_cert.go index 5ce138e..7ad3ec9 100644 --- a/pkg/api/inclusion_cert.go +++ b/pkg/api/inclusion_cert.go @@ -104,34 +104,50 @@ func (c *InclusionCert) Verify(key, value, expectedRoot []byte, algo HashAlgorit if len(key) != StateTreeKeyLengthBytes { return fmt.Errorf("%w: got %d, want %d", ErrCertKeyLength, len(key), StateTreeKeyLengthBytes) } - if len(expectedRoot) != SiblingSize { - return fmt.Errorf("%w: got %d, want %d", ErrCertRootLength, len(expectedRoot), SiblingSize) - } + + // Leaf hash: H(0x00 || key || value). hasher := NewDataHasher(algo) if hasher == nil { return fmt.Errorf("%w: %d", ErrCertUnknownAlgo, algo) } - - // Leaf hash: H(0x00 || key || value). hasher.Reset(). AddData([]byte{0x00}). AddData(key). AddData(value) h := hasher.GetHash().RawHash + return verifyBitmapPath(&c.Bitmap, c.Siblings, key, h, expectedRoot, algo) +} + +func verifyBitmapPath(bitmap *[BitmapSize]byte, siblings [][SiblingSize]byte, key, startHash, expectedRoot []byte, algo HashAlgorithm) error { + if len(startHash) != SiblingSize { + return fmt.Errorf("%w: got %d, want %d", ErrCertRootLength, len(startHash), SiblingSize) + } + if len(expectedRoot) != SiblingSize { + return fmt.Errorf("%w: got %d, want %d", ErrCertRootLength, len(expectedRoot), SiblingSize) + } + hasher := NewDataHasher(algo) + if hasher == nil { + return fmt.Errorf("%w: %d", ErrCertUnknownAlgo, algo) + } + // Walk depths from deepest to shallowest, consuming siblings from // the end of the slice. Depths with bitmap bit clear are skipped // (unary passthrough or off-path). - j := len(c.Siblings) + h := append([]byte(nil), startHash...) + j := len(siblings) for d := maxDepth - 1; d >= 0; d-- { - if (c.Bitmap[d/8]>>(uint(d)%8))&1 == 0 { + if ((*bitmap)[d/8]>>(uint(d)%8))&1 == 0 { continue } + if d/8 >= len(key) { + return fmt.Errorf("%w: key too short for depth %d", ErrCertKeyLength, d) + } if j == 0 { return ErrCertSiblingUnderflow } j-- - sibling := c.Siblings[j][:] + sibling := siblings[j][:] hasher.Reset().AddData([]byte{0x01, byte(d)}) if keyBitAt(key, d) == 1 { diff --git a/pkg/api/inclusion_cert_compose.go b/pkg/api/inclusion_cert_compose.go new file mode 100644 index 0000000..d2b01c1 --- /dev/null +++ b/pkg/api/inclusion_cert_compose.go @@ -0,0 +1,101 @@ +package api + +import ( + "bytes" + "errors" + "fmt" +) + +var ( + ErrCertDepthOverlap = errors.New("inclusion cert: parent and child cert overlap in depth") + ErrCertDepthOrder = errors.New("inclusion cert: parent cert depths must be shallower than child cert depths") + ErrCertChildRootMismatch = errors.New("inclusion cert: parent fragment shard leaf value does not match child root") + ErrCertMissingChild = errors.New("inclusion cert: missing child cert") + ErrCertMissingParent = errors.New("inclusion cert: missing parent fragment") +) + +// ComposeInclusionCert merges a child inclusion certificate with the stored +// parent proof fragment for the child shard. The result is a single public +// InclusionCert that can later be verified against the parent UC.IR.h. +// +// Invariants enforced here: +// - parentFragment.ShardLeafValue must equal childRoot +// - parent fragment certificate bytes must decode as a valid InclusionCert +// - parent and child bitmaps must not overlap in depth +// - if both certs contain siblings, every parent depth must be shallower +// than every child depth +// - merged siblings stay in root-to-leaf order: parent first, then child +func ComposeInclusionCert(parentFragment *ParentInclusionFragment, child *InclusionCert, childRoot []byte) (*InclusionCert, error) { + if parentFragment == nil { + return nil, ErrCertMissingParent + } + if child == nil { + return nil, ErrCertMissingChild + } + if len(childRoot) != SiblingSize { + return nil, fmt.Errorf("%w: got %d, want %d", ErrCertRootLength, len(childRoot), SiblingSize) + } + if len(parentFragment.ShardLeafValue) != SiblingSize { + return nil, fmt.Errorf("invalid parent fragment shard leaf value length: got %d, want %d", + len(parentFragment.ShardLeafValue), SiblingSize) + } + if !bytes.Equal(parentFragment.ShardLeafValue, childRoot) { + return nil, ErrCertChildRootMismatch + } + + var parent InclusionCert + if err := parent.UnmarshalBinary(parentFragment.CertificateBytes); err != nil { + return nil, fmt.Errorf("failed to decode parent fragment cert: %w", err) + } + + if bitmapOverlap(&parent.Bitmap, &child.Bitmap) { + return nil, ErrCertDepthOverlap + } + + parentHasDepths, _, parentMax := bitmapDepthRange(&parent.Bitmap) + childHasDepths, childMin, _ := bitmapDepthRange(&child.Bitmap) + if parentHasDepths && childHasDepths && parentMax >= childMin { + return nil, fmt.Errorf("%w: deepest parent depth %d, shallowest child depth %d", + ErrCertDepthOrder, parentMax, childMin) + } + + out := &InclusionCert{} + for i := range out.Bitmap { + out.Bitmap[i] = parent.Bitmap[i] | child.Bitmap[i] + } + out.Siblings = make([][SiblingSize]byte, 0, len(parent.Siblings)+len(child.Siblings)) + out.Siblings = append(out.Siblings, parent.Siblings...) + out.Siblings = append(out.Siblings, child.Siblings...) + + if len(out.Siblings) != bitmapPopcount(&out.Bitmap) { + return nil, fmt.Errorf("%w: have %d, want %d", + ErrCertBitmapMismatch, len(out.Siblings), bitmapPopcount(&out.Bitmap)) + } + return out, nil +} + +func bitmapOverlap(a, b *[BitmapSize]byte) bool { + for i := 0; i < BitmapSize; i++ { + if a[i]&b[i] != 0 { + return true + } + } + return false +} + +func bitmapDepthRange(bitmap *[BitmapSize]byte) (ok bool, minDepth, maxDepthSeen int) { + minDepth = BitmapSize * 8 + for depth := 0; depth < BitmapSize*8; depth++ { + if (bitmap[depth/8]>>(uint(depth)%8))&1 == 0 { + continue + } + if !ok { + minDepth = depth + maxDepthSeen = depth + ok = true + continue + } + maxDepthSeen = depth + } + return ok, minDepth, maxDepthSeen +} diff --git a/pkg/api/inclusion_cert_test.go b/pkg/api/inclusion_cert_test.go index d08f76a..ccebec7 100644 --- a/pkg/api/inclusion_cert_test.go +++ b/pkg/api/inclusion_cert_test.go @@ -305,7 +305,7 @@ func TestBitmapPopcount(t *testing.T) { func TestKeyBitAt(t *testing.T) { key := make([]byte, StateTreeKeyLengthBytes) key[0] = 0b1010_0101 - key[1] = 0x01 // bit 8 set + key[1] = 0x01 // bit 8 set key[31] = 0x80 // bit 255 set checks := []struct { @@ -388,3 +388,136 @@ func TestExclusionCertVerify_NotImplemented(t *testing.T) { t.Fatalf("expected ErrExclusionNotImpl, got %v", err) } } + +func TestComposeInclusionCert_Success(t *testing.T) { + key := make([]byte, StateTreeKeyLengthBytes) // shard prefix 00, child bits also 0 + value := []byte("leaf-value") + + childSibling := bytes.Repeat([]byte{0x55}, SiblingSize) + parentSibling := bytes.Repeat([]byte{0x99}, SiblingSize) + + leaf := hashLeafRaw(t, SHA256, key, value) + childRoot := hashNodeRaw(t, SHA256, 5, leaf, childSibling) + parentRoot := hashNodeRaw(t, SHA256, 1, childRoot, parentSibling) + + child := &InclusionCert{} + child.Bitmap[5/8] |= 1 << (5 % 8) + var childS [SiblingSize]byte + copy(childS[:], childSibling) + child.Siblings = append(child.Siblings, childS) + + parent := &InclusionCert{} + parent.Bitmap[1/8] |= 1 << (1 % 8) + var parentS [SiblingSize]byte + copy(parentS[:], parentSibling) + parent.Siblings = append(parent.Siblings, parentS) + parentBytes, err := parent.MarshalBinary() + if err != nil { + t.Fatalf("marshal parent cert: %v", err) + } + + fragment := &ParentInclusionFragment{ + CertificateBytes: parentBytes, + ShardLeafValue: append([]byte(nil), childRoot...), + } + + composed, err := ComposeInclusionCert(fragment, child, childRoot) + if err != nil { + t.Fatalf("compose inclusion cert: %v", err) + } + if err := composed.Verify(key, value, parentRoot, SHA256); err != nil { + t.Fatalf("verify composed cert: %v", err) + } + if got := len(composed.Siblings); got != 2 { + t.Fatalf("composed sibling count: got %d want 2", got) + } + if composed.Siblings[0] != parentS { + t.Fatalf("parent sibling not first in composed cert") + } + if composed.Siblings[1] != childS { + t.Fatalf("child sibling not second in composed cert") + } +} + +func TestComposeInclusionCert_RejectsMalformedParentFragment(t *testing.T) { + child := &InclusionCert{} + fragment := &ParentInclusionFragment{ + CertificateBytes: []byte{0x01, 0x02}, + ShardLeafValue: bytes.Repeat([]byte{0xAA}, SiblingSize), + } + _, err := ComposeInclusionCert(fragment, child, bytes.Repeat([]byte{0xAA}, SiblingSize)) + if err == nil { + t.Fatal("expected malformed parent fragment to fail") + } + if !errors.Is(err, ErrCertTruncated) { + t.Fatalf("expected ErrCertTruncated, got %v", err) + } +} + +func TestComposeInclusionCert_RejectsChildRootMismatch(t *testing.T) { + parent := &InclusionCert{} + parentBytes, err := parent.MarshalBinary() + if err != nil { + t.Fatalf("marshal parent cert: %v", err) + } + fragment := &ParentInclusionFragment{ + CertificateBytes: parentBytes, + ShardLeafValue: bytes.Repeat([]byte{0xAA}, SiblingSize), + } + _, err = ComposeInclusionCert(fragment, &InclusionCert{}, bytes.Repeat([]byte{0xBB}, SiblingSize)) + if !errors.Is(err, ErrCertChildRootMismatch) { + t.Fatalf("expected ErrCertChildRootMismatch, got %v", err) + } +} + +func TestComposeInclusionCert_RejectsDepthOverlap(t *testing.T) { + child := &InclusionCert{} + child.Bitmap[5/8] |= 1 << (5 % 8) + var childS [SiblingSize]byte + child.Siblings = append(child.Siblings, childS) + + parent := &InclusionCert{} + parent.Bitmap[5/8] |= 1 << (5 % 8) + var parentS [SiblingSize]byte + parent.Siblings = append(parent.Siblings, parentS) + parentBytes, err := parent.MarshalBinary() + if err != nil { + t.Fatalf("marshal parent cert: %v", err) + } + + childRoot := bytes.Repeat([]byte{0x11}, SiblingSize) + fragment := &ParentInclusionFragment{ + CertificateBytes: parentBytes, + ShardLeafValue: append([]byte(nil), childRoot...), + } + _, err = ComposeInclusionCert(fragment, child, childRoot) + if !errors.Is(err, ErrCertDepthOverlap) { + t.Fatalf("expected ErrCertDepthOverlap, got %v", err) + } +} + +func TestComposeInclusionCert_RejectsParentDeeperThanChild(t *testing.T) { + child := &InclusionCert{} + child.Bitmap[3/8] |= 1 << (3 % 8) + var childS [SiblingSize]byte + child.Siblings = append(child.Siblings, childS) + + parent := &InclusionCert{} + parent.Bitmap[7/8] |= 1 << (7 % 8) + var parentS [SiblingSize]byte + parent.Siblings = append(parent.Siblings, parentS) + parentBytes, err := parent.MarshalBinary() + if err != nil { + t.Fatalf("marshal parent cert: %v", err) + } + + childRoot := bytes.Repeat([]byte{0x22}, SiblingSize) + fragment := &ParentInclusionFragment{ + CertificateBytes: parentBytes, + ShardLeafValue: append([]byte(nil), childRoot...), + } + _, err = ComposeInclusionCert(fragment, child, childRoot) + if !errors.Is(err, ErrCertDepthOrder) { + t.Fatalf("expected ErrCertDepthOrder, got %v", err) + } +} diff --git a/pkg/api/shard_match.go b/pkg/api/shard_match.go new file mode 100644 index 0000000..18fc895 --- /dev/null +++ b/pkg/api/shard_match.go @@ -0,0 +1,48 @@ +package api + +import ( + "encoding/hex" + "fmt" + "math/bits" +) + +// MatchesShardPrefix checks whether the LSB-first bits of keyBytes match the +// shard prefix defined by shardBitmask. The bitmask encodes a sentinel-prefixed +// shard ID (e.g. 0b100 = shard 0 in a 2-bit tree). keyBytes must be at least +// ceil(shardDepth/8) bytes long. +func MatchesShardPrefix(keyBytes []byte, shardBitmask int) (bool, error) { + shardDepth := bits.Len(uint(shardBitmask)) - 1 + if shardDepth < 0 { + return false, fmt.Errorf("invalid shard bitmask: %d", shardBitmask) + } + if len(keyBytes) < (shardDepth+7)/8 { + return false, fmt.Errorf("key too short for shard depth %d: got %d bytes", shardDepth, len(keyBytes)) + } + + for d := 0; d < shardDepth; d++ { + expected := byte((uint(shardBitmask) >> uint(d)) & 1) + actual := (keyBytes[d/8] >> (uint(d) % 8)) & 1 + if actual != expected { + return false, nil + } + } + return true, nil +} + +// MatchesShardPrefixFromHex decodes a hex-encoded 32-byte state key and +// applies MatchesShardPrefix. +func MatchesShardPrefixFromHex(keyHex string, shardBitmask int) (bool, error) { + keyBytes, err := hex.DecodeString(keyHex) + if err != nil { + return false, fmt.Errorf("failed to decode state key: %w", err) + } + + if len(keyBytes) != StateTreeKeyLengthBytes { + return false, fmt.Errorf( + "state key must be exactly %d bytes, got %d", + StateTreeKeyLengthBytes, len(keyBytes), + ) + } + + return MatchesShardPrefix(keyBytes, shardBitmask) +} diff --git a/pkg/api/state_id_path_test.go b/pkg/api/state_id_path_test.go deleted file mode 100644 index aaca346..0000000 --- a/pkg/api/state_id_path_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package api - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestImprintV2GetPath_UsesRawHashBytesOnly(t *testing.T) { - // 32-byte raw hash - raw := "11223344556677889900aabbccddeeff00112233445566778899aabbccddeeff" - rawID, err := NewImprintV2(raw) - require.NoError(t, err) - - // Legacy-prefixed form of the same hash - legacyID, err := NewImprintV2("0000" + raw) - require.NoError(t, err) - - rawPath, err := rawID.GetPath() - require.NoError(t, err) - legacyPath, err := legacyID.GetPath() - require.NoError(t, err) - - // Both encodings must map to the same 256-bit key path (+sentinel bit). - require.Equal(t, rawPath, legacyPath) - require.Equal(t, StateTreeKeyLengthBits, rawPath.BitLen()-1) -} diff --git a/pkg/api/types.go b/pkg/api/types.go index dd268e5..3f02611 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -68,17 +68,18 @@ type AggregatorRecord struct { // Block represents a blockchain block type Block struct { - Index *BigInt `json:"index"` - ChainID string `json:"chainId"` - ShardID ShardID `json:"shardId"` - Version string `json:"version"` - ForkID string `json:"forkId"` - RootHash HexBytes `json:"rootHash"` - PreviousBlockHash HexBytes `json:"previousBlockHash"` - NoDeletionProofHash HexBytes `json:"noDeletionProofHash"` - CreatedAt *Timestamp `json:"createdAt"` - UnicityCertificate HexBytes `json:"unicityCertificate"` - ParentMerkleTreePath *MerkleTreePath `json:"parentMerkleTreePath,omitempty"` // child mode only + Index *BigInt `json:"index"` + ChainID string `json:"chainId"` + ShardID ShardID `json:"shardId"` + Version string `json:"version"` + ForkID string `json:"forkId"` + RootHash HexBytes `json:"rootHash"` + PreviousBlockHash HexBytes `json:"previousBlockHash"` + NoDeletionProofHash HexBytes `json:"noDeletionProofHash"` + CreatedAt *Timestamp `json:"createdAt"` + UnicityCertificate HexBytes `json:"unicityCertificate"` + ParentFragment *ParentInclusionFragment `json:"parentFragment,omitempty"` // child mode only + ParentBlockNumber uint64 `json:"parentBlockNumber,omitempty"` // child mode only } // NoDeletionProof represents a no-deletion proof @@ -138,9 +139,45 @@ type InclusionProofV2 struct { UnicityCertificate types.RawCBOR `json:"unicityCertificate"` } +// ParentInclusionFragment is the internal parent-tree proof fragment stored on +// finalized child blocks and returned by get_shard_proof. CertificateBytes +// uses the same bitmap+sibling wire shape as InclusionCert; ShardLeafValue is +// the parent leaf value proven by that fragment and must equal the child SMT +// root before later composition. +type ParentInclusionFragment struct { + CertificateBytes HexBytes `json:"certificateBytes"` + ShardLeafValue HexBytes `json:"shardLeafValue"` +} + +func (f *ParentInclusionFragment) Verify(shardID ShardID, keyLength int, expectedLeafValue, expectedRoot []byte, algo HashAlgorithm) error { + if f == nil { + return errors.New("missing parent fragment") + } + if len(f.ShardLeafValue) != SiblingSize { + return fmt.Errorf("invalid parent fragment shard leaf value length: got %d, want %d", len(f.ShardLeafValue), SiblingSize) + } + if !bytes.Equal(f.ShardLeafValue, expectedLeafValue) { + return errors.New("parent fragment shard leaf value does not match expected child root") + } + + var cert InclusionCert + if err := cert.UnmarshalBinary(f.CertificateBytes); err != nil { + return fmt.Errorf("failed to decode parent fragment cert: %w", err) + } + + path := big.NewInt(int64(shardID)) + key, err := PathToFixedBytes(path, keyLength) + if err != nil { + return fmt.Errorf("failed to derive parent fragment key: %w", err) + } + + return verifyBitmapPath(&cert.Bitmap, cert.Siblings, key, f.ShardLeafValue, expectedRoot, algo) +} + type RootShardInclusionProof struct { - MerkleTreePath *MerkleTreePath `json:"merkleTreePath"` - UnicityCertificate HexBytes `json:"unicityCertificate"` + ParentFragment *ParentInclusionFragment `json:"parentFragment,omitempty"` + UnicityCertificate HexBytes `json:"unicityCertificate"` + BlockNumber uint64 `json:"blockNumber,omitempty"` } // GetNoDeletionProofResponse represents the get_no_deletion_proof JSON-RPC response @@ -209,8 +246,9 @@ type GetShardProofRequest struct { // GetShardProofResponse represents the get_shard_proof JSON-RPC response type GetShardProofResponse struct { - MerkleTreePath *MerkleTreePath `json:"merkleTreePath"` // Proof path for the shard - UnicityCertificate HexBytes `json:"unicityCertificate"` // Unicity Certificate from the finalized block + ParentFragment *ParentInclusionFragment `json:"parentFragment,omitempty"` // native parent fragment for child v2 composition + UnicityCertificate HexBytes `json:"unicityCertificate"` // Unicity Certificate from the finalized block + BlockNumber uint64 `json:"blockNumber,omitempty"` } // HealthStatus represents the health status of the service @@ -240,9 +278,15 @@ func (h *HealthStatus) AddDetail(key, value string) { h.Details[key] = value } -func (r *RootShardInclusionProof) IsValid(shardRootHash string) bool { - return r.MerkleTreePath != nil && len(r.UnicityCertificate) > 0 && - len(r.MerkleTreePath.Steps) > 0 && r.MerkleTreePath.Steps[0].Data != nil && *r.MerkleTreePath.Steps[0].Data == shardRootHash +func (r *RootShardInclusionProof) IsValid(shardID ShardID, keyLength int, shardRootHash HexBytes) bool { + if r == nil || len(r.UnicityCertificate) == 0 || r.ParentFragment == nil { + return false + } + rootRaw, err := ucInputRecordHashRaw(r.UnicityCertificate) + if err != nil { + return false + } + return r.ParentFragment.Verify(shardID, keyLength, shardRootHash, rootRaw, InclusionProofV2HashAlgorithm) == nil } type Sharding struct { @@ -320,11 +364,15 @@ func (p *InclusionProofV2) Verify(v2 *CertificationRequest) error { // UCInputRecordHashRaw decodes the embedded Unicity Certificate and // returns UC.IR.h as a raw 32-byte hash. Any other length is rejected. func (p *InclusionProofV2) UCInputRecordHashRaw() ([]byte, error) { - if len(p.UnicityCertificate) == 0 { + return ucInputRecordHashRaw(p.UnicityCertificate) +} + +func ucInputRecordHashRaw(raw []byte) ([]byte, error) { + if len(raw) == 0 { return nil, errors.New("missing unicity certificate") } var uc types.UnicityCertificate - if err := types.Cbor.Unmarshal(p.UnicityCertificate, &uc); err != nil { + if err := types.Cbor.Unmarshal(raw, &uc); err != nil { return nil, fmt.Errorf("failed to decode unicity certificate: %w", err) } if uc.InputRecord == nil { diff --git a/test/integration/sharding_e2e_test.go b/test/integration/sharding_e2e_test.go index 963c87b..a09ce96 100644 --- a/test/integration/sharding_e2e_test.go +++ b/test/integration/sharding_e2e_test.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "math/big" "net/http" "net/url" "strconv" @@ -23,6 +22,7 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/service" + "github.com/unicitynetwork/aggregator-go/internal/signing" "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage" "github.com/unicitynetwork/aggregator-go/internal/testutil" @@ -32,9 +32,6 @@ import ( // TestShardingE2E tests the full sharding flow: parent + 2 child shards // submitting commitments and verifying inclusion proofs. func TestShardingE2E(t *testing.T) { - // Child-mode v2 inclusion proof generation has not been reimplemented - // yet. Re-enable this test once child mode inclusion_proof_v2 is migrated. - t.Skip("child-mode inclusion_proof_v2 not yet migrated") ctx := t.Context() // Start containers (shared MongoDB with different databases per aggregator) @@ -64,17 +61,17 @@ func TestShardingE2E(t *testing.T) { waitForBlock(t, "http://localhost:9002", 1, 15*time.Second) // Submit commitments over multiple rounds - var shard2ReqIDs, shard3ReqIDs []string + var shard2Requests, shard3Requests []*api.CertificationRequest // Round 1: submit 2 commitments to each shard for i := 0; i < 2; i++ { - c, reqID := createCommitmentForShard(t, 2) + c := createCommitmentForShard(t, cfgForShard(2)) submitCertificationRequest(t, "http://localhost:9001", c) - shard2ReqIDs = append(shard2ReqIDs, reqID) + shard2Requests = append(shard2Requests, c) - c, reqID = createCommitmentForShard(t, 3) + c = createCommitmentForShard(t, cfgForShard(3)) submitCertificationRequest(t, "http://localhost:9002", c) - shard3ReqIDs = append(shard3ReqIDs, reqID) + shard3Requests = append(shard3Requests, c) } waitForBlock(t, "http://localhost:9001", 2, 15*time.Second) @@ -82,36 +79,36 @@ func TestShardingE2E(t *testing.T) { // Round 2: submit 2 more commitments to each shard for i := 0; i < 2; i++ { - c, reqID := createCommitmentForShard(t, 2) + c := createCommitmentForShard(t, cfgForShard(2)) submitCertificationRequest(t, "http://localhost:9001", c) - shard2ReqIDs = append(shard2ReqIDs, reqID) + shard2Requests = append(shard2Requests, c) - c, reqID = createCommitmentForShard(t, 3) + c = createCommitmentForShard(t, cfgForShard(3)) submitCertificationRequest(t, "http://localhost:9002", c) - shard3ReqIDs = append(shard3ReqIDs, reqID) + shard3Requests = append(shard3Requests, c) } waitForBlock(t, "http://localhost:9001", 3, 15*time.Second) waitForBlock(t, "http://localhost:9002", 3, 15*time.Second) // Round 3: submit 1 more commitment to each shard - c, reqID := createCommitmentForShard(t, 2) + c := createCommitmentForShard(t, cfgForShard(2)) submitCertificationRequest(t, "http://localhost:9001", c) - shard2ReqIDs = append(shard2ReqIDs, reqID) + shard2Requests = append(shard2Requests, c) - c, reqID = createCommitmentForShard(t, 3) + c = createCommitmentForShard(t, cfgForShard(3)) submitCertificationRequest(t, "http://localhost:9002", c) - shard3ReqIDs = append(shard3ReqIDs, reqID) + shard3Requests = append(shard3Requests, c) waitForBlock(t, "http://localhost:9001", 4, 15*time.Second) waitForBlock(t, "http://localhost:9002", 4, 15*time.Second) // Verify all proofs - for _, reqID := range shard2ReqIDs { - waitForValidProof(t, "http://localhost:9001", reqID, 15*time.Second) + for _, req := range shard2Requests { + waitForValidProof(t, "http://localhost:9001", req, 15*time.Second) } - for _, reqID := range shard3ReqIDs { - waitForValidProof(t, "http://localhost:9002", reqID, 15*time.Second) + for _, req := range shard3Requests { + waitForValidProof(t, "http://localhost:9002", req, 15*time.Second) } } @@ -167,8 +164,10 @@ func startAggregator(t *testing.T, ctx context.Context, name, port, mongoURI, re // Create SMT instance based on sharding mode var smtInstance *smt.SparseMerkleTree switch cfg.Sharding.Mode { - case config.ShardingModeStandalone, config.ShardingModeChild: + case config.ShardingModeStandalone: smtInstance = smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + case config.ShardingModeChild: + smtInstance = smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, cfg.Sharding.Child.ShardID) case config.ShardingModeParent: smtInstance = smt.NewParentSparseMerkleTree(api.SHA256, cfg.Sharding.ShardIDLength) } @@ -239,19 +238,17 @@ func waitForBlock(t *testing.T, url string, minBlock int64, timeout time.Duratio t.Fatalf("Timeout waiting for block %d at %s", minBlock, url) } -func waitForValidProof(t *testing.T, url, reqID string, timeout time.Duration) { +func waitForValidProof(t *testing.T, url string, req *api.CertificationRequest, timeout time.Duration) { deadline := time.Now().Add(timeout) - reqIDObj := api.RequireNewImprintV2(reqID) for time.Now().Before(deadline) { - result, err := rpcCall(url, "get_inclusion_proof.v2", map[string]string{"stateId": reqID}) + result, err := rpcCall(url, "get_inclusion_proof.v2", &api.GetInclusionProofRequestV2{StateID: req.StateID}) if err == nil { var resp api.GetInclusionProofResponseV2 - json.Unmarshal(result, &resp) + require.NoError(t, json.Unmarshal(result, &resp)) if resp.InclusionProof != nil && len(resp.InclusionProof.CertificateBytes) > 0 { - // Reconstruct a minimal CertificationRequest for v2 verification. - // (Test skipped until child-mode v2 is migrated — see t.Skip above.) - _ = reqIDObj + require.Greater(t, resp.BlockNumber, uint64(0)) + require.NoError(t, resp.InclusionProof.Verify(req)) return } } @@ -260,17 +257,24 @@ func waitForValidProof(t *testing.T, url, reqID string, timeout time.Duration) { t.Fatalf("Timeout waiting for valid proof at %s", url) } -func createCommitmentForShard(t *testing.T, shardID api.ShardID) (*api.CertificationRequest, string) { - mask := big.NewInt(1) - expected := big.NewInt(int64(shardID & 1)) - +func createCommitmentForShard(t *testing.T, shardingCfg config.ShardingConfig) *api.CertificationRequest { + validator := signing.NewCertificationRequestValidator(shardingCfg) for i := 0; i < 1000; i++ { - c := testutil.CreateTestCertificationRequest(t, fmt.Sprintf("shard%d_%d_%d", shardID, i, time.Now().UnixNano())) - if new(big.Int).And(new(big.Int).SetBytes(c.StateID.Bytes()), mask).Cmp(expected) == 0 { - certificationRequest := c.ToAPI() - return certificationRequest, c.StateID.String() + c := testutil.CreateTestCertificationRequest(t, fmt.Sprintf("shard%d_%d_%d", shardingCfg.Child.ShardID, i, time.Now().UnixNano())) + if err := validator.ValidateShardID(c.StateID); err == nil { + return c.ToAPI() } } t.Fatal("Failed to generate commitment for shard") - return nil, "" + return nil +} + +func cfgForShard(shardID api.ShardID) config.ShardingConfig { + return config.ShardingConfig{ + Mode: config.ShardingModeChild, + ShardIDLength: 1, + Child: config.ChildConfig{ + ShardID: shardID, + }, + } }