From 5477fa0b3a5b1d6ad731c8424a56d0b5002d1ed9 Mon Sep 17 00:00:00 2001 From: Koko Bhadra Date: Tue, 10 Mar 2026 16:48:53 -0400 Subject: [PATCH 1/5] Add batch eth_call + pure Zig DEX math (V2/V3/Router) Batch RPC (#11): - HttpTransport.requestBatch() for JSON-RPC batch requests - BatchCaller with addCall/execute/reset and per-call error handling - Response matching by id, partial failure support DEX Math (#16): - dex/v2.zig: getAmountOut/In with configurable fees, multi-hop, arb profit - dex/v3.zig: TickMath, SqrtPriceMath, SwapMath, multi-tick simulation - dex/router.zig: mixed V2/V3 routing, binary search arb detection - uint256.zig: mulDivRoundingUp, getAmountOut delegates to dex/v2 Benchmarks: V2 25ns, V3 swap step 1.5us, 3-hop 88ns Closes #11, closes #16 --- bench/u256_bench.zig | 54 +++ docs/content/docs/batch-calls.mdx | 86 ++++ docs/content/docs/dex-math.mdx | 182 +++++++ docs/content/docs/meta.json | 2 + src/dex/router.zig | 295 ++++++++++++ src/dex/v2.zig | 230 +++++++++ src/dex/v3.zig | 757 ++++++++++++++++++++++++++++++ src/http_transport.zig | 95 ++++ src/provider.zig | 266 ++++++++++- src/root.zig | 9 + src/uint256.zig | 46 +- 11 files changed, 2004 insertions(+), 18 deletions(-) create mode 100644 docs/content/docs/batch-calls.mdx create mode 100644 docs/content/docs/dex-math.mdx create mode 100644 src/dex/router.zig create mode 100644 src/dex/v2.zig create mode 100644 src/dex/v3.zig diff --git a/bench/u256_bench.zig b/bench/u256_bench.zig index 0a8a05c..402ec4d 100644 --- a/bench/u256_bench.zig +++ b/bench/u256_bench.zig @@ -156,6 +156,51 @@ fn benchMulDiv() void { std.mem.doNotOptimizeAway(&result); } +// DEX V2 getAmountOut with configurable fee (dex/v2.zig) +fn benchDexV2AmountOut() void { + var amount_in: u256 = ONE_ETH; + var reserve_in: u256 = RESERVE_IN; + var reserve_out: u256 = RESERVE_OUT; + std.mem.doNotOptimizeAway(&amount_in); + std.mem.doNotOptimizeAway(&reserve_in); + std.mem.doNotOptimizeAway(&reserve_out); + + const amount_out = eth.dex_v2.getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); + std.mem.doNotOptimizeAway(&amount_out); +} + +// DEX V2 multi-hop: 3-hop path +fn benchDexV2MultiHop() void { + const path = [_]eth.dex_v2.Pair{ + .{ .reserve_in = RESERVE_IN, .reserve_out = RESERVE_OUT }, + .{ .reserve_in = 200_000_000_000, .reserve_out = 50_000_000_000_000_000_000 }, + .{ .reserve_in = 50_000_000_000_000_000_000, .reserve_out = 100_000_000_000 }, + }; + var amount_in: u256 = ONE_ETH; + std.mem.doNotOptimizeAway(&amount_in); + + const result = eth.dex_v2.getAmountsOut(amount_in, &path); + std.mem.doNotOptimizeAway(&result); +} + +// DEX V3 getSqrtRatioAtTick +fn benchDexV3TickToPrice() void { + var tick: i24 = 10000; + std.mem.doNotOptimizeAway(&tick); + const result = eth.dex_v3.getSqrtRatioAtTick(tick); + std.mem.doNotOptimizeAway(&result); +} + +// DEX V3 computeSwapStep +fn benchDexV3SwapStep() void { + var current: u256 = SQRT_PRICE; + var target: u256 = eth.dex_v3.getSqrtRatioAtTick(-100).?; + std.mem.doNotOptimizeAway(¤t); + std.mem.doNotOptimizeAway(&target); + const result = eth.dex_v3.computeSwapStep(current, target, 1_000_000_000_000_000_000, 1_000_000, 3000); + std.mem.doNotOptimizeAway(&result); +} + // UniswapV4 getNextSqrtPriceFromAmount0RoundingUp fn benchUniswapV4Swap() void { var liquidity: u256 = ONE_ETH; @@ -202,6 +247,11 @@ pub fn main() !void { try runAndPrint("u256_uniswapv2_optimized", benchUniswapV2Optimized, stdout); try runAndPrint("u256_mulDiv", benchMulDiv, stdout); try runAndPrint("u256_uniswapv4_swap", benchUniswapV4Swap, stdout); + // DEX benchmarks + try runAndPrint("dex_v2_amount_out", benchDexV2AmountOut, stdout); + try runAndPrint("dex_v2_multi_hop_3", benchDexV2MultiHop, stdout); + try runAndPrint("dex_v3_tick_to_price", benchDexV3TickToPrice, stdout); + try runAndPrint("dex_v3_swap_step", benchDexV3SwapStep, stdout); try stdout.print("\n", .{}); @@ -215,6 +265,10 @@ pub fn main() !void { try runAndJson("u256_uniswapv2_optimized", benchUniswapV2Optimized, stdout); try runAndJson("u256_mulDiv", benchMulDiv, stdout); try runAndJson("u256_uniswapv4_swap", benchUniswapV4Swap, stdout); + try runAndJson("dex_v2_amount_out", benchDexV2AmountOut, stdout); + try runAndJson("dex_v2_multi_hop_3", benchDexV2MultiHop, stdout); + try runAndJson("dex_v3_tick_to_price", benchDexV3TickToPrice, stdout); + try runAndJson("dex_v3_swap_step", benchDexV3SwapStep, stdout); try stdout.flush(); } diff --git a/docs/content/docs/batch-calls.mdx b/docs/content/docs/batch-calls.mdx new file mode 100644 index 0000000..6456239 --- /dev/null +++ b/docs/content/docs/batch-calls.mdx @@ -0,0 +1,86 @@ +--- +title: Batch RPC Calls +description: Send multiple eth_call requests in a single JSON-RPC round-trip for MEV and high-throughput applications. +--- + +eth.zig supports [JSON-RPC batch requests](https://www.jsonrpc.org/specification#batch), sending multiple `eth_call` requests in a single HTTP POST. This dramatically reduces latency when evaluating many candidates per block. + +## Basic Usage + +```zig +const eth = @import("eth"); + +// Set up provider +var transport = eth.http_transport.HttpTransport.init(allocator, "https://rpc.example.com"); +defer transport.deinit(); +var provider = eth.provider.Provider.init(allocator, &transport); + +// Create a batch +var batch = eth.provider.BatchCaller.init(allocator, &provider); +defer batch.deinit(); + +// Add calls (returns index for result retrieval) +const idx0 = try batch.addCall(pool_a_address, quote_calldata_a); +const idx1 = try batch.addCall(pool_b_address, quote_calldata_b); +const idx2 = try batch.addCall(pool_c_address, quote_calldata_c); + +// Execute all in one HTTP request +const results = try batch.execute(); +defer eth.provider.freeBatchResults(allocator, results); + +// Each result is either .success or .rpc_error +switch (results[idx0]) { + .success => |data| { + // data contains the decoded bytes from the RPC response + // Decode as needed (e.g., ABI decode a uint256) + }, + .rpc_error => |err| { + // err.code and err.message describe what went wrong + // (e.g., execution reverted, invalid params) + }, +} +``` + +## How It Works + +Under the hood, `BatchCaller` uses the [JSON-RPC batch spec](https://www.jsonrpc.org/specification#batch): +- Each call is formatted as an individual JSON-RPC request with a unique `id` +- All requests are wrapped in a JSON array and sent as a single HTTP POST +- Responses may arrive in any order -- `BatchCaller` matches them by `id` and returns results in the original `addCall` order + +## Per-Call Error Handling + +Some calls in a batch may succeed while others revert. Each `BatchCallResult` is independent: + +```zig +for (results, 0..) |result, i| { + switch (result) { + .success => |data| std.debug.print("Call {d}: {d} bytes\n", .{ i, data.len }), + .rpc_error => |err| std.debug.print("Call {d}: error {d} - {s}\n", .{ i, err.code, err.message }), + } +} +``` + +## Reusing a Batch + +Call `reset()` to clear pending calls and reuse the `BatchCaller`: + +```zig +batch.reset(); +// Add new calls for the next block... +_ = try batch.addCall(new_target, new_calldata); +const new_results = try batch.execute(); +defer eth.provider.freeBatchResults(allocator, new_results); +``` + +## When to Use Batch vs Multicall + +| Feature | BatchCaller | Multicall3 | +|---------|-------------|------------| +| Protocol | JSON-RPC batch | On-chain contract call | +| Atomicity | Independent calls | Single transaction | +| Node support | All JSON-RPC nodes | Requires Multicall3 deployment | +| Gas overhead | None | Contract execution gas | +| Best for | Mixed RPC methods | Same-block state consistency | + +For MEV searchers: use `BatchCaller` when you need to query multiple pools across different blocks or need raw `eth_call` flexibility. Use `Multicall3` when you need atomic same-block reads. diff --git a/docs/content/docs/dex-math.mdx b/docs/content/docs/dex-math.mdx new file mode 100644 index 0000000..24bc289 --- /dev/null +++ b/docs/content/docs/dex-math.mdx @@ -0,0 +1,182 @@ +--- +title: DEX Math +description: Pure Zig Uniswap V2/V3 price computation for off-chain MEV simulation. +--- + +eth.zig includes pure Zig implementations of Uniswap V2 and V3 math, enabling off-chain price computation without RPC calls. This is critical for MEV searchers who need to evaluate hundreds of arb paths per block at sub-microsecond latency. + +## Uniswap V2 + +### Single Swap + +```zig +const eth = @import("eth"); + +// Standard Uniswap V2 (0.3% fee = 997/1000) +const amount_out = eth.dex_v2.getAmountOut( + 1_000_000_000_000_000_000, // 1 ETH input + 100_000_000_000_000_000_000, // 100 ETH reserve_in + 200_000_000_000, // 200k USDC reserve_out + 997, 1000, // fee: 0.3% +); + +// SushiSwap / PancakeSwap (0.25% fee = 9975/10000) +const sushi_out = eth.dex_v2.getAmountOut( + amount_in, reserve_in, reserve_out, + 9975, 10000, // fee: 0.25% +); +``` + +### Inverse: Required Input + +```zig +// How much ETH do I need to get exactly 1000 USDC out? +const required_input = eth.dex_v2.getAmountIn( + 1_000_000_000, // 1000 USDC desired output + reserve_eth, reserve_usdc, + 997, 1000, +) orelse { + // null = insufficient liquidity (amount_out >= reserve_out) + return error.InsufficientLiquidity; +}; +``` + +### Multi-Hop Paths + +```zig +const path = [_]eth.dex_v2.Pair{ + .{ .reserve_in = 100e18, .reserve_out = 200_000e6 }, // ETH -> USDC + .{ .reserve_in = 300_000e6, .reserve_out = 50e18 }, // USDC -> DAI +}; + +// Forward: how much DAI for 1 ETH? +const output = eth.dex_v2.getAmountsOut(1e18, &path); + +// Reverse: how much ETH for 10 DAI? +const input = eth.dex_v2.getAmountsIn(10e18, &path); +``` + +### Arbitrage Profit + +```zig +// Circular path: buy on pool A, sell on pool B +const path = [_]eth.dex_v2.Pair{ + .{ .reserve_in = 1_000_000, .reserve_out = 2_000_000_000 }, + .{ .reserve_in = 2_000_000_000, .reserve_out = 2_000_000 }, +}; + +if (eth.dex_v2.calculateProfit(1000, &path)) |profit| { + // profit = output - input (positive means arb exists) +} +``` + +## Uniswap V3 + +### Tick/Price Conversion + +```zig +// Convert tick index to sqrtPriceX96 (Q96 fixed-point) +const sqrt_price = eth.dex_v3.getSqrtRatioAtTick(100).?; // tick 100 ~ price 1.01 +const sqrt_price_0 = eth.dex_v3.getSqrtRatioAtTick(0).?; // tick 0 = price 1.0 = Q96 + +// Convert sqrtPriceX96 back to tick +const tick = eth.dex_v3.getTickAtSqrtRatio(sqrt_price).?; // = 100 +``` + +### Token Amount Deltas + +```zig +const sqrt_a = eth.dex_v3.getSqrtRatioAtTick(0).?; +const sqrt_b = eth.dex_v3.getSqrtRatioAtTick(100).?; +const liquidity: u128 = 1_000_000_000_000_000_000; + +// How much token0 for a price move from tick 0 to tick 100? +const token0_amount = eth.dex_v3.getAmount0Delta(sqrt_a, sqrt_b, liquidity, true).?; + +// How much token1? +const token1_amount = eth.dex_v3.getAmount1Delta(sqrt_a, sqrt_b, liquidity, true).?; +``` + +### Swap Step Simulation + +```zig +// Simulate a single swap step within one tick range +const step = eth.dex_v3.computeSwapStep( + current_sqrt_price, // current pool price + target_sqrt_price, // next initialized tick boundary + pool_liquidity, // current active liquidity + @as(i256, amount_in), // positive = exact input, negative = exact output + 3000, // fee: 0.3% (3000 pips) +); + +// step.sqrt_ratio_next_x96 -- price after this step +// step.amount_in -- tokens consumed +// step.amount_out -- tokens produced +// step.fee_amount -- fee charged +``` + +### Full Multi-Tick Swap Simulation + +```zig +const ticks = [_]eth.dex_v3.TickInfo{ + .{ .tick = -100, .liquidity_net = 500e18, .sqrt_price_x96 = eth.dex_v3.getSqrtRatioAtTick(-100).? }, + .{ .tick = -200, .liquidity_net = 300e18, .sqrt_price_x96 = eth.dex_v3.getSqrtRatioAtTick(-200).? }, +}; + +const result = eth.dex_v3.simulateSwap( + current_sqrt_price, + current_liquidity, + &ticks, + amount_in, + true, // zero_for_one (selling token0) + 3000, // 0.3% fee +); + +// result.amount_in_consumed -- total input consumed +// result.amount_out -- total output received +// result.sqrt_price_final_x96 -- final pool price +// result.ticks_crossed -- number of tick boundaries crossed +``` + +## Cross-DEX Router + +Route through mixed V2/V3 pools: + +```zig +const hops = [_]eth.dex_router.Pool{ + .{ .v2 = .{ .reserve_in = 100e18, .reserve_out = 200_000e6 } }, + .{ .v3 = .{ + .sqrt_price_x96 = current_sqrt, + .liquidity = pool_liquidity, + .ticks = &tick_array, + .fee_pips = 3000, + .zero_for_one = true, + } }, +}; + +const output = eth.dex_router.quoteExactInput(1e18, &hops); +``` + +### Arbitrage Detection + +```zig +if (eth.dex_router.findArbOpportunity(&circular_path, max_input)) |arb| { + // arb.profit -- expected profit + // arb.optimal_input -- optimal trade size (found via binary search) +} +``` + +## Fee Configurations + +| DEX | fee_numerator | fee_denominator | Fee | +|-----|--------------|-----------------|-----| +| Uniswap V2 | 997 | 1000 | 0.30% | +| SushiSwap | 997 | 1000 | 0.30% | +| PancakeSwap | 9975 | 10000 | 0.25% | +| Uniswap V3 | fee_pips=500 | -- | 0.05% | +| Uniswap V3 | fee_pips=3000 | -- | 0.30% | +| Uniswap V3 | fee_pips=10000 | -- | 1.00% | + +## Performance + +All math uses eth.zig's limb-native u256 arithmetic -- no heap allocation, no LLVM software division routines. Run `zig build bench-u256` to measure on your hardware. diff --git a/docs/content/docs/meta.json b/docs/content/docs/meta.json index d1d68fb..f4a8d28 100644 --- a/docs/content/docs/meta.json +++ b/docs/content/docs/meta.json @@ -12,6 +12,8 @@ "comptime", "ens", "websockets", + "batch-calls", + "dex-math", "---Reference---", "modules", "benchmarks", diff --git a/src/dex/router.zig b/src/dex/router.zig new file mode 100644 index 0000000..9626a44 --- /dev/null +++ b/src/dex/router.zig @@ -0,0 +1,295 @@ +const std = @import("std"); +const v2 = @import("v2.zig"); +const v3 = @import("v3.zig"); + +// ============================================================================ +// Types +// ============================================================================ + +/// A pool in a multi-hop path. Can be V2 (constant-product) or V3 (concentrated liquidity). +pub const Pool = union(enum) { + v2: V2Pool, + v3: V3Pool, + + pub const V2Pool = struct { + reserve_in: u256, + reserve_out: u256, + fee_numerator: u64 = 997, + fee_denominator: u64 = 1000, + }; + + pub const V3Pool = struct { + sqrt_price_x96: u256, + liquidity: u128, + ticks: []const v3.TickInfo, + fee_pips: u24, + zero_for_one: bool, + }; +}; + +pub const ArbOpportunity = struct { + profit: u256, + optimal_input: u256, +}; + +// ============================================================================ +// Routing +// ============================================================================ + +/// Quote exact input through a mixed V2/V3 path. +/// Returns the final output amount, or null if any hop fails. +pub fn quoteExactInput(amount_in: u256, hops: []const Pool) ?u256 { + if (hops.len == 0) return null; + if (amount_in == 0) return @as(u256, 0); + + var current = amount_in; + for (hops) |pool| { + current = quotePool(current, pool) orelse return null; + if (current == 0) return null; + } + return current; +} + +/// Quote exact output through a mixed V2/V3 path (reverse). +/// Returns the required input amount, or null if any hop fails. +/// Only supports V2 hops (V3 reverse quoting requires tick state traversal). +pub fn quoteExactOutput(amount_out: u256, hops: []const Pool) ?u256 { + if (hops.len == 0) return null; + + var current = amount_out; + var i: usize = hops.len; + while (i > 0) { + i -= 1; + switch (hops[i]) { + .v2 => |p| { + current = v2.getAmountIn(current, p.reserve_in, p.reserve_out, p.fee_numerator, p.fee_denominator) orelse return null; + }, + .v3 => { + // V3 reverse quoting is complex (requires iterating ticks in reverse) + // Return null to signal unsupported for now + return null; + }, + } + } + return current; +} + +/// Find the optimal input amount for a circular arb path using binary search. +/// Profit is concave for constant-product AMMs, so binary search on the derivative works. +/// Returns null if no profitable opportunity exists. +pub fn findArbOpportunity(hops: []const Pool, max_input: u256) ?ArbOpportunity { + if (hops.len == 0) return null; + + // Check if there's any profit at all with a small amount + const small_amount: u256 = 1000; + const small_output = quoteExactInput(small_amount, hops) orelse return null; + if (small_output <= small_amount) return null; + + // Binary search for optimal input + // The profit function is concave, so we search for the peak + var lo: u256 = 1; + var hi: u256 = max_input; + + // Run binary search for ~100 iterations (enough for u256 precision) + var iterations: u32 = 0; + while (lo < hi and iterations < 128) : (iterations += 1) { + // Avoid overflow in midpoint calculation + const mid = lo + (hi - lo) / 2; + if (mid == lo) break; + + const mid_output = quoteExactInput(mid, hops) orelse break; + if (mid == std.math.maxInt(u256)) break; + const mid_plus = quoteExactInput(mid + 1, hops) orelse break; + + // Check marginal profit at mid: is f(mid+1) - f(mid) > 1? + // If marginal output > marginal input (1), we can increase input + if (mid_plus > mid_output and mid_plus - mid_output > 1) { + // Still profitable to increase - marginal output > marginal input + lo = mid; + } else { + hi = mid; + } + } + + // Evaluate profit at the found optimal point + const optimal = lo; + const output = quoteExactInput(optimal, hops) orelse return null; + if (output <= optimal) return null; + + return .{ + .profit = output - optimal, + .optimal_input = optimal, + }; +} + +// ============================================================================ +// Internal helpers +// ============================================================================ + +/// Quote a single pool hop. +fn quotePool(amount_in: u256, pool: Pool) ?u256 { + switch (pool) { + .v2 => |p| { + const result = v2.getAmountOut(amount_in, p.reserve_in, p.reserve_out, p.fee_numerator, p.fee_denominator); + return if (result == 0) null else result; + }, + .v3 => |p| { + const result = v3.simulateSwap( + p.sqrt_price_x96, + p.liquidity, + p.ticks, + amount_in, + p.zero_for_one, + p.fee_pips, + ); + return if (result.amount_out == 0) null else result.amount_out; + }, + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +test "quoteExactInput V2 single hop" { + const hops = [_]Pool{ + .{ .v2 = .{ + .reserve_in = 100_000_000_000_000_000_000, + .reserve_out = 200_000_000_000, + } }, + }; + + const result = quoteExactInput(1_000_000_000_000_000_000, &hops); + try std.testing.expect(result != null); + + // Should match direct V2 calculation + const direct = v2.getAmountOut(1_000_000_000_000_000_000, 100_000_000_000_000_000_000, 200_000_000_000, 997, 1000); + try std.testing.expectEqual(direct, result.?); +} + +test "quoteExactInput V2 multi-hop" { + const hops = [_]Pool{ + .{ .v2 = .{ + .reserve_in = 100_000_000_000_000_000_000, + .reserve_out = 200_000_000_000, + } }, + .{ .v2 = .{ + .reserve_in = 300_000_000_000, + .reserve_out = 50_000_000_000_000_000_000, + } }, + }; + + const result = quoteExactInput(1_000_000_000_000_000_000, &hops); + try std.testing.expect(result != null); + try std.testing.expect(result.? > 0); +} + +test "quoteExactInput V3 single hop" { + const current_sqrt = v3.getSqrtRatioAtTick(0).?; + const tick_info = [_]v3.TickInfo{ + .{ + .tick = -100, + .liquidity_net = 500_000_000_000_000_000, + .sqrt_price_x96 = v3.getSqrtRatioAtTick(-100).?, + }, + }; + + const hops = [_]Pool{ + .{ .v3 = .{ + .sqrt_price_x96 = current_sqrt, + .liquidity = 1_000_000_000_000_000_000, + .ticks = &tick_info, + .fee_pips = 3000, + .zero_for_one = true, + } }, + }; + + const result = quoteExactInput(1_000_000, &hops); + try std.testing.expect(result != null); + try std.testing.expect(result.? > 0); +} + +test "quoteExactInput mixed V2+V3" { + const current_sqrt = v3.getSqrtRatioAtTick(0).?; + const tick_info = [_]v3.TickInfo{ + .{ + .tick = -100, + .liquidity_net = 500_000_000_000_000_000, + .sqrt_price_x96 = v3.getSqrtRatioAtTick(-100).?, + }, + }; + + const hops = [_]Pool{ + // V2 hop first + .{ .v2 = .{ + .reserve_in = 100_000_000_000_000_000_000, + .reserve_out = 200_000_000_000, + } }, + // V3 hop second + .{ .v3 = .{ + .sqrt_price_x96 = current_sqrt, + .liquidity = 1_000_000_000_000_000_000, + .ticks = &tick_info, + .fee_pips = 3000, + .zero_for_one = true, + } }, + }; + + const result = quoteExactInput(1_000_000_000_000_000_000, &hops); + try std.testing.expect(result != null); + try std.testing.expect(result.? > 0); +} + +test "quoteExactInput zero amount" { + const hops = [_]Pool{ + .{ .v2 = .{ + .reserve_in = 100_000_000_000, + .reserve_out = 200_000_000_000, + } }, + }; + + const result = quoteExactInput(0, &hops); + try std.testing.expectEqual(@as(?u256, 0), result); +} + +test "quoteExactInput empty path" { + const result = quoteExactInput(1000, &.{}); + try std.testing.expectEqual(@as(?u256, null), result); +} + +test "quoteExactOutput V2 single hop" { + const hops = [_]Pool{ + .{ .v2 = .{ + .reserve_in = 100_000_000_000_000_000_000, + .reserve_out = 200_000_000_000, + } }, + }; + + const result = quoteExactOutput(1_000_000_000, &hops); + try std.testing.expect(result != null); + try std.testing.expect(result.? > 0); +} + +test "findArbOpportunity profitable V2" { + // Imbalanced pools create arb opportunity + const hops = [_]Pool{ + .{ .v2 = .{ .reserve_in = 1_000_000, .reserve_out = 2_000_000_000 } }, + .{ .v2 = .{ .reserve_in = 2_000_000_000, .reserve_out = 2_000_000 } }, + }; + + const result = findArbOpportunity(&hops, 1_000_000); + try std.testing.expect(result != null); + try std.testing.expect(result.?.profit > 0); + try std.testing.expect(result.?.optimal_input > 0); +} + +test "findArbOpportunity unprofitable" { + // Equal pools with fees = no arb + const hops = [_]Pool{ + .{ .v2 = .{ .reserve_in = 1_000_000_000, .reserve_out = 1_000_000_000 } }, + .{ .v2 = .{ .reserve_in = 1_000_000_000, .reserve_out = 1_000_000_000 } }, + }; + + const result = findArbOpportunity(&hops, 1_000_000_000); + try std.testing.expectEqual(@as(?ArbOpportunity, null), result); +} diff --git a/src/dex/v2.zig b/src/dex/v2.zig new file mode 100644 index 0000000..88459f5 --- /dev/null +++ b/src/dex/v2.zig @@ -0,0 +1,230 @@ +const std = @import("std"); +const uint256_mod = @import("../uint256.zig"); + +const u256ToLimbs = uint256_mod.u256ToLimbs; +const limbsToU256 = uint256_mod.limbsToU256; +const mulLimbs = uint256_mod.mulLimbs; +const mulLimbScalar = uint256_mod.mulLimbScalar; +const divLimbsDirect = uint256_mod.divLimbsDirect; +const addLimbs = uint256_mod.addLimbs; + +// ============================================================================ +// Types +// ============================================================================ + +pub const Pair = struct { + reserve_in: u256, + reserve_out: u256, + fee_numerator: u64 = 997, + fee_denominator: u64 = 1000, +}; + +// ============================================================================ +// Core Functions +// ============================================================================ + +/// Compute UniswapV2 getAmountOut with configurable fee, entirely in u64-limb space. +/// Formula: (amountIn * feeNum * reserveOut) / (reserveIn * feeDenom + amountIn * feeNum) +pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256, fee_numerator: u64, fee_denominator: u64) u256 { + if (amount_in == 0) return 0; + + const ai = u256ToLimbs(amount_in); + const ri = u256ToLimbs(reserve_in); + const ro = u256ToLimbs(reserve_out); + + const amount_in_with_fee = mulLimbScalar(ai, fee_numerator); + const numerator = mulLimbs(amount_in_with_fee, ro); + const denominator = addLimbs(mulLimbScalar(ri, fee_denominator), amount_in_with_fee); + + if (denominator[0] == 0 and denominator[1] == 0 and denominator[2] == 0 and denominator[3] == 0) { + @panic("getAmountOut: denominator is zero (invalid reserves)"); + } + + return limbsToU256(divLimbsDirect(numerator, denominator)); +} + +/// Compute UniswapV2 getAmountIn with configurable fee. +/// Formula: (reserveIn * amountOut * feeDenom) / ((reserveOut - amountOut) * feeNum) + 1 +/// Returns null if amount_out >= reserve_out (insufficient liquidity). +pub fn getAmountIn(amount_out: u256, reserve_in: u256, reserve_out: u256, fee_numerator: u64, fee_denominator: u64) ?u256 { + if (amount_out == 0) return 0; + if (amount_out >= reserve_out) return null; + + const reserve_diff = reserve_out - amount_out; + + const ri = u256ToLimbs(reserve_in); + const ao = u256ToLimbs(amount_out); + const rd = u256ToLimbs(reserve_diff); + + // numerator = reserveIn * amountOut * feeDenom + const ri_times_ao = mulLimbs(ri, ao); + const numerator = mulLimbScalar(ri_times_ao, fee_denominator); + + // denominator = (reserveOut - amountOut) * feeNum + const denominator = mulLimbScalar(rd, fee_numerator); + + if (denominator[0] == 0 and denominator[1] == 0 and denominator[2] == 0 and denominator[3] == 0) { + @panic("getAmountIn: denominator is zero"); + } + + // Uniswap V2 always adds 1 (ceiling) + const quotient = limbsToU256(divLimbsDirect(numerator, denominator)); + return quotient + 1; +} + +// ============================================================================ +// Multi-hop Functions +// ============================================================================ + +/// Chain forward swaps: each hop's output feeds the next hop's input. +/// Returns null if any intermediate output is 0. +pub fn getAmountsOut(amount_in: u256, path: []const Pair) ?u256 { + if (path.len == 0) return null; + + var current = amount_in; + for (path) |pair| { + current = getAmountOut(current, pair.reserve_in, pair.reserve_out, pair.fee_numerator, pair.fee_denominator); + if (current == 0) return null; + } + return current; +} + +/// Chain backward swaps: start from desired output, work backwards to find required input. +/// Returns null if any getAmountIn returns null. +pub fn getAmountsIn(amount_out: u256, path: []const Pair) ?u256 { + if (path.len == 0) return null; + + var current = amount_out; + var i: usize = path.len; + while (i > 0) { + i -= 1; + current = getAmountIn(current, path[i].reserve_in, path[i].reserve_out, path[i].fee_numerator, path[i].fee_denominator) orelse return null; + } + return current; +} + +/// Run getAmountsOut and return profit (output - input) if positive, else null. +pub fn calculateProfit(amount_in: u256, path: []const Pair) ?u256 { + const output = getAmountsOut(amount_in, path) orelse return null; + if (output > amount_in) { + return output - amount_in; + } + return null; +} + +// ============================================================================ +// Tests +// ============================================================================ + +test "getAmountOut matches legacy" { + const amount_in: u256 = 1_000_000_000_000_000_000; // 1 ETH + const reserve_in: u256 = 100_000_000_000_000_000_000; // 100 ETH + const reserve_out: u256 = 200_000_000_000; // 200k USDC (6 decimals) + + const v2_result = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); + const legacy_result = uint256_mod.getAmountOut(amount_in, reserve_in, reserve_out); + try std.testing.expectEqual(legacy_result, v2_result); +} + +test "getAmountOut different fees" { + const amount_in: u256 = 1_000_000_000_000_000_000; + const reserve_in: u256 = 100_000_000_000_000_000_000; + const reserve_out: u256 = 200_000_000_000; + + // PancakeSwap uses 9975/10000 (0.25% fee) vs Uniswap 997/1000 (0.3% fee) + const pancake = getAmountOut(amount_in, reserve_in, reserve_out, 9975, 10000); + const uniswap = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); + + // Lower fee => more output + try std.testing.expect(pancake > uniswap); +} + +test "getAmountOut zero input" { + const result = getAmountOut(0, 100_000, 200_000, 997, 1000); + try std.testing.expectEqual(@as(u256, 0), result); +} + +test "getAmountOut result less than reserve" { + const amounts = [_]u256{ 1, 1000, 1_000_000_000_000_000_000, 50_000_000_000_000_000_000 }; + const reserve_in: u256 = 100_000_000_000_000_000_000; + const reserve_out: u256 = 200_000_000_000; + + for (amounts) |amount_in| { + const result = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); + try std.testing.expect(result < reserve_out); + } +} + +test "getAmountIn inverse" { + const amount_in: u256 = 1_000_000_000_000_000_000; + const reserve_in: u256 = 100_000_000_000_000_000_000; + const reserve_out: u256 = 200_000_000_000; + + const output = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); + const recovered_input = getAmountIn(output, reserve_in, reserve_out, 997, 1000) orelse unreachable; + + // Due to ceiling division (+1), recovered_input >= amount_in + // Should be within 2 units + try std.testing.expect(recovered_input >= amount_in); + try std.testing.expect(recovered_input - amount_in <= 2); +} + +test "getAmountIn insufficient liquidity" { + const reserve_in: u256 = 100_000_000_000_000_000_000; + const reserve_out: u256 = 200_000_000_000; + + // amount_out == reserve_out + try std.testing.expectEqual(@as(?u256, null), getAmountIn(reserve_out, reserve_in, reserve_out, 997, 1000)); + + // amount_out > reserve_out + try std.testing.expectEqual(@as(?u256, null), getAmountIn(reserve_out + 1, reserve_in, reserve_out, 997, 1000)); +} + +test "getAmountsOut multi-hop" { + const path = [_]Pair{ + .{ .reserve_in = 100_000_000_000_000_000_000, .reserve_out = 200_000_000_000 }, // ETH -> USDC + .{ .reserve_in = 300_000_000_000, .reserve_out = 50_000_000_000_000_000_000 }, // USDC -> DAI + }; + + const amount_in: u256 = 1_000_000_000_000_000_000; // 1 ETH + const result = getAmountsOut(amount_in, &path); + try std.testing.expect(result != null); + try std.testing.expect(result.? > 0); +} + +test "getAmountsIn multi-hop" { + const path = [_]Pair{ + .{ .reserve_in = 100_000_000_000_000_000_000, .reserve_out = 200_000_000_000 }, + .{ .reserve_in = 300_000_000_000, .reserve_out = 50_000_000_000_000_000_000 }, + }; + + const desired_output: u256 = 500_000_000_000_000_000; // 0.5 DAI + const required_input = getAmountsIn(desired_output, &path); + try std.testing.expect(required_input != null); + try std.testing.expect(required_input.? > 0); +} + +test "calculateProfit unprofitable" { + // Equal reserves with 0.3% fee each way => guaranteed loss + const path = [_]Pair{ + .{ .reserve_in = 1_000_000_000, .reserve_out = 1_000_000_000 }, + .{ .reserve_in = 1_000_000_000, .reserve_out = 1_000_000_000 }, + }; + + const result = calculateProfit(1_000_000, &path); + try std.testing.expectEqual(@as(?u256, null), result); +} + +test "calculateProfit arithmetic" { + // Imbalanced pools create arbitrage opportunity: + // Pool 1: buy cheap (low reserve_in, high reserve_out) + // Pool 2: sell expensive (high reserve_in, low reserve_out relative to what we got) + const path = [_]Pair{ + .{ .reserve_in = 1_000_000, .reserve_out = 2_000_000_000 }, // very cheap + .{ .reserve_in = 2_000_000_000, .reserve_out = 2_000_000 }, // sell back + }; + + const result = calculateProfit(1000, &path); + try std.testing.expect(result != null); + try std.testing.expect(result.? > 0); +} diff --git a/src/dex/v3.zig b/src/dex/v3.zig new file mode 100644 index 0000000..fd98a5a --- /dev/null +++ b/src/dex/v3.zig @@ -0,0 +1,757 @@ +const std = @import("std"); +const uint256_mod = @import("../uint256.zig"); + +const mulDiv = uint256_mod.mulDiv; +const mulDivRoundingUp = uint256_mod.mulDivRoundingUp; +const u256ToLimbs = uint256_mod.u256ToLimbs; +const limbsToU256 = uint256_mod.limbsToU256; +const mulLimbs = uint256_mod.mulLimbs; +const mulLimbScalar = uint256_mod.mulLimbScalar; +const divLimbsDirect = uint256_mod.divLimbsDirect; +const addLimbs = uint256_mod.addLimbs; +const fastMul = uint256_mod.fastMul; +const fastDiv = uint256_mod.fastDiv; +const mulWide = uint256_mod.mulWide; +const MAX = uint256_mod.MAX; +const ZERO = uint256_mod.ZERO; + +// ============================================================================ +// Constants +// ============================================================================ + +pub const Q96: u256 = uint256_mod.Q96; // 1 << 96 +pub const Q128: u256 = @as(u256, 1) << 128; +pub const MIN_TICK: i24 = -887272; +pub const MAX_TICK: i24 = 887272; +pub const MIN_SQRT_RATIO: u256 = 4295128739; +pub const MAX_SQRT_RATIO: u256 = 1461446703485210103287273052203988822378723970342; + +// ============================================================================ +// TickMath +// ============================================================================ + +/// Precomputed magic constants: sqrt(1.0001^(2^i)) in Q128.128 format. +/// From Uniswap V3 TickMath.sol. +const TICK_RATIOS = [20]u256{ + 0xfffcb933bd6fad37aa2d162d1a594001, // bit 0: sqrt(1.0001^1) + 0xfff97272373d413259a46990580e213a, // bit 1: sqrt(1.0001^2) + 0xfff2e50f5f656932ef12357cf3c7fdcc, // bit 2: sqrt(1.0001^4) + 0xffe5caca7e10e4e61c3624eaa0941cd0, // bit 3: sqrt(1.0001^8) + 0xffcb9843d60f6159c9db58835c926644, // bit 4: sqrt(1.0001^16) + 0xff973b41fa98c081472e6896dfb254c0, // bit 5: sqrt(1.0001^32) + 0xff2ea16466c96a3843ec78b326b52861, // bit 6: sqrt(1.0001^64) + 0xfe5dee046a99a2a811c461f1969c3053, // bit 7: sqrt(1.0001^128) + 0xfcbe86c7900a88aedcffc83b479aa3a4, // bit 8: sqrt(1.0001^256) + 0xf987a7253ac413176f2b074cf7815e54, // bit 9: sqrt(1.0001^512) + 0xf3392b0822b70005940c7a398e4b70f3, // bit 10: sqrt(1.0001^1024) + 0xe7159475a2c29b7443b29c7fa6e889d9, // bit 11: sqrt(1.0001^2048) + 0xd097f3bdfd2022b8845ad8f792aa5825, // bit 12: sqrt(1.0001^4096) + 0xa9f746462d870fdf8a65dc1f90e061e5, // bit 13: sqrt(1.0001^8192) + 0x70d869a156d2a1b890bb3df62baf32f7, // bit 14: sqrt(1.0001^16384) + 0x31be135f97d08fd981231505542fcfa6, // bit 15: sqrt(1.0001^32768) + 0x9aa508b5b7a84e1c677de54f3e99bc9, // bit 16: sqrt(1.0001^65536) + 0x5d6af8dedb81196699c329225ee604, // bit 17: sqrt(1.0001^131072) + 0x2216e584f5fa1ea926041bedfe98, // bit 18: sqrt(1.0001^262144) + 0x48a170391f7dc42444e8fa2, // bit 19: sqrt(1.0001^524288) +}; + +/// Get sqrtPriceX96 from a tick index. +/// Port of Uniswap V3 TickMath.getSqrtRatioAtTick. +pub fn getSqrtRatioAtTick(tick: i24) ?u256 { + // Validate tick range + if (tick < MIN_TICK or tick > MAX_TICK) return null; + + const abs_tick: u24 = @abs(tick); + + // Initialize ratio based on bit 0 + var ratio: u256 = if (abs_tick & 0x1 != 0) + TICK_RATIOS[0] + else + 0x100000000000000000000000000000000; // Q128.128 representation of 1.0 + + // Apply conditional multiplications for bits 1-19 + if (abs_tick & 0x2 != 0) ratio = mulDiv(ratio, TICK_RATIOS[1], Q128) orelse return null; + if (abs_tick & 0x4 != 0) ratio = mulDiv(ratio, TICK_RATIOS[2], Q128) orelse return null; + if (abs_tick & 0x8 != 0) ratio = mulDiv(ratio, TICK_RATIOS[3], Q128) orelse return null; + if (abs_tick & 0x10 != 0) ratio = mulDiv(ratio, TICK_RATIOS[4], Q128) orelse return null; + if (abs_tick & 0x20 != 0) ratio = mulDiv(ratio, TICK_RATIOS[5], Q128) orelse return null; + if (abs_tick & 0x40 != 0) ratio = mulDiv(ratio, TICK_RATIOS[6], Q128) orelse return null; + if (abs_tick & 0x80 != 0) ratio = mulDiv(ratio, TICK_RATIOS[7], Q128) orelse return null; + if (abs_tick & 0x100 != 0) ratio = mulDiv(ratio, TICK_RATIOS[8], Q128) orelse return null; + if (abs_tick & 0x200 != 0) ratio = mulDiv(ratio, TICK_RATIOS[9], Q128) orelse return null; + if (abs_tick & 0x400 != 0) ratio = mulDiv(ratio, TICK_RATIOS[10], Q128) orelse return null; + if (abs_tick & 0x800 != 0) ratio = mulDiv(ratio, TICK_RATIOS[11], Q128) orelse return null; + if (abs_tick & 0x1000 != 0) ratio = mulDiv(ratio, TICK_RATIOS[12], Q128) orelse return null; + if (abs_tick & 0x2000 != 0) ratio = mulDiv(ratio, TICK_RATIOS[13], Q128) orelse return null; + if (abs_tick & 0x4000 != 0) ratio = mulDiv(ratio, TICK_RATIOS[14], Q128) orelse return null; + if (abs_tick & 0x8000 != 0) ratio = mulDiv(ratio, TICK_RATIOS[15], Q128) orelse return null; + if (abs_tick & 0x10000 != 0) ratio = mulDiv(ratio, TICK_RATIOS[16], Q128) orelse return null; + if (abs_tick & 0x20000 != 0) ratio = mulDiv(ratio, TICK_RATIOS[17], Q128) orelse return null; + if (abs_tick & 0x40000 != 0) ratio = mulDiv(ratio, TICK_RATIOS[18], Q128) orelse return null; + if (abs_tick & 0x80000 != 0) ratio = mulDiv(ratio, TICK_RATIOS[19], Q128) orelse return null; + + // If tick > 0, invert the ratio + if (tick > 0) { + ratio = MAX / ratio; + } + + // Convert from Q128.128 to Q96.96: right-shift by 32, rounding up if remainder + const remainder = ratio & ((@as(u256, 1) << 32) - 1); + const shifted = ratio >> 32; + return shifted + if (remainder == 0) @as(u256, 0) else @as(u256, 1); +} + +/// Get tick index from sqrtPriceX96. +/// Port of Uniswap V3 TickMath.getTickAtSqrtRatio. +pub fn getTickAtSqrtRatio(sqrt_price_x96: u256) ?i24 { + // Validate: MIN_SQRT_RATIO <= sqrt_price_x96 < MAX_SQRT_RATIO + if (sqrt_price_x96 < MIN_SQRT_RATIO or sqrt_price_x96 >= MAX_SQRT_RATIO) return null; + + // Convert Q96 to Q128 + const ratio: u256 = sqrt_price_x96 << 32; + + var r: u256 = ratio; + var msb: u256 = 0; + + // Find MSB via binary search (8 stages) + { + const f: u256 = if (r > 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF) @as(u256, 128) else 0; + msb |= f; + r >>= @intCast(f); + } + { + const f: u256 = if (r > 0xFFFFFFFFFFFFFFFF) @as(u256, 64) else 0; + msb |= f; + r >>= @intCast(f); + } + { + const f: u256 = if (r > 0xFFFFFFFF) @as(u256, 32) else 0; + msb |= f; + r >>= @intCast(f); + } + { + const f: u256 = if (r > 0xFFFF) @as(u256, 16) else 0; + msb |= f; + r >>= @intCast(f); + } + { + const f: u256 = if (r > 0xFF) @as(u256, 8) else 0; + msb |= f; + r >>= @intCast(f); + } + { + const f: u256 = if (r > 0xF) @as(u256, 4) else 0; + msb |= f; + r >>= @intCast(f); + } + { + const f: u256 = if (r > 0x3) @as(u256, 2) else 0; + msb |= f; + r >>= @intCast(f); + } + { + const f: u256 = if (r > 0x1) @as(u256, 1) else 0; + msb |= f; + } + + // Normalize r so MSB is at bit 127 + if (msb >= 128) { + r = ratio >> @intCast(msb - 127); + } else { + r = ratio << @intCast(127 - msb); + } + + // Initialize log_2 as signed Q64.64 + // log_2 = (msb - 128) << 64 + var log_2: i256 = (@as(i256, @intCast(msb)) - 128) << 64; + + // 14 iterations of squaring to compute fractional bits of log2 + // Each iteration: r = (r * r) >> 127, f = r >> 128, log_2 |= f << (63 - i), r >>= f + // NOTE: Solidity uses wrapping mul (mod 2^256) then shr, NOT full-precision mulDiv. + inline for (0..14) |iteration| { + // r = (r *% r) >> 127 (wrapping multiply, matching Solidity assembly `mul`) + r = fastMul(r, r) >> 127; + const f: u256 = r >> 128; + log_2 = log_2 | @as(i256, @intCast(f << (63 - iteration))); + r >>= @intCast(f); + } + + // Convert log_2 (base 2) to log_sqrt10001 (base sqrt(1.0001)) + const log_sqrt10001: i256 = log_2 * 255738958999603826347141; + + // Compute tick bounds + const tick_low: i24 = @intCast(@as(i256, (log_sqrt10001 - 3402992956809132418596140100660247210) >> 128)); + const tick_high: i24 = @intCast(@as(i256, (log_sqrt10001 + 291339464771989622907027621153398088495) >> 128)); + + if (tick_low == tick_high) { + return tick_low; + } + + // Check which tick is correct + const sqrt_ratio_at_high = getSqrtRatioAtTick(tick_high) orelse return null; + if (sqrt_ratio_at_high <= sqrt_price_x96) { + return tick_high; + } else { + return tick_low; + } +} + +// ============================================================================ +// SqrtPriceMath +// ============================================================================ + +/// Unsafe division rounding up: ceil(a / b). Assumes b != 0. +fn divRoundingUp(a: u256, b: u256) u256 { + const quotient = fastDiv(a, b); + if (a % b != 0) { + return quotient + 1; + } + return quotient; +} + +/// Amount of token0 for a price move. Uses mulDiv for precision. +/// Calculates: liquidity * (sqrtB - sqrtA) / (sqrtA * sqrtB) in Q96 space. +pub fn getAmount0Delta(sqrt_ratio_a_x96: u256, sqrt_ratio_b_x96: u256, liquidity: u128, round_up: bool) ?u256 { + var lower = sqrt_ratio_a_x96; + var upper = sqrt_ratio_b_x96; + if (lower > upper) { + const tmp = lower; + lower = upper; + upper = tmp; + } + if (lower == 0) return null; + + const numerator1: u256 = @as(u256, liquidity) << 96; + const numerator2: u256 = upper - lower; + + if (round_up) { + const inner = mulDivRoundingUp(numerator1, numerator2, upper) orelse return null; + return divRoundingUp(inner, lower); + } else { + const inner = mulDiv(numerator1, numerator2, upper) orelse return null; + return inner / lower; + } +} + +/// Amount of token1 for a price move. +/// Calculates: liquidity * (sqrtB - sqrtA) / Q96. +pub fn getAmount1Delta(sqrt_ratio_a_x96: u256, sqrt_ratio_b_x96: u256, liquidity: u128, round_up: bool) ?u256 { + var lower = sqrt_ratio_a_x96; + var upper = sqrt_ratio_b_x96; + if (lower > upper) { + const tmp = lower; + lower = upper; + upper = tmp; + } + + const diff = upper - lower; + + if (round_up) { + return mulDivRoundingUp(@as(u256, liquidity), diff, Q96); + } else { + return mulDiv(@as(u256, liquidity), diff, Q96); + } +} + +/// Get next sqrt price from token0 amount change, rounding up. +/// When add=true (input token0), price goes down. +/// When add=false (output token0), price goes up. +pub fn getNextSqrtPriceFromAmount0RoundingUp(sqrt_price_x96: u256, liquidity: u128, amount: u256, add: bool) ?u256 { + if (amount == 0) return sqrt_price_x96; + const numerator1: u256 = @as(u256, liquidity) << 96; + + if (add) { + // product = amount * sqrt_price_x96 -- check for overflow + const ov = @mulWithOverflow(amount, sqrt_price_x96); + if (ov[1] == 0) { + const product = ov[0]; + // Verify: product / amount == sqrt_price_x96 (no truncation) + if (product / amount == sqrt_price_x96) { + const denominator_ov = @addWithOverflow(numerator1, product); + if (denominator_ov[1] == 0 and denominator_ov[0] >= numerator1) { + return mulDivRoundingUp(numerator1, sqrt_price_x96, denominator_ov[0]); + } + } + } + // Fallback for overflow: numerator1 / (numerator1 / sqrt_price_x96 + amount) + const div_result = fastDiv(numerator1, sqrt_price_x96); + const sum_ov = @addWithOverflow(div_result, amount); + if (sum_ov[1] != 0) return null; + return divRoundingUp(numerator1, sum_ov[0]); + } else { + // Subtract: price goes up + const ov = @mulWithOverflow(amount, sqrt_price_x96); + if (ov[1] != 0) return null; + const product = ov[0]; + // Verify no truncation + if (product / amount != sqrt_price_x96) return null; + if (numerator1 <= product) return null; + const denominator = numerator1 - product; + return mulDivRoundingUp(numerator1, sqrt_price_x96, denominator); + } +} + +/// Get next sqrt price from token1 amount change, rounding down. +/// When add=true (input token1), price goes up. +/// When add=false (output token1), price goes down. +pub fn getNextSqrtPriceFromAmount1RoundingDown(sqrt_price_x96: u256, liquidity: u128, amount: u256, add: bool) ?u256 { + if (add) { + // quotient = amount * Q96 / liquidity (or amount << 96 / liquidity if fits) + const quotient: u256 = if (amount <= (@as(u256, 1) << 160) - 1) + (amount << 96) / @as(u256, liquidity) + else + mulDiv(amount, Q96, @as(u256, liquidity)) orelse return null; + + const result_ov = @addWithOverflow(sqrt_price_x96, quotient); + if (result_ov[1] != 0) return null; + return result_ov[0]; + } else { + // quotient = mulDivRoundingUp(amount, Q96, liquidity) or divRoundingUp(amount << 96, liquidity) if fits + const quotient: u256 = if (amount <= (@as(u256, 1) << 160) - 1) + divRoundingUp(amount << 96, @as(u256, liquidity)) + else + mulDivRoundingUp(amount, Q96, @as(u256, liquidity)) orelse return null; + + if (sqrt_price_x96 <= quotient) return null; + return sqrt_price_x96 - quotient; + } +} + +/// Get next sqrt price from input amount. +pub fn getNextSqrtPriceFromInput(sqrt_price_x96: u256, liquidity: u128, amount_in: u256, zero_for_one: bool) ?u256 { + if (sqrt_price_x96 == 0) return null; + if (liquidity == 0) return null; + + return if (zero_for_one) + getNextSqrtPriceFromAmount0RoundingUp(sqrt_price_x96, liquidity, amount_in, true) + else + getNextSqrtPriceFromAmount1RoundingDown(sqrt_price_x96, liquidity, amount_in, true); +} + +/// Get next sqrt price from output amount. +pub fn getNextSqrtPriceFromOutput(sqrt_price_x96: u256, liquidity: u128, amount_out: u256, zero_for_one: bool) ?u256 { + if (sqrt_price_x96 == 0) return null; + if (liquidity == 0) return null; + + return if (zero_for_one) + getNextSqrtPriceFromAmount1RoundingDown(sqrt_price_x96, liquidity, amount_out, false) + else + getNextSqrtPriceFromAmount0RoundingUp(sqrt_price_x96, liquidity, amount_out, false); +} + +// ============================================================================ +// SwapMath +// ============================================================================ + +pub const SwapStepResult = struct { + sqrt_ratio_next_x96: u256, + amount_in: u256, + amount_out: u256, + fee_amount: u256, +}; + +/// Compute a single swap step within one tick range. +/// Port of Uniswap V3 SwapMath.computeSwapStep. +pub fn computeSwapStep( + sqrt_ratio_current_x96: u256, + sqrt_ratio_target_x96: u256, + liquidity: u128, + amount_remaining: i256, + fee_pips: u24, // e.g. 3000 = 0.3% +) SwapStepResult { + const zero_for_one = sqrt_ratio_current_x96 >= sqrt_ratio_target_x96; + const exact_in = amount_remaining >= 0; + + var sqrt_ratio_next_x96: u256 = 0; + var amount_in: u256 = 0; + var amount_out: u256 = 0; + var fee_amount: u256 = 0; + + if (exact_in) { + const amount_remaining_u: u256 = @intCast(amount_remaining); + const amount_remaining_less_fee = mulDiv(amount_remaining_u, 1_000_000 - @as(u256, fee_pips), 1_000_000) orelse 0; + + amount_in = if (zero_for_one) + getAmount0Delta(sqrt_ratio_target_x96, sqrt_ratio_current_x96, liquidity, true) orelse 0 + else + getAmount1Delta(sqrt_ratio_current_x96, sqrt_ratio_target_x96, liquidity, true) orelse 0; + + if (amount_remaining_less_fee >= amount_in) { + sqrt_ratio_next_x96 = sqrt_ratio_target_x96; + } else { + sqrt_ratio_next_x96 = getNextSqrtPriceFromInput( + sqrt_ratio_current_x96, + liquidity, + amount_remaining_less_fee, + zero_for_one, + ) orelse sqrt_ratio_current_x96; + } + } else { + // exact_out: amount_remaining is negative + const neg_amount: u256 = @intCast(-amount_remaining); + + amount_out = if (zero_for_one) + getAmount1Delta(sqrt_ratio_target_x96, sqrt_ratio_current_x96, liquidity, false) orelse 0 + else + getAmount0Delta(sqrt_ratio_current_x96, sqrt_ratio_target_x96, liquidity, false) orelse 0; + + if (neg_amount >= amount_out) { + sqrt_ratio_next_x96 = sqrt_ratio_target_x96; + } else { + sqrt_ratio_next_x96 = getNextSqrtPriceFromOutput( + sqrt_ratio_current_x96, + liquidity, + neg_amount, + zero_for_one, + ) orelse sqrt_ratio_current_x96; + } + } + + const max = sqrt_ratio_target_x96 == sqrt_ratio_next_x96; + + // Recalculate amounts based on whether we hit the target + if (zero_for_one) { + if (!(max and exact_in)) { + amount_in = getAmount0Delta(sqrt_ratio_next_x96, sqrt_ratio_current_x96, liquidity, true) orelse 0; + } + if (!(max and !exact_in)) { + amount_out = getAmount1Delta(sqrt_ratio_next_x96, sqrt_ratio_current_x96, liquidity, false) orelse 0; + } + } else { + if (!(max and exact_in)) { + amount_in = getAmount1Delta(sqrt_ratio_current_x96, sqrt_ratio_next_x96, liquidity, true) orelse 0; + } + if (!(max and !exact_in)) { + amount_out = getAmount0Delta(sqrt_ratio_current_x96, sqrt_ratio_next_x96, liquidity, false) orelse 0; + } + } + + // Cap output amount to not exceed remaining output amount + if (!exact_in) { + const neg_amount: u256 = @intCast(-amount_remaining); + if (amount_out > neg_amount) { + amount_out = neg_amount; + } + } + + // Fee computation + if (exact_in and sqrt_ratio_next_x96 != sqrt_ratio_target_x96) { + // Didn't reach target: remainder of input goes to fee + const amount_remaining_u: u256 = @intCast(amount_remaining); + fee_amount = amount_remaining_u - amount_in; + } else { + fee_amount = mulDivRoundingUp(amount_in, @as(u256, fee_pips), 1_000_000 - @as(u256, fee_pips)) orelse 0; + } + + return .{ + .sqrt_ratio_next_x96 = sqrt_ratio_next_x96, + .amount_in = amount_in, + .amount_out = amount_out, + .fee_amount = fee_amount, + }; +} + +// ============================================================================ +// Tick-crossing simulation +// ============================================================================ + +pub const TickInfo = struct { + tick: i24, + liquidity_net: i128, + sqrt_price_x96: u256, // precomputed getSqrtRatioAtTick(tick) +}; + +pub const SwapResult = struct { + amount_in_consumed: u256, + amount_out: u256, + sqrt_price_final_x96: u256, + ticks_crossed: usize, +}; + +/// Simulate a full swap across multiple ticks. +/// ticks must be sorted (ascending for !zero_for_one, descending for zero_for_one). +pub fn simulateSwap( + sqrt_price_x96: u256, + liquidity: u128, + ticks: []const TickInfo, + amount_in: u256, + zero_for_one: bool, + fee_pips: u24, +) SwapResult { + var current_sqrt_price: u256 = sqrt_price_x96; + var current_liquidity: u128 = liquidity; + var amount_remaining: u256 = amount_in; + var total_amount_out: u256 = 0; + var ticks_crossed: usize = 0; + + for (ticks) |tick_info| { + if (amount_remaining == 0) break; + + const step = computeSwapStep( + current_sqrt_price, + tick_info.sqrt_price_x96, + current_liquidity, + @as(i256, @intCast(amount_remaining)), // exact_in (positive) + fee_pips, + ); + + // Deduct consumed amount + fee from remaining + const consumed = step.amount_in + step.fee_amount; + if (consumed >= amount_remaining) { + amount_remaining = 0; + } else { + amount_remaining -= consumed; + } + total_amount_out += step.amount_out; + current_sqrt_price = step.sqrt_ratio_next_x96; + + // If we reached the tick boundary, update liquidity + if (current_sqrt_price == tick_info.sqrt_price_x96) { + if (zero_for_one) { + // Moving left: subtract liquidity_net + if (tick_info.liquidity_net < 0) { + current_liquidity += @as(u128, @abs(tick_info.liquidity_net)); + } else { + const net_u: u128 = @intCast(tick_info.liquidity_net); + if (net_u > current_liquidity) { + current_liquidity = 0; + } else { + current_liquidity -= net_u; + } + } + } else { + // Moving right: add liquidity_net + if (tick_info.liquidity_net >= 0) { + current_liquidity += @intCast(tick_info.liquidity_net); + } else { + const net_u: u128 = @abs(tick_info.liquidity_net); + if (net_u > current_liquidity) { + current_liquidity = 0; + } else { + current_liquidity -= net_u; + } + } + } + ticks_crossed += 1; + } + } + + return .{ + .amount_in_consumed = amount_in - amount_remaining, + .amount_out = total_amount_out, + .sqrt_price_final_x96 = current_sqrt_price, + .ticks_crossed = ticks_crossed, + }; +} + +// ============================================================================ +// Tests +// ============================================================================ + +test "getSqrtRatioAtTick(0) == Q96" { + // tick 0 = price 1.0, sqrtPrice = 1.0 * 2^96 + try std.testing.expectEqual(Q96, getSqrtRatioAtTick(0).?); +} + +test "getSqrtRatioAtTick(MIN_TICK) == MIN_SQRT_RATIO" { + try std.testing.expectEqual(MIN_SQRT_RATIO, getSqrtRatioAtTick(MIN_TICK).?); +} + +test "getSqrtRatioAtTick(MAX_TICK) == MAX_SQRT_RATIO" { + try std.testing.expectEqual(MAX_SQRT_RATIO, getSqrtRatioAtTick(MAX_TICK).?); +} + +test "getSqrtRatioAtTick out of range" { + try std.testing.expectEqual(@as(?u256, null), getSqrtRatioAtTick(MIN_TICK - 1)); + try std.testing.expectEqual(@as(?u256, null), getSqrtRatioAtTick(MAX_TICK + 1)); +} + +test "getTickAtSqrtRatio roundtrip" { + // Test various ticks. Note: MAX_TICK is excluded because getSqrtRatioAtTick(MAX_TICK) == MAX_SQRT_RATIO, + // which is outside getTickAtSqrtRatio's domain [MIN_SQRT_RATIO, MAX_SQRT_RATIO). + const test_ticks = [_]i24{ 0, 1, -1, 100, -100, -887272, 50, -50, 10000, -10000, 887271 }; + for (test_ticks) |tick| { + const sqrt_ratio = getSqrtRatioAtTick(tick).?; + const recovered_tick = getTickAtSqrtRatio(sqrt_ratio).?; + try std.testing.expectEqual(tick, recovered_tick); + } +} + +test "getTickAtSqrtRatio boundary" { + try std.testing.expectEqual(@as(?i24, null), getTickAtSqrtRatio(MIN_SQRT_RATIO - 1)); + try std.testing.expectEqual(@as(?i24, null), getTickAtSqrtRatio(MAX_SQRT_RATIO)); +} + +test "getAmount0Delta known value" { + // Known V3 test vector: liquidity = 1e18, price range 1.0 to 1.01 + const sqrt_a = getSqrtRatioAtTick(0).?; // price 1.0 + const sqrt_b = getSqrtRatioAtTick(100).?; // price ~1.01 + const result = getAmount0Delta(sqrt_a, sqrt_b, 1_000_000_000_000_000_000, true); + try std.testing.expect(result != null); + try std.testing.expect(result.? > 0); +} + +test "getAmount1Delta known value" { + const sqrt_a = getSqrtRatioAtTick(0).?; + const sqrt_b = getSqrtRatioAtTick(100).?; + const result = getAmount1Delta(sqrt_a, sqrt_b, 1_000_000_000_000_000_000, true); + try std.testing.expect(result != null); + try std.testing.expect(result.? > 0); +} + +test "getAmount0Delta symmetry" { + // getAmount0Delta(a, b, ...) == getAmount0Delta(b, a, ...) (auto-sorts) + const sqrt_a = getSqrtRatioAtTick(0).?; + const sqrt_b = getSqrtRatioAtTick(100).?; + const liq: u128 = 1_000_000_000_000_000_000; + const result1 = getAmount0Delta(sqrt_a, sqrt_b, liq, true); + const result2 = getAmount0Delta(sqrt_b, sqrt_a, liq, true); + try std.testing.expectEqual(result1, result2); +} + +test "getAmount1Delta symmetry" { + const sqrt_a = getSqrtRatioAtTick(0).?; + const sqrt_b = getSqrtRatioAtTick(100).?; + const liq: u128 = 1_000_000_000_000_000_000; + const result1 = getAmount1Delta(sqrt_a, sqrt_b, liq, true); + const result2 = getAmount1Delta(sqrt_b, sqrt_a, liq, true); + try std.testing.expectEqual(result1, result2); +} + +test "getNextSqrtPriceFromInput zero_for_one" { + const current = getSqrtRatioAtTick(0).?; + const liq: u128 = 1_000_000_000_000_000_000; + const result = getNextSqrtPriceFromInput(current, liq, 100_000, true); + try std.testing.expect(result != null); + // Selling token0 should decrease price + try std.testing.expect(result.? < current); +} + +test "getNextSqrtPriceFromInput !zero_for_one" { + const current = getSqrtRatioAtTick(0).?; + const liq: u128 = 1_000_000_000_000_000_000; + const result = getNextSqrtPriceFromInput(current, liq, 100_000, false); + try std.testing.expect(result != null); + // Selling token1 should increase price + try std.testing.expect(result.? > current); +} + +test "getNextSqrtPriceFromOutput" { + const current = getSqrtRatioAtTick(0).?; + const liq: u128 = 1_000_000_000_000_000_000; + // zero_for_one output: buying token1 + const result = getNextSqrtPriceFromOutput(current, liq, 100_000, true); + try std.testing.expect(result != null); + try std.testing.expect(result.? < current); +} + +test "computeSwapStep exact input within range" { + // Simple swap that stays within one tick range + const current = getSqrtRatioAtTick(0).?; + const target = getSqrtRatioAtTick(-100).?; + const result = computeSwapStep(current, target, 1_000_000_000_000_000_000, 100_000, 3000); + try std.testing.expect(result.amount_in > 0); + try std.testing.expect(result.amount_out > 0); + try std.testing.expect(result.fee_amount > 0); +} + +test "computeSwapStep exact output" { + const current = getSqrtRatioAtTick(0).?; + const target = getSqrtRatioAtTick(-100).?; + // Negative amount_remaining means exact output + const result = computeSwapStep(current, target, 1_000_000_000_000_000_000, -50_000, 3000); + try std.testing.expect(result.amount_in > 0); + try std.testing.expect(result.amount_out > 0); +} + +test "computeSwapStep hits target exactly" { + // Small amount that can fully consume to target + const current = getSqrtRatioAtTick(0).?; + const target = getSqrtRatioAtTick(-1).?; + const liq: u128 = 100; // very small liquidity + // Very large input should hit the target + const result = computeSwapStep(current, target, liq, 1_000_000_000_000_000_000, 3000); + try std.testing.expectEqual(target, result.sqrt_ratio_next_x96); +} + +test "simulateSwap basic" { + const current_tick: i24 = 0; + const current_sqrt_price = getSqrtRatioAtTick(current_tick).?; + const liq: u128 = 1_000_000_000_000_000_000; + + // Set up two ticks below current price (for zero_for_one swap) + const tick1: i24 = -100; + const tick2: i24 = -200; + const ticks = [_]TickInfo{ + .{ + .tick = tick1, + .liquidity_net = 500_000_000_000_000_000, + .sqrt_price_x96 = getSqrtRatioAtTick(tick1).?, + }, + .{ + .tick = tick2, + .liquidity_net = 500_000_000_000_000_000, + .sqrt_price_x96 = getSqrtRatioAtTick(tick2).?, + }, + }; + + const result = simulateSwap( + current_sqrt_price, + liq, + &ticks, + 1_000_000, + true, // zero_for_one + 3000, // 0.3% fee + ); + + try std.testing.expect(result.amount_in_consumed > 0); + try std.testing.expect(result.amount_out > 0); + try std.testing.expect(result.sqrt_price_final_x96 < current_sqrt_price); +} + +test "simulateSwap zero amount" { + const current_sqrt_price = getSqrtRatioAtTick(0).?; + const ticks = [_]TickInfo{}; + + const result = simulateSwap( + current_sqrt_price, + 1_000_000_000_000_000_000, + &ticks, + 0, + true, + 3000, + ); + + try std.testing.expectEqual(@as(u256, 0), result.amount_in_consumed); + try std.testing.expectEqual(@as(u256, 0), result.amount_out); + try std.testing.expectEqual(@as(usize, 0), result.ticks_crossed); +} + +test "mulDivRoundingUp basic" { + // 7 * 1 / 2 = 3.5 -> rounds up to 4 + try std.testing.expectEqual(@as(?u256, 4), mulDivRoundingUp(7, 1, 2)); + // 6 * 1 / 2 = 3.0 -> exact, no rounding + try std.testing.expectEqual(@as(?u256, 3), mulDivRoundingUp(6, 1, 2)); + // div by zero + try std.testing.expectEqual(@as(?u256, null), mulDivRoundingUp(1, 1, 0)); +} + +test "getTickAtSqrtRatio known values" { + // MIN_SQRT_RATIO should give MIN_TICK + try std.testing.expectEqual(@as(?i24, MIN_TICK), getTickAtSqrtRatio(MIN_SQRT_RATIO)); + // MAX_SQRT_RATIO - 1 should give MAX_TICK - 1 + const result = getTickAtSqrtRatio(MAX_SQRT_RATIO - 1); + try std.testing.expect(result != null); + try std.testing.expectEqual(@as(i24, MAX_TICK - 1), result.?); +} + +test "getSqrtRatioAtTick monotonic" { + // Verify that higher ticks produce higher sqrt ratios + var prev: u256 = 0; + const test_ticks = [_]i24{ -887272, -10000, -1000, -100, -10, -1, 0, 1, 10, 100, 1000, 10000, 887272 }; + for (test_ticks) |tick| { + const ratio = getSqrtRatioAtTick(tick).?; + try std.testing.expect(ratio > prev); + prev = ratio; + } +} diff --git a/src/http_transport.zig b/src/http_transport.zig index 899f711..954f6f1 100644 --- a/src/http_transport.zig +++ b/src/http_transport.zig @@ -42,6 +42,55 @@ pub const HttpTransport = struct { return buf.toOwnedSlice(allocator); } + /// Build a JSON-RPC 2.0 batch request body from individual request bodies. + /// Wraps them in a JSON array: [body1,body2,...bodyN] + /// Caller owns the returned memory. + pub fn buildBatchBody(allocator: std.mem.Allocator, bodies: []const []const u8) ![]u8 { + var buf: std.ArrayList(u8) = .empty; + errdefer buf.deinit(allocator); + + try buf.append(allocator, '['); + for (bodies, 0..) |body, i| { + if (i > 0) try buf.append(allocator, ','); + try buf.appendSlice(allocator, body); + } + try buf.append(allocator, ']'); + + return buf.toOwnedSlice(allocator); + } + + /// Send a batch JSON-RPC request and return the raw response body. + /// Caller owns the returned memory. + pub fn requestBatch(self: *HttpTransport, bodies: []const []const u8) ![]u8 { + const batch_body = try buildBatchBody(self.allocator, bodies); + defer self.allocator.free(batch_body); + + // Use an allocating writer to collect the response body. + var response_body: std.Io.Writer.Allocating = .init(self.allocator); + errdefer response_body.deinit(); + + const result = self.client.fetch(.{ + .location = .{ .url = self.url }, + .method = .POST, + .payload = batch_body, + .extra_headers = &.{ + .{ .name = "Content-Type", .value = "application/json" }, + }, + .response_writer = &response_body.writer, + }); + + if (result) |res| { + if (res.status != .ok) { + response_body.deinit(); + return error.HttpError; + } + return response_body.toOwnedSlice(); + } else |_| { + response_body.deinit(); + return error.ConnectionFailed; + } + } + /// Send a JSON-RPC request and return the raw response body. /// Caller owns the returned memory. pub fn request(self: *HttpTransport, method: []const u8, params_json: []const u8, id: u64) ![]u8 { @@ -140,3 +189,49 @@ test "init and deinit" { try std.testing.expectEqualStrings("http://localhost:8545", transport.url); } + +test "buildBatchBody - empty" { + const allocator = std.testing.allocator; + const body = try HttpTransport.buildBatchBody(allocator, &.{}); + defer allocator.free(body); + try std.testing.expectEqualStrings("[]", body); +} + +test "buildBatchBody - single request" { + const allocator = std.testing.allocator; + const req = try HttpTransport.buildRequestBody(allocator, "eth_chainId", "[]", 1); + defer allocator.free(req); + const bodies: []const []const u8 = &.{req}; + const batch = try HttpTransport.buildBatchBody(allocator, bodies); + defer allocator.free(batch); + // Should be [{"jsonrpc":"2.0","method":"eth_chainId","params":[],"id":1}] + try std.testing.expect(batch[0] == '['); + try std.testing.expect(batch[batch.len - 1] == ']'); + try std.testing.expectEqualStrings(req, batch[1 .. batch.len - 1]); +} + +test "buildBatchBody - multiple requests" { + const allocator = std.testing.allocator; + const req1 = try HttpTransport.buildRequestBody(allocator, "eth_chainId", "[]", 1); + defer allocator.free(req1); + const req2 = try HttpTransport.buildRequestBody(allocator, "eth_blockNumber", "[]", 2); + defer allocator.free(req2); + const bodies: []const []const u8 = &.{ req1, req2 }; + const batch = try HttpTransport.buildBatchBody(allocator, bodies); + defer allocator.free(batch); + try std.testing.expect(batch[0] == '['); + try std.testing.expect(batch[batch.len - 1] == ']'); + // Should contain a comma between requests + const comma_count = blk: { + var count: usize = 0; + // Count commas outside braces to find separator + var depth: i32 = 0; + for (batch) |c| { + if (c == '{') depth += 1; + if (c == '}') depth -= 1; + if (c == ',' and depth == 0) count += 1; + } + break :blk count; + }; + try std.testing.expectEqual(@as(usize, 1), comma_count); +} diff --git a/src/provider.zig b/src/provider.zig index c05bc64..1ea0835 100644 --- a/src/provider.zig +++ b/src/provider.zig @@ -281,7 +281,7 @@ pub const Provider = struct { return buf.toOwnedSlice(self.allocator); } - fn formatCallParams(self: *Provider, to: [20]u8, data: []const u8, from: ?[20]u8) ![]u8 { + pub fn formatCallParams(self: *Provider, to: [20]u8, data: []const u8, from: ?[20]u8) ![]u8 { const to_hex = primitives.addressToHex(&to); const data_hex = try hex_mod.bytesToHex(self.allocator, data); defer self.allocator.free(data_hex); @@ -307,6 +307,173 @@ pub const Provider = struct { } }; +// ============================================================================ +// Batch eth_call support +// ============================================================================ + +pub const BatchCallResult = union(enum) { + success: []u8, + rpc_error: RpcErrorData, + + pub const RpcErrorData = struct { + code: i64, + message: []const u8, + }; +}; + +pub const BatchCaller = struct { + provider: *Provider, + allocator: std.mem.Allocator, + targets: std.ArrayList([20]u8), + calldata: std.ArrayList([]const u8), + + pub fn init(allocator: std.mem.Allocator, prov: *Provider) BatchCaller { + return .{ + .provider = prov, + .allocator = allocator, + .targets = .empty, + .calldata = .empty, + }; + } + + pub fn deinit(self: *BatchCaller) void { + self.targets.deinit(self.allocator); + self.calldata.deinit(self.allocator); + } + + pub fn addCall(self: *BatchCaller, to: [20]u8, data: []const u8) !usize { + const index = self.targets.items.len; + try self.targets.append(self.allocator, to); + try self.calldata.append(self.allocator, data); + return index; + } + + pub fn reset(self: *BatchCaller) void { + self.targets.clearRetainingCapacity(); + self.calldata.clearRetainingCapacity(); + } + + pub fn execute(self: *BatchCaller) ![]BatchCallResult { + const n = self.targets.items.len; + if (n == 0) return try self.allocator.alloc(BatchCallResult, 0); + + // Build individual request bodies + const bodies = try self.allocator.alloc([]u8, n); + defer { + for (bodies) |b| self.allocator.free(b); + self.allocator.free(bodies); + } + + const ids = try self.allocator.alloc(u64, n); + defer self.allocator.free(ids); + + const base_id = self.provider.next_id; + self.provider.next_id += n; + + for (0..n) |i| { + ids[i] = base_id + i; + const params = try self.provider.formatCallParams(self.targets.items[i], self.calldata.items[i], null); + defer self.allocator.free(params); + bodies[i] = try HttpTransport.buildRequestBody(self.allocator, json_rpc.Method.eth_call, params, ids[i]); + } + + // Build const slice for requestBatch + const const_bodies = try self.allocator.alloc([]const u8, n); + defer self.allocator.free(const_bodies); + for (bodies, 0..) |b, i| { + const_bodies[i] = b; + } + + const raw = try self.provider.transport.requestBatch(const_bodies); + defer self.allocator.free(raw); + + return try parseBatchResponse(self.allocator, raw, ids); + } +}; + +fn parseBatchResponse(allocator: std.mem.Allocator, raw: []const u8, ids: []const u64) ![]BatchCallResult { + const n = ids.len; + var results = try allocator.alloc(BatchCallResult, n); + // Initialize all to a sentinel so we can detect missing responses + for (results) |*r| r.* = .{ .rpc_error = .{ .code = -1, .message = "" } }; + + // Parse JSON array + const parsed = std.json.parseFromSlice(std.json.Value, allocator, raw, .{}) catch { + return error.InvalidResponse; + }; + defer parsed.deinit(); + + const arr = switch (parsed.value) { + .array => |a| a, + else => return error.InvalidResponse, + }; + + // Match each response to its request by id + for (arr.items) |item| { + const obj = switch (item) { + .object => |o| o, + else => continue, + }; + + // Get id + const id_val = obj.get("id") orelse continue; + const id: u64 = switch (id_val) { + .integer => |i| @intCast(i), + else => continue, + }; + + // Find index for this id + var idx: ?usize = null; + for (ids, 0..) |expected_id, i| { + if (expected_id == id) { + idx = i; + break; + } + } + const index = idx orelse continue; + + // Check for error + if (obj.get("error")) |err_val| { + if (err_val == .object) { + const code = if (err_val.object.get("code")) |c| switch (c) { + .integer => |ci| @as(i64, @intCast(ci)), + else => @as(i64, 0), + } else 0; + const message = if (err_val.object.get("message")) |m| switch (m) { + .string => |s| s, + else => "unknown error", + } else "unknown error"; + // Dupe message since parsed will be freed + const msg_copy = try allocator.dupe(u8, message); + results[index] = .{ .rpc_error = .{ .code = code, .message = msg_copy } }; + continue; + } + } + + // Get result + const result_val = obj.get("result") orelse continue; + switch (result_val) { + .string => |s| { + const decoded = try parseHexBytes(allocator, s); + results[index] = .{ .success = decoded }; + }, + else => {}, + } + } + + return results; +} + +pub fn freeBatchResults(allocator: std.mem.Allocator, results: []BatchCallResult) void { + for (results) |r| { + switch (r) { + .success => |data| if (data.len > 0) allocator.free(data), + .rpc_error => |e| if (e.message.len > 0) allocator.free(@constCast(e.message)), + } + } + allocator.free(results); +} + // ============================================================================ // JSON response parsing // ============================================================================ @@ -315,7 +482,7 @@ pub const Provider = struct { /// Handles both quoted string results and null. /// Caller owns the returned memory. fn extractResultString(allocator: std.mem.Allocator, raw: []const u8) ![]u8 { - const parsed = std.json.parseFromSlice(std.json.Value, std.heap.page_allocator, raw, .{}) catch { + const parsed = std.json.parseFromSlice(std.json.Value, allocator, raw, .{}) catch { return error.InvalidResponse; }; defer parsed.deinit(); @@ -1048,3 +1215,98 @@ test "Provider.init" { const provider = Provider.init(allocator, &transport); try std.testing.expectEqual(@as(u64, 1), provider.next_id); } + +test "BatchCaller.init and deinit" { + const allocator = std.testing.allocator; + var transport = HttpTransport.init(allocator, "http://localhost:8545"); + defer transport.deinit(); + var prov = Provider.init(allocator, &transport); + var batch = BatchCaller.init(allocator, &prov); + defer batch.deinit(); + try std.testing.expectEqual(@as(usize, 0), batch.targets.items.len); +} + +test "BatchCaller.addCall accumulates" { + const allocator = std.testing.allocator; + var transport = HttpTransport.init(allocator, "http://localhost:8545"); + defer transport.deinit(); + var prov = Provider.init(allocator, &transport); + var batch = BatchCaller.init(allocator, &prov); + defer batch.deinit(); + + const idx0 = try batch.addCall([_]u8{0x11} ** 20, &.{ 0x01, 0x02 }); + const idx1 = try batch.addCall([_]u8{0x22} ** 20, &.{ 0x03, 0x04 }); + + try std.testing.expectEqual(@as(usize, 0), idx0); + try std.testing.expectEqual(@as(usize, 1), idx1); + try std.testing.expectEqual(@as(usize, 2), batch.targets.items.len); +} + +test "BatchCaller.reset clears" { + const allocator = std.testing.allocator; + var transport = HttpTransport.init(allocator, "http://localhost:8545"); + defer transport.deinit(); + var prov = Provider.init(allocator, &transport); + var batch = BatchCaller.init(allocator, &prov); + defer batch.deinit(); + + _ = try batch.addCall([_]u8{0x11} ** 20, &.{0x01}); + try std.testing.expectEqual(@as(usize, 1), batch.targets.items.len); + batch.reset(); + try std.testing.expectEqual(@as(usize, 0), batch.targets.items.len); +} + +test "parseBatchResponse in order" { + const allocator = std.testing.allocator; + const raw = "[{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xdead\"},{\"jsonrpc\":\"2.0\",\"id\":2,\"result\":\"0xbeef\"}]"; + const ids = [_]u64{ 1, 2 }; + const results = try parseBatchResponse(allocator, raw, &ids); + defer freeBatchResults(allocator, results); + + try std.testing.expectEqual(@as(usize, 2), results.len); + switch (results[0]) { + .success => |data| try std.testing.expectEqualSlices(u8, &.{ 0xde, 0xad }, data), + else => return error.TestUnexpectedResult, + } + switch (results[1]) { + .success => |data| try std.testing.expectEqualSlices(u8, &.{ 0xbe, 0xef }, data), + else => return error.TestUnexpectedResult, + } +} + +test "parseBatchResponse out of order" { + const allocator = std.testing.allocator; + const raw = "[{\"jsonrpc\":\"2.0\",\"id\":2,\"result\":\"0xbeef\"},{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xdead\"}]"; + const ids = [_]u64{ 1, 2 }; + const results = try parseBatchResponse(allocator, raw, &ids); + defer freeBatchResults(allocator, results); + + // Results should be in original order (by id), not response order + switch (results[0]) { + .success => |data| try std.testing.expectEqualSlices(u8, &.{ 0xde, 0xad }, data), + else => return error.TestUnexpectedResult, + } + switch (results[1]) { + .success => |data| try std.testing.expectEqualSlices(u8, &.{ 0xbe, 0xef }, data), + else => return error.TestUnexpectedResult, + } +} + +test "parseBatchResponse partial failure" { + const allocator = std.testing.allocator; + const raw = "[{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xdead\"},{\"jsonrpc\":\"2.0\",\"id\":2,\"error\":{\"code\":3,\"message\":\"execution reverted\"}}]"; + const ids = [_]u64{ 1, 2 }; + const results = try parseBatchResponse(allocator, raw, &ids); + defer freeBatchResults(allocator, results); + + switch (results[0]) { + .success => |data| try std.testing.expectEqualSlices(u8, &.{ 0xde, 0xad }, data), + else => return error.TestUnexpectedResult, + } + switch (results[1]) { + .rpc_error => |e| { + try std.testing.expectEqual(@as(i64, 3), e.code); + }, + else => return error.TestUnexpectedResult, + } +} diff --git a/src/root.zig b/src/root.zig index cfd684b..b38dce7 100644 --- a/src/root.zig +++ b/src/root.zig @@ -51,6 +51,11 @@ pub const event = @import("event.zig"); pub const erc20 = @import("erc20.zig"); pub const erc721 = @import("erc721.zig"); +// -- DEX Math -- +pub const dex_v2 = @import("dex/v2.zig"); +pub const dex_v3 = @import("dex/v3.zig"); +pub const dex_router = @import("dex/router.zig"); + // -- Layer 9: Standards -- pub const eip712 = @import("eip712.zig"); pub const abi_json = @import("abi_json.zig"); @@ -124,4 +129,8 @@ test { _ = @import("ens/reverse.zig"); // Utils _ = @import("utils/units.zig"); + // DEX Math + _ = @import("dex/v2.zig"); + _ = @import("dex/v3.zig"); + _ = @import("dex/router.zig"); } diff --git a/src/uint256.zig b/src/uint256.zig index 4590f4a..90796dd 100644 --- a/src/uint256.zig +++ b/src/uint256.zig @@ -631,25 +631,39 @@ pub inline fn mulDiv(a: u256, b: u256, denominator: u256) ?u256 { return if (divWide(wide, d_limbs)) |q| limbsToU256(q) else null; } +/// mulDiv with rounding up: ceil(a * b / denominator) +pub fn mulDivRoundingUp(a: u256, b: u256, denominator: u256) ?u256 { + const result = mulDiv(a, b, denominator) orelse return null; + if (denominator == 0) return null; + // Check remainder: if result * denominator != a * b, round up + const a_limbs = u256ToLimbs(a); + const b_limbs = u256ToLimbs(b); + const wide_ab = mulWide(a_limbs, b_limbs); + const result_limbs = u256ToLimbs(result); + const d_limbs = u256ToLimbs(denominator); + const wide_rd = mulWide(result_limbs, d_limbs); + var has_remainder = false; + var i: usize = 7; + while (true) : (i -= 1) { + if (wide_ab[i] != wide_rd[i]) { + has_remainder = wide_ab[i] > wide_rd[i]; + break; + } + if (i == 0) break; + } + if (has_remainder) { + if (result == MAX) return null; + return result + 1; + } + return result; +} + /// Compute UniswapV2 getAmountOut entirely in u64-limb space. /// Formula: (amountIn * 997 * reserveOut) / (reserveIn * 1000 + amountIn * 997) -/// Uses limb arithmetic + div128by64 to avoid __udivti3 (u128/u128 software division). +/// Delegates to dex/v2.zig with the standard Uniswap V2 fee (997/1000). pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256) u256 { - if (amount_in == 0) return 0; - - const ai = u256ToLimbs(amount_in); - const ri = u256ToLimbs(reserve_in); - const ro = u256ToLimbs(reserve_out); - - const amount_in_with_fee = mulLimbScalar(ai, 997); - const numerator = mulLimbs(amount_in_with_fee, ro); - const denominator = addLimbs(mulLimbScalar(ri, 1000), amount_in_with_fee); - - if (denominator[0] == 0 and denominator[1] == 0 and denominator[2] == 0 and denominator[3] == 0) { - @panic("getAmountOut: denominator is zero (invalid reserves)"); - } - - return limbsToU256(divLimbsDirect(numerator, denominator)); + const dex_v2 = @import("dex/v2.zig"); + return dex_v2.getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); } /// Q96 constant (2^96) used in UniswapV3/V4 fixed-point arithmetic. From 6080a00de8b332c45753021ff9657f3e9c484afb Mon Sep 17 00:00:00 2001 From: Koko Bhadra Date: Tue, 10 Mar 2026 17:10:42 -0400 Subject: [PATCH 2/5] Address CodeRabbit review findings - v2: add zero-reserve/fee_denominator guards in getAmountOut/getAmountIn - v2: replace tautological legacy test with hardcoded expected value - v3: add liquidity/sqrt_price guards in getNextSqrtPrice helpers - v3: validate fee_pips <= 1M in computeSwapStep - v3: add terminal swap step in simulateSwap for remaining input - v3: break tick loop when current_liquidity reaches zero - router: fix findArbOpportunity for max_input < 1000 and max_input == 0 - router: evaluate both binary search endpoints for optimal result - uint256: remove redundant denominator==0 check in mulDivRoundingUp - provider: add doc comments for addCall borrowing and sentinel contract - docs: clarify token decimals and null returns in dex-math.mdx --- docs/content/docs/dex-math.mdx | 7 +++++-- src/dex/router.zig | 36 ++++++++++++++++------------------ src/dex/v2.zig | 24 +++++++++++++++-------- src/dex/v3.zig | 29 +++++++++++++++++++++++++++ src/provider.zig | 6 +++++- src/uint256.zig | 1 - 6 files changed, 72 insertions(+), 31 deletions(-) diff --git a/docs/content/docs/dex-math.mdx b/docs/content/docs/dex-math.mdx index 24bc289..9e4350b 100644 --- a/docs/content/docs/dex-math.mdx +++ b/docs/content/docs/dex-math.mdx @@ -45,8 +45,8 @@ const required_input = eth.dex_v2.getAmountIn( ```zig const path = [_]eth.dex_v2.Pair{ - .{ .reserve_in = 100e18, .reserve_out = 200_000e6 }, // ETH -> USDC - .{ .reserve_in = 300_000e6, .reserve_out = 50e18 }, // USDC -> DAI + .{ .reserve_in = 100e18, .reserve_out = 200_000e6 }, // ETH (18 dec) -> USDC (6 dec) + .{ .reserve_in = 300_000e6, .reserve_out = 50e18 }, // USDC (6 dec) -> DAI (18 dec) }; // Forward: how much DAI for 1 ETH? @@ -83,6 +83,9 @@ const sqrt_price_0 = eth.dex_v3.getSqrtRatioAtTick(0).?; // tick 0 = price 1 const tick = eth.dex_v3.getTickAtSqrtRatio(sqrt_price).?; // = 100 ``` +> These functions return `null` for out-of-range ticks (outside `MIN_TICK`..`MAX_TICK`) +> or invalid sqrt prices. Always check the result before unwrapping with `.?`. + ### Token Amount Deltas ```zig diff --git a/src/dex/router.zig b/src/dex/router.zig index 9626a44..58a695c 100644 --- a/src/dex/router.zig +++ b/src/dex/router.zig @@ -79,47 +79,45 @@ pub fn quoteExactOutput(amount_out: u256, hops: []const Pool) ?u256 { /// Returns null if no profitable opportunity exists. pub fn findArbOpportunity(hops: []const Pool, max_input: u256) ?ArbOpportunity { if (hops.len == 0) return null; + if (max_input == 0) return null; - // Check if there's any profit at all with a small amount - const small_amount: u256 = 1000; + // Probe with a small amount to check if arb exists + const small_amount = @min(@as(u256, 1000), max_input); const small_output = quoteExactInput(small_amount, hops) orelse return null; if (small_output <= small_amount) return null; - // Binary search for optimal input - // The profit function is concave, so we search for the peak + // Binary search for optimal input (profit is concave for constant-product AMMs) var lo: u256 = 1; var hi: u256 = max_input; - // Run binary search for ~100 iterations (enough for u256 precision) var iterations: u32 = 0; - while (lo < hi and iterations < 128) : (iterations += 1) { - // Avoid overflow in midpoint calculation + while (lo + 1 < hi and iterations < 128) : (iterations += 1) { const mid = lo + (hi - lo) / 2; - if (mid == lo) break; const mid_output = quoteExactInput(mid, hops) orelse break; if (mid == std.math.maxInt(u256)) break; const mid_plus = quoteExactInput(mid + 1, hops) orelse break; - // Check marginal profit at mid: is f(mid+1) - f(mid) > 1? - // If marginal output > marginal input (1), we can increase input + // Check marginal profit: is f(mid+1) - f(mid) > 1? if (mid_plus > mid_output and mid_plus - mid_output > 1) { - // Still profitable to increase - marginal output > marginal input lo = mid; } else { hi = mid; } } - // Evaluate profit at the found optimal point - const optimal = lo; - const output = quoteExactInput(optimal, hops) orelse return null; - if (output <= optimal) return null; + // Evaluate both endpoints, pick the better one + const lo_output = quoteExactInput(lo, hops); + const hi_output = quoteExactInput(hi, hops); + const lo_profit: u256 = if (lo_output) |o| (if (o > lo) o - lo else 0) else 0; + const hi_profit: u256 = if (hi_output) |o| (if (o > hi) o - hi else 0) else 0; - return .{ - .profit = output - optimal, - .optimal_input = optimal, - }; + if (lo_profit == 0 and hi_profit == 0) return null; + + return if (hi_profit > lo_profit) + .{ .profit = hi_profit, .optimal_input = hi } + else + .{ .profit = lo_profit, .optimal_input = lo }; } // ============================================================================ diff --git a/src/dex/v2.zig b/src/dex/v2.zig index 88459f5..4e0197d 100644 --- a/src/dex/v2.zig +++ b/src/dex/v2.zig @@ -27,6 +27,8 @@ pub const Pair = struct { /// Formula: (amountIn * feeNum * reserveOut) / (reserveIn * feeDenom + amountIn * feeNum) pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256, fee_numerator: u64, fee_denominator: u64) u256 { if (amount_in == 0) return 0; + if (reserve_in == 0 or reserve_out == 0) return 0; + if (fee_denominator == 0) return 0; const ai = u256ToLimbs(amount_in); const ri = u256ToLimbs(reserve_in); @@ -47,7 +49,10 @@ pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256, fee_nu /// Formula: (reserveIn * amountOut * feeDenom) / ((reserveOut - amountOut) * feeNum) + 1 /// Returns null if amount_out >= reserve_out (insufficient liquidity). pub fn getAmountIn(amount_out: u256, reserve_in: u256, reserve_out: u256, fee_numerator: u64, fee_denominator: u64) ?u256 { - if (amount_out == 0) return 0; + if (amount_out == 0) return @as(u256, 0); + if (reserve_in == 0 or reserve_out == 0) return null; + if (fee_denominator == 0) return null; + if (fee_numerator == 0) return null; if (amount_out >= reserve_out) return null; const reserve_diff = reserve_out - amount_out; @@ -116,14 +121,17 @@ pub fn calculateProfit(amount_in: u256, path: []const Pair) ?u256 { // Tests // ============================================================================ -test "getAmountOut matches legacy" { - const amount_in: u256 = 1_000_000_000_000_000_000; // 1 ETH - const reserve_in: u256 = 100_000_000_000_000_000_000; // 100 ETH - const reserve_out: u256 = 200_000_000_000; // 200k USDC (6 decimals) +test "getAmountOut known value" { + // 1 ETH in, 100 ETH / 200k USDC pool, 0.3% fee + // Expected: (1e18 * 997 * 200e9) / (100e18 * 1000 + 1e18 * 997) = 1_974_316_068 + const v2_result = getAmountOut(1_000_000_000_000_000_000, 100_000_000_000_000_000_000, 200_000_000_000, 997, 1000); + try std.testing.expectEqual(@as(u256, 1_974_316_068), v2_result); +} - const v2_result = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); - const legacy_result = uint256_mod.getAmountOut(amount_in, reserve_in, reserve_out); - try std.testing.expectEqual(legacy_result, v2_result); +test "getAmountOut zero reserves" { + try std.testing.expectEqual(@as(u256, 0), getAmountOut(1000, 0, 200_000, 997, 1000)); + try std.testing.expectEqual(@as(u256, 0), getAmountOut(1000, 100_000, 0, 997, 1000)); + try std.testing.expectEqual(@as(u256, 0), getAmountOut(1000, 100_000, 200_000, 997, 0)); } test "getAmountOut different fees" { diff --git a/src/dex/v3.zig b/src/dex/v3.zig index fd98a5a..9ce9d57 100644 --- a/src/dex/v3.zig +++ b/src/dex/v3.zig @@ -257,6 +257,7 @@ pub fn getAmount1Delta(sqrt_ratio_a_x96: u256, sqrt_ratio_b_x96: u256, liquidity /// When add=true (input token0), price goes down. /// When add=false (output token0), price goes up. pub fn getNextSqrtPriceFromAmount0RoundingUp(sqrt_price_x96: u256, liquidity: u128, amount: u256, add: bool) ?u256 { + if (liquidity == 0 or sqrt_price_x96 == 0) return null; if (amount == 0) return sqrt_price_x96; const numerator1: u256 = @as(u256, liquidity) << 96; @@ -295,6 +296,7 @@ pub fn getNextSqrtPriceFromAmount0RoundingUp(sqrt_price_x96: u256, liquidity: u1 /// When add=true (input token1), price goes up. /// When add=false (output token1), price goes down. pub fn getNextSqrtPriceFromAmount1RoundingDown(sqrt_price_x96: u256, liquidity: u128, amount: u256, add: bool) ?u256 { + if (liquidity == 0 or sqrt_price_x96 == 0) return null; if (add) { // quotient = amount * Q96 / liquidity (or amount << 96 / liquidity if fits) const quotient: u256 = if (amount <= (@as(u256, 1) << 160) - 1) @@ -359,6 +361,12 @@ pub fn computeSwapStep( amount_remaining: i256, fee_pips: u24, // e.g. 3000 = 0.3% ) SwapStepResult { + if (fee_pips > 1_000_000) return .{ + .sqrt_ratio_next_x96 = sqrt_ratio_current_x96, + .amount_in = 0, + .amount_out = 0, + .fee_amount = 0, + }; const zero_for_one = sqrt_ratio_current_x96 >= sqrt_ratio_target_x96; const exact_in = amount_remaining >= 0; @@ -533,7 +541,28 @@ pub fn simulateSwap( } } ticks_crossed += 1; + if (current_liquidity == 0) break; + } + } + + // Terminal step: consume remaining input in unbounded range if liquidity available + if (amount_remaining > 0 and current_liquidity > 0) { + const terminal_sqrt: u256 = if (zero_for_one) MIN_SQRT_RATIO + 1 else MAX_SQRT_RATIO - 1; + const step = computeSwapStep( + current_sqrt_price, + terminal_sqrt, + current_liquidity, + @as(i256, @intCast(amount_remaining)), + fee_pips, + ); + const consumed = step.amount_in + step.fee_amount; + if (consumed >= amount_remaining) { + amount_remaining = 0; + } else { + amount_remaining -= consumed; } + total_amount_out += step.amount_out; + current_sqrt_price = step.sqrt_ratio_next_x96; } return .{ diff --git a/src/provider.zig b/src/provider.zig index 1ea0835..44b5d74 100644 --- a/src/provider.zig +++ b/src/provider.zig @@ -341,6 +341,8 @@ pub const BatchCaller = struct { self.calldata.deinit(self.allocator); } + /// Add an eth_call to the batch. Returns the index for result retrieval. + /// `data` is borrowed (not copied) -- caller must keep it valid until `execute()` returns. pub fn addCall(self: *BatchCaller, to: [20]u8, data: []const u8) !usize { const index = self.targets.items.len; try self.targets.append(self.allocator, to); @@ -391,10 +393,12 @@ pub const BatchCaller = struct { } }; +/// Parse a JSON-RPC batch response, matching results to request IDs. +/// Unmatched IDs are left as sentinel values: `.rpc_error = .{ .code = -1, .message = "" }`. +/// Callers can detect missing responses by checking for `code == -1` and `message.len == 0`. fn parseBatchResponse(allocator: std.mem.Allocator, raw: []const u8, ids: []const u64) ![]BatchCallResult { const n = ids.len; var results = try allocator.alloc(BatchCallResult, n); - // Initialize all to a sentinel so we can detect missing responses for (results) |*r| r.* = .{ .rpc_error = .{ .code = -1, .message = "" } }; // Parse JSON array diff --git a/src/uint256.zig b/src/uint256.zig index 90796dd..1e0bdc0 100644 --- a/src/uint256.zig +++ b/src/uint256.zig @@ -634,7 +634,6 @@ pub inline fn mulDiv(a: u256, b: u256, denominator: u256) ?u256 { /// mulDiv with rounding up: ceil(a * b / denominator) pub fn mulDivRoundingUp(a: u256, b: u256, denominator: u256) ?u256 { const result = mulDiv(a, b, denominator) orelse return null; - if (denominator == 0) return null; // Check remainder: if result * denominator != a * b, round up const a_limbs = u256ToLimbs(a); const b_limbs = u256ToLimbs(b); From 35209e7c734f94bd7adad1f730b9a8057c8a9c5b Mon Sep 17 00:00:00 2001 From: Koko Bhadra Date: Tue, 10 Mar 2026 17:14:57 -0400 Subject: [PATCH 3/5] Guard against negative JSON-RPC ids and fix zero-length free - Skip negative id values in batch response parsing instead of panicking - Always free zero-length success data (was leaking empty eth_call results) --- src/provider.zig | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/provider.zig b/src/provider.zig index 44b5d74..750b086 100644 --- a/src/provider.zig +++ b/src/provider.zig @@ -422,7 +422,7 @@ fn parseBatchResponse(allocator: std.mem.Allocator, raw: []const u8, ids: []cons // Get id const id_val = obj.get("id") orelse continue; const id: u64 = switch (id_val) { - .integer => |i| @intCast(i), + .integer => |i| if (i >= 0) @as(u64, @intCast(i)) else continue, else => continue, }; @@ -471,7 +471,7 @@ fn parseBatchResponse(allocator: std.mem.Allocator, raw: []const u8, ids: []cons pub fn freeBatchResults(allocator: std.mem.Allocator, results: []BatchCallResult) void { for (results) |r| { switch (r) { - .success => |data| if (data.len > 0) allocator.free(data), + .success => |data| allocator.free(data), .rpc_error => |e| if (e.message.len > 0) allocator.free(@constCast(e.message)), } } From f8e45f5cca36eabe32842518a0e9dd69502aee95 Mon Sep 17 00:00:00 2001 From: Koko Bhadra Date: Tue, 10 Mar 2026 17:31:17 -0400 Subject: [PATCH 4/5] Fix RPC error message leak and bodies partial-init UB - Change RpcErrorData.message from []const u8 to ?[]u8 - Sentinel uses null (not string literal), eliminating @constCast - freeBatchResults frees message only when non-null - Initialize bodies array to empty slices before loop to prevent freeing uninitialized pointers on partial failure --- src/provider.zig | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/provider.zig b/src/provider.zig index 750b086..0aa02ef 100644 --- a/src/provider.zig +++ b/src/provider.zig @@ -317,7 +317,7 @@ pub const BatchCallResult = union(enum) { pub const RpcErrorData = struct { code: i64, - message: []const u8, + message: ?[]u8, }; }; @@ -361,8 +361,9 @@ pub const BatchCaller = struct { // Build individual request bodies const bodies = try self.allocator.alloc([]u8, n); + @memset(bodies, &.{}); defer { - for (bodies) |b| self.allocator.free(b); + for (bodies) |b| if (b.len > 0) self.allocator.free(b); self.allocator.free(bodies); } @@ -394,12 +395,12 @@ pub const BatchCaller = struct { }; /// Parse a JSON-RPC batch response, matching results to request IDs. -/// Unmatched IDs are left as sentinel values: `.rpc_error = .{ .code = -1, .message = "" }`. -/// Callers can detect missing responses by checking for `code == -1` and `message.len == 0`. +/// Unmatched IDs are left as sentinel values: `.rpc_error = .{ .code = -1, .message = null }`. +/// Callers can detect missing responses by checking for `code == -1` and `message == null`. fn parseBatchResponse(allocator: std.mem.Allocator, raw: []const u8, ids: []const u64) ![]BatchCallResult { const n = ids.len; var results = try allocator.alloc(BatchCallResult, n); - for (results) |*r| r.* = .{ .rpc_error = .{ .code = -1, .message = "" } }; + for (results) |*r| r.* = .{ .rpc_error = .{ .code = -1, .message = null } }; // Parse JSON array const parsed = std.json.parseFromSlice(std.json.Value, allocator, raw, .{}) catch { @@ -447,8 +448,8 @@ fn parseBatchResponse(allocator: std.mem.Allocator, raw: []const u8, ids: []cons .string => |s| s, else => "unknown error", } else "unknown error"; - // Dupe message since parsed will be freed - const msg_copy = try allocator.dupe(u8, message); + // Dupe message since parsed JSON will be freed + const msg_copy: []u8 = try allocator.dupe(u8, message); results[index] = .{ .rpc_error = .{ .code = code, .message = msg_copy } }; continue; } @@ -472,7 +473,7 @@ pub fn freeBatchResults(allocator: std.mem.Allocator, results: []BatchCallResult for (results) |r| { switch (r) { .success => |data| allocator.free(data), - .rpc_error => |e| if (e.message.len > 0) allocator.free(@constCast(e.message)), + .rpc_error => |e| if (e.message) |msg| allocator.free(msg), } } allocator.free(results); From ebd7dd09e09ecb9e56cdd30d60e577e38d508d0f Mon Sep 17 00:00:00 2001 From: Koko Bhadra Date: Tue, 10 Mar 2026 17:41:39 -0400 Subject: [PATCH 5/5] Add errdefer to free results on parse failure in parseBatchResponse Prevents leaking the results array and any already-duped messages when JSON parsing or hex decoding fails mid-function. --- src/provider.zig | 1 + 1 file changed, 1 insertion(+) diff --git a/src/provider.zig b/src/provider.zig index 0aa02ef..2a3653e 100644 --- a/src/provider.zig +++ b/src/provider.zig @@ -401,6 +401,7 @@ fn parseBatchResponse(allocator: std.mem.Allocator, raw: []const u8, ids: []cons const n = ids.len; var results = try allocator.alloc(BatchCallResult, n); for (results) |*r| r.* = .{ .rpc_error = .{ .code = -1, .message = null } }; + errdefer freeBatchResults(allocator, results); // Parse JSON array const parsed = std.json.parseFromSlice(std.json.Value, allocator, raw, .{}) catch {