diff --git a/src/AggregatorGateway.ts b/src/AggregatorGateway.ts index ba3bc99..ef101b6 100644 --- a/src/AggregatorGateway.ts +++ b/src/AggregatorGateway.ts @@ -284,7 +284,7 @@ export class AggregatorGateway { private static async setupSmt(smtStorage: ISmtStorage, aggregatorServerId: string): Promise { const smt = new SparseMerkleTree(new DataHasherFactory(HashAlgorithm.SHA256, NodeDataHasher)); - const smtWrapper = await Smt.create(smt); + const smtWrapper = new Smt(smt); let totalLeaves = 0; const chunkSize = 1000; @@ -465,8 +465,9 @@ export class AggregatorGateway { })); await this.smt.addLeaves(leavesToAdd); + const rootHash = await this.smt.rootHash(); - logger.info(`Updated in-memory SMT for follower node, new root hash: ${(this.smt.rootHash).toString()}`); + logger.info(`Updated in-memory SMT for follower node, new root hash: ${rootHash.toString()}`); }); logger.info(`BlockRecords change listener initialized for server ${this.serverId}`); diff --git a/src/AggregatorService.ts b/src/AggregatorService.ts index 263eda5..8842acb 100644 --- a/src/AggregatorService.ts +++ b/src/AggregatorService.ts @@ -42,7 +42,7 @@ export class AggregatorService { public async getInclusionProof(requestId: RequestId): Promise { const record = await this.recordStorage.get(requestId); - const merkleTreePath = this.smt.getPath(requestId.toBitString().toBigInt()); + const merkleTreePath = await this.smt.getPath(requestId.toBitString().toBigInt()); if (!record) { return new InclusionProof(merkleTreePath, null, null); diff --git a/src/RoundManager.ts b/src/RoundManager.ts index 59946d8..b330646 100644 --- a/src/RoundManager.ts +++ b/src/RoundManager.ts @@ -92,7 +92,7 @@ export class RoundManager { } let submitHashResponse; - const rootHash = this.smt.rootHash; + const rootHash = await this.smt.rootHash(); try { loggerWithMetadata.info(`Submitting hash to BFT: ${rootHash.toString()}...`); submitHashResponse = await this.bftClient.submitHash(rootHash); diff --git a/src/router/AggregatorRouter.ts b/src/router/AggregatorRouter.ts index a27b278..bc40d2b 100644 --- a/src/router/AggregatorRouter.ts +++ b/src/router/AggregatorRouter.ts @@ -38,7 +38,8 @@ export function setupRouter( app.get('/health', async (req: Request, res: Response) => { let smtRootHash: string | null = null; try { - smtRootHash = aggregatorService ? aggregatorService.getSmt().rootHash.toString() : null; + const hash = await aggregatorService?.getSmt().rootHash(); + smtRootHash = hash?.toString() ?? null; } catch (error) { logger.error('Error getting SMT root hash in health endpoint:', error); } diff --git a/src/smt/Smt.ts b/src/smt/Smt.ts index 64d73da..b036a87 100644 --- a/src/smt/Smt.ts +++ b/src/smt/Smt.ts @@ -1,140 +1,59 @@ import { DataHash } from '@unicitylabs/commons/lib/hash/DataHash.js'; import { LeafInBranchError } from '@unicitylabs/commons/lib/smt/LeafInBranchError.js'; import { MerkleTreePath } from '@unicitylabs/commons/lib/smt/MerkleTreePath.js'; -import { MerkleTreeRootNode } from '@unicitylabs/commons/lib/smt/MerkleTreeRootNode.js'; import { SparseMerkleTree } from '@unicitylabs/commons/lib/smt/SparseMerkleTree.js'; import logger from '../logger.js'; - /** * Wrapper for SparseMerkleTree that provides concurrency control * using a locking mechanism to ensure sequential execution of * asynchronous operations. */ export class Smt { - private smtUpdateLock: boolean = false; - private waitingPromises: Array<{ - resolve: () => void; - reject: (error: Error) => void; - timer: NodeJS.Timeout; - }> = []; - - // Lock timeout in milliseconds (10 seconds) - private readonly LOCK_TIMEOUT_MS = 10000; - /** * Creates a new SMT wrapper * @param smt The SparseMerkleTree to wrap - * @param _root SparseMerkleTreeRoot representing the current state of the tree */ - private constructor( - private readonly smt: SparseMerkleTree, - private _root: MerkleTreeRootNode, - ) {} + public constructor(private readonly smt: SparseMerkleTree) {} /** * Gets the root hash of the tree */ - public get rootHash(): DataHash { - return this._root.hash; - } - - public static async create(smt: SparseMerkleTree): Promise { - return new Smt(smt, await smt.calculateRoot()); + public async rootHash(): Promise { + const root = await this.smt.calculateRoot(); + return root.hash; } /** * Adds a leaf to the SMT with locking to prevent concurrent updates */ public addLeaf(path: bigint, value: Uint8Array): Promise { - return this.withSmtLock(async () => { - await this.smt.addLeaf(path, value); - this._root = await this.smt.calculateRoot(); - }); + return this.smt.addLeaf(path, value); } /** * Gets a proof path for a leaf with locking to ensure consistent view */ - public getPath(path: bigint): MerkleTreePath { - return this._root.getPath(path); + public async getPath(path: bigint): Promise { + const root = await this.smt.calculateRoot(); + return root.getPath(path); } /** * Adds multiple leaves atomically with a single lock */ - public addLeaves(leaves: Array<{ path: bigint; value: Uint8Array }>): Promise { - return this.withSmtLock(async () => { - await Promise.all( - leaves.map((leaf) => - this.smt.addLeaf(leaf.path, leaf.value).catch((error) => { - if (error instanceof LeafInBranchError) { - logger.warn(`Leaf already exists in tree for path ${leaf.path} - skipping`); - } else { - throw error; - } - }), - ), - ); - - this._root = await this.smt.calculateRoot(); - }); - } - - /** - * Acquires a lock for SMT updates with a timeout - * @returns A promise that resolves when the lock is acquired - */ - private acquireSmtLock(): Promise { - if (!this.smtUpdateLock) { - this.smtUpdateLock = true; - return Promise.resolve(); - } - - return new Promise((resolve, reject) => { - // Create a timeout that will reject the promise if the lock isn't acquired in time - const timer = setTimeout(() => { - // Remove this waiting promise from the queue - const index = this.waitingPromises.findIndex((p) => p.timer === timer); - if (index !== -1) { - this.waitingPromises.splice(index, 1); - } - - reject(new Error(`SMT lock acquisition timed out after ${this.LOCK_TIMEOUT_MS}ms`)); - }, this.LOCK_TIMEOUT_MS); - - this.waitingPromises.push({ resolve, reject, timer }); - }); - } - - /** - * Releases the SMT update lock and resolves the next waiting promise - */ - private releaseSmtLock(): void { - if (this.waitingPromises.length > 0) { - const next = this.waitingPromises.shift(); - // Clear the timeout since we're resolving this promise - if (next) { - clearTimeout(next.timer); - next.resolve(); - } - } else { - this.smtUpdateLock = false; - } - } - - /** - * Executes a function while holding the SMT lock - * @param fn The function to execute with the lock held - * @returns The result of the function - */ - public async withSmtLock(fn: () => Promise): Promise { - await this.acquireSmtLock(); - try { - return await fn(); - } finally { - this.releaseSmtLock(); - } + public async addLeaves(leaves: Array<{ path: bigint; value: Uint8Array }>): Promise { + await Promise.all( + leaves.map((leaf) => + this.smt.addLeaf(leaf.path, leaf.value).catch((error) => { + if (error instanceof LeafInBranchError) { + logger.warn(`Leaf already exists in tree for path ${leaf.path} - skipping`); + } else { + throw error; + } + }), + ), + ); } } diff --git a/tests/AggregatorServiceTest.ts b/tests/AggregatorServiceTest.ts index a519b17..ebc3a6a 100644 --- a/tests/AggregatorServiceTest.ts +++ b/tests/AggregatorServiceTest.ts @@ -79,7 +79,7 @@ describe('AggregatorService Tests', () => { initialBlockHash: '185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969', }, bftClient, - await Smt.create(smt), + new Smt(smt), {} as never, recordStorage, {} as never, @@ -95,7 +95,7 @@ describe('AggregatorService Tests', () => { aggregatorService = new AggregatorService( roundManager, - await Smt.create(smt), + new Smt(smt), recordStorage, blockStorage, blockRecordsStorage, diff --git a/tests/RoundManagerUnitTest.ts b/tests/RoundManagerUnitTest.ts index 029d7b3..05e8a74 100644 --- a/tests/RoundManagerUnitTest.ts +++ b/tests/RoundManagerUnitTest.ts @@ -1,5 +1,5 @@ -import { HashAlgorithm } from '@unicitylabs/commons/lib/hash/HashAlgorithm.js'; import { DataHasherFactory } from '@unicitylabs/commons/lib/hash/DataHasherFactory.js'; +import { HashAlgorithm } from '@unicitylabs/commons/lib/hash/HashAlgorithm.js'; import { NodeDataHasher } from '@unicitylabs/commons/lib/hash/NodeDataHasher.js'; import { SparseMerkleTree } from '@unicitylabs/commons/lib/smt/SparseMerkleTree.js'; import mongoose from 'mongoose'; @@ -9,7 +9,6 @@ import { CommitmentStorage } from '../src/commitment/CommitmentStorage.js'; import { MockBftClient } from './consensus/bft/MockBftClient.js'; import { connectToSharedMongo, disconnectFromSharedMongo, generateTestCommitments, clearAllCollections } from './TestUtils.js'; import { BlockStorage } from '../src/hashchain/BlockStorage.js'; -import logger from '../src/logger.js'; import { AggregatorRecordStorage } from '../src/records/AggregatorRecordStorage.js'; import { BlockRecordsStorage } from '../src/records/BlockRecordsStorage.js'; import { RoundManager } from '../src/RoundManager.js'; @@ -56,7 +55,7 @@ describe('Round Manager Tests', () => { blockRecordsStorage = await BlockRecordsStorage.create('test-server'); smtStorage = new SmtStorage(); smt = new SparseMerkleTree(new DataHasherFactory(HashAlgorithm.SHA256, NodeDataHasher)); - const smtWrapper = await Smt.create(smt); + const smtWrapper = new Smt(smt); roundManager = new RoundManager( config, diff --git a/tests/benchmarks/BlockCreationBenchmarkTest.ts b/tests/benchmarks/BlockCreationBenchmarkTest.ts index 636c2f2..fa7f7c3 100644 --- a/tests/benchmarks/BlockCreationBenchmarkTest.ts +++ b/tests/benchmarks/BlockCreationBenchmarkTest.ts @@ -177,7 +177,7 @@ describe('Block Creation Performance Benchmarks', () => { }); beforeEach(async () => { - smt = await Smt.create(new SparseMerkleTree(new DataHasherFactory(HashAlgorithm.SHA256, NodeDataHasher))); + smt = new Smt(new SparseMerkleTree(new DataHasherFactory(HashAlgorithm.SHA256, NodeDataHasher))); mockBftClient = new MockBftClient(); const originalSubmitHash = mockBftClient.submitHash; diff --git a/tests/benchmarks/SmtBenchmarkTest.ts b/tests/benchmarks/SmtBenchmarkTest.ts index 96bc178..c564611 100644 --- a/tests/benchmarks/SmtBenchmarkTest.ts +++ b/tests/benchmarks/SmtBenchmarkTest.ts @@ -10,7 +10,6 @@ import { v4 as uuidv4 } from 'uuid'; import logger from '../../src/logger.js'; - interface ISmtBenchmarkResult { testDescription: string; treeSize: number; diff --git a/tests/smt/SmtChunkedLoadingTest.ts b/tests/smt/SmtChunkedLoadingTest.ts index 12405fa..9cf2be8 100644 --- a/tests/smt/SmtChunkedLoadingTest.ts +++ b/tests/smt/SmtChunkedLoadingTest.ts @@ -116,7 +116,7 @@ describe('SMT Chunked Loading Tests', () => { logger.info('Verifying SMT root hash through round manager...'); const roundManager = gateway.getRoundManager(); - const actualRootHash = roundManager.smt.rootHash; + const actualRootHash = await roundManager.smt.rootHash(); logger.info(`Actual root hash from gateway: ${actualRootHash.toString()}`); logger.info(`Expected root hash: ${expectedRootHash.toString()}`); diff --git a/tests/smt/SmtTest.ts b/tests/smt/SmtTest.ts deleted file mode 100644 index 7822318..0000000 --- a/tests/smt/SmtTest.ts +++ /dev/null @@ -1,200 +0,0 @@ -import { DataHasherFactory } from '@unicitylabs/commons/lib/hash/DataHasherFactory.js'; -import { HashAlgorithm } from '@unicitylabs/commons/lib/hash/HashAlgorithm.js'; -import { NodeDataHasher } from '@unicitylabs/commons/lib/hash/NodeDataHasher.js'; -import { SparseMerkleTree } from '@unicitylabs/commons/lib/smt/SparseMerkleTree.js'; - -import logger from '../../src/logger.js'; -import { Smt } from '../../src/smt/Smt.js'; -import { delay } from '../TestUtils.js'; - -describe('SMT Wrapper Tests', () => { - jest.setTimeout(30000); - - let smt: SparseMerkleTree; - let smtWrapper: Smt; - - beforeEach(async () => { - smt = new SparseMerkleTree(new DataHasherFactory(HashAlgorithm.SHA256, NodeDataHasher)); - smtWrapper = await Smt.create(smt); - }); - - const createAddLeafOperation = - (id: number, duration: number) => async (): Promise<{ id: number; status: string }> => { - try { - const startTime = Date.now(); - logger.info(`Operation ${id}: Starting addLeaf (with ${duration}ms simulated processing)`); - - const path = BigInt(1000 + id); - const value = new Uint8Array([id % 256, (id * 2) % 256]); - - await smtWrapper.withSmtLock(async () => { - await delay(duration); - await smt.addLeaf(path, value); - }); - - const totalTime = Date.now() - startTime; - logger.info(`Operation ${id}: Completed addLeaf in ${totalTime}ms`); - - return { id, status: 'success' }; - } catch (error) { - logger.error(`Operation ${id}: Failed:`, error); - return { id, status: 'error' }; - } - }; - - const createReadOperation = (id: number, duration: number) => async (): Promise<{ id: number; status: string }> => { - try { - const startTime = Date.now(); - logger.info(`Operation ${id}: Starting read (with ${duration}ms simulated processing)`); - - const path = BigInt(id); - - await smtWrapper.withSmtLock(async () => { - await delay(duration); - return smt.calculateRoot().then((root) => root.getPath(path)); - }); - - const totalTime = Date.now() - startTime; - logger.info(`Operation ${id}: Completed read in ${totalTime}ms`); - - return { id, status: 'success' }; - } catch (error) { - logger.error(`Operation ${id}: Failed:`, error); - return { id, status: 'error' }; - } - }; - - it('should properly lock during concurrent operations', async () => { - const operations = [ - createAddLeafOperation(1, 300), - createReadOperation(2, 100), - createAddLeafOperation(3, 500), - createReadOperation(4, 200), - createAddLeafOperation(5, 150), - ]; - - const startTime = Date.now(); - const results = await Promise.all(operations.map((op) => op())); - const totalTime = Date.now() - startTime; - - logger.info(`All operations completed in ${totalTime}ms`); - logger.info('Results:', results); - - expect(results.every((r) => r.status === 'success')).toBe(true); - - const sumOfOperationTimes = 300 + 100 + 500 + 200 + 150; - expect(totalTime).toBeGreaterThanOrEqual(sumOfOperationTimes); - - expect(smtWrapper.rootHash).toBeDefined(); - }); - - it('should execute operations in FIFO order when waiting for lock', async () => { - const executionOrder: number[] = []; - - const createTrackedOperation = (id: number, duration: number) => async (): Promise => { - await smtWrapper.withSmtLock(async () => { - executionOrder.push(id); - await delay(duration); - }); - }; - - const longOperation = createTrackedOperation(1, 500); - const longOperationPromise = longOperation(); - - await delay(50); - - const op2Promise = createTrackedOperation(2, 50)(); - const op3Promise = createTrackedOperation(3, 50)(); - const op4Promise = createTrackedOperation(4, 50)(); - - await Promise.all([longOperationPromise, op2Promise, op3Promise, op4Promise]); - - expect(executionOrder).toEqual([1, 2, 3, 4]); - }); - - it('should timeout if lock acquisition takes too long', async () => { - Object.defineProperty(smtWrapper, 'LOCK_TIMEOUT_MS', { value: 100 }); - - const longOperation = async (): Promise => { - await smtWrapper.withSmtLock(async () => { - await delay(1000); - }); - }; - - const longOperationPromise = longOperation(); - await delay(50); - - try { - await smtWrapper.withSmtLock(() => { - fail('Should not reach this point'); - }); - fail('Expected lock acquisition to timeout'); - } catch (error) { - expect((error as Error).message).toContain('lock acquisition timed out'); - } - - await longOperationPromise; - }); - - it('should process batch addLeaves atomically', async () => { - const leavesToAdd = [ - { path: BigInt(1), value: new Uint8Array([1]) }, - { path: BigInt(2), value: new Uint8Array([2]) }, - ]; - - await smtWrapper.addLeaves(leavesToAdd); - const readRootDuringAdd = smtWrapper.rootHash; - - const tempSmt = new SparseMerkleTree(new DataHasherFactory(HashAlgorithm.SHA256, NodeDataHasher)); - await Promise.all([ - tempSmt.addLeaf(BigInt(1), new Uint8Array([1])), - tempSmt.addLeaf(BigInt(2), new Uint8Array([2])), - ]); - const expectedRoot = await tempSmt.calculateRoot(); - - expect(readRootDuringAdd).toBeDefined(); - expect(readRootDuringAdd!.equals(expectedRoot.hash)).toBe(true); - expect(smtWrapper.rootHash.equals(expectedRoot.hash)).toBe(true); - }); - - it('should skip duplicate leaves when using addLeaves batch function', async () => { - // Add initial leaf directly - const path = BigInt(42); - const value = new Uint8Array([1, 2, 3]); - - await smtWrapper.addLeaf(path, value); - const rootAfterFirstAdd = smtWrapper.rootHash; - - // Adding the same leaf with addLeaf should throw an error - try { - await smtWrapper.addLeaf(path, value); - fail('Expected error when adding duplicate leaf with addLeaf'); - } catch (error) { - expect((error as Error).message).toContain('Cannot add leaf inside branch'); - } - - // Root should remain unchanged after error - expect(smtWrapper.rootHash.equals(rootAfterFirstAdd)).toBe(true); - - // Now add a batch containing the duplicate leaf and a new leaf - const newPath = BigInt(43); - const newValue = new Uint8Array([4, 5, 6]); - - // This should succeed, skipping the duplicate but adding the new leaf - await smtWrapper.addLeaves([ - { path, value }, - { path: newPath, value: newValue }, - ]); - - // Root should change after adding the new leaf (duplicate was skipped) - const rootAfterBatchAdd = smtWrapper.rootHash; - expect(rootAfterBatchAdd.equals(rootAfterFirstAdd)).toBe(false); - - // Verify both leaves exist by getting their paths - const path1 = await smtWrapper.getPath(path); - const path2 = await smtWrapper.getPath(newPath); - - expect(path1).toBeDefined(); - expect(path2).toBeDefined(); - }); -});