diff --git a/src/abi_decode.zig b/src/abi_decode.zig index fce7d5c..81db3b8 100644 --- a/src/abi_decode.zig +++ b/src/abi_decode.zig @@ -86,8 +86,9 @@ fn decodeValuesAt(data: []const u8, base: usize, types: []const AbiType, allocat } var result = try allocator.alloc(AbiValue, n); + var decoded_count: usize = 0; errdefer { - for (result[0..n]) |*val| { + for (result[0..decoded_count]) |*val| { freeValue(val, allocator); } allocator.free(result); @@ -106,6 +107,7 @@ fn decodeValuesAt(data: []const u8, base: usize, types: []const AbiType, allocat } else { result[i] = try decodeStaticValue(data[head_offset..][0..32], abi_type, allocator); } + decoded_count += 1; } return result; diff --git a/src/abi_encode.zig b/src/abi_encode.zig index 9c29862..6045506 100644 --- a/src/abi_encode.zig +++ b/src/abi_encode.zig @@ -1,6 +1,8 @@ const std = @import("std"); const uint256_mod = @import("uint256.zig"); +const max_tuple_values = 32; + /// Tagged union representing any ABI-encodable value. pub const AbiValue = union(enum) { /// Unsigned 256-bit integer (covers uint8 through uint256). @@ -54,11 +56,13 @@ pub const AbiValue = union(enum) { /// Errors during ABI encoding. pub const EncodeError = error{ OutOfMemory, + TooManyValues, }; /// Encode a slice of ABI values according to the Solidity ABI specification. /// Returns the encoded bytes. Caller owns the returned memory. pub fn encodeValues(allocator: std.mem.Allocator, values: []const AbiValue) EncodeError![]u8 { + if (values.len > max_tuple_values) return error.TooManyValues; const total = calcEncodedSize(values); const buf = try allocator.alloc(u8, total); errdefer allocator.free(buf); @@ -69,6 +73,7 @@ pub fn encodeValues(allocator: std.mem.Allocator, values: []const AbiValue) Enco /// Encode a function call: 4-byte selector followed by ABI-encoded arguments. /// Returns the encoded bytes. Caller owns the returned memory. pub fn encodeFunctionCall(allocator: std.mem.Allocator, selector: [4]u8, values: []const AbiValue) EncodeError![]u8 { + if (values.len > max_tuple_values) return error.TooManyValues; const total = 4 + calcEncodedSize(values); const buf = try allocator.alloc(u8, total); errdefer allocator.free(buf); @@ -131,7 +136,8 @@ fn encodeValuesInto(allocator: std.mem.Allocator, buf: *std.ArrayList(u8), value // First pass: calculate tail offsets for dynamic values // and pre-compute the offset each dynamic value will be at - var offsets: [32]usize = undefined; // max 32 values in a single tuple + std.debug.assert(values.len <= max_tuple_values); + var offsets: [max_tuple_values]usize = undefined; for (values, 0..) |val, i| { if (val.isDynamic()) { offsets[i] = tail_offset; @@ -151,7 +157,7 @@ fn encodeValuesInto(allocator: std.mem.Allocator, buf: *std.ArrayList(u8), value // Third pass: write tail section directly into buf (no temp allocations) for (values) |val| { if (val.isDynamic()) { - encodeDynamicValueInto(allocator, buf, val); + encodeDynamicValueInto(buf, val); } } } @@ -202,49 +208,8 @@ fn encodeStaticValueNoAlloc(buf: *std.ArrayList(u8), val: AbiValue) void { } } -/// Encode a static value directly as a 32-byte word (allocating variant for backward compat). -fn encodeStaticValue(allocator: std.mem.Allocator, buf: *std.ArrayList(u8), val: AbiValue) EncodeError!void { - switch (val) { - .uint256 => |v| { - try writeUint256(allocator, buf, v); - }, - .int256 => |v| { - try writeInt256(allocator, buf, v); - }, - .address => |v| { - var word: [32]u8 = [_]u8{0} ** 32; - @memcpy(word[12..32], &v); - try buf.appendSlice(allocator, &word); - }, - .boolean => |v| { - var word: [32]u8 = [_]u8{0} ** 32; - if (v) word[31] = 1; - try buf.appendSlice(allocator, &word); - }, - .fixed_bytes => |v| { - var word: [32]u8 = [_]u8{0} ** 32; - const size: usize = @intCast(v.len); - @memcpy(word[0..size], v.data[0..size]); - try buf.appendSlice(allocator, &word); - }, - .fixed_array => |items| { - for (items) |item| { - try encodeStaticValue(allocator, buf, item); - } - }, - .tuple => |items| { - for (items) |item| { - try encodeStaticValue(allocator, buf, item); - } - }, - else => unreachable, - } -} - /// Encode a dynamic value directly into the output buffer (no temp allocation). -fn encodeDynamicValueInto(allocator: std.mem.Allocator, buf: *std.ArrayList(u8), val: AbiValue) void { - _ = allocator; - +fn encodeDynamicValueInto(buf: *std.ArrayList(u8), val: AbiValue) void { switch (val) { .bytes => |data| { writeUint256NoAlloc(buf, @intCast(data.len)); @@ -281,7 +246,8 @@ fn encodeValuesIntoNoAlloc(buf: *std.ArrayList(u8), values: []const AbiValue) vo var tail_offset: usize = head_size; // Calculate offsets for dynamic values - var offsets: [32]usize = undefined; + std.debug.assert(values.len <= max_tuple_values); + var offsets: [max_tuple_values]usize = undefined; for (values, 0..) |val, i| { if (val.isDynamic()) { offsets[i] = tail_offset; @@ -301,7 +267,7 @@ fn encodeValuesIntoNoAlloc(buf: *std.ArrayList(u8), values: []const AbiValue) vo // Write tails for (values) |val| { if (val.isDynamic()) { - encodeDynamicValueInto(undefined, buf, val); + encodeDynamicValueInto(buf, val); } } } @@ -336,7 +302,8 @@ fn writeValuesDirect(buf: []u8, values: []const AbiValue) void { } var tail_offset: usize = head_size; - var offsets: [32]usize = undefined; + std.debug.assert(values.len <= max_tuple_values); + var offsets: [max_tuple_values]usize = undefined; for (values, 0..) |val, i| { if (val.isDynamic()) { offsets[i] = tail_offset; @@ -421,19 +388,6 @@ fn writeDynamicValueDirect(buf: []u8, val: AbiValue) usize { } } -/// Write a u256 as a big-endian 32-byte word. -fn writeUint256(allocator: std.mem.Allocator, buf: *std.ArrayList(u8), value: u256) EncodeError!void { - const bytes = uint256_mod.toBigEndianBytes(value); - try buf.appendSlice(allocator, &bytes); -} - -/// Write an i256 as a big-endian 32-byte two's complement word. -fn writeInt256(allocator: std.mem.Allocator, buf: *std.ArrayList(u8), value: i256) EncodeError!void { - // Two's complement: cast to u256 bit pattern, then write as big-endian. - const unsigned: u256 = @bitCast(value); - try writeUint256(allocator, buf, unsigned); -} - // ============================================================================ // Tests // ============================================================================ diff --git a/src/abi_json.zig b/src/abi_json.zig index 9a73717..9785fa4 100644 --- a/src/abi_json.zig +++ b/src/abi_json.zig @@ -187,7 +187,7 @@ fn parseParam(allocator: std.mem.Allocator, obj: std.json.ObjectMap) !AbiParam { const name_str = jsonGetString(obj, "name") orelse ""; const indexed = jsonGetBool(obj, "indexed") orelse false; - const abi_type = parseType(type_str); + const abi_type = parseType(type_str) orelse return error.UnknownType; const name = if (name_str.len > 0) try allocator.dupe(u8, name_str) else name_str; errdefer if (name.len > 0) allocator.free(name); @@ -214,7 +214,7 @@ fn parseMutability(obj: std.json.ObjectMap) StateMutability { } /// Parse a Solidity type string into an AbiType. -pub fn parseType(type_str: []const u8) AbiType { +pub fn parseType(type_str: []const u8) ?AbiType { // Handle array suffixes if (std.mem.endsWith(u8, type_str, "[]")) return .dynamic_array; @@ -243,7 +243,7 @@ pub fn parseType(type_str: []const u8) AbiType { return parseBytesType(type_str) orelse .bytes; } - return .uint256; // fallback + return null; // unknown type } fn parseUintType(type_str: []const u8) ?AbiType { @@ -413,6 +413,11 @@ test "parseType - int without bits defaults to int256" { try std.testing.expectEqual(AbiType.int256, parseType("int")); } +test "parseType - unknown type returns null" { + try std.testing.expectEqual(@as(?AbiType, null), parseType("foobar")); + try std.testing.expectEqual(@as(?AbiType, null), parseType("custom_type")); +} + test "ContractAbi.fromJson - ERC20 ABI" { const allocator = std.testing.allocator; const json = diff --git a/src/chains/chain.zig b/src/chains/chain.zig index 84d1d0b..2594edf 100644 --- a/src/chains/chain.zig +++ b/src/chains/chain.zig @@ -37,10 +37,10 @@ pub const Chain = struct { testnet: bool = false, }; -/// Parse a hex address string into a 20-byte address. -/// Works at both comptime and runtime. -pub fn addressFromHex(hex_str: []const u8) Address { - return hex_mod.hexToBytesFixed(20, hex_str) catch unreachable; +/// Parse a hex address string into a 20-byte address at comptime. +/// Compile error if the hex string is invalid. +pub fn addressFromHex(comptime hex_str: []const u8) Address { + return comptime hex_mod.hexToBytesFixed(20, hex_str) catch @compileError("invalid hex address: " ++ hex_str); } /// Look up a chain by ID. diff --git a/src/dex/router.zig b/src/dex/router.zig index 58a695c..a6647f3 100644 --- a/src/dex/router.zig +++ b/src/dex/router.zig @@ -128,7 +128,7 @@ pub fn findArbOpportunity(hops: []const Pool, max_input: u256) ?ArbOpportunity { fn quotePool(amount_in: u256, pool: Pool) ?u256 { switch (pool) { .v2 => |p| { - const result = v2.getAmountOut(amount_in, p.reserve_in, p.reserve_out, p.fee_numerator, p.fee_denominator); + const result = v2.getAmountOut(amount_in, p.reserve_in, p.reserve_out, p.fee_numerator, p.fee_denominator) orelse return null; return if (result == 0) null else result; }, .v3 => |p| { @@ -161,7 +161,7 @@ test "quoteExactInput V2 single hop" { try std.testing.expect(result != null); // Should match direct V2 calculation - const direct = v2.getAmountOut(1_000_000_000_000_000_000, 100_000_000_000_000_000_000, 200_000_000_000, 997, 1000); + const direct = v2.getAmountOut(1_000_000_000_000_000_000, 100_000_000_000_000_000_000, 200_000_000_000, 997, 1000).?; try std.testing.expectEqual(direct, result.?); } diff --git a/src/dex/v2.zig b/src/dex/v2.zig index 4e0197d..b33cbd1 100644 --- a/src/dex/v2.zig +++ b/src/dex/v2.zig @@ -25,10 +25,10 @@ pub const Pair = struct { /// Compute UniswapV2 getAmountOut with configurable fee, entirely in u64-limb space. /// Formula: (amountIn * feeNum * reserveOut) / (reserveIn * feeDenom + amountIn * feeNum) -pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256, fee_numerator: u64, fee_denominator: u64) u256 { +pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256, fee_numerator: u64, fee_denominator: u64) ?u256 { if (amount_in == 0) return 0; if (reserve_in == 0 or reserve_out == 0) return 0; - if (fee_denominator == 0) return 0; + if (fee_denominator == 0) return null; const ai = u256ToLimbs(amount_in); const ri = u256ToLimbs(reserve_in); @@ -39,7 +39,7 @@ pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256, fee_nu const denominator = addLimbs(mulLimbScalar(ri, fee_denominator), amount_in_with_fee); if (denominator[0] == 0 and denominator[1] == 0 and denominator[2] == 0 and denominator[3] == 0) { - @panic("getAmountOut: denominator is zero (invalid reserves)"); + return null; } return limbsToU256(divLimbsDirect(numerator, denominator)); @@ -69,7 +69,7 @@ pub fn getAmountIn(amount_out: u256, reserve_in: u256, reserve_out: u256, fee_nu const denominator = mulLimbScalar(rd, fee_numerator); if (denominator[0] == 0 and denominator[1] == 0 and denominator[2] == 0 and denominator[3] == 0) { - @panic("getAmountIn: denominator is zero"); + return null; } // Uniswap V2 always adds 1 (ceiling) @@ -88,7 +88,7 @@ pub fn getAmountsOut(amount_in: u256, path: []const Pair) ?u256 { var current = amount_in; for (path) |pair| { - current = getAmountOut(current, pair.reserve_in, pair.reserve_out, pair.fee_numerator, pair.fee_denominator); + current = getAmountOut(current, pair.reserve_in, pair.reserve_out, pair.fee_numerator, pair.fee_denominator) orelse return null; if (current == 0) return null; } return current; @@ -124,14 +124,14 @@ pub fn calculateProfit(amount_in: u256, path: []const Pair) ?u256 { test "getAmountOut known value" { // 1 ETH in, 100 ETH / 200k USDC pool, 0.3% fee // Expected: (1e18 * 997 * 200e9) / (100e18 * 1000 + 1e18 * 997) = 1_974_316_068 - const v2_result = getAmountOut(1_000_000_000_000_000_000, 100_000_000_000_000_000_000, 200_000_000_000, 997, 1000); + const v2_result = getAmountOut(1_000_000_000_000_000_000, 100_000_000_000_000_000_000, 200_000_000_000, 997, 1000).?; try std.testing.expectEqual(@as(u256, 1_974_316_068), v2_result); } test "getAmountOut zero reserves" { - try std.testing.expectEqual(@as(u256, 0), getAmountOut(1000, 0, 200_000, 997, 1000)); - try std.testing.expectEqual(@as(u256, 0), getAmountOut(1000, 100_000, 0, 997, 1000)); - try std.testing.expectEqual(@as(u256, 0), getAmountOut(1000, 100_000, 200_000, 997, 0)); + try std.testing.expectEqual(@as(?u256, 0), getAmountOut(1000, 0, 200_000, 997, 1000)); + try std.testing.expectEqual(@as(?u256, 0), getAmountOut(1000, 100_000, 0, 997, 1000)); + try std.testing.expectEqual(@as(?u256, null), getAmountOut(1000, 100_000, 200_000, 997, 0)); } test "getAmountOut different fees" { @@ -140,8 +140,8 @@ test "getAmountOut different fees" { const reserve_out: u256 = 200_000_000_000; // PancakeSwap uses 9975/10000 (0.25% fee) vs Uniswap 997/1000 (0.3% fee) - const pancake = getAmountOut(amount_in, reserve_in, reserve_out, 9975, 10000); - const uniswap = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); + const pancake = getAmountOut(amount_in, reserve_in, reserve_out, 9975, 10000).?; + const uniswap = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000).?; // Lower fee => more output try std.testing.expect(pancake > uniswap); @@ -149,7 +149,7 @@ test "getAmountOut different fees" { test "getAmountOut zero input" { const result = getAmountOut(0, 100_000, 200_000, 997, 1000); - try std.testing.expectEqual(@as(u256, 0), result); + try std.testing.expectEqual(@as(?u256, 0), result); } test "getAmountOut result less than reserve" { @@ -158,7 +158,7 @@ test "getAmountOut result less than reserve" { const reserve_out: u256 = 200_000_000_000; for (amounts) |amount_in| { - const result = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); + const result = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000).?; try std.testing.expect(result < reserve_out); } } @@ -168,7 +168,7 @@ test "getAmountIn inverse" { const reserve_in: u256 = 100_000_000_000_000_000_000; const reserve_out: u256 = 200_000_000_000; - const output = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); + const output = getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000).?; const recovered_input = getAmountIn(output, reserve_in, reserve_out, 997, 1000) orelse unreachable; // Due to ceiling division (+1), recovered_input >= amount_in diff --git a/src/eip155.zig b/src/eip155.zig index 3cf4719..dda1bc4 100644 --- a/src/eip155.zig +++ b/src/eip155.zig @@ -18,11 +18,11 @@ pub fn applyEip155(v: u8, chain_id: u64) u256 { /// Extract the recovery ID (0 or 1) from an EIP-155 encoded v value. /// Reverses the formula: recovery_id = v - chain_id * 2 - 35 -pub fn recoverFromEip155V(v: u256, chain_id: u64) u8 { +pub fn recoverFromEip155V(v: u256, chain_id: u64) ?u8 { const base: u256 = @as(u256, chain_id) * 2 + 35; - if (v < base) return 0; + if (v < base) return null; const recovery_id = v - base; - if (recovery_id > 1) return 0; + if (recovery_id > 1) return null; return @intCast(recovery_id); } @@ -86,13 +86,18 @@ test "applyEip155 - Arbitrum (chain_id=42161)" { } test "recoverFromEip155V - Ethereum mainnet" { - try std.testing.expectEqual(@as(u8, 0), recoverFromEip155V(37, 1)); - try std.testing.expectEqual(@as(u8, 1), recoverFromEip155V(38, 1)); + try std.testing.expectEqual(@as(?u8, 0), recoverFromEip155V(37, 1)); + try std.testing.expectEqual(@as(?u8, 1), recoverFromEip155V(38, 1)); } test "recoverFromEip155V - BSC" { - try std.testing.expectEqual(@as(u8, 0), recoverFromEip155V(147, 56)); - try std.testing.expectEqual(@as(u8, 1), recoverFromEip155V(148, 56)); + try std.testing.expectEqual(@as(?u8, 0), recoverFromEip155V(147, 56)); + try std.testing.expectEqual(@as(?u8, 1), recoverFromEip155V(148, 56)); +} + +test "recoverFromEip155V - invalid v returns null" { + try std.testing.expectEqual(@as(?u8, null), recoverFromEip155V(5, 1)); + try std.testing.expectEqual(@as(?u8, null), recoverFromEip155V(100, 1)); } test "recoverFromEip155V roundtrip" { @@ -100,7 +105,7 @@ test "recoverFromEip155V roundtrip" { for (chain_ids) |chain_id| { for ([_]u8{ 0, 1 }) |recovery_id| { const eip155_v = applyEip155(recovery_id, chain_id); - const recovered = recoverFromEip155V(eip155_v, chain_id); + const recovered = recoverFromEip155V(eip155_v, chain_id).?; try std.testing.expectEqual(recovery_id, recovered); } } diff --git a/src/eip712.zig b/src/eip712.zig index 5d2ebcd..1261bca 100644 --- a/src/eip712.zig +++ b/src/eip712.zig @@ -195,7 +195,9 @@ pub fn hashStruct( defer allocator.free(ref_types); // Compute typeHash - const fields = if (td) |t| t.fields else fieldDefsFromFieldValues(struct_val.fields); + const inferred_fields = if (td == null) try fieldDefsFromFieldValues(allocator, struct_val.fields) else null; + defer if (inferred_fields) |f| allocator.free(f); + const fields = if (td) |t| t.fields else inferred_fields.?; const type_hash = try hashType(allocator, struct_val.type_name, fields, ref_types); // Build the data: typeHash || encodeData(field1) || encodeData(field2) || ... @@ -333,24 +335,12 @@ fn findTypeDef(type_defs: []const TypeDef, name: []const u8) ?TypeDef { } /// Build FieldDefs from FieldValues (for when we don't have an explicit TypeDef). -fn fieldDefsFromFieldValues(fields: []const FieldValue) []const FieldDef { - // FieldValue has the same layout prefix as FieldDef (name, type_str), - // but they are different types. We cannot simply reinterpret. Instead, - // we note that since FieldValue starts with { name, type_str, value }, - // and FieldDef is { name, type_str }, we need to produce a FieldDef slice. - // - // Since this function cannot allocate and we need a slice, we use a - // static buffer approach with a reasonable max. For production use, - // callers should always provide TypeDefs. - const max_fields = 32; - const S = struct { - var buf: [max_fields]FieldDef = undefined; - }; - const count = @min(fields.len, max_fields); - for (0..count) |i| { - S.buf[i] = .{ .name = fields[i].name, .type_str = fields[i].type_str }; +fn fieldDefsFromFieldValues(allocator: std.mem.Allocator, fields: []const FieldValue) std.mem.Allocator.Error![]const FieldDef { + const defs = try allocator.alloc(FieldDef, fields.len); + for (fields, 0..) |field, i| { + defs[i] = .{ .name = field.name, .type_str = field.type_str }; } - return S.buf[0..count]; + return defs; } /// Check if a type string refers to a struct type (i.e., not a built-in Solidity type). diff --git a/src/erc20.zig b/src/erc20.zig index c933982..98722f1 100644 --- a/src/erc20.zig +++ b/src/erc20.zig @@ -54,7 +54,7 @@ pub const ERC20 = struct { &.{}, &.{.string}, ); - defer self.contract.allocator.free(result); + defer contract_mod.freeReturnValues(result, self.contract.allocator); if (result.len == 0) return error.InvalidResponse; const s = result[0].string; // Copy the string so we can free the decode result @@ -74,7 +74,7 @@ pub const ERC20 = struct { &.{}, &.{.string}, ); - defer self.contract.allocator.free(result); + defer contract_mod.freeReturnValues(result, self.contract.allocator); if (result.len == 0) return error.InvalidResponse; const s = result[0].string; const owned = try self.contract.allocator.alloc(u8, s.len); diff --git a/src/erc721.zig b/src/erc721.zig index 4ca6aa6..f82c893 100644 --- a/src/erc721.zig +++ b/src/erc721.zig @@ -57,7 +57,7 @@ pub const ERC721 = struct { &.{}, &.{.string}, ); - defer self.contract.allocator.free(result); + defer contract_mod.freeReturnValues(result, self.contract.allocator); if (result.len == 0) return error.InvalidResponse; const s = result[0].string; const owned = try self.contract.allocator.alloc(u8, s.len); @@ -76,7 +76,7 @@ pub const ERC721 = struct { &.{}, &.{.string}, ); - defer self.contract.allocator.free(result); + defer contract_mod.freeReturnValues(result, self.contract.allocator); if (result.len == 0) return error.InvalidResponse; const s = result[0].string; const owned = try self.contract.allocator.alloc(u8, s.len); @@ -96,7 +96,7 @@ pub const ERC721 = struct { &args, &.{.string}, ); - defer self.contract.allocator.free(result); + defer contract_mod.freeReturnValues(result, self.contract.allocator); if (result.len == 0) return error.InvalidResponse; const s = result[0].string; const owned = try self.contract.allocator.alloc(u8, s.len); diff --git a/src/hd_wallet.zig b/src/hd_wallet.zig index fafd0f4..70bacd3 100644 --- a/src/hd_wallet.zig +++ b/src/hd_wallet.zig @@ -17,16 +17,16 @@ pub const ExtendedKey = struct { chain_code: [32]u8, /// Derive the Ethereum address from this private key. - pub fn toAddress(self: ExtendedKey) [20]u8 { + pub fn toAddress(self: ExtendedKey) HdWalletError!primitives.Address { // Get public key from private key const Secp256k1 = std.crypto.ecc.Secp256k1; - const privkey_scalar = Secp256k1.scalar.Scalar.fromBytes(self.key, .big) catch return std.mem.zeroes([20]u8); - const pubkey_point = Secp256k1.basePoint.mul(privkey_scalar.toBytes(.big), .big) catch return std.mem.zeroes([20]u8); + const privkey_scalar = Secp256k1.scalar.Scalar.fromBytes(self.key, .big) catch return error.DerivationFailed; + const pubkey_point = Secp256k1.basePoint.mul(privkey_scalar.toBytes(.big), .big) catch return error.DerivationFailed; const pubkey_bytes = pubkey_point.toUncompressedSec1(); // Address = last 20 bytes of keccak256(pubkey[1..]) const hash = keccak.hash(pubkey_bytes[1..]); - var addr: [20]u8 = undefined; + var addr: primitives.Address = undefined; @memcpy(&addr, hash[12..32]); return addr; } @@ -39,11 +39,13 @@ pub const HARDENED: u32 = 0x80000000; pub const ETH_COIN_TYPE: u32 = 60; const HmacSha512 = std.crypto.auth.hmac.sha2.HmacSha512; +const secureZero = @import("utils/constants.zig").secureZero; /// Derive the master key from a BIP-39 seed. pub fn masterKeyFromSeed(seed: [64]u8) HdWalletError!ExtendedKey { // HMAC-SHA512 with key "Bitcoin seed" var mac: [64]u8 = undefined; + defer secureZero(&mac); HmacSha512.create(&mac, &seed, "Bitcoin seed"); const key = mac[0..32].*; @@ -60,6 +62,7 @@ pub fn masterKeyFromSeed(seed: [64]u8) HdWalletError!ExtendedKey { /// Use index | HARDENED for hardened derivation. pub fn deriveChild(parent: ExtendedKey, index: u32) HdWalletError!ExtendedKey { var data: [37]u8 = undefined; + defer secureZero(&data); if (index >= HARDENED) { // Hardened: 0x00 || private_key || index @@ -79,6 +82,7 @@ pub fn deriveChild(parent: ExtendedKey, index: u32) HdWalletError!ExtendedKey { // HMAC-SHA512 var mac: [64]u8 = undefined; + defer secureZero(&mac); HmacSha512.create(&mac, &data, &parent.chain_code); const il = mac[0..32].*; @@ -177,7 +181,7 @@ test "derivePath just m returns master key" { test "toAddress produces 20-byte address" { const seed = [_]u8{0xef} ** 64; const key = try deriveEthAccount(seed, 0); - const addr = key.toAddress(); + const addr = try key.toAddress(); // Just verify it's not all zeros var all_zero = true; for (addr) |b| { @@ -193,8 +197,8 @@ test "different accounts produce different addresses" { const seed = [_]u8{0x11} ** 64; const key0 = try deriveEthAccount(seed, 0); const key1 = try deriveEthAccount(seed, 1); - const addr0 = key0.toAddress(); - const addr1 = key1.toAddress(); + const addr0 = try key0.toAddress(); + const addr1 = try key1.toAddress(); try std.testing.expect(!std.mem.eql(u8, &addr0, &addr1)); } @@ -209,7 +213,7 @@ test "known BIP-39 mnemonic to address" { const seed = try mnemonic_mod.toSeed(&words, ""); const key = try deriveEthAccount(seed, 0); - const addr = key.toAddress(); + const addr = try key.toAddress(); // The address should be deterministic - verify non-zero var all_zero = true; @@ -223,7 +227,7 @@ test "known BIP-39 mnemonic to address" { // Verify it matches when derived again const key2 = try deriveEthAccount(seed, 0); - const addr2 = key2.toAddress(); + const addr2 = try key2.toAddress(); try std.testing.expectEqualSlices(u8, &addr, &addr2); } @@ -253,7 +257,7 @@ test "known mnemonic abandon...about to exact address" { }; const seed = try mnemonic_mod.toSeed(&words, ""); const key = try deriveEthAccount(seed, 0); - const addr = key.toAddress(); + const addr = try key.toAddress(); const addr_hex = primitives.addressToChecksum(&addr); // Well-known address for this mnemonic const expected = try hex_mod.hexToBytesFixed(20, "0x9858EfFD232B4033E47d90003D41EC34EcaEda94"); @@ -274,8 +278,8 @@ test "BIP-44 second account index produces different key and address" { // Keys must differ try std.testing.expect(!std.mem.eql(u8, &key0.key, &key1.key)); // Addresses must differ - const addr0 = key0.toAddress(); - const addr1 = key1.toAddress(); + const addr0 = try key0.toAddress(); + const addr1 = try key1.toAddress(); try std.testing.expect(!std.mem.eql(u8, &addr0, &addr1)); } @@ -286,7 +290,7 @@ test "deriveEthAccount key bytes deterministic" { const key2 = try deriveEthAccount(seed, 0); try std.testing.expectEqualSlices(u8, &key1.key, &key2.key); try std.testing.expectEqualSlices(u8, &key1.chain_code, &key2.chain_code); - const addr1 = key1.toAddress(); - const addr2 = key2.toAddress(); + const addr1 = try key1.toAddress(); + const addr2 = try key2.toAddress(); try std.testing.expectEqualSlices(u8, &addr1, &addr2); } diff --git a/src/keccak_xkcp.zig b/src/keccak_xkcp.zig index 92c1454..dba6365 100644 --- a/src/keccak_xkcp.zig +++ b/src/keccak_xkcp.zig @@ -40,7 +40,7 @@ extern fn KeccakWidth1600_Sponge(r: c_uint, c: c_uint, input: [*]const u8, input pub fn hash(data: []const u8) Hash { var result: Hash = undefined; const ret = KeccakWidth1600_Sponge(rate, capacity, data.ptr, data.len, delimited_suffix, &result, 32); - std.debug.assert(ret == 0); + if (ret != 0) @panic("KeccakWidth1600_Sponge failed"); return result; } @@ -51,19 +51,19 @@ pub const Hasher = struct { pub fn init() Hasher { var self: Hasher = undefined; const ret = Keccak_HashInitialize(&self.instance, rate, capacity, hash_bit_len, delimited_suffix); - std.debug.assert(ret == .success); + if (ret != .success) @panic("Keccak_HashInitialize failed"); return self; } pub fn update(self: *Hasher, data: []const u8) void { const ret = Keccak_HashUpdate(&self.instance, data.ptr, data.len * 8); - std.debug.assert(ret == .success); + if (ret != .success) @panic("Keccak_HashUpdate failed"); } pub fn final(self: *Hasher) Hash { var result: Hash = undefined; const ret = Keccak_HashFinal(&self.instance, &result); - std.debug.assert(ret == .success); + if (ret != .success) @panic("Keccak_HashFinal failed"); return result; } }; diff --git a/src/mnemonic.zig b/src/mnemonic.zig index bfec8ac..5a251ac 100644 --- a/src/mnemonic.zig +++ b/src/mnemonic.zig @@ -172,8 +172,11 @@ fn wordToIndex(word: []const u8) ?u16 { /// Convert mnemonic to seed using PBKDF2-HMAC-SHA512. /// The passphrase is optional (empty string if not provided). pub fn toSeed(words: []const []const u8, passphrase: []const u8) ![64]u8 { + const secureZero = @import("utils/constants.zig").secureZero; + // Build mnemonic string: words joined by spaces var mnemonic_buf: [1024]u8 = undefined; + defer secureZero(&mnemonic_buf); var mnemonic_len: usize = 0; for (words, 0..) |word, i| { @@ -189,6 +192,7 @@ pub fn toSeed(words: []const []const u8, passphrase: []const u8) ![64]u8 { // Salt = "mnemonic" + passphrase var salt_buf: [256]u8 = undefined; + defer secureZero(&salt_buf); const prefix = "mnemonic"; @memcpy(salt_buf[0..prefix.len], prefix); @memcpy(salt_buf[prefix.len .. prefix.len + passphrase.len], passphrase); diff --git a/src/multicall.zig b/src/multicall.zig index aee51e3..658dc2d 100644 --- a/src/multicall.zig +++ b/src/multicall.zig @@ -167,11 +167,11 @@ pub const Multicall = struct { /// Decode the ABI-encoded result of aggregate3: (bool, bytes)[] /// Returns a slice of Result structs. Caller owns all returned memory. pub fn decodeAggregate3Results(allocator: std.mem.Allocator, data: []const u8) ![]Result { - if (data.len < 64) return error.OutOfMemory; + if (data.len < 64) return error.InvalidAbiData; // First word: offset to array data (should be 0x20) const array_offset = readWord(data[0..32]); - if (array_offset + 32 > data.len) return error.OutOfMemory; + if (array_offset + 32 > data.len) return error.InvalidAbiData; // Array length const array_len = readWord(data[array_offset .. array_offset + 32]); @@ -188,7 +188,7 @@ pub fn decodeAggregate3Results(allocator: std.mem.Allocator, data: []const u8) ! // Read offsets for each result tuple for (0..array_len) |i| { const offset_pos = array_data_start + i * 32; - if (offset_pos + 32 > data.len) return error.OutOfMemory; + if (offset_pos + 32 > data.len) return error.InvalidAbiData; const tuple_offset = readWord(data[offset_pos .. offset_pos + 32]); const tuple_start = array_data_start + tuple_offset; @@ -196,17 +196,17 @@ pub fn decodeAggregate3Results(allocator: std.mem.Allocator, data: []const u8) ! // word 0: success (bool) // word 1: offset to returnData within the tuple // At that offset: length word + data - if (tuple_start + 64 > data.len) return error.OutOfMemory; + if (tuple_start + 64 > data.len) return error.InvalidAbiData; const success_word = readWord(data[tuple_start .. tuple_start + 32]); const return_data_offset = readWord(data[tuple_start + 32 .. tuple_start + 64]); const return_data_abs = tuple_start + return_data_offset; - if (return_data_abs + 32 > data.len) return error.OutOfMemory; + if (return_data_abs + 32 > data.len) return error.InvalidAbiData; const return_data_len = readWord(data[return_data_abs .. return_data_abs + 32]); const return_data_start = return_data_abs + 32; - if (return_data_start + return_data_len > data.len) return error.OutOfMemory; + if (return_data_start + return_data_len > data.len) return error.InvalidAbiData; var return_data: []const u8 = &.{}; if (return_data_len > 0) { diff --git a/src/provider.zig b/src/provider.zig index 2a3653e..0f96947 100644 --- a/src/provider.zig +++ b/src/provider.zig @@ -649,6 +649,13 @@ fn parseTransactionReceipt(allocator: std.mem.Allocator, raw: []const u8) !?rece // Parse logs array const logs = try parseLogsArray(allocator, obj); + errdefer { + for (logs) |log| { + allocator.free(log.data); + if (log.topics.len > 0) allocator.free(log.topics); + } + if (logs.len > 0) allocator.free(logs); + } // Parse required fields const tx_hash = try parseHash(jsonGetString(obj, "transactionHash") orelse return error.InvalidResponse); @@ -690,11 +697,19 @@ fn parseLogsArray(allocator: std.mem.Allocator, obj: std.json.ObjectMap) ![]cons if (arr.items.len == 0) return &.{}; const logs = try allocator.alloc(receipt_mod.Log, arr.items.len); - errdefer allocator.free(logs); + var parsed_count: usize = 0; + errdefer { + for (logs[0..parsed_count]) |log| { + allocator.free(log.data); + if (log.topics.len > 0) allocator.free(log.topics); + } + allocator.free(logs); + } for (arr.items, 0..) |item, i| { if (item != .object) return error.InvalidResponse; logs[i] = try parseSingleLog(allocator, item.object); + parsed_count += 1; } return logs; @@ -705,9 +720,11 @@ fn parseSingleLog(allocator: std.mem.Allocator, obj: std.json.ObjectMap) !receip const address = (try parseOptionalAddress(jsonGetString(obj, "address"))) orelse return error.InvalidResponse; const data_str = jsonGetString(obj, "data") orelse "0x"; const data = try parseHexBytes(allocator, data_str); + errdefer allocator.free(data); // Parse topics array const topics = try parseTopics(allocator, obj); + errdefer if (topics.len > 0) allocator.free(topics); const block_number = try parseOptionalHexU64(jsonGetString(obj, "blockNumber")); const tx_hash = try parseOptionalHash(jsonGetString(obj, "transactionHash")); @@ -780,11 +797,19 @@ fn parseLogsResponse(allocator: std.mem.Allocator, raw: []const u8) ![]receipt_m } const logs = try allocator.alloc(receipt_mod.Log, arr.items.len); - errdefer allocator.free(logs); + var parsed_count: usize = 0; + errdefer { + for (logs[0..parsed_count]) |log| { + allocator.free(log.data); + if (log.topics.len > 0) allocator.free(log.topics); + } + allocator.free(logs); + } for (arr.items, 0..) |item, i| { if (item != .object) return error.InvalidResponse; logs[i] = try parseSingleLog(allocator, item.object); + parsed_count += 1; } return logs; @@ -838,6 +863,7 @@ fn parseBlockHeader(allocator: std.mem.Allocator, raw: []const u8) !?block_mod.B // Parse extraData const extra_data_str = jsonGetString(obj, "extraData") orelse "0x"; const extra_data = try parseHexBytes(allocator, extra_data_str); + errdefer allocator.free(extra_data); // Optional EIP-1559 / EIP-4844 fields const base_fee: ?u256 = if (jsonGetString(obj, "baseFeePerGas")) |s| diff --git a/src/rlp.zig b/src/rlp.zig index 4524abb..c2d0846 100644 --- a/src/rlp.zig +++ b/src/rlp.zig @@ -96,7 +96,7 @@ pub fn encodedLength(value: anytype) usize { return 1; // 0x80 } }, - else => return 0, + else => @compileError("unsupported type for RLP encoding"), } } @@ -375,7 +375,7 @@ pub fn writeDirect(buf: []u8, value: anytype) usize { return 1; } }, - else => return 0, + else => @compileError("unsupported type for RLP encoding"), } } diff --git a/src/secp256k1.zig b/src/secp256k1.zig index 0c3d420..8c358fb 100644 --- a/src/secp256k1.zig +++ b/src/secp256k1.zig @@ -274,12 +274,16 @@ fn builtinDerivePublicKey(private_key: [32]u8) SignError![65]u8 { /// RFC 6979 deterministic nonce generation using HMAC-SHA256. /// Implements the algorithm from RFC 6979 Section 3.2. fn generateRfc6979Nonce(private_key: [32]u8, message_hash: [32]u8) Scalar { + const secureZero = @import("utils/constants.zig").secureZero; + // Step a: h1 = message_hash (already provided, 32 bytes) // Step b: V = 0x01 0x01 ... 0x01 (32 bytes of 0x01) var v: [32]u8 = [_]u8{0x01} ** 32; + defer secureZero(&v); // Step c: K = 0x00 0x00 ... 0x00 (32 bytes of 0x00) var k: [32]u8 = [_]u8{0x00} ** 32; + defer secureZero(&k); // Step d: K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1)) var hmac_d = HmacSha256.init(&k); diff --git a/src/secp256k1_c.zig b/src/secp256k1_c.zig index e17fd40..bd85a97 100644 --- a/src/secp256k1_c.zig +++ b/src/secp256k1_c.zig @@ -30,6 +30,7 @@ const SECP256K1_CONTEXT_NONE: c_uint = 1; const SECP256K1_EC_UNCOMPRESSED: c_uint = 2; extern fn secp256k1_context_create(flags: c_uint) ?*secp256k1_context; +extern fn secp256k1_context_destroy(ctx: *secp256k1_context) void; extern fn secp256k1_context_randomize(ctx: *secp256k1_context, seed32: ?[*]const u8) c_int; extern fn secp256k1_ecdsa_sign_recoverable( @@ -91,8 +92,12 @@ fn getContext() *secp256k1_context { if (@atomicLoad(?*secp256k1_context, &global_ctx, .acquire)) |ctx| return ctx; const ctx = secp256k1_context_create(SECP256K1_CONTEXT_NONE) orelse @panic("secp256k1_context_create failed"); - @atomicStore(?*secp256k1_context, &global_ctx, ctx, .release); - return ctx; + // Use cmpxchg to avoid TOCTOU race: if another thread won, use their context + if (@cmpxchgStrong(?*secp256k1_context, &global_ctx, null, ctx, .release, .acquire)) |_| { + // Another thread already initialized; destroy our redundant context. + secp256k1_context_destroy(ctx); + } + return @atomicLoad(?*secp256k1_context, &global_ctx, .acquire).?; } // ============================================================================ diff --git a/src/signer.zig b/src/signer.zig index 0848688..5b80f45 100644 --- a/src/signer.zig +++ b/src/signer.zig @@ -9,11 +9,18 @@ const Signature = @import("signature.zig").Signature; pub const Signer = struct { private_key: [32]u8, + const secureZero = @import("utils/constants.zig").secureZero; + /// Create a new Signer from a 32-byte private key. pub fn init(private_key: [32]u8) Signer { return .{ .private_key = private_key }; } + /// Securely zero the private key. Call when the Signer is no longer needed. + pub fn deinit(self: *Signer) void { + secureZero(&self.private_key); + } + /// Derive the Ethereum address corresponding to this signer's private key. /// pubkey -> keccak256(pubkey_xy) -> last 20 bytes pub fn address(self: Signer) secp256k1.SignError!primitives.Address { diff --git a/src/subscription.zig b/src/subscription.zig index 8727946..c71a1f5 100644 --- a/src/subscription.zig +++ b/src/subscription.zig @@ -247,13 +247,17 @@ pub fn formatHash(hash: [32]u8) [66]u8 { fn extractResultString(allocator: std.mem.Allocator, json: []const u8) ![]u8 { // Look for "result":" pattern const needle = "\"result\":\""; - const start = indexOfSubstring(json, needle) orelse return error.InvalidResponse; + const start = std.mem.indexOf(u8, json, needle) orelse return error.InvalidResponse; const value_start = start + needle.len; // Find closing quote - const value_end = indexOfFrom(json, value_start, '"') orelse return error.InvalidResponse; + const value_end = if (value_start < json.len) + if (std.mem.indexOfScalar(u8, json[value_start..], '"')) |idx| idx + value_start else null + else + null; + const end = value_end orelse return error.InvalidResponse; - const value = json[value_start..value_end]; + const value = json[value_start..end]; const result = try allocator.alloc(u8, value.len); @memcpy(result, value); return result; @@ -262,12 +266,12 @@ fn extractResultString(allocator: std.mem.Allocator, json: []const u8) ![]u8 { /// Check if a JSON message is a subscription notification for the given ID. fn isSubscriptionNotification(json: []const u8, subscription_id: []const u8) bool { // Must contain "eth_subscription" method - if (!containsSubstring(json, "\"eth_subscription\"")) return false; + if (std.mem.indexOf(u8, json, "\"eth_subscription\"") == null) return false; // Must contain our subscription ID // Look for "subscription":"" pattern const needle_prefix = "\"subscription\":\""; - const prefix_pos = indexOfSubstring(json, needle_prefix) orelse return false; + const prefix_pos = std.mem.indexOf(u8, json, needle_prefix) orelse return false; const id_start = prefix_pos + needle_prefix.len; if (id_start + subscription_id.len > json.len) return false; @@ -276,32 +280,6 @@ fn isSubscriptionNotification(json: []const u8, subscription_id: []const u8) boo return std.mem.eql(u8, candidate, subscription_id); } -// --------------------------------------------------------------------------- -// String utilities -// --------------------------------------------------------------------------- - -fn containsSubstring(haystack: []const u8, needle: []const u8) bool { - return indexOfSubstring(haystack, needle) != null; -} - -fn indexOfSubstring(haystack: []const u8, needle: []const u8) ?usize { - if (needle.len > haystack.len) return null; - if (needle.len == 0) return 0; - const limit = haystack.len - needle.len + 1; - for (0..limit) |i| { - if (std.mem.eql(u8, haystack[i .. i + needle.len], needle)) { - return i; - } - } - return null; -} - -fn indexOfFrom(haystack: []const u8, start: usize, needle: u8) ?usize { - if (start >= haystack.len) return null; - const idx = std.mem.indexOfScalar(u8, haystack[start..], needle); - return if (idx) |i| i + start else null; -} - // ============================================================================ // Tests // ============================================================================ @@ -353,10 +331,10 @@ test "buildSubscribeParams - logs with address only" { const result = try buildSubscribeParams(allocator, params); defer allocator.free(result); - try std.testing.expect(containsSubstring(result, "[\"logs\",{")); - try std.testing.expect(containsSubstring(result, "\"address\":\"0x")); - try std.testing.expect(containsSubstring(result, "dededededededededededededededededededededede")); - try std.testing.expect(containsSubstring(result, "}]")); + try std.testing.expect(std.mem.indexOf(u8, result, "[\"logs\",{") != null); + try std.testing.expect(std.mem.indexOf(u8, result, "\"address\":\"0x") != null); + try std.testing.expect(std.mem.indexOf(u8, result, "dededededededededededededededededededededede") != null); + try std.testing.expect(std.mem.indexOf(u8, result, "}]") != null); } test "buildSubscribeParams - logs with topics" { @@ -372,8 +350,8 @@ test "buildSubscribeParams - logs with topics" { const result = try buildSubscribeParams(allocator, params); defer allocator.free(result); - try std.testing.expect(containsSubstring(result, "\"topics\":[")); - try std.testing.expect(containsSubstring(result, "\"0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\"")); + try std.testing.expect(std.mem.indexOf(u8, result, "\"topics\":[") != null); + try std.testing.expect(std.mem.indexOf(u8, result, "\"0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\"") != null); } test "buildSubscribeParams - logs with null topic" { @@ -388,7 +366,7 @@ test "buildSubscribeParams - logs with null topic" { const result = try buildSubscribeParams(allocator, params); defer allocator.free(result); - try std.testing.expect(containsSubstring(result, "\"topics\":[null]")); + try std.testing.expect(std.mem.indexOf(u8, result, "\"topics\":[null]") != null); } test "buildSubscribeParams - logs with address and topics" { @@ -406,10 +384,10 @@ test "buildSubscribeParams - logs with address and topics" { const result = try buildSubscribeParams(allocator, params); defer allocator.free(result); - try std.testing.expect(containsSubstring(result, "\"address\":\"0x")); - try std.testing.expect(containsSubstring(result, "\"topics\":[\"0x")); - try std.testing.expect(containsSubstring(result, "null")); - try std.testing.expect(containsSubstring(result, "}]")); + try std.testing.expect(std.mem.indexOf(u8, result, "\"address\":\"0x") != null); + try std.testing.expect(std.mem.indexOf(u8, result, "\"topics\":[\"0x") != null); + try std.testing.expect(std.mem.indexOf(u8, result, "null") != null); + try std.testing.expect(std.mem.indexOf(u8, result, "}]") != null); } test "extractResultString - valid response" { diff --git a/src/transaction.zig b/src/transaction.zig index 0f38129..47df11f 100644 --- a/src/transaction.zig +++ b/src/transaction.zig @@ -106,7 +106,7 @@ pub fn hashForSigning(allocator: std.mem.Allocator, tx: Transaction) ![32]u8 { /// - EIP-4844: 0x03 ++ RLP([chainId, nonce, maxPriorityFeePerGas, maxFeePerGas, gasLimit, to, value, data, accessList, maxFeePerBlobGas, blobVersionedHashes, v, r, s]) /// /// Caller owns the returned slice. -pub fn serializeSigned(allocator: std.mem.Allocator, tx: Transaction, r: [32]u8, s: [32]u8, v: u8) ![]u8 { +pub fn serializeSigned(allocator: std.mem.Allocator, tx: Transaction, r: [32]u8, s: [32]u8, v: u256) ![]u8 { switch (tx) { .legacy => |legacy| return serializeLegacySigned(allocator, legacy, r, s, v), .eip2930 => |eip2930| return serializeTypedSigned(allocator, 0x01, eip2930, r, s, v), @@ -383,7 +383,7 @@ fn encodeLengthAssumeCapacity(list: *std.ArrayList(u8), len: usize, offset: u8) } /// Serialize a signed legacy transaction. -fn serializeLegacySigned(allocator: std.mem.Allocator, legacy: LegacyTransaction, r: [32]u8, s: [32]u8, v: u8) ![]u8 { +fn serializeLegacySigned(allocator: std.mem.Allocator, legacy: LegacyTransaction, r: [32]u8, s: [32]u8, v: u256) ![]u8 { // Calculate payload length var payload_len: usize = 0; payload_len += rlp.encodedLength(legacy.nonce); @@ -412,7 +412,7 @@ fn serializeLegacySigned(allocator: std.mem.Allocator, legacy: LegacyTransaction } /// Serialize a signed typed transaction. -fn serializeTypedSigned(allocator: std.mem.Allocator, type_byte: u8, tx: anytype, r: [32]u8, s: [32]u8, v: u8) ![]u8 { +fn serializeTypedSigned(allocator: std.mem.Allocator, type_byte: u8, tx: anytype, r: [32]u8, s: [32]u8, v: u256) ![]u8 { // Pre-calculate total size var payload_len = calculateTypedFieldsLength(tx); payload_len += rlp.encodedLength(v); diff --git a/src/uint256.zig b/src/uint256.zig index 1e0bdc0..9d57b99 100644 --- a/src/uint256.zig +++ b/src/uint256.zig @@ -660,7 +660,7 @@ pub fn mulDivRoundingUp(a: u256, b: u256, denominator: u256) ?u256 { /// Compute UniswapV2 getAmountOut entirely in u64-limb space. /// Formula: (amountIn * 997 * reserveOut) / (reserveIn * 1000 + amountIn * 997) /// Delegates to dex/v2.zig with the standard Uniswap V2 fee (997/1000). -pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256) u256 { +pub fn getAmountOut(amount_in: u256, reserve_in: u256, reserve_out: u256) ?u256 { const dex_v2 = @import("dex/v2.zig"); return dex_v2.getAmountOut(amount_in, reserve_in, reserve_out, 997, 1000); } @@ -935,7 +935,7 @@ test "getAmountOut correctness" { const denominator = fastMul(reserve_in, 1000) +% amount_in_with_fee; const expected = fastDiv(numerator, denominator); - const result = getAmountOut(amount_in, reserve_in, reserve_out); + const result = getAmountOut(amount_in, reserve_in, reserve_out).?; try std.testing.expectEqual(expected, result); try std.testing.expect(result > 0); try std.testing.expect(result < reserve_out); @@ -943,11 +943,11 @@ test "getAmountOut correctness" { test "getAmountOut edge cases" { // Small amount in - const r1 = getAmountOut(1, 1_000_000, 1_000_000); + const r1 = getAmountOut(1, 1_000_000, 1_000_000).?; try std.testing.expect(r1 < 1_000_000); // Equal reserves - const r2 = getAmountOut(1_000_000, 1_000_000_000, 1_000_000_000); + const r2 = getAmountOut(1_000_000, 1_000_000_000, 1_000_000_000).?; try std.testing.expect(r2 > 0); try std.testing.expect(r2 < 1_000_000); } diff --git a/src/utils/constants.zig b/src/utils/constants.zig index 9fc2033..5e8d720 100644 --- a/src/utils/constants.zig +++ b/src/utils/constants.zig @@ -11,3 +11,12 @@ pub const ZERO_ADDRESS: primitives.Address = primitives.ZERO_ADDRESS; /// Zero hash (0x0000...0000). pub const ZERO_HASH: primitives.Hash = primitives.ZERO_HASH; + +/// Securely zero a buffer using volatile writes to prevent compiler elision. +/// Use this to clear sensitive data (private keys, seeds, nonces) from memory. +pub fn secureZero(buf: []u8) void { + for (buf) |*b| { + const volatile_ptr: *volatile u8 = b; + volatile_ptr.* = 0; + } +} diff --git a/src/utils/units.zig b/src/utils/units.zig index b02ec5c..9f9e975 100644 --- a/src/utils/units.zig +++ b/src/utils/units.zig @@ -15,7 +15,9 @@ const ETHER_F64: f64 = @as(f64, @floatFromInt(ETHER)); const GWEI_F64: f64 = @as(f64, @floatFromInt(GWEI)); const TWO_POW_128_F64: f64 = 2.0 * @as(f64, @floatFromInt(@as(u128, 1) << 127)); -inline fn f64ToU256(value: f64) u256 { +inline fn f64ToU256(value: f64) ?u256 { + if (value < 0.0 or !std.math.isFinite(value)) return null; + if (value >= TWO_POW_128_F64) return null; return @as(u256, @as(u128, @intFromFloat(value))); } @@ -28,13 +30,13 @@ inline fn u256ToF64(value: u256) f64 { return @as(f64, @floatFromInt(hi)) * TWO_POW_128_F64 + @as(f64, @floatFromInt(lo)); } -/// Convert ether (as f64) to wei (u256). -pub fn parseEther(ether: f64) u256 { +/// Convert ether (as f64) to wei (u256). Returns null for negative, non-finite, or overflow input. +pub fn parseEther(ether: f64) ?u256 { return f64ToU256(ether * ETHER_F64); } -/// Convert gwei (as f64) to wei (u256). -pub fn parseGwei(gwei: f64) u256 { +/// Convert gwei (as f64) to wei (u256). Returns null for negative, non-finite, or overflow input. +pub fn parseGwei(gwei: f64) ?u256 { return f64ToU256(gwei * GWEI_F64); } @@ -50,13 +52,13 @@ pub fn formatGwei(wei: u256) f64 { // Tests test "parseEther" { - try std.testing.expectEqual(@as(u256, 1_000_000_000_000_000_000), parseEther(1.0)); - try std.testing.expectEqual(@as(u256, 500_000_000_000_000_000), parseEther(0.5)); + try std.testing.expectEqual(@as(?u256, 1_000_000_000_000_000_000), parseEther(1.0)); + try std.testing.expectEqual(@as(?u256, 500_000_000_000_000_000), parseEther(0.5)); } test "parseGwei" { - try std.testing.expectEqual(@as(u256, 1_000_000_000), parseGwei(1.0)); - try std.testing.expectEqual(@as(u256, 20_000_000_000), parseGwei(20.0)); + try std.testing.expectEqual(@as(?u256, 1_000_000_000), parseGwei(1.0)); + try std.testing.expectEqual(@as(?u256, 20_000_000_000), parseGwei(20.0)); } test "formatEther" { @@ -68,13 +70,28 @@ test "formatGwei" { } test "parseEther zero" { - try std.testing.expectEqual(@as(u256, 0), parseEther(0.0)); + try std.testing.expectEqual(@as(?u256, 0), parseEther(0.0)); +} + +test "parseEther negative returns null" { + try std.testing.expectEqual(@as(?u256, null), parseEther(-1.0)); + try std.testing.expectEqual(@as(?u256, null), parseGwei(-1.0)); +} + +test "parseEther non-finite returns null" { + try std.testing.expectEqual(@as(?u256, null), parseEther(std.math.inf(f64))); + try std.testing.expectEqual(@as(?u256, null), parseEther(std.math.nan(f64))); +} + +test "parseEther overflow returns null" { + // 1e30 ether = 1e48 wei exceeds u128 max + try std.testing.expectEqual(@as(?u256, null), parseEther(1e30)); } test "parseEther large value" { // 9007.0 is exact in f64, but 9007.0 * 1e18 exceeds f64 mantissa precision. // Allow up to 1 ULP of error at this magnitude (~2^20 = 1_048_576). - const result = parseEther(9007.0); + const result = parseEther(9007.0).?; const expected: u256 = 9007_000_000_000_000_000_000; const diff = if (result > expected) result - expected else expected - result; try std.testing.expect(diff < 1_048_576); @@ -89,11 +106,11 @@ test "formatGwei zero" { } test "parseEther formatEther roundtrip" { - try std.testing.expectApproxEqAbs(@as(f64, 1.5), formatEther(parseEther(1.5)), 1e-6); + try std.testing.expectApproxEqAbs(@as(f64, 1.5), formatEther(parseEther(1.5).?), 1e-6); } test "parseGwei formatGwei roundtrip" { - try std.testing.expectApproxEqAbs(@as(f64, 30.0), formatGwei(parseGwei(30.0)), 1e-6); + try std.testing.expectApproxEqAbs(@as(f64, 30.0), formatGwei(parseGwei(30.0).?), 1e-6); } test "formatEther is finite and monotonic for very large u256 values" { diff --git a/src/wallet.zig b/src/wallet.zig index a62282d..0c4c7d1 100644 --- a/src/wallet.zig +++ b/src/wallet.zig @@ -45,6 +45,11 @@ pub const Wallet = struct { }; } + /// Securely zero the private key. Call when the Wallet is no longer needed. + pub fn deinit(self: *Wallet) void { + self.signer_instance.deinit(); + } + /// Return the Ethereum address derived from this wallet's private key. pub fn address(self: *const Wallet) ![20]u8 { return try self.signer_instance.address(); diff --git a/src/ws_transport.zig b/src/ws_transport.zig index a24e57a..9bf90cc 100644 --- a/src/ws_transport.zig +++ b/src/ws_transport.zig @@ -28,10 +28,10 @@ pub fn parseUrl(url: []const u8) UrlError!ParsedUrl { var rest: []const u8 = undefined; var is_tls: bool = false; - if (startsWith(url, "wss://")) { + if (std.mem.startsWith(u8, url, "wss://")) { rest = url[6..]; is_tls = true; - } else if (startsWith(url, "ws://")) { + } else if (std.mem.startsWith(u8, url, "ws://")) { rest = url[5..]; is_tls = false; } else { @@ -41,7 +41,7 @@ pub fn parseUrl(url: []const u8) UrlError!ParsedUrl { if (rest.len == 0) return error.MissingHost; // Split host+port from path. - const path_start = indexOf(rest, '/') orelse rest.len; + const path_start = std.mem.indexOfScalar(u8, rest, '/') orelse rest.len; const host_port = rest[0..path_start]; const path = if (path_start < rest.len) rest[path_start..] else "/"; @@ -49,7 +49,7 @@ pub fn parseUrl(url: []const u8) UrlError!ParsedUrl { var host: []const u8 = undefined; var port: u16 = if (is_tls) 443 else 80; - if (indexOf(host_port, ':')) |colon| { + if (std.mem.indexOfScalar(u8, host_port, ':')) |colon| { host = host_port[0..colon]; const port_str = host_port[colon + 1 ..]; port = std.fmt.parseInt(u16, port_str, 10) catch { @@ -253,21 +253,24 @@ pub fn computeAcceptKey(ws_key: []const u8) [28]u8 { /// correct Sec-WebSocket-Accept header. pub fn validateHandshakeResponse(response: []const u8, expected_accept: []const u8) bool { // Check for HTTP 101 status - if (!containsSubstring(response, "101")) return false; + if (std.mem.indexOf(u8, response, "101") == null) return false; // Find Sec-WebSocket-Accept header (case-insensitive search) const accept_header = "sec-websocket-accept: "; var lower_buf: [4096]u8 = undefined; const check_len = @min(response.len, lower_buf.len); for (response[0..check_len], 0..) |c, i| { - lower_buf[i] = toLower(c); + lower_buf[i] = std.ascii.toLower(c); } const lower_response = lower_buf[0..check_len]; - if (indexOfSubstring(lower_response, accept_header)) |header_start| { + if (std.mem.indexOf(u8, lower_response, accept_header)) |header_start| { const value_start = header_start + accept_header.len; // Find end of header value (terminated by \r\n) - const value_end = indexOfFrom(response, value_start, '\r') orelse response.len; + const value_end = blk: { + const slice = response[value_start..]; + break :blk if (std.mem.indexOfScalar(u8, slice, '\r')) |idx| idx + value_start else null; + } orelse response.len; const accept_value = response[value_start..value_end]; // Trim whitespace @@ -441,7 +444,7 @@ pub const WsTransport = struct { var id_buf: [32]u8 = undefined; const id_str = std.fmt.bufPrint(&id_buf, "\"id\":{d}", .{id}) catch unreachable; - if (containsSubstring(frame_data, id_str)) { + if (std.mem.indexOf(u8, frame_data, id_str) != null) { return frame_data; } @@ -575,7 +578,7 @@ pub const WsTransport = struct { // Check if we have the full response (ends with \r\n\r\n) if (total_read >= 4) { - if (indexOfSubstring(response_buf[0..total_read], "\r\n\r\n") != null) { + if (std.mem.indexOf(u8, response_buf[0..total_read], "\r\n\r\n") != null) { break; } } @@ -641,46 +644,6 @@ pub const WsTransport = struct { } }; -// --------------------------------------------------------------------------- -// String utility helpers -// --------------------------------------------------------------------------- - -fn startsWith(haystack: []const u8, prefix: []const u8) bool { - if (haystack.len < prefix.len) return false; - return std.mem.eql(u8, haystack[0..prefix.len], prefix); -} - -fn indexOf(haystack: []const u8, needle: u8) ?usize { - return std.mem.indexOfScalar(u8, haystack, needle); -} - -fn indexOfFrom(haystack: []const u8, start: usize, needle: u8) ?usize { - if (start >= haystack.len) return null; - const idx = std.mem.indexOfScalar(u8, haystack[start..], needle); - return if (idx) |i| i + start else null; -} - -fn containsSubstring(haystack: []const u8, needle: []const u8) bool { - return indexOfSubstring(haystack, needle) != null; -} - -fn indexOfSubstring(haystack: []const u8, needle: []const u8) ?usize { - if (needle.len > haystack.len) return null; - if (needle.len == 0) return 0; - const limit = haystack.len - needle.len + 1; - for (0..limit) |i| { - if (std.mem.eql(u8, haystack[i .. i + needle.len], needle)) { - return i; - } - } - return null; -} - -fn toLower(c: u8) u8 { - if (c >= 'A' and c <= 'Z') return c + 32; - return c; -} - // ============================================================================ // Tests // ============================================================================ @@ -1045,13 +1008,13 @@ test "buildHandshakeRequest - basic" { const req = try buildHandshakeRequest(allocator, "localhost", 8545, "/ws", "dGhlIHNhbXBsZSBub25jZQ=="); defer allocator.free(req); - try std.testing.expect(containsSubstring(req, "GET /ws HTTP/1.1\r\n")); - try std.testing.expect(containsSubstring(req, "Host: localhost:8545\r\n")); - try std.testing.expect(containsSubstring(req, "Upgrade: websocket\r\n")); - try std.testing.expect(containsSubstring(req, "Connection: Upgrade\r\n")); - try std.testing.expect(containsSubstring(req, "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n")); - try std.testing.expect(containsSubstring(req, "Sec-WebSocket-Version: 13\r\n")); - try std.testing.expect(containsSubstring(req, "\r\n\r\n")); + try std.testing.expect(std.mem.indexOf(u8, req, "GET /ws HTTP/1.1\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, "Host: localhost:8545\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, "Upgrade: websocket\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, "Connection: Upgrade\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, "Sec-WebSocket-Version: 13\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, "\r\n\r\n") != null); } test "buildHandshakeRequest - standard port 80 omitted" { @@ -1059,8 +1022,8 @@ test "buildHandshakeRequest - standard port 80 omitted" { const req = try buildHandshakeRequest(allocator, "example.com", 80, "/", "abc="); defer allocator.free(req); - try std.testing.expect(containsSubstring(req, "Host: example.com\r\n")); - try std.testing.expect(!containsSubstring(req, ":80")); + try std.testing.expect(std.mem.indexOf(u8, req, "Host: example.com\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, ":80") == null); } test "buildHandshakeRequest - standard port 443 omitted" { @@ -1068,8 +1031,8 @@ test "buildHandshakeRequest - standard port 443 omitted" { const req = try buildHandshakeRequest(allocator, "example.com", 443, "/", "abc="); defer allocator.free(req); - try std.testing.expect(containsSubstring(req, "Host: example.com\r\n")); - try std.testing.expect(!containsSubstring(req, "443")); + try std.testing.expect(std.mem.indexOf(u8, req, "Host: example.com\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, "443") == null); } test "buildHandshakeRequest - deep path" { @@ -1077,8 +1040,8 @@ test "buildHandshakeRequest - deep path" { const req = try buildHandshakeRequest(allocator, "node.io", 9546, "/v2/my-key", "key="); defer allocator.free(req); - try std.testing.expect(containsSubstring(req, "GET /v2/my-key HTTP/1.1\r\n")); - try std.testing.expect(containsSubstring(req, "Host: node.io:9546\r\n")); + try std.testing.expect(std.mem.indexOf(u8, req, "GET /v2/my-key HTTP/1.1\r\n") != null); + try std.testing.expect(std.mem.indexOf(u8, req, "Host: node.io:9546\r\n") != null); } test "validateHandshakeResponse - valid" { @@ -1189,32 +1152,6 @@ test "encodeFrame - exactly 126 bytes (triggers extended 16-bit length)" { try std.testing.expectEqual(@as(usize, 2 + 2 + 4 + 126), frame.len); } -test "string helpers - startsWith" { - try std.testing.expect(startsWith("ws://hello", "ws://")); - try std.testing.expect(startsWith("wss://hello", "wss://")); - try std.testing.expect(!startsWith("http://hello", "ws://")); - try std.testing.expect(!startsWith("w", "ws://")); -} - -test "string helpers - indexOf" { - try std.testing.expectEqual(@as(?usize, 3), indexOf("abc:def", ':')); - try std.testing.expectEqual(@as(?usize, null), indexOf("abcdef", ':')); -} - -test "string helpers - containsSubstring" { - try std.testing.expect(containsSubstring("Hello, World!", "World")); - try std.testing.expect(!containsSubstring("Hello, World!", "xyz")); - try std.testing.expect(containsSubstring("abc", "abc")); - try std.testing.expect(containsSubstring("abc", "")); -} - -test "string helpers - toLower" { - try std.testing.expectEqual(@as(u8, 'a'), toLower('A')); - try std.testing.expectEqual(@as(u8, 'z'), toLower('Z')); - try std.testing.expectEqual(@as(u8, 'a'), toLower('a')); - try std.testing.expectEqual(@as(u8, '1'), toLower('1')); -} - test "Opcode values" { try std.testing.expectEqual(@as(u4, 0x1), @intFromEnum(Opcode.text)); try std.testing.expectEqual(@as(u4, 0x2), @intFromEnum(Opcode.binary)); diff --git a/tests/unit_tests.zig b/tests/unit_tests.zig index a48ee9e..ad7ab64 100644 --- a/tests/unit_tests.zig +++ b/tests/unit_tests.zig @@ -14,6 +14,9 @@ test { _ = eth.abi_decode; // Layer 3: Crypto _ = eth.signature; + _ = eth.secp256k1; + _ = eth.signer; + _ = eth.eip155; // Layer 4: Types _ = eth.access_list; _ = eth.transaction; @@ -25,11 +28,17 @@ test { _ = eth.hd_wallet; // Layer 6: Transport _ = eth.json_rpc; + _ = eth.http_transport; _ = eth.ws_transport; _ = eth.subscription; _ = eth.provider; - // Layer 7: Client + // Layer 7: ENS + _ = eth.ens_namehash; + _ = eth.ens_resolver; + _ = eth.ens_reverse; + // Layer 8: Client _ = eth.wallet; + _ = eth.flashbots; _ = eth.contract; _ = eth.multicall; _ = eth.event; @@ -40,6 +49,11 @@ test { _ = eth.abi_json; // Layer 10: Chains _ = eth.chains; + // DEX Math + _ = eth.dex_v2; + _ = eth.dex_v3; + _ = eth.dex_router; // Utils _ = eth.units; + _ = eth.constants; }