diff --git a/README.md b/README.md index b8c566f..bf42658 100644 --- a/README.md +++ b/README.md @@ -310,47 +310,63 @@ type CertificationData struct { - `INVALID_TRANSACTION_HASH_FORMAT` - TransactionHash not in proper DataHash imprint format - `INVALID_SHARD` - The certification request was sent to the wrong shard -#### `get_inclusion_proof` -Retrieve the Sparse Merkle Tree inclusion proof for a specific state transition request. +#### `get_inclusion_proof.v2` +Retrieve the v2 inclusion proof for a submitted certification request. + +The `stateId` must be exactly 64 hex characters (32 raw bytes). **Request:** ```json { "jsonrpc": "2.0", - "method": "get_inclusion_proof", + "method": "get_inclusion_proof.v2", "params": { - "stateId": "0000c7aa6962316c0eeb1469dc3d7793e39e140c005e6eea0e188dcc73035d765937" + "stateId": "c7aa6962316c0eeb1469dc3d7793e39e140c005e6eea0e188dcc73035d765937" }, "id": 2 } ``` **Response:** + ```json { - "jsonrpc":"2.0", - "result":{ - "certificationData":{ - "publicKey":"027c4fdf89e8138b360397a7285ca99b863499d26f3c1652251fcf680f4d64882c", - "signature":"65ed0261e093aa2df02c0e8fb0aa46144e053ea705ce7053023745b3626c60550b2a5e90eacb93416df116af96872547608a31de1f8ef25dc5a79104e6b69c8d00", - "sourceStateHash":"0000539cb40d7450fa842ac13f4ea50a17e56c5b1ee544257d46b6ec8bb48a63e647", - "transactionHash":"0000c5f9a1f02e6475c599449250bb741b49bd8858afe8a42059ac1522bff47c6297" - }, - "merkleTreePath":{ - "root":"0000342d44bb4f43b2de5661cf3690254b95b49e46820b90f13fbe2798f428459ba4", - "steps":[ - { - "branch":["0000f00a106493f8bee8846b84325fe71411ea01b8a7c5d7cc0853888b1ef9cbf83b"], - "path":"7588619140208316325429861720569170962648734570557434545963804239233978322458521890", - "sibling":null - } - ] - } - }, - "id":2 + "jsonrpc": "2.0", + "result": "", + "id": 2 } ``` +The `result` field is a hex-encoded CBOR array: +``` +[blockNumber, [certificationData, certificateBytes, unicityCertificate]] +``` + +- `certificationData` is the certification data for inclusion proofs, or `null` for non-inclusion proofs. +- For inclusion proofs, `certificateBytes` is the binary inclusion certificate: `bitmap[32] || sibling_1[32] || ... || sibling_n[32]`, where `n = popcount(bitmap)`. Siblings are in root-to-leaf order. For non-inclusion proofs, `certificateBytes` is an exclusion certificate: `k_l[32] || h_l[32] || bitmap[32] || siblings...` (exclusion proof generation is not yet implemented). +- The expected SMT root is always taken from `UC.IR.h` (input record hash of the Unicity Certificate). No root field appears in the certificate itself. + +**Hash rules (Yellowpaper-aligned):** +- Leaf: `H(0x00 || key || value)` where value is the raw transaction hash bytes +- Inner node (two children): `H(0x01 || depth_byte || left || right)` +- Inner node (one child): passthrough (child hash unchanged) + +**Key encoding:** 32 bytes, LSB-first bit addressing. `bit(key, d) = (key[d/8] >> (d%8)) & 1`. + +**Verification pseudocode:** +``` +h = H(0x00 || key || value) +j = len(siblings) +for d in 255..=0: + if bitmap bit d is not set: continue + j -= 1 + if bit(key, d) == 1: + h = H(0x01 || d || siblings[j] || h) + else: + h = H(0x01 || d || h || siblings[j]) +assert j == 0 and h == UC.IR.h +``` + #### `get_block_height` Retrieve the current blockchain height. diff --git a/cmd/aggregator/main.go b/cmd/aggregator/main.go index cb9ca3a..0ecbb08 100644 --- a/cmd/aggregator/main.go +++ b/cmd/aggregator/main.go @@ -168,9 +168,9 @@ func main() { var smtInstance *smt.SparseMerkleTree switch cfg.Sharding.Mode { case config.ShardingModeStandalone: - smtInstance = smt.NewSparseMerkleTree(api.SHA256, 16+256) + smtInstance = smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) case config.ShardingModeChild: - smtInstance = smt.NewChildSparseMerkleTree(api.SHA256, 16+256, cfg.Sharding.Child.ShardID) + smtInstance = smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, cfg.Sharding.Child.ShardID) case config.ShardingModeParent: smtInstance = smt.NewParentSparseMerkleTree(api.SHA256, cfg.Sharding.ShardIDLength) default: diff --git a/cmd/commitment/main.go b/cmd/commitment/main.go index 65a6b7f..02b21dc 100644 --- a/cmd/commitment/main.go +++ b/cmd/commitment/main.go @@ -9,7 +9,6 @@ import ( "flag" "fmt" "log" - "math/big" "net/http" "os" "time" @@ -67,112 +66,80 @@ func main() { }, } - commitReq := generateCommitmentRequest() + certReq := generateCertificationRequest() if *flagVerbose { - if payload, err := json.MarshalIndent(commitReq, "", " "); err == nil { - logger.Printf("submit_commitment request:\n%s", payload) + if payload, err := json.MarshalIndent(certReq, "", " "); err == nil { + logger.Printf("certification_request request:\n%s", payload) } } - logger.Printf("Submitting commitment to URL: %s", *flagURL) + logger.Printf("Submitting certification request to URL: %s", *flagURL) - submitResp, err := callJSONRPC(ctx, client, *flagURL, *flagAuth, "submit_commitment", commitReq) + submitResp, err := callJSONRPC(ctx, client, *flagURL, *flagAuth, "certification_request", certReq) if err != nil { - logger.Fatalf("submit_commitment call failed: %v", err) + logger.Fatalf("certification_request call failed: %v", err) } if *flagVerbose { if payload, err := json.MarshalIndent(submitResp, "", " "); err == nil { - logger.Printf("submit_commitment response:\n%s", payload) + logger.Printf("certification_request response:\n%s", payload) } } - var submitResult api.SubmitCommitmentResponse + var submitResult api.CertificationResponse if submitResp.Error != nil { - logger.Fatalf("submit_commitment returned error: %s (code %d)", submitResp.Error.Message, submitResp.Error.Code) + logger.Fatalf("certification_request returned error: %s (code %d)", submitResp.Error.Message, submitResp.Error.Code) } if err := json.Unmarshal(submitResp.Result, &submitResult); err != nil { - logger.Fatalf("failed to decode submit_commitment result: %v", err) + logger.Fatalf("failed to decode certification_request result: %v", err) } - if submitResult.Status != "SUCCESS" { - logger.Fatalf("submit_commitment status was %q", submitResult.Status) - } - - // Verify and display the receipt if present - if submitResult.Receipt != nil { - logger.Printf("Receipt received:") - logger.Printf(" Algorithm: %s", submitResult.Receipt.Algorithm) - logger.Printf(" PublicKey: %s", submitResult.Receipt.PublicKey) - logger.Printf(" Signature: %s", submitResult.Receipt.Signature) - logger.Printf(" Request.Service: %s", submitResult.Receipt.Request.Service) - logger.Printf(" Request.Method: %s", submitResult.Receipt.Request.Method) - logger.Printf(" Request.RequestID: %s", submitResult.Receipt.Request.RequestID) - - // Verify the receipt signature - requestBytes, err := json.Marshal(submitResult.Receipt.Request) - if err != nil { - logger.Printf("Warning: failed to marshal receipt request for verification: %v", err) - } else { - signingService := signing.NewSigningService() - valid, err := signingService.VerifyWithPublicKey(requestBytes, submitResult.Receipt.Signature, submitResult.Receipt.PublicKey) - if err != nil { - logger.Printf("Warning: receipt signature verification error: %v", err) - } else if valid { - logger.Printf("Receipt signature VERIFIED successfully!") - } else { - logger.Printf("Warning: receipt signature verification FAILED!") - } - } - } else { - logger.Printf("No receipt received (receipt was requested: %v)", *commitReq.Receipt) - } - - logger.Printf("Commitment %s accepted. Polling for inclusion proof...", commitReq.RequestID) - - path, err := commitReq.RequestID.GetPath() - if err != nil { - logger.Fatalf("failed to derive SMT path: %v", err) + if submitResult.Status != "SUCCESS" && submitResult.Status != "STATE_ID_EXISTS" { + logger.Fatalf("certification_request status was %q", submitResult.Status) } + logger.Printf("Certification request %s accepted. Polling for inclusion proof...", certReq.StateID) submittedAt := time.Now() - inclusionProof, verification, attempts, err := waitForInclusionProof(ctx, client, commitReq.RequestID, path, logger) + inclusionProof, attempts, err := waitForInclusionProof(ctx, client, certReq, logger) if err != nil { logger.Fatalf("failed to retrieve inclusion proof after %d attempt(s): %v", attempts, err) } if *flagVerbose { if payload, err := json.MarshalIndent(inclusionProof, "", " "); err == nil { - logger.Printf("get_inclusion_proof response:\n%s", payload) + logger.Printf("get_inclusion_proof.v2 response:\n%s", payload) } } - logger.Printf("Proof verification result: pathValid=%t pathIncluded=%t overall=%t", - verification.PathValid, verification.PathIncluded, verification.Result) + if err := inclusionProof.InclusionProof.Verify(certReq); err != nil { + logger.Fatalf("proof verification failed: %v", err) + } + logger.Printf("Proof verified successfully against block %d.", inclusionProof.BlockNumber) elapsed := time.Since(submittedAt) logger.Printf("Valid inclusion proof received in %s after %d attempt(s).", elapsed.Round(time.Millisecond), attempts) - logger.Printf("Commitment %s successfully submitted and verified.", commitReq.RequestID) + logger.Printf("Certification request %s successfully submitted and verified.", certReq.StateID) } -func generateCommitmentRequest() *api.SubmitCommitmentRequest { +func generateCertificationRequest() *api.CertificationRequest { privateKey, err := btcec.NewPrivateKey() if err != nil { panic(fmt.Sprintf("failed to generate private key: %v", err)) } publicKeyBytes := privateKey.PubKey().SerializeCompressed() + ownerPredicate := api.NewPayToPublicKeyPredicate(publicKeyBytes) stateData := make([]byte, 32) if _, err := rand.Read(stateData); err != nil { panic(fmt.Sprintf("failed to read random state bytes: %v", err)) } - stateHashImprint := signing.CreateDataHash(stateData) + sourceStateHash := signing.CreateDataHash(stateData) - requestID, err := api.CreateRequestID(publicKeyBytes, stateHashImprint) + stateID, err := api.CreateStateID(ownerPredicate, sourceStateHash) if err != nil { - panic(fmt.Sprintf("failed to create request ID: %v", err)) + panic(fmt.Sprintf("failed to create state ID: %v", err)) } transactionData := make([]byte, 32) @@ -180,25 +147,20 @@ func generateCommitmentRequest() *api.SubmitCommitmentRequest { panic(fmt.Sprintf("failed to read random transaction bytes: %v", err)) } - transactionHashImprint := signing.CreateDataHash(transactionData) - transactionHashBytes := transactionHashImprint.DataBytes() - - signature, err := signing.NewSigningService().SignHash(transactionHashBytes, privateKey.Serialize()) - if err != nil { - panic(fmt.Sprintf("failed to sign transaction hash: %v", err)) + transactionHash := signing.CreateDataHash(transactionData) + certData := api.CertificationData{ + OwnerPredicate: ownerPredicate, + SourceStateHash: sourceStateHash, + TransactionHash: transactionHash, + } + if err := signing.NewSigningService().SignCertData(&certData, privateKey.Serialize()); err != nil { + panic(fmt.Sprintf("failed to sign certification data: %v", err)) } - receipt := true - return &api.SubmitCommitmentRequest{ - RequestID: requestID, - TransactionHash: transactionHashImprint, - Authenticator: api.Authenticator{ - Algorithm: "secp256k1", - PublicKey: publicKeyBytes, - Signature: signature, - StateHash: stateHashImprint, - }, - Receipt: &receipt, + return &api.CertificationRequest{ + StateID: stateID, + CertificationData: certData, + AggregateRequestCount: 1, } } @@ -219,10 +181,9 @@ func callJSONRPC(ctx context.Context, client *http.Client, url, authHeader, meth } req.Header.Set("Content-Type", "application/json") - /*if authHeader != "" { - req.Header.Set("Authorization", "supersecret") - }*/ - req.Header.Set("Authorization", "Bearer supersecret") + if authHeader != "" { + req.Header.Set("Authorization", authHeader) + } resp, err := client.Do(req) if err != nil { @@ -238,7 +199,7 @@ func callJSONRPC(ctx context.Context, client *http.Client, url, authHeader, meth return &rpcResp, nil } -func waitForInclusionProof(ctx context.Context, client *http.Client, requestID api.RequestID, requestPath *big.Int, logger *log.Logger) (*api.GetInclusionProofResponseV2, *api.PathVerificationResult, int, error) { +func waitForInclusionProof(ctx context.Context, client *http.Client, req *api.CertificationRequest, logger *log.Logger) (*api.GetInclusionProofResponseV2, int, error) { deadline, ok := ctx.Deadline() if !ok { deadline = time.Now().Add(45 * time.Second) @@ -249,65 +210,46 @@ func waitForInclusionProof(ctx context.Context, client *http.Client, requestID a attempts++ select { case <-ctx.Done(): - return nil, nil, attempts, ctx.Err() + return nil, attempts, ctx.Err() default: } - proofResp, err := callJSONRPC(ctx, client, *flagURL, *flagAuth, "get_inclusion_proof", api.GetInclusionProofRequestV2{ - StateID: requestID, + proofResp, err := callJSONRPC(ctx, client, *flagURL, *flagAuth, "get_inclusion_proof.v2", api.GetInclusionProofRequestV2{ + StateID: req.StateID, }) if err != nil { - logger.Printf("get_inclusion_proof attempt %d failed: %v", attempts, err) + logger.Printf("get_inclusion_proof.v2 attempt %d failed: %v", attempts, err) time.Sleep(*flagPollInterval) continue } if proofResp.Error != nil { - logger.Printf("get_inclusion_proof attempt %d returned error: %s (code %d)", attempts, proofResp.Error.Message, proofResp.Error.Code) + logger.Printf("get_inclusion_proof.v2 attempt %d returned error: %s (code %d)", attempts, proofResp.Error.Message, proofResp.Error.Code) time.Sleep(*flagPollInterval) continue } var payload api.GetInclusionProofResponseV2 if err := json.Unmarshal(proofResp.Result, &payload); err != nil { - logger.Printf("get_inclusion_proof attempt %d decode error: %v", attempts, err) + logger.Printf("get_inclusion_proof.v2 attempt %d decode error: %v", attempts, err) time.Sleep(*flagPollInterval) continue } - if payload.InclusionProof == nil || payload.InclusionProof.MerkleTreePath == nil { - logger.Printf("get_inclusion_proof attempt %d: proof payload incomplete, retrying...", attempts) + if payload.InclusionProof == nil || len(payload.InclusionProof.UnicityCertificate) == 0 { + logger.Printf("get_inclusion_proof.v2 attempt %d: proof payload incomplete, retrying...", attempts) time.Sleep(*flagPollInterval) continue } - result, err := verifyProof(&payload, requestPath) - if err != nil { - logger.Printf("get_inclusion_proof attempt %d verification error: %v", attempts, err) + if err := payload.InclusionProof.Verify(req); err != nil { + logger.Printf("get_inclusion_proof.v2 attempt %d verification error: %v", attempts, err) time.Sleep(*flagPollInterval) continue } - if result.PathIncluded { - return &payload, result, attempts, nil - } - - logger.Printf("get_inclusion_proof attempt %d: proof returned but path not included yet (pathValid=%t). Waiting...", - attempts, result.PathValid) - time.Sleep(*flagPollInterval) - } - - return nil, nil, attempts, fmt.Errorf("timed out waiting for inclusion proof for request %s", requestID) -} - -func verifyProof(resp *api.GetInclusionProofResponseV2, path *big.Int) (*api.PathVerificationResult, error) { - if resp == nil || resp.InclusionProof == nil { - return nil, fmt.Errorf("inclusion proof payload was empty") - } - - if resp.InclusionProof.MerkleTreePath == nil { - return nil, fmt.Errorf("merkle tree path missing from inclusion proof") + return &payload, attempts, nil } - return resp.InclusionProof.MerkleTreePath.Verify(path) + return nil, attempts, fmt.Errorf("timed out waiting for inclusion proof for state ID %s", req.StateID) } diff --git a/cmd/performance-test/main.go b/cmd/performance-test/main.go index 4f6c18f..1fbdf95 100644 --- a/cmd/performance-test/main.go +++ b/cmd/performance-test/main.go @@ -5,12 +5,10 @@ import ( "bytes" "context" "crypto/rand" - "encoding/hex" "encoding/json" "fmt" "log" "math" - "math/big" "os" "path/filepath" "runtime" @@ -198,32 +196,19 @@ func selectShardIndex(requestID api.StateID, shardClients []*ShardClient) int { if len(imprint) == 0 { return 0 } - return int(imprint[len(imprint)-1]) % shardCount + keyBytes := requestID.DataBytes() + if len(keyBytes) == 0 { + return 0 + } + // Fallback only: keep distribution aligned with LSB-first key layout. + return int(keyBytes[0]) % shardCount } func matchesShardMask(requestIDHex string, shardMask int) (bool, error) { if shardMask <= 0 { return false, nil } - - bytes, err := hex.DecodeString(requestIDHex) - if err != nil { - return false, fmt.Errorf("failed to decode request ID: %w", err) - } - - requestBig := new(big.Int).SetBytes(bytes) - maskBig := new(big.Int).SetInt64(int64(shardMask)) - - msbPos := maskBig.BitLen() - 1 - if msbPos < 0 { - return false, fmt.Errorf("invalid shard mask: %d", shardMask) - } - - compareMask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(msbPos)), big.NewInt(1)) - expected := new(big.Int).And(maskBig, compareMask) - requestLowBits := new(big.Int).And(requestBig, compareMask) - - return requestLowBits.Cmp(expected) == 0, nil + return api.MatchesShardPrefixFromHex(requestIDHex, shardMask) } // matchesAnyShardTarget checks if a request ID matches any of the configured shard targets @@ -253,7 +238,7 @@ func generateCommitmentRequest() *api.CertificationRequest { publicKeyBytes := privateKey.PubKey().SerializeCompressed() ownerPredicate := api.NewPayToPublicKeyPredicate(publicKeyBytes) - // Generate random state data and create DataHash imprint + // Generate random state data and hash it. stateData := make([]byte, 32) rand.Read(stateData) sourceStateHashImprint := signing.CreateDataHash(stateData) @@ -289,7 +274,7 @@ func generateCommitmentRequest() *api.CertificationRequest { sourceStateHashImprint = signing.CreateDataHash(stateData) } - // Generate random transaction data and create DataHash imprint + // Generate random transaction data and hash it. transactionData := make([]byte, 32) rand.Read(transactionData) transactionHashImprint := signing.CreateDataHash(transactionData) @@ -392,8 +377,8 @@ func commitmentWorker(ctx context.Context, shardClients []*ShardClient, metrics } switch submitResp.Status { - case "SUCCESS", "REQUEST_ID_EXISTS": - if submitResp.Status == "REQUEST_ID_EXISTS" { + case "SUCCESS", "STATE_ID_EXISTS": + if submitResp.Status == "STATE_ID_EXISTS" { atomic.AddInt64(&metrics.requestIdExistsErr, 1) if sm := metrics.shard(shardIdx); sm != nil { sm.requestIdExistsErr.Add(1) @@ -409,7 +394,7 @@ func commitmentWorker(ctx context.Context, shardClients []*ShardClient, metrics if proofQueue != nil { metrics.recordSubmissionTimestamp(requestIDStr) select { - case proofQueue <- proofJob{shardIdx: shardIdx, requestID: requestIDStr}: + case proofQueue <- proofJob{shardIdx: shardIdx, request: req}: default: // Queue full, skip proof verification for this one } @@ -435,10 +420,21 @@ func proofVerificationWorker(ctx context.Context, shardClients []*ShardClient, m case <-ctx.Done(): return case job := <-proofQueue: - go func(reqID string, shardIdx int) { + go func(job proofJob) { + if job.request == nil { + metrics.recordError("Missing original request for proof verification") + atomic.AddInt64(&metrics.proofVerifyFailed, 1) + if sm := metrics.shard(job.shardIdx); sm != nil { + sm.proofVerifyFailed.Add(1) + } + return + } + + reqID := normalizeRequestID(job.request.StateID.String()) + shardIdx := job.shardIdx time.Sleep(proofInitialDelay) startTime := time.Now() - normalizedID := normalizeRequestID(reqID) + normalizedID := reqID client := shardClients[shardIdx].proofClient // Use separate proof client pool for attempt := 0; attempt < proofMaxRetries; attempt++ { @@ -534,49 +530,22 @@ func proofVerificationWorker(ctx context.Context, shardClients []*ShardClient, m } metrics.addProofLatency(totalLatency) - apiPath := proofResp.InclusionProof.MerkleTreePath - - requestIDPath, err := api.RequireNewImprintV2(reqID).GetPath() - if err != nil { - metrics.recordError(fmt.Sprintf("Failed to get path for request ID: %v", err)) - atomic.AddInt64(&metrics.proofVerifyFailed, 1) - if sm := metrics.shard(shardIdx); sm != nil { - sm.proofVerifyFailed.Add(1) - } - return - } - - result, err := apiPath.Verify(requestIDPath) - if err != nil { - metrics.recordError(fmt.Sprintf("Proof verification error: %v", err)) - atomic.AddInt64(&metrics.proofVerifyFailed, 1) - if sm := metrics.shard(shardIdx); sm != nil { - sm.proofVerifyFailed.Add(1) - } - return - } - - if result.Result { - atomic.AddInt64(&metrics.proofVerified, 1) - if sm := metrics.shard(shardIdx); sm != nil { - sm.proofVerified.Add(1) - } - } else { - if !result.PathIncluded && attempt < proofMaxRetries-1 { + if err := proofResp.InclusionProof.Verify(job.request); err != nil { + if attempt < proofMaxRetries-1 { time.Sleep(proofRetryDelay) continue } - fmt.Printf("\n\n[FATAL ERROR] Proof verification failed for request %s after %d attempts\n", reqID, attempt+1) - fmt.Printf("PathValid: %v\n", result.PathValid) - fmt.Printf("PathIncluded: %v\n", result.PathIncluded) - fmt.Printf("Stopping test due to proof verification failure.\n\n") + metrics.recordError(fmt.Sprintf("Proof verification failed for request %s: %v", reqID, err)) atomic.AddInt64(&metrics.proofVerifyFailed, 1) if sm := metrics.shard(shardIdx); sm != nil { sm.proofVerifyFailed.Add(1) } - cancelTest() return } + atomic.AddInt64(&metrics.proofVerified, 1) + if sm := metrics.shard(shardIdx); sm != nil { + sm.proofVerified.Add(1) + } return } @@ -586,7 +555,7 @@ func proofVerificationWorker(ctx context.Context, shardClients []*ShardClient, m if sm := metrics.shard(shardIdx); sm != nil { sm.proofFailed.Add(1) } - }(job.requestID, job.shardIdx) + }(job) } } } diff --git a/cmd/performance-test/types.go b/cmd/performance-test/types.go index ff6c8ca..55b4246 100644 --- a/cmd/performance-test/types.go +++ b/cmd/performance-test/types.go @@ -17,6 +17,7 @@ import ( "sync/atomic" "time" + "github.com/unicitynetwork/aggregator-go/pkg/api" "golang.org/x/net/http2" ) @@ -109,8 +110,8 @@ type ShardClient struct { } type proofJob struct { - shardIdx int - requestID string + shardIdx int + request *api.CertificationRequest } // RequestRateCounters tracks per-second client-side request activity. diff --git a/internal/bft/client_stub.go b/internal/bft/client_stub.go index 2f7d57a..c6cee0d 100644 --- a/internal/bft/client_stub.go +++ b/internal/bft/client_stub.go @@ -7,6 +7,7 @@ import ( "time" "github.com/unicitynetwork/bft-go-base/types" + "github.com/unicitynetwork/bft-go-base/types/hex" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" @@ -61,6 +62,7 @@ func (n *BFTClientStub) CertificationRequest(ctx context.Context, block *models. uc := types.UnicityCertificate{ InputRecord: &types.InputRecord{ RoundNumber: roundNumber, + Hash: hex.Bytes(block.RootHash), }, UnicitySeal: &types.UnicitySeal{ RootChainRoundNumber: roundNumber, diff --git a/internal/bft/client_stub_test.go b/internal/bft/client_stub_test.go new file mode 100644 index 0000000..a584d85 --- /dev/null +++ b/internal/bft/client_stub_test.go @@ -0,0 +1,61 @@ +package bft + +import ( + "context" + "math/big" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unicitynetwork/bft-go-base/types" + + "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/models" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +type stubRoundManager struct { + finalizedBlocks []*models.Block + startedRounds []*api.BigInt +} + +func (m *stubRoundManager) FinalizeBlock(ctx context.Context, block *models.Block) error { + m.finalizedBlocks = append(m.finalizedBlocks, block) + return nil +} + +func (m *stubRoundManager) FinalizeBlockWithRetry(ctx context.Context, block *models.Block) error { + return m.FinalizeBlock(ctx, block) +} + +func (m *stubRoundManager) StartNewRound(ctx context.Context, roundNumber *api.BigInt) error { + m.startedRounds = append(m.startedRounds, api.NewBigInt(new(big.Int).Set(roundNumber.Int))) + return nil +} + +func TestBFTClientStub_CertificationRequest_PopulatesSyntheticUC(t *testing.T) { + rm := &stubRoundManager{} + log, err := logger.New("warn", "json", "", false) + require.NoError(t, err) + + client := NewBFTClientStub(log, rm, api.NewBigIntFromUint64(1), 0) + block := models.NewBlock( + api.NewBigIntFromUint64(7), + "unicity", + 0, + "1.0", + "mainnet", + api.HexBytes("0123"), + nil, + nil, + ) + + err = client.CertificationRequest(t.Context(), block) + require.NoError(t, err) + require.Len(t, rm.finalizedBlocks, 1) + require.NotEmpty(t, block.UnicityCertificate) + + var uc types.UnicityCertificate + require.NoError(t, types.Cbor.Unmarshal(block.UnicityCertificate, &uc)) + require.EqualValues(t, 7, uc.GetRoundNumber()) + require.EqualValues(t, 7, uc.GetRootRoundNumber()) +} diff --git a/internal/gateway/docs.go b/internal/gateway/docs.go index 6767eb4..a74bd2e 100644 --- a/internal/gateway/docs.go +++ b/internal/gateway/docs.go @@ -170,7 +170,7 @@ func GenerateDocsHTML() string {
certification_request
-
Submit a state transition certification request to the aggregator. The example below uses a real secp256k1 signature that will pass validation. Note: All hash fields (stateId, transactionHash, sourceStateHash) start with "0000" (SHA256 algorithm prefix).
+
Submit a state transition certification request to the aggregator. The example below uses a real secp256k1 signature that will pass validation. In the v2 wire format, stateId, transactionHash, and sourceStateHash are raw 32-byte SHA-256 values with no algorithm-prefix bytes.
@@ -191,28 +191,28 @@ func GenerateDocsHTML() string {
- +
-
get_inclusion_proof
+
get_inclusion_proof.v2
-
Retrieve the inclusion proof for a submitted certification request.
+
Retrieve the v2 inclusion proof for a submitted certification request. The stateId must be the raw 32-byte key used in the aggregation tree.

Request Parameters

-
- - - + + +
-

Response

-
Click "Send Request" to see the response here...
+

Response

+
Click "Send Request" to see the response here...
diff --git a/internal/gateway/docs_test.go b/internal/gateway/docs_test.go index 643ca8a..4ce7934 100644 --- a/internal/gateway/docs_test.go +++ b/internal/gateway/docs_test.go @@ -26,7 +26,7 @@ func TestDocumentationExamplePayload(t *testing.T) { err = types.Cbor.Unmarshal(exampleJSON, &certRequest) require.NoError(t, err, "Failed to parse example CBOR") - // Verify DataHash imprint format (should start with 0000 for SHA256) + // Verify the example uses raw 32-byte hashes. certData := certRequest.CertificationData require.Equal(t, len(certData.SourceStateHash.Imprint()), 32, "State hash should be 32 bytes") require.GreaterOrEqual(t, len(certData.TransactionHash.Imprint()), 32, "Transaction hash should be 32 bytes") diff --git a/internal/gateway/handlers.go b/internal/gateway/handlers.go index c2ab5b4..2ea138b 100644 --- a/internal/gateway/handlers.go +++ b/internal/gateway/handlers.go @@ -50,7 +50,9 @@ func (s *Server) parseCertificationRequest(params json.RawMessage) (*api.Certifi return req, nil } -// handleGetInclusionProofV2 handles the get_inclusion_proof.v2 method +// handleGetInclusionProofV2 handles the get_inclusion_proof.v2 method. +// +// v2 requires stateId to be exactly 32 raw bytes with no algorithm prefix. func (s *Server) handleGetInclusionProofV2(ctx context.Context, params json.RawMessage) (interface{}, *jsonrpc.Error) { var req api.GetInclusionProofRequestV2 if err := json.Unmarshal(params, &req); err != nil { @@ -61,6 +63,11 @@ func (s *Server) handleGetInclusionProofV2(ctx context.Context, params json.RawM if req.StateID == nil { return nil, jsonrpc.NewValidationError("stateId is required") } + if len(req.StateID) != api.StateTreeKeyLengthBytes { + return nil, jsonrpc.NewValidationError(fmt.Sprintf( + "stateId must be exactly %d bytes (v2 wire format), got %d", + api.StateTreeKeyLengthBytes, len(req.StateID))) + } // Call service response, err := s.service.GetInclusionProofV2(ctx, &req) diff --git a/internal/gateway/handlers_v1.go b/internal/gateway/handlers_v1.go deleted file mode 100644 index 2f5a761..0000000 --- a/internal/gateway/handlers_v1.go +++ /dev/null @@ -1,80 +0,0 @@ -package gateway - -import ( - "context" - "encoding/json" - - "github.com/unicitynetwork/aggregator-go/pkg/api" - "github.com/unicitynetwork/aggregator-go/pkg/jsonrpc" -) - -// JSON-RPC method handlers - -// handleSubmitCommitment handles the submit_commitment method -func (s *Server) handleSubmitCommitment(ctx context.Context, params json.RawMessage) (interface{}, *jsonrpc.Error) { - var req api.SubmitCommitmentRequest - if err := json.Unmarshal(params, &req); err != nil { - return nil, jsonrpc.NewValidationError("Invalid parameters: " + err.Error()) - } - - // Validate required fields - if req.RequestID == nil { - return nil, jsonrpc.NewValidationError("requestId is required") - } - if req.TransactionHash == nil { - return nil, jsonrpc.NewValidationError("transactionHash is required") - } - - // Call service - response, err := s.service.SubmitCommitment(ctx, &req) - if err != nil { - s.logger.WithContext(ctx).Error("Failed to submit commitment", "error", err.Error()) - return nil, jsonrpc.NewError(jsonrpc.InternalErrorCode, "Failed to submit commitment", err.Error()) - } - - return response, nil -} - -// handleGetInclusionProofV1 handles the get_inclusion_proof method -func (s *Server) handleGetInclusionProofV1(ctx context.Context, params json.RawMessage) (interface{}, *jsonrpc.Error) { - var req api.GetInclusionProofRequestV1 - if err := json.Unmarshal(params, &req); err != nil { - return nil, jsonrpc.NewValidationError("Invalid parameters: " + err.Error()) - } - - // Validate required fields - if req.RequestID == nil { - return nil, jsonrpc.NewValidationError("requestId is required") - } - - // Call service - response, err := s.service.GetInclusionProofV1(ctx, &req) - if err != nil { - s.logger.WithContext(ctx).Error("Failed to get inclusion proof", "error", err.Error()) - return nil, jsonrpc.NewError(jsonrpc.InternalErrorCode, "Failed to get inclusion proof", err.Error()) - } - - return response, nil -} - -// handleGetBlockCommitments handles the get_block_commitments method -func (s *Server) handleGetBlockCommitments(ctx context.Context, params json.RawMessage) (interface{}, *jsonrpc.Error) { - var req api.GetBlockCommitmentsRequest - if err := json.Unmarshal(params, &req); err != nil { - return nil, jsonrpc.NewValidationError("Invalid parameters: " + err.Error()) - } - - // Validate required fields - if req.BlockNumber == nil { - return nil, jsonrpc.NewValidationError("blockNumber is required") - } - - // Call service - response, err := s.service.GetBlockCommitments(ctx, &req) - if err != nil { - s.logger.WithContext(ctx).Error("Failed to get block commitments", "error", err.Error()) - return nil, jsonrpc.NewError(jsonrpc.InternalErrorCode, "Failed to get block commitments", err.Error()) - } - - return response, nil -} diff --git a/internal/gateway/server.go b/internal/gateway/server.go index d534476..2f7ee98 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -20,24 +20,21 @@ import ( // Server represents the HTTP gateway server type Server struct { - config *config.Config - logger *logger.Logger - rpcServer *jsonrpc.Server - httpServer *http.Server - router *gin.Engine - service Service + config *config.Config + logger *logger.Logger + rpcServer *jsonrpc.Server + httpServer *http.Server + router *gin.Engine + service Service } // Service represents the business logic service interface type Service interface { - SubmitCommitment(ctx context.Context, req *api.SubmitCommitmentRequest) (*api.SubmitCommitmentResponse, error) CertificationRequest(ctx context.Context, req *api.CertificationRequest) (*api.CertificationResponse, error) - GetInclusionProofV1(ctx context.Context, req *api.GetInclusionProofRequestV1) (*api.GetInclusionProofResponseV1, error) GetInclusionProofV2(ctx context.Context, req *api.GetInclusionProofRequestV2) (*api.GetInclusionProofResponseV2, error) GetNoDeletionProof(ctx context.Context) (*api.GetNoDeletionProofResponse, error) GetBlockHeight(ctx context.Context) (*api.GetBlockHeightResponse, error) GetBlock(ctx context.Context, req *api.GetBlockRequest) (*api.GetBlockResponse, error) - GetBlockCommitments(ctx context.Context, req *api.GetBlockCommitmentsRequest) (*api.GetBlockCommitmentsResponse, error) GetBlockRecords(ctx context.Context, req *api.GetBlockRecords) (*api.GetBlockRecordsResponse, error) GetHealthStatus(ctx context.Context) (*api.HealthStatus, error) @@ -164,10 +161,7 @@ func (s *Server) setupJSONRPCHandlers() { s.rpcServer.RegisterMethod("get_shard_proof", s.handleGetShardProof) } else { // Standalone mode handlers (default) - s.rpcServer.RegisterMethod("submit_commitment", s.handleSubmitCommitment) // v1 - s.rpcServer.RegisterMethod("certification_request", s.handleCertificationRequest) // v2 - - s.rpcServer.RegisterMethod("get_inclusion_proof", s.handleGetInclusionProofV1) // v1 + s.rpcServer.RegisterMethod("certification_request", s.handleCertificationRequest) s.rpcServer.RegisterMethod("get_inclusion_proof.v2", s.handleGetInclusionProofV2) // v2 } @@ -175,9 +169,7 @@ func (s *Server) setupJSONRPCHandlers() { s.rpcServer.RegisterMethod("get_no_deletion_proof", s.handleGetNoDeletionProof) s.rpcServer.RegisterMethod("get_block_height", s.handleGetBlockHeight) s.rpcServer.RegisterMethod("get_block", s.handleGetBlock) - - s.rpcServer.RegisterMethod("get_block_commitments", s.handleGetBlockCommitments) // v1 - s.rpcServer.RegisterMethod("get_block_records", s.handleGetBlockRecords) // v2 + s.rpcServer.RegisterMethod("get_block_records", s.handleGetBlockRecords) } // Start starts the HTTP server diff --git a/internal/ha/block_syncer.go b/internal/ha/block_syncer.go index 9708678..266f438 100644 --- a/internal/ha/block_syncer.go +++ b/internal/ha/block_syncer.go @@ -139,7 +139,7 @@ func (bs *BlockSyncer) SyncToLatestBlock(ctx context.Context) error { return nil } -func (bs *BlockSyncer) verifySMTForBlock(ctx context.Context, smtRootHash string, blockNumber *api.BigInt) error { +func (bs *BlockSyncer) verifySMTForBlock(ctx context.Context, smtRootHash api.HexBytes, blockNumber *api.BigInt) error { block, err := bs.storage.BlockStorage().GetByNumber(ctx, blockNumber) if err != nil { return fmt.Errorf("failed to fetch block: %w", err) @@ -147,10 +147,10 @@ func (bs *BlockSyncer) verifySMTForBlock(ctx context.Context, smtRootHash string if block == nil { return fmt.Errorf("block not found for block number: %s", blockNumber.String()) } - expectedRootHash := block.RootHash.String() - if smtRootHash != expectedRootHash { + expectedRootHash := block.RootHash + if smtRootHash.String() != expectedRootHash.String() { return fmt.Errorf("smt root hash %s does not match latest block root hash %s", - smtRootHash, expectedRootHash) + smtRootHash.String(), expectedRootHash.String()) } return nil } @@ -167,11 +167,11 @@ func (bs *BlockSyncer) updateSMTForBlock(ctx context.Context, blockRecord *model } uniqueStateIds[key] = struct{}{} - path, err := stateID.GetPath() + keyBytes, err := stateID.GetTreeKey() if err != nil { - return fmt.Errorf("failed to get path: %w", err) + return fmt.Errorf("failed to get SMT key: %w", err) } - leafIDs = append(leafIDs, api.NewHexBytes(path.Bytes())) + leafIDs = append(leafIDs, api.NewHexBytes(keyBytes)) } // load smt nodes by ids smtNodes, err := bs.storage.SmtStorage().GetByKeys(ctx, leafIDs) @@ -181,16 +181,20 @@ func (bs *BlockSyncer) updateSMTForBlock(ctx context.Context, blockRecord *model // convert smt nodes to leaves leaves := make([]*smt.Leaf, 0, len(smtNodes)) for _, smtNode := range smtNodes { - leaves = append(leaves, smt.NewLeaf(new(big.Int).SetBytes(smtNode.Key), smtNode.Value)) + path, err := api.FixedBytesToPath(smtNode.Key, bs.smt.GetKeyLength()) + if err != nil { + return fmt.Errorf("failed to convert SMT key to path: %w", err) + } + leaves = append(leaves, smt.NewLeaf(path, smtNode.Value)) } // apply changes to smt snapshot snapshot := bs.smt.CreateSnapshot() - smtRootHash, err := snapshot.AddLeaves(leaves) - if err != nil { + if _, err := snapshot.AddLeaves(leaves); err != nil { return fmt.Errorf("failed to apply SMT updates for block %s: %w", blockRecord.BlockNumber.String(), err) } - // verify smt root hash matches block store root hash + smtRootHash := api.HexBytes(snapshot.GetRootHashRaw()) + // verify smt root hash matches the raw 32-byte block root hash if err := bs.verifySMTForBlock(ctx, smtRootHash, blockRecord.BlockNumber); err != nil { return fmt.Errorf("failed to verify SMT: %w", err) } diff --git a/internal/ha/block_syncer_test.go b/internal/ha/block_syncer_test.go index 8c5f448..80f22f0 100644 --- a/internal/ha/block_syncer_test.go +++ b/internal/ha/block_syncer_test.go @@ -51,7 +51,7 @@ func TestBlockSyncer(t *testing.T) { // initialize block syncer with isLeader=false mockLeader := &mockLeaderSelector{} - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) stateTracker := state.NewSyncStateTracker() syncer := NewBlockSyncer(testLogger, mockLeader, storage, smtInstance, 0, cfg.Processing.RoundDuration, stateTracker) @@ -66,14 +66,14 @@ func TestBlockSyncer(t *testing.T) { time.Sleep(2 * cfg.Processing.RoundDuration) // SMT root hash should match persisted block root hash after block sync - require.Equal(t, rootHash.String(), smtInstance.GetRootHash()) + require.Equal(t, rootHash.String(), api.HexBytes(smtInstance.GetRootHashRaw()).String()) require.Equal(t, big.NewInt(1), stateTracker.GetLastSyncedBlock()) // verify the blocks are not synced if node is leader mockLeader.isLeader.Store(true) createBlock(t, storage, 2) time.Sleep(2 * cfg.Processing.RoundDuration) - require.Equal(t, rootHash.String(), smtInstance.GetRootHash()) + require.Equal(t, rootHash.String(), api.HexBytes(smtInstance.GetRootHashRaw()).String()) require.Equal(t, big.NewInt(1), stateTracker.GetLastSyncedBlock()) } @@ -94,7 +94,7 @@ func createBlock(t *testing.T, storage *mongodb.Storage, blockNum int64) api.Hex path, err := c.StateID.GetPath() require.NoError(t, err) - val, err := c.CertificationData.ToAPI().Hash() + val, err := c.LeafValue() require.NoError(t, err) leaves[i] = &smt.Leaf{Path: path, Value: val} @@ -107,7 +107,9 @@ func createBlock(t *testing.T, storage *mongodb.Storage, blockNum int64) api.Hex // persist smt nodes smtNodes := make([]*models.SmtNode, len(leaves)) for i, leaf := range leaves { - key := api.NewHexBytes(leaf.Path.Bytes()) + keyBytes, err := api.PathToFixedBytes(leaf.Path, api.StateTreeKeyLengthBits) + require.NoError(t, err) + key := api.NewHexBytes(keyBytes) value := api.NewHexBytes(leaf.Value) smtNodes[i] = models.NewSmtNode(key, value) } @@ -115,12 +117,12 @@ func createBlock(t *testing.T, storage *mongodb.Storage, blockNum int64) api.Hex require.NoError(t, err) // compute rootHash - tmpSMT := smt.NewSparseMerkleTree(api.SHA256, 16+256) + tmpSMT := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) require.NoError(t, tmpSMT.AddLeaves(leaves)) - rootHash := api.NewHexBytes(tmpSMT.GetRootHash()) + rootHash := api.HexBytes(tmpSMT.GetRootHashRaw()) // persist block - block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", rootHash, nil, nil, nil) + block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", rootHash, nil, nil) block.Finalized = true // Mark as finalized so GetLatestNumber finds it err = storage.BlockStorage().Store(ctx, block) require.NoError(t, err) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index d0cbeda..9572f3c 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -15,14 +15,11 @@ import ( // Any method name not in this set is normalized to "unknown" to prevent // unbounded label cardinality from arbitrary client input. var knownMethods = map[string]struct{}{ - "submit_commitment": {}, "certification_request": {}, - "get_inclusion_proof": {}, "get_inclusion_proof.v2": {}, "get_no_deletion_proof": {}, "get_block_height": {}, "get_block": {}, - "get_block_commitments": {}, "get_block_records": {}, "submit_shard_root": {}, "get_shard_proof": {}, diff --git a/internal/models/block.go b/internal/models/block.go index 7fa706b..72495d5 100644 --- a/internal/models/block.go +++ b/internal/models/block.go @@ -1,7 +1,6 @@ package models import ( - "encoding/json" "fmt" "time" @@ -12,18 +11,19 @@ import ( // Block represents a blockchain block type Block struct { - Index *api.BigInt `json:"index"` - ChainID string `json:"chainId"` - ShardID api.ShardID `json:"shardId"` - Version string `json:"version"` - ForkID string `json:"forkId"` - RootHash api.HexBytes `json:"rootHash"` - PreviousBlockHash api.HexBytes `json:"previousBlockHash"` - NoDeletionProofHash api.HexBytes `json:"noDeletionProofHash"` - CreatedAt *api.Timestamp `json:"createdAt"` - UnicityCertificate api.HexBytes `json:"unicityCertificate"` - ParentMerkleTreePath *api.MerkleTreePath `json:"parentMerkleTreePath,omitempty"` // child mode only - Finalized bool `json:"finalized"` // true when all data is persisted + Index *api.BigInt `json:"index"` + ChainID string `json:"chainId"` + ShardID api.ShardID `json:"shardId"` + Version string `json:"version"` + ForkID string `json:"forkId"` + RootHash api.HexBytes `json:"rootHash"` + PreviousBlockHash api.HexBytes `json:"previousBlockHash"` + NoDeletionProofHash api.HexBytes `json:"noDeletionProofHash"` + CreatedAt *api.Timestamp `json:"createdAt"` + UnicityCertificate api.HexBytes `json:"unicityCertificate"` + ParentFragment *api.ParentInclusionFragment `json:"parentFragment,omitempty"` // child mode only + ParentBlockNumber uint64 `json:"parentBlockNumber,omitempty"` // child mode only + Finalized bool `json:"finalized"` // true when all data is persisted } // BlockBSON represents the BSON version of Block for MongoDB storage @@ -38,23 +38,29 @@ type BlockBSON struct { NoDeletionProofHash string `bson:"noDeletionProofHash,omitempty"` CreatedAt time.Time `bson:"createdAt"` UnicityCertificate string `bson:"unicityCertificate"` - MerkleTreePath string `bson:"merkleTreePath,omitempty"` // child mode only + ParentFragment *ParentFragmentBSON `bson:"parentFragment,omitempty"` // child mode only + ParentBlockNumber uint64 `bson:"parentBlockNumber,omitempty"` Finalized bool `bson:"finalized"` } +// ParentFragmentBSON is the BSON representation of ParentInclusionFragment. +type ParentFragmentBSON struct { + CertificateBytes []byte `bson:"certificateBytes"` + ShardLeafValue []byte `bson:"shardLeafValue"` +} + // ToBSON converts Block to BlockBSON for MongoDB storage func (b *Block) ToBSON() (*BlockBSON, error) { indexDecimal, err := primitive.ParseDecimal128(b.Index.String()) if err != nil { return nil, fmt.Errorf("error converting block index to decimal-128: %w", err) } - var merkleTreePath string - if b.ParentMerkleTreePath != nil { - merkleTreePathJson, err := json.Marshal(b.ParentMerkleTreePath) - if err != nil { - return nil, fmt.Errorf("failed to marshal parent merkle tree path: %w", err) + var parentFragment *ParentFragmentBSON + if b.ParentFragment != nil { + parentFragment = &ParentFragmentBSON{ + CertificateBytes: append([]byte(nil), b.ParentFragment.CertificateBytes...), + ShardLeafValue: append([]byte(nil), b.ParentFragment.ShardLeafValue...), } - merkleTreePath = api.NewHexBytes(merkleTreePathJson).String() } return &BlockBSON{ Index: indexDecimal, @@ -67,7 +73,8 @@ func (b *Block) ToBSON() (*BlockBSON, error) { NoDeletionProofHash: b.NoDeletionProofHash.String(), CreatedAt: b.CreatedAt.Time, UnicityCertificate: b.UnicityCertificate.String(), - MerkleTreePath: merkleTreePath, + ParentFragment: parentFragment, + ParentBlockNumber: b.ParentBlockNumber, Finalized: b.Finalized, }, nil } @@ -94,15 +101,11 @@ func (bb *BlockBSON) FromBSON() (*Block, error) { return nil, fmt.Errorf("failed to parse unicityCertificate: %w", err) } - var parentMerkleTreePath *api.MerkleTreePath - if bb.MerkleTreePath != "" { - hexBytes, err := api.NewHexBytesFromString(bb.MerkleTreePath) - if err != nil { - return nil, fmt.Errorf("failed to parse parentMerkleTreePath: %w", err) - } - parentMerkleTreePath = &api.MerkleTreePath{} - if err := json.Unmarshal(hexBytes, parentMerkleTreePath); err != nil { - return nil, fmt.Errorf("failed to parse parentMerkleTreePath: %w", err) + var parentFragment *api.ParentInclusionFragment + if bb.ParentFragment != nil { + parentFragment = &api.ParentInclusionFragment{ + CertificateBytes: append([]byte(nil), bb.ParentFragment.CertificateBytes...), + ShardLeafValue: append([]byte(nil), bb.ParentFragment.ShardLeafValue...), } } @@ -112,33 +115,41 @@ func (bb *BlockBSON) FromBSON() (*Block, error) { } return &Block{ - Index: index, - ChainID: bb.ChainID, - ShardID: bb.ShardID, - Version: bb.Version, - ForkID: bb.ForkID, - RootHash: rootHash, - PreviousBlockHash: previousBlockHash, - NoDeletionProofHash: noDeletionProofHash, - CreatedAt: api.NewTimestamp(bb.CreatedAt), - UnicityCertificate: unicityCertificate, - ParentMerkleTreePath: parentMerkleTreePath, - Finalized: bb.Finalized, + Index: index, + ChainID: bb.ChainID, + ShardID: bb.ShardID, + Version: bb.Version, + ForkID: bb.ForkID, + RootHash: rootHash, + PreviousBlockHash: previousBlockHash, + NoDeletionProofHash: noDeletionProofHash, + CreatedAt: api.NewTimestamp(bb.CreatedAt), + UnicityCertificate: unicityCertificate, + ParentFragment: parentFragment, + ParentBlockNumber: bb.ParentBlockNumber, + Finalized: bb.Finalized, }, nil } // NewBlock creates a new block -func NewBlock(index *api.BigInt, chainID string, shardID api.ShardID, version, forkID string, rootHash, previousBlockHash, uc api.HexBytes, parentMerkleTreePath *api.MerkleTreePath) *Block { +func NewBlock(index *api.BigInt, chainID string, shardID api.ShardID, version, forkID string, rootHash, previousBlockHash, uc api.HexBytes) *Block { return &Block{ - Index: index, - ChainID: chainID, - ShardID: shardID, - Version: version, - ForkID: forkID, - RootHash: rootHash, - PreviousBlockHash: previousBlockHash, - CreatedAt: api.Now(), - UnicityCertificate: uc, - ParentMerkleTreePath: parentMerkleTreePath, + Index: index, + ChainID: chainID, + ShardID: shardID, + Version: version, + ForkID: forkID, + RootHash: rootHash, + PreviousBlockHash: previousBlockHash, + CreatedAt: api.Now(), + UnicityCertificate: uc, } } + +// NewChildBlock creates a block for child mode with required parent proof metadata. +func NewChildBlock(index *api.BigInt, chainID string, shardID api.ShardID, version, forkID string, rootHash, previousBlockHash, uc api.HexBytes, parentFragment *api.ParentInclusionFragment, parentBlockNumber uint64) *Block { + block := NewBlock(index, chainID, shardID, version, forkID, rootHash, previousBlockHash, uc) + block.ParentFragment = parentFragment + block.ParentBlockNumber = parentBlockNumber + return block +} diff --git a/internal/models/block_test.go b/internal/models/block_test.go index 777962b..83431c5 100644 --- a/internal/models/block_test.go +++ b/internal/models/block_test.go @@ -35,8 +35,10 @@ func createTestBlock() *Block { NoDeletionProofHash: randomHash, CreatedAt: api.Now(), UnicityCertificate: randomHash, - ParentMerkleTreePath: &api.MerkleTreePath{ - Root: randomHash.String(), + ParentFragment: &api.ParentInclusionFragment{ + CertificateBytes: randomHash, + ShardLeafValue: randomHash, }, + ParentBlockNumber: 7, } } diff --git a/internal/models/certification_request.go b/internal/models/certification_request.go index 8e2446a..4d18878 100644 --- a/internal/models/certification_request.go +++ b/internal/models/certification_request.go @@ -112,7 +112,8 @@ func (c *CertificationRequest) LeafValue() ([]byte, error) { case 0, 1: return c.ToV1().ToAPI().CreateLeafValue() case 2: - return c.CertificationData.ToAPI().Hash() + // v2 semantics: leaf value is the transaction hash bytes. + return c.CertificationData.TransactionHash.DataBytes(), nil default: return nil, fmt.Errorf("invalid version: %d", c.Version) } diff --git a/internal/models/certification_request_leafvalue_test.go b/internal/models/certification_request_leafvalue_test.go new file mode 100644 index 0000000..a9889b5 --- /dev/null +++ b/internal/models/certification_request_leafvalue_test.go @@ -0,0 +1,34 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +func TestCertificationRequestLeafValue_V2UsesTransactionHashBytes(t *testing.T) { + txRaw := "11223344556677889900aabbccddeeff00112233445566778899aabbccddeeff" + txLegacy := "0000" + txRaw + + reqRaw := &CertificationRequest{ + Version: 2, + CertificationData: CertificationData{ + TransactionHash: api.RequireNewImprintV2(txRaw), + }, + } + leafRaw, err := reqRaw.LeafValue() + require.NoError(t, err) + require.Equal(t, api.RequireNewImprintV2(txRaw), api.ImprintV2(leafRaw)) + + reqLegacy := &CertificationRequest{ + Version: 2, + CertificationData: CertificationData{ + TransactionHash: api.RequireNewImprintV2(txLegacy), + }, + } + leafLegacy, err := reqLegacy.LeafValue() + require.NoError(t, err) + require.Equal(t, api.RequireNewImprintV2(txRaw), api.ImprintV2(leafLegacy)) +} diff --git a/internal/round/batch_processor.go b/internal/round/batch_processor.go index e3598cb..269caa8 100644 --- a/internal/round/batch_processor.go +++ b/internal/round/batch_processor.go @@ -77,11 +77,13 @@ type leafAddResult struct { rejected []interfaces.CertificationRequestAck } -// ProposeBlock creates and proposes a new block with the given data -func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigInt, rootHash string) error { +// ProposeBlock creates and proposes a new block with the given data. +// rootHash is the raw 32-byte SMT root (no algorithm-id prefix) — the +// block, UC.IR.h and V2 proof wire all bind against this raw form. +func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigInt, rootHash api.HexBytes) error { rm.logger.WithContext(ctx).Info("proposeBlock called", "blockNumber", blockNumber.String(), - "rootHash", rootHash) + "rootHash", rootHash.String()) rm.roundMutex.Lock() if rm.currentRound != nil { @@ -94,7 +96,7 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn rm.logger.WithContext(ctx).Info("Creating block proposal", "blockNumber", blockNumber.String(), - "rootHash", rootHash) + "rootHash", rootHash.String()) // Get parent block hash var parentHash api.HexBytes @@ -114,12 +116,6 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn } } - // Create block (simplified for now) - rootHashBytes, err := api.NewHexBytesFromString(rootHash) - if err != nil { - return fmt.Errorf("failed to parse root hash %s: %w", rootHash, err) - } - switch rm.config.Sharding.Mode { case config.ShardingModeStandalone: block := models.NewBlock( @@ -128,10 +124,9 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn 0, rm.config.Chain.Version, rm.config.Chain.ForkID, - rootHashBytes, + rootHash, parentHash, nil, - nil, ) rm.roundMutex.RLock() if rm.currentRound != nil && !rm.currentRound.StartTime.IsZero() { @@ -151,22 +146,16 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn "blockNumber", blockNumber.String()) return nil case config.ShardingModeChild: - rm.logger.WithContext(ctx).Info("Submitting root hash to parent shard", "rootHash", rootHash) + rm.logger.WithContext(ctx).Info("Submitting root hash to parent shard", "rootHash", rootHash.String()) - // Strip algorithm prefix (first 2 bytes) before sending to parent - // Parent SMT stores raw 32-byte hashes, not the full 34-byte format with algorithm ID - // This is required for JoinPaths to work correctly when combining child and parent proofs - if len(rootHashBytes) < 2 { - return fmt.Errorf("root hash too short: expected at least 2 bytes for algorithm prefix, got %d", len(rootHashBytes)) - } - rootHashRaw := rootHashBytes[2:] // Remove algorithm identifier - if len(rootHashRaw) != 32 { - return fmt.Errorf("child root hash has invalid length after stripping prefix: expected 32 bytes, got %d", len(rootHashRaw)) + if len(rootHash) != api.StateTreeKeyLengthBytes { + return fmt.Errorf("child root hash has invalid length: expected %d bytes, got %d", + api.StateTreeKeyLengthBytes, len(rootHash)) } request := &api.SubmitShardRootRequest{ ShardID: rm.config.Sharding.Child.ShardID, - RootHash: rootHashRaw, + RootHash: rootHash, } submitStart := time.Now() if err := rm.submitShardRootWithRetry(ctx, request); err != nil { @@ -175,22 +164,23 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn submissionDuration := time.Since(submitStart) metrics.ParentRootSubmissionDuration.Observe(submissionDuration.Seconds()) rm.logger.WithContext(ctx).Info("Root hash submitted to parent, polling for inclusion proof...", - "rootHash", rootHashRaw.String(), + "rootHash", rootHash.String(), "submissionDuration", submissionDuration) proofWaitStart := time.Now() var ( proof *api.RootShardInclusionProof parentUC *types.UnicityCertificate + err error ) for { - proof, parentUC, err = rm.pollForParentProof(ctx, rootHashRaw.String()) + proof, parentUC, err = rm.pollForParentProof(ctx, rootHash) if err == nil { break } if errors.Is(err, ErrParentProofPollTimeout) { rm.logger.WithContext(ctx).Warn("Parent shard proof poll timed out, continuing to poll", - "rootHash", rootHashRaw.String(), + "rootHash", rootHash.String(), "timeout", rm.config.Sharding.Child.ParentPollTimeout) continue } @@ -198,7 +188,7 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn } proofWait := time.Since(proofWaitStart) rm.logger.WithContext(ctx).Info("Parent shard proof received", - "rootHash", rootHashRaw.String(), + "rootHash", rootHash.String(), "proofWait", proofWait, "submissionToProof", submissionDuration+proofWait) rm.roundMutex.Lock() @@ -208,16 +198,21 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn } rm.roundMutex.Unlock() - block := models.NewBlock( + if proof.ParentFragment == nil { + return fmt.Errorf("parent shard proof missing native parent fragment") + } + + block := models.NewChildBlock( blockNumber, rm.config.Chain.ID, request.ShardID, rm.config.Chain.Version, rm.config.Chain.ForkID, - rootHashBytes, + rootHash, parentHash, proof.UnicityCertificate, - proof.MerkleTreePath, + proof.ParentFragment, + proof.BlockNumber, ) if err := rm.FinalizeBlockWithRetry(ctx, block); err != nil { return fmt.Errorf("failed to finalize block after retries: %w", err) @@ -259,7 +254,7 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn } } -func (rm *RoundManager) pollForParentProof(ctx context.Context, rootHash string) (*api.RootShardInclusionProof, *types.UnicityCertificate, error) { +func (rm *RoundManager) pollForParentProof(ctx context.Context, rootHash api.HexBytes) (*api.RootShardInclusionProof, *types.UnicityCertificate, error) { pollingCtx, cancel := context.WithTimeout(ctx, rm.config.Sharding.Child.ParentPollTimeout) defer cancel() @@ -275,27 +270,27 @@ func (rm *RoundManager) pollForParentProof(ctx context.Context, rootHash string) } metrics.ParentProofErrorsTotal.Inc() rm.logger.WithContext(ctx).Warn("Timed out waiting for parent shard inclusion proof", - "rootHash", rootHash, + "rootHash", rootHash.String(), "pollDuration", time.Since(pollStart)) - return nil, nil, fmt.Errorf("%w: %s", ErrParentProofPollTimeout, rootHash) + return nil, nil, fmt.Errorf("%w: %s", ErrParentProofPollTimeout, rootHash.String()) case <-ticker.C: request := &api.GetShardProofRequest{ShardID: rm.config.Sharding.Child.ShardID} proof, err := rm.rootClient.GetShardProof(pollingCtx, request) if err != nil { metrics.ParentProofErrorsTotal.Inc() rm.logger.WithContext(ctx).Warn("Failed to fetch parent shard inclusion proof, retrying", - "rootHash", rootHash, + "rootHash", rootHash.String(), "error", err.Error()) continue } - if proof == nil || !proof.IsValid(rootHash) { + if proof == nil || !proof.IsValid(rm.config.Sharding.Child.ShardID, rm.config.Sharding.ShardIDLength, rootHash) { continue } parentUC, err := decodeUnicityCertificate(proof.UnicityCertificate) if err != nil { rm.logger.WithContext(ctx).Warn("Failed to decode parent shard proof UC, retrying", - "rootHash", rootHash, + "rootHash", rootHash.String(), "error", err.Error()) continue } @@ -303,7 +298,7 @@ func (rm *RoundManager) pollForParentProof(ctx context.Context, rootHash string) lastParentRound := rm.lastAcceptedParentUC() if parentUC.GetRoundNumber() <= lastParentRound { rm.logger.WithContext(ctx).Debug("Ignoring stale parent shard proof", - "rootHash", rootHash, + "rootHash", rootHash.String(), "proofParentRound", parentUC.GetRoundNumber(), "lastAcceptedParentRound", lastParentRound) continue @@ -393,6 +388,10 @@ func (rm *RoundManager) FinalizeBlockWithRetry(ctx context.Context, block *model // FinalizeBlock creates and persists a new block with the given data func (rm *RoundManager) FinalizeBlock(ctx context.Context, block *models.Block) error { + if err := rm.validateBlockForMode(block); err != nil { + return err + } + rm.logger.WithContext(ctx).Info("FinalizeBlock called", "blockNumber", block.Index.String(), "rootHash", block.RootHash.String(), @@ -443,7 +442,10 @@ func (rm *RoundManager) FinalizeBlock(ctx context.Context, block *models.Block) } persistDataStart = time.Now() - smtNodes := rm.convertLeavesToNodes(pendingLeaves) + smtNodes, err := rm.convertLeavesToNodes(pendingLeaves) + if err != nil { + return fmt.Errorf("failed to convert leaves to storage nodes: %w", err) + } records := rm.convertCommitmentsToRecords(pendingCommitments, block.Index) block.Finalized = false @@ -486,7 +488,7 @@ func (rm *RoundManager) FinalizeBlock(ctx context.Context, block *models.Block) rm.roundMutex.Lock() if rm.currentRound != nil { rm.currentRound.Block = block - rm.currentRound.PendingRootHash = "" + rm.currentRound.PendingRootHash = nil rm.currentRound.PendingLeaves = nil rm.currentRound.PendingCommitments = nil rm.currentRound.Snapshot = nil @@ -590,20 +592,39 @@ func (rm *RoundManager) FinalizeBlock(ctx context.Context, block *models.Block) return nil } +func (rm *RoundManager) validateBlockForMode(block *models.Block) error { + if block == nil { + return errors.New("block is nil") + } + if rm.config.Sharding.Mode.IsChild() { + if block.ParentFragment == nil { + return errors.New("child-mode block missing parent fragment") + } + if block.ParentBlockNumber == 0 { + return errors.New("child-mode block missing parent block number") + } + } + return nil +} + // convertLeavesToNodes converts SMT leaves to storage models -func (rm *RoundManager) convertLeavesToNodes(leaves []*smt.Leaf) []*models.SmtNode { +func (rm *RoundManager) convertLeavesToNodes(leaves []*smt.Leaf) ([]*models.SmtNode, error) { if len(leaves) == 0 { - return nil + return nil, nil } - smtNodes := make([]*models.SmtNode, len(leaves)) - for i, leaf := range leaves { - keyBytes := leaf.Path.Bytes() + keyLength := rm.smt.GetKeyLength() + smtNodes := make([]*models.SmtNode, 0, len(leaves)) + for _, leaf := range leaves { + keyBytes, err := api.PathToFixedBytes(leaf.Path, keyLength) + if err != nil { + return nil, fmt.Errorf("failed to convert leaf path %s to SMT storage key: %w", leaf.Path.String(), err) + } key := api.NewHexBytes(keyBytes) value := api.NewHexBytes(leaf.Value) - smtNodes[i] = models.NewSmtNode(key, value) + smtNodes = append(smtNodes, models.NewSmtNode(key, value)) } - return smtNodes + return smtNodes, nil } // convertCommitmentsToRecords converts commitments to aggregator records diff --git a/internal/round/finalize_duplicate_test.go b/internal/round/finalize_duplicate_test.go index 2ff7c6a..ac54932 100644 --- a/internal/round/finalize_duplicate_test.go +++ b/internal/round/finalize_duplicate_test.go @@ -62,7 +62,7 @@ func (s *FinalizeDuplicateTestSuite) Test1_DuplicateRecovery() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) + smtInstance := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT, nil) @@ -91,7 +91,8 @@ func (s *FinalizeDuplicateTestSuite) Test1_DuplicateRecovery() { // Pre-populate storage with 2 out of 5 records (simulating partial write before crash) partialLeaves := rm.currentRound.PendingLeaves[:2] - preExistingNodes := rm.convertLeavesToNodes(partialLeaves) + preExistingNodes, err := rm.convertLeavesToNodes(partialLeaves) + require.NoError(t, err) err = s.storage.SmtStorage().StoreBatch(ctx, preExistingNodes) require.NoError(t, err, "Pre-populating SMT nodes should succeed") @@ -121,7 +122,6 @@ func (s *FinalizeDuplicateTestSuite) Test1_DuplicateRecovery() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // FinalizeBlock should succeed despite duplicates @@ -147,7 +147,7 @@ func (s *FinalizeDuplicateTestSuite) Test2_NoDuplicates() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) + smtInstance := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT, nil) @@ -180,7 +180,6 @@ func (s *FinalizeDuplicateTestSuite) Test2_NoDuplicates() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // Should succeed on first try (no duplicates) @@ -199,7 +198,7 @@ func (s *FinalizeDuplicateTestSuite) Test3_AllDuplicates() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) + smtInstance := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT, nil) @@ -224,7 +223,8 @@ func (s *FinalizeDuplicateTestSuite) Test3_AllDuplicates() { recordCountBefore, _ := s.storage.AggregatorRecordStorage().Count(ctx) // Pre-populate ALL SMT nodes and aggregator records - allNodes := rm.convertLeavesToNodes(rm.currentRound.PendingLeaves) + allNodes, err := rm.convertLeavesToNodes(rm.currentRound.PendingLeaves) + require.NoError(t, err) err = s.storage.SmtStorage().StoreBatch(ctx, allNodes) require.NoError(t, err) @@ -245,7 +245,6 @@ func (s *FinalizeDuplicateTestSuite) Test3_AllDuplicates() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // Should succeed even when all records are duplicates @@ -272,7 +271,7 @@ func (s *FinalizeDuplicateTestSuite) Test4_DuplicateBlock() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - threadSafeSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + threadSafeSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT, nil) require.NoError(t, err) @@ -303,7 +302,6 @@ func (s *FinalizeDuplicateTestSuite) Test4_DuplicateBlock() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // Pre-store the block (simulating previous attempt that stored block but failed on MarkProcessed) @@ -354,7 +352,7 @@ func (s *FinalizeDuplicateTestSuite) Test5_DuplicateBlockAlreadyFinalized() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - threadSafeSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + threadSafeSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT, nil) require.NoError(t, err) @@ -385,7 +383,6 @@ func (s *FinalizeDuplicateTestSuite) Test5_DuplicateBlockAlreadyFinalized() { rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) // Pre-store the block as FINALIZED (simulating previous successful attempt except MarkProcessed) @@ -402,7 +399,8 @@ func (s *FinalizeDuplicateTestSuite) Test5_DuplicateBlockAlreadyFinalized() { require.NoError(t, err, "Pre-storing block records should succeed") // Pre-store all SMT nodes and records (simulating full previous attempt) - allNodes := rm.convertLeavesToNodes(rm.currentRound.PendingLeaves) + allNodes, err := rm.convertLeavesToNodes(rm.currentRound.PendingLeaves) + require.NoError(t, err) err = s.storage.SmtStorage().StoreBatch(ctx, allNodes) require.NoError(t, err) @@ -431,7 +429,7 @@ func (s *FinalizeDuplicateTestSuite) Test6_BlockRecordsMatchPendingCommitmentsOn testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - threadSafeSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + threadSafeSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT, nil) require.NoError(t, err) @@ -484,7 +482,6 @@ func (s *FinalizeDuplicateTestSuite) Test6_BlockRecordsMatchPendingCommitmentsOn rootHashBytes, api.HexBytes{}, api.HexBytes{}, - nil, ) err = rm.FinalizeBlock(ctx, block) diff --git a/internal/round/parent_round_manager.go b/internal/round/parent_round_manager.go index 714b08f..dc87da8 100644 --- a/internal/round/parent_round_manager.go +++ b/internal/round/parent_round_manager.go @@ -1,6 +1,7 @@ package round import ( + "bytes" "context" "encoding/binary" "fmt" @@ -295,12 +296,7 @@ func (prm *ParentRoundManager) processRound(ctx context.Context, round *ParentRo var parentRootHash api.HexBytes if len(round.ProcessedShardUpdates) == 0 { - rootHashHex := round.Snapshot.GetRootHash() - parsedRoot, err := api.NewHexBytesFromString(rootHashHex) - if err != nil { - return fmt.Errorf("failed to parse parent SMT root hash %q: %w", rootHashHex, err) - } - parentRootHash = parsedRoot + parentRootHash = round.Snapshot.GetRootHashRaw() prm.logger.WithContext(ctx).Info("Empty parent round, using current SMT root hash", "rootHash", parentRootHash.String()) } else { @@ -313,18 +309,13 @@ func (prm *ParentRoundManager) processRound(ctx context.Context, round *ParentRo } smtStart := time.Now() - rootHashStr, err := round.Snapshot.AddLeaves(leaves) - metrics.SMTAddLeavesDuration.Observe(time.Since(smtStart).Seconds()) - if err != nil { + if _, err := round.Snapshot.AddLeaves(leaves); err != nil { + metrics.SMTAddLeavesDuration.Observe(time.Since(smtStart).Seconds()) return fmt.Errorf("failed to add shard leaves to parent SMT snapshot: %w", err) } + metrics.SMTAddLeavesDuration.Observe(time.Since(smtStart).Seconds()) - parsedRoot, err := api.NewHexBytesFromString(rootHashStr) - if err != nil { - return fmt.Errorf("failed to parse updated parent SMT root hash %q: %w", rootHashStr, err) - } - - parentRootHash = parsedRoot + parentRootHash = round.Snapshot.GetRootHashRaw() prm.logger.WithContext(ctx).Info("Added shard updates to parent SMT snapshot", "shardCount", len(round.ProcessedShardUpdates), "newRootHash", parentRootHash.String()) @@ -354,7 +345,6 @@ func (prm *ParentRoundManager) processRound(ctx context.Context, round *ParentRo parentRootHash, previousBlockHash, nil, - nil, ) round.Block = block @@ -639,22 +629,23 @@ func (prm *ParentRoundManager) reconstructParentSMT(ctx context.Context) error { prm.logger.WithContext(ctx).Info("Successfully reconstructed parent SMT", "leafCount", len(leaves), - "rootHash", prm.parentSMT.GetRootHash()) + "rootHash", api.HexBytes(prm.parentSMT.GetRootHashRaw()).String()) } - // Verify reconstructed root hash matches latest block (safety check) + // Verify reconstructed root hash matches latest block (safety check). + // Both sides are raw 32-byte roots (matching UC.IR.h / V2 wire format). latestBlock, err := prm.storage.BlockStorage().GetLatest(ctx) if err != nil { return fmt.Errorf("failed to get latest block for verification: %w", err) } if latestBlock != nil { - reconstructedHash := prm.parentSMT.GetRootHash() - expectedHash := latestBlock.RootHash.String() - if reconstructedHash != expectedHash { - return fmt.Errorf("parent SMT reconstruction failed: root hash mismatch (got %s, expected %s)", reconstructedHash, expectedHash) + reconstructedHash := api.HexBytes(prm.parentSMT.GetRootHashRaw()) + if !bytes.Equal(reconstructedHash, latestBlock.RootHash) { + return fmt.Errorf("parent SMT reconstruction failed: root hash mismatch (got %s, expected %s)", + reconstructedHash.String(), latestBlock.RootHash.String()) } prm.logger.WithContext(ctx).Info("Parent SMT root hash verification passed", - "rootHash", reconstructedHash, + "rootHash", reconstructedHash.String(), "blockNumber", latestBlock.Index.String()) } diff --git a/internal/round/parent_round_manager_test.go b/internal/round/parent_round_manager_test.go index bff969e..74b56ec 100644 --- a/internal/round/parent_round_manager_test.go +++ b/internal/round/parent_round_manager_test.go @@ -1,6 +1,7 @@ package round import ( + "bytes" "context" "testing" "time" @@ -433,12 +434,13 @@ func (suite *ParentRoundManagerTestSuite) TestBlockRootMatchesSMTRoot() { err = prm.SubmitShardRoot(ctx, update) suite.Require().NoError(err) - // Wait for the SMT to include the shard update (root hash will change from empty) - var expectedRootHex string + // Wait for the SMT to include the shard update (raw 32-byte root matching UC.IR.h) + emptyRoot := api.HexBytes(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength).GetRootHashRaw()) + var expectedRoot api.HexBytes suite.Require().Eventually(func() bool { - expectedRootHex = prm.GetSMT().GetRootHash() + expectedRoot = api.HexBytes(prm.GetSMT().GetRootHashRaw()) // Empty SMT has a specific hash, wait for it to change after shard update is processed - return expectedRootHex != smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength).GetRootHashHex() + return !bytes.Equal(expectedRoot, emptyRoot) }, 5*time.Second, 50*time.Millisecond, "expected SMT to include shard update") // Now wait for a block with this root hash to be stored @@ -450,10 +452,10 @@ func (suite *ParentRoundManagerTestSuite) TestBlockRootMatchesSMTRoot() { } latestBlock = block // Check if this block has the expected root (contains our shard update) - return latestBlock.RootHash.String() == expectedRootHex + return bytes.Equal(latestBlock.RootHash, expectedRoot) }, 5*time.Second, 50*time.Millisecond, "expected block with shard update to be stored") - suite.Assert().Equal(expectedRootHex, latestBlock.RootHash.String(), "stored block root should match SMT root") + suite.Assert().Equal(expectedRoot, latestBlock.RootHash, "stored block root should match SMT root") } // TestParentRoundManagerSuite runs the test suite diff --git a/internal/round/precollection_test.go b/internal/round/precollection_test.go index 52967b3..9b7019a 100644 --- a/internal/round/precollection_test.go +++ b/internal/round/precollection_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/unicitynetwork/bft-go-base/types" + "github.com/unicitynetwork/bft-go-base/types/hex" "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/events" @@ -175,7 +176,6 @@ func (c *staleThenFreshRootAggregatorClient) GetShardProof(ctx context.Context, } c.proofPolls++ - root := c.submittedRoot.String() parentRound := c.staleParentRound rootRound := c.staleRootRound @@ -186,16 +186,18 @@ func (c *staleThenFreshRootAggregatorClient) GetShardProof(ctx context.Context, default: } - uc, err := testProofUC(parentRound, rootRound) + uc, err := testProofUC(parentRound, rootRound, c.submittedRoot) if err != nil { return nil, err } return &api.RootShardInclusionProof{ - UnicityCertificate: uc, - MerkleTreePath: &api.MerkleTreePath{ - Steps: []api.MerkleTreeStep{{Data: &root}}, + ParentFragment: &api.ParentInclusionFragment{ + CertificateBytes: api.NewHexBytes(make([]byte, api.BitmapSize)), + ShardLeafValue: api.NewHexBytes(c.submittedRoot), }, + BlockNumber: parentRound, + UnicityCertificate: uc, }, nil } @@ -209,10 +211,11 @@ func (c *staleThenFreshRootAggregatorClient) ProofPolls() int { return c.proofPolls } -func testProofUC(parentRound, rootRound uint64) (api.HexBytes, error) { +func testProofUC(parentRound, rootRound uint64, rootHash api.HexBytes) (api.HexBytes, error) { uc := types.UnicityCertificate{ InputRecord: &types.InputRecord{ RoundNumber: parentRound, + Hash: hex.Bytes(rootHash), }, UnicitySeal: &types.UnicitySeal{ RootChainRoundNumber: rootRound, @@ -265,7 +268,7 @@ func newTestLogger(t *testing.T) *logger.Logger { func newTestPrecollector(t *testing.T, stream chan *models.CertificationRequest, maxPerRound int) (*childPrecollector, *smt.ThreadSafeSMT) { t.Helper() log := newTestLogger(t) - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) if maxPerRound <= 0 { maxPerRound = 10000 } @@ -550,7 +553,7 @@ func TestPreCollectionReparenting(t *testing.T) { defer cancel() testLogger := newTestLogger(t) - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) initialMainRootHash := smtInstance.GetRootHash() // Round N: Create snapshot and add a leaf @@ -624,7 +627,8 @@ func TestChildPrecollector_DeactivateDuringInFlightRound(t *testing.T) { MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -636,7 +640,7 @@ func TestChildPrecollector_DeactivateDuringInFlightRound(t *testing.T) { testLogger := newTestLogger(t) rootClient := newBlockingProofRootAggregatorClient() - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rm, err := NewRoundManager( ctx, @@ -697,7 +701,8 @@ func TestChildRound_ParentProofTimeoutIsRetriable(t *testing.T) { MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 100 * time.Millisecond, @@ -709,7 +714,7 @@ func TestChildRound_ParentProofTimeoutIsRetriable(t *testing.T) { testLogger := newTestLogger(t) rootClient := newBlockingProofRootAggregatorClient() - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rm, err := NewRoundManager( ctx, @@ -776,7 +781,7 @@ func TestStartNewRoundWithSnapshot(t *testing.T) { storage := testutil.SetupTestStorage(t, cfg) testLogger := newTestLogger(t) - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rm, err := NewRoundManager( ctx, @@ -833,7 +838,8 @@ func TestPipelinedChildModeFlow(t *testing.T) { MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -845,7 +851,7 @@ func TestPipelinedChildModeFlow(t *testing.T) { testLogger := newTestLogger(t) rootAggregatorClient := testsharding.NewRootAggregatorClientStub() - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rm, err := NewRoundManager( ctx, @@ -903,7 +909,8 @@ func TestChildPreCollection_CommitmentAfterProofBeforeRoundEnd_ShouldBeInNextRou MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -915,7 +922,7 @@ func TestChildPreCollection_CommitmentAfterProofBeforeRoundEnd_ShouldBeInNextRou testLogger := newTestLogger(t) rootAggregatorClient := testsharding.NewRootAggregatorClientStub() - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rm, err := NewRoundManager( ctx, @@ -972,7 +979,8 @@ func TestChildMode_RequiresFreshParentProof(t *testing.T) { MaxCommitmentsPerRound: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -984,11 +992,11 @@ func TestChildMode_RequiresFreshParentProof(t *testing.T) { testLogger := newTestLogger(t) rootAggregatorClient := newStaleThenFreshRootAggregatorClient(1, 10, 2, 13) - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) rootHash, err := api.NewHexBytesFromString(smtInstance.GetRootHash()) require.NoError(t, err) - initialUC, err := testProofUC(1, 10) + initialUC, err := testProofUC(1, 10, rootHash) require.NoError(t, err) initialBlock := models.NewBlock( api.NewBigInt(big.NewInt(1)), @@ -999,7 +1007,6 @@ func TestChildMode_RequiresFreshParentProof(t *testing.T) { rootHash, api.HexBytes{}, initialUC, - nil, ) initialBlock.Finalized = true require.NoError(t, storage.BlockStorage().Store(ctx, initialBlock)) diff --git a/internal/round/recovery.go b/internal/round/recovery.go index f6951a7..01e8227 100644 --- a/internal/round/recovery.go +++ b/internal/round/recovery.go @@ -3,7 +3,6 @@ package round import ( "context" "fmt" - "math/big" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" @@ -98,11 +97,11 @@ func recoverBlock( smtKeyStrings := make([]string, len(requestIDs)) for i, reqID := range requestIDs { - path, err := reqID.GetPath() + keyBytes, err := reqID.GetTreeKey() if err != nil { - return nil, fmt.Errorf("failed to get path for requestID: %w", err) + return nil, fmt.Errorf("failed to get SMT key for requestID: %w", err) } - smtKeyStrings[i] = api.HexBytes(path.Bytes()).String() + smtKeyStrings[i] = api.HexBytes(keyBytes).String() } existingSmtKeys, err := storage.SmtStorage().GetExistingKeys(ctx, smtKeyStrings) if err != nil { @@ -219,11 +218,11 @@ func recoverMissingData( for _, reqID := range missingSmtKeys { commitment, ok := commitmentMap[reqID.String()] if !ok { - path, err := reqID.GetPath() + keyBytes, err := reqID.GetTreeKey() if err != nil { - return fmt.Errorf("failed to get path for reqID: %w", err) + return fmt.Errorf("failed to get SMT key for reqID: %w", err) } - existingNode, err := storage.SmtStorage().GetByKey(ctx, path.Bytes()) + existingNode, err := storage.SmtStorage().GetByKey(ctx, keyBytes) if err != nil { return fmt.Errorf("failed to check existing SMT node: %w", err) } @@ -233,15 +232,15 @@ func recoverMissingData( return fmt.Errorf("FATAL: commitment not found for SMT key %s", reqID) } - path, err := commitment.StateID.GetPath() + keyBytes, err := commitment.StateID.GetTreeKey() if err != nil { - return fmt.Errorf("failed to get path for commitment: %w", err) + return fmt.Errorf("failed to get SMT key for commitment: %w", err) } leafValue, err := commitment.LeafValue() if err != nil { return fmt.Errorf("failed to create leaf value: %w", err) } - nodes = append(nodes, models.NewSmtNode(path.Bytes(), leafValue)) + nodes = append(nodes, models.NewSmtNode(keyBytes, leafValue)) } if len(nodes) > 0 { @@ -280,11 +279,11 @@ func LoadRecoveredNodesIntoSMT( continue } seen[key] = struct{}{} - path, err := reqID.GetPath() + keyBytes, err := reqID.GetTreeKey() if err != nil { - return fmt.Errorf("failed to get path for requestID %s: %w", reqID, err) + return fmt.Errorf("failed to get SMT key for requestID %s: %w", reqID, err) } - keys = append(keys, api.HexBytes(path.Bytes())) + keys = append(keys, api.HexBytes(keyBytes)) } nodes, err := storage.SmtStorage().GetByKeys(ctx, keys) @@ -298,7 +297,10 @@ func LoadRecoveredNodesIntoSMT( leaves := make([]*smt.Leaf, len(nodes)) for i, node := range nodes { - path := new(big.Int).SetBytes(node.Key) + path, err := api.FixedBytesToPath(node.Key, smtTree.GetKeyLength()) + if err != nil { + return fmt.Errorf("failed to convert SMT key to path: %w", err) + } leaves[i] = smt.NewLeaf(path, node.Value) } diff --git a/internal/round/recovery_test.go b/internal/round/recovery_test.go index 2453303..3254987 100644 --- a/internal/round/recovery_test.go +++ b/internal/round/recovery_test.go @@ -135,12 +135,12 @@ func (s *RecoveryTestSuite) createTestData(blockNum int64, commitmentCount int, } // Compute SMT root hash - smtTree := smt.NewSparseMerkleTree(api.SHA256, 16+256) + smtTree := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) leaves := make([]*smt.Leaf, len(commitments)) for i, c := range commitments { path, err := c.StateID.GetPath() require.NoError(t, err) - leafValue, err := c.CertificationData.ToAPI().Hash() + leafValue, err := c.LeafValue() require.NoError(t, err) leaves[i] = smt.NewLeaf(path, leafValue) } @@ -149,7 +149,7 @@ func (s *RecoveryTestSuite) createTestData(blockNum int64, commitmentCount int, rootHashBytes := smtTree.GetRootHash() // Create block (unfinalized) - block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", api.HexBytes(rootHashBytes), nil, nil, nil) + block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", api.HexBytes(rootHashBytes), nil, nil) block.Finalized = false return commitments, block, requestIDs @@ -169,11 +169,11 @@ func (s *RecoveryTestSuite) storeCommitmentsInRedis(commitments []*models.Certif func (s *RecoveryTestSuite) storeSmtNodes(commitments []*models.CertificationRequest) { nodes := make([]*models.SmtNode, len(commitments)) for i, c := range commitments { - path, err := c.StateID.GetPath() + keyBytes, err := c.StateID.GetTreeKey() s.Require().NoError(err) - leafValue, err := c.CertificationData.ToAPI().Hash() + leafValue, err := c.LeafValue() s.Require().NoError(err) - nodes[i] = models.NewSmtNode(api.HexBytes(path.Bytes()), leafValue) + nodes[i] = models.NewSmtNode(api.HexBytes(keyBytes), leafValue) } err := s.storage.SmtStorage().StoreBatch(s.ctx, nodes) s.Require().NoError(err) @@ -590,11 +590,11 @@ func (s *RecoveryTestSuite) Test10_PartialSmtNodes_CorrectDetection() { existingIndices := []int{0, 1, 4} existingNodes := make([]*models.SmtNode, len(existingIndices)) for i, idx := range existingIndices { - path, err := commitments[idx].StateID.GetPath() + keyBytes, err := commitments[idx].StateID.GetTreeKey() require.NoError(t, err) - leafValue, err := commitments[idx].CertificationData.ToAPI().Hash() + leafValue, err := commitments[idx].LeafValue() require.NoError(t, err) - existingNodes[i] = models.NewSmtNode(api.HexBytes(path.Bytes()), leafValue) + existingNodes[i] = models.NewSmtNode(api.HexBytes(keyBytes), leafValue) } err = s.storage.SmtStorage().StoreBatch(s.ctx, existingNodes) require.NoError(t, err) @@ -639,12 +639,12 @@ func (s *RecoveryTestSuite) Test11_LoadRecoveredNodesIntoSMT() { } // Compute expected root hash - expectedSMT := smt.NewSparseMerkleTree(api.SHA256, 16+256) + expectedSMT := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) leaves := make([]*smt.Leaf, len(commitments)) for i, c := range commitments { path, err := c.StateID.GetPath() require.NoError(t, err) - leafValue, err := c.CertificationData.ToAPI().Hash() + leafValue, err := c.LeafValue() require.NoError(t, err) leaves[i] = smt.NewLeaf(path, leafValue) } @@ -656,7 +656,7 @@ func (s *RecoveryTestSuite) Test11_LoadRecoveredNodesIntoSMT() { s.storeSmtNodes(commitments) // Create empty SMT to load into - targetSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + targetSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) require.NotEqual(t, expectedRootHash, targetSMT.GetRootHash(), "SMT should be empty initially") // Load recovered nodes into SMT diff --git a/internal/round/round_manager.go b/internal/round/round_manager.go index 5ad00a3..05d4827 100644 --- a/internal/round/round_manager.go +++ b/internal/round/round_manager.go @@ -1,6 +1,7 @@ package round import ( + "bytes" "context" "errors" "fmt" @@ -54,8 +55,9 @@ type Round struct { State RoundState Commitments []*models.CertificationRequest Block *models.Block - // Track commitments that have been added to SMT but not yet finalized in a block - PendingRootHash string + // Track commitments that have been added to SMT but not yet finalized in a block. + // Raw 32-byte SMT root (no algorithm-id prefix), matching the V2 wire format. + PendingRootHash api.HexBytes // SMT snapshot for this round - allows accumulating changes before committing Snapshot *smt.ThreadSafeSmtSnapshot // Store data for persistence during FinalizeBlock @@ -484,9 +486,9 @@ func (rm *RoundManager) processRound(ctx context.Context) error { } rm.roundMutex.Lock() commitmentCount := len(rm.currentRound.Commitments) - var rootHash string + var rootHash api.HexBytes if rm.currentRound.Snapshot != nil { - rootHash = rm.currentRound.Snapshot.GetRootHash() + rootHash = rm.currentRound.Snapshot.GetRootHashRaw() } rm.currentRound.PendingRootHash = rootHash rm.currentRound.ProposalTime = time.Now() @@ -495,7 +497,7 @@ func (rm *RoundManager) processRound(ctx context.Context) error { rm.logger.WithContext(ctx).Info("processRound called", "roundNumber", roundNumber.String(), "commitments", commitmentCount, - "rootHash", rootHash) + "rootHash", rootHash.String()) if err := rm.proposeBlock(ctx, roundNumber, rootHash); err != nil { return fmt.Errorf("failed to propose block: %w", err) @@ -664,8 +666,14 @@ func (rm *RoundManager) restoreSmtFromStorage(ctx context.Context) (*api.BigInt, // Convert storage nodes to SMT leaves leaves := make([]*smt.Leaf, len(nodes)) for i, node := range nodes { - // Convert key bytes back to big.Int path - path := new(big.Int).SetBytes(node.Key) + // Restore the sentinel-bit path from the fixed-bytes storage key. + // Keys were written via api.PathToFixedBytes (which clears the + // sentinel bit); FixedBytesToPath is the inverse and the SMT + // AddLeaf path strictly requires the sentinel bit to be set. + path, err := api.FixedBytesToPath(node.Key, rm.smt.GetKeyLength()) + if err != nil { + return nil, fmt.Errorf("failed to convert SMT key to path: %w", err) + } leaves[i] = smt.NewLeaf(path, node.Value) } @@ -688,12 +696,12 @@ func (rm *RoundManager) restoreSmtFromStorage(ctx context.Context) (*api.BigInt, } } - // Log final state - finalRootHash := rm.smt.GetRootHash() + // Log final state (raw 32-byte root matching UC.IR.h / V2 wire format) + finalRootHash := api.HexBytes(rm.smt.GetRootHashRaw()) rm.logger.Info("SMT restoration complete", "restoredNodes", restoredCount, "totalNodes", totalCount, - "finalRootHash", finalRootHash) + "finalRootHash", finalRootHash.String()) if restoredCount != int(totalCount) { rm.logger.Warn("SMT restoration count mismatch", @@ -709,17 +717,16 @@ func (rm *RoundManager) restoreSmtFromStorage(ctx context.Context) (*api.BigInt, rm.logger.Info("No latest block found, skipping SMT verification") return nil, nil } else { - expectedRootHash := latestBlock.RootHash.String() - if finalRootHash != expectedRootHash { + if !bytes.Equal(finalRootHash, latestBlock.RootHash) { rm.logger.Error("SMT restoration verification failed - root hash mismatch", - "restoredRootHash", finalRootHash, - "expectedRootHash", expectedRootHash, + "restoredRootHash", finalRootHash.String(), + "expectedRootHash", latestBlock.RootHash.String(), "latestBlockNumber", latestBlock.Index.String()) return nil, fmt.Errorf("SMT restoration verification failed: restored root hash %s does not match latest block root hash %s", - finalRootHash, expectedRootHash) + finalRootHash.String(), latestBlock.RootHash.String()) } rm.logger.Info("SMT restoration verified successfully - root hash matches latest block", - "rootHash", finalRootHash, + "rootHash", finalRootHash.String(), "latestBlockNumber", latestBlock.Index.String()) rm.stateTracker.SetLastSyncedBlock(latestBlock.Index.Int) diff --git a/internal/round/round_manager_test.go b/internal/round/round_manager_test.go index 78dc236..6166359 100644 --- a/internal/round/round_manager_test.go +++ b/internal/round/round_manager_test.go @@ -30,7 +30,8 @@ func TestParentShardIntegration_GoodCase(t *testing.T) { BatchLimit: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -46,7 +47,7 @@ func TestParentShardIntegration_GoodCase(t *testing.T) { // create round manager rm, err := NewRoundManager(ctx, &cfg, testLogger, storage.CommitmentQueue(), storage, rootAggregatorClient, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), - smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err) // start round manager @@ -79,7 +80,8 @@ func TestParentShardIntegration_RoundProcessingError(t *testing.T) { BatchLimit: 1000, }, Sharding: config.ShardingConfig{ - Mode: config.ShardingModeChild, + Mode: config.ShardingModeChild, + ShardIDLength: 1, Child: config.ChildConfig{ ShardID: 0b11, ParentPollTimeout: 5 * time.Second, @@ -98,7 +100,7 @@ func TestParentShardIntegration_RoundProcessingError(t *testing.T) { // create round manager rm, err := NewRoundManager(ctx, &cfg, testLogger, storage.CommitmentQueue(), storage, rootAggregatorClient, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), - smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err) require.NoError(t, rm.Start(ctx)) diff --git a/internal/round/smt_persistence_integration_test.go b/internal/round/smt_persistence_integration_test.go index f60802e..e7cd120 100644 --- a/internal/round/smt_persistence_integration_test.go +++ b/internal/round/smt_persistence_integration_test.go @@ -60,7 +60,7 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { {Path: big.NewInt(15), Value: []byte("test_value_15")}, {Path: big.NewInt(16), Value: []byte("test_value_16")}, } - keyLen := 16 + 256 + keyLen := api.StateTreeKeyLengthBits for _, t := range testLeaves { t.Path = new(big.Int).SetBit(t.Path, keyLen, 1) } @@ -73,11 +73,12 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "Should create RoundManager") // Test persistence - smtNodes := rm.convertLeavesToNodes(testLeaves) + smtNodes, err := rm.convertLeavesToNodes(testLeaves) + require.NoError(t, err) err = storage.SmtStorage().StoreBatch(ctx, smtNodes) require.NoError(t, err, "Should persist SMT nodes") @@ -86,20 +87,20 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(len(testLeaves)), count, "Should have stored all SMT nodes") - // Test restoration produces same root hash as fresh SMT + // Test restoration produces same root hash as fresh SMT (raw 32-byte form) freshSmt := smt.NewSparseMerkleTree(api.SHA256, keyLen) err = freshSmt.AddLeaves(testLeaves) require.NoError(t, err, "Fresh SMT should accept leaves") - freshHash := freshSmt.GetRootHashHex() + freshHash := freshSmt.GetRootHashRaw() // Create RoundManager and call Start() to trigger restoration - restoredRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + restoredRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "Should create RoundManager") err = restoredRm.Start(ctx) require.NoError(t, err, "SMT restoration should succeed") defer restoredRm.Stop(ctx) - restoredHash := restoredRm.smt.GetRootHash() + restoredHash := restoredRm.smt.GetRootHashRaw() assert.Equal(t, freshHash, restoredHash, "Restored SMT should have same root hash as fresh SMT") @@ -127,28 +128,29 @@ func TestLargeSmtRestoration(t *testing.T) { RoundDuration: time.Second, }, } - rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "Should create RoundManager") const testNodeCount = 12000 // Ensure multiple chunks (chunkSize = 10000 in round_manager.go) // Create large dataset with non-sequential paths to test ordering testLeaves := make([]*smt.Leaf, testNodeCount) - keyLen := 16 + 256 + keyLen := api.StateTreeKeyLengthBits for i := 0; i < testNodeCount; i++ { path := new(big.Int).SetBit(big.NewInt(int64((i+1)*700000)), keyLen, 1) value := []byte(fmt.Sprintf("large_test_value_%d", i)) testLeaves[i] = smt.NewLeaf(path, value) } - // Create fresh SMT for comparison + // Create fresh SMT for comparison (raw 32-byte form) freshSmt := smt.NewSparseMerkleTree(api.SHA256, keyLen) err = freshSmt.AddLeaves(testLeaves) require.NoError(t, err, "Fresh SMT AddLeaves should succeed") - freshHash := freshSmt.GetRootHashHex() + freshHash := freshSmt.GetRootHashRaw() // Persist leaves to storage - smtNodes := rm.convertLeavesToNodes(testLeaves) + smtNodes, err := rm.convertLeavesToNodes(testLeaves) + require.NoError(t, err) err = storage.SmtStorage().StoreBatch(ctx, smtNodes) require.NoError(t, err, "Should persist large number of SMT nodes") @@ -158,14 +160,14 @@ func TestLargeSmtRestoration(t *testing.T) { require.Equal(t, int64(testNodeCount), count, "Should have stored all nodes") // Create new RoundManager and call Start() to restore from storage (uses multiple chunks) - newRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + newRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "Should create new RoundManager") err = newRm.Start(ctx) require.NoError(t, err, "Large SMT restoration should succeed") defer newRm.Stop(ctx) - restoredHash := newRm.smt.GetRootHash() + restoredHash := newRm.smt.GetRootHashRaw() // Critical test: multi-chunk restoration should match single-batch creation assert.Equal(t, freshHash, restoredHash, "Multi-chunk restoration should produce same hash as fresh SMT") @@ -191,7 +193,7 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "Should create RoundManager") rm.currentRound = &Round{ @@ -212,15 +214,12 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { rm.roundMutex.Unlock() require.NoError(t, err, "processMiniBatch should succeed") - // Get the root hash from the snapshot - rootHash := rm.currentRound.Snapshot.GetRootHash() - require.NotEmpty(t, rootHash, "Root hash should not be empty") + // Get the raw 32-byte root hash from the snapshot (matches UC.IR.h / V2 wire format) + rootHashBytes := api.HexBytes(rm.currentRound.Snapshot.GetRootHashRaw()) + require.NotEmpty(t, rootHashBytes, "Root hash should not be empty") // After processBatch, SMT nodes are not yet persisted - they're stored in round state // We need to finalize a block to trigger persistence - rootHashBytes, err := api.NewHexBytesFromString(rootHash) - require.NoError(t, err, "Should parse root hash") - block := models.NewBlock( blockNumber, "unicity", @@ -230,7 +229,6 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { rootHashBytes, api.HexBytes{}, nil, - nil, ) err = rm.FinalizeBlock(ctx, block) @@ -243,7 +241,7 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { // Simulate service restart with new round manager cfg = &config.Config{Processing: config.ProcessingConfig{RoundDuration: time.Second}} - newRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + newRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "NewRoundManager should succeed after restart") // Call Start() to trigger SMT restoration @@ -251,8 +249,8 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { require.NoError(t, err, "Start should succeed and restore SMT") defer newRm.Stop(ctx) - // Verify restored SMT has correct data - restoredRootHash := newRm.smt.GetRootHash() + // Verify restored SMT has correct data (raw 32-byte form) + restoredRootHash := newRm.smt.GetRootHashRaw() assert.NotEmpty(t, restoredRootHash, "Restored SMT should have non-empty root hash") // Verify inclusion proofs work after restart @@ -284,32 +282,30 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { {Path: big.NewInt(0x120), Value: []byte("block_test_value_20")}, {Path: big.NewInt(0x130), Value: []byte("block_test_value_30")}, } - keyLen := 16 + 256 + keyLen := api.StateTreeKeyLengthBits for _, t := range testLeaves { t.Path = new(big.Int).SetBit(t.Path, keyLen, 1) } - // Create fresh SMT to get expected root hash + // Create fresh SMT to get expected root hash (raw 32-byte form matches UC.IR.h) freshSmt := smt.NewSparseMerkleTree(api.SHA256, keyLen) err = freshSmt.AddLeaves(testLeaves) require.NoError(t, err, "Fresh SMT should accept leaves") - expectedRootHash := freshSmt.GetRootHashHex() - expectedRootHashBytes := freshSmt.GetRootHash() + expectedRootHashRaw := freshSmt.GetRootHashRaw() - // Create a block with the expected root hash + // Create a block with the expected raw root hash block := &models.Block{ - Index: api.NewBigInt(big.NewInt(1)), - ChainID: "test-chain", - ShardID: 0, - Version: "1.0.0", - ForkID: "test-fork", - RootHash: api.HexBytes(expectedRootHashBytes), // Use bytes, not hex string - PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000000"), - NoDeletionProofHash: api.HexBytes(""), - CreatedAt: api.NewTimestamp(time.Now()), - UnicityCertificate: api.HexBytes("certificate_data"), - ParentMerkleTreePath: nil, - Finalized: true, // Mark as finalized for test + Index: api.NewBigInt(big.NewInt(1)), + ChainID: "test-chain", + ShardID: 0, + Version: "1.0.0", + ForkID: "test-fork", + RootHash: api.HexBytes(expectedRootHashRaw), + PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000000"), + NoDeletionProofHash: api.HexBytes(""), + CreatedAt: api.NewTimestamp(time.Now()), + UnicityCertificate: api.HexBytes("certificate_data"), + Finalized: true, // Mark as finalized for test } // Store the block @@ -320,51 +316,51 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { cfg := &config.Config{ Processing: config.ProcessingConfig{RoundDuration: time.Second}, } - rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "Should create RoundManager") // Persist SMT nodes to storage - smtNodes := rm.convertLeavesToNodes(testLeaves) + smtNodes, err := rm.convertLeavesToNodes(testLeaves) + require.NoError(t, err) err = storage.SmtStorage().StoreBatch(ctx, smtNodes) require.NoError(t, err, "Should persist SMT nodes") // Test 1: Successful verification (matching root hash) t.Run("SuccessfulVerification", func(t *testing.T) { - successRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + successRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "Should create RoundManager") err = successRm.Start(ctx) require.NoError(t, err, "SMT restoration should succeed when root hashes match") defer successRm.Stop(ctx) - // Verify the restored SMT has the correct hash - restoredHash := successRm.smt.GetRootHash() - assert.Equal(t, expectedRootHash, restoredHash, "Restored SMT should have expected root hash") + // Verify the restored SMT has the correct hash (raw 32-byte form) + restoredHash := successRm.smt.GetRootHashRaw() + assert.Equal(t, expectedRootHashRaw, restoredHash, "Restored SMT should have expected root hash") }) // Test 2: Failed verification (mismatched root hash) t.Run("FailedVerification", func(t *testing.T) { // Create a block with a different root hash to simulate mismatch wrongBlock := &models.Block{ - Index: api.NewBigInt(big.NewInt(2)), - ChainID: "test-chain", - ShardID: 0, - Version: "1.0.0", - ForkID: "test-fork", - RootHash: api.HexBytes("wrong_root_hash_value"), // Intentionally wrong hash - PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000001"), - NoDeletionProofHash: api.HexBytes(""), - CreatedAt: api.NewTimestamp(time.Now()), - UnicityCertificate: api.HexBytes("certificate_data"), - ParentMerkleTreePath: nil, - Finalized: true, // Mark as finalized for test + Index: api.NewBigInt(big.NewInt(2)), + ChainID: "test-chain", + ShardID: 0, + Version: "1.0.0", + ForkID: "test-fork", + RootHash: api.HexBytes("wrong_root_hash_value"), // Intentionally wrong hash + PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000001"), + NoDeletionProofHash: api.HexBytes(""), + CreatedAt: api.NewTimestamp(time.Now()), + UnicityCertificate: api.HexBytes("certificate_data"), + Finalized: true, // Mark as finalized for test } // Store the wrong block (this will become the "latest" block) err = storage.BlockStorage().Store(ctx, wrongBlock) require.NoError(t, err, "Should store wrong test block") - failRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + failRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err, "Should create RoundManager") // This should fail because the restored SMT root hash doesn't match the latest block diff --git a/internal/service/parent_service.go b/internal/service/parent_service.go index af4353d..1c4f382 100644 --- a/internal/service/parent_service.go +++ b/internal/service/parent_service.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math/big" "github.com/unicitynetwork/bft-go-base/types" @@ -163,6 +162,9 @@ func (pas *ParentAggregatorService) GetShardProof(ctx context.Context, req *api. pas.logger.WithContext(ctx).Debug("Serving shard proof while node is follower") } } + if !pas.parentRoundManager.IsReady() { + return nil, errors.New("parent round manager is not ready to serve shard proofs") + } if err := pas.validateShardID(req.ShardID); err != nil { pas.logger.WithContext(ctx).Warn("Invalid shard ID", "shardId", req.ShardID, "error", err.Error()) @@ -171,31 +173,24 @@ func (pas *ParentAggregatorService) GetShardProof(ctx context.Context, req *api. pas.logger.WithContext(ctx).Debug("Shard proof requested", "shardId", req.ShardID) - shardPath := new(big.Int).SetInt64(int64(req.ShardID)) - merkleTreePath, err := pas.parentRoundManager.GetSMT().GetPath(shardPath) + parentFragment, rootHashRaw, err := pas.parentRoundManager.GetSMT().GetShardInclusionFragmentWithRoot(req.ShardID) if err != nil { - return nil, fmt.Errorf("failed to get merkle tree path: %w", err) + return nil, fmt.Errorf("failed to get shard inclusion fragment: %w", err) } - - var proofPath *api.MerkleTreePath - if len(merkleTreePath.Steps) > 0 && merkleTreePath.Steps[0].Data != nil { - proofPath = merkleTreePath - pas.logger.WithContext(ctx).Info("Generated shard proof from current state", - "shardId", req.ShardID, - "rootHash", merkleTreePath.Root) - } else { - proofPath = nil + if parentFragment == nil { pas.logger.WithContext(ctx).Info("Shard has not submitted root yet, returning nil proof", "shardId", req.ShardID) } var unicityCertificate api.HexBytes - if proofPath != nil { - rootHash, err := api.NewHexBytesFromString(merkleTreePath.Root) - if err != nil { - return nil, fmt.Errorf("failed to parse root hash: %w", err) - } - + var blockNumber uint64 + if parentFragment != nil { + // GetShardInclusionFragmentWithRoot snapshots fragment+root under one RLock. + // Finalization persists block data before committing the SMT snapshot, + // so a visible root is expected to have a corresponding finalized block. + // Blocks are stored under the raw 32-byte SMT root matching + // UC.IR.h, so look up the block by the raw root hash. + rootHash := api.HexBytes(rootHashRaw) block, err := pas.storage.BlockStorage().GetLatestByRootHash(ctx, rootHash) if err != nil { return nil, fmt.Errorf("failed to get latest block by root hash: %w", err) @@ -204,11 +199,13 @@ func (pas *ParentAggregatorService) GetShardProof(ctx context.Context, req *api. return nil, fmt.Errorf("no block found with root hash %s", rootHash.String()) } unicityCertificate = block.UnicityCertificate + blockNumber = block.Index.Uint64() } return &api.GetShardProofResponse{ - MerkleTreePath: proofPath, + ParentFragment: parentFragment, UnicityCertificate: unicityCertificate, + BlockNumber: blockNumber, }, nil } diff --git a/internal/service/parent_service_test.go b/internal/service/parent_service_test.go index 54790f9..32f844c 100644 --- a/internal/service/parent_service_test.go +++ b/internal/service/parent_service_test.go @@ -136,7 +136,7 @@ func (suite *ParentServiceTestSuite) waitForShardToExist(ctx context.Context, sh suite.T().Fatalf("Timeout waiting for shard %d to be processed", shardID) case <-ticker.C: response, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shardID}) - if err == nil && response.MerkleTreePath != nil { + if err == nil && response.ParentFragment != nil { return // Shard exists and has a proof! } // Continue polling @@ -283,7 +283,8 @@ func (suite *ParentServiceTestSuite) TestSubmitShardRoot_UpdateQueued() { suite.T().Log("✓ Shard update was queued and processed correctly") } -// GetShardProof - Success - Full E2E test with child SMT, proof joining, and verification +// GetShardProof - Success - Full E2E test with child SMT root submission and +// native parent fragment verification. func (suite *ParentServiceTestSuite) TestGetShardProof_Success() { ctx := context.Background() @@ -298,16 +299,9 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_Success() { err := childSMT.AddLeaf(testLeafPath, testLeafValue) suite.Require().NoError(err, "Should add leaf to child SMT") - // 2. Extract child root hash (what child aggregator would submit to parent) - childRootHash := childSMT.GetRootHash() - suite.Require().NotEmpty(childRootHash, "Child SMT should have root hash") - suite.T().Logf("Child SMT root hash: %x", childRootHash) - - // 3. Submit child root to parent (strip algorithm prefix - first 2 bytes) - // This is required for JoinPaths to work: parent stores raw 32-byte hashes - childRootRaw := childRootHash[2:] // Remove algorithm identifier (2 bytes) - suite.Require().True(len(childRootRaw) == 32, "Root hash should be 32 bytes after stripping prefix") - suite.T().Logf("Sending %d bytes to parent (WITHOUT algorithm prefix)", len(childRootRaw)) + // 2. Extract child root hash (what child aggregator submits to the parent). + childRootRaw := childSMT.GetRootHashRaw() + suite.Require().Len(childRootRaw, 32, "Child SMT raw root hash must be 32 bytes") submitReq := &api.SubmitShardRootRequest{ ShardID: shard0ID, @@ -319,42 +313,27 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_Success() { // 4. Wait for round to process suite.waitForShardToExist(ctx, shard0ID) - // 5. Get child proof from child SMT - childProof, err := childSMT.GetPath(testLeafPath) - suite.Require().NoError(err, "Should get child proof") - suite.Require().NotNil(childProof, "Child proof should not be nil") - suite.T().Logf("Child proof has %d steps", len(childProof.Steps)) - - // 6. Request parent proof from parent aggregator + // 5. Request parent proof from parent aggregator proofReq := &api.GetShardProofRequest{ ShardID: shard0ID, } parentResponse, err := suite.service.GetShardProof(ctx, proofReq) suite.Require().NoError(err, "Should get parent proof successfully") suite.Require().NotNil(parentResponse, "Parent response should not be nil") - suite.Require().NotNil(parentResponse.MerkleTreePath, "Parent proof should not be nil") - suite.T().Logf("Parent proof has %d steps", len(parentResponse.MerkleTreePath.Steps)) - - // 7. Join child and parent proofs - joinedProof, err := smt.JoinPaths(childProof, parentResponse.MerkleTreePath) - suite.Require().NoError(err, "Should join proofs successfully") - suite.Require().NotNil(joinedProof, "Joined proof should not be nil") - suite.T().Logf("Joined proof has %d steps", len(joinedProof.Steps)) - - // 8. Verify the joined proof - result, err := joinedProof.Verify(testLeafPath) - suite.Require().NoError(err, "Proof verification should not error") - suite.Require().NotNil(result, "Verification result should not be nil") - - // Both PathValid and PathIncluded should be true - suite.Assert().True(result.PathValid, "Joined proof path must be valid") - suite.Assert().True(result.PathIncluded, "Joined proof should show path is included") - suite.Assert().True(result.Result, "Overall verification result should be true") - - suite.T().Log("✓ End-to-end test: child SMT → parent submission → proof joining → verification SUCCESS") + suite.Require().NotNil(parentResponse.ParentFragment, "Parent fragment should not be nil") + + // 6. Verify the native parent fragment against the returned parent UC root. + suite.Assert().Equal(api.NewHexBytes(childRootRaw), parentResponse.ParentFragment.ShardLeafValue) + suite.Assert().True((&api.RootShardInclusionProof{ + ParentFragment: parentResponse.ParentFragment, + UnicityCertificate: parentResponse.UnicityCertificate, + BlockNumber: parentResponse.BlockNumber, + }).IsValid(shard0ID, suite.cfg.Sharding.ShardIDLength, api.NewHexBytes(childRootRaw))) + + suite.T().Log("✓ End-to-end test: child SMT → parent submission → native parent fragment verification SUCCESS") } -// GetShardProof - Non-existent Shard (returns nil MerkleTreePath) +// GetShardProof - Non-existent Shard (returns nil fragment) func (suite *ParentServiceTestSuite) TestGetShardProof_NonExistentShard() { ctx := context.Background() @@ -379,9 +358,9 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_NonExistentShard() { response, err := suite.service.GetShardProof(ctx, proofReq) suite.Require().NoError(err, "Should not return error for non-existent shard") suite.Require().NotNil(response, "Response should not be nil") - suite.Assert().Nil(response.MerkleTreePath, "MerkleTreePath should be nil for non-existent shard") + suite.Assert().Nil(response.ParentFragment, "Parent fragment should be nil for non-existent shard") - suite.T().Log("✓ GetShardProof returns nil MerkleTreePath for non-existent shard") + suite.T().Log("✓ GetShardProof returns nil parent fragment for non-existent shard") } // GetShardProof - Empty Tree (no shards submitted yet) @@ -397,9 +376,9 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_EmptyTree() { response, err := suite.service.GetShardProof(ctx, proofReq) suite.Require().NoError(err, "Should not return error for empty tree") suite.Require().NotNil(response, "Response should not be nil") - suite.Assert().Nil(response.MerkleTreePath, "MerkleTreePath should be nil when no shards submitted") + suite.Assert().Nil(response.ParentFragment, "Parent fragment should be nil when no shards submitted") - suite.T().Log("✓ GetShardProof returns nil MerkleTreePath for empty tree") + suite.T().Log("✓ GetShardProof returns nil parent fragment for empty tree") } // GetShardProof - Multiple Shards (verify each has correct proof) @@ -437,21 +416,33 @@ func (suite *ParentServiceTestSuite) TestGetShardProof_MultipleShards() { // Get proofs for all 3 shards proof0, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard0ID}) suite.Require().NoError(err, "Should get proof for shard 0") - suite.Assert().NotNil(proof0.MerkleTreePath, "Proof 0 should not be nil") + suite.Assert().NotNil(proof0.ParentFragment, "Fragment 0 should not be nil") proof1, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard1ID}) suite.Require().NoError(err, "Should get proof for shard 1") - suite.Assert().NotNil(proof1.MerkleTreePath, "Proof 1 should not be nil") + suite.Assert().NotNil(proof1.ParentFragment, "Fragment 1 should not be nil") proof2, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard2ID}) suite.Require().NoError(err, "Should get proof for shard 2") - suite.Assert().NotNil(proof2.MerkleTreePath, "Proof 2 should not be nil") - - // All proofs should have the same root (same parent SMT) - suite.Assert().Equal(proof0.MerkleTreePath.Root, proof1.MerkleTreePath.Root, "All proofs should have same root") - suite.Assert().Equal(proof0.MerkleTreePath.Root, proof2.MerkleTreePath.Root, "All proofs should have same root") - - suite.T().Log("✓ GetShardProof returns valid proofs for multiple shards with same root") + suite.Assert().NotNil(proof2.ParentFragment, "Fragment 2 should not be nil") + + suite.Assert().True((&api.RootShardInclusionProof{ + ParentFragment: proof0.ParentFragment, + UnicityCertificate: proof0.UnicityCertificate, + BlockNumber: proof0.BlockNumber, + }).IsValid(shard0ID, suite.cfg.Sharding.ShardIDLength, makeTestHash(0xAA))) + suite.Assert().True((&api.RootShardInclusionProof{ + ParentFragment: proof1.ParentFragment, + UnicityCertificate: proof1.UnicityCertificate, + BlockNumber: proof1.BlockNumber, + }).IsValid(shard1ID, suite.cfg.Sharding.ShardIDLength, makeTestHash(0xBB))) + suite.Assert().True((&api.RootShardInclusionProof{ + ParentFragment: proof2.ParentFragment, + UnicityCertificate: proof2.UnicityCertificate, + BlockNumber: proof2.BlockNumber, + }).IsValid(shard2ID, suite.cfg.Sharding.ShardIDLength, makeTestHash(0xCC))) + + suite.T().Log("✓ GetShardProof returns valid parent fragments for multiple shards") } // TestParentServiceSuite runs the test suite diff --git a/internal/service/service.go b/internal/service/service.go index 7bae978..56f75df 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -14,7 +14,6 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/signing" signingV1 "github.com/unicitynetwork/aggregator-go/internal/signing/v1" - "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" "github.com/unicitynetwork/aggregator-go/internal/trustbase" "github.com/unicitynetwork/aggregator-go/pkg/api" @@ -126,17 +125,18 @@ func modelToAPIAggregatorRecordV1(modelRecord *models.AggregatorRecord) *api.Agg func modelToAPIBlock(modelBlock *models.Block) *api.Block { return &api.Block{ - Index: modelBlock.Index, - ChainID: modelBlock.ChainID, - ShardID: modelBlock.ShardID, - Version: modelBlock.Version, - ForkID: modelBlock.ForkID, - RootHash: modelBlock.RootHash, - PreviousBlockHash: modelBlock.PreviousBlockHash, - NoDeletionProofHash: modelBlock.NoDeletionProofHash, - CreatedAt: modelBlock.CreatedAt, - UnicityCertificate: modelBlock.UnicityCertificate, - ParentMerkleTreePath: modelBlock.ParentMerkleTreePath, + Index: modelBlock.Index, + ChainID: modelBlock.ChainID, + ShardID: modelBlock.ShardID, + Version: modelBlock.Version, + ForkID: modelBlock.ForkID, + RootHash: modelBlock.RootHash, + PreviousBlockHash: modelBlock.PreviousBlockHash, + NoDeletionProofHash: modelBlock.NoDeletionProofHash, + CreatedAt: modelBlock.CreatedAt, + UnicityCertificate: modelBlock.UnicityCertificate, + ParentFragment: modelBlock.ParentFragment, + ParentBlockNumber: modelBlock.ParentBlockNumber, } } @@ -168,7 +168,7 @@ func NewAggregatorService(cfg *config.Config, leaderSelector: leaderSelector, commitmentValidator: signingV1.NewCommitmentValidator(cfg.Sharding), certificationRequestValidator: signing.NewCertificationRequestValidator(cfg.Sharding), - trustBaseValidator: trustbase.NewTrustBaseValidator(storage.TrustBaseStorage()), + trustBaseValidator: trustbase.NewTrustBaseValidator(storage.TrustBaseStorage()), receiptSigner: receiptSigner, } } @@ -326,6 +326,10 @@ func (as *AggregatorService) GetInclusionProofV1(ctx context.Context, req *api.G unlock := as.roundManager.FinalizationReadLock() defer unlock() + if as.config.Sharding.Mode == config.ShardingModeChild { + return nil, fmt.Errorf("legacy inclusion proof v1 is not supported in child mode") + } + // verify that the request ID matches the shard ID of this aggregator if err := as.commitmentValidator.ValidateShardID(req.RequestID); err != nil { return nil, fmt.Errorf("request ID validation failed: %w", err) @@ -349,25 +353,15 @@ func (as *AggregatorService) GetInclusionProofV1(ctx context.Context, req *api.G return nil, fmt.Errorf("failed to get inclusion proof for request ID %s: %w", req.RequestID, err) } - // Find the latest block that matches the current SMT root hash - rootHash, err := api.NewHexBytesFromString(merkleTreePath.Root) - if err != nil { - return nil, fmt.Errorf("failed to parse root hash: %w", err) - } + // Find the latest block that matches the current SMT root hash. Blocks + // are stored under the raw 32-byte root matching UC.IR.h. + rootHash := api.HexBytes(smtInstance.GetRootHashRaw()) block, err := as.storage.BlockStorage().GetLatestByRootHash(ctx, rootHash) if err != nil { return nil, fmt.Errorf("failed to get latest block by root hash: %w", err) } if block == nil { - return nil, fmt.Errorf("no block found with root hash %s", rootHash) - } - - // Join parent and child SMT paths if sharding mode is enabled - if as.config.Sharding.Mode == config.ShardingModeChild { - merkleTreePath, err = smt.JoinPaths(merkleTreePath, block.ParentMerkleTreePath) - if err != nil { - return nil, fmt.Errorf("failed to join parent and child aggregator paths: %w", err) - } + return nil, fmt.Errorf("no block found with root hash %s", rootHash.String()) } // Check if commitment exists in aggregator records (finalized) @@ -405,67 +399,63 @@ func (as *AggregatorService) GetInclusionProofV1(ctx context.Context, req *api.G }, nil } -// GetInclusionProofV2 retrieves inclusion proof for a commitment +// GetInclusionProofV2 retrieves a v2 inclusion proof for the given stateId. +// Both standalone and child mode serve proofs against the current certified +// SMT root. In child mode, the local child cert is composed with the stored +// parent fragment and bound to the parent UC.IR.h. func (as *AggregatorService) GetInclusionProofV2(ctx context.Context, req *api.GetInclusionProofRequestV2) (*api.GetInclusionProofResponseV2, error) { unlock := as.roundManager.FinalizationReadLock() defer unlock() - // verify that the state ID matches the shard ID of this aggregator if err := as.certificationRequestValidator.ValidateShardID(req.StateID); err != nil { return nil, fmt.Errorf("state ID validation failed: %w", err) } - path, err := req.StateID.GetPath() + // v2 stateId must be exactly 32 bytes. + if len(req.StateID) != api.StateTreeKeyLengthBytes { + return nil, fmt.Errorf("invalid state ID length: expected %d bytes (v2 wire format), got %d", + api.StateTreeKeyLengthBytes, len(req.StateID)) + } + key, err := req.StateID.GetTreeKey() if err != nil { - return nil, fmt.Errorf("failed to get path for state ID %s: %w", req.StateID, err) + return nil, fmt.Errorf("invalid state ID: %w", err) } smtInstance := as.roundManager.GetSMT() if smtInstance == nil { return nil, fmt.Errorf("merkle tree not initialized") } - if keyLen := smtInstance.GetKeyLength(); path.BitLen()-1 != keyLen { - return nil, fmt.Errorf("request path length %d does not match SMT key length %d", path.BitLen()-1, keyLen) - } - - merkleTreePath, err := as.roundManager.GetSMT().GetPath(path) - if err != nil { - return nil, fmt.Errorf("failed to get inclusion proof for state ID %s: %w", req.StateID, err) + if keyLen := smtInstance.GetKeyLength(); keyLen != api.StateTreeKeyLengthBits { + return nil, fmt.Errorf("unexpected SMT key length: got %d bits, want %d", keyLen, api.StateTreeKeyLengthBits) } - // Find the latest block that matches the current SMT root hash - rootHash, err := api.NewHexBytesFromString(merkleTreePath.Root) - if err != nil { - return nil, fmt.Errorf("failed to parse root hash: %w", err) - } - block, err := as.storage.BlockStorage().GetLatestByRootHash(ctx, rootHash) + // Bind the UC via the block whose stored rootHash matches the current + // raw 32-byte SMT root (which also lives in UC.IR.h). + rootHashRaw := api.HexBytes(smtInstance.GetRootHashRaw()) + block, err := as.storage.BlockStorage().GetLatestByRootHash(ctx, rootHashRaw) if err != nil { return nil, fmt.Errorf("failed to get latest block by root hash: %w", err) } if block == nil { - return nil, fmt.Errorf("no block found with root hash %s", rootHash) + return nil, fmt.Errorf("no block found with root hash %s", rootHashRaw.String()) } - - // Join parent and child SMT paths if sharding mode is enabled - if as.config.Sharding.Mode == config.ShardingModeChild { - merkleTreePath, err = smt.JoinPaths(merkleTreePath, block.ParentMerkleTreePath) - if err != nil { - return nil, fmt.Errorf("failed to join parent and child aggregator paths: %w", err) - } + responseBlockNumber, err := proofBundleBlockNumber(as.config.Sharding.Mode, block) + if err != nil { + return nil, err } - // Check if certification request exists in aggregator records (finalized) record, err := as.storage.AggregatorRecordStorage().GetByStateID(ctx, req.StateID) if err != nil { return nil, fmt.Errorf("failed to get aggregator record: %w", err) } if record == nil || record.BlockNumber.Cmp(block.Index.Int) > 0 { - // Non-inclusion proof: either record doesn't exist, or it belongs to a - // newer block whose SMT state hasn't been committed yet. + // Non-inclusion is not implemented yet. Return an empty v2 proof + // payload so verifiers short-circuit with ErrExclusionNotImpl. return &api.GetInclusionProofResponseV2{ + BlockNumber: responseBlockNumber, InclusionProof: &api.InclusionProofV2{ CertificationData: nil, - MerkleTreePath: merkleTreePath, + CertificateBytes: nil, UnicityCertificate: types.RawCBOR(block.UnicityCertificate), }, }, nil @@ -473,16 +463,60 @@ func (as *AggregatorService) GetInclusionProofV2(ctx context.Context, req *api.G if record.Version != 2 { return nil, fmt.Errorf("invalid aggregator record version got %d expected 2", record.Version) } + + childCert, err := smtInstance.GetInclusionCert(key) + if err != nil { + return nil, fmt.Errorf("failed to build inclusion cert for state ID %s: %w", req.StateID, err) + } + + cert := childCert + if as.config.Sharding.Mode == config.ShardingModeChild { + if block.ParentFragment == nil { + return nil, fmt.Errorf("current child block %s is missing parent fragment", block.Index.String()) + } + childRoot := smtInstance.GetRootHashRaw() + cert, err = api.ComposeInclusionCert(block.ParentFragment, childCert, childRoot) + if err != nil { + return nil, fmt.Errorf("failed to compose child and parent inclusion certs: %w", err) + } + } + + certBytes, err := cert.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("failed to marshal inclusion cert: %w", err) + } + + proof := &api.InclusionProofV2{ + CertificationData: record.CertificationData.ToAPI(), + CertificateBytes: certBytes, + UnicityCertificate: types.RawCBOR(block.UnicityCertificate), + } + if err := proof.Verify(&api.CertificationRequest{ + StateID: req.StateID, + CertificationData: *record.CertificationData.ToAPI(), + }); err != nil { + return nil, fmt.Errorf("generated inclusion proof failed self-verification: %w", err) + } + return &api.GetInclusionProofResponseV2{ - BlockNumber: record.BlockNumber.Uint64(), - InclusionProof: &api.InclusionProofV2{ - CertificationData: record.CertificationData.ToAPI(), - MerkleTreePath: merkleTreePath, - UnicityCertificate: types.RawCBOR(block.UnicityCertificate), - }, + BlockNumber: responseBlockNumber, + InclusionProof: proof, }, nil } +func proofBundleBlockNumber(mode config.ShardingMode, block *models.Block) (uint64, error) { + if block == nil { + return 0, fmt.Errorf("missing block for proof bundle") + } + if mode != config.ShardingModeChild { + return block.Index.Uint64(), nil + } + if block.ParentBlockNumber == 0 { + return 0, fmt.Errorf("current child block %s is missing parent block number", block.Index.String()) + } + return block.ParentBlockNumber, nil +} + // GetNoDeletionProof retrieves the global no-deletion proof func (as *AggregatorService) GetNoDeletionProof(ctx context.Context) (*api.GetNoDeletionProofResponse, error) { // TODO: Implement no-deletion proof generation diff --git a/internal/service/service_test.go b/internal/service/service_test.go index aad94d4..a6c7c7d 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -3,8 +3,10 @@ package service import ( "bytes" "context" + "encoding/hex" "encoding/json" "fmt" + "math/big" "net" "net/http" "net/url" @@ -21,17 +23,21 @@ import ( "github.com/stretchr/testify/suite" "github.com/testcontainers/testcontainers-go/modules/mongodb" redisContainer "github.com/testcontainers/testcontainers-go/modules/redis" + bfttypes "github.com/unicitynetwork/bft-go-base/types" + bfthex "github.com/unicitynetwork/bft-go-base/types/hex" "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/gateway" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/sharding" "github.com/unicitynetwork/aggregator-go/internal/signing" "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage" + "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" "github.com/unicitynetwork/aggregator-go/pkg/api" "github.com/unicitynetwork/aggregator-go/pkg/jsonrpc" ) @@ -129,7 +135,7 @@ func setupMongoDBAndAggregator(t *testing.T, ctx context.Context) (string, func( // Initialize round manager rootAggregatorClient := sharding.NewRootAggregatorClientStub() - roundManager, err := round.NewRoundManager(ctx, cfg, log, commitmentQueue, mongoStorage, rootAggregatorClient, state.NewSyncStateTracker(), nil, events.NewEventBus(log), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)), nil) + roundManager, err := round.NewRoundManager(ctx, cfg, log, commitmentQueue, mongoStorage, rootAggregatorClient, state.NewSyncStateTracker(), nil, events.NewEventBus(log), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)), nil) require.NoError(t, err) // Start the round manager (restores SMT) @@ -226,60 +232,31 @@ func makeJSONRPCRequest[T any](t *testing.T, serverAddr, method, requestID strin return result } -// Helper function to validate inclusion proof structure and encoding -func validateInclusionProof(t *testing.T, proof *api.InclusionProofV2, stateID api.StateID) { +// validateInclusionProof validates the v2 inclusion proof wire structure +// and performs end-to-end verification against the originating +// CertificationRequest via the new InclusionCert + UC.IR.h binding. +func validateInclusionProof(t *testing.T, proof *api.InclusionProofV2, req *api.CertificationRequest) { assert.NotNil(t, proof.CertificationData, "Should have certification data") - assert.NotNil(t, proof.MerkleTreePath, "Should have merkle tree path") - - // Validate unicity certificate field - if len(proof.UnicityCertificate) > 0 { - assert.NotEmpty(t, proof.UnicityCertificate, "Unicity certificate should not be empty") - } - - // Validate certification data encoding - if proof.CertificationData != nil { - assert.NotEmpty(t, proof.CertificationData.OwnerPredicate, "CertificationData should have owner predicate") - assert.NotEmpty(t, proof.CertificationData.Witness, "CertificationData should have signature") - assert.NotEmpty(t, proof.CertificationData.SourceStateHash, "CertificationData should have source state hash") - assert.NotEmpty(t, proof.CertificationData.TransactionHash, "CertificationData should have transaction hash") - - // Verify CBOR encoding of certification data - certDataBytes, err := cbor.Marshal(proof.CertificationData) - require.NoError(t, err, "CertificationData should be CBOR encodable") - assert.NotEmpty(t, certDataBytes, "CBOR encoded certification data should not be empty") - - // Verify we can decode it back - var decodedAuth api.CertificationData - err = cbor.Unmarshal(certDataBytes, &decodedAuth) - require.NoError(t, err, "Should be able to decode CBOR certification data") - } - - // Validate merkle tree path encoding - if proof.MerkleTreePath != nil { - assert.NotEmpty(t, proof.MerkleTreePath.Root, "Merkle path should have root") - assert.NotNil(t, proof.MerkleTreePath.Steps, "Merkle path should have steps") - - // Verify Merkle tree path with state ID - stateIDBigInt, err := stateID.GetPath() - require.NoError(t, err, "Should be able to get path from stateID") - - verificationResult, err := proof.MerkleTreePath.Verify(stateIDBigInt) - require.NoError(t, err, "Merkle tree path verification should not error") - assert.True(t, verificationResult.PathValid, "Merkle tree path should be valid") - assert.True(t, verificationResult.PathIncluded, "Request should be included in the Merkle tree") - assert.True(t, verificationResult.Result, "Overall verification result should be true") - - // Verify CBOR encoding of merkle tree path - pathBytes, err := cbor.Marshal(proof.MerkleTreePath) - require.NoError(t, err, "Merkle tree path should be CBOR encodable") - assert.NotEmpty(t, pathBytes, "CBOR encoded merkle path should not be empty") - - // Verify we can decode it back - var decodedPath api.MerkleTreePath - err = cbor.Unmarshal(pathBytes, &decodedPath) - require.NoError(t, err, "Should be able to decode CBOR merkle path") - assert.Equal(t, proof.MerkleTreePath.Root, decodedPath.Root, "Decoded merkle path should match original") - } + assert.NotEmpty(t, proof.CertificateBytes, "Should have inclusion cert bytes") + assert.NotEmpty(t, proof.UnicityCertificate, "Unicity certificate should not be empty") + + assert.NotEmpty(t, proof.CertificationData.OwnerPredicate, "CertificationData should have owner predicate") + assert.NotEmpty(t, proof.CertificationData.Witness, "CertificationData should have signature") + assert.NotEmpty(t, proof.CertificationData.SourceStateHash, "CertificationData should have source state hash") + assert.NotEmpty(t, proof.CertificationData.TransactionHash, "CertificationData should have transaction hash") + + // Verify CBOR round-trip of certification data. + certDataBytes, err := cbor.Marshal(proof.CertificationData) + require.NoError(t, err, "CertificationData should be CBOR encodable") + assert.NotEmpty(t, certDataBytes, "CBOR encoded certification data should not be empty") + var decodedAuth api.CertificationData + require.NoError(t, cbor.Unmarshal(certDataBytes, &decodedAuth)) + + // Wire-decode the inclusion certificate and perform full v2 verification. + var cert api.InclusionCert + require.NoError(t, cert.UnmarshalBinary(proof.CertificateBytes), "InclusionCert must decode") + + require.NoError(t, proof.Verify(req), "v2 inclusion proof must verify") } // TestInclusionProofMissingRecord tests getting inclusion proof for non-existent record @@ -294,11 +271,9 @@ func (suite *AggregatorTestSuite) TestInclusionProofMissingRecord() { // Wait for block processing time.Sleep(3 * time.Second) - // Now test non-inclusion proof for a different state ID - stateId := "" - for i := 0; i < 2+32; i++ { - stateId = stateId + "00" - } + // Now test non-inclusion proof for a raw 32-byte v2 state ID that has + // never been submitted. + stateId := strings.Repeat("00", api.StateTreeKeyLengthBytes) inclusionProof := makeJSONRPCRequest[api.GetInclusionProofResponseV2](suite.T(), suite.serverAddr, "get_inclusion_proof.v2", @@ -306,13 +281,12 @@ func (suite *AggregatorTestSuite) TestInclusionProofMissingRecord() { &api.GetInclusionProofRequestV2{StateID: api.RequireNewImprintV2(stateId)}, ) - // Validate non-inclusion proof structure + // v2 non-inclusion: ExclusionCert wire path is not yet implemented. The + // service returns an empty-cert payload with CertificationData == nil + // plus a bound UnicityCertificate for round-trip. suite.Nil(inclusionProof.InclusionProof.CertificationData) - suite.NotNil(inclusionProof.InclusionProof.MerkleTreePath) - - // Verify that UnicityCertificate is included in non-inclusion proof - suite.NotNil(inclusionProof.InclusionProof.UnicityCertificate, "Non-inclusion proof should include UnicityCertificate") - suite.NotEmpty(inclusionProof.InclusionProof.UnicityCertificate, "UnicityCertificate should not be empty") + suite.Empty(inclusionProof.InclusionProof.CertificateBytes) + suite.NotEmpty(inclusionProof.InclusionProof.UnicityCertificate, "Non-inclusion proof should include UnicityCertificate") } func TestGetInclusionProofShardMismatch(t *testing.T) { @@ -322,10 +296,11 @@ func TestGetInclusionProofShardMismatch(t *testing.T) { ShardID: 4, }, } - tree := smt.NewChildSparseMerkleTree(api.SHA256, 16+256, shardingCfg.Child.ShardID) + tree := smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, shardingCfg.Child.ShardID) service := newAggregatorServiceForTest(t, shardingCfg, tree) - invalidShardID := api.RequireNewImprintV2(strings.Repeat("00", 33) + "01") + // Raw 32-byte v2 stateId whose shard-prefix bits don't match shard 4 (=0b100). + invalidShardID := api.RequireNewImprintV2("01" + strings.Repeat("00", api.StateTreeKeyLengthBytes-1)) _, err := service.GetInclusionProofV2(context.Background(), &api.GetInclusionProofRequestV2{StateID: invalidShardID}) require.Error(t, err) assert.Contains(t, err.Error(), "state ID validation failed") @@ -338,7 +313,7 @@ func TestGetInclusionProofInvalidRequestFormat(t *testing.T) { ShardID: 4, }, } - tree := smt.NewChildSparseMerkleTree(api.SHA256, 16+256, shardingCfg.Child.ShardID) + tree := smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, shardingCfg.Child.ShardID) service := newAggregatorServiceForTest(t, shardingCfg, tree) _, err := service.GetInclusionProofV2(context.Background(), &api.GetInclusionProofRequestV2{StateID: api.ImprintV2([]byte("zz"))}) @@ -355,7 +330,8 @@ func TestGetInclusionProofSMTUnavailable(t *testing.T) { } service := newAggregatorServiceForTest(t, shardingCfg, nil) - validID := api.RequireNewImprintV2(strings.Repeat("00", 34)) + // Raw 32-byte v2 stateId; all-zero shard-prefix bits match shard 4 (expected = 0b00). + validID := api.RequireNewImprintV2(strings.Repeat("00", api.StateTreeKeyLengthBytes)) _, err := service.GetInclusionProofV2(t.Context(), &api.GetInclusionProofRequestV2{StateID: validID}) require.Error(t, err) assert.Contains(t, err.Error(), "merkle tree not initialized") @@ -365,50 +341,172 @@ func TestInclusionProofInvalidPathLength(t *testing.T) { shardingCfg := config.ShardingConfig{ Mode: config.ShardingModeStandalone, } - tree := smt.NewSparseMerkleTree(api.SHA256, 16+256) + tree := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) service := newAggregatorServiceForTest(t, shardingCfg, tree) - validID := createTestCertificationRequests(t, 1)[0].StateID.String() - require.Greater(t, len(validID), 2) - badID := api.RequireNewImprintV2(validID[2:]) + // Drop one byte from the canonical 32-byte stateId — v2 strict enforcement + // must reject it at length validation, before any SMT traversal. + validID := createTestCertificationRequests(t, 1)[0].StateID + require.Len(t, validID, api.StateTreeKeyLengthBytes) + badID := api.ImprintV2(append([]byte(nil), validID[:len(validID)-1]...)) _, err := service.GetInclusionProofV2(context.Background(), &api.GetInclusionProofRequestV2{StateID: badID}) require.Error(t, err) - assert.Contains(t, err.Error(), "path length") + assert.Contains(t, err.Error(), "invalid state ID length") +} + +func TestGetInclusionProofV2Child_ComposesParentFragment(t *testing.T) { + shardingCfg := config.ShardingConfig{ + Mode: config.ShardingModeChild, + ShardIDLength: 2, + Child: config.ChildConfig{ + ShardID: 4, // shard prefix bits 00 + }, + } + + stateID := api.RequireNewImprintV2(strings.Repeat("00", api.StateTreeKeyLengthBytes)) + sourceStateHash := api.RequireNewImprintV2("10" + strings.Repeat("00", api.StateTreeKeyLengthBytes-1)) + transactionHash := api.RequireNewImprintV2("20" + strings.Repeat("00", api.StateTreeKeyLengthBytes-1)) + + childTree := smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, shardingCfg.Child.ShardID) + path, err := stateID.GetPath() + require.NoError(t, err) + require.NoError(t, childTree.AddLeaf(path, transactionHash.DataBytes())) + childRoot := childTree.GetRootHashRaw() + + parentTree := smt.NewParentSparseMerkleTree(api.SHA256, shardingCfg.ShardIDLength) + require.NoError(t, smt.NewThreadSafeSMT(parentTree).AddPreHashedLeaf(big.NewInt(int64(shardingCfg.Child.ShardID)), childRoot)) + parentRoot := parentTree.GetRootHashRaw() + parentFragment, err := parentTree.GetShardInclusionFragment(shardingCfg.Child.ShardID) + require.NoError(t, err) + require.NotNil(t, parentFragment) + + parentUC := testChildProofUC(t, 9, parentRoot) + block := models.NewBlock( + api.NewBigIntFromUint64(3), + "test-chain", + shardingCfg.Child.ShardID, + "v", + "f", + api.HexBytes(childRoot), + nil, + parentUC, + ) + block.ParentFragment = parentFragment + block.ParentBlockNumber = 9 + block.Finalized = true + + record := &models.AggregatorRecord{ + Version: 2, + StateID: stateID, + CertificationData: models.CertificationData{ + OwnerPredicate: api.Predicate{Engine: 1, Code: []byte{0x01}, Params: []byte{0x02}}, + SourceStateHash: sourceStateHash, + TransactionHash: transactionHash, + Witness: []byte{0x01, 0x02}, + }, + BlockNumber: api.NewBigIntFromUint64(1), + LeafIndex: api.NewBigIntFromUint64(0), + CreatedAt: api.Now(), + FinalizedAt: api.Now(), + } + + service := newAggregatorServiceForTest(t, shardingCfg, childTree) + service.storage = &testStorage{ + blockStorage: &testBlockStorage{latestByRoot: map[string]*models.Block{block.RootHash.String(): block}}, + recordStorage: &testAggregatorRecordStorage{ + byStateID: map[string]*models.AggregatorRecord{stateID.String(): record}, + }, + } + + resp, err := service.GetInclusionProofV2(context.Background(), &api.GetInclusionProofRequestV2{StateID: stateID}) + require.NoError(t, err) + require.Equal(t, uint64(9), resp.BlockNumber) + require.NotNil(t, resp.InclusionProof) + require.Equal(t, parentUC, api.HexBytes(resp.InclusionProof.UnicityCertificate)) + + req := &api.CertificationRequest{ + StateID: stateID, + CertificationData: api.CertificationData{ + OwnerPredicate: record.CertificationData.OwnerPredicate, + SourceStateHash: record.CertificationData.SourceStateHash, + TransactionHash: record.CertificationData.TransactionHash, + Witness: record.CertificationData.Witness, + }, + } + validateInclusionProof(t, resp.InclusionProof, req) +} + +func TestGetInclusionProofV2Child_NonInclusionUsesParentBundleMetadata(t *testing.T) { + shardingCfg := config.ShardingConfig{ + Mode: config.ShardingModeChild, + ShardIDLength: 2, + Child: config.ChildConfig{ + ShardID: 4, + }, + } + + childTree := smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, shardingCfg.Child.ShardID) + parentUC := testChildProofUC(t, 12, childTree.GetRootHashRaw()) + block := models.NewBlock( + api.NewBigIntFromUint64(4), + "test-chain", + shardingCfg.Child.ShardID, + "v", + "f", + api.HexBytes(childTree.GetRootHashRaw()), + nil, + parentUC, + ) + block.ParentBlockNumber = 12 + block.Finalized = true + + service := newAggregatorServiceForTest(t, shardingCfg, childTree) + service.storage = &testStorage{ + blockStorage: &testBlockStorage{latestByRoot: map[string]*models.Block{block.RootHash.String(): block}}, + recordStorage: &testAggregatorRecordStorage{byStateID: map[string]*models.AggregatorRecord{}}, + } + + stateID := api.RequireNewImprintV2(strings.Repeat("00", api.StateTreeKeyLengthBytes)) + resp, err := service.GetInclusionProofV2(context.Background(), &api.GetInclusionProofRequestV2{StateID: stateID}) + require.NoError(t, err) + require.Equal(t, uint64(12), resp.BlockNumber) + require.Nil(t, resp.InclusionProof.CertificationData) + require.Empty(t, resp.InclusionProof.CertificateBytes) + require.Equal(t, parentUC, api.HexBytes(resp.InclusionProof.UnicityCertificate)) } // TestInclusionProof tests the complete inclusion proof workflow func (suite *AggregatorTestSuite) TestInclusionProof() { // 1) Send commitments testRequests := createTestCertificationRequests(suite.T(), 3) - var submittedStateIDs []api.StateID for i, req := range testRequests { submitResponse := makeJSONRPCRequest[api.CertificationResponse]( suite.T(), suite.serverAddr, "certification_request", fmt.Sprintf("submit-%d", i), req) suite.Equal("SUCCESS", submitResponse.Status, "Should return SUCCESS status") - submittedStateIDs = append(submittedStateIDs, req.StateID) } // Wait for block processing time.Sleep(3 * time.Second) - // 2) Verify inclusion proofs and store root hashes - firstBatchRootHashes := make(map[string]string) // stateID => rootHash + // 2) Verify inclusion proofs and store bound UC roots for the stability + // check below. With the v2 wire, the authoritative root identity is + // UC.IR.h, sourced from the proof's UnicityCertificate. + firstBatchRoots := make(map[string]string) // stateID => hex(UC.IR.h) - for _, stateID := range submittedStateIDs { - proofRequest := &api.GetInclusionProofRequestV2{StateID: stateID} + for _, req := range testRequests { + proofRequest := &api.GetInclusionProofRequestV2{StateID: req.StateID} proofResponse := makeJSONRPCRequest[api.GetInclusionProofResponseV2]( suite.T(), suite.serverAddr, "get_inclusion_proof.v2", "get-proof", proofRequest) - // Validate inclusion proof structure and encoding - validateInclusionProof(suite.T(), proofResponse.InclusionProof, stateID) + // Validate inclusion proof structure and encoding + end-to-end verify. + validateInclusionProof(suite.T(), proofResponse.InclusionProof, req) - // Store root hash for later stability check - suite.Require().NotNil(proofResponse.InclusionProof.MerkleTreePath) - rootHash := proofResponse.InclusionProof.MerkleTreePath.Root - firstBatchRootHashes[stateID.String()] = rootHash + rootRaw, err := proofResponse.InclusionProof.UCInputRecordHashRaw() + suite.Require().NoError(err) + firstBatchRoots[req.StateID.String()] = hex.EncodeToString(rootRaw) } // 3) Send more requests @@ -422,24 +520,25 @@ func (suite *AggregatorTestSuite) TestInclusionProof() { // Wait for new block processing time.Sleep(3 * time.Second) - // 4) Verify original commitments now reference current root hash - for _, stateID := range submittedStateIDs { - proofRequest := &api.GetInclusionProofRequestV2{StateID: stateID} + // 4) Verify original commitments now reference current root hash (UC.IR.h). + for _, req := range testRequests { + proofRequest := &api.GetInclusionProofRequestV2{StateID: req.StateID} proofResponse := makeJSONRPCRequest[api.GetInclusionProofResponseV2]( suite.T(), suite.serverAddr, "get_inclusion_proof.v2", "stability-check", proofRequest) - suite.Require().NotNil(proofResponse.InclusionProof.MerkleTreePath) - currentRootHash := proofResponse.InclusionProof.MerkleTreePath.Root - originalRootHash, exists := firstBatchRootHashes[stateID.String()] + currentRootRaw, err := proofResponse.InclusionProof.UCInputRecordHashRaw() + suite.Require().NoError(err) + currentRootHash := hex.EncodeToString(currentRootRaw) - suite.True(exists, "Should have stored original root hash for %s", stateID) + originalRootHash, exists := firstBatchRoots[req.StateID.String()] + suite.True(exists, "Should have stored original root hash for %s", req.StateID) // With on-demand proof generation, the root hash should now be different (current SMT state) suite.NotEqual(originalRootHash, currentRootHash, "Root hash in inclusion proof should be current (not original) for StateID %s. Original: %s, Current: %s", - stateID, originalRootHash, currentRootHash) + req.StateID, originalRootHash, currentRootHash) // Validate that the proof is still valid for the commitment - validateInclusionProof(suite.T(), proofResponse.InclusionProof, stateID) + validateInclusionProof(suite.T(), proofResponse.InclusionProof, req) } } @@ -521,3 +620,88 @@ func newAggregatorServiceForTest(t *testing.T, shardingCfg config.ShardingConfig certificationRequestValidator: signing.NewCertificationRequestValidator(shardingCfg), } } + +type testStorage struct { + blockStorage interfaces.BlockStorage + recordStorage interfaces.AggregatorRecordStorage +} + +func (s *testStorage) AggregatorRecordStorage() interfaces.AggregatorRecordStorage { + return s.recordStorage +} +func (s *testStorage) BlockStorage() interfaces.BlockStorage { return s.blockStorage } +func (s *testStorage) SmtStorage() interfaces.SmtStorage { return nil } +func (s *testStorage) BlockRecordsStorage() interfaces.BlockRecordsStorage { return nil } +func (s *testStorage) LeadershipStorage() interfaces.LeadershipStorage { return nil } +func (s *testStorage) TrustBaseStorage() interfaces.TrustBaseStorage { return nil } +func (s *testStorage) Initialize(context.Context) error { return nil } +func (s *testStorage) Ping(context.Context) error { return nil } +func (s *testStorage) Close(context.Context) error { return nil } +func (s *testStorage) WithTransaction(ctx context.Context, fn func(context.Context) error) error { + return fn(ctx) +} + +type testBlockStorage struct { + latestByRoot map[string]*models.Block +} + +func (s *testBlockStorage) Store(context.Context, *models.Block) error { return nil } +func (s *testBlockStorage) GetByNumber(context.Context, *api.BigInt) (*models.Block, error) { + return nil, nil +} +func (s *testBlockStorage) GetLatest(context.Context) (*models.Block, error) { return nil, nil } +func (s *testBlockStorage) GetLatestNumber(context.Context) (*api.BigInt, error) { return nil, nil } +func (s *testBlockStorage) Count(context.Context) (int64, error) { return 0, nil } +func (s *testBlockStorage) GetRange(context.Context, *api.BigInt, *api.BigInt) ([]*models.Block, error) { + return nil, nil +} +func (s *testBlockStorage) SetFinalized(context.Context, *api.BigInt, bool) error { return nil } +func (s *testBlockStorage) GetUnfinalized(context.Context) ([]*models.Block, error) { + return nil, nil +} +func (s *testBlockStorage) GetLatestByRootHash(ctx context.Context, rootHash api.HexBytes) (*models.Block, error) { + if s == nil { + return nil, nil + } + return s.latestByRoot[rootHash.String()], nil +} + +type testAggregatorRecordStorage struct { + byStateID map[string]*models.AggregatorRecord +} + +func (s *testAggregatorRecordStorage) Store(context.Context, *models.AggregatorRecord) error { + return nil +} +func (s *testAggregatorRecordStorage) StoreBatch(context.Context, []*models.AggregatorRecord) error { + return nil +} +func (s *testAggregatorRecordStorage) GetByBlockNumber(context.Context, *api.BigInt) ([]*models.AggregatorRecord, error) { + return nil, nil +} +func (s *testAggregatorRecordStorage) Count(context.Context) (int64, error) { return 0, nil } +func (s *testAggregatorRecordStorage) GetExistingRequestIDs(context.Context, []string) (map[string]bool, error) { + return nil, nil +} +func (s *testAggregatorRecordStorage) GetByStateID(ctx context.Context, stateID api.StateID) (*models.AggregatorRecord, error) { + if s == nil { + return nil, nil + } + return s.byStateID[stateID.String()], nil +} + +func testChildProofUC(t *testing.T, roundNumber uint64, rootHash []byte) api.HexBytes { + t.Helper() + uc := bfttypes.UnicityCertificate{ + InputRecord: &bfttypes.InputRecord{ + RoundNumber: roundNumber, + Hash: bfthex.Bytes(rootHash), + }, + UnicitySeal: &bfttypes.UnicitySeal{ + RootChainRoundNumber: roundNumber, + }, + } + ucBytes, err := bfttypes.Cbor.Marshal(uc) + require.NoError(t, err) + return api.NewHexBytes(ucBytes) +} diff --git a/internal/service/v2_surface_test.go b/internal/service/v2_surface_test.go new file mode 100644 index 0000000..5e1cf02 --- /dev/null +++ b/internal/service/v2_surface_test.go @@ -0,0 +1,42 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/pkg/jsonrpc" +) + +func TestRemovedV1MethodsReturnMethodNotFound(t *testing.T) { + serverAddr, cleanup := setupMongoDBAndAggregator(t, context.Background()) + defer cleanup() + + for _, method := range []string{ + "submit_commitment", + "get_inclusion_proof", + "get_block_commitments", + } { + t.Run(method, func(t *testing.T) { + request, err := jsonrpc.NewRequest(method, map[string]any{}, method) + require.NoError(t, err) + + body, err := json.Marshal(request) + require.NoError(t, err) + + httpResp, err := http.Post(serverAddr, "application/json", bytes.NewReader(body)) + require.NoError(t, err) + defer httpResp.Body.Close() + + var response jsonrpc.Response + require.NoError(t, json.NewDecoder(httpResp.Body).Decode(&response)) + require.NotNil(t, response.Error) + require.Equal(t, jsonrpc.MethodNotFoundCode, response.Error.Code) + require.Equal(t, jsonrpc.ErrMethodNotFound.Message, response.Error.Message) + }) + } +} diff --git a/internal/sharding/root_aggregator_client_stub.go b/internal/sharding/root_aggregator_client_stub.go index e1f5e13..b304f9d 100644 --- a/internal/sharding/root_aggregator_client_stub.go +++ b/internal/sharding/root_aggregator_client_stub.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/unicitynetwork/bft-go-base/types" + "github.com/unicitynetwork/bft-go-base/types/hex" "github.com/unicitynetwork/aggregator-go/pkg/api" ) @@ -46,16 +47,18 @@ func (m *RootAggregatorClientStub) GetShardProof(ctx context.Context, request *a if m.submissions[request.ShardID] != nil { m.returnedProofCount++ - submittedRootHash := m.submittedRootHash.String() - ucBytes, err := stubProofUC(uint64(m.returnedProofCount), uint64(m.returnedProofCount)) + ucBytes, err := stubProofUC(uint64(m.returnedProofCount), uint64(m.returnedProofCount), m.submittedRootHash) if err != nil { return nil, err } + fragment := &api.ParentInclusionFragment{ + CertificateBytes: api.NewHexBytes(make([]byte, api.BitmapSize)), + ShardLeafValue: api.NewHexBytes(m.submittedRootHash), + } return &api.RootShardInclusionProof{ + ParentFragment: fragment, + BlockNumber: uint64(m.returnedProofCount), UnicityCertificate: ucBytes, - MerkleTreePath: &api.MerkleTreePath{ - Steps: []api.MerkleTreeStep{{Data: &submittedRootHash}}, - }, }, nil } return nil, nil @@ -91,10 +94,11 @@ func (m *RootAggregatorClientStub) SetSubmissionError(err error) { m.submissionError = err } -func stubProofUC(parentRound, rootRound uint64) (api.HexBytes, error) { +func stubProofUC(parentRound, rootRound uint64, rootHash api.HexBytes) (api.HexBytes, error) { uc := types.UnicityCertificate{ InputRecord: &types.InputRecord{ RoundNumber: parentRound, + Hash: hex.Bytes(rootHash), }, UnicitySeal: &types.UnicitySeal{ RootChainRoundNumber: rootRound, diff --git a/internal/sharding/root_aggregator_client_test.go b/internal/sharding/root_aggregator_client_test.go index 7a7245f..fba9e86 100644 --- a/internal/sharding/root_aggregator_client_test.go +++ b/internal/sharding/root_aggregator_client_test.go @@ -59,7 +59,11 @@ func TestRootAggregatorClient_GetShardProof(t *testing.T) { require.Equal(t, float64(4), params["shardId"]) proof := &api.RootShardInclusionProof{ - MerkleTreePath: &api.MerkleTreePath{Root: "0x1234"}, + ParentFragment: &api.ParentInclusionFragment{ + CertificateBytes: api.NewHexBytes(make([]byte, api.BitmapSize)), + ShardLeafValue: api.HexBytes("0x1234"), + }, + BlockNumber: 9, UnicityCertificate: api.HexBytes("0xabcdef"), } @@ -80,7 +84,8 @@ func TestRootAggregatorClient_GetShardProof(t *testing.T) { require.NoError(t, err) require.NotNil(t, proof) - require.Equal(t, "0x1234", proof.MerkleTreePath.Root) + require.Equal(t, api.HexBytes("0x1234"), proof.ParentFragment.ShardLeafValue) + require.Equal(t, uint64(9), proof.BlockNumber) require.Equal(t, api.HexBytes("0xabcdef"), proof.UnicityCertificate) } diff --git a/internal/signing/certification_request_validator.go b/internal/signing/certification_request_validator.go index a34f7e9..7f38e6b 100644 --- a/internal/signing/certification_request_validator.go +++ b/internal/signing/certification_request_validator.go @@ -1,10 +1,8 @@ package signing import ( - "encoding/hex" "errors" "fmt" - "math/big" "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/models" @@ -139,8 +137,7 @@ func (v *CertificationRequestValidator) Validate(commitment *models.Certificatio if len(transactionHash) != 32 { return ValidationResult{ Status: ValidationStatusInvalidTransactionHashFormat, - Error: fmt.Errorf("transaction hash imprint must have at least 3 bytes (2 algorithm + 1 data), "+ - "got %d", len(transactionHash)), + Error: fmt.Errorf("transaction hash must be exactly 32 bytes, got %d", len(transactionHash)), } } @@ -186,7 +183,7 @@ func (v *CertificationRequestValidator) ValidateShardID(stateID api.StateID) err if !v.shardConfig.Mode.IsChild() { return nil } - ok, err := verifyShardID(stateID.String(), v.shardConfig.Child.ShardID) + ok, err := api.MatchesShardPrefixFromHex(stateID.String(), v.shardConfig.Child.ShardID) if err != nil { return fmt.Errorf("error verifying shard id: %w", err) } @@ -195,34 +192,3 @@ func (v *CertificationRequestValidator) ValidateShardID(stateID api.StateID) err } return nil } - -// verifyShardID Checks if commitmentID's least significant bits match the shard bitmask. -func verifyShardID(commitmentID string, shardBitmask int) (bool, error) { - // convert to big.Ints - bytes, err := hex.DecodeString(commitmentID) - if err != nil { - return false, fmt.Errorf("failed to decode certification state ID: %w", err) - } - commitmentIdBigInt := new(big.Int).SetBytes(bytes) - shardBitmaskBigInt := new(big.Int).SetInt64(int64(shardBitmask)) - - // find position of MSB e.g. - // 0b111 -> BitLen=3 -> 3-1=2 - msbPos := shardBitmaskBigInt.BitLen() - 1 - - // build a mask covering bits below MSB e.g. - // 1<<2=0b100; 0b100-1=0b11; compareMask=0b11 - compareMask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(msbPos)), big.NewInt(1)) - - // remove MSB from shardBitmask to get expected value e.g. - // 0b111 & 0b11 = 0b11 - expected := new(big.Int).And(shardBitmaskBigInt, compareMask) - - // extract low bits from certification request e.g. - // commitment=0b11111111 & 0b11 = 0b11 - commitmentLowBits := new(big.Int).And(commitmentIdBigInt, compareMask) - - // return true if the certification request low bits match bitmask bits e.g. - // 0b11 == 0b11 - return commitmentLowBits.Cmp(expected) == 0, nil -} diff --git a/internal/signing/certification_request_validator_test.go b/internal/signing/certification_request_validator_test.go index 50db8c2..9ff5c43 100644 --- a/internal/signing/certification_request_validator_test.go +++ b/internal/signing/certification_request_validator_test.go @@ -149,6 +149,13 @@ func TestValidator_StateIDMismatch(t *testing.T) { } func TestValidator_ShardID(t *testing.T) { + makeShardTestID := func(firstByte, lastByte byte) string { + key := make([]byte, api.StateTreeKeyLengthBytes) + key[0] = firstByte + key[len(key)-1] = lastByte + return hex.EncodeToString(key) + } + tests := []struct { commitmentID string shardBitmask int @@ -158,25 +165,29 @@ func TestValidator_ShardID(t *testing.T) { // shard1=bitmask 0b10 // shard2=bitmask 0b11 - // certification request ending with 0b00000000 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b10, true}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b11, false}, + // certification request with key bit 0 = 0 belongs to shard1 + {makeShardTestID(0x00, 0x00), 0b10, true}, + {makeShardTestID(0x00, 0x00), 0b11, false}, + + // certification request with key bit 0 = 1 belongs to shard2 + {makeShardTestID(0x01, 0x00), 0b10, false}, + {makeShardTestID(0x01, 0x00), 0b11, true}, - // certification request ending with 0b00000001 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b10, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b11, true}, + // certification request with first byte 0b00000010 still belongs to shard1 + {makeShardTestID(0x02, 0x00), 0b10, true}, + {makeShardTestID(0x02, 0x00), 0b11, false}, - // certification request ending with 0b00000010 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b10, true}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b11, false}, + // certification request with first byte 0b00000011 belongs to shard2 + {makeShardTestID(0x03, 0x00), 0b10, false}, + {makeShardTestID(0x03, 0x00), 0b11, true}, - // certification request ending with 0b00000011 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b10, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b11, true}, + // certification request with first byte 0b11111111 belongs to shard2 + {makeShardTestID(0xFF, 0x00), 0b10, false}, + {makeShardTestID(0xFF, 0x00), 0b11, true}, - // certification request ending with 0b11111111 belongs to shard2 - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b10, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b11, true}, + // the last byte no longer affects shard routing under LSB-first byte order + {makeShardTestID(0x00, 0xFF), 0b10, true}, + {makeShardTestID(0x00, 0xFF), 0b11, false}, // === END TWO SHARD CONFIG === @@ -186,40 +197,40 @@ func TestValidator_ShardID(t *testing.T) { // shard3=0b101 // shard4=0b111 - // certification request ending with 0b00000000 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b100, true}, - - // certification request ending with 0b00000010 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b100, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b110, true}, - - // certification request ending with 0b00000001 belongs to shard3 - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b101, true}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b100, false}, - - // certification request ending with 0b00000011 belongs to shard4 - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b111, true}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b100, false}, - - // certification request ending with 0b11111111 belongs to shard4 - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b111, true}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b101, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b110, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b100, false}, + // key bits 1:0 = 00 belong to shard1 + {makeShardTestID(0x00, 0x00), 0b111, false}, + {makeShardTestID(0x00, 0x00), 0b101, false}, + {makeShardTestID(0x00, 0x00), 0b110, false}, + {makeShardTestID(0x00, 0x00), 0b100, true}, + + // key bits 1:0 = 10 belong to shard2 + {makeShardTestID(0x02, 0x00), 0b111, false}, + {makeShardTestID(0x02, 0x00), 0b100, false}, + {makeShardTestID(0x02, 0x00), 0b101, false}, + {makeShardTestID(0x02, 0x00), 0b110, true}, + + // key bits 1:0 = 01 belong to shard3 + {makeShardTestID(0x01, 0x00), 0b111, false}, + {makeShardTestID(0x01, 0x00), 0b101, true}, + {makeShardTestID(0x01, 0x00), 0b110, false}, + {makeShardTestID(0x01, 0x00), 0b100, false}, + + // key bits 1:0 = 11 belong to shard4 + {makeShardTestID(0x03, 0x00), 0b111, true}, + {makeShardTestID(0x03, 0x00), 0b101, false}, + {makeShardTestID(0x03, 0x00), 0b110, false}, + {makeShardTestID(0x03, 0x00), 0b100, false}, + + // key bits 1:0 = 11 still belong to shard4 when the whole first byte is set + {makeShardTestID(0xFF, 0x00), 0b111, true}, + {makeShardTestID(0xFF, 0x00), 0b101, false}, + {makeShardTestID(0xFF, 0x00), 0b110, false}, + {makeShardTestID(0xFF, 0x00), 0b100, false}, // === END FOUR SHARD CONFIG === } for _, tc := range tests { - match, err := verifyShardID(tc.commitmentID, tc.shardBitmask) + match, err := api.MatchesShardPrefixFromHex(tc.commitmentID, tc.shardBitmask) require.NoError(t, err) if match != tc.match { t.Errorf("commitmentID=%s shardBitmask=%b expected %v got %v", tc.commitmentID, tc.shardBitmask, tc.match, match) @@ -227,6 +238,17 @@ func TestValidator_ShardID(t *testing.T) { } } +func TestValidator_ShardIDRejectsPrefixedStateID(t *testing.T) { + rawKey := make([]byte, api.StateTreeKeyLengthBytes) + rawKey[0] = 0x01 + prefixed := append([]byte{0x00, 0x00}, rawKey...) + + match, err := api.MatchesShardPrefixFromHex(hex.EncodeToString(prefixed), 0b11) + require.Error(t, err) + require.False(t, match) + require.Contains(t, err.Error(), "must be exactly 32 bytes") +} + func TestValidator_InvalidSignatureFormat(t *testing.T) { validator := newDefaultCertificationRequestValidator() diff --git a/internal/signing/v1/commitment_validator.go b/internal/signing/v1/commitment_validator.go index 7d1606c..ce9293d 100644 --- a/internal/signing/v1/commitment_validator.go +++ b/internal/signing/v1/commitment_validator.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "errors" "fmt" - "math/big" "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/models/v1" @@ -212,7 +211,16 @@ func (v *CommitmentValidator) ValidateShardID(requestID api.RequestID) error { if !v.shardConfig.Mode.IsChild() { return nil } - ok, err := verifyShardID(requestID.String(), v.shardConfig.Child.ShardID) + keyHex := requestID.String() + keyBytes, err := hex.DecodeString(keyHex) + if err != nil { + return fmt.Errorf("error decoding request ID: %w", err) + } + // V1 request IDs may carry a 2-byte algorithm prefix; strip it. + if len(keyBytes) == api.StateTreeKeyLengthBytes+2 { + keyBytes = keyBytes[2:] + } + ok, err := api.MatchesShardPrefix(keyBytes, v.shardConfig.Child.ShardID) if err != nil { return fmt.Errorf("error verifying shard id: %w", err) } @@ -222,37 +230,6 @@ func (v *CommitmentValidator) ValidateShardID(requestID api.RequestID) error { return nil } -// verifyShardID Checks if commitmentID's least significant bits match the shard bitmask. -func verifyShardID(commitmentID string, shardBitmask int) (bool, error) { - // convert to big.Ints - bytes, err := hex.DecodeString(commitmentID) - if err != nil { - return false, fmt.Errorf("failed to decode certification state ID: %w", err) - } - commitmentIdBigInt := new(big.Int).SetBytes(bytes) - shardBitmaskBigInt := new(big.Int).SetInt64(int64(shardBitmask)) - - // find position of MSB e.g. - // 0b111 -> BitLen=3 -> 3-1=2 - msbPos := shardBitmaskBigInt.BitLen() - 1 - - // build a mask covering bits below MSB e.g. - // 1<<2=0b100; 0b100-1=0b11; compareMask=0b11 - compareMask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(msbPos)), big.NewInt(1)) - - // remove MSB from shardBitmask to get expected value e.g. - // 0b111 & 0b11 = 0b11 - expected := new(big.Int).And(shardBitmaskBigInt, compareMask) - - // extract low bits from certification request e.g. - // commitment=0b11111111 & 0b11 = 0b11 - commitmentLowBits := new(big.Int).And(commitmentIdBigInt, compareMask) - - // return true if the certification request low bits match bitmask bits e.g. - // 0b11 == 0b11 - return commitmentLowBits.Cmp(expected) == 0, nil -} - // CreateDataHashImprint creates a DataHash imprint in the Unicity format: // 2 bytes algorithm (big-endian) + actual hash bytes // For SHA256: algorithm = 0, so prefix is [0x00, 0x00] diff --git a/internal/signing/v1/commitment_validator_test.go b/internal/signing/v1/commitment_validator_test.go index 98e2076..a35872a 100644 --- a/internal/signing/v1/commitment_validator_test.go +++ b/internal/signing/v1/commitment_validator_test.go @@ -161,6 +161,13 @@ func TestValidator_RequestIDMismatch(t *testing.T) { } func TestValidator_ShardID(t *testing.T) { + makeShardTestID := func(firstByte, lastByte byte) string { + key := make([]byte, api.StateTreeKeyLengthBytes) + key[0] = firstByte + key[len(key)-1] = lastByte + return hex.EncodeToString(key) + } + tests := []struct { commitmentID string shardBitmask int @@ -170,25 +177,29 @@ func TestValidator_ShardID(t *testing.T) { // shard1=bitmask 0b10 // shard2=bitmask 0b11 - // commitment ending with 0b00000000 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b10, true}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b11, false}, + // request key bit 0 = 0 belongs to shard1 + {makeShardTestID(0x00, 0x00), 0b10, true}, + {makeShardTestID(0x00, 0x00), 0b11, false}, + + // request key bit 0 = 1 belongs to shard2 + {makeShardTestID(0x01, 0x00), 0b10, false}, + {makeShardTestID(0x01, 0x00), 0b11, true}, - // commitment ending with 0b00000001 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b10, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b11, true}, + // request first byte 0b00000010 still belongs to shard1 + {makeShardTestID(0x02, 0x00), 0b10, true}, + {makeShardTestID(0x02, 0x00), 0b11, false}, - // commitment ending with 0b00000010 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b10, true}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b11, false}, + // request first byte 0b00000011 belongs to shard2 + {makeShardTestID(0x03, 0x00), 0b10, false}, + {makeShardTestID(0x03, 0x00), 0b11, true}, - // commitment ending with 0b00000011 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b10, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b11, true}, + // request first byte 0b11111111 belongs to shard2 + {makeShardTestID(0xFF, 0x00), 0b10, false}, + {makeShardTestID(0xFF, 0x00), 0b11, true}, - // commitment ending with 0b11111111 belongs to shard2 - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b10, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b11, true}, + // the final byte no longer affects shard routing + {makeShardTestID(0x00, 0xFF), 0b10, true}, + {makeShardTestID(0x00, 0xFF), 0b11, false}, // === END TWO SHARD CONFIG === @@ -198,40 +209,40 @@ func TestValidator_ShardID(t *testing.T) { // shard3=0b101 // shard4=0b111 - // commitment ending with 0b00000000 belongs to shard1 - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000000", 0b100, true}, - - // commitment ending with 0b00000010 belongs to shard2 - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b100, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000002", 0b110, true}, - - // commitment ending with 0b00000001 belongs to shard3 - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b111, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b101, true}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000001", 0b100, false}, - - // commitment ending with 0b00000011 belongs to shard4 - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b111, true}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b101, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b110, false}, - {"00000000000000000000000000000000000000000000000000000000000000000003", 0b100, false}, - - // commitment ending with 0b11111111 belongs to shard4 - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b111, true}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b101, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b110, false}, - {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b100, false}, + // request key bits 1:0 = 00 belong to shard1 + {makeShardTestID(0x00, 0x00), 0b111, false}, + {makeShardTestID(0x00, 0x00), 0b101, false}, + {makeShardTestID(0x00, 0x00), 0b110, false}, + {makeShardTestID(0x00, 0x00), 0b100, true}, + + // request key bits 1:0 = 10 belong to shard2 + {makeShardTestID(0x02, 0x00), 0b111, false}, + {makeShardTestID(0x02, 0x00), 0b100, false}, + {makeShardTestID(0x02, 0x00), 0b101, false}, + {makeShardTestID(0x02, 0x00), 0b110, true}, + + // request key bits 1:0 = 01 belong to shard3 + {makeShardTestID(0x01, 0x00), 0b111, false}, + {makeShardTestID(0x01, 0x00), 0b101, true}, + {makeShardTestID(0x01, 0x00), 0b110, false}, + {makeShardTestID(0x01, 0x00), 0b100, false}, + + // request key bits 1:0 = 11 belong to shard4 + {makeShardTestID(0x03, 0x00), 0b111, true}, + {makeShardTestID(0x03, 0x00), 0b101, false}, + {makeShardTestID(0x03, 0x00), 0b110, false}, + {makeShardTestID(0x03, 0x00), 0b100, false}, + + // request key bits 1:0 = 11 still belong to shard4 when the whole first byte is set + {makeShardTestID(0xFF, 0x00), 0b111, true}, + {makeShardTestID(0xFF, 0x00), 0b101, false}, + {makeShardTestID(0xFF, 0x00), 0b110, false}, + {makeShardTestID(0xFF, 0x00), 0b100, false}, // === END FOUR SHARD CONFIG === } for _, tc := range tests { - match, err := verifyShardID(tc.commitmentID, tc.shardBitmask) + match, err := api.MatchesShardPrefixFromHex(tc.commitmentID, tc.shardBitmask) require.NoError(t, err) if match != tc.match { t.Errorf("commitmentID=%s shardBitmask=%b expected %v got %v", tc.commitmentID, tc.shardBitmask, tc.match, match) @@ -239,6 +250,23 @@ func TestValidator_ShardID(t *testing.T) { } } +func TestValidator_ValidateShardID_AcceptsPrefixedRequestID(t *testing.T) { + key := make([]byte, api.StateTreeKeyLengthBytes) + key[0] = 0x01 // bit 0 set -> shard 0b11 + + validator := &CommitmentValidator{ + shardConfig: config.ShardingConfig{ + Mode: config.ShardingModeChild, + Child: config.ChildConfig{ + ShardID: 0b11, + }, + }, + } + + requestID := api.RequireNewImprintV2("0000" + hex.EncodeToString(key)) + require.NoError(t, validator.ValidateShardID(requestID)) +} + func TestValidator_InvalidSignatureFormat(t *testing.T) { validator := newDefaultCommitmentValidator() diff --git a/internal/smt/golden_vectors_test.go b/internal/smt/golden_vectors_test.go new file mode 100644 index 0000000..2e326d6 --- /dev/null +++ b/internal/smt/golden_vectors_test.go @@ -0,0 +1,151 @@ +package smt + +import ( + "encoding/hex" + "fmt" + "math/big" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +func TestGoldenVector_RootMatches(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + + k1 := mustPathFromKeyHex(t, "0100000000000000000000000000000000000000000000000000000000000000") + k2 := mustPathFromKeyHex(t, "0200000000000000000000000000000000000000000000000000000000000000") + + require.NoError(t, tree.AddLeaf(k1, []byte("value-one"))) + require.NoError(t, tree.AddLeaf(k2, []byte("value-two"))) + + const expectedRoot = "000020563433422d651813394a07697b9c09f9c2ab2ddb95eaa8ed2dc3211de3e869" + require.Equal(t, expectedRoot, tree.GetRootHashHex()) +} + +func TestGoldenVector_ProofBitmapAndSiblingsMatch(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + + k1 := mustPathFromKeyHex(t, "0100000000000000000000000000000000000000000000000000000000000000") + k2 := mustPathFromKeyHex(t, "0300000000000000000000000000000000000000000000000000000000000000") + k3 := mustPathFromKeyHex(t, "0800000000000000000000000000000000000000000000000000000000000000") + + v1 := mustHex(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + v2 := mustHex(t, "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + v3 := mustHex(t, "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc") + + require.NoError(t, tree.AddLeaf(k1, v1)) + require.NoError(t, tree.AddLeaf(k2, v2)) + require.NoError(t, tree.AddLeaf(k3, v3)) + + const expectedRoot = "0000b08cae8f98a168a4b39dced99fc3ea2833291c8c53a0eb447e0056044dee598a" + require.Equal(t, expectedRoot, tree.GetRootHashHex()) + + path, err := tree.GetPath(k2) + require.NoError(t, err) + require.NotNil(t, path) + + vr, err := path.Verify(k2) + require.NoError(t, err) + require.True(t, vr.PathValid) + require.True(t, vr.PathIncluded) + require.True(t, vr.Result) + + bitmap, siblings, err := pathToBitmapAndSiblings(path, k2.BitLen()-1) + require.NoError(t, err) + + const expectedBitmap = "0300000000000000000000000000000000000000000000000000000000000000" + require.Equal(t, expectedBitmap, hex.EncodeToString(bitmap[:])) + + expectedSiblings := []string{ + "4a67e1a8224ab28c6a641ea0a66def8366990856d3b6c6176319c66062821ea6", + "40c73c53d4a8ba73ccc286520242668a11e66f253da145cd54c5a465ff640552", + } + require.Len(t, siblings, len(expectedSiblings)) + for i, s := range siblings { + require.Equal(t, expectedSiblings[i], hex.EncodeToString(s)) + } +} + +func pathToBitmapAndSiblings(path *api.MerkleTreePath, fullKeyBits int) ([32]byte, [][]byte, error) { + var bitmap [32]byte + if path == nil { + return bitmap, nil, fmt.Errorf("nil path") + } + if len(path.Steps) == 0 { + return bitmap, nil, fmt.Errorf("empty path") + } + + currentPath, ok := new(big.Int).SetString(path.Steps[0].Path, 10) + if !ok || currentPath.Sign() < 0 { + return bitmap, nil, fmt.Errorf("invalid first path segment") + } + + siblings := make([][]byte, 0, len(path.Steps)) + for i := 1; i < len(path.Steps); i++ { + step := path.Steps[i] + stepPath, ok := new(big.Int).SetString(step.Path, 10) + if !ok || stepPath.Sign() < 0 { + return bitmap, nil, fmt.Errorf("invalid step path at index %d", i) + } + + var depth int + if currentPath.BitLen() >= 2 { + depth = fullKeyBits - (currentPath.BitLen() - 1) + } else { + depth = stepPath.BitLen() - 1 + } + if depth < 0 || depth > 255 { + return bitmap, nil, fmt.Errorf("invalid depth %d at index %d", depth, i) + } + + if step.Data != nil { + sibling, err := hex.DecodeString(*step.Data) + if err != nil { + return bitmap, nil, fmt.Errorf("invalid sibling hex at index %d: %w", i, err) + } + if len(sibling) != 32 { + return bitmap, nil, fmt.Errorf("invalid sibling length at index %d: %d", i, len(sibling)) + } + bitmap[depth/8] |= 1 << (depth % 8) + siblings = append(siblings, sibling) + } + + if currentPath.BitLen() < 2 { + currentPath = big.NewInt(1) + } + pathLen := stepPath.BitLen() - 1 + if pathLen < 0 { + return bitmap, nil, fmt.Errorf("invalid path length at index %d", i) + } + mask := new(big.Int).SetBit(new(big.Int).Set(stepPath), pathLen, 0) + currentPath.Lsh(currentPath, uint(pathLen)) + currentPath.Or(currentPath, mask) + } + + // Current Go proof API (`MerkleTreePath.Steps`) is emitted leaf-to-root. + // Rugregator bitmap proofs serialize siblings root-to-leaf. + // This reverse is only for test-vector normalization between the two wire + // formats; it does not change tree/hash semantics. + for i, j := 0, len(siblings)-1; i < j; i, j = i+1, j-1 { + siblings[i], siblings[j] = siblings[j], siblings[i] + } + + return bitmap, siblings, nil +} + +func mustPathFromKeyHex(t *testing.T, hexKey string) *big.Int { + t.Helper() + key := mustHex(t, hexKey) + p, err := api.FixedBytesToPath(key, api.StateTreeKeyLengthBits) + require.NoError(t, err) + return p +} + +func mustHex(t *testing.T, s string) []byte { + t.Helper() + b, err := hex.DecodeString(s) + require.NoError(t, err) + return b +} diff --git a/internal/smt/inclusion_cert_test.go b/internal/smt/inclusion_cert_test.go new file mode 100644 index 0000000..124c994 --- /dev/null +++ b/internal/smt/inclusion_cert_test.go @@ -0,0 +1,344 @@ +package smt + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "math/big" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// TestGetInclusionCert_SingleLeaf exercises the single-leaf / unary-root +// edge case: bitmap is all zeros, no siblings, and verification reduces to +// H_leaf(key, value) == root. +func TestGetInclusionCert_SingleLeaf(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + + key := mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000") + value := []byte("solo-leaf-value") + + path, err := api.FixedBytesToPath(key, api.StateTreeKeyLengthBits) + require.NoError(t, err) + require.NoError(t, tree.AddLeaf(path, value)) + + cert, err := tree.GetInclusionCert(key) + require.NoError(t, err) + require.NotNil(t, cert) + require.Equal(t, 0, bitmapPopcountForTest(&cert.Bitmap), "single-leaf cert bitmap must be empty") + require.Len(t, cert.Siblings, 0, "single-leaf cert must have zero siblings") + + root := tree.GetRootHashRaw() + require.Len(t, root, api.SiblingSize) + + require.NoError(t, cert.Verify(key, value, root, api.SHA256)) +} + +// TestGetInclusionCert_TwoLeaves covers a two-leaf tree: each proof should +// verify against the same root and the two bitmaps must be identical (shared +// split depth), with exactly one sibling each. +func TestGetInclusionCert_TwoLeaves(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + + keyA := mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000") + keyB := mustHex(t, "0200000000000000000000000000000000000000000000000000000000000000") + valA := []byte("value-a") + valB := []byte("value-b") + + pathA, err := api.FixedBytesToPath(keyA, api.StateTreeKeyLengthBits) + require.NoError(t, err) + pathB, err := api.FixedBytesToPath(keyB, api.StateTreeKeyLengthBits) + require.NoError(t, err) + + require.NoError(t, tree.AddLeaf(pathA, valA)) + require.NoError(t, tree.AddLeaf(pathB, valB)) + + root := tree.GetRootHashRaw() + + certA, err := tree.GetInclusionCert(keyA) + require.NoError(t, err) + certB, err := tree.GetInclusionCert(keyB) + require.NoError(t, err) + + require.Equal(t, certA.Bitmap, certB.Bitmap, "shared split depth → identical bitmaps") + require.Len(t, certA.Siblings, 1, "two-leaf cert has exactly one sibling") + require.Len(t, certB.Siblings, 1, "two-leaf cert has exactly one sibling") + require.Equal(t, 1, bitmapPopcountForTest(&certA.Bitmap)) + + require.NoError(t, certA.Verify(keyA, valA, root, api.SHA256)) + require.NoError(t, certB.Verify(keyB, valB, root, api.SHA256)) +} + +// TestGetInclusionCert_GoldenVector cross-checks the Go inclusion cert +// generator against the frozen golden vector already used by +// golden_vectors_test.go. The same three keys, values and root must produce +// the same bitmap and sibling list under the new wire format. +func TestGetInclusionCert_GoldenVector(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + + k1 := mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000") + k2 := mustHex(t, "0300000000000000000000000000000000000000000000000000000000000000") + k3 := mustHex(t, "0800000000000000000000000000000000000000000000000000000000000000") + + v1 := mustHex(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + v2 := mustHex(t, "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + v3 := mustHex(t, "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc") + + addLeaf(t, tree, k1, v1) + addLeaf(t, tree, k2, v2) + addLeaf(t, tree, k3, v3) + + const expectedRoot = "0000b08cae8f98a168a4b39dced99fc3ea2833291c8c53a0eb447e0056044dee598a" + require.Equal(t, expectedRoot, tree.GetRootHashHex()) + + cert, err := tree.GetInclusionCert(k2) + require.NoError(t, err) + + const expectedBitmap = "0300000000000000000000000000000000000000000000000000000000000000" + require.Equal(t, expectedBitmap, hex.EncodeToString(cert.Bitmap[:]), + "cert bitmap must match golden vector") + + expectedSiblings := []string{ + "4a67e1a8224ab28c6a641ea0a66def8366990856d3b6c6176319c66062821ea6", + "40c73c53d4a8ba73ccc286520242668a11e66f253da145cd54c5a465ff640552", + } + require.Len(t, cert.Siblings, len(expectedSiblings)) + for i, sib := range cert.Siblings { + require.Equal(t, expectedSiblings[i], hex.EncodeToString(sib[:]), + "sibling %d must match golden vector", i) + } + + root := tree.GetRootHashRaw() + require.NoError(t, cert.Verify(k2, v2, root, api.SHA256)) +} + +// TestGetInclusionCert_AllKeysVerify builds a tree with several leaves and +// verifies that every leaf's cert round-trips through +// Marshal → Unmarshal → Verify. +func TestGetInclusionCert_AllKeysVerify(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + + keys := [][]byte{ + mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000"), + mustHex(t, "0300000000000000000000000000000000000000000000000000000000000000"), + mustHex(t, "0800000000000000000000000000000000000000000000000000000000000000"), + mustHex(t, "ff00000000000000000000000000000000000000000000000000000000000000"), + mustHex(t, "aa55aa55aa55aa55aa55aa55aa55aa55aa55aa55aa55aa55aa55aa55aa55aa55"), + } + values := [][]byte{ + []byte("value-1"), + []byte("value-2"), + []byte("value-3"), + []byte("value-4"), + []byte("value-5"), + } + for i := range keys { + addLeaf(t, tree, keys[i], values[i]) + } + + root := tree.GetRootHashRaw() + + for i := range keys { + cert, err := tree.GetInclusionCert(keys[i]) + require.NoError(t, err, "GetInclusionCert key=%d", i) + + // Round-trip through wire encoding. + wire, err := cert.MarshalBinary() + require.NoError(t, err) + var decoded api.InclusionCert + require.NoError(t, decoded.UnmarshalBinary(wire)) + require.Equal(t, cert.Bitmap, decoded.Bitmap) + require.Equal(t, len(cert.Siblings), len(decoded.Siblings)) + + require.NoError(t, decoded.Verify(keys[i], values[i], root, api.SHA256), + "round-tripped cert must verify for key %d", i) + } +} + +// TestGetInclusionCert_RandomLeaves fuzzes a moderately sized tree with +// random keys and values, and verifies that every cert verifies against the +// SMT root. Uses a deterministic seed via crypto/rand (non-reproducible but +// exercises many tree shapes). +func TestGetInclusionCert_RandomLeaves(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + + const n = 64 + keys := make([][]byte, 0, n) + values := make([][]byte, 0, n) + seen := make(map[string]struct{}, n) + + for len(keys) < n { + key := make([]byte, api.StateTreeKeyLengthBytes) + _, err := rand.Read(key) + require.NoError(t, err) + if _, dup := seen[string(key)]; dup { + continue + } + seen[string(key)] = struct{}{} + val := make([]byte, 16) + _, err = rand.Read(val) + require.NoError(t, err) + addLeaf(t, tree, key, val) + keys = append(keys, key) + values = append(values, val) + } + + root := tree.GetRootHashRaw() + + for i := range keys { + cert, err := tree.GetInclusionCert(keys[i]) + require.NoError(t, err) + require.NoError(t, cert.Verify(keys[i], values[i], root, api.SHA256), + "random cert must verify for key %d", i) + } +} + +// TestGetInclusionCert_WrongValueFails verifies that a cert generated for +// key K cannot be used to attest an incorrect value. +func TestGetInclusionCert_WrongValueFails(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + + k1 := mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000") + k2 := mustHex(t, "0300000000000000000000000000000000000000000000000000000000000000") + v1 := []byte("truth") + v2 := []byte("other") + + addLeaf(t, tree, k1, v1) + addLeaf(t, tree, k2, v2) + root := tree.GetRootHashRaw() + + cert, err := tree.GetInclusionCert(k1) + require.NoError(t, err) + + err = cert.Verify(k1, []byte("lie"), root, api.SHA256) + require.Error(t, err, "verifying with wrong value must fail") +} + +// TestGetInclusionCert_MissingKey ensures that requesting a cert for a key +// that does not exist in the tree returns an error rather than producing an +// invalid proof. +func TestGetInclusionCert_MissingKey(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + present := mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000") + addLeaf(t, tree, present, []byte("v")) + + missing := mustHex(t, "0200000000000000000000000000000000000000000000000000000000000000") + cert, err := tree.GetInclusionCert(missing) + require.Error(t, err) + require.Nil(t, cert) +} + +// TestGetInclusionCert_EmptyTree ensures that calling GetInclusionCert on +// an empty tree fails rather than returning a spurious cert. +func TestGetInclusionCert_EmptyTree(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + key := mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000") + + cert, err := tree.GetInclusionCert(key) + require.Error(t, err) + require.Nil(t, cert) +} + +// TestGetInclusionCert_WrongKeyLength ensures short/long keys are rejected +// by the Go SMT generator before any traversal starts. +func TestGetInclusionCert_WrongKeyLength(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + addLeaf(t, tree, + mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000"), + []byte("v")) + + cases := []struct { + name string + key []byte + }{ + {"nil", nil}, + {"short", make([]byte, api.StateTreeKeyLengthBytes-1)}, + {"long", make([]byte, api.StateTreeKeyLengthBytes+1)}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + cert, err := tree.GetInclusionCert(tc.key) + require.Error(t, err) + require.ErrorIs(t, err, api.ErrCertKeyLength) + require.Nil(t, cert) + }) + } +} + +// TestGetShardInclusionFragment_SkipsNilSiblingAfterNonNilSibling exercises +// the parent-tree traversal branch where a shallower sibling exists but a +// deeper sibling subtree is empty (nil hash). That depth must be skipped as a +// unary passthrough while keeping previously recorded siblings valid. +func TestGetShardInclusionFragment_SkipsNilSiblingAfterNonNilSibling(t *testing.T) { + parent := NewParentSparseMerkleTree(api.SHA256, 2) + + // Update two shards on opposite root sides so root-level sibling is present. + leaf4 := bytes.Repeat([]byte{0xA4}, api.SiblingSize) + leaf5 := bytes.Repeat([]byte{0xB5}, api.SiblingSize) + require.NoError(t, parent.AddLeaf(big.NewInt(0b100), leaf4)) + require.NoError(t, parent.AddLeaf(big.NewInt(0b101), leaf5)) + + fragment, err := parent.GetShardInclusionFragment(0b100) + require.NoError(t, err) + require.NotNil(t, fragment) + + var cert api.InclusionCert + require.NoError(t, cert.UnmarshalBinary(fragment.CertificateBytes)) + require.Equal(t, 1, bitmapPopcountForTest(&cert.Bitmap), "one deeper depth should be skipped as unary passthrough") + require.Len(t, cert.Siblings, 1, "only the non-empty shallower sibling should be present") + + root := parent.GetRootHashRaw() + require.NoError(t, fragment.Verify(0b100, 2, leaf4, root, api.SHA256)) +} + +// TestGetRootHashRaw_MatchesHex confirms that GetRootHashRaw produces +// the 32-byte hash portion of GetRootHashHex. +func TestGetRootHashRaw_MatchesHex(t *testing.T) { + tree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + addLeaf(t, tree, + mustHex(t, "0100000000000000000000000000000000000000000000000000000000000000"), + []byte("v1")) + addLeaf(t, tree, + mustHex(t, "0300000000000000000000000000000000000000000000000000000000000000"), + []byte("v2")) + + rawHex := hex.EncodeToString(tree.GetRootHashRaw()) + require.Len(t, rawHex, 64, "raw root must be 32 bytes hex-encoded") + + fullHex := tree.GetRootHashHex() + require.Equal(t, "0000"+rawHex, fullHex, "raw root must match the hash portion of the hex root") +} + +// addLeaf is a small helper that converts a 32-byte key to its path form +// and inserts it into the tree. +func addLeaf(t *testing.T, tree *SparseMerkleTree, key, value []byte) { + t.Helper() + path, err := api.FixedBytesToPath(key, api.StateTreeKeyLengthBits) + require.NoError(t, err) + require.NoError(t, tree.AddLeaf(path, value)) +} + +// bitmapPopcountForTest counts set bits in a bitmap. Lives in this test +// file so we don't export the unexported helper from pkg/api. +func bitmapPopcountForTest(b *[api.BitmapSize]byte) int { + total := 0 + for _, byteVal := range b { + total += bits8(byteVal) + } + return total +} + +// bits8 counts set bits in a single byte without pulling math/bits into +// the test file. +func bits8(x byte) int { + n := 0 + for x != 0 { + n += int(x & 1) + x >>= 1 + } + return n +} diff --git a/internal/smt/smt.go b/internal/smt/smt.go index b751bc4..3d7ce0e 100644 --- a/internal/smt/smt.go +++ b/internal/smt/smt.go @@ -11,10 +11,10 @@ import ( ) var ( - ErrDuplicateLeaf = errors.New("smt: duplicate leaf") - ErrLeafModification = errors.New("smt: attempt to modify an existing leaf") - ErrKeyLength = errors.New("smt: invalid key length") - ErrWrongShard = errors.New("smt: key does not belong in this shard") + ErrDuplicateLeaf = errors.New("smt: duplicate leaf") + ErrLeafModification = errors.New("smt: attempt to modify an existing leaf") + ErrKeyLength = errors.New("smt: invalid key length") + ErrWrongShard = errors.New("smt: key does not belong in this shard") ErrInvalidChildHashLength = errors.New("smt: child hash value length does not match hash algorithm") ) @@ -74,14 +74,17 @@ func NewSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int) *SparseMerk panic(fmt.Sprintf("smt: hash algorithm output (%d bytes) exceeds inline cache size (%d) — update smtCachedHashBytes", hashLen(algorithm), smtCachedHashBytes)) } - return &SparseMerkleTree{ + tree := &SparseMerkleTree{ parentMode: false, keyLength: keyLength, algorithm: algorithm, - root: newRootBranch(big.NewInt(1), nil, nil), + root: newRootBranch(big.NewInt(1), nil, nil, 0), isSnapshot: false, original: nil, } + // Prime the root hash to eliminate races on the first concurrent reads + tree.root.calculateHash(api.NewDataHasher(algorithm)) + return tree } // NewChildSparseMerkleTree creates a new sparse Merkle tree for a child aggregator in sharded setup @@ -100,14 +103,17 @@ func NewChildSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int, shardI if path.BitLen() > keyLength { panic("Shard ID must be shorter than SMT key length") } - return &SparseMerkleTree{ + tree := &SparseMerkleTree{ parentMode: false, keyLength: keyLength, algorithm: algorithm, - root: newRootBranch(path, nil, nil), + root: newRootBranch(path, nil, nil, path.BitLen()-1), isSnapshot: false, original: nil, } + // Prime the root hash to eliminate races on the first concurrent reads + tree.root.calculateHash(api.NewDataHasher(algorithm)) + return tree } // NewParentSparseMerkleTree creates a new sparse Merkle tree for the parent aggregator in sharded setup @@ -124,19 +130,24 @@ func NewParentSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int) *Spar // better to ensure all the leaves exist; otherwise the hash values // of siblings of the missing nodes would not match the structure of // the tree and the corresponding inclusion proofs would fail to verify - tree.root.Left = populate(0b10, keyLength) - tree.root.Right = populate(0b11, keyLength) + tree.root.Left = populate(0b10, keyLength, 1) + tree.root.Right = populate(0b11, keyLength, 1) + + // Mutation above invalidated the root hash primed by NewSparseMerkleTree. + // We reset and re-prime. + tree.root.hashSet = false + tree.root.calculateHash(api.NewDataHasher(algorithm)) return tree } -func populate(path, levels int) branch { +func populate(path, levels, depth int) branch { if levels == 1 { return newChildLeafBranch(big.NewInt(int64(path)), nil) } - left := populate(0b10, levels-1) - right := populate(0b11, levels-1) - return newNodeBranch(big.NewInt(int64(path)), left, right) + left := populate(0b10, levels-1, depth+1) + right := populate(0b11, levels-1, depth+1) + return newNodeBranchWithDepth(big.NewInt(int64(path)), left, right, depth) } // CreateSnapshot creates a snapshot of the current SMT state @@ -206,7 +217,7 @@ func (smt *SparseMerkleTree) CanModify() bool { // copyOnWriteRoot creates a new root if this snapshot is sharing it with the original func (smt *SparseMerkleTree) copyOnWriteRoot() *NodeBranch { if smt.original != nil && smt.root == smt.original.root { - return newRootBranch(smt.root.Path, smt.root.Left, smt.root.Right) + return newRootBranch(smt.root.Path, smt.root.Left, smt.root.Right, int(smt.root.Depth)) } return smt.root } @@ -219,13 +230,13 @@ func (smt *SparseMerkleTree) cloneBranch(branch branch) branch { if branch.isLeaf() { leafBranch := branch.(*LeafBranch) - cloned := newLeafBranch(leafBranch.Path, leafBranch.Value) + cloned := newLeafBranchWithKey(leafBranch.Path, leafBranch.Key, leafBranch.Value) // Preserve the isChild flag for parent mode trees cloned.isChild = leafBranch.isChild return cloned } else { nodeBranch := branch.(*NodeBranch) - return newNodeBranch(nodeBranch.Path, nodeBranch.Left, nodeBranch.Right) + return newNodeBranchWithDepth(nodeBranch.Path, nodeBranch.Left, nodeBranch.Right, int(nodeBranch.Depth)) } } @@ -239,6 +250,7 @@ type branch interface { // LeafBranch represents a leaf node type LeafBranch struct { Path *big.Int + Key []byte Value []byte rawHash [smtCachedHashBytes]byte // inline hash cache; valid when hashSet == true hashSet bool @@ -248,6 +260,7 @@ type LeafBranch struct { // NodeBranch represents an internal node type NodeBranch struct { Path *big.Int + Depth uint8 Left branch Right branch rawHash [smtCachedHashBytes]byte // inline hash cache; valid when hashSet == true @@ -257,8 +270,17 @@ type NodeBranch struct { // NewLeafBranch creates a regular leaf branch func newLeafBranch(path *big.Int, value []byte) *LeafBranch { + key, err := api.PathToFixedBytes(path, path.BitLen()-1) + if err != nil { + panic(fmt.Sprintf("smt: failed to derive key bytes for leaf path: %v", err)) + } + return newLeafBranchWithKey(path, key, value) +} + +func newLeafBranchWithKey(path *big.Int, key, value []byte) *LeafBranch { return &LeafBranch{ Path: new(big.Int).Set(path), + Key: append([]byte(nil), key...), Value: append([]byte(nil), value...), isChild: false, // Hash will be computed on demand @@ -267,11 +289,19 @@ func newLeafBranch(path *big.Int, value []byte) *LeafBranch { // NewChildLeafBranch creates a parent tree leaf containing the root hash of a child tree func newChildLeafBranch(path *big.Int, value []byte) *LeafBranch { + return newChildLeafBranchWithKey(path, nil, value) +} + +func newChildLeafBranchWithKey(path *big.Int, key, value []byte) *LeafBranch { if value != nil { value = append([]byte(nil), value...) } + if key != nil { + key = append([]byte(nil), key...) + } return &LeafBranch{ Path: new(big.Int).Set(path), + Key: key, Value: value, isChild: true, // Hash will be set on demand @@ -290,9 +320,12 @@ func (l *LeafBranch) calculateHash(hasher *api.DataHasher) []byte { l.hashSet = true return l.rawHash[:] } - pathBytes := api.BigintEncode(l.Path) - hasher.Reset().AddData(api.CborArray(2)). - AddCborBytes(pathBytes).AddCborBytes(l.Value) + // v2 leaf hashing with domain separation: + // H(0x00 || key || value) + hasher.Reset(). + AddData([]byte{0x00}). + AddData(l.Key). + AddData(l.Value) hasher.SumRaw(l.rawHash[:0]) l.hashSet = true return l.rawHash[:] @@ -308,23 +341,33 @@ func (l *LeafBranch) isLeaf() bool { // NewNodeBranch creates a regular node branch func newNodeBranch(path *big.Int, left, right branch) *NodeBranch { + return newNodeBranchWithDepth(path, left, right, path.BitLen()-1) +} + +func newNodeBranchWithDepth(path *big.Int, left, right branch, depth int) *NodeBranch { + if depth < 0 || depth > 255 { + panic(fmt.Sprintf("smt: node depth %d out of uint8 range [0, 255]", depth)) + } return &NodeBranch{ Path: new(big.Int).Set(path), + Depth: uint8(depth), Left: left, Right: right, isRoot: false, - // Hash will be computed on demand } } -// NewRootBranch creates a root node branch -func newRootBranch(path *big.Int, left, right branch) *NodeBranch { +// newRootBranch creates a root node branch +func newRootBranch(path *big.Int, left, right branch, depth int) *NodeBranch { + if depth < 0 || depth > 255 { + panic(fmt.Sprintf("smt: root depth %d out of uint8 range [0, 255]", depth)) + } return &NodeBranch{ Path: new(big.Int).Set(path), + Depth: uint8(depth), Left: left, Right: right, isRoot: true, - // Hash will be computed on demand } } @@ -342,29 +385,28 @@ func (n *NodeBranch) calculateHash(hasher *api.DataHasher) []byte { rightHash = n.Right.calculateHash(hasher) } - hasher.Reset().AddData(api.CborArray(3)) - - if n.isRoot && n.Path.BitLen() > 1 { - // This is root of a child tree in sharded setup - // The path to add is the last bit of the shard ID - pos := n.Path.BitLen() - 2 - path := big.NewInt(int64(2 + n.Path.Bit(pos))) - hasher.AddCborBytes(api.BigintEncode(path)) - } else { - // In all other cases we just add the actual path - hasher.AddCborBytes(api.BigintEncode(n.Path)) + // v2 SMT semantics for unary nodes: hash is child hash. + if leftHash == nil && rightHash != nil { + copy(n.rawHash[:], rightHash) + n.hashSet = true + return n.rawHash[:] } - - if leftHash == nil { - hasher.AddCborNull() - } else { - hasher.AddCborBytes(leftHash) + if rightHash == nil && leftHash != nil { + copy(n.rawHash[:], leftHash) + n.hashSet = true + return n.rawHash[:] } - if rightHash == nil { - hasher.AddCborNull() - } else { - hasher.AddCborBytes(rightHash) + // Keep root hash stable for empty trees by hashing domain+level when both + // children are empty. + hasher.Reset(). + AddData([]byte{0x01}). + AddData([]byte{n.Depth}) + if leftHash != nil { + hasher.AddData(leftHash) + } + if rightHash != nil { + hasher.AddData(rightHash) } hasher.SumRaw(n.rawHash[:0]) @@ -388,6 +430,10 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { if calculateCommonPath(path, smt.root.Path).BitLen() != smt.root.Path.BitLen() { return ErrWrongShard } + leafKey, err := api.PathToFixedBytes(path, smt.keyLength) + if err != nil { + return fmt.Errorf("failed to derive leaf key bytes from path: %w", err) + } if smt.parentMode && value != nil && len(value) != hashLen(smt.algorithm) { return ErrInvalidChildHashLength } @@ -412,13 +458,13 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { } else { rightBranch = smt.root.Right } - newRight, err := smt.buildTree(rightBranch, shifted, value) + newRight, err := smt.buildTree(rightBranch, shifted, leafKey, value, int(smt.root.Depth)) if err != nil { return err } right = newRight } else { - right = newLeafBranch(shifted, value) + right = newLeafBranchWithKey(shifted, leafKey, value) } } else { if smt.root.Left != nil { @@ -429,18 +475,18 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { } else { leftBranch = smt.root.Left } - newLeft, err := smt.buildTree(leftBranch, shifted, value) + newLeft, err := smt.buildTree(leftBranch, shifted, leafKey, value, int(smt.root.Depth)) if err != nil { return err } left = newLeft } else { - left = newLeafBranch(shifted, value) + left = newLeafBranchWithKey(shifted, leafKey, value) } right = smt.root.Right } - smt.root = newRootBranch(smt.root.Path, left, right) + smt.root = newRootBranch(smt.root.Path, left, right, int(smt.root.Depth)) return nil } @@ -464,6 +510,14 @@ func (smt *SparseMerkleTree) AddLeaves(leaves []*Leaf) error { return nil } +// ensureHashes eagerly computes all node hashes in the tree. Must be called +// under a write lock so that subsequent readers under RLock find every +// hashSet flag already true and never mutate node state. +func (smt *SparseMerkleTree) ensureHashes() { + hasher := api.NewDataHasher(smt.algorithm) + smt.root.calculateHash(hasher) +} + // GetRootHash returns the root hash as imprint func (smt *SparseMerkleTree) GetRootHash() []byte { // Create a new hasher to ensure thread safety @@ -478,6 +532,167 @@ func (smt *SparseMerkleTree) GetRootHashHex() string { return fmt.Sprintf("%x", buildImprint(smt.algorithm, smt.root.calculateHash(hasher))) } +// GetRootHashRaw returns the raw 32-byte root hash without the algorithm +// prefix. This is the canonical v2 root hash consumed by +// api.InclusionCert verification and by UC.IR.h binding. +func (smt *SparseMerkleTree) GetRootHashRaw() []byte { + hasher := api.NewDataHasher(smt.algorithm) + raw := smt.root.calculateHash(hasher) + out := make([]byte, len(raw)) + copy(out, raw) + return out +} + +// GetInclusionCert builds a v2 inclusion certificate for the leaf at the +// given raw 32-byte key. Verifier consumes bitmap + siblings in root-to-leaf +// wire order. +// +// Returns an error if no leaf exists at the key. Non-inclusion certificates +// are produced by a separate path (not yet implemented). +func (smt *SparseMerkleTree) GetInclusionCert(key []byte) (*api.InclusionCert, error) { + if len(key) != api.StateTreeKeyLengthBytes { + return nil, fmt.Errorf("%w: got %d, want %d", api.ErrCertKeyLength, len(key), api.StateTreeKeyLengthBytes) + } + + // Prime the hash cache by hashing the root. This cascades through every + // node reachable from the root, so sibling reads below are cache hits. + hasher := api.NewDataHasher(smt.algorithm) + _ = smt.root.calculateHash(hasher) + + var cert api.InclusionCert + if _, err := smt.generateInclusionCertWithLeafValue(hasher, key, smt.root, &cert); err != nil { + return nil, err + } + return &cert, nil +} + +// GetShardInclusionFragment builds the native parent proof fragment used by a +// child aggregator to later compose a full v2 proof locally. The fragment uses +// the same bitmap+sibling wire shape as InclusionCert but only contains the +// shallow parent-tree depths. The returned shard leaf value must equal the +// child tree root for composition to proceed. +func (smt *SparseMerkleTree) GetShardInclusionFragment(shardID api.ShardID) (*api.ParentInclusionFragment, error) { + if !smt.parentMode { + return nil, fmt.Errorf("smt: shard inclusion fragment only valid for parent trees") + } + + path := big.NewInt(int64(shardID)) + if path.BitLen()-1 != smt.keyLength { + return nil, ErrKeyLength + } + + key, err := api.PathToFixedBytes(path, smt.keyLength) + if err != nil { + return nil, fmt.Errorf("failed to derive shard key bytes: %w", err) + } + + hasher := api.NewDataHasher(smt.algorithm) + _ = smt.root.calculateHash(hasher) + + var cert api.InclusionCert + leafValue, err := smt.generateInclusionCertWithLeafValue(hasher, key, smt.root, &cert) + if err != nil { + return nil, err + } + if len(leafValue) == 0 { + return nil, nil + } + + certBytes, err := cert.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("failed to encode parent fragment cert: %w", err) + } + return &api.ParentInclusionFragment{ + CertificateBytes: api.NewHexBytes(certBytes), + ShardLeafValue: api.NewHexBytes(leafValue), + }, nil +} + +// generateInclusionCert recursively walks from the current node toward the +// leaf matching key, appending siblings and setting bitmap bits at every +// 2-child branching node along the path. Unary passthrough nodes contribute +// nothing to either the bitmap or the sibling list. +func (smt *SparseMerkleTree) generateInclusionCert(hasher *api.DataHasher, key []byte, current branch, cert *api.InclusionCert) error { + _, err := smt.generateInclusionCertWithLeafValue(hasher, key, current, cert) + return err +} + +func (smt *SparseMerkleTree) generateInclusionCertWithLeafValue(hasher *api.DataHasher, key []byte, current branch, cert *api.InclusionCert) ([]byte, error) { + if current == nil { + return nil, fmt.Errorf("smt: inclusion cert traversal reached nil subtree") + } + if current.isLeaf() { + leaf := current.(*LeafBranch) + // Parent-mode populate() creates placeholder leaves with nil Key and nil + // Value. For those placeholders, key equality is validated only after the + // first real AddLeaf replaces them with a keyed leaf. + if leaf.Key != nil && !bytes.Equal(leaf.Key, key) { + return nil, fmt.Errorf("smt: leaf not found for key %x", key) + } + if leaf.Value == nil { + return nil, nil + } + return append([]byte(nil), leaf.Value...), nil + } + + node := current.(*NodeBranch) + if node.Left != nil && node.Right != nil { + depth := int(node.Depth) + if depth < 0 || depth >= api.StateTreeKeyLengthBits { + return nil, fmt.Errorf("smt: invalid branch depth %d", depth) + } + + var sibling, child branch + if keyBit(key, depth) == 0 { + sibling = node.Right + child = node.Left + } else { + sibling = node.Left + child = node.Right + } + + childHash := child.calculateHash(hasher) + if childHash == nil { + return nil, nil + } + + sibHash := sibling.calculateHash(hasher) + if sibHash == nil { + // The sibling subtree is empty (all-placeholder / no submitted leaf), + // so this depth is a unary passthrough: no bitmap bit and no sibling. + // Any already recorded shallower siblings remain valid. + return smt.generateInclusionCertWithLeafValue(hasher, key, child, cert) + } + + cert.Bitmap[depth/8] |= 1 << (uint(depth) % 8) + if len(sibHash) != api.SiblingSize { + return nil, fmt.Errorf("smt: sibling hash unexpected length: got %d, want %d", len(sibHash), api.SiblingSize) + } + var sib [api.SiblingSize]byte + copy(sib[:], sibHash) + cert.Siblings = append(cert.Siblings, sib) + + return smt.generateInclusionCertWithLeafValue(hasher, key, child, cert) + } + + // Unary passthrough: typically only at the root when the tree holds + // ≤1 leaves. The v2 rule says a single-child node's hash equals the + // child's hash, so no bitmap bit and no sibling are added. + if node.Left != nil { + return smt.generateInclusionCertWithLeafValue(hasher, key, node.Left, cert) + } + if node.Right != nil { + return smt.generateInclusionCertWithLeafValue(hasher, key, node.Right, cert) + } + return nil, fmt.Errorf("smt: reached empty subtree in inclusion cert traversal") +} + +// keyBit returns bit d of the raw key under LSB-first byte layout. +// Matches api.keyBitAt. +func keyBit(key []byte, d int) byte { + return (key[d/8] >> (uint(d) % 8)) & 1 +} + // GetLeaf retrieves a leaf by path (for compatibility) func (smt *SparseMerkleTree) GetLeaf(path *big.Int) (*LeafBranch, error) { return smt.findLeafInBranch(smt.root, path) @@ -519,13 +734,14 @@ func (smt *SparseMerkleTree) findLeafInBranch(branch branch, targetPath *big.Int } } -// buildTree matches TypeScript buildTree logic exactly -func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, value []byte) (branch, error) { +// buildTree updates a subtree while preserving the leaf full key bytes and +// tracking absolute node depth (0..255) for v2 node hashing. +func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, leafKey, value []byte, depthOffset int) (branch, error) { // Special checks for adding a leaf that already exists in the tree if branch.isLeaf() && branch.getPath().Cmp(remainingPath) == 0 { leafBranch := branch.(*LeafBranch) if leafBranch.isChild { - return newChildLeafBranch(leafBranch.Path, value), nil + return newChildLeafBranchWithKey(leafBranch.Path, leafKey, value), nil } else if bytes.Equal(leafBranch.Value, value) { return nil, ErrDuplicateLeaf } else { @@ -550,16 +766,18 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, va // TypeScript: branch.path >> commonPath.length oldBranchPath := new(big.Int).Rsh(leafBranch.Path, uint(commonPath.BitLen()-1)) - oldBranch := newLeafBranch(oldBranchPath, leafBranch.Value) + oldBranch := newLeafBranchWithKey(oldBranchPath, leafBranch.Key, leafBranch.Value) // TypeScript: remainingPath >> commonPath.length newBranchPath := new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) - newBranch := newLeafBranch(newBranchPath, value) + newBranch := newLeafBranchWithKey(newBranchPath, leafKey, value) + + nodeDepth := depthOffset + (commonPath.BitLen() - 1) if isRight { - return newNodeBranch(commonPath, oldBranch, newBranch), nil + return newNodeBranchWithDepth(commonPath, oldBranch, newBranch, nodeDepth), nil } else { - return newNodeBranch(commonPath, newBranch, oldBranch), nil + return newNodeBranchWithDepth(commonPath, newBranch, oldBranch, nodeDepth), nil } } @@ -567,30 +785,33 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, va nodeBranch := branch.(*NodeBranch) if commonPath.Cmp(nodeBranch.Path) < 0 { newBranchPath := new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) - newBranch := newLeafBranch(newBranchPath, value) + newBranch := newLeafBranchWithKey(newBranchPath, leafKey, value) oldBranchPath := new(big.Int).Rsh(nodeBranch.Path, uint(commonPath.BitLen()-1)) - oldBranch := newNodeBranch(oldBranchPath, nodeBranch.Left, nodeBranch.Right) + oldBranch := newNodeBranchWithDepth(oldBranchPath, nodeBranch.Left, nodeBranch.Right, int(nodeBranch.Depth)) + + nodeDepth := depthOffset + (commonPath.BitLen() - 1) if isRight { - return newNodeBranch(commonPath, oldBranch, newBranch), nil + return newNodeBranchWithDepth(commonPath, oldBranch, newBranch, nodeDepth), nil } else { - return newNodeBranch(commonPath, newBranch, oldBranch), nil + return newNodeBranchWithDepth(commonPath, newBranch, oldBranch, nodeDepth), nil } } + nextDepthOffset := depthOffset + (commonPath.BitLen() - 1) if isRight { - newRight, err := smt.buildTree(nodeBranch.Right, new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)), value) + newRight, err := smt.buildTree(nodeBranch.Right, new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)), leafKey, value, nextDepthOffset) if err != nil { return nil, err } - return newNodeBranch(nodeBranch.Path, nodeBranch.Left, newRight), nil + return newNodeBranchWithDepth(nodeBranch.Path, nodeBranch.Left, newRight, int(nodeBranch.Depth)), nil } else { - newLeft, err := smt.buildTree(nodeBranch.Left, new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)), value) + newLeft, err := smt.buildTree(nodeBranch.Left, new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)), leafKey, value, nextDepthOffset) if err != nil { return nil, err } - return newNodeBranch(nodeBranch.Path, newLeft, nodeBranch.Right), nil + return newNodeBranchWithDepth(nodeBranch.Path, newLeft, nodeBranch.Right, int(nodeBranch.Depth)), nil } } @@ -733,40 +954,3 @@ func NewLeaf(path *big.Int, value []byte) *Leaf { Value: append([]byte(nil), value...), } } - -// JoinPaths joins the hash proofs from a child and parent in sharded setting -func JoinPaths(child, parent *api.MerkleTreePath) (*api.MerkleTreePath, error) { - if child == nil { - return nil, errors.New("nil child path") - } - if parent == nil { - return nil, errors.New("nil parent path") - } - - // Root hashes are hex-encoded imprints, the first 4 characters are hash function identifiers - if len(child.Root) < 4 { - return nil, errors.New("invalid child root hash format") - } - if len(parent.Root) < 4 { - return nil, errors.New("invalid parent root hash format") - } - if child.Root[:4] != parent.Root[:4] { - return nil, errors.New("can't join paths: child hash algorithm does not match parent") - } - - if len(parent.Steps) == 0 { - return nil, errors.New("empty parent hash steps") - } - if parent.Steps[0].Data == nil || *parent.Steps[0].Data != child.Root[4:] { - return nil, errors.New("can't join paths: child root hash does not match parent input hash") - } - - steps := make([]api.MerkleTreeStep, len(child.Steps)+len(parent.Steps)-1) - copy(steps, child.Steps) - copy(steps[len(child.Steps):], parent.Steps[1:]) - - return &api.MerkleTreePath{ - Root: parent.Root, - Steps: steps, - }, nil -} diff --git a/internal/smt/smt_debug_test.go b/internal/smt/smt_debug_test.go index 927b8f0..e7037b9 100644 --- a/internal/smt/smt_debug_test.go +++ b/internal/smt/smt_debug_test.go @@ -46,7 +46,7 @@ func TestAddLeaves_DebugInvalidPath(t *testing.T) { return rh } - _smt := NewSparseMerkleTree(api.SHA256, 16+256) + _smt := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) { // mint commitment commJson := map[string]interface{}{ "stateId": "00007d535ade796772c5088b095e79a18e282437ee8d8238f5aa9d9c61694948ba9e", diff --git a/internal/smt/smt_memory_benchmark_test.go b/internal/smt/smt_memory_benchmark_test.go index 37883d5..fbc18f6 100644 --- a/internal/smt/smt_memory_benchmark_test.go +++ b/internal/smt/smt_memory_benchmark_test.go @@ -76,7 +76,7 @@ func BenchmarkSMTMemoryUsageRealistic(b *testing.B) { runtime.ReadMemStats(&memBefore) // Create SMT with realistic key length (16 bits shard prefix + 256 bits hash) - smtTree := NewSparseMerkleTree(api.SHA256, 16+256) + smtTree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) // Generate realistic commitments and add as leaves leaves := make([]*Leaf, size) @@ -141,7 +141,7 @@ func BenchmarkSMTOperationsWithLoad(b *testing.B) { // Pre-populate with 100k leaves const preloadSize = 100_000 - smtTree := NewSparseMerkleTree(api.SHA256, 16+256) + smtTree := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) leaves := make([]*Leaf, preloadSize) paths := make([]*big.Int, preloadSize) diff --git a/internal/smt/smt_test.go b/internal/smt/smt_test.go index d1106b2..ab05d15 100644 --- a/internal/smt/smt_test.go +++ b/internal/smt/smt_test.go @@ -14,44 +14,56 @@ import ( "github.com/stretchr/testify/require" ) +func normalizeLegacyPath(t *testing.T, raw string) *big.Int { + t.Helper() + + legacyPath, ok := new(big.Int).SetString(raw, 10) + require.True(t, ok, "failed to parse path") + + key, err := api.PathToFixedBytes(legacyPath, legacyPath.BitLen()-1) + require.NoError(t, err) + key, err = api.ImprintV2(key).GetTreeKey() + require.NoError(t, err) + + path, err := api.FixedBytesToPath(key, api.StateTreeKeyLengthBits) + require.NoError(t, err) + return path +} + // TestSMTGetRoot test basic SMT root hash computation func TestSMTGetRoot(t *testing.T) { - // "Singleton" example from the spec + // v2 reference values for basic tree shapes. t.Run("EmptyTree", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) - expected := "00001e54402898172f2948615fb17627733abbd120a85381c624ad060d28321be672" + expected := "000047dc540c94ceb704a23875c11273e16bb0b8a87aed84de911f2133568115f254" require.Equal(t, expected, smt.GetRootHashHex()) }) - // "Left Child Only" example from the spec t.Run("LeftLeaf", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) - expected := "0000ccd73506d27518c983860a47a6a323d41038a74f9339f5302798563cb168f12f" + expected := "0000d4cb5334dcabbcaff56bfc78706f041b72c0d29337db87d8c85d4e1aaf9fea3a" require.Equal(t, expected, smt.GetRootHashHex()) }) - // "Right Child Only" example from the spec t.Run("RightLeaf", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b111), []byte{0x62}) - expected := "00005219d2dac90ad497a82a5231f10cffaf5a12dc65b762be39a6d739b4159136a3" + expected := "000064a2f31a60210df058e75a10312c486538f8874e4681de085e3e2d9985b5fd50" require.Equal(t, expected, smt.GetRootHashHex()) }) - // "Two Leaves" example from the spec t.Run("TwoLeaves", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) smt.AddLeaf(big.NewInt(0b111), []byte{0x62}) - expected := "0000b5fcdedf0f5e9cdaec060d8963b5ea86fcd16b7a48fa8607a3347a213316b857" + expected := "0000f0698f0230044b700c1e5e433f7776b8af113199905b6122b19504274dd77111" require.Equal(t, expected, smt.GetRootHashHex()) }) - // "Four Leaves" example from the spec t.Run("FourLeaves", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 3) smt.AddLeaf(big.NewInt(0b1000), []byte{0x61}) @@ -59,53 +71,48 @@ func TestSMTGetRoot(t *testing.T) { smt.AddLeaf(big.NewInt(0b1011), []byte{0x63}) smt.AddLeaf(big.NewInt(0b1111), []byte{0x64}) - expected := "000095005e568fdac5cc01a3a091c70ce89ab2da98c36b254dd2ddf29bd568c377ab" + expected := "0000728a4e5f71d239df87b57bdf1e3bd5ca3383d2b0d16758a9b3f2aedff02e4c24" require.Equal(t, expected, smt.GetRootHashHex()) }) } func TestChildSMTGetRoot(t *testing.T) { - // Left child of the "Two Leaves, Sharded" example from the spec t.Run("LeftOfTwoLeaves", func(t *testing.T) { smt := NewChildSparseMerkleTree(api.SHA256, 2, 0b10) smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) - expected := "0000256aedd9f31e69a4b0803616beab77234bae5dff519a10e519a0753be49f0534" + expected := "0000d4cb5334dcabbcaff56bfc78706f041b72c0d29337db87d8c85d4e1aaf9fea3a" require.Equal(t, expected, smt.GetRootHashHex()) }) - // Right child of the "Two Leaves, Sharded" example from the spec t.Run("RightOfTwoLeaves", func(t *testing.T) { smt := NewChildSparseMerkleTree(api.SHA256, 2, 0b11) smt.AddLeaf(big.NewInt(0b111), []byte{0x62}) - expected := "0000e777763b4ce391c2f8acdf480dd64758bc8063a3aa5f62670a499a61d3bc7b9a" + expected := "000064a2f31a60210df058e75a10312c486538f8874e4681de085e3e2d9985b5fd50" require.Equal(t, expected, smt.GetRootHashHex()) }) - // Left child of the "Four Leaves, Sharded" example from the spec t.Run("LeftOfFourLeaves", func(t *testing.T) { smt := NewChildSparseMerkleTree(api.SHA256, 4, 0b110) smt.AddLeaf(big.NewInt(0b10010), []byte{0x61}) smt.AddLeaf(big.NewInt(0b11010), []byte{0x62}) - expected := "000010c1dc89e30d51613f2c1a182d16f87fe6709b9735db612adaadaa91955bdaf0" + expected := "0000564b213cf6cee27badc130c7b9c7f06c27b76e8bbe25149e1412646d24027d2d" require.Equal(t, expected, smt.GetRootHashHex()) }) - // Right child of the "Four Leaves, Sharded" example from the spec t.Run("RightOfFourLeaves", func(t *testing.T) { smt := NewChildSparseMerkleTree(api.SHA256, 4, 0b101) smt.AddLeaf(big.NewInt(0b10101), []byte{0x63}) smt.AddLeaf(big.NewInt(0b11101), []byte{0x64}) - expected := "0000981d2f4e01189506c5a36430e7774e3f9498c1c4cc27801d8e6400d4965a8860" + expected := "0000c5f0538e97bb172a7e423848673faa84141b2201cc803b328f1824299f24dd7f" require.Equal(t, expected, smt.GetRootHashHex()) }) } func TestParentSMTGetRoot(t *testing.T) { - // Parent of the "Two Leaves, Sharded" example from the spec t.Run("TwoLeaves", func(t *testing.T) { left, _ := hex.DecodeString("256aedd9f31e69a4b0803616beab77234bae5dff519a10e519a0753be49f0534") right, _ := hex.DecodeString("e777763b4ce391c2f8acdf480dd64758bc8063a3aa5f62670a499a61d3bc7b9a") @@ -113,11 +120,10 @@ func TestParentSMTGetRoot(t *testing.T) { smt.AddLeaf(big.NewInt(0b10), left) smt.AddLeaf(big.NewInt(0b11), right) - expected := "0000413b961d0069adfea0b4e122cf6dbf98e0a01ef7fd573d68c084ddfa03e4f9d6" + expected := "0000245915b6e866e0dfa36eb5c1323325c6663bd0ea7fe9ea7c60efe54700901577" require.Equal(t, expected, smt.GetRootHashHex()) }) - // Parent of the "Four Leaves, Sharded" example from the spec t.Run("FourLeaves", func(t *testing.T) { left, _ := hex.DecodeString("10c1dc89e30d51613f2c1a182d16f87fe6709b9735db612adaadaa91955bdaf0") right, _ := hex.DecodeString("981d2f4e01189506c5a36430e7774e3f9498c1c4cc27801d8e6400d4965a8860") @@ -125,7 +131,7 @@ func TestParentSMTGetRoot(t *testing.T) { smt.AddLeaf(big.NewInt(0b110), left) smt.AddLeaf(big.NewInt(0b101), right) - expected := "0000eb1a95574056c988f441a50bd18d0555f038276aecf3d155eb9e008a72afcb45" + expected := "00001f52283972b0b30de79673b0a889357af74859504f70a657c7516ab77b698302" require.Equal(t, expected, smt.GetRootHashHex()) }) } @@ -237,7 +243,7 @@ func TestSMTBatchOperations(t *testing.T) { // TestSMTRootHashRegressionFixture pins an implementation reference root hash // for a fixed leaf set, so refactors cannot accidentally change hash behavior. func TestSMTRootHashRegressionFixture(t *testing.T) { - const expectedRoot = "0000273e7c3e2d1a8d9194babbff44377e27facc2007893f0dd0db09f6406e04390f" + const expectedRoot = "00008f12d069a0a8d02649dae4485d97ea1d98f2742b5c22de64a4f331b6f0b7b7dd" leaves := []*Leaf{ NewLeaf(big.NewInt(0b110010000), []byte("value00010000")), // 400 @@ -264,9 +270,9 @@ func TestSMTRootHashRegressionFixture(t *testing.T) { // deterministic child-root inputs. func TestParentSMTRootHashRegressionFixture(t *testing.T) { const ( - expectedEmpty = "00001ef37149189b0c122d9ffc5eac0652c0d136ba71061a98dab823e4ba2544ee3f" - expectedOneUpdate = "0000be056f14098551f711bcf4ed2fa9becff6467565faffeaf612602d8a3fccdc9b" - expectedTwoUpdates = "000062a64635665e45b753e21e0489443a8589c8561094c3137201a2c275540d1aa8" + expectedEmpty = "0000cd123cd6893ea82539bbce16cd69f196ad3770a1a16806acd624061684f04c22" + expectedOneUpdate = "0000b3bf509ebc9114647fd69f72b817b257b07b8bd32ed82f6b85b8f5b19dedcfc8" + expectedTwoUpdates = "00000eb669e4b5572cd9c2cf3b4a18b491354f67c0591d549cc2879a63defe0a7759" ) make32 := func(start byte) []byte { @@ -546,11 +552,11 @@ func TestSMTProductionTiming(t *testing.T) { func TestSMTGetPath(t *testing.T) { t.Run("ExpectedPath", func(t *testing.T) { - smt := NewSparseMerkleTree(api.SHA256, 272) + smt := NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) - // Add some test data - path := big.NewInt(0) - path.SetString("7588617121771513359933852905331119149238064034818011809301695587375759386505263024", 10) + // Normalize the old 272-bit sentinel-prefixed fixture to the current + // 256-bit tree-key path form. + path := normalizeLegacyPath(t, "7588617121771513359933852905331119149238064034818011809301695587375759386505263024") leafValue, err := hex.DecodeString("00000777e81da35187bc52073e96a10f89d7fe9aa826693982c8e748a96a3cc7d7b7") require.NoError(t, err) @@ -566,7 +572,7 @@ func TestSMTGetPath(t *testing.T) { require.NotNil(t, merklePath.Steps, "Steps should not be nil") require.Equal(t, 2, len(merklePath.Steps), "There should be exactly two steps in the path") // First step should be the LeafNode hash step - require.Equal(t, "7588617121771513359933852905331119149238064034818011809301695587375759386505263024", merklePath.Steps[0].Path, "Leaf step path should match") + require.Equal(t, path.String(), merklePath.Steps[0].Path, "Leaf step path should match") require.Equal(t, "00000777e81da35187bc52073e96a10f89d7fe9aa826693982c8e748a96a3cc7d7b7", *merklePath.Steps[0].Data, "Leaf step value should match") // Second step should be the root hash step require.Equal(t, "1", merklePath.Steps[1].Path, "Root step path should be empty") @@ -1375,155 +1381,6 @@ func TestSMTAddingLeafAboveNode(t *testing.T) { require.Error(t, smt2.AddLeaves(leaves2), "SMT should not allow adding leaves above existing nodes, even in a batch") } -func TestJoinPaths(t *testing.T) { - // "Two Leaves, Sharded" example from the spec - t.Run("TwoLeaves", func(t *testing.T) { - left := NewChildSparseMerkleTree(api.SHA256, 2, 0b10) - left.AddLeaf(big.NewInt(0b100), []byte{0x61}) - - right := NewChildSparseMerkleTree(api.SHA256, 2, 0b11) - right.AddLeaf(big.NewInt(0b111), []byte{0x62}) - - parent := NewParentSparseMerkleTree(api.SHA256, 1) - parent.AddLeaf(big.NewInt(0b10), left.GetRootHash()[2:]) - parent.AddLeaf(big.NewInt(0b11), right.GetRootHash()[2:]) - - leftChild, _ := left.GetPath(big.NewInt(0b100)) - leftParent, _ := parent.GetPath(big.NewInt(0b10)) - leftPath, err := JoinPaths(leftChild, leftParent) - assert.Nil(t, err) - assert.NotNil(t, leftPath) - leftRes, err := leftPath.Verify(big.NewInt(0b100)) - assert.Nil(t, err) - assert.NotNil(t, leftRes) - assert.True(t, leftRes.PathValid) - assert.True(t, leftRes.PathIncluded) - - rightChild, _ := right.GetPath(big.NewInt(0b111)) - rightParent, _ := parent.GetPath(big.NewInt(0b11)) - rightPath, err := JoinPaths(rightChild, rightParent) - assert.Nil(t, err) - assert.NotNil(t, rightPath) - rightRes, err := rightPath.Verify(big.NewInt(0b111)) - assert.Nil(t, err) - assert.NotNil(t, rightRes) - assert.True(t, rightRes.PathValid) - assert.True(t, rightRes.PathIncluded) - }) - - // "Four Leaves, Sharded" example from the spec - t.Run("FourLeaves", func(t *testing.T) { - left := NewChildSparseMerkleTree(api.SHA256, 4, 0b110) - left.AddLeaf(big.NewInt(0b10010), []byte{0x61}) - left.AddLeaf(big.NewInt(0b11010), []byte{0x62}) - - right := NewChildSparseMerkleTree(api.SHA256, 4, 0b101) - right.AddLeaf(big.NewInt(0b10101), []byte{0x63}) - right.AddLeaf(big.NewInt(0b11101), []byte{0x64}) - - parent := NewParentSparseMerkleTree(api.SHA256, 2) - parent.AddLeaf(big.NewInt(0b110), left.GetRootHash()[2:]) - parent.AddLeaf(big.NewInt(0b101), right.GetRootHash()[2:]) - - child1, _ := left.GetPath(big.NewInt(0b10010)) - parent1, _ := parent.GetPath(big.NewInt(0b110)) - path1, err := JoinPaths(child1, parent1) - assert.Nil(t, err) - assert.NotNil(t, path1) - res1, err := path1.Verify(big.NewInt(0b10010)) - assert.Nil(t, err) - assert.NotNil(t, res1) - assert.True(t, res1.PathValid) - assert.True(t, res1.PathIncluded) - - child2, _ := left.GetPath(big.NewInt(0b11010)) - parent2, _ := parent.GetPath(big.NewInt(0b110)) - path2, err := JoinPaths(child2, parent2) - assert.Nil(t, err) - assert.NotNil(t, path2) - res2, err := path2.Verify(big.NewInt(0b11010)) - assert.Nil(t, err) - assert.NotNil(t, res2) - assert.True(t, res2.PathValid) - assert.True(t, res2.PathIncluded) - - child3, _ := right.GetPath(big.NewInt(0b10101)) - parent3, _ := parent.GetPath(big.NewInt(0b101)) - path3, err := JoinPaths(child3, parent3) - assert.Nil(t, err) - assert.NotNil(t, path3) - res3, err := path3.Verify(big.NewInt(0b10101)) - assert.Nil(t, err) - assert.NotNil(t, res3) - assert.True(t, res3.PathValid) - assert.True(t, res3.PathIncluded) - - child4, _ := right.GetPath(big.NewInt(0b11101)) - parent4, _ := parent.GetPath(big.NewInt(0b101)) - path4, err := JoinPaths(child4, parent4) - assert.Nil(t, err) - assert.NotNil(t, path4) - res4, err := path4.Verify(big.NewInt(0b11101)) - assert.Nil(t, err) - assert.NotNil(t, res4) - assert.True(t, res4.PathValid) - assert.True(t, res4.PathIncluded) - }) - - t.Run("NilInputDoesNotPanic", func(t *testing.T) { - dummy := &api.MerkleTreePath{Root: "0000"} - - joinedPath, err := JoinPaths(nil, dummy) - assert.ErrorContains(t, err, "nil child path") - assert.Nil(t, joinedPath) - - joinedPath, err = JoinPaths(dummy, nil) - assert.ErrorContains(t, err, "nil parent path") - assert.Nil(t, joinedPath) - }) - - t.Run("NilRootDoesNotPanic", func(t *testing.T) { - dummyNil := &api.MerkleTreePath{} - dummyShort := &api.MerkleTreePath{Root: ""} - dummyOK := &api.MerkleTreePath{Root: "0000"} - - joinedPath, err := JoinPaths(dummyNil, dummyOK) - assert.ErrorContains(t, err, "invalid child root hash format") - assert.Nil(t, joinedPath) - - joinedPath, err = JoinPaths(dummyShort, dummyOK) - assert.ErrorContains(t, err, "invalid child root hash format") - assert.Nil(t, joinedPath) - - joinedPath, err = JoinPaths(dummyOK, dummyNil) - assert.ErrorContains(t, err, "invalid parent root hash format") - assert.Nil(t, joinedPath) - - joinedPath, err = JoinPaths(dummyOK, dummyShort) - assert.ErrorContains(t, err, "invalid parent root hash format") - assert.Nil(t, joinedPath) - }) - - t.Run("HashFunctionMismatch", func(t *testing.T) { - child := &api.MerkleTreePath{Root: "0000"} - parent := &api.MerkleTreePath{Root: "0001"} - - joinedPath, err := JoinPaths(child, parent) - assert.ErrorContains(t, err, "child hash algorithm does not match parent") - assert.Nil(t, joinedPath) - }) - - t.Run("HashValueMismatch", func(t *testing.T) { - smt := NewSparseMerkleTree(api.SHA256, 1) - smt.AddLeaf(big.NewInt(0b10), []byte{0}) - path, _ := smt.GetPath(big.NewInt(0b10)) - - joinedPath, err := JoinPaths(path, path) - assert.ErrorContains(t, err, "child root hash does not match parent input hash") - assert.Nil(t, joinedPath) - }) -} - // TestParentSMTSnapshotUpdateLeaf tests that parent SMT snapshots can update pre-populated leaves func TestParentSMTSnapshotUpdateLeaf(t *testing.T) { // Create parent SMT with ShardIDLength=1 (creates pre-populated leaves at paths 2 and 3) diff --git a/internal/smt/thread_safe_smt.go b/internal/smt/thread_safe_smt.go index fa5da1f..5ead109 100644 --- a/internal/smt/thread_safe_smt.go +++ b/internal/smt/thread_safe_smt.go @@ -17,6 +17,9 @@ type ThreadSafeSMT struct { // NewThreadSafeSMT creates a new thread-safe SMT wrapper func NewThreadSafeSMT(smtInstance *SparseMerkleTree) *ThreadSafeSMT { + // Prime hash caches before publishing the wrapper so no later read path + // under RLock needs to mutate node state on a freshly constructed tree. + smtInstance.ensureHashes() return &ThreadSafeSMT{ smt: smtInstance, } @@ -42,7 +45,11 @@ func (ts *ThreadSafeSMT) AddLeaf(path *big.Int, value []byte) error { ts.rwMux.Lock() defer ts.rwMux.Unlock() - return ts.smt.AddLeaf(path, value) + if err := ts.smt.AddLeaf(path, value); err != nil { + return err + } + ts.smt.ensureHashes() + return nil } // AddPreHashedLeaf adds a leaf where the value is already a hash calculated externally @@ -51,8 +58,10 @@ func (ts *ThreadSafeSMT) AddPreHashedLeaf(path *big.Int, hash []byte) error { ts.rwMux.Lock() defer ts.rwMux.Unlock() - // TODO(SMT): Implement AddPreHashedLeaf in SparseMerkleTree - //return ts.smt.AddPreHashedLeaf(path, hash) + if err := ts.smt.AddLeaf(path, hash); err != nil { + return err + } + ts.smt.ensureHashes() return nil } @@ -82,6 +91,49 @@ func (ts *ThreadSafeSMT) GetPath(path *big.Int) (*api.MerkleTreePath, error) { return ts.smt.GetPath(path) } +// GetRootHashRaw returns the raw 32-byte root hash without algorithm prefix. +// This is a read operation and allows concurrent access. +func (ts *ThreadSafeSMT) GetRootHashRaw() []byte { + ts.rwMux.RLock() + defer ts.rwMux.RUnlock() + return ts.smt.GetRootHashRaw() +} + +// GetInclusionCert builds a v2 inclusion certificate for the leaf +// at the given raw 32-byte key. This is a read operation and allows +// concurrent access. +func (ts *ThreadSafeSMT) GetInclusionCert(key []byte) (*api.InclusionCert, error) { + ts.rwMux.RLock() + defer ts.rwMux.RUnlock() + return ts.smt.GetInclusionCert(key) +} + +// GetShardInclusionFragment builds the native parent proof fragment for the +// given shard ID. This is only valid on parent-mode SMT instances. +func (ts *ThreadSafeSMT) GetShardInclusionFragment(shardID api.ShardID) (*api.ParentInclusionFragment, error) { + ts.rwMux.RLock() + defer ts.rwMux.RUnlock() + return ts.smt.GetShardInclusionFragment(shardID) +} + +// GetShardInclusionFragmentWithRoot atomically reads the parent fragment and +// the raw SMT root from the same in-memory snapshot. This avoids serving a +// fragment from one root and looking up a block for a newer root in a later +// read section. +func (ts *ThreadSafeSMT) GetShardInclusionFragmentWithRoot(shardID api.ShardID) (*api.ParentInclusionFragment, []byte, error) { + ts.rwMux.RLock() + defer ts.rwMux.RUnlock() + + parentFragment, err := ts.smt.GetShardInclusionFragment(shardID) + if err != nil { + return nil, nil, err + } + if parentFragment == nil { + return nil, nil, nil + } + return parentFragment, ts.smt.GetRootHashRaw(), nil +} + // GetKeyLength exposes the configured SMT key length. func (ts *ThreadSafeSMT) GetKeyLength() int { ts.rwMux.RLock() @@ -124,7 +176,11 @@ func (ts *ThreadSafeSMT) WithReadLock(fn func() error) error { func (ts *ThreadSafeSMT) WithWriteLock(fn func() error) error { ts.rwMux.Lock() defer ts.rwMux.Unlock() - return fn() + if err := fn(); err != nil { + return err + } + ts.smt.ensureHashes() + return nil } // CreateSnapshot creates a thread-safe snapshot of the current SMT state diff --git a/internal/smt/thread_safe_smt_snapshot.go b/internal/smt/thread_safe_smt_snapshot.go index 3c7c78b..7b482f3 100644 --- a/internal/smt/thread_safe_smt_snapshot.go +++ b/internal/smt/thread_safe_smt_snapshot.go @@ -18,6 +18,9 @@ type ThreadSafeSmtSnapshot struct { // NewThreadSafeSmtSnapshot creates a new thread-safe SMT snapshot wrapper func NewThreadSafeSmtSnapshot(snapshot *SmtSnapshot) *ThreadSafeSmtSnapshot { + // Prime hash caches before publishing the snapshot wrapper so concurrent + // readers never need to populate cache state under RLock. + snapshot.ensureHashes() return &ThreadSafeSmtSnapshot{ snapshot: snapshot, } @@ -53,7 +56,11 @@ func (tss *ThreadSafeSmtSnapshot) AddLeaf(path *big.Int, value []byte) error { // addLeafUnsafe adds a single leaf without acquiring locks (internal use) func (tss *ThreadSafeSmtSnapshot) addLeafUnsafe(path *big.Int, value []byte) error { - return tss.snapshot.AddLeaf(path, value) + if err := tss.snapshot.AddLeaf(path, value); err != nil { + return err + } + tss.snapshot.ensureHashes() + return nil } // GetRootHash returns the current root hash of the snapshot @@ -72,6 +79,25 @@ func (tss *ThreadSafeSmtSnapshot) GetPath(path *big.Int) (*api.MerkleTreePath, e return tss.snapshot.GetPath(path) } +// GetRootHashRaw returns the raw 32-byte root hash of the snapshot. +// This is a read operation that can be performed concurrently. +func (tss *ThreadSafeSmtSnapshot) GetRootHashRaw() []byte { + tss.rwMux.RLock() + defer tss.rwMux.RUnlock() + + return tss.snapshot.GetRootHashRaw() +} + +// GetInclusionCert builds a v2 inclusion certificate for the leaf +// at the given raw 32-byte key in this snapshot. This is a read operation +// that can be performed concurrently. +func (tss *ThreadSafeSmtSnapshot) GetInclusionCert(key []byte) (*api.InclusionCert, error) { + tss.rwMux.RLock() + defer tss.rwMux.RUnlock() + + return tss.snapshot.GetInclusionCert(key) +} + // GetStats returns statistics about the snapshot // This is a read operation that can be performed concurrently func (tss *ThreadSafeSmtSnapshot) GetStats() map[string]interface{} { @@ -100,6 +126,7 @@ func (tss *ThreadSafeSmtSnapshot) Commit(originalSMT *ThreadSafeSMT) { defer originalSMT.rwMux.Unlock() tss.snapshot.Commit() + originalSMT.smt.ensureHashes() } // SetCommitTarget changes the target tree for snapshot chaining. @@ -127,5 +154,9 @@ func (tss *ThreadSafeSmtSnapshot) CreateSnapshot() *ThreadSafeSmtSnapshot { func (tss *ThreadSafeSmtSnapshot) WithWriteLock(fn func(*SmtSnapshot) error) error { tss.rwMux.Lock() defer tss.rwMux.Unlock() - return fn(tss.snapshot) + if err := fn(tss.snapshot); err != nil { + return err + } + tss.snapshot.ensureHashes() + return nil } diff --git a/internal/smt/thread_safe_smt_test.go b/internal/smt/thread_safe_smt_test.go new file mode 100644 index 0000000..cb60e77 --- /dev/null +++ b/internal/smt/thread_safe_smt_test.go @@ -0,0 +1,72 @@ +package smt + +import ( + "bytes" + "math/big" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +func TestThreadSafeSMT_AddPreHashedLeaf_StoresChildRoot(t *testing.T) { + tree := NewThreadSafeSMT(NewParentSparseMerkleTree(api.SHA256, 2)) + + path := big.NewInt(4) // shard ID with sentinel bit for a 2-bit parent tree + hash := bytes.Repeat([]byte{0xab}, 32) + + require.NoError(t, tree.AddPreHashedLeaf(path, hash)) + + leaf, err := tree.GetLeaf(path) + require.NoError(t, err) + require.Equal(t, hash, leaf.Value) + require.True(t, leaf.isChild) +} + +func TestThreadSafeSMT_GetShardInclusionFragment_ReturnsNativeParentFragment(t *testing.T) { + tree := NewThreadSafeSMT(NewParentSparseMerkleTree(api.SHA256, 2)) + + path := big.NewInt(4) // shard ID with sentinel bit for a 2-bit parent tree + hash := bytes.Repeat([]byte{0xcd}, 32) + + require.NoError(t, tree.AddPreHashedLeaf(path, hash)) + + fragment, err := tree.GetShardInclusionFragment(4) + require.NoError(t, err) + require.NotNil(t, fragment) + require.Equal(t, hash, []byte(fragment.ShardLeafValue)) + require.GreaterOrEqual(t, len(fragment.CertificateBytes), api.BitmapSize) + require.Equal(t, 0, len(fragment.CertificateBytes)%api.SiblingSize) + require.NoError(t, fragment.Verify(4, 2, hash, tree.GetRootHashRaw(), api.SHA256)) +} + +func TestThreadSafeSMT_PrimesHashesOnConstruction(t *testing.T) { + tree := NewThreadSafeSMT(NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) + requireBranchHashesPrimed(t, tree.smt.root) +} + +func TestThreadSafeSMTSnapshot_PrimesHashesOnConstruction(t *testing.T) { + tree := NewThreadSafeSMT(NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits)) + snapshot := tree.CreateSnapshot() + requireBranchHashesPrimed(t, snapshot.snapshot.root) +} + +func requireBranchHashesPrimed(t *testing.T, b branch) { + t.Helper() + + switch node := b.(type) { + case *LeafBranch: + require.True(t, node.hashSet) + case *NodeBranch: + require.True(t, node.hashSet) + if node.Left != nil { + requireBranchHashesPrimed(t, node.Left) + } + if node.Right != nil { + requireBranchHashesPrimed(t, node.Right) + } + default: + require.Failf(t, "unexpected branch type", "%T", b) + } +} diff --git a/internal/smt/yellowpaper_hash_semantics_test.go b/internal/smt/yellowpaper_hash_semantics_test.go new file mode 100644 index 0000000..d0efa46 --- /dev/null +++ b/internal/smt/yellowpaper_hash_semantics_test.go @@ -0,0 +1,54 @@ +package smt + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +func TestLeafHash_DomainSeparated(t *testing.T) { + leaf := newLeafBranch(big.NewInt(0b10101), []byte("value")) + hasher := api.NewDataHasher(api.SHA256) + + got := leaf.calculateHash(hasher) + + expectedHasher := api.NewDataHasher(api.SHA256) + expectedHasher.Reset(). + AddData([]byte{0x00}). + AddData(leaf.Key). + AddData(leaf.Value) + expected := expectedHasher.GetHash().RawHash + + require.Equal(t, expected, got) +} + +func TestNodeHash_UnaryPassthrough(t *testing.T) { + leftLeaf := newLeafBranch(big.NewInt(0b101), []byte("left")) + node := newNodeBranch(big.NewInt(0b10), leftLeaf, nil) + + hasher := api.NewDataHasher(api.SHA256) + got := node.calculateHash(hasher) + + require.Equal(t, leftLeaf.calculateHash(api.NewDataHasher(api.SHA256)), got) +} + +func TestNodeHash_BinaryDomainSeparated(t *testing.T) { + leftLeaf := newLeafBranch(big.NewInt(0b10), []byte("left")) + rightLeaf := newLeafBranch(big.NewInt(0b11), []byte("right")) + node := newNodeBranch(big.NewInt(0b10), leftLeaf, rightLeaf) + + hasher := api.NewDataHasher(api.SHA256) + got := node.calculateHash(hasher) + + expectedHasher := api.NewDataHasher(api.SHA256) + expectedHasher.Reset(). + AddData([]byte{0x01, node.Depth}). + AddData(leftLeaf.calculateHash(api.NewDataHasher(api.SHA256))). + AddData(rightLeaf.calculateHash(api.NewDataHasher(api.SHA256))) + expected := expectedHasher.GetHash().RawHash + + require.Equal(t, expected, got) +} diff --git a/pkg/api/certification_request.go b/pkg/api/certification_request.go index c9af6c0..ef458d1 100644 --- a/pkg/api/certification_request.go +++ b/pkg/api/certification_request.go @@ -12,9 +12,9 @@ import ( type CertificationRequest struct { _ struct{} `cbor:",toarray"` - // StateID is the unique identifier of the certification request, used as a key in the state tree. - // Calculated as hash of CBOR array [CertificationData.OwnerPredicate, CertificationData.SourceStateHashImprint], - // prefixed by two bytes that define the hashing algorithm (two zero bytes in case of SHA2_256). + // StateID is the unique identifier of the certification request, used as a + // key in the state tree. In v2 it is the raw 32-byte hash of the CBOR array + // [CertificationData.OwnerPredicate, CertificationData.SourceStateHash]. StateID StateID // CertificationData contains the necessary cryptographic data needed for the CertificationRequest. @@ -58,12 +58,10 @@ type CertificationData struct { // - params = 5821 000102..20 (byte array of length 33 containing the raw bytes of the public key value) OwnerPredicate Predicate `json:"ownerPredicate"` - // SourceStateHash is the source data (token) hash, - // prefixed by two bytes that define the hashing algorithm (two zero bytes in case of SHA2_256). + // SourceStateHash is the raw 32-byte hash of the source data. SourceStateHash SourceStateHash `json:"sourceStateHash"` - // TransactionHash is the entire transaction data hash (including the source data), - // prefixed by two bytes that define the hashing algorithm (two zero bytes in case of SHA2_256). + // TransactionHash is the raw 32-byte hash of the transaction data. TransactionHash TransactionHash `json:"transactionHash"` // Witness is the "unlocking part" of owner predicate. In case of PayToPublicKey owner predicate the witness must be @@ -73,13 +71,13 @@ type CertificationData struct { } // SigDataHash returns the data hash used for signature generation. -// The hash is calculated as CBOR array of [sourceStateHashImprint, transactionHashImprint]. +// The hash is calculated as the CBOR array [SourceStateHash, TransactionHash]. func (c CertificationData) SigDataHash() (*DataHash, error) { return SigDataHash(c.SourceStateHash, c.TransactionHash), nil } // SigDataHash returns the data hash used for signature generation. -// The hash is calculated as CBOR array of [sourceStateHashImprint, transactionHashImprint]. +// The hash is calculated as the CBOR array [sourceStateHash, transactionHash]. func SigDataHash(sourceStateHash []byte, transactionHash []byte) *DataHash { return NewDataHasher(SHA256).AddData( CborArray(2)). @@ -88,9 +86,9 @@ func SigDataHash(sourceStateHash []byte, transactionHash []byte) *DataHash { GetHash() } -// Hash returns the data hash of certification data, used as a key in the state tree. -// The hash is calculated as CBOR array of [OwnerPredicate, SourceStateHashImprint, TransactionHashImprint, Witness] and -// the value returned is in DataHash imprint format (2-byte algorithm prefix + hash of cbor array). +// Hash returns the data hash of certification data. +// The hash is calculated as the CBOR array +// [OwnerPredicate, SourceStateHash, TransactionHash, Witness]. func (c CertificationData) Hash() ([]byte, error) { dataHash, err := CertDataHash(c.OwnerPredicate, c.SourceStateHash, c.TransactionHash, c.Witness) if err != nil { @@ -103,8 +101,9 @@ func (c CertificationData) CreateStateID() (StateID, error) { return CreateStateID(c.OwnerPredicate, c.SourceStateHash) } -// CertDataHash returns the data hash of certification data, used as a key in the state tree. -// The hash is calculated as CBOR array of [OwnerPredicate, SourceStateHashImprint, TransactionHashImprint, Witness]. +// CertDataHash returns the data hash of certification data. +// The hash is calculated as the CBOR array +// [OwnerPredicate, SourceStateHash, TransactionHash, Witness]. func CertDataHash(ownerPredicate Predicate, sourceStateHash, transactionHash, signature []byte) (*DataHash, error) { predicateBytes, err := types.Cbor.Marshal(ownerPredicate) if err != nil { diff --git a/pkg/api/inclusion_cert.go b/pkg/api/inclusion_cert.go new file mode 100644 index 0000000..7ad3ec9 --- /dev/null +++ b/pkg/api/inclusion_cert.go @@ -0,0 +1,250 @@ +package api + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "math/bits" +) + +// SiblingSize is the fixed byte length of each sibling hash and of the +// leaf key / value hashes in an InclusionCert or ExclusionCert wire +// encoding. All supported SMT hash algorithms (SHA-256, SHA-3-256) +// produce 32-byte digests. +const SiblingSize = 32 + +// BitmapSize is the fixed byte length of the depth bitmap. The SMT +// key is StateTreeKeyLengthBits (256) bits, so the bitmap is 32 bytes. +const BitmapSize = StateTreeKeyLengthBytes + +// maxDepth is the exclusive upper bound on tree depth indices. Depth +// values walked by the verifier are in [0, maxDepth). +const maxDepth = StateTreeKeyLengthBits + +// Errors returned by certificate decoding and verification. +var ( + ErrCertTruncated = errors.New("inclusion cert: truncated") + ErrCertMisalignedSibs = errors.New("inclusion cert: sibling bytes not aligned to 32") + ErrCertBitmapMismatch = errors.New("inclusion cert: sibling count does not match bitmap popcount") + ErrCertRootMismatch = errors.New("inclusion cert: root mismatch") + ErrCertSiblingUnderflow = errors.New("inclusion cert: sibling underflow during verification") + ErrCertKeyLength = errors.New("inclusion cert: invalid key length") + ErrCertRootLength = errors.New("inclusion cert: invalid root length") + ErrCertUnknownAlgo = errors.New("inclusion cert: unknown hash algorithm") + ErrExclusionNotImpl = errors.New("exclusion cert: verification not yet implemented") +) + +// InclusionCert is the decoded v2 inclusion certificate. +// +// Wire format (raw binary, no framing): +// +// bitmap[32] || s_1[32] || ... || s_n[32] +// +// where n = popcount(bitmap). Siblings are in generation order +// (root-to-leaf): s_1 is the sibling at the shallowest depth with a +// bitmap bit set, s_n at the deepest. Verification walks depths +// 255..0 and consumes siblings from the end of the slice. +// +// The certificate carries no root, no key, and no value. Verification +// requires these to be supplied from the outer proof tuple: +// - key (sid) — from the RPC request parameter. +// - value (txhash) — from CertificationData.TransactionHash. +// - root — from UC.IR.h. +// +// See docs/inclusion-proof-wire.md for the full specification. +type InclusionCert struct { + Bitmap [BitmapSize]byte + Siblings [][SiblingSize]byte +} + +// MarshalBinary encodes the certificate to its wire form. +func (c *InclusionCert) MarshalBinary() ([]byte, error) { + out := make([]byte, 0, BitmapSize+len(c.Siblings)*SiblingSize) + out = append(out, c.Bitmap[:]...) + for i := range c.Siblings { + out = append(out, c.Siblings[i][:]...) + } + return out, nil +} + +// UnmarshalBinary decodes the wire form into the certificate. The +// sibling count is validated against the bitmap popcount. +func (c *InclusionCert) UnmarshalBinary(data []byte) error { + if len(data) < BitmapSize { + return ErrCertTruncated + } + copy(c.Bitmap[:], data[:BitmapSize]) + rest := data[BitmapSize:] + if len(rest)%SiblingSize != 0 { + return ErrCertMisalignedSibs + } + actual := len(rest) / SiblingSize + expected := bitmapPopcount(&c.Bitmap) + if actual != expected { + return fmt.Errorf("%w: have %d, want %d", ErrCertBitmapMismatch, actual, expected) + } + c.Siblings = make([][SiblingSize]byte, actual) + for i := 0; i < actual; i++ { + copy(c.Siblings[i][:], rest[i*SiblingSize:(i+1)*SiblingSize]) + } + return nil +} + +// Verify checks that applying the bitmap + siblings path on top of +// H_leaf(key, value) reproduces expectedRoot under the given hash +// algorithm. +// +// Parameters: +// - key: 32-byte SMT key, LSB-first layout. +// - value: raw leaf value bytes (v2 inclusion proofs use the tx hash). +// - expectedRoot: raw 32-byte root hash, taken from UC.IR.h. +// - algo: hash algorithm used by the SMT. +func (c *InclusionCert) Verify(key, value, expectedRoot []byte, algo HashAlgorithm) error { + if len(key) != StateTreeKeyLengthBytes { + return fmt.Errorf("%w: got %d, want %d", ErrCertKeyLength, len(key), StateTreeKeyLengthBytes) + } + + // Leaf hash: H(0x00 || key || value). + hasher := NewDataHasher(algo) + if hasher == nil { + return fmt.Errorf("%w: %d", ErrCertUnknownAlgo, algo) + } + hasher.Reset(). + AddData([]byte{0x00}). + AddData(key). + AddData(value) + h := hasher.GetHash().RawHash + + return verifyBitmapPath(&c.Bitmap, c.Siblings, key, h, expectedRoot, algo) +} + +func verifyBitmapPath(bitmap *[BitmapSize]byte, siblings [][SiblingSize]byte, key, startHash, expectedRoot []byte, algo HashAlgorithm) error { + if len(startHash) != SiblingSize { + return fmt.Errorf("%w: got %d, want %d", ErrCertRootLength, len(startHash), SiblingSize) + } + if len(expectedRoot) != SiblingSize { + return fmt.Errorf("%w: got %d, want %d", ErrCertRootLength, len(expectedRoot), SiblingSize) + } + hasher := NewDataHasher(algo) + if hasher == nil { + return fmt.Errorf("%w: %d", ErrCertUnknownAlgo, algo) + } + + // Walk depths from deepest to shallowest, consuming siblings from + // the end of the slice. Depths with bitmap bit clear are skipped + // (unary passthrough or off-path). + h := append([]byte(nil), startHash...) + j := len(siblings) + for d := maxDepth - 1; d >= 0; d-- { + if ((*bitmap)[d/8]>>(uint(d)%8))&1 == 0 { + continue + } + if d/8 >= len(key) { + return fmt.Errorf("%w: key too short for depth %d", ErrCertKeyLength, d) + } + if j == 0 { + return ErrCertSiblingUnderflow + } + j-- + sibling := siblings[j][:] + + hasher.Reset().AddData([]byte{0x01, byte(d)}) + if keyBitAt(key, d) == 1 { + // Descent went right at depth d → sibling is the left child. + hasher.AddData(sibling).AddData(h) + } else { + // Descent went left at depth d → sibling is the right child. + hasher.AddData(h).AddData(sibling) + } + h = hasher.GetHash().RawHash + } + if j != 0 { + return fmt.Errorf("%w: %d siblings unused", ErrCertBitmapMismatch, j) + } + if !bytes.Equal(h, expectedRoot) { + return ErrCertRootMismatch + } + return nil +} + +// ExclusionCert is the decoded v2 non-inclusion certificate. +// +// Wire format (raw binary, no framing): +// +// k_l[32] || h_l[32] || bitmap[32] || s_1[32] || ... || s_n[32] +// +// (k_l, h_l) is the witness leaf present in the tree at the position +// reached when routing the query key. bitmap + siblings describe the +// proof path from the root to that position, under the same root-to- +// leaf sibling ordering as InclusionCert. +// +// Verification semantics are not yet implemented in Go. The type and +// codec are frozen so clients can decode today; see +// docs/inclusion-proof-wire.md. +type ExclusionCert struct { + KL [SiblingSize]byte + HL [SiblingSize]byte + Bitmap [BitmapSize]byte + Siblings [][SiblingSize]byte +} + +// MarshalBinary encodes the exclusion certificate to its wire form. +func (c *ExclusionCert) MarshalBinary() ([]byte, error) { + out := make([]byte, 0, 2*SiblingSize+BitmapSize+len(c.Siblings)*SiblingSize) + out = append(out, c.KL[:]...) + out = append(out, c.HL[:]...) + out = append(out, c.Bitmap[:]...) + for i := range c.Siblings { + out = append(out, c.Siblings[i][:]...) + } + return out, nil +} + +// UnmarshalBinary decodes the wire form into the exclusion certificate. +// The sibling count is validated against the bitmap popcount. +func (c *ExclusionCert) UnmarshalBinary(data []byte) error { + const head = 2*SiblingSize + BitmapSize + if len(data) < head { + return ErrCertTruncated + } + copy(c.KL[:], data[:SiblingSize]) + copy(c.HL[:], data[SiblingSize:2*SiblingSize]) + copy(c.Bitmap[:], data[2*SiblingSize:head]) + rest := data[head:] + if len(rest)%SiblingSize != 0 { + return ErrCertMisalignedSibs + } + actual := len(rest) / SiblingSize + expected := bitmapPopcount(&c.Bitmap) + if actual != expected { + return fmt.Errorf("%w: have %d, want %d", ErrCertBitmapMismatch, actual, expected) + } + c.Siblings = make([][SiblingSize]byte, actual) + for i := 0; i < actual; i++ { + copy(c.Siblings[i][:], rest[i*SiblingSize:(i+1)*SiblingSize]) + } + return nil +} + +// Verify is not yet implemented for exclusion certificates. The wire +// schema is frozen so clients can decode. +func (c *ExclusionCert) Verify(queryKey, expectedRoot []byte, algo HashAlgorithm) error { + return ErrExclusionNotImpl +} + +// bitmapPopcount counts the set bits in the 32-byte depth bitmap. +func bitmapPopcount(b *[BitmapSize]byte) int { + total := 0 + for i := 0; i < BitmapSize; i += 8 { + total += bits.OnesCount64(binary.LittleEndian.Uint64(b[i:])) + } + return total +} + +// keyBitAt returns bit d of key under LSB-first byte layout: +// bit d is bit (d mod 8) of key[d / 8]. Matches PathToFixedBytes / +// FixedBytesToPath in state_id.go. +func keyBitAt(key []byte, d int) byte { + return (key[d/8] >> (uint(d) % 8)) & 1 +} diff --git a/pkg/api/inclusion_cert_compose.go b/pkg/api/inclusion_cert_compose.go new file mode 100644 index 0000000..d2b01c1 --- /dev/null +++ b/pkg/api/inclusion_cert_compose.go @@ -0,0 +1,101 @@ +package api + +import ( + "bytes" + "errors" + "fmt" +) + +var ( + ErrCertDepthOverlap = errors.New("inclusion cert: parent and child cert overlap in depth") + ErrCertDepthOrder = errors.New("inclusion cert: parent cert depths must be shallower than child cert depths") + ErrCertChildRootMismatch = errors.New("inclusion cert: parent fragment shard leaf value does not match child root") + ErrCertMissingChild = errors.New("inclusion cert: missing child cert") + ErrCertMissingParent = errors.New("inclusion cert: missing parent fragment") +) + +// ComposeInclusionCert merges a child inclusion certificate with the stored +// parent proof fragment for the child shard. The result is a single public +// InclusionCert that can later be verified against the parent UC.IR.h. +// +// Invariants enforced here: +// - parentFragment.ShardLeafValue must equal childRoot +// - parent fragment certificate bytes must decode as a valid InclusionCert +// - parent and child bitmaps must not overlap in depth +// - if both certs contain siblings, every parent depth must be shallower +// than every child depth +// - merged siblings stay in root-to-leaf order: parent first, then child +func ComposeInclusionCert(parentFragment *ParentInclusionFragment, child *InclusionCert, childRoot []byte) (*InclusionCert, error) { + if parentFragment == nil { + return nil, ErrCertMissingParent + } + if child == nil { + return nil, ErrCertMissingChild + } + if len(childRoot) != SiblingSize { + return nil, fmt.Errorf("%w: got %d, want %d", ErrCertRootLength, len(childRoot), SiblingSize) + } + if len(parentFragment.ShardLeafValue) != SiblingSize { + return nil, fmt.Errorf("invalid parent fragment shard leaf value length: got %d, want %d", + len(parentFragment.ShardLeafValue), SiblingSize) + } + if !bytes.Equal(parentFragment.ShardLeafValue, childRoot) { + return nil, ErrCertChildRootMismatch + } + + var parent InclusionCert + if err := parent.UnmarshalBinary(parentFragment.CertificateBytes); err != nil { + return nil, fmt.Errorf("failed to decode parent fragment cert: %w", err) + } + + if bitmapOverlap(&parent.Bitmap, &child.Bitmap) { + return nil, ErrCertDepthOverlap + } + + parentHasDepths, _, parentMax := bitmapDepthRange(&parent.Bitmap) + childHasDepths, childMin, _ := bitmapDepthRange(&child.Bitmap) + if parentHasDepths && childHasDepths && parentMax >= childMin { + return nil, fmt.Errorf("%w: deepest parent depth %d, shallowest child depth %d", + ErrCertDepthOrder, parentMax, childMin) + } + + out := &InclusionCert{} + for i := range out.Bitmap { + out.Bitmap[i] = parent.Bitmap[i] | child.Bitmap[i] + } + out.Siblings = make([][SiblingSize]byte, 0, len(parent.Siblings)+len(child.Siblings)) + out.Siblings = append(out.Siblings, parent.Siblings...) + out.Siblings = append(out.Siblings, child.Siblings...) + + if len(out.Siblings) != bitmapPopcount(&out.Bitmap) { + return nil, fmt.Errorf("%w: have %d, want %d", + ErrCertBitmapMismatch, len(out.Siblings), bitmapPopcount(&out.Bitmap)) + } + return out, nil +} + +func bitmapOverlap(a, b *[BitmapSize]byte) bool { + for i := 0; i < BitmapSize; i++ { + if a[i]&b[i] != 0 { + return true + } + } + return false +} + +func bitmapDepthRange(bitmap *[BitmapSize]byte) (ok bool, minDepth, maxDepthSeen int) { + minDepth = BitmapSize * 8 + for depth := 0; depth < BitmapSize*8; depth++ { + if (bitmap[depth/8]>>(uint(depth)%8))&1 == 0 { + continue + } + if !ok { + minDepth = depth + maxDepthSeen = depth + ok = true + continue + } + maxDepthSeen = depth + } + return ok, minDepth, maxDepthSeen +} diff --git a/pkg/api/inclusion_cert_test.go b/pkg/api/inclusion_cert_test.go new file mode 100644 index 0000000..ccebec7 --- /dev/null +++ b/pkg/api/inclusion_cert_test.go @@ -0,0 +1,523 @@ +package api + +import ( + "bytes" + "errors" + "testing" +) + +// hashLeafRaw computes H(0x00 || key || value) for test fixtures. +func hashLeafRaw(t *testing.T, algo HashAlgorithm, key, value []byte) []byte { + t.Helper() + h := NewDataHasher(algo) + h.AddData([]byte{0x00}).AddData(key).AddData(value) + return h.GetHash().RawHash +} + +// hashNodeRaw computes H(0x01 || depth || left || right) for test fixtures. +func hashNodeRaw(t *testing.T, algo HashAlgorithm, depth byte, left, right []byte) []byte { + t.Helper() + h := NewDataHasher(algo) + h.AddData([]byte{0x01, depth}).AddData(left).AddData(right) + return h.GetHash().RawHash +} + +func TestInclusionCertVerify_EmptyBitmap_SingleLeafTree(t *testing.T) { + // Single-leaf tree: no internal nodes, no siblings. Root is the + // leaf hash directly. + key := bytes.Repeat([]byte{0x01}, StateTreeKeyLengthBytes) + value := []byte("hello") + root := hashLeafRaw(t, SHA256, key, value) + + cert := &InclusionCert{} // bitmap zero, siblings nil + if err := cert.Verify(key, value, root, SHA256); err != nil { + t.Fatalf("verify empty-bitmap cert: %v", err) + } +} + +func TestInclusionCertVerify_SingleSiblingAtDepth0(t *testing.T) { + // Two-leaf tree diverging at depth 0. Our key has bit 0 = 0, so + // we went left at depth 0. Sibling is the right child. + key := make([]byte, StateTreeKeyLengthBytes) // all zeros → bit 0 = 0 + value := []byte("left leaf") + siblingHash := bytes.Repeat([]byte{0xAB}, SiblingSize) + + leafHash := hashLeafRaw(t, SHA256, key, value) + root := hashNodeRaw(t, SHA256, 0, leafHash, siblingHash) + + cert := &InclusionCert{} + cert.Bitmap[0] = 0x01 // depth 0 + var s [SiblingSize]byte + copy(s[:], siblingHash) + cert.Siblings = append(cert.Siblings, s) + + if err := cert.Verify(key, value, root, SHA256); err != nil { + t.Fatalf("verify single-sibling cert: %v", err) + } +} + +func TestInclusionCertVerify_TwoSiblingsRootToLeafWireOrder(t *testing.T) { + // Proof path touches depth 3 (shallower) and depth 7 (deeper). + // Wire order must be root → leaf: siblings[0] at depth 3, + // siblings[1] at depth 7. Verification consumes from the end: + // depth 7 first (sib7), then depth 3 (sib3). + + // Key byte 0 = 0b0000_1000 → bit 3 = 1, bit 7 = 0. + key := make([]byte, StateTreeKeyLengthBytes) + key[0] = 0b0000_1000 + value := []byte("v") + + sib3 := bytes.Repeat([]byte{0x33}, SiblingSize) + sib7 := bytes.Repeat([]byte{0x77}, SiblingSize) + + leaf := hashLeafRaw(t, SHA256, key, value) + // d=7: bit 7 = 0 → went left → sibling right + h7 := hashNodeRaw(t, SHA256, 7, leaf, sib7) + // d=3: bit 3 = 1 → went right → sibling left + h3 := hashNodeRaw(t, SHA256, 3, sib3, h7) + root := h3 + + cert := &InclusionCert{} + cert.Bitmap[0] = 0b1000_1000 // bits 3 and 7 + var s3, s7 [SiblingSize]byte + copy(s3[:], sib3) + copy(s7[:], sib7) + // Canonical root-to-leaf wire order: shallowest first. + cert.Siblings = append(cert.Siblings, s3, s7) + + if err := cert.Verify(key, value, root, SHA256); err != nil { + t.Fatalf("verify two-sibling cert: %v", err) + } +} + +func TestInclusionCertVerify_WrongSiblingOrderFails(t *testing.T) { + // Same setup as the two-sibling test but with siblings swapped + // into leaf-to-root order. Must fail. + key := make([]byte, StateTreeKeyLengthBytes) + key[0] = 0b0000_1000 + value := []byte("v") + + sib3 := bytes.Repeat([]byte{0x33}, SiblingSize) + sib7 := bytes.Repeat([]byte{0x77}, SiblingSize) + + leaf := hashLeafRaw(t, SHA256, key, value) + h7 := hashNodeRaw(t, SHA256, 7, leaf, sib7) + h3 := hashNodeRaw(t, SHA256, 3, sib3, h7) + root := h3 + + cert := &InclusionCert{} + cert.Bitmap[0] = 0b1000_1000 + var s3, s7 [SiblingSize]byte + copy(s3[:], sib3) + copy(s7[:], sib7) + // Wrong order: deepest first. + cert.Siblings = append(cert.Siblings, s7, s3) + + if err := cert.Verify(key, value, root, SHA256); err == nil { + t.Fatal("expected verify to fail with wrong sibling order") + } +} + +func TestInclusionCertVerify_DepthSpanning8Bytes(t *testing.T) { + // Single sibling at depth 200 to exercise cross-byte bitmap + // addressing and large depth encoding. + const depth = 200 + + // Set bit 200 of the key so we went right at depth 200. + key := make([]byte, StateTreeKeyLengthBytes) + key[depth/8] |= 1 << (depth % 8) + value := []byte("deep") + + siblingHash := bytes.Repeat([]byte{0x5A}, SiblingSize) + leafHash := hashLeafRaw(t, SHA256, key, value) + // bit 200 = 1 → went right → sibling is left + root := hashNodeRaw(t, SHA256, byte(depth), siblingHash, leafHash) + + cert := &InclusionCert{} + cert.Bitmap[depth/8] = 1 << (depth % 8) + var s [SiblingSize]byte + copy(s[:], siblingHash) + cert.Siblings = append(cert.Siblings, s) + + if err := cert.Verify(key, value, root, SHA256); err != nil { + t.Fatalf("verify depth-%d cert: %v", depth, err) + } +} + +func TestInclusionCertVerify_WrongValueFails(t *testing.T) { + key := bytes.Repeat([]byte{0x01}, StateTreeKeyLengthBytes) + root := hashLeafRaw(t, SHA256, key, []byte("right")) + + cert := &InclusionCert{} + if err := cert.Verify(key, []byte("wrong"), root, SHA256); err == nil { + t.Fatal("expected verify to fail on wrong value") + } +} + +func TestInclusionCertVerify_WrongRootFails(t *testing.T) { + key := bytes.Repeat([]byte{0x01}, StateTreeKeyLengthBytes) + value := []byte("v") + fakeRoot := bytes.Repeat([]byte{0xFF}, SiblingSize) + + cert := &InclusionCert{} + err := cert.Verify(key, value, fakeRoot, SHA256) + if !errors.Is(err, ErrCertRootMismatch) { + t.Fatalf("expected ErrCertRootMismatch, got %v", err) + } +} + +func TestInclusionCertVerify_InvalidKeyLength(t *testing.T) { + cert := &InclusionCert{} + err := cert.Verify([]byte{1, 2, 3}, []byte("v"), bytes.Repeat([]byte{0}, 32), SHA256) + if !errors.Is(err, ErrCertKeyLength) { + t.Fatalf("expected ErrCertKeyLength, got %v", err) + } +} + +func TestInclusionCertVerify_InvalidRootLength(t *testing.T) { + key := bytes.Repeat([]byte{0x01}, StateTreeKeyLengthBytes) + cert := &InclusionCert{} + err := cert.Verify(key, []byte("v"), []byte{0}, SHA256) + if !errors.Is(err, ErrCertRootLength) { + t.Fatalf("expected ErrCertRootLength, got %v", err) + } +} + +func TestInclusionCertVerify_UnknownAlgorithm(t *testing.T) { + key := bytes.Repeat([]byte{0x01}, StateTreeKeyLengthBytes) + root := bytes.Repeat([]byte{0}, SiblingSize) + cert := &InclusionCert{} + err := cert.Verify(key, []byte("v"), root, HashAlgorithm(99)) + if !errors.Is(err, ErrCertUnknownAlgo) { + t.Fatalf("expected ErrCertUnknownAlgo, got %v", err) + } +} + +func TestInclusionCertRoundTrip(t *testing.T) { + cert := &InclusionCert{} + cert.Bitmap[0] = 0b1010_1010 + cert.Bitmap[5] = 0xFF + cert.Bitmap[31] = 0x80 + + n := bitmapPopcount(&cert.Bitmap) + cert.Siblings = make([][SiblingSize]byte, n) + for i := 0; i < n; i++ { + for j := range cert.Siblings[i] { + cert.Siblings[i][j] = byte(i*7 + j) + } + } + + wire, err := cert.MarshalBinary() + if err != nil { + t.Fatalf("marshal: %v", err) + } + wantLen := BitmapSize + n*SiblingSize + if len(wire) != wantLen { + t.Fatalf("wire length: got %d want %d", len(wire), wantLen) + } + + var got InclusionCert + if err := got.UnmarshalBinary(wire); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got.Bitmap != cert.Bitmap { + t.Fatalf("bitmap mismatch") + } + if len(got.Siblings) != n { + t.Fatalf("sibling count: got %d want %d", len(got.Siblings), n) + } + for i := range cert.Siblings { + if got.Siblings[i] != cert.Siblings[i] { + t.Fatalf("sibling %d mismatch", i) + } + } +} + +func TestInclusionCertRoundTrip_EmptyBitmap(t *testing.T) { + cert := &InclusionCert{} + wire, err := cert.MarshalBinary() + if err != nil { + t.Fatalf("marshal: %v", err) + } + if len(wire) != BitmapSize { + t.Fatalf("empty cert wire length: got %d want %d", len(wire), BitmapSize) + } + var got InclusionCert + if err := got.UnmarshalBinary(wire); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(got.Siblings) != 0 { + t.Fatalf("expected 0 siblings, got %d", len(got.Siblings)) + } +} + +func TestInclusionCertDecode_Truncated(t *testing.T) { + short := make([]byte, 10) + var cert InclusionCert + if err := cert.UnmarshalBinary(short); !errors.Is(err, ErrCertTruncated) { + t.Fatalf("expected ErrCertTruncated, got %v", err) + } +} + +func TestInclusionCertDecode_MisalignedSiblings(t *testing.T) { + data := make([]byte, BitmapSize+17) // bitmap + partial sibling + var cert InclusionCert + if err := cert.UnmarshalBinary(data); !errors.Is(err, ErrCertMisalignedSibs) { + t.Fatalf("expected ErrCertMisalignedSibs, got %v", err) + } +} + +func TestInclusionCertDecode_BitmapMismatch(t *testing.T) { + // Bitmap claims 1 set bit; wire carries 2 siblings. + data := make([]byte, BitmapSize+2*SiblingSize) + data[0] = 0x01 + var cert InclusionCert + if err := cert.UnmarshalBinary(data); !errors.Is(err, ErrCertBitmapMismatch) { + t.Fatalf("expected ErrCertBitmapMismatch, got %v", err) + } +} + +func TestBitmapPopcount(t *testing.T) { + cases := []struct { + name string + set map[int]bool + want int + }{ + {"empty", nil, 0}, + {"one bit byte 0", map[int]bool{0: true}, 1}, + {"one bit byte 31", map[int]bool{255: true}, 1}, + {"multi", map[int]bool{0: true, 7: true, 8: true, 200: true, 255: true}, 5}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var b [BitmapSize]byte + for d := range tc.set { + b[d/8] |= 1 << (d % 8) + } + got := bitmapPopcount(&b) + if got != tc.want { + t.Fatalf("got %d want %d", got, tc.want) + } + }) + } +} + +func TestKeyBitAt(t *testing.T) { + key := make([]byte, StateTreeKeyLengthBytes) + key[0] = 0b1010_0101 + key[1] = 0x01 // bit 8 set + key[31] = 0x80 // bit 255 set + + checks := []struct { + pos int + want byte + }{ + {0, 1}, {1, 0}, {2, 1}, {5, 1}, {7, 1}, + {8, 1}, {9, 0}, + {255, 1}, + } + for _, c := range checks { + if got := keyBitAt(key, c.pos); got != c.want { + t.Errorf("keyBitAt(%d) = %d, want %d", c.pos, got, c.want) + } + } +} + +func TestExclusionCertRoundTrip(t *testing.T) { + cert := &ExclusionCert{} + for i := range cert.KL { + cert.KL[i] = byte(i) + } + for i := range cert.HL { + cert.HL[i] = byte(i + 100) + } + cert.Bitmap[2] = 0b0000_0101 // 2 set bits + cert.Siblings = make([][SiblingSize]byte, 2) + for i := range cert.Siblings[0] { + cert.Siblings[0][i] = 0xAA + } + for i := range cert.Siblings[1] { + cert.Siblings[1][i] = 0xBB + } + + wire, err := cert.MarshalBinary() + if err != nil { + t.Fatalf("marshal: %v", err) + } + wantLen := 2*SiblingSize + BitmapSize + 2*SiblingSize + if len(wire) != wantLen { + t.Fatalf("wire length: got %d want %d", len(wire), wantLen) + } + + var got ExclusionCert + if err := got.UnmarshalBinary(wire); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got.KL != cert.KL { + t.Fatalf("KL mismatch") + } + if got.HL != cert.HL { + t.Fatalf("HL mismatch") + } + if got.Bitmap != cert.Bitmap { + t.Fatalf("bitmap mismatch") + } + if len(got.Siblings) != 2 { + t.Fatalf("sibling count: got %d want 2", len(got.Siblings)) + } + for i := range cert.Siblings { + if got.Siblings[i] != cert.Siblings[i] { + t.Fatalf("sibling %d mismatch", i) + } + } +} + +func TestExclusionCertDecode_Truncated(t *testing.T) { + short := make([]byte, 2*SiblingSize+BitmapSize-1) // 1 byte short of header + var cert ExclusionCert + if err := cert.UnmarshalBinary(short); !errors.Is(err, ErrCertTruncated) { + t.Fatalf("expected ErrCertTruncated, got %v", err) + } +} + +func TestExclusionCertVerify_NotImplemented(t *testing.T) { + cert := &ExclusionCert{} + key := bytes.Repeat([]byte{0}, StateTreeKeyLengthBytes) + root := bytes.Repeat([]byte{0}, SiblingSize) + if err := cert.Verify(key, root, SHA256); !errors.Is(err, ErrExclusionNotImpl) { + t.Fatalf("expected ErrExclusionNotImpl, got %v", err) + } +} + +func TestComposeInclusionCert_Success(t *testing.T) { + key := make([]byte, StateTreeKeyLengthBytes) // shard prefix 00, child bits also 0 + value := []byte("leaf-value") + + childSibling := bytes.Repeat([]byte{0x55}, SiblingSize) + parentSibling := bytes.Repeat([]byte{0x99}, SiblingSize) + + leaf := hashLeafRaw(t, SHA256, key, value) + childRoot := hashNodeRaw(t, SHA256, 5, leaf, childSibling) + parentRoot := hashNodeRaw(t, SHA256, 1, childRoot, parentSibling) + + child := &InclusionCert{} + child.Bitmap[5/8] |= 1 << (5 % 8) + var childS [SiblingSize]byte + copy(childS[:], childSibling) + child.Siblings = append(child.Siblings, childS) + + parent := &InclusionCert{} + parent.Bitmap[1/8] |= 1 << (1 % 8) + var parentS [SiblingSize]byte + copy(parentS[:], parentSibling) + parent.Siblings = append(parent.Siblings, parentS) + parentBytes, err := parent.MarshalBinary() + if err != nil { + t.Fatalf("marshal parent cert: %v", err) + } + + fragment := &ParentInclusionFragment{ + CertificateBytes: parentBytes, + ShardLeafValue: append([]byte(nil), childRoot...), + } + + composed, err := ComposeInclusionCert(fragment, child, childRoot) + if err != nil { + t.Fatalf("compose inclusion cert: %v", err) + } + if err := composed.Verify(key, value, parentRoot, SHA256); err != nil { + t.Fatalf("verify composed cert: %v", err) + } + if got := len(composed.Siblings); got != 2 { + t.Fatalf("composed sibling count: got %d want 2", got) + } + if composed.Siblings[0] != parentS { + t.Fatalf("parent sibling not first in composed cert") + } + if composed.Siblings[1] != childS { + t.Fatalf("child sibling not second in composed cert") + } +} + +func TestComposeInclusionCert_RejectsMalformedParentFragment(t *testing.T) { + child := &InclusionCert{} + fragment := &ParentInclusionFragment{ + CertificateBytes: []byte{0x01, 0x02}, + ShardLeafValue: bytes.Repeat([]byte{0xAA}, SiblingSize), + } + _, err := ComposeInclusionCert(fragment, child, bytes.Repeat([]byte{0xAA}, SiblingSize)) + if err == nil { + t.Fatal("expected malformed parent fragment to fail") + } + if !errors.Is(err, ErrCertTruncated) { + t.Fatalf("expected ErrCertTruncated, got %v", err) + } +} + +func TestComposeInclusionCert_RejectsChildRootMismatch(t *testing.T) { + parent := &InclusionCert{} + parentBytes, err := parent.MarshalBinary() + if err != nil { + t.Fatalf("marshal parent cert: %v", err) + } + fragment := &ParentInclusionFragment{ + CertificateBytes: parentBytes, + ShardLeafValue: bytes.Repeat([]byte{0xAA}, SiblingSize), + } + _, err = ComposeInclusionCert(fragment, &InclusionCert{}, bytes.Repeat([]byte{0xBB}, SiblingSize)) + if !errors.Is(err, ErrCertChildRootMismatch) { + t.Fatalf("expected ErrCertChildRootMismatch, got %v", err) + } +} + +func TestComposeInclusionCert_RejectsDepthOverlap(t *testing.T) { + child := &InclusionCert{} + child.Bitmap[5/8] |= 1 << (5 % 8) + var childS [SiblingSize]byte + child.Siblings = append(child.Siblings, childS) + + parent := &InclusionCert{} + parent.Bitmap[5/8] |= 1 << (5 % 8) + var parentS [SiblingSize]byte + parent.Siblings = append(parent.Siblings, parentS) + parentBytes, err := parent.MarshalBinary() + if err != nil { + t.Fatalf("marshal parent cert: %v", err) + } + + childRoot := bytes.Repeat([]byte{0x11}, SiblingSize) + fragment := &ParentInclusionFragment{ + CertificateBytes: parentBytes, + ShardLeafValue: append([]byte(nil), childRoot...), + } + _, err = ComposeInclusionCert(fragment, child, childRoot) + if !errors.Is(err, ErrCertDepthOverlap) { + t.Fatalf("expected ErrCertDepthOverlap, got %v", err) + } +} + +func TestComposeInclusionCert_RejectsParentDeeperThanChild(t *testing.T) { + child := &InclusionCert{} + child.Bitmap[3/8] |= 1 << (3 % 8) + var childS [SiblingSize]byte + child.Siblings = append(child.Siblings, childS) + + parent := &InclusionCert{} + parent.Bitmap[7/8] |= 1 << (7 % 8) + var parentS [SiblingSize]byte + parent.Siblings = append(parent.Siblings, parentS) + parentBytes, err := parent.MarshalBinary() + if err != nil { + t.Fatalf("marshal parent cert: %v", err) + } + + childRoot := bytes.Repeat([]byte{0x22}, SiblingSize) + fragment := &ParentInclusionFragment{ + CertificateBytes: parentBytes, + ShardLeafValue: append([]byte(nil), childRoot...), + } + _, err = ComposeInclusionCert(fragment, child, childRoot) + if !errors.Is(err, ErrCertDepthOrder) { + t.Fatalf("expected ErrCertDepthOrder, got %v", err) + } +} diff --git a/pkg/api/inclusion_proof_v2_verify_test.go b/pkg/api/inclusion_proof_v2_verify_test.go new file mode 100644 index 0000000..d8d6133 --- /dev/null +++ b/pkg/api/inclusion_proof_v2_verify_test.go @@ -0,0 +1,245 @@ +package api + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unicitynetwork/bft-go-base/types" +) + +// TestInclusionProofV2Verify_SingleLeafHappyPath builds the minimal valid +// v2 inclusion proof — a single-leaf tree whose root is simply +// H_leaf(key, value) = H(0x00 || key || value), an empty InclusionCert +// (zero bitmap, zero siblings), and a UnicityCertificate whose IR.Hash is +// the raw 32-byte root. Verification must succeed. +func TestInclusionProofV2Verify_SingleLeafHappyPath(t *testing.T) { + stateID := RequireNewImprintV2("1111111111111111111111111111111111111111111111111111111111111111") + txHash := RequireNewImprintV2("2222222222222222222222222222222222222222222222222222222222222222") + + req := &CertificationRequest{ + StateID: stateID, + CertificationData: CertificationData{ + TransactionHash: txHash, + }, + } + + key, err := stateID.GetTreeKey() + require.NoError(t, err) + value := txHash.DataBytes() + + // H_leaf(key, value) under InclusionProofV2HashAlgorithm. + hasher := NewDataHasher(InclusionProofV2HashAlgorithm) + hasher.Reset(). + AddData([]byte{0x00}). + AddData(key). + AddData(value) + leafRoot := append([]byte(nil), hasher.GetHash().RawHash...) + require.Len(t, leafRoot, SiblingSize) + + // Empty InclusionCert — single-leaf edge case. + cert := &InclusionCert{} + certBytes, err := cert.MarshalBinary() + require.NoError(t, err) + require.Len(t, certBytes, BitmapSize) + + ucBytes, err := types.Cbor.Marshal(types.UnicityCertificate{ + InputRecord: &types.InputRecord{ + Hash: leafRoot, + }, + }) + require.NoError(t, err) + + proof := &InclusionProofV2{ + CertificationData: &req.CertificationData, + CertificateBytes: certBytes, + UnicityCertificate: ucBytes, + } + + require.NoError(t, proof.Verify(req)) +} + +// TestInclusionProofV2Verify_WrongRootFails confirms that a mismatch +// between UC.IR.h and the leaf's computed hash is surfaced as a verify +// error. +func TestInclusionProofV2Verify_WrongRootFails(t *testing.T) { + stateID := RequireNewImprintV2("1111111111111111111111111111111111111111111111111111111111111111") + txHash := RequireNewImprintV2("2222222222222222222222222222222222222222222222222222222222222222") + req := &CertificationRequest{ + StateID: stateID, + CertificationData: CertificationData{ + TransactionHash: txHash, + }, + } + + // Wrong root: all-zeros instead of the actual leaf hash. + wrongRoot := make([]byte, SiblingSize) + + cert := &InclusionCert{} + certBytes, err := cert.MarshalBinary() + require.NoError(t, err) + + ucBytes, err := types.Cbor.Marshal(types.UnicityCertificate{ + InputRecord: &types.InputRecord{ + Hash: wrongRoot, + }, + }) + require.NoError(t, err) + + proof := &InclusionProofV2{ + CertificationData: &req.CertificationData, + CertificateBytes: certBytes, + UnicityCertificate: ucBytes, + } + + err = proof.Verify(req) + require.Error(t, err) + require.ErrorIs(t, err, ErrCertRootMismatch) +} + +// TestInclusionProofV2Verify_NonInclusionShortCircuits confirms that when +// CertificationData is nil the Verify method short-circuits with +// ErrExclusionNotImpl, before attempting to decode the certificate or UC. +func TestInclusionProofV2Verify_NonInclusionShortCircuits(t *testing.T) { + stateID := RequireNewImprintV2("1111111111111111111111111111111111111111111111111111111111111111") + req := &CertificationRequest{StateID: stateID} + + proof := &InclusionProofV2{ + CertificationData: nil, + CertificateBytes: nil, // intentionally invalid — must not be touched + UnicityCertificate: nil, // intentionally invalid — must not be touched + } + + err := proof.Verify(req) + require.Error(t, err) + require.True(t, errors.Is(err, ErrExclusionNotImpl)) +} + +// TestInclusionProofV2Verify_MissingRequestTxHash checks that malformed +// outer requests fail fast with a clear error instead of relying on deeper +// cert verification. +func TestInclusionProofV2Verify_MissingRequestTxHash(t *testing.T) { + stateID := RequireNewImprintV2("1111111111111111111111111111111111111111111111111111111111111111") + txHash := RequireNewImprintV2("2222222222222222222222222222222222222222222222222222222222222222") + + // Build a valid proof envelope first. + key, err := stateID.GetTreeKey() + require.NoError(t, err) + hasher := NewDataHasher(InclusionProofV2HashAlgorithm) + hasher.Reset(). + AddData([]byte{0x00}). + AddData(key). + AddData(txHash.DataBytes()) + root := append([]byte(nil), hasher.GetHash().RawHash...) + + cert := &InclusionCert{} + certBytes, err := cert.MarshalBinary() + require.NoError(t, err) + ucBytes, err := types.Cbor.Marshal(types.UnicityCertificate{ + InputRecord: &types.InputRecord{Hash: root}, + }) + require.NoError(t, err) + + proof := &InclusionProofV2{ + CertificationData: &CertificationData{ + TransactionHash: txHash, + }, + CertificateBytes: certBytes, + UnicityCertificate: ucBytes, + } + + // Malformed request: missing tx hash. + req := &CertificationRequest{ + StateID: stateID, + CertificationData: CertificationData{ + TransactionHash: nil, + }, + } + + err = proof.Verify(req) + require.Error(t, err) + require.Contains(t, err.Error(), "missing certification request transaction hash") +} + +// TestInclusionProofV2Verify_MismatchedProofTxHashFails ensures the proof +// payload cannot carry a different tx hash than the outer request while +// still verifying against the request's leaf value. +func TestInclusionProofV2Verify_MismatchedProofTxHashFails(t *testing.T) { + stateID := RequireNewImprintV2("1111111111111111111111111111111111111111111111111111111111111111") + reqTxHash := RequireNewImprintV2("2222222222222222222222222222222222222222222222222222222222222222") + proofTxHash := RequireNewImprintV2("3333333333333333333333333333333333333333333333333333333333333333") + + req := &CertificationRequest{ + StateID: stateID, + CertificationData: CertificationData{ + TransactionHash: reqTxHash, + }, + } + + // Build a valid root for the request tx hash so the only failure is the + // proof/request consistency check. + key, err := stateID.GetTreeKey() + require.NoError(t, err) + hasher := NewDataHasher(InclusionProofV2HashAlgorithm) + hasher.Reset(). + AddData([]byte{0x00}). + AddData(key). + AddData(reqTxHash.DataBytes()) + root := append([]byte(nil), hasher.GetHash().RawHash...) + + cert := &InclusionCert{} + certBytes, err := cert.MarshalBinary() + require.NoError(t, err) + ucBytes, err := types.Cbor.Marshal(types.UnicityCertificate{ + InputRecord: &types.InputRecord{Hash: root}, + }) + require.NoError(t, err) + + proof := &InclusionProofV2{ + CertificationData: &CertificationData{ + TransactionHash: proofTxHash, + }, + CertificateBytes: certBytes, + UnicityCertificate: ucBytes, + } + + err = proof.Verify(req) + require.Error(t, err) + require.Contains(t, err.Error(), "proof certification data transaction hash does not match") +} + +// TestInclusionProofV2Verify_RejectsInvalidUCInputRecordHash confirms that +// v2 requires UC.IR.h to be exactly 32 bytes. +func TestInclusionProofV2Verify_RejectsInvalidUCInputRecordHash(t *testing.T) { + stateID := RequireNewImprintV2("1111111111111111111111111111111111111111111111111111111111111111") + txHash := RequireNewImprintV2("2222222222222222222222222222222222222222222222222222222222222222") + req := &CertificationRequest{ + StateID: stateID, + CertificationData: CertificationData{ + TransactionHash: txHash, + }, + } + + legacyRoot := make([]byte, SiblingSize+2) + + cert := &InclusionCert{} + certBytes, err := cert.MarshalBinary() + require.NoError(t, err) + + ucBytes, err := types.Cbor.Marshal(types.UnicityCertificate{ + InputRecord: &types.InputRecord{ + Hash: legacyRoot, + }, + }) + require.NoError(t, err) + + proof := &InclusionProofV2{ + CertificationData: &req.CertificationData, + CertificateBytes: certBytes, + UnicityCertificate: ucBytes, + } + + err = proof.Verify(req) + require.Error(t, err) + require.Contains(t, err.Error(), "UC.IR.h length") +} diff --git a/pkg/api/merkle_tree_path_verify_test.go b/pkg/api/merkle_tree_path_verify_test.go index b1a2fd1..f28233d 100644 --- a/pkg/api/merkle_tree_path_verify_test.go +++ b/pkg/api/merkle_tree_path_verify_test.go @@ -20,6 +20,22 @@ func createLeaf(path int64, value []byte) *smt.Leaf { } } +func normalizeLegacyPath(t *testing.T, raw string) *big.Int { + t.Helper() + + legacyPath, ok := new(big.Int).SetString(raw, 10) + require.True(t, ok, "failed to parse path") + + key, err := api.PathToFixedBytes(legacyPath, legacyPath.BitLen()-1) + require.NoError(t, err) + key, err = api.ImprintV2(key).GetTreeKey() + require.NoError(t, err) + + path, err := api.FixedBytesToPath(key, api.StateTreeKeyLengthBits) + require.NoError(t, err) + return path +} + // TestMerkleTreePathVerify tests comprehensive verification scenarios func TestMerkleTreePathVerify(t *testing.T) { t.Run("SingleLeaf", func(t *testing.T) { @@ -90,11 +106,12 @@ func TestMerkleTreePathVerify(t *testing.T) { }) t.Run("LargePaths", func(t *testing.T) { - tree := smt.NewSparseMerkleTree(api.SHA256, 272) + tree := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) - // Test with the actual large paths from the failing test - mintPath, _ := new(big.Int).SetString("7588607046638288532898314259371162887598150843702815116345200719347816808430746270", 10) - transferPath, _ := new(big.Int).SetString("7588595804959218369815512972651793411311840553453637142956782535261123804631684864", 10) + // Normalize these captured path fixtures to the current 256-bit path + // representation before inserting them into the tree. + mintPath := normalizeLegacyPath(t, "7588607046638288532898314259371162887598150843702815116345200719347816808430746270") + transferPath := normalizeLegacyPath(t, "7588595804959218369815512972651793411311840553453637142956782535261123804631684864") leaves := []*smt.Leaf{ {Path: mintPath, Value: []byte("mint")}, @@ -212,10 +229,10 @@ func TestMerkleTreePathVerify(t *testing.T) { }) t.Run("RealStateIDs", func(t *testing.T) { - // Test with actual stateID format (34-byte with algorithm prefix) - tree := smt.NewSparseMerkleTree(api.SHA256, 16+256) + // Test with concrete stateId values. + tree := smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) - // Create stateIDs with proper format + // Create stateIds stateID1 := "00007d535ade796772c5088b095e79a18e282437ee8d8238f5aa9d9c61694948ba9e" stateID2 := "00006478ca42f6949cfbd4b9e4a41b9a384ea78261c1776808da70cf21e98c345700" diff --git a/pkg/api/shard_match.go b/pkg/api/shard_match.go new file mode 100644 index 0000000..18fc895 --- /dev/null +++ b/pkg/api/shard_match.go @@ -0,0 +1,48 @@ +package api + +import ( + "encoding/hex" + "fmt" + "math/bits" +) + +// MatchesShardPrefix checks whether the LSB-first bits of keyBytes match the +// shard prefix defined by shardBitmask. The bitmask encodes a sentinel-prefixed +// shard ID (e.g. 0b100 = shard 0 in a 2-bit tree). keyBytes must be at least +// ceil(shardDepth/8) bytes long. +func MatchesShardPrefix(keyBytes []byte, shardBitmask int) (bool, error) { + shardDepth := bits.Len(uint(shardBitmask)) - 1 + if shardDepth < 0 { + return false, fmt.Errorf("invalid shard bitmask: %d", shardBitmask) + } + if len(keyBytes) < (shardDepth+7)/8 { + return false, fmt.Errorf("key too short for shard depth %d: got %d bytes", shardDepth, len(keyBytes)) + } + + for d := 0; d < shardDepth; d++ { + expected := byte((uint(shardBitmask) >> uint(d)) & 1) + actual := (keyBytes[d/8] >> (uint(d) % 8)) & 1 + if actual != expected { + return false, nil + } + } + return true, nil +} + +// MatchesShardPrefixFromHex decodes a hex-encoded 32-byte state key and +// applies MatchesShardPrefix. +func MatchesShardPrefixFromHex(keyHex string, shardBitmask int) (bool, error) { + keyBytes, err := hex.DecodeString(keyHex) + if err != nil { + return false, fmt.Errorf("failed to decode state key: %w", err) + } + + if len(keyBytes) != StateTreeKeyLengthBytes { + return false, fmt.Errorf( + "state key must be exactly %d bytes, got %d", + StateTreeKeyLengthBytes, len(keyBytes), + ) + } + + return MatchesShardPrefix(keyBytes, shardBitmask) +} diff --git a/pkg/api/smt.go b/pkg/api/smt.go index 6ac58f3..a57bd9b 100644 --- a/pkg/api/smt.go +++ b/pkg/api/smt.go @@ -45,36 +45,72 @@ func (m *MerkleTreePath) Verify(stateID *big.Int) (*PathVerificationResult, erro var currentPath *big.Int var currentData *[]byte + stepPaths := make([]*big.Int, len(m.Steps)) + stepData := make([]*[]byte, len(m.Steps)) for i, step := range m.Steps { - stepPath, ok := new(big.Int).SetString(step.Path, 10) - if !ok || stepPath.Sign() < 0 { + parsedPath, ok := new(big.Int).SetString(step.Path, 10) + if !ok || parsedPath.Sign() < 0 { return nil, fmt.Errorf("invalid step path '%s'", step.Path) } - var stepData *[]byte + var parsedData *[]byte if step.Data != nil { data, err := hex.DecodeString(*step.Data) if err != nil { return nil, fmt.Errorf("invalid step data '%s': %w", *step.Data, err) } - stepData = &data + parsedData = &data } + stepPaths[i] = parsedPath + stepData[i] = parsedData + } + + // Resolve the leaf path represented by the proof (may differ from stateID + // for exclusion proofs). We need it to hash leaves as H(0x00 || key || value). + var leafPath *big.Int + if len(stepPaths) > 0 { + leafPath = new(big.Int).Set(stepPaths[0]) + for i := 1; i < len(stepPaths); i++ { + if leafPath.BitLen() < 2 { + leafPath = big.NewInt(1) + } + pathLen := stepPaths[i].BitLen() - 1 + if pathLen < 0 { + return nil, fmt.Errorf("invalid path '%s' on step %d", m.Steps[i].Path, i+1) + } + mask := new(big.Int).SetBit(new(big.Int).Set(stepPaths[i]), pathLen, 0) + leafPath.Lsh(leafPath, uint(pathLen)) + leafPath.Or(leafPath, mask) + } + } + + var leafKey []byte + if len(stepPaths) > 0 && stepPaths[0].BitLen() >= 2 { + var err error + leafKey, err = PathToFixedBytes(leafPath, leafPath.BitLen()-1) + if err != nil { + return nil, fmt.Errorf("invalid leaf path encoding in proof: %w", err) + } + } + + fullKeyBits := stateID.BitLen() - 1 + for i := range stepPaths { + stepPath := stepPaths[i] + stepBytes := stepData[i] if i == 0 { if stepPath.BitLen() >= 2 { - // First step, normal case: data is the value in the leaf, apply the leaf hashing rule - hasher.Reset().AddData(CborArray(2)) - hasher.AddCborBytes(BigintEncode(stepPath)) - if stepData == nil { - hasher.AddCborNull() - } else { - hasher.AddCborBytes(*stepData) + // First step, normal case: data is the value in the leaf, apply + // the current SMT leaf hashing rule H(0x00 || key || value). + hasher.Reset().AddData([]byte{0x00}).AddData(leafKey) + if stepBytes != nil { + hasher.AddData(*stepBytes) } currentData = &hasher.GetHash().RawHash } else { // First step, special case: data is the "our branch" hash value for the next step // Note that in this case stepPath is a "naked" direction bit - currentData = stepData + currentData = stepBytes } currentPath = stepPath } else { @@ -83,26 +119,43 @@ func (m *MerkleTreePath) Verify(stateID *big.Int) (*PathVerificationResult, erro if currentPath.Bit(0) == 0 { // Our branch on the left, sibling on the right left = currentData - right = stepData + right = stepBytes } else { // Sibling on the left, our branch on the right - left = stepData + left = stepBytes right = currentData } - hasher.Reset().AddData(CborArray(3)) - hasher.AddCborBytes(BigintEncode(stepPath)) - if left == nil { - hasher.AddCborNull() - } else { - hasher.AddCborBytes(*left) - } - if right == nil { - hasher.AddCborNull() + // Under the current SMT hashing rules, a unary node hash is the child hash. + if left == nil && right != nil { + currentData = right + } else if right == nil && left != nil { + currentData = left } else { - hasher.AddCborBytes(*right) + var depth int + if currentPath.BitLen() >= 2 { + depth = fullKeyBits - (currentPath.BitLen() - 1) + } else { + // Legacy MerkleTreePath may encode the current branch as + // only a raw direction bit ("0" or "1"), not a full + // sentinel-encoded path. In that case derive depth from + // the next step path as a fallback. + depth = stepPath.BitLen() - 1 + } + if depth < 0 || depth > 255 { + return nil, fmt.Errorf("invalid node depth %d on step %d", depth, i+1) + } + + hasher.Reset(). + AddData([]byte{0x01, byte(depth)}) + if left != nil { + hasher.AddData(*left) + } + if right != nil { + hasher.AddData(*right) + } + currentData = &hasher.GetHash().RawHash } - currentData = &hasher.GetHash().RawHash // Initialization for when currentPath is a "naked" direction bit if currentPath.BitLen() < 2 { @@ -111,9 +164,9 @@ func (m *MerkleTreePath) Verify(stateID *big.Int) (*PathVerificationResult, erro // Append step path bits to current path pathLen := stepPath.BitLen() - 1 if pathLen < 0 { - return nil, fmt.Errorf("invalid path '%s' on step %d", step.Path, i+1) + return nil, fmt.Errorf("invalid path '%s' on step %d", m.Steps[i].Path, i+1) } - mask := new(big.Int).SetBit(stepPath, pathLen, 0) + mask := new(big.Int).SetBit(new(big.Int).Set(stepPath), pathLen, 0) currentPath.Lsh(currentPath, uint(pathLen)) currentPath.Or(currentPath, mask) } diff --git a/pkg/api/state_id.go b/pkg/api/state_id.go index bad0120..440492a 100644 --- a/pkg/api/state_id.go +++ b/pkg/api/state_id.go @@ -8,7 +8,7 @@ import ( "github.com/unicitynetwork/bft-go-base/types" ) -// ImprintV2 is the unified type for both V1 (imprints with algorithm prefix) and V2 (raw bytes) +// ImprintV2 stores hash-like identifiers used by the public API. type ImprintV2 HexBytes type StateID = ImprintV2 @@ -16,6 +16,13 @@ type SourceStateHash = ImprintV2 type TransactionHash = ImprintV2 type RequestID = ImprintV2 +const ( + // StateTreeKeyLengthBits is the v2 SMT key size. + // The key is the raw 32-byte hash value (no per-key algorithm prefix). + StateTreeKeyLengthBits = 256 + StateTreeKeyLengthBytes = StateTreeKeyLengthBits / 8 +) + func NewImprintV2(s string) (ImprintV2, error) { b, err := NewHexBytesFromString(s) if err != nil { @@ -44,20 +51,85 @@ func (r ImprintV2) IsV1() bool { } func (r ImprintV2) GetPath() (*big.Int, error) { - // pad v2 imprints with two zero bytes to maintain consistency in smt - var path []byte - if len(r) == 32 { - path = make([]byte, 34) - copy(path[2:], r) - } else { - path = r[:] - } - - // Converts StateID hex string to a big.Int for use as an SMT path. - // Prefixes with "0x01" to preserve leading zero bits in the original hex string, - // ensuring consistent path representation in the Sparse Merkle Tree. - b := append([]byte{0x01}, path...) - return new(big.Int).SetBytes(b), nil + key, err := r.GetTreeKey() + if err != nil { + return nil, err + } + return FixedBytesToPath(key, StateTreeKeyLengthBits) +} + +// GetTreeKey returns the canonical SMT key bytes (32 bytes, no algorithm prefix). +func (r ImprintV2) GetTreeKey() ([]byte, error) { + key := r.DataBytes() + if len(key) != StateTreeKeyLengthBytes { + return nil, fmt.Errorf("invalid imprint length for SMT key: expected %d bytes key data, got %d", StateTreeKeyLengthBytes, len(key)) + } + return append([]byte(nil), key...), nil +} + +// PathToFixedBytes converts a sentinel-prefixed SMT path into fixed-width key bytes. +// Byte order follows the v2 SMT bit layout: +// key bit d is bit (d%8) of key[d/8] (LSB-first across bytes). +func PathToFixedBytes(path *big.Int, keyLengthBits int) ([]byte, error) { + if keyLengthBits <= 0 { + return nil, fmt.Errorf("invalid key length: %d", keyLengthBits) + } + if path == nil || path.Sign() <= 0 { + return nil, fmt.Errorf("invalid path: must be positive") + } + if path.BitLen()-1 != keyLengthBits { + return nil, fmt.Errorf("invalid path length: expected %d bits, got %d", keyLengthBits, path.BitLen()-1) + } + + keyLengthBytes := (keyLengthBits + 7) / 8 + keyInt := new(big.Int).Set(path) + // Clear sentinel bit (the highest bit at index keyLengthBits). + keyInt.SetBit(keyInt, keyLengthBits, 0) + + beKey := keyInt.Bytes() + if len(beKey) > keyLengthBytes { + return nil, fmt.Errorf("path bytes too long: expected at most %d bytes, got %d", keyLengthBytes, len(beKey)) + } + + bePadded := make([]byte, keyLengthBytes) + copy(bePadded[keyLengthBytes-len(beKey):], beKey) + + out := make([]byte, keyLengthBytes) + for i := range out { + // Convert from big-endian integer bytes to LSB-first SMT key byte order. + out[i] = bePadded[keyLengthBytes-1-i] + } + return out, nil +} + +// FixedBytesToPath converts fixed-width SMT key bytes into sentinel-prefixed path form. +func FixedBytesToPath(key []byte, keyLengthBits int) (*big.Int, error) { + if keyLengthBits <= 0 { + return nil, fmt.Errorf("invalid key length: %d", keyLengthBits) + } + keyLengthBytes := (keyLengthBits + 7) / 8 + if len(key) != keyLengthBytes { + return nil, fmt.Errorf("invalid key length in bytes: expected %d, got %d", keyLengthBytes, len(key)) + } + + // For non-byte-aligned keys, ensure unused high bits are zero in the last + // (highest-index) byte under LSB-first key-byte ordering. + if rem := keyLengthBits % 8; rem != 0 { + mask := byte(0xFF << rem) + if key[keyLengthBytes-1]&mask != 0 { + return nil, fmt.Errorf("invalid key: unused high bits must be zero") + } + } + + // Convert from LSB-first SMT key byte order to big-endian integer bytes. + be := make([]byte, keyLengthBytes) + for i := range be { + be[i] = key[keyLengthBytes-1-i] + } + + path := new(big.Int).SetBytes(be) + path.SetBit(path, keyLengthBits, 1) + return path, nil } func (r ImprintV2) Algorithm() []byte { diff --git a/pkg/api/state_id_bitorder_test.go b/pkg/api/state_id_bitorder_test.go new file mode 100644 index 0000000..3e8c3f4 --- /dev/null +++ b/pkg/api/state_id_bitorder_test.go @@ -0,0 +1,36 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFixedBytesToPath_UsesLSBFirstBitAddressing(t *testing.T) { + key := make([]byte, StateTreeKeyLengthBytes) + key[0] = 0x01 + + path, err := FixedBytesToPath(key, StateTreeKeyLengthBits) + require.NoError(t, err) + + // v2 SMT semantics: depth 0 is bit 0 of byte 0. + require.Equal(t, uint(1), path.Bit(0)) + require.Equal(t, uint(0), path.Bit(248)) +} + +func TestPathToFixedBytes_RoundtripLSBFirst(t *testing.T) { + key := []byte{ + 0x8d, 0x17, 0x23, 0x41, 0x99, 0xfe, 0x00, 0x7c, + 0x11, 0xaa, 0x52, 0x02, 0x7f, 0x03, 0x10, 0x20, + 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, + 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, 0x12, 0x34, 0x56, + } + require.Len(t, key, StateTreeKeyLengthBytes) + + path, err := FixedBytesToPath(key, StateTreeKeyLengthBits) + require.NoError(t, err) + + got, err := PathToFixedBytes(path, StateTreeKeyLengthBits) + require.NoError(t, err) + require.Equal(t, key, got) +} diff --git a/pkg/api/types.go b/pkg/api/types.go index 062650b..3f02611 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -68,17 +68,18 @@ type AggregatorRecord struct { // Block represents a blockchain block type Block struct { - Index *BigInt `json:"index"` - ChainID string `json:"chainId"` - ShardID ShardID `json:"shardId"` - Version string `json:"version"` - ForkID string `json:"forkId"` - RootHash HexBytes `json:"rootHash"` - PreviousBlockHash HexBytes `json:"previousBlockHash"` - NoDeletionProofHash HexBytes `json:"noDeletionProofHash"` - CreatedAt *Timestamp `json:"createdAt"` - UnicityCertificate HexBytes `json:"unicityCertificate"` - ParentMerkleTreePath *MerkleTreePath `json:"parentMerkleTreePath,omitempty"` // child mode only + Index *BigInt `json:"index"` + ChainID string `json:"chainId"` + ShardID ShardID `json:"shardId"` + Version string `json:"version"` + ForkID string `json:"forkId"` + RootHash HexBytes `json:"rootHash"` + PreviousBlockHash HexBytes `json:"previousBlockHash"` + NoDeletionProofHash HexBytes `json:"noDeletionProofHash"` + CreatedAt *Timestamp `json:"createdAt"` + UnicityCertificate HexBytes `json:"unicityCertificate"` + ParentFragment *ParentInclusionFragment `json:"parentFragment,omitempty"` // child mode only + ParentBlockNumber uint64 `json:"parentBlockNumber,omitempty"` // child mode only } // NoDeletionProof represents a no-deletion proof @@ -109,16 +110,74 @@ type GetInclusionProofResponseV2 struct { InclusionProof *InclusionProofV2 `json:"inclusionProof"` } +// InclusionProofV2 is the canonical v2 inclusion proof payload. +// +// Wire form (CBOR toarray): +// +// [ +// certificationDataOrNull, +// certificateBytes: bstr, // InclusionCert or ExclusionCert raw wire form +// unicityCertificate: raw CBOR +// ] +// +// Discriminator: +// - CertificationData != nil → inclusion. CertificateBytes is an +// InclusionCert wire payload. The SMT key comes from the outer RPC +// request (stateId); the leaf value is CertificationData.TransactionHash. +// - CertificationData == nil → non-inclusion. CertificateBytes is an +// ExclusionCert wire payload. Non-inclusion verification is not yet +// implemented in Go. +// +// The expected SMT root is ALWAYS taken from UC.IR.h (input record hash +// of the embedded Unicity Certificate). No root field appears here. +// +// See docs/inclusion-proof-wire.md for the frozen specification. type InclusionProofV2 struct { _ struct{} `cbor:",toarray"` CertificationData *CertificationData `json:"certificationData"` - MerkleTreePath *MerkleTreePath `json:"merkleTreePath"` + CertificateBytes HexBytes `json:"certificateBytes"` UnicityCertificate types.RawCBOR `json:"unicityCertificate"` } +// ParentInclusionFragment is the internal parent-tree proof fragment stored on +// finalized child blocks and returned by get_shard_proof. CertificateBytes +// uses the same bitmap+sibling wire shape as InclusionCert; ShardLeafValue is +// the parent leaf value proven by that fragment and must equal the child SMT +// root before later composition. +type ParentInclusionFragment struct { + CertificateBytes HexBytes `json:"certificateBytes"` + ShardLeafValue HexBytes `json:"shardLeafValue"` +} + +func (f *ParentInclusionFragment) Verify(shardID ShardID, keyLength int, expectedLeafValue, expectedRoot []byte, algo HashAlgorithm) error { + if f == nil { + return errors.New("missing parent fragment") + } + if len(f.ShardLeafValue) != SiblingSize { + return fmt.Errorf("invalid parent fragment shard leaf value length: got %d, want %d", len(f.ShardLeafValue), SiblingSize) + } + if !bytes.Equal(f.ShardLeafValue, expectedLeafValue) { + return errors.New("parent fragment shard leaf value does not match expected child root") + } + + var cert InclusionCert + if err := cert.UnmarshalBinary(f.CertificateBytes); err != nil { + return fmt.Errorf("failed to decode parent fragment cert: %w", err) + } + + path := big.NewInt(int64(shardID)) + key, err := PathToFixedBytes(path, keyLength) + if err != nil { + return fmt.Errorf("failed to derive parent fragment key: %w", err) + } + + return verifyBitmapPath(&cert.Bitmap, cert.Siblings, key, f.ShardLeafValue, expectedRoot, algo) +} + type RootShardInclusionProof struct { - MerkleTreePath *MerkleTreePath `json:"merkleTreePath"` - UnicityCertificate HexBytes `json:"unicityCertificate"` + ParentFragment *ParentInclusionFragment `json:"parentFragment,omitempty"` + UnicityCertificate HexBytes `json:"unicityCertificate"` + BlockNumber uint64 `json:"blockNumber,omitempty"` } // GetNoDeletionProofResponse represents the get_no_deletion_proof JSON-RPC response @@ -187,8 +246,9 @@ type GetShardProofRequest struct { // GetShardProofResponse represents the get_shard_proof JSON-RPC response type GetShardProofResponse struct { - MerkleTreePath *MerkleTreePath `json:"merkleTreePath"` // Proof path for the shard - UnicityCertificate HexBytes `json:"unicityCertificate"` // Unicity Certificate from the finalized block + ParentFragment *ParentInclusionFragment `json:"parentFragment,omitempty"` // native parent fragment for child v2 composition + UnicityCertificate HexBytes `json:"unicityCertificate"` // Unicity Certificate from the finalized block + BlockNumber uint64 `json:"blockNumber,omitempty"` } // HealthStatus represents the health status of the service @@ -218,9 +278,15 @@ func (h *HealthStatus) AddDetail(key, value string) { h.Details[key] = value } -func (r *RootShardInclusionProof) IsValid(shardRootHash string) bool { - return r.MerkleTreePath != nil && len(r.UnicityCertificate) > 0 && - len(r.MerkleTreePath.Steps) > 0 && r.MerkleTreePath.Steps[0].Data != nil && *r.MerkleTreePath.Steps[0].Data == shardRootHash +func (r *RootShardInclusionProof) IsValid(shardID ShardID, keyLength int, shardRootHash HexBytes) bool { + if r == nil || len(r.UnicityCertificate) == 0 || r.ParentFragment == nil { + return false + } + rootRaw, err := ucInputRecordHashRaw(r.UnicityCertificate) + if err != nil { + return false + } + return r.ParentFragment.Verify(shardID, keyLength, shardRootHash, rootRaw, InclusionProofV2HashAlgorithm) == nil } type Sharding struct { @@ -247,16 +313,77 @@ func (c *GetInclusionProofResponseV2) UnmarshalJSON(data []byte) error { return types.Cbor.Unmarshal(hb, c) } +// InclusionProofV2HashAlgorithm is the SMT hash algorithm locked in by the +// v2 inclusion proof wire contract. It is fixed to SHA-256; changing it +// requires a format version bump. +const InclusionProofV2HashAlgorithm = SHA256 + +// Verify checks a v2 inclusion proof against the outer +// CertificationRequest. The expected SMT root is sourced exclusively from +// UC.IR.h (strictly 32 raw bytes); CertificateBytes never carries a root. +// Non-inclusion proofs short-circuit with ErrExclusionNotImpl — v2 +// exclusion verification is not yet implemented in Go. func (p *InclusionProofV2) Verify(v2 *CertificationRequest) error { - path, err := v2.StateID.GetPath() + if p == nil { + return errors.New("nil inclusion proof") + } + if v2 == nil { + return errors.New("nil certification request") + } + if p.CertificationData == nil { + return ErrExclusionNotImpl + } + if len(v2.CertificationData.TransactionHash) == 0 { + return errors.New("missing certification request transaction hash") + } + if !bytes.Equal( + p.CertificationData.TransactionHash.DataBytes(), + v2.CertificationData.TransactionHash.DataBytes(), + ) { + return errors.New("proof certification data transaction hash does not match certification request transaction hash") + } + + rootRaw, err := p.UCInputRecordHashRaw() if err != nil { - return fmt.Errorf("failed to get path: %w", err) + return err + } + + var cert InclusionCert + if err := cert.UnmarshalBinary(p.CertificateBytes); err != nil { + return fmt.Errorf("failed to decode inclusion cert: %w", err) } - expectedLeafValue, err := v2.CertificationData.Hash() + key, err := v2.StateID.GetTreeKey() if err != nil { - return fmt.Errorf("failed to get leaf value: %w", err) + return fmt.Errorf("failed to derive SMT key from stateId: %w", err) + } + // v2 leaf value is the raw transaction hash. + value := v2.CertificationData.TransactionHash.DataBytes() + return cert.Verify(key, value, rootRaw, InclusionProofV2HashAlgorithm) +} + +// UCInputRecordHashRaw decodes the embedded Unicity Certificate and +// returns UC.IR.h as a raw 32-byte hash. Any other length is rejected. +func (p *InclusionProofV2) UCInputRecordHashRaw() ([]byte, error) { + return ucInputRecordHashRaw(p.UnicityCertificate) +} + +func ucInputRecordHashRaw(raw []byte) ([]byte, error) { + if len(raw) == 0 { + return nil, errors.New("missing unicity certificate") + } + var uc types.UnicityCertificate + if err := types.Cbor.Unmarshal(raw, &uc); err != nil { + return nil, fmt.Errorf("failed to decode unicity certificate: %w", err) + } + if uc.InputRecord == nil { + return nil, errors.New("unicity certificate missing input record") + } + ir := uc.InputRecord.Hash + if len(ir) != StateTreeKeyLengthBytes { + return nil, fmt.Errorf("invalid UC.IR.h length: got %d, want %d", + len(ir), StateTreeKeyLengthBytes) } - return verify(p.MerkleTreePath, path, expectedLeafValue) + return append([]byte(nil), ir...), nil } func verify(p *MerkleTreePath, path *big.Int, expectedLeafValue []byte) error { diff --git a/pkg/api/types_test.go b/pkg/api/types_test.go index cc21aaf..77b8d4f 100644 --- a/pkg/api/types_test.go +++ b/pkg/api/types_test.go @@ -9,7 +9,7 @@ import ( ) func TestStateIDMarshalJSON(t *testing.T) { - stateID := RequireNewImprintV2("0000cfe84a1828e2edd0a7d9533b23e519f746069a938d549a150e07e14dc0f9cf00") + stateID := RequireNewImprintV2("cfe84a1828e2edd0a7d9533b23e519f746069a938d549a150e07e14dc0f9cf00") data, err := json.Marshal(stateID) require.NoError(t, err, "Failed to marshal StateID") @@ -39,7 +39,7 @@ func TestHexBytesMarshalJSON(t *testing.T) { } func TestImprintHexStringMarshalJSON(t *testing.T) { - imprint := RequireNewImprintV2("0000cd60") + imprint := RequireNewImprintV2("cd60a4ad038d834f4ef0fefc4a9f4b5a8f4e1dd51c79f0f4bbcb5c39f4c8d8a1") data, err := json.Marshal(imprint) require.NoError(t, err, "Failed to marshal ImprintHexString") diff --git a/test/integration/commitment_versions_compatibility_test.go b/test/integration/commitment_versions_compatibility_test.go deleted file mode 100644 index a866130..0000000 --- a/test/integration/commitment_versions_compatibility_test.go +++ /dev/null @@ -1,208 +0,0 @@ -package integration - -import ( - "encoding/json" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/require" - mongoContainer "github.com/testcontainers/testcontainers-go/modules/mongodb" - redisContainer "github.com/testcontainers/testcontainers-go/modules/redis" - - "github.com/unicitynetwork/aggregator-go/internal/config" - "github.com/unicitynetwork/aggregator-go/internal/testutil" - "github.com/unicitynetwork/aggregator-go/pkg/api" -) - -// TestCompatibilityV2 tests compatibility of v1 and v2 commitments and proofs -func TestCompatibilityV2(t *testing.T) { - // phase 1: - // submit commitment_v1 - // submit commitment_v2 - // - // phase 2: - // verify inclusion_proof_v1 - // verify inclusion_proof_v2 - // - // phase 3: - // verify inclusion_proof_v2 for commitment_v1 returns error - // verify inclusion_proof_v1 for commitment_v2 returns error - // - // phase 4: - // verify block records v1 and v2 - ctx := t.Context() - - // Start containers (shared MongoDB with different databases per aggregator) - redis, err := redisContainer.Run(ctx, "redis:7") - require.NoError(t, err) - defer redis.Terminate(ctx) - redisURI, _ := redis.ConnectionString(ctx) - - mongo, err := mongoContainer.Run(ctx, "mongo:7.0", mongoContainer.WithReplicaSet("rs0")) - require.NoError(t, err) - defer mongo.Terminate(ctx) - mongoURI, _ := mongo.ConnectionString(ctx) - mongoURI += "&directConnection=true" - - // Start standalone aggregator - url := "http://localhost:9100" - nodeCleanup := startAggregator(t, ctx, "backwards_compatibility_Test", "9100", mongoURI, redisURI, config.ShardingModeStandalone, 0) - defer nodeCleanup() - waitForBlock(t, url, 1, 15*time.Second) - - t.Log("Phase 1: Submitting commitments...") - v1 := testutil.CreateTestCommitment(t, fmt.Sprintf("commitment_v1")) - v1Resp, err := submitCommitment(url, v1.ToAPI()) - require.NoError(t, err) - require.Equal(t, "SUCCESS", v1Resp.Status) - - v2 := testutil.CreateTestCertificationRequest(t, fmt.Sprintf("commitment_v2")) - submitCertificationRequest(t, url, v2.ToAPI()) - t.Logf("Commitments submitted successfully") - - t.Log("Phase 2: Verifying proofs...") - v1ProofResp := waitForProofAvailableV1(t, url, v1.RequestID.String(), 5*time.Second) - verifyProofV1(t, v1ProofResp, v1.ToAPI()) - v2ProofResp := waitForProofAvailableV2(t, url, v2.StateID.String(), 5*time.Second) - verifyProofV2(t, v2ProofResp, v2.ToAPI()) - t.Logf("Proofs verified successfully") - - t.Log("Phase 3: Verifying endpoint compatibility...") - // try to fetch v2 proof from v1 api - invalidProofV1, err := getInclusionProofV1(t, url, v2.StateID.String()) - require.Nil(t, invalidProofV1) - require.ErrorContains(t, err, "Failed to get inclusion proof") - - // try to fetch v1 proof from v2 api - invalidProofV2, err := getInclusionProofV2(t, url, v1.RequestID.String()) - require.Nil(t, invalidProofV2) - require.ErrorContains(t, err, "Failed to get inclusion proof") - - t.Log("Phase 4: Verifying block records endpoint...") - blockRecordsResponse := getBlockRecords(t, err, url, v2ProofResp.BlockNumber) - require.Len(t, blockRecordsResponse.AggregatorRecords, 2) - - blockCommitmentsResponse := getBlockCommitments(t, err, url, v2ProofResp.BlockNumber) - require.Len(t, blockCommitmentsResponse.Commitments, 2) -} - -func verifyProofV1(t *testing.T, v1ProofResp *api.GetInclusionProofResponseV1, v1 *api.Commitment) { - require.NotNil(t, v1ProofResp) - require.NotNil(t, v1ProofResp.InclusionProof) - require.NotNil(t, v1ProofResp.InclusionProof.MerkleTreePath) - require.NoError(t, v1ProofResp.InclusionProof.Verify(v1)) -} - -func verifyProofV2(t *testing.T, v2ProofResp *api.GetInclusionProofResponseV2, v2 *api.CertificationRequest) { - require.NotNil(t, v2ProofResp) - require.NotNil(t, v2ProofResp.InclusionProof) - require.NotNil(t, v2ProofResp.InclusionProof.MerkleTreePath) - require.NoError(t, v2ProofResp.InclusionProof.Verify(v2)) -} - -func submitCommitment(url string, commitment *api.Commitment) (*api.SubmitCommitmentResponse, error) { - result, err := rpcCall(url, "submit_commitment", commitment) - if err != nil { - return nil, err - } - - var response api.SubmitCommitmentResponse - if err := json.Unmarshal(result, &response); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &response, nil -} - -// waitForProofAvailableV1 waits for a VALID inclusion proof to become available -func waitForProofAvailableV1(t *testing.T, url, stateIDStr string, timeout time.Duration) *api.GetInclusionProofResponseV1 { - deadline := time.Now().Add(timeout) - requestID := api.RequireNewImprintV2(stateIDStr) - - for time.Now().Before(deadline) { - resp, err := getInclusionProofV1(t, url, stateIDStr) - require.NoError(t, err) - if resp.InclusionProof != nil && resp.InclusionProof.Authenticator != nil { - return resp // only return non inclusion proofs - } - time.Sleep(50 * time.Millisecond) - } - t.Fatalf("Timeout waiting for valid proof for requestID %s at %s", requestID, url) - return nil -} - -// waitForProofAvailableV2 waits for a VALID inclusion proof to become available -// This includes waiting for the parent proof to be received and joined -func waitForProofAvailableV2(t *testing.T, url, stateIDStr string, timeout time.Duration) *api.GetInclusionProofResponseV2 { - deadline := time.Now().Add(timeout) - stateID := api.RequireNewImprintV2(stateIDStr) - - for time.Now().Before(deadline) { - resp, err := getInclusionProofV2(t, url, stateIDStr) - require.NoError(t, err) - if resp.InclusionProof != nil && resp.InclusionProof.CertificationData != nil { - return resp // only return non inclusion proofs - } - time.Sleep(50 * time.Millisecond) - } - t.Fatalf("Timeout waiting for valid proof for stateID %s at %s", stateID, url) - return nil -} - -func getInclusionProofV1(t *testing.T, url string, requestID string) (*api.GetInclusionProofResponseV1, error) { - params := map[string]string{"requestId": requestID} - result, err := rpcCall(url, "get_inclusion_proof", params) - if err != nil { - return nil, err - } - - var response api.GetInclusionProofResponseV1 - if err := json.Unmarshal(result, &response); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &response, nil -} - -func getInclusionProofV2(t *testing.T, url string, stateID string) (*api.GetInclusionProofResponseV2, error) { - params := map[string]string{"stateId": stateID} - result, err := rpcCall(url, "get_inclusion_proof.v2", params) - if err != nil { - return nil, err - } - - var response api.GetInclusionProofResponseV2 - if err := json.Unmarshal(result, &response); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &response, nil -} - -func getBlockCommitments(t *testing.T, err error, url string, blockNumber uint64) api.GetBlockCommitmentsResponse { - blockCommitmentsResponseJSON, err := rpcCall( - url, - "get_block_commitments", - api.GetBlockCommitmentsRequest{BlockNumber: api.NewBigIntFromUint64(blockNumber)}, - ) - require.NoError(t, err) - - var blockCommitmentsResponse api.GetBlockCommitmentsResponse - require.NoError(t, json.Unmarshal(blockCommitmentsResponseJSON, &blockCommitmentsResponse)) - return blockCommitmentsResponse -} - -func getBlockRecords(t *testing.T, err error, url string, blockNumber uint64) api.GetBlockRecordsResponse { - blockRecordsResponseJSON, err := rpcCall( - url, - "get_block_records", - api.GetBlockCommitmentsRequest{BlockNumber: api.NewBigIntFromUint64(blockNumber)}, - ) - require.NoError(t, err) - - var blockRecordsResponse api.GetBlockRecordsResponse - require.NoError(t, json.Unmarshal(blockRecordsResponseJSON, &blockRecordsResponse)) - - return blockRecordsResponse -} diff --git a/test/integration/sharding_e2e_test.go b/test/integration/sharding_e2e_test.go index 92393c4..a09ce96 100644 --- a/test/integration/sharding_e2e_test.go +++ b/test/integration/sharding_e2e_test.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "math/big" "net/http" "net/url" "strconv" @@ -23,6 +22,7 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/service" + "github.com/unicitynetwork/aggregator-go/internal/signing" "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage" "github.com/unicitynetwork/aggregator-go/internal/testutil" @@ -61,17 +61,17 @@ func TestShardingE2E(t *testing.T) { waitForBlock(t, "http://localhost:9002", 1, 15*time.Second) // Submit commitments over multiple rounds - var shard2ReqIDs, shard3ReqIDs []string + var shard2Requests, shard3Requests []*api.CertificationRequest // Round 1: submit 2 commitments to each shard for i := 0; i < 2; i++ { - c, reqID := createCommitmentForShard(t, 2) + c := createCommitmentForShard(t, cfgForShard(2)) submitCertificationRequest(t, "http://localhost:9001", c) - shard2ReqIDs = append(shard2ReqIDs, reqID) + shard2Requests = append(shard2Requests, c) - c, reqID = createCommitmentForShard(t, 3) + c = createCommitmentForShard(t, cfgForShard(3)) submitCertificationRequest(t, "http://localhost:9002", c) - shard3ReqIDs = append(shard3ReqIDs, reqID) + shard3Requests = append(shard3Requests, c) } waitForBlock(t, "http://localhost:9001", 2, 15*time.Second) @@ -79,36 +79,36 @@ func TestShardingE2E(t *testing.T) { // Round 2: submit 2 more commitments to each shard for i := 0; i < 2; i++ { - c, reqID := createCommitmentForShard(t, 2) + c := createCommitmentForShard(t, cfgForShard(2)) submitCertificationRequest(t, "http://localhost:9001", c) - shard2ReqIDs = append(shard2ReqIDs, reqID) + shard2Requests = append(shard2Requests, c) - c, reqID = createCommitmentForShard(t, 3) + c = createCommitmentForShard(t, cfgForShard(3)) submitCertificationRequest(t, "http://localhost:9002", c) - shard3ReqIDs = append(shard3ReqIDs, reqID) + shard3Requests = append(shard3Requests, c) } waitForBlock(t, "http://localhost:9001", 3, 15*time.Second) waitForBlock(t, "http://localhost:9002", 3, 15*time.Second) // Round 3: submit 1 more commitment to each shard - c, reqID := createCommitmentForShard(t, 2) + c := createCommitmentForShard(t, cfgForShard(2)) submitCertificationRequest(t, "http://localhost:9001", c) - shard2ReqIDs = append(shard2ReqIDs, reqID) + shard2Requests = append(shard2Requests, c) - c, reqID = createCommitmentForShard(t, 3) + c = createCommitmentForShard(t, cfgForShard(3)) submitCertificationRequest(t, "http://localhost:9002", c) - shard3ReqIDs = append(shard3ReqIDs, reqID) + shard3Requests = append(shard3Requests, c) waitForBlock(t, "http://localhost:9001", 4, 15*time.Second) waitForBlock(t, "http://localhost:9002", 4, 15*time.Second) // Verify all proofs - for _, reqID := range shard2ReqIDs { - waitForValidProof(t, "http://localhost:9001", reqID, 15*time.Second) + for _, req := range shard2Requests { + waitForValidProof(t, "http://localhost:9001", req, 15*time.Second) } - for _, reqID := range shard3ReqIDs { - waitForValidProof(t, "http://localhost:9002", reqID, 15*time.Second) + for _, req := range shard3Requests { + waitForValidProof(t, "http://localhost:9002", req, 15*time.Second) } } @@ -164,8 +164,10 @@ func startAggregator(t *testing.T, ctx context.Context, name, port, mongoURI, re // Create SMT instance based on sharding mode var smtInstance *smt.SparseMerkleTree switch cfg.Sharding.Mode { - case config.ShardingModeStandalone, config.ShardingModeChild: - smtInstance = smt.NewSparseMerkleTree(api.SHA256, 16+256) + case config.ShardingModeStandalone: + smtInstance = smt.NewSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits) + case config.ShardingModeChild: + smtInstance = smt.NewChildSparseMerkleTree(api.SHA256, api.StateTreeKeyLengthBits, cfg.Sharding.Child.ShardID) case config.ShardingModeParent: smtInstance = smt.NewParentSparseMerkleTree(api.SHA256, cfg.Sharding.ShardIDLength) } @@ -236,21 +238,18 @@ func waitForBlock(t *testing.T, url string, minBlock int64, timeout time.Duratio t.Fatalf("Timeout waiting for block %d at %s", minBlock, url) } -func waitForValidProof(t *testing.T, url, reqID string, timeout time.Duration) { +func waitForValidProof(t *testing.T, url string, req *api.CertificationRequest, timeout time.Duration) { deadline := time.Now().Add(timeout) - reqIDObj := api.RequireNewImprintV2(reqID) - path, _ := reqIDObj.GetPath() for time.Now().Before(deadline) { - result, err := rpcCall(url, "get_inclusion_proof.v2", map[string]string{"stateId": reqID}) + result, err := rpcCall(url, "get_inclusion_proof.v2", &api.GetInclusionProofRequestV2{StateID: req.StateID}) if err == nil { var resp api.GetInclusionProofResponseV2 - json.Unmarshal(result, &resp) - if resp.InclusionProof != nil && resp.InclusionProof.MerkleTreePath != nil { - verifyResult, _ := resp.InclusionProof.MerkleTreePath.Verify(path) - if verifyResult != nil && verifyResult.Result { - return - } + require.NoError(t, json.Unmarshal(result, &resp)) + if resp.InclusionProof != nil && len(resp.InclusionProof.CertificateBytes) > 0 { + require.Greater(t, resp.BlockNumber, uint64(0)) + require.NoError(t, resp.InclusionProof.Verify(req)) + return } } time.Sleep(200 * time.Millisecond) @@ -258,17 +257,24 @@ func waitForValidProof(t *testing.T, url, reqID string, timeout time.Duration) { t.Fatalf("Timeout waiting for valid proof at %s", url) } -func createCommitmentForShard(t *testing.T, shardID api.ShardID) (*api.CertificationRequest, string) { - mask := big.NewInt(1) - expected := big.NewInt(int64(shardID & 1)) - +func createCommitmentForShard(t *testing.T, shardingCfg config.ShardingConfig) *api.CertificationRequest { + validator := signing.NewCertificationRequestValidator(shardingCfg) for i := 0; i < 1000; i++ { - c := testutil.CreateTestCertificationRequest(t, fmt.Sprintf("shard%d_%d_%d", shardID, i, time.Now().UnixNano())) - if new(big.Int).And(new(big.Int).SetBytes(c.StateID.Bytes()), mask).Cmp(expected) == 0 { - certificationRequest := c.ToAPI() - return certificationRequest, c.StateID.String() + c := testutil.CreateTestCertificationRequest(t, fmt.Sprintf("shard%d_%d_%d", shardingCfg.Child.ShardID, i, time.Now().UnixNano())) + if err := validator.ValidateShardID(c.StateID); err == nil { + return c.ToAPI() } } t.Fatal("Failed to generate commitment for shard") - return nil, "" + return nil +} + +func cfgForShard(shardID api.ShardID) config.ShardingConfig { + return config.ShardingConfig{ + Mode: config.ShardingModeChild, + ShardIDLength: 1, + Child: config.ChildConfig{ + ShardID: shardID, + }, + } }