diff --git a/src/arith.ml b/src/arith.ml index 3cbde80..4eaa5b0 100644 --- a/src/arith.ml +++ b/src/arith.ml @@ -8,26 +8,24 @@ open Signal Operands are read from and results written to an external 32×256-bit register file. Operations (op input): - 0 = add: r[addr_out] <- r[addr_a] + r[addr_b] mod m - 1 = sub: r[addr_out] <- r[addr_a] - r[addr_b] mod m - 2 = mul: r[addr_out] <- r[addr_a] * r[addr_b] mod m - 3 = inv: r[addr_out] <- r[addr_a]^(-1) mod m (addr_b ignored) + 0 = add: f <- a + b mod m + 1 = sub: f <- a - b mod m + 2 = mul: f <- a * b mod m + 3 = inv: f <- a^(-1) mod m (b ignored) Modulus selection (prime_sel): 0 = prime_p, 1 = prime_n Protocol: - 1. Set addr_a, addr_b, addr_out, op, prime_sel; pulse start - 2. Provide reg_read_data_a/b in response to reg_read_addr_a/b - 3. Wait for done_ pulse; result written via reg_write_* signals - 4. For inv, check inv_exists to confirm inverse was found + 1. Set a, b, op, prime_sel; pulse start + 2. Wait for done_ pulse; result written via reg_write_data signal + 3. For inv, check inv_exists to confirm inverse was found - State machine: Idle -> Load -> Capture -> Compute -> Write -> Done -> Idle + State machine: Idle -> Compute -> Idle *) module Config = struct let width = 256 let num_registers = 32 - let reg_addr_width = 5 (* secp256k1 field prime: p = 2^256 - 2^32 - 977 *) let prime_p = Z.of_string "115792089237316195423570985008687907853269984665640564039457584007908834671663" @@ -50,11 +48,7 @@ end module State = struct type t = | Idle - | Load - | Capture | Compute - | Write - | Done [@@deriving sexp_of, compare, enumerate] end @@ -65,9 +59,6 @@ module I = struct ; start : 'a ; op : 'a [@bits 2] ; prime_sel : 'a - ; addr_a : 'a [@bits Config.reg_addr_width] - ; addr_b : 'a [@bits Config.reg_addr_width] - ; addr_out : 'a [@bits Config.reg_addr_width] ; reg_read_data_a : 'a [@bits Config.width] ; reg_read_data_b : 'a [@bits Config.width] } @@ -78,11 +69,7 @@ module O = struct type 'a t = { busy : 'a ; done_ : 'a - ; reg_write_enable : 'a - ; reg_write_addr : 'a [@bits Config.reg_addr_width] ; reg_write_data : 'a [@bits Config.width] - ; reg_read_addr_a : 'a [@bits Config.reg_addr_width] - ; reg_read_addr_b : 'a [@bits Config.reg_addr_width] ; inv_exists : 'a } [@@deriving sexp_of, hardcaml] @@ -100,9 +87,6 @@ let create scope (i : _ I.t) = (* Latched inputs *) let op_reg = Variable.reg spec ~width:2 in let prime_sel_reg = Variable.reg spec ~width:1 in - let addr_a_reg = Variable.reg spec ~width:Config.reg_addr_width in - let addr_b_reg = Variable.reg spec ~width:Config.reg_addr_width in - let addr_out_reg = Variable.reg spec ~width:Config.reg_addr_width in (* Captured operands *) let operand_a = Variable.reg spec ~width:Config.width in @@ -119,7 +103,6 @@ let create scope (i : _ I.t) = let inv_exists_reg = Variable.reg spec ~width:1 in (* Output registers *) - let reg_write_enable = Variable.reg spec ~width:1 in let done_flag = Variable.reg spec ~width:1 in (* Prime constants *) @@ -228,7 +211,6 @@ let create scope (i : _ I.t) = (* TODO move start_mul, start_inv clear to Compute step as well when updated *) start_mul <-- gnd; start_inv <-- gnd; - reg_write_enable <-- gnd; done_flag <-- gnd; sm.switch [ @@ -237,67 +219,35 @@ let create scope (i : _ I.t) = (* Latch all inputs *) op_reg <-- i.op; prime_sel_reg <-- i.prime_sel; - addr_a_reg <-- i.addr_a; - addr_b_reg <-- i.addr_b; - addr_out_reg <-- i.addr_out; - sm.set_next Load; + operand_a <-- i.reg_read_data_a; + operand_b <-- i.reg_read_data_b; + + (* Start the required operation *) + switch i.op [ + of_int ~width:2 Op.add, [ mod_add_valid <-- vdd ]; + of_int ~width:2 Op.sub, [ mod_sub_valid <-- vdd ]; + of_int ~width:2 Op.mul, [ start_mul <-- vdd ]; + of_int ~width:2 Op.inv, [ start_inv <-- vdd ]; + ]; + + (* Move to next state *) + sm.set_next Compute; ]; ]; - State.Load, [ - (* Wait one cycle for register file to provide data *) - sm.set_next Capture; - ]; - - - -State.Capture, [ - operand_a <-- i.reg_read_data_a; - operand_b <-- i.reg_read_data_b; - - switch op_reg.value [ - of_int ~width:2 Op.add, [ mod_add_valid <-- vdd ]; - of_int ~width:2 Op.sub, [ mod_sub_valid <-- vdd ]; - of_int ~width:2 Op.mul, [ start_mul <-- vdd ]; - of_int ~width:2 Op.inv, [ start_inv <-- vdd ]; - ]; - - sm.set_next Compute; -]; - -State.Compute, [ - (* Simply wait for operation to complete *) - when_ op_ready [ - (* Clear valid signals *) - mod_add_valid <-- gnd; - mod_sub_valid <-- gnd; - - result_reg <-- op_result; - inv_exists_reg <-- mod_inv_out.exists; - sm.set_next Write; - ]; -]; - - State.Write, [ - (* Write result to register file *) - reg_write_enable <-- vdd; - sm.set_next Done; + State.Compute, [ + (* Simply wait for operation to complete *) + when_ op_ready [ + (* Clear valid signals *) + mod_add_valid <-- gnd; + mod_sub_valid <-- gnd; + + done_flag <-- vdd; + result_reg <-- op_result; + inv_exists_reg <-- mod_inv_out.exists; + sm.set_next Idle; + ]; ]; - -State.Done, [ - done_flag <-- vdd; - if_ i.start [ - (* New operation starting immediately - latch inputs and go to Load *) - op_reg <-- i.op; - prime_sel_reg <-- i.prime_sel; - addr_a_reg <-- i.addr_a; - addr_b_reg <-- i.addr_b; - addr_out_reg <-- i.addr_out; - sm.set_next Load; - ] [ - sm.set_next Idle; - ]; -]; ]; ]; @@ -307,10 +257,6 @@ State.Done, [ { O. busy = busy -- "busy" ; done_ = done_flag.value -- "done" - ; reg_write_enable = reg_write_enable.value -- "reg_write_enable" - ; reg_write_addr = addr_out_reg.value -- "reg_write_addr" ; reg_write_data = result_reg.value -- "reg_write_data" - ; reg_read_addr_a = addr_a_reg.value -- "reg_read_addr_a" - ; reg_read_addr_b = addr_b_reg.value -- "reg_read_addr_b" ; inv_exists = inv_exists_reg.value -- "inv_exists" } \ No newline at end of file diff --git a/src/dune b/src/dune index a2fdf19..73dc3b9 100644 --- a/src/dune +++ b/src/dune @@ -2,6 +2,6 @@ (name off_switch) (public_name off_switch) (wrapped false) - (modules trng security_block ecdsa arith mod_add mod_mul mod_inv comb_add point_add point_mul) + (modules trng security_block ecdsa arith mod_add mod_mul mod_inv comb_add) (libraries base hardcaml zarith) (preprocess (pps ppx_jane ppx_hardcaml))) diff --git a/src/ecdsa.ml b/src/ecdsa.ml index ede76b4..86da25b 100644 --- a/src/ecdsa.ml +++ b/src/ecdsa.ml @@ -4,42 +4,42 @@ open Signal (* ECDSA Signature Verification for secp256k1 - + Verifies ECDSA signatures using the equation: R = u1*G + u2*Q where: u1 = z * s^(-1) mod n u2 = r * s^(-1) mod n - + Signature is valid if R.x mod n == r Uses Renes, Costello and Batina's complete addition formula in projective coordinates Uses the Arith module for modular {add, sub, mul, inv}. - + Inputs: - z: message hash (256 bits) - r, s: signature components (256 bits each) - param_a, param_b3: curve parameters (a=0, b3=21 for secp256k1) - + Outputs: - valid: 1 if signature is valid, 0 otherwise - done_: pulses high for one cycle when verification completes - busy: high while verification is in progress - x, y, z_out: final point coordinates (for debugging) - + Hardcoded: - G: generator point - Q: public key (currently 2G for testing) - G+Q: precomputed sum (currently 3G) - + State machine: - Idle -> Prep_op (3 ops) -> Loop/Load/Run_add (scalar mult) + Idle -> Prep_op (3 ops) -> Loop/Load/Run_add (scalar mult) -> Finalize_op (2 ops) -> Compare -> Done -> Idle - + Cycle count: ~1.5-2M cycles for typical 256-bit scalars - - Note: Does not currently reduce x_affine mod n before comparison, + + Note: Does not currently reduce x_affine mod n before comparison, or check that z, s, r are within range. *) @@ -48,7 +48,7 @@ module Config = struct let width = 256 let reg_addr_width = 5 let num_steps = 40 - + let t0 = 0 let t1 = 1 let t2 = 2 @@ -66,24 +66,24 @@ module Config = struct let z2 = 14 let param_a = 15 let param_b3 = 16 - + let num_regs = 17 - + (* Generator point G *) let generator_x = Z.of_string "0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" let generator_y = Z.of_string "0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8" let generator_z = Z.one - + (* Public key Q = 2G *) let q_x = Z.of_string "0xc6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5" let q_y = Z.of_string "12158399299693830322967808612713398636155367887041628176798871954788371653930" let q_z = Z.one - + (* Precomputed G + Q = 3G *) let gpq_x = Z.of_string "0xf9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9" let gpq_y = Z.of_string "0x388f7b0f632de8140fe337e62a37f3566500a99934c2231b6cb9fd7584b8e672" let gpq_z = Z.one - + (* Point at infinity *) let infinity_x = Z.zero let infinity_y = Z.one @@ -191,37 +191,37 @@ let create scope (i : _ I.t) = let open Always in let width = Config.width in let addr_width = Config.reg_addr_width in - + let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in let sm = State_machine.create (module State) spec ~enable:vdd in - + let reg_file = Array.init Config.num_regs ~f:(fun _ -> Variable.reg spec ~width) in - + (* Hardcoded point constants *) let const_gx = of_z ~width Config.generator_x in let const_gy = of_z ~width Config.generator_y in let const_gz = of_z ~width Config.generator_z in - + let const_qx = of_z ~width Config.q_x in let const_qy = of_z ~width Config.q_y in let const_qz = of_z ~width Config.q_z in - + let const_gpqx = of_z ~width Config.gpq_x in let const_gpqy = of_z ~width Config.gpq_y in let const_gpqz = of_z ~width Config.gpq_z in - + (* Input registers *) let z_reg = Variable.reg spec ~width in let r_reg = Variable.reg spec ~width in let s_reg = Variable.reg spec ~width in - + (* Prep phase registers *) let w_reg = Variable.reg spec ~width in (* s^(-1) mod n *) let u1_reg = Variable.reg spec ~width in (* z * w mod n *) let u2_reg = Variable.reg spec ~width in (* r * w mod n *) let prep_step = Variable.reg spec ~width:2 in let prep_op_started = Variable.reg spec ~width:1 in - + (* Main loop registers *) let bit_pos = Variable.reg spec ~width:8 in let doubling = Variable.reg spec ~width:1 in @@ -229,36 +229,36 @@ let create scope (i : _ I.t) = let step = Variable.reg spec ~width:6 in let load_idx = Variable.reg spec ~width:2 in let op_started = Variable.reg spec ~width:1 in - + (* Latched second operand registers *) let x2_latched = Variable.reg spec ~width in let y2_latched = Variable.reg spec ~width in let z2_latched = Variable.reg spec ~width in - + (* Finalize phase registers *) let finalize_step = Variable.reg spec ~width:1 in let finalize_op_started = Variable.reg spec ~width:1 in let z_inv_reg = Variable.reg spec ~width in let x_affine_reg = Variable.reg spec ~width in - + (* Output registers *) let out_x = Variable.reg spec ~width in let out_y = Variable.reg spec ~width in let out_z = Variable.reg spec ~width in let valid_reg = Variable.reg spec ~width:1 in let done_flag = Variable.reg spec ~width:1 in - + let current_bit_u = bit_select_dynamic u1_reg.value bit_pos.value in let current_bit_v = bit_select_dynamic u2_reg.value bit_pos.value in - + (* Check if P is point at infinity (z1 = 0) *) let p_is_infinity = reg_file.(Config.z1).value ==:. 0 in - + (* Instruction decode for point addition *) let default_instr = of_int ~width:addr_width 0 in let decode field = - mux step.value - (Array.to_list (Array.map program ~f:(fun instr -> + mux step.value + (Array.to_list (Array.map program ~f:(fun instr -> of_int ~width:addr_width (field instr)))) |> fun s -> mux2 (step.value >=:. Config.num_steps) default_instr s in @@ -268,19 +268,19 @@ let create scope (i : _ I.t) = of_int ~width:2 instr.op))) |> fun s -> mux2 (step.value >=:. Config.num_steps) (zero 2) s in - + let current_dst = decode (fun instr -> instr.dst) in let current_src1 = decode (fun instr -> instr.src1) in let current_src2 = decode (fun instr -> instr.src2) in let current_op_from_program = decode_op in - + (* Arith control signals *) let arith_start = Variable.wire ~default:gnd in let arith_op = Variable.wire ~default:(zero 2) in let arith_prime_sel = Variable.wire ~default:gnd in let arith_a = Variable.wire ~default:(zero width) in let arith_b = Variable.wire ~default:(zero width) in - + (* Use latched values for x2/y2/z2 in point addition *) let reg_read idx = let base_values = Array.to_list (Array.map reg_file ~f:(fun r -> r.value)) in @@ -289,25 +289,25 @@ let create scope (i : _ I.t) = (mux2 (idx ==:. Config.y2) y2_latched.value (mux2 (idx ==:. Config.z2) z2_latched.value base)) in - + (* Select arith inputs based on current state *) let in_prep_or_finalize = (sm.is Prep_op) |: (sm.is Finalize_op) in - - let arith_read_data_a = + + let arith_read_data_a = mux2 in_prep_or_finalize arith_a.value (reg_read current_src1) in - let arith_read_data_b = + let arith_read_data_b = mux2 in_prep_or_finalize arith_b.value (reg_read current_src2) in - + let arith_op_selected = mux2 in_prep_or_finalize arith_op.value current_op_from_program in - + let arith_prime_sel_selected = mux2 in_prep_or_finalize arith_prime_sel.value gnd in - + let arith_out = Arith.create (Scope.sub_scope scope "arith") { Arith.I. clock = i.clock @@ -315,17 +315,14 @@ let create scope (i : _ I.t) = ; start = arith_start.value ; op = arith_op_selected ; prime_sel = arith_prime_sel_selected - ; addr_a = current_src1 - ; addr_b = current_src2 - ; addr_out = current_dst ; reg_read_data_a = arith_read_data_a ; reg_read_data_b = arith_read_data_b } in - + compile [ done_flag <-- gnd; - + sm.switch [ State.Idle, [ when_ i.start [ @@ -339,11 +336,11 @@ let create scope (i : _ I.t) = sm.set_next Prep_op; ]; ]; - + State.Prep_op, [ (* Set up arith inputs based on prep_step *) arith_prime_sel <-- vdd; (* All prep ops use mod n *) - + if_ (prep_step.value ==:. 0) [ (* w = s^(-1) mod n *) arith_op <--. Op.inv; @@ -360,7 +357,7 @@ let create scope (i : _ I.t) = arith_a <-- r_reg.value; arith_b <-- w_reg.value; ]; - + if_ (~:(prep_op_started.value)) [ arith_start <-- vdd; prep_op_started <-- vdd; @@ -374,7 +371,7 @@ let create scope (i : _ I.t) = ] [ u2_reg <-- arith_out.reg_write_data; ]; - + if_ (prep_step.value ==:. 2) [ (* Initialize for main loop *) bit_pos <--. 255; @@ -393,7 +390,7 @@ let create scope (i : _ I.t) = ]; ]; ]; - + State.Loop, [ if_ last_step.value [ (* Initialize for finalize phase *) @@ -415,7 +412,7 @@ let create scope (i : _ I.t) = load_idx <-- (current_bit_v @: current_bit_u); last_step <-- (bit_pos.value ==:. 0); bit_pos <-- bit_pos.value -:. 1; - + if_ ((~: current_bit_u) &: (~: current_bit_v)) [ sm.set_next Loop; ] [ @@ -423,7 +420,7 @@ let create scope (i : _ I.t) = ]; ]; ]; - + State.Load, [ if_ (load_idx.value ==:. 0) [ x2_latched <-- reg_file.(Config.x1).value; @@ -445,7 +442,7 @@ let create scope (i : _ I.t) = step <-- zero 6; sm.set_next Run_add; ]; - + State.Run_add, [ if_ (~:(op_started.value)) [ arith_start <-- vdd; @@ -456,7 +453,7 @@ let create scope (i : _ I.t) = when_ (current_dst ==:. idx) [ reg <-- arith_out.reg_write_data; ]))); - + if_ (step.value ==:. Config.num_steps - 1) [ op_started <-- gnd; sm.set_next Loop; @@ -467,11 +464,11 @@ let create scope (i : _ I.t) = ]; ]; ]; - + State.Finalize_op, [ (* Set up arith inputs based on finalize_step *) arith_prime_sel <-- gnd; (* Finalize ops use mod p *) - + if_ (~:(finalize_step.value)) [ (* z_inv = Z1^(-1) mod p *) arith_op <--. Op.inv; @@ -483,7 +480,7 @@ let create scope (i : _ I.t) = arith_a <-- reg_file.(Config.x1).value; arith_b <-- z_inv_reg.value; ]; - + if_ (~:(finalize_op_started.value)) [ arith_start <-- vdd; finalize_op_started <-- vdd; @@ -495,7 +492,7 @@ let create scope (i : _ I.t) = ] [ x_affine_reg <-- arith_out.reg_write_data; ]; - + if_ finalize_step.value [ sm.set_next Compare; ] [ @@ -505,7 +502,7 @@ let create scope (i : _ I.t) = ]; ]; ]; - + State.Compare, [ (* Compare x_affine with r *) valid_reg <-- (x_affine_reg.value ==: r_reg.value); @@ -514,14 +511,14 @@ let create scope (i : _ I.t) = out_z <-- reg_file.(Config.z1).value; sm.set_next Done; ]; - + State.Done, [ done_flag <-- vdd; sm.set_next Idle; ]; ]; ]; - + { O. busy = ~:(sm.is Idle) ; done_ = done_flag.value diff --git a/src/point_add.ml b/src/point_add.ml deleted file mode 100644 index bf7a356..0000000 --- a/src/point_add.ml +++ /dev/null @@ -1,275 +0,0 @@ -open Base -open Hardcaml -open Signal - -(* Point addition on short Weierstrass curves using projective coordinates. - - Implements complete addition formula for E/Fq: y² = x³ + ax + b - Input: P = (X1:Y1:Z1), Q = (X2:Y2:Z2) in projective coordinates - Output: P + Q = (X3:Y3:Z3) - - Uses Arith module for field operations (add, sub, mul) mod prime_p. - Executes 40-step algorithm with step counter and combinational decode. -*) - -module Config = struct - let width = 256 - let reg_addr_width = 5 - let num_steps = 40 - - (* Register allocation *) - let t0 = 0 - let t1 = 1 - let t2 = 2 - let t3 = 3 - let t4 = 4 - let t5 = 5 - let x3 = 6 - let y3 = 7 - let z3 = 8 - let x1 = 9 - let y1 = 10 - let z1 = 11 - let x2 = 12 - let y2 = 13 - let z2 = 14 - let param_a = 15 - let param_b3 = 16 -end - -module Op = struct - let add = 0 - let sub = 1 - let mul = 2 -end - -module I = struct - type 'a t = - { clock : 'a - ; clear : 'a - ; start : 'a - ; x1 : 'a [@bits Config.width] - ; y1 : 'a [@bits Config.width] - ; z1 : 'a [@bits Config.width] - ; x2 : 'a [@bits Config.width] - ; y2 : 'a [@bits Config.width] - ; z2 : 'a [@bits Config.width] - ; param_a : 'a [@bits Config.width] - ; param_b3 : 'a [@bits Config.width] - } - [@@deriving sexp_of, hardcaml] -end - -module O = struct - type 'a t = - { busy : 'a - ; done_ : 'a - ; x3 : 'a [@bits Config.width] - ; y3 : 'a [@bits Config.width] - ; z3 : 'a [@bits Config.width] - } - [@@deriving sexp_of, hardcaml] -end - -module State = struct - type t = - | Idle - | Load_inputs - | Wait_load (* NEW: wait for register writes to take effect *) - | Run_step - | Output - | Done - [@@deriving sexp_of, compare, enumerate] -end - -(* Instruction encoding *) -type instr = { op : int; src1 : int; src2 : int; dst : int } - -let program = [| - { op = Op.mul; src1 = Config.x1; src2 = Config.x2; dst = Config.t0 }; (* 0: t0 <- X1 * X2 *) - { op = Op.mul; src1 = Config.y1; src2 = Config.y2; dst = Config.t1 }; (* 1: t1 <- Y1 * Y2 *) - { op = Op.mul; src1 = Config.z1; src2 = Config.z2; dst = Config.t2 }; (* 2: t2 <- Z1 * Z2 *) - { op = Op.add; src1 = Config.x1; src2 = Config.y1; dst = Config.t3 }; (* 3: t3 <- X1 + Y1 *) - { op = Op.add; src1 = Config.x2; src2 = Config.y2; dst = Config.t4 }; (* 4: t4 <- X2 + Y2 *) - { op = Op.mul; src1 = Config.t3; src2 = Config.t4; dst = Config.t3 }; (* 5: t3 <- t3 * t4 *) - { op = Op.add; src1 = Config.t0; src2 = Config.t1; dst = Config.t4 }; (* 6: t4 <- t0 + t1 *) - { op = Op.sub; src1 = Config.t3; src2 = Config.t4; dst = Config.t3 }; (* 7: t3 <- t3 - t4 *) - { op = Op.add; src1 = Config.x1; src2 = Config.z1; dst = Config.t4 }; (* 8: t4 <- X1 + Z1 *) - { op = Op.add; src1 = Config.x2; src2 = Config.z2; dst = Config.t5 }; (* 9: t5 <- X2 + Z2 *) - { op = Op.mul; src1 = Config.t4; src2 = Config.t5; dst = Config.t4 }; (* 10: t4 <- t4 * t5 *) - { op = Op.add; src1 = Config.t0; src2 = Config.t2; dst = Config.t5 }; (* 11: t5 <- t0 + t2 *) - { op = Op.sub; src1 = Config.t4; src2 = Config.t5; dst = Config.t4 }; (* 12: t4 <- t4 - t5 *) - { op = Op.add; src1 = Config.y1; src2 = Config.z1; dst = Config.t5 }; (* 13: t5 <- Y1 + Z1 *) - { op = Op.add; src1 = Config.y2; src2 = Config.z2; dst = Config.x3 }; (* 14: X3 <- Y2 + Z2 *) - { op = Op.mul; src1 = Config.t5; src2 = Config.x3; dst = Config.t5 }; (* 15: t5 <- t5 * X3 *) - { op = Op.add; src1 = Config.t1; src2 = Config.t2; dst = Config.x3 }; (* 16: X3 <- t1 + t2 *) - { op = Op.sub; src1 = Config.t5; src2 = Config.x3; dst = Config.t5 }; (* 17: t5 <- t5 - X3 *) - { op = Op.mul; src1 = Config.param_a; src2 = Config.t4; dst = Config.z3 }; (* 18: Z3 <- a * t4 *) - { op = Op.mul; src1 = Config.param_b3; src2 = Config.t2; dst = Config.x3 }; (* 19: X3 <- b3 * t2 *) - { op = Op.add; src1 = Config.x3; src2 = Config.z3; dst = Config.z3 }; (* 20: Z3 <- X3 + Z3 *) - { op = Op.sub; src1 = Config.t1; src2 = Config.z3; dst = Config.x3 }; (* 21: X3 <- t1 - Z3 *) - { op = Op.add; src1 = Config.t1; src2 = Config.z3; dst = Config.z3 }; (* 22: Z3 <- t1 + Z3 *) - { op = Op.mul; src1 = Config.x3; src2 = Config.z3; dst = Config.y3 }; (* 23: Y3 <- X3 * Z3 *) - { op = Op.add; src1 = Config.t0; src2 = Config.t0; dst = Config.t1 }; (* 24: t1 <- t0 + t0 *) - { op = Op.add; src1 = Config.t1; src2 = Config.t0; dst = Config.t1 }; (* 25: t1 <- t1 + t0 *) - { op = Op.mul; src1 = Config.param_a; src2 = Config.t2; dst = Config.t2 }; (* 26: t2 <- a * t2 *) - { op = Op.mul; src1 = Config.param_b3; src2 = Config.t4; dst = Config.t4 }; (* 27: t4 <- b3 * t4 *) - { op = Op.add; src1 = Config.t1; src2 = Config.t2; dst = Config.t1 }; (* 28: t1 <- t1 + t2 *) - { op = Op.sub; src1 = Config.t0; src2 = Config.t2; dst = Config.t2 }; (* 29: t2 <- t0 - t2 *) - { op = Op.mul; src1 = Config.param_a; src2 = Config.t2; dst = Config.t2 }; (* 30: t2 <- a * t2 *) - { op = Op.add; src1 = Config.t4; src2 = Config.t2; dst = Config.t4 }; (* 31: t4 <- t4 + t2 *) - { op = Op.mul; src1 = Config.t1; src2 = Config.t4; dst = Config.t0 }; (* 32: t0 <- t1 * t4 *) - { op = Op.add; src1 = Config.y3; src2 = Config.t0; dst = Config.y3 }; (* 33: Y3 <- Y3 + t0 *) - { op = Op.mul; src1 = Config.t5; src2 = Config.t4; dst = Config.t0 }; (* 34: t0 <- t5 * t4 *) - { op = Op.mul; src1 = Config.t3; src2 = Config.x3; dst = Config.x3 }; (* 35: X3 <- t3 * X3 *) - { op = Op.sub; src1 = Config.x3; src2 = Config.t0; dst = Config.x3 }; (* 36: X3 <- X3 - t0 *) - { op = Op.mul; src1 = Config.t3; src2 = Config.t1; dst = Config.t0 }; (* 37: t0 <- t3 * t1 *) - { op = Op.mul; src1 = Config.t5; src2 = Config.z3; dst = Config.z3 }; (* 38: Z3 <- t5 * Z3 *) - { op = Op.add; src1 = Config.z3; src2 = Config.t0; dst = Config.z3 }; (* 39: Z3 <- Z3 + t0 *) -|] - -let create scope (i : _ I.t) = - let open Always in - let ( -- ) = Scope.naming scope in - let width = Config.width in - let addr_width = Config.reg_addr_width in - - let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in - let sm = State_machine.create (module State) spec ~enable:vdd in - - (* Internal register file *) - let reg_file = Array.init 17 ~f:(fun _ -> Variable.reg spec ~width) in - - (* Step counter *) - let step = Variable.reg spec ~width:6 in - - (* Flag: has operation been started this step? *) - let op_started = Variable.reg spec ~width:1 in - - (* Output registers *) - let out_x3 = Variable.reg spec ~width in - let out_y3 = Variable.reg spec ~width in - let out_z3 = Variable.reg spec ~width in - let done_flag = Variable.reg spec ~width:1 in - - (* Build instruction decode ROM *) - let default_instr = of_int ~width:addr_width 0 in - let decode field = - mux step.value - (Array.to_list (Array.map program ~f:(fun instr -> - of_int ~width:addr_width (field instr)))) - |> fun s -> mux2 (step.value >=:. Config.num_steps) default_instr s - in - let decode_op = - mux step.value - (Array.to_list (Array.map program ~f:(fun instr -> - of_int ~width:2 instr.op))) - |> fun s -> mux2 (step.value >=:. Config.num_steps) (zero 2) s - in - - let current_dst = decode (fun instr -> instr.dst) in - let current_src1 = decode (fun instr -> instr.src1) in - let current_src2 = decode (fun instr -> instr.src2) in - let current_op = decode_op in - - (* Arith interface - directly from decode *) - let arith_start = Variable.wire ~default:gnd in - - let arith_read_data_a = - mux current_src1 (Array.to_list (Array.map reg_file ~f:(fun r -> r.value))) - in - let arith_read_data_b = - mux current_src2 (Array.to_list (Array.map reg_file ~f:(fun r -> r.value))) - in - - let arith_out = Arith.create (Scope.sub_scope scope "arith") - { Arith.I. - clock = i.clock - ; clear = i.clear - ; start = arith_start.value - ; op = current_op - ; prime_sel = gnd - ; addr_a = current_src1 - ; addr_b = current_src2 - ; addr_out = current_dst - ; reg_read_data_a = arith_read_data_a - ; reg_read_data_b = arith_read_data_b - } - in - - compile [ - done_flag <-- gnd; - - sm.switch [ - State.Idle, [ - when_ i.start [ - step <-- zero 6; - op_started <-- gnd; - sm.set_next Load_inputs; - ]; - ]; - -State.Load_inputs, [ - reg_file.(Config.x1) <-- i.x1; - reg_file.(Config.y1) <-- i.y1; - reg_file.(Config.z1) <-- i.z1; - reg_file.(Config.x2) <-- i.x2; - reg_file.(Config.y2) <-- i.y2; - reg_file.(Config.z2) <-- i.z2; - reg_file.(Config.param_a) <-- i.param_a; - reg_file.(Config.param_b3) <-- i.param_b3; - op_started <-- gnd; - sm.set_next Wait_load; (* Changed from Run_step *) -]; - -State.Wait_load, [ - (* Wait one cycle for register writes to take effect *) - sm.set_next Run_step; -]; - -State.Run_step, [ - if_ (~:(op_started.value)) [ - (* First cycle of step: start operation *) - arith_start <-- vdd; - op_started <-- vdd; - ] [ - (* Wait for operation to complete *) - when_ arith_out.done_ [ - (* Write result to destination register *) - proc (Array.to_list (Array.mapi reg_file ~f:(fun idx reg -> - when_ (current_dst ==:. idx) [ - reg <-- arith_out.reg_write_data; - ]))); - - (* Move to next step or finish *) - if_ (step.value ==:. Config.num_steps - 1) [ - sm.set_next Output; - ] [ - step <-- step.value +:. 1; - op_started <-- gnd; - ]; - ]; - ]; -]; - - State.Output, [ - out_x3 <-- reg_file.(Config.x3).value; - out_y3 <-- reg_file.(Config.y3).value; - out_z3 <-- reg_file.(Config.z3).value; - sm.set_next Done; - ]; - - State.Done, [ - done_flag <-- vdd; - sm.set_next Idle; - ]; - ]; - ]; - - { O. - busy = ~:(sm.is Idle) -- "busy" - ; done_ = done_flag.value -- "done" - ; x3 = out_x3.value -- "x3" - ; y3 = out_y3.value -- "y3" - ; z3 = out_z3.value -- "z3" - } \ No newline at end of file diff --git a/src/point_mul.ml b/src/point_mul.ml deleted file mode 100644 index ca91799..0000000 --- a/src/point_mul.ml +++ /dev/null @@ -1,315 +0,0 @@ -open Base -open Hardcaml -open Signal - -module Config = struct - let width = 256 - let reg_addr_width = 5 - let num_steps = 40 - - let t0 = 0 - let t1 = 1 - let t2 = 2 - let t3 = 3 - let t4 = 4 - let t5 = 5 - let x3 = 6 - let y3 = 7 - let z3 = 8 - let x1 = 9 - let y1 = 10 - let z1 = 11 - let x2 = 12 - let y2 = 13 - let z2 = 14 - let param_a = 15 - let param_b3 = 16 - - let num_regs = 17 - - let generator_x = Z.of_string "0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" - let generator_y = Z.of_string "0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8" - let generator_z = Z.one - - let infinity_x = Z.zero - let infinity_y = Z.one - let infinity_z = Z.zero -end - -module Op = struct - let add = 0 - let sub = 1 - let mul = 2 -end - -module I = struct - type 'a t = - { clock : 'a - ; clear : 'a - ; start : 'a - ; scalar : 'a [@bits Config.width] - ; param_a : 'a [@bits Config.width] - ; param_b3 : 'a [@bits Config.width] - } - [@@deriving sexp_of, hardcaml] -end - -module O = struct - type 'a t = - { busy : 'a - ; done_ : 'a - ; x : 'a [@bits Config.width] - ; y : 'a [@bits Config.width] - ; z : 'a [@bits Config.width] - } - [@@deriving sexp_of, hardcaml] -end - -module State = struct - type t = - | Idle - | Loop - | Load - | Run_add - | Done - [@@deriving sexp_of, compare, enumerate] -end - -type instr = { op : int; src1 : int; src2 : int; dst : int } - -let program = [| - { op = Op.mul; src1 = Config.x1; src2 = Config.x2; dst = Config.t0 }; - { op = Op.mul; src1 = Config.y1; src2 = Config.y2; dst = Config.t1 }; - { op = Op.mul; src1 = Config.z1; src2 = Config.z2; dst = Config.t2 }; - { op = Op.add; src1 = Config.x1; src2 = Config.y1; dst = Config.t3 }; - { op = Op.add; src1 = Config.x2; src2 = Config.y2; dst = Config.t4 }; - { op = Op.mul; src1 = Config.t3; src2 = Config.t4; dst = Config.t3 }; - { op = Op.add; src1 = Config.t0; src2 = Config.t1; dst = Config.t4 }; - { op = Op.sub; src1 = Config.t3; src2 = Config.t4; dst = Config.t3 }; - { op = Op.add; src1 = Config.x1; src2 = Config.z1; dst = Config.t4 }; - { op = Op.add; src1 = Config.x2; src2 = Config.z2; dst = Config.t5 }; - { op = Op.mul; src1 = Config.t4; src2 = Config.t5; dst = Config.t4 }; - { op = Op.add; src1 = Config.t0; src2 = Config.t2; dst = Config.t5 }; - { op = Op.sub; src1 = Config.t4; src2 = Config.t5; dst = Config.t4 }; - { op = Op.add; src1 = Config.y1; src2 = Config.z1; dst = Config.t5 }; - { op = Op.add; src1 = Config.y2; src2 = Config.z2; dst = Config.x3 }; - { op = Op.mul; src1 = Config.t5; src2 = Config.x3; dst = Config.t5 }; - { op = Op.add; src1 = Config.t1; src2 = Config.t2; dst = Config.x3 }; - { op = Op.sub; src1 = Config.t5; src2 = Config.x3; dst = Config.t5 }; - { op = Op.mul; src1 = Config.param_a; src2 = Config.t4; dst = Config.z3 }; - { op = Op.mul; src1 = Config.param_b3; src2 = Config.t2; dst = Config.x3 }; - { op = Op.add; src1 = Config.x3; src2 = Config.z3; dst = Config.z3 }; - { op = Op.sub; src1 = Config.t1; src2 = Config.z3; dst = Config.x3 }; - { op = Op.add; src1 = Config.t1; src2 = Config.z3; dst = Config.z3 }; - { op = Op.mul; src1 = Config.x3; src2 = Config.z3; dst = Config.y3 }; - { op = Op.add; src1 = Config.t0; src2 = Config.t0; dst = Config.t1 }; - { op = Op.add; src1 = Config.t1; src2 = Config.t0; dst = Config.t1 }; - { op = Op.mul; src1 = Config.param_a; src2 = Config.t2; dst = Config.t2 }; - { op = Op.mul; src1 = Config.param_b3; src2 = Config.t4; dst = Config.t4 }; - { op = Op.add; src1 = Config.t1; src2 = Config.t2; dst = Config.t1 }; - { op = Op.sub; src1 = Config.t0; src2 = Config.t2; dst = Config.t2 }; - { op = Op.mul; src1 = Config.param_a; src2 = Config.t2; dst = Config.t2 }; - { op = Op.add; src1 = Config.t4; src2 = Config.t2; dst = Config.t4 }; - { op = Op.mul; src1 = Config.t1; src2 = Config.t4; dst = Config.t0 }; - { op = Op.add; src1 = Config.y3; src2 = Config.t0; dst = Config.y1 }; - { op = Op.mul; src1 = Config.t5; src2 = Config.t4; dst = Config.t0 }; - { op = Op.mul; src1 = Config.t3; src2 = Config.x3; dst = Config.x3 }; - { op = Op.sub; src1 = Config.x3; src2 = Config.t0; dst = Config.x1 }; - { op = Op.mul; src1 = Config.t3; src2 = Config.t1; dst = Config.t0 }; - { op = Op.mul; src1 = Config.t5; src2 = Config.z3; dst = Config.z3 }; - { op = Op.add; src1 = Config.z3; src2 = Config.t0; dst = Config.z1 }; -|] - -(* Helper to extract a bit from a signal using a signal index *) -let bit_select_dynamic signal index_signal = - let width = Signal.width signal in - let bits = List.init width ~f:(fun i -> bit signal i) in - mux index_signal bits - -let create scope (i : _ I.t) = - let open Always in - let width = Config.width in - let addr_width = Config.reg_addr_width in - - let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in - let sm = State_machine.create (module State) spec ~enable:vdd in - - let reg_file = Array.init Config.num_regs ~f:(fun _ -> Variable.reg spec ~width) in - - let const_gx = of_z ~width Config.generator_x in - let const_gy = of_z ~width Config.generator_y in - let const_gz = of_z ~width Config.generator_z in - - let scalar_reg = Variable.reg spec ~width in - let bit_pos = Variable.reg spec ~width:8 in - let doubling = Variable.reg spec ~width:1 in - let last_step = Variable.reg spec ~width:1 in - let step = Variable.reg spec ~width:6 in - let load_idx = Variable.reg spec ~width:1 in (* 0 = P, 1 = G *) - - (* Latched second operand registers *) - let x2_latched = Variable.reg spec ~width in - let y2_latched = Variable.reg spec ~width in - let z2_latched = Variable.reg spec ~width in - - let out_x = Variable.reg spec ~width in - let out_y = Variable.reg spec ~width in - let out_z = Variable.reg spec ~width in - let done_flag = Variable.reg spec ~width:1 in - - let current_bit = bit_select_dynamic scalar_reg.value bit_pos.value in - - (* Check if P is point at infinity (z1 = 0) *) - let p_is_infinity = reg_file.(Config.z1).value ==:. 0 in - - let default_instr = of_int ~width:addr_width 0 in - let decode field = - mux step.value - (Array.to_list (Array.map program ~f:(fun instr -> - of_int ~width:addr_width (field instr)))) - |> fun s -> mux2 (step.value >=:. Config.num_steps) default_instr s - in - let decode_op = - mux step.value - (Array.to_list (Array.map program ~f:(fun instr -> - of_int ~width:2 instr.op))) - |> fun s -> mux2 (step.value >=:. Config.num_steps) (zero 2) s - in - - let current_dst = decode (fun instr -> instr.dst) in - let current_src1 = decode (fun instr -> instr.src1) in - let current_src2 = decode (fun instr -> instr.src2) in - let current_op = decode_op in - - let arith_start = Variable.wire ~default:gnd in - - (* Use latched values for x2/y2/z2 *) - let reg_read idx = - let base_values = Array.to_list (Array.map reg_file ~f:(fun r -> r.value)) in - let base = mux idx base_values in - mux2 (idx ==:. Config.x2) x2_latched.value - (mux2 (idx ==:. Config.y2) y2_latched.value - (mux2 (idx ==:. Config.z2) z2_latched.value base)) - in - - let arith_read_data_a = reg_read current_src1 in - let arith_read_data_b = reg_read current_src2 in - - let arith_out = Arith.create (Scope.sub_scope scope "arith") - { Arith.I. - clock = i.clock - ; clear = i.clear - ; start = arith_start.value - ; op = current_op - ; prime_sel = gnd - ; addr_a = current_src1 - ; addr_b = current_src2 - ; addr_out = current_dst - ; reg_read_data_a = arith_read_data_a - ; reg_read_data_b = arith_read_data_b - } - in - - compile [ - done_flag <-- gnd; - - sm.switch [ - State.Idle, [ - when_ i.start [ - scalar_reg <-- i.scalar; - bit_pos <--. 255; - doubling <-- vdd; - last_step <-- gnd; - step <-- zero 6; - reg_file.(Config.x1) <-- of_z ~width Config.infinity_x; - reg_file.(Config.y1) <-- of_z ~width Config.infinity_y; - reg_file.(Config.z1) <-- of_z ~width Config.infinity_z; - reg_file.(Config.param_a) <-- i.param_a; - reg_file.(Config.param_b3) <-- i.param_b3; - sm.set_next Loop; - ]; - ]; - - State.Loop, [ - if_ last_step.value [ - sm.set_next Done; - ] @@ elif doubling.value [ - (* Doubling phase *) - doubling <-- gnd; (* Next time go to adding step *) - if_ p_is_infinity [ - (* Skip doubling when P is infinity *) - sm.set_next Loop; - ] [ - load_idx <-- gnd; (* Load P *) - sm.set_next Load; - ]; - ] [ - (* Adding phase *) - doubling <-- vdd; (* Next time go to doubling step *) - load_idx <-- current_bit; (* Load G if bit is set, else 0 *) - last_step <-- (bit_pos.value ==:. 0); - bit_pos <-- bit_pos.value -:. 1; - - if_ (~: current_bit) [ - (* Skip add when bit is 0 *) - sm.set_next Loop; - ] [ - sm.set_next Load; - ]; - ]; - ]; - - State.Load, [ - (* Latch second operand based on load_idx *) - if_ load_idx.value [ - (* load_idx = 1: Load G *) - x2_latched <-- const_gx; - y2_latched <-- const_gy; - z2_latched <-- const_gz; - ] [ - (* load_idx = 0: Load P *) - x2_latched <-- reg_file.(Config.x1).value; - y2_latched <-- reg_file.(Config.y1).value; - z2_latched <-- reg_file.(Config.z1).value; - ]; - step <-- zero 6; - sm.set_next Run_add; - ]; - - State.Run_add, [ - (* Only start arith on the first cycle of each step *) - when_ (~:(arith_out.busy) &: (~:(arith_out.done_))) [ - arith_start <-- vdd; - ]; - when_ arith_out.done_ [ - (* Write result to destination register *) - proc (Array.to_list (Array.mapi reg_file ~f:(fun idx reg -> - when_ (current_dst ==:. idx) [ - reg <-- arith_out.reg_write_data; - ]))); - - if_ (step.value ==:. Config.num_steps - 1) [ - sm.set_next Loop; - ] [ - step <-- step.value +:. 1; - ]; - ]; -]; - - State.Done, [ - out_x <-- reg_file.(Config.x1).value; - out_y <-- reg_file.(Config.y1).value; - out_z <-- reg_file.(Config.z1).value; - done_flag <-- vdd; - sm.set_next Idle; - ]; - ]; - ]; - - { O. - busy = ~:(sm.is Idle) - ; done_ = done_flag.value - ; x = out_x.value - ; y = out_y.value - ; z = out_z.value - } \ No newline at end of file diff --git a/test/dune b/test/dune index aae88e4..a4951d1 100644 --- a/test/dune +++ b/test/dune @@ -1,5 +1,5 @@ (tests - (names test_arith test_mod_add test_mod_mul test_mod_inv - test_point_add test_point_mul test_ecdsa test_security_block) + (names test_arith test_mod_add test_mod_mul test_mod_inv + test_ecdsa test_security_block) (libraries base hardcaml stdio zarith off_switch) (preprocess (pps ppx_jane ppx_hardcaml))) diff --git a/test/test_arith.ml b/test/test_arith.ml index fcbe9cf..d139a47 100644 --- a/test/test_arith.ml +++ b/test/test_arith.ml @@ -3,61 +3,55 @@ open Hardcaml let () = Stdio.printf "=== Arith Unit Test ===\n\n"; - + let scope = Scope.create ~flatten_design:true () in let module Sim = Cyclesim.With_interface(Arith.I)(Arith.O) in let sim = Sim.create (Arith.create scope) in - + let inputs = Cyclesim.inputs sim in let outputs = Cyclesim.outputs sim in - + (* Simulated register file *) let registers = Array.init 32 ~f:(fun _ -> Z.zero) in - + let z_to_bits z = let hex_str = Z.format "%x" z in let padded = String.pad_left hex_str ~len:64 ~char:'0' in Bits.of_hex ~width:256 padded in - + let bits_to_z bits = Z.of_string_base 2 (Bits.to_bstr bits) in - + let reset () = inputs.clear := Bits.vdd; inputs.start := Bits.gnd; inputs.op := Bits.zero 2; inputs.prime_sel := Bits.gnd; - inputs.addr_a := Bits.zero 5; - inputs.addr_b := Bits.zero 5; - inputs.addr_out := Bits.zero 5; inputs.reg_read_data_a := Bits.zero 256; inputs.reg_read_data_b := Bits.zero 256; Cyclesim.cycle sim; inputs.clear := Bits.gnd; Cyclesim.cycle sim in - + let run_op ~op ~prime_sel ~addr_a ~addr_b ~addr_out = (* Reset to clear any leftover state *) inputs.clear := Bits.vdd; Cyclesim.cycle sim; inputs.clear := Bits.gnd; Cyclesim.cycle sim; - + (* Provide initial register data for the requested addresses *) inputs.reg_read_data_a := z_to_bits registers.(addr_a); inputs.reg_read_data_b := z_to_bits registers.(addr_b); - + (* Start operation *) inputs.start := Bits.vdd; inputs.op := Bits.of_int ~width:2 op; inputs.prime_sel := if prime_sel then Bits.vdd else Bits.gnd; - inputs.addr_a := Bits.of_int ~width:5 addr_a; - inputs.addr_b := Bits.of_int ~width:5 addr_b; - inputs.addr_out := Bits.of_int ~width:5 addr_out; Cyclesim.cycle sim; inputs.start := Bits.gnd; - + (* Run until done *) let max_cycles = 10_000 in let rec wait n = @@ -65,25 +59,10 @@ let () = Stdio.printf " TIMEOUT after %d cycles\n" max_cycles; false end else begin - (* Update register file read data based on requested addresses *) - let read_addr_a = Bits.to_int !(outputs.reg_read_addr_a) in - let read_addr_b = Bits.to_int !(outputs.reg_read_addr_b) in - inputs.reg_read_data_a := z_to_bits registers.(read_addr_a); - inputs.reg_read_data_b := z_to_bits registers.(read_addr_b); - - (* Check for write and done BEFORE cycling - they're combinational now *) let is_done = Bits.to_bool !(outputs.done_) in - let is_write = Bits.to_bool !(outputs.reg_write_enable) in - - if is_write then begin - let write_addr = Bits.to_int !(outputs.reg_write_addr) in - let write_data = bits_to_z !(outputs.reg_write_data) in - registers.(write_addr) <- write_data - end; - if is_done then begin - (* Cycle once more to return to Idle *) - Cyclesim.cycle sim; + registers.(addr_out) <- bits_to_z !(outputs.reg_write_data); + Cyclesim.cycle sim; (* let machine return to Idle *) true end else begin Cyclesim.cycle sim; @@ -93,20 +72,20 @@ let () = in wait 0 in - + let prime_p = Arith.Config.prime_p in let prime_n = Arith.Config.prime_n in - + (* Test result tracking *) let results = ref [] in let record result = results := result :: !results in - + (* ============================================ *) (* ADDITION TESTS (adapted from mod_add tests) *) (* ============================================ *) - + Stdio.printf "=== Addition Tests ===\n\n"; - + (* Test: Simple addition *) registers.(0) <- Z.of_int 100; registers.(1) <- Z.of_int 200; @@ -122,7 +101,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Addition with modular wrap *) registers.(0) <- Z.(prime_p - of_int 63); registers.(1) <- Z.of_int 100; @@ -138,7 +117,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Add zero *) registers.(0) <- Z.of_int 12345; registers.(1) <- Z.zero; @@ -154,7 +133,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Addition with prime_n (curve order) *) registers.(0) <- Z.of_int 100; registers.(1) <- Z.of_int 200; @@ -170,13 +149,13 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* =============================================== *) (* SUBTRACTION TESTS (adapted from mod_add tests) *) (* =============================================== *) - + Stdio.printf "=== Subtraction Tests ===\n\n"; - + (* Test: Simple subtraction *) registers.(0) <- Z.of_int 500; registers.(1) <- Z.of_int 300; @@ -192,7 +171,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Subtraction requiring modular correction *) registers.(0) <- Z.of_int 100; registers.(1) <- Z.of_int 200; @@ -208,7 +187,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Subtraction from modulus-1 *) registers.(0) <- Z.(prime_p - one); registers.(1) <- Z.of_int 10; @@ -224,7 +203,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Subtract zero *) registers.(0) <- Z.of_int 12345; registers.(1) <- Z.zero; @@ -240,7 +219,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Subtraction with prime_n *) registers.(0) <- Z.of_int 10; registers.(1) <- Z.of_int 20; @@ -256,13 +235,13 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* ================================================= *) (* MULTIPLICATION TESTS (adapted from mod_mul tests) *) (* ================================================= *) - + Stdio.printf "=== Multiplication Tests ===\n\n"; - + (* Test: Simple multiplication *) registers.(0) <- Z.of_int 3; registers.(1) <- Z.of_int 5; @@ -278,7 +257,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Multiply by zero *) registers.(0) <- Z.of_int 12345; registers.(1) <- Z.zero; @@ -294,7 +273,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Multiply by one *) registers.(0) <- Z.of_int 12345; registers.(1) <- Z.one; @@ -310,7 +289,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Medium multiplication *) registers.(0) <- Z.of_int 123456; registers.(1) <- Z.of_int 789012; @@ -326,7 +305,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: 64-bit multiplication *) registers.(0) <- Z.of_string "12345678901234"; registers.(1) <- Z.of_string "98765432109876"; @@ -342,7 +321,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: 128-bit multiplication *) registers.(0) <- Z.of_string "123456789012345678901234567890"; registers.(1) <- Z.of_string "987654321098765432109876543210"; @@ -358,7 +337,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: 256-bit multiplication *) registers.(0) <- Z.of_string "123456789012345678901234567890123456789"; registers.(1) <- Z.of_string "987654321098765432109876543210987654321"; @@ -374,7 +353,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Multiplication with prime_n *) registers.(0) <- Z.of_int 12345; registers.(1) <- Z.of_int 67890; @@ -390,13 +369,13 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* ============================================== *) (* INVERSION TESTS (adapted from mod_inv tests) *) (* ============================================== *) - + Stdio.printf "=== Inversion Tests ===\n\n"; - + (* Test: Simple inversion 3^(-1) mod p *) registers.(0) <- Z.of_int 3; reset (); @@ -414,7 +393,7 @@ let () = record pass end else record false end else record false; - + (* Test: Inversion of 1 *) registers.(0) <- Z.one; reset (); @@ -428,7 +407,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Inversion of (p-1) should be (p-1) since (p-1)^2 = 1 mod p *) registers.(0) <- Z.(prime_p - one); reset (); @@ -446,7 +425,7 @@ let () = record pass end else record false end else record false; - + (* Test: 64-bit inversion *) registers.(0) <- Z.of_string "123456789012345"; reset (); @@ -464,7 +443,7 @@ let () = record pass end else record false end else record false; - + (* Test: 128-bit inversion *) registers.(0) <- Z.of_string "123456789012345678901234567890"; reset (); @@ -482,7 +461,7 @@ let () = record pass end else record false end else record false; - + (* Test: 256-bit inversion *) registers.(0) <- Z.of_string "12345678901234567890123456789012345678901234567890"; reset (); @@ -500,7 +479,7 @@ let () = record pass end else record false end else record false; - + (* Test: Inversion with prime_n (curve order) *) registers.(0) <- Z.of_string "999999999999999999"; reset (); @@ -518,24 +497,24 @@ let () = record pass end else record false end else record false; - + (* ============================================== *) (* CHAINED OPERATIONS TEST *) (* ============================================== *) - + Stdio.printf "=== Chained Operations Test ===\n\n"; - + (* Test a sequence: (a + b) * c mod p *) registers.(0) <- Z.of_int 100; registers.(1) <- Z.of_int 200; registers.(2) <- Z.of_int 50; reset (); - + Stdio.printf "Test: Chained operations ((a + b) * c mod p)\n"; Stdio.printf " a = %s\n" (Z.to_string registers.(0)); Stdio.printf " b = %s\n" (Z.to_string registers.(1)); Stdio.printf " c = %s\n" (Z.to_string registers.(2)); - + (* First: r[20] = r[0] + r[1] *) let chain_pass = ref true in if run_op ~op:0 ~prime_sel:false ~addr_a:0 ~addr_b:1 ~addr_out:20 then begin @@ -554,9 +533,9 @@ let () = (* ============================================== *) (* EDGE CASE TESTS *) (* ============================================== *) - + Stdio.printf "=== Edge Case Tests ===\n\n"; - + (* Test: Inverse of zero should not exist *) registers.(0) <- Z.zero; reset (); @@ -569,7 +548,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: x - x = 0 *) registers.(0) <- Z.of_string "98765432109876543210"; reset (); @@ -583,7 +562,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: (p-1) + 1 mod p = 0 *) registers.(0) <- Z.(prime_p - one); registers.(1) <- Z.one; @@ -599,7 +578,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: 0 - 1 mod p = p-1 *) registers.(0) <- Z.zero; registers.(1) <- Z.one; @@ -615,7 +594,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: (p-1) * (p-1) mod p = 1 since (-1)*(-1) = 1 *) registers.(0) <- Z.(prime_p - one); reset (); @@ -629,7 +608,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Squaring via same register (addr_a = addr_b) *) registers.(0) <- Z.of_int 12345; reset (); @@ -643,7 +622,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: In-place operation (addr_a = addr_out) *) registers.(5) <- Z.of_int 100; registers.(6) <- Z.of_int 50; @@ -660,7 +639,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Large values mod n *) registers.(0) <- Z.of_string "115792089237316195423570985008687907852837564279074904382605163141518161494000"; registers.(1) <- Z.of_string "500"; @@ -676,7 +655,7 @@ let () = Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass end else record false; - + (* Test: Large multiplication mod n *) registers.(0) <- Z.of_string "57896044618658097711785492504343953926418782139537452191302581570759080747168"; registers.(1) <- Z.of_string "2"; @@ -696,52 +675,52 @@ let () = (* ============================================== *) (* BACK-TO-BACK OPERATIONS TEST *) (* ============================================== *) - + Stdio.printf "=== Back-to-Back Operations Test ===\n\n"; - + registers.(0) <- Z.of_int 7; registers.(1) <- Z.of_int 11; reset (); - + Stdio.printf "Test: Rapid back-to-back operations\n"; let btb_pass = ref true in - + (* Op 1: add *) if run_op ~op:0 ~prime_sel:false ~addr_a:0 ~addr_b:1 ~addr_out:25 then begin let exp1 = Z.((of_int 7 + of_int 11) mod prime_p) in if not (Z.equal registers.(25) exp1) then btb_pass := false; - Stdio.printf " Op 1 (add): %s (expected %s) %s\n" + Stdio.printf " Op 1 (add): %s (expected %s) %s\n" (Z.to_string registers.(25)) (Z.to_string exp1) (if Z.equal registers.(25) exp1 then "✓" else "✗") end else btb_pass := false; - + (* Op 2: sub *) if run_op ~op:1 ~prime_sel:false ~addr_a:0 ~addr_b:1 ~addr_out:26 then begin let exp2 = Z.(erem (of_int 7 - of_int 11) prime_p) in if not (Z.equal registers.(26) exp2) then btb_pass := false; - Stdio.printf " Op 2 (sub): %s (expected %s) %s\n" + Stdio.printf " Op 2 (sub): %s (expected %s) %s\n" (Z.to_string registers.(26)) (Z.to_string exp2) (if Z.equal registers.(26) exp2 then "✓" else "✗") end else btb_pass := false; - + (* Op 3: mul *) if run_op ~op:2 ~prime_sel:false ~addr_a:0 ~addr_b:1 ~addr_out:27 then begin let exp3 = Z.((of_int 7 * of_int 11) mod prime_p) in if not (Z.equal registers.(27) exp3) then btb_pass := false; - Stdio.printf " Op 3 (mul): %s (expected %s) %s\n" + Stdio.printf " Op 3 (mul): %s (expected %s) %s\n" (Z.to_string registers.(27)) (Z.to_string exp3) (if Z.equal registers.(27) exp3 then "✓" else "✗") end else btb_pass := false; - + (* Op 4: inv *) if run_op ~op:3 ~prime_sel:false ~addr_a:0 ~addr_b:0 ~addr_out:28 then begin let product = Z.((of_int 7 * registers.(28)) mod prime_p) in if not (Z.equal product Z.one) then btb_pass := false; - Stdio.printf " Op 4 (inv): 7 * %s mod p = %s %s\n" + Stdio.printf " Op 4 (inv): 7 * %s mod p = %s %s\n" (Z.to_string registers.(28)) (Z.to_string product) (if Z.equal product Z.one then "✓" else "✗") end else btb_pass := false; - + Stdio.printf " %s\n\n" (if !btb_pass then "PASS ✓" else "FAIL ✗"); record !btb_pass; @@ -767,7 +746,7 @@ Stdio.printf " r[2] = r[3] = G_y (secp256k1 generator y)\n\n"; (* First multiplication: r[10] = r[0] * r[1] = G_x * G_x *) Stdio.printf "Step 1: r[10] = r[0] * r[1] (G_x * G_x)\n"; -let mul1_pass = +let mul1_pass = if run_op ~op:2 ~prime_sel:false ~addr_a:0 ~addr_b:1 ~addr_out:10 then begin let expected1 = Z.((registers.(0) * registers.(1)) mod prime_p) in Stdio.printf " result = %s...\n" (String.prefix (Z.to_string registers.(10)) 40); @@ -798,7 +777,7 @@ Stdio.printf " r[10] != r[11]: %b\n" different; Stdio.printf " %s\n\n" (if different then "PASS ✓" else "FAIL ✗ (results should differ!)"); let btb_mul_pass = mul1_pass && mul2_pass && different in -Stdio.printf "Back-to-back multiplication test: %s\n\n" +Stdio.printf "Back-to-back multiplication test: %s\n\n" (if btb_mul_pass then "PASS ✓" else "FAIL ✗"); record btb_mul_pass; @@ -824,7 +803,7 @@ Stdio.printf " Mul 1: r[20] = r[0] * r[1] = 12345 * 67890\n"; if run_op ~op:2 ~prime_sel:false ~addr_a:0 ~addr_b:1 ~addr_out:20 then begin let exp1 = Z.((of_int 12345 * of_int 67890) mod prime_p) in if not (Z.equal registers.(20) exp1) then triple_pass := false; - Stdio.printf " result=%s expected=%s %s\n" + Stdio.printf " result=%s expected=%s %s\n" (Z.to_string registers.(20)) (Z.to_string exp1) (if Z.equal registers.(20) exp1 then "✓" else "✗") end else triple_pass := false; @@ -834,7 +813,7 @@ Stdio.printf " Mul 2: r[21] = r[2] * r[3] = 11111 * 22222\n"; if run_op ~op:2 ~prime_sel:false ~addr_a:2 ~addr_b:3 ~addr_out:21 then begin let exp2 = Z.((of_int 11111 * of_int 22222) mod prime_p) in if not (Z.equal registers.(21) exp2) then triple_pass := false; - Stdio.printf " result=%s expected=%s %s\n" + Stdio.printf " result=%s expected=%s %s\n" (Z.to_string registers.(21)) (Z.to_string exp2) (if Z.equal registers.(21) exp2 then "✓" else "✗") end else triple_pass := false; @@ -844,12 +823,12 @@ Stdio.printf " Mul 3: r[22] = r[4] * r[5] = 33333 * 44444\n"; if run_op ~op:2 ~prime_sel:false ~addr_a:4 ~addr_b:5 ~addr_out:22 then begin let exp3 = Z.((of_int 33333 * of_int 44444) mod prime_p) in if not (Z.equal registers.(22) exp3) then triple_pass := false; - Stdio.printf " result=%s expected=%s %s\n" + Stdio.printf " result=%s expected=%s %s\n" (Z.to_string registers.(22)) (Z.to_string exp3) (if Z.equal registers.(22) exp3 then "✓" else "✗") end else triple_pass := false; -Stdio.printf " Triple multiplication test: %s\n\n" +Stdio.printf " Triple multiplication test: %s\n\n" (if !triple_pass then "PASS ✓" else "FAIL ✗"); record !triple_pass; @@ -913,10 +892,10 @@ if run_op ~op:2 ~prime_sel:false ~addr_a:13 ~addr_b:1 ~addr_out:14 then begin (if Z.equal registers.(14) exp then "✓" else "✗") end else mixed_pass := false; -Stdio.printf " Mixed operation test: %s\n\n" +Stdio.printf " Mixed operation test: %s\n\n" (if !mixed_pass then "PASS ✓" else "FAIL ✗"); record !mixed_pass; - + (* ============================================== *) (* TRUE BACK-TO-BACK TEST (no reset between ops) *) (* ============================================== *) @@ -940,14 +919,11 @@ let run_op_no_reset ~op ~data_a ~data_b = inputs.start := Bits.vdd; inputs.op := Bits.of_int ~width:2 op; inputs.prime_sel := Bits.gnd; - inputs.addr_a := Bits.zero 5; - inputs.addr_b := Bits.zero 5; - inputs.addr_out := Bits.zero 5; inputs.reg_read_data_a := z_to_bits data_a; inputs.reg_read_data_b := z_to_bits data_b; Cyclesim.cycle sim; inputs.start := Bits.gnd; - + let max_cycles = 3000 in let rec wait n = if n >= max_cycles then begin @@ -955,6 +931,7 @@ let run_op_no_reset ~op ~data_a ~data_b = None end else if Bits.to_bool !(outputs.done_) then begin let result = bits_to_z !(outputs.reg_write_data) in + Cyclesim.cycle sim; (* let Done -> Idle transition complete *) Some result end else begin Cyclesim.cycle sim; @@ -967,24 +944,24 @@ in Stdio.printf "Op 1: 12345 * 67890\n"; let result1 = run_op_no_reset ~op:2 ~data_a:val_a ~data_b:val_b in (match result1 with -| Some r -> +| Some r -> let expected = Z.((val_a * val_b) mod prime_p) in let pass = Z.equal r expected in - Stdio.printf " result=%s expected=%s %s\n" + Stdio.printf " result=%s expected=%s %s\n" (Z.to_string r) (Z.to_string expected) (if pass then "✓" else "✗") | None -> ()); Stdio.printf "Op 2 (immediate): 11111 * 22222\n"; let result2 = run_op_no_reset ~op:2 ~data_a:val_c ~data_b:val_d in (match result2 with -| Some r -> +| Some r -> let expected = Z.((val_c * val_d) mod prime_p) in let pass = Z.equal r expected in - Stdio.printf " result=%s expected=%s %s\n" + Stdio.printf " result=%s expected=%s %s\n" (Z.to_string r) (Z.to_string expected) (if pass then "✓" else "✗") | None -> ()); -let true_btb_pass = +let true_btb_pass = match result1, result2 with | Some r1, Some r2 -> let exp1 = Z.((val_a * val_b) mod prime_p) in @@ -992,7 +969,7 @@ let true_btb_pass = Z.equal r1 exp1 && Z.equal r2 exp2 | _ -> false in -Stdio.printf "True back-to-back test: %s\n\n" +Stdio.printf "True back-to-back test: %s\n\n" (if true_btb_pass then "PASS ✓" else "FAIL ✗"); record true_btb_pass; @@ -1015,7 +992,7 @@ Cyclesim.cycle sim; (* Simulate Point_mul's register file state after Init: x1 = 0 (infinity_x) - y1 = 1 (infinity_y) + y1 = 1 (infinity_y) z1 = 0 (infinity_z) When j=0 (doubling), x2=x1, y2=y1, z2=z1 *) @@ -1031,38 +1008,28 @@ let run_pm_op ~src1 ~src2 ~dst = (* Provide register data *) inputs.reg_read_data_a := z_to_bits pm_registers.(src1); inputs.reg_read_data_b := z_to_bits pm_registers.(src2); - + (* Start multiplication *) inputs.start := Bits.vdd; inputs.op := Bits.of_int ~width:2 2; (* mul *) inputs.prime_sel := Bits.gnd; - inputs.addr_a := Bits.of_int ~width:5 src1; - inputs.addr_b := Bits.of_int ~width:5 src2; - inputs.addr_out := Bits.of_int ~width:5 dst; Cyclesim.cycle sim; inputs.start := Bits.gnd; - - (* Wait for completion, updating reg reads as requested *) + + (* Wait for completion *) let max_cycles = 300 in let rec wait n = if n >= max_cycles then begin Stdio.printf " TIMEOUT after %d cycles!\n" max_cycles; - Stdio.printf " busy=%d done=%d\n" + Stdio.printf " busy=%d done=%d\n" (Bits.to_int !(outputs.busy)) (Bits.to_int !(outputs.done_)); None end else begin - let read_addr_a = Bits.to_int !(outputs.reg_read_addr_a) in - let read_addr_b = Bits.to_int !(outputs.reg_read_addr_b) in - inputs.reg_read_data_a := z_to_bits pm_registers.(read_addr_a); - inputs.reg_read_data_b := z_to_bits pm_registers.(read_addr_b); - if Bits.to_bool !(outputs.done_) then begin let result = bits_to_z !(outputs.reg_write_data) in - if Bits.to_bool !(outputs.reg_write_enable) then begin - let write_addr = Bits.to_int !(outputs.reg_write_addr) in - pm_registers.(write_addr) <- result - end; + pm_registers.(dst) <- result; + Cyclesim.cycle sim; (* let Done -> Idle complete *) Stdio.printf " Completed in %d cycles\n" n; Some result end else begin @@ -1079,9 +1046,9 @@ Stdio.printf "Op 0: t0 <- x1 * x2 (src1=9, src2=12, dst=0)\n"; Stdio.printf " x1 = %s, x2 = %s\n" (Z.to_string pm_registers.(9)) (Z.to_string pm_registers.(12)); let op0_result = run_pm_op ~src1:9 ~src2:12 ~dst:0 in (match op0_result with -| Some r -> +| Some r -> let expected = Z.zero in - Stdio.printf " result = %s, expected = %s %s\n\n" + Stdio.printf " result = %s, expected = %s %s\n\n" (Z.to_string r) (Z.to_string expected) (if Z.equal r expected then "✓" else "✗") | None -> Stdio.printf "\n"); @@ -1091,9 +1058,9 @@ Stdio.printf "Op 1: t1 <- y1 * y2 (src1=10, src2=13, dst=1)\n"; Stdio.printf " y1 = %s, y2 = %s\n" (Z.to_string pm_registers.(10)) (Z.to_string pm_registers.(13)); let op1_result = run_pm_op ~src1:10 ~src2:13 ~dst:1 in (match op1_result with -| Some r -> +| Some r -> let expected = Z.one in - Stdio.printf " result = %s, expected = %s %s\n\n" + Stdio.printf " result = %s, expected = %s %s\n\n" (Z.to_string r) (Z.to_string expected) (if Z.equal r expected then "✓" else "✗") | None -> Stdio.printf "\n"); @@ -1103,20 +1070,20 @@ Stdio.printf "Op 2: t2 <- z1 * z2 (src1=11, src2=14, dst=2)\n"; Stdio.printf " z1 = %s, z2 = %s\n" (Z.to_string pm_registers.(11)) (Z.to_string pm_registers.(14)); let op2_result = run_pm_op ~src1:11 ~src2:14 ~dst:2 in (match op2_result with -| Some r -> +| Some r -> let expected = Z.zero in - Stdio.printf " result = %s, expected = %s %s\n\n" + Stdio.printf " result = %s, expected = %s %s\n\n" (Z.to_string r) (Z.to_string expected) (if Z.equal r expected then "✓" else "✗") | None -> Stdio.printf "\n"); -let pm_pattern_pass = +let pm_pattern_pass = match op0_result, op1_result, op2_result with | Some r0, Some r1, Some r2 -> Z.equal r0 Z.zero && Z.equal r1 Z.one && Z.equal r2 Z.zero | _ -> false in -Stdio.printf "Point_mul pattern test: %s\n\n" +Stdio.printf "Point_mul pattern test: %s\n\n" (if pm_pattern_pass then "PASS ✓" else "FAIL ✗"); record pm_pattern_pass; @@ -1136,23 +1103,20 @@ Cyclesim.cycle sim; let run_mul_verbose ~a ~b ~label = Stdio.printf "%s: %s * %s\n" label (Z.to_string a) (Z.to_string b); - + inputs.reg_read_data_a := z_to_bits a; inputs.reg_read_data_b := z_to_bits b; inputs.start := Bits.vdd; inputs.op := Bits.of_int ~width:2 2; (* mul *) inputs.prime_sel := Bits.gnd; - inputs.addr_a := Bits.zero 5; - inputs.addr_b := Bits.zero 5; - inputs.addr_out := Bits.zero 5; Cyclesim.cycle sim; inputs.start := Bits.gnd; - + for cycle = 1 to 20 do let busy = Bits.to_int !(outputs.busy) in let done_ = Bits.to_int !(outputs.done_) in Stdio.printf " cycle %2d: busy=%d done=%d\n" cycle busy done_; - + if done_ = 1 then begin let result = bits_to_z !(outputs.reg_write_data) in Stdio.printf " Result: %s\n\n" (Z.to_string result); @@ -1214,14 +1178,14 @@ Stdio.printf "=== End Montgomery Debug Test ===\n\n"; (* ============================================== *) (* TEST SUMMARY *) (* ============================================== *) - + let results_list = List.rev !results in let passed = List.count results_list ~f:Fn.id in let total = List.length results_list in - + Stdio.printf "=== Test Summary ===\n"; Stdio.printf "Passed: %d/%d\n" passed total; - + if passed = total then begin Stdio.printf "\n"; Stdio.printf "███████╗██╗ ██╗ ██████╗ ██████╗███████╗███████╗███████╗\n"; diff --git a/test/test_point_add.ml b/test/test_point_add.ml deleted file mode 100644 index 7ffd466..0000000 --- a/test/test_point_add.ml +++ /dev/null @@ -1,758 +0,0 @@ -open Base -open Hardcaml - -let () = - Stdio.printf "=== Point Addition Unit Test ===\n\n"; - - let scope = Scope.create ~flatten_design:true () in - let module Sim = Cyclesim.With_interface(Point_add.I)(Point_add.O) in - let sim = Sim.create (Point_add.create scope) in - - let inputs = Cyclesim.inputs sim in - let outputs = Cyclesim.outputs sim in - - let prime_p = Arith.Config.prime_p in - - (* secp256k1 curve parameters: y² = x³ + 7, so a = 0, b = 7, b3 = 21 *) - let param_a = Z.zero in - let param_b3 = Z.of_int 21 in - - (* Generator point G for secp256k1 *) - let g_x = Z.of_string_base 16 "79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" in - let g_y = Z.of_string_base 16 "483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8" in - - let z_to_bits z = - let hex_str = Z.format "%x" z in - let padded = String.pad_left hex_str ~len:64 ~char:'0' in - Bits.of_hex ~width:256 padded - in - - let bits_to_z bits = Z.of_string_base 2 (Bits.to_bstr bits) in - - (* Modular arithmetic helpers *) - let mod_add a b = Z.((a + b) mod prime_p) in - let mod_sub a b = Z.(erem (a - b) prime_p) in - let mod_mul a b = Z.((a * b) mod prime_p) in - let mod_inv a = Z.invert a prime_p in - - (* Convert projective (X:Y:Z) to affine (x, y) *) - let proj_to_affine x y z = - if Z.equal z Z.zero then - None (* Point at infinity *) - else - let z_inv = mod_inv z in - let aff_x = mod_mul x z_inv in - let aff_y = mod_mul y z_inv in - Some (aff_x, aff_y) - in - - (* Check if projective point equals affine point *) - let proj_equals_affine ~proj_x ~proj_y ~proj_z ~aff_x ~aff_y = - match proj_to_affine proj_x proj_y proj_z with - | None -> false - | Some (x, y) -> Z.equal x aff_x && Z.equal y aff_y - in - - (* Reference implementation: affine point addition *) - let affine_add (x1, y1) (x2, y2) = - if Z.equal x1 x2 then begin - if Z.equal y1 y2 then - (* Point doubling *) - let lambda = mod_mul (mod_mul (Z.of_int 3) (mod_mul x1 x1)) (mod_inv (mod_mul (Z.of_int 2) y1)) in - let x3 = mod_sub (mod_mul lambda lambda) (mod_add x1 x2) in - let y3 = mod_sub (mod_mul lambda (mod_sub x1 x3)) y1 in - Some (x3, y3) - else - (* P + (-P) = O *) - None - end else begin - let lambda = mod_mul (mod_sub y2 y1) (mod_inv (mod_sub x2 x1)) in - let x3 = mod_sub (mod_sub (mod_mul lambda lambda) x1) x2 in - let y3 = mod_sub (mod_mul lambda (mod_sub x1 x3)) y1 in - Some (x3, y3) - end - in - - (* Scalar multiplication using double-and-add *) - let scalar_mult k (x, y) = - let rec loop n acc pt = - if Z.equal n Z.zero then acc - else - let acc' = - if Z.(equal (n land one) one) then - match acc with - | None -> Some pt - | Some a -> affine_add a pt - else acc - in - let pt' = - match affine_add pt pt with - | None -> pt (* shouldn't happen for valid points *) - | Some p -> p - in - loop Z.(n asr 1) acc' pt' - in - loop k None (x, y) - in - - let reset () = - inputs.clear := Bits.vdd; - inputs.start := Bits.gnd; - inputs.x1 := Bits.zero 256; - inputs.y1 := Bits.zero 256; - inputs.z1 := Bits.zero 256; - inputs.x2 := Bits.zero 256; - inputs.y2 := Bits.zero 256; - inputs.z2 := Bits.zero 256; - inputs.param_a := Bits.zero 256; - inputs.param_b3 := Bits.zero 256; - Cyclesim.cycle sim; - inputs.clear := Bits.gnd; - Cyclesim.cycle sim - in - - let run_point_add ~x1 ~y1 ~z1 ~x2 ~y2 ~z2 = - reset (); - - inputs.x1 := z_to_bits x1; - inputs.y1 := z_to_bits y1; - inputs.z1 := z_to_bits z1; - inputs.x2 := z_to_bits x2; - inputs.y2 := z_to_bits y2; - inputs.z2 := z_to_bits z2; - inputs.param_a := z_to_bits param_a; - inputs.param_b3 := z_to_bits param_b3; - inputs.start := Bits.vdd; - Cyclesim.cycle sim; - inputs.start := Bits.gnd; - - let max_cycles = 50000 in - let rec wait n = - if n >= max_cycles then begin - Stdio.printf " TIMEOUT after %d cycles\n" max_cycles; - None - end else if Bits.to_bool !(outputs.done_) then begin - let result_x = bits_to_z !(outputs.x3) in - let result_y = bits_to_z !(outputs.y3) in - let result_z = bits_to_z !(outputs.z3) in - Stdio.printf " Completed in %d cycles\n" n; - Some (result_x, result_y, result_z) - end else begin - Cyclesim.cycle sim; - wait (n + 1) - end - in - wait 0 - in - - let results = ref [] in - let record result = results := result :: !results in - - (* ============================================== *) - (* TEST 1: G + G = 2G (Point Doubling) *) - (* ============================================== *) - - Stdio.printf "Test 1: G + G = 2G (Point Doubling)\n"; - Stdio.printf " G_x = %s\n" (Z.to_string g_x); - Stdio.printf " G_y = %s\n" (Z.to_string g_y); - - let expected_2g = affine_add (g_x, g_y) (g_x, g_y) in - - (match run_point_add ~x1:g_x ~y1:g_y ~z1:Z.one ~x2:g_x ~y2:g_y ~z2:Z.one with - | None -> record false - | Some (rx, ry, rz) -> - Stdio.printf " Result (projective): X=%s..., Y=%s..., Z=%s...\n" - (String.prefix (Z.to_string rx) 20) - (String.prefix (Z.to_string ry) 20) - (String.prefix (Z.to_string rz) 20); - match expected_2g with - | None -> - Stdio.printf " Expected: point at infinity\n"; - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected (affine): x=%s..., y=%s...\n" - (String.prefix (Z.to_string ex) 20) - (String.prefix (Z.to_string ey) 20); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 2: G + 2G = 3G *) - (* ============================================== *) - - Stdio.printf "Test 2: G + 2G = 3G\n"; - - let p2g = affine_add (g_x, g_y) (g_x, g_y) in - (match p2g with - | None -> - Stdio.printf " ERROR: 2G is point at infinity\n"; - record false - | Some (x_2g, y_2g) -> - Stdio.printf " 2G_x = %s...\n" (String.prefix (Z.to_string x_2g) 30); - Stdio.printf " 2G_y = %s...\n" (String.prefix (Z.to_string y_2g) 30); - - let expected_3g = affine_add (g_x, g_y) (x_2g, y_2g) in - - (match run_point_add ~x1:g_x ~y1:g_y ~z1:Z.one ~x2:x_2g ~y2:y_2g ~z2:Z.one with - | None -> record false - | Some (rx, ry, rz) -> - match expected_3g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 3G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - Stdio.printf " Expected 3G_y = %s...\n" (String.prefix (Z.to_string ey) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass)); - - (* ============================================== *) - (* TEST 3: 2G + 2G = 4G *) - (* ============================================== *) - - Stdio.printf "Test 3: 2G + 2G = 4G\n"; - - (match p2g with - | None -> record false - | Some (x_2g, y_2g) -> - let expected_4g = affine_add (x_2g, y_2g) (x_2g, y_2g) in - - (match run_point_add ~x1:x_2g ~y1:y_2g ~z1:Z.one ~x2:x_2g ~y2:y_2g ~z2:Z.one with - | None -> record false - | Some (rx, ry, rz) -> - match expected_4g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 4G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - Stdio.printf " Expected 4G_y = %s...\n" (String.prefix (Z.to_string ey) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass)); - - (* ============================================== *) - (* TEST 4: Compute 5G via additions *) - (* ============================================== *) - - Stdio.printf "Test 4: 2G + 3G = 5G\n"; - - let p3g = - match p2g with - | None -> None - | Some (x_2g, y_2g) -> affine_add (g_x, g_y) (x_2g, y_2g) - in - - (match p2g, p3g with - | Some (x_2g, y_2g), Some (x_3g, y_3g) -> - let expected_5g = affine_add (x_2g, y_2g) (x_3g, y_3g) in - - (match run_point_add ~x1:x_2g ~y1:y_2g ~z1:Z.one ~x2:x_3g ~y2:y_3g ~z2:Z.one with - | None -> record false - | Some (rx, ry, rz) -> - match expected_5g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 5G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - Stdio.printf " Expected 5G_y = %s...\n" (String.prefix (Z.to_string ey) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass) - | _ -> - Stdio.printf " ERROR: Could not compute 2G or 3G\n"; - record false); - - (* ============================================== *) - (* TEST 5: P + O = P (identity element) *) - (* ============================================== *) - - Stdio.printf "Test 5: G + O = G (adding point at infinity)\n"; - Stdio.printf " Using O = (0:1:0) as point at infinity\n"; - - (* Point at infinity in projective: (0:1:0) *) - (match run_point_add ~x1:g_x ~y1:g_y ~z1:Z.one ~x2:Z.zero ~y2:Z.one ~z2:Z.zero with - | None -> record false - | Some (rx, ry, rz) -> - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:g_x ~aff_y:g_y in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 6: O + P = P (identity element) *) - (* ============================================== *) - - Stdio.printf "Test 6: O + G = G (point at infinity + G)\n"; - - (match run_point_add ~x1:Z.zero ~y1:Z.one ~z1:Z.zero ~x2:g_x ~y2:g_y ~z2:Z.one with - | None -> record false - | Some (rx, ry, rz) -> - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:g_x ~aff_y:g_y in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 7: P + (-P) = O (inverse) *) - (* ============================================== *) - - Stdio.printf "Test 7: G + (-G) = O (point plus its inverse)\n"; - - let neg_g_y = mod_sub Z.zero g_y in (* -G has same x, negated y *) - Stdio.printf " -G_y = %s...\n" (String.prefix (Z.to_string neg_g_y) 30); - - (match run_point_add ~x1:g_x ~y1:g_y ~z1:Z.one ~x2:g_x ~y2:neg_g_y ~z2:Z.one with - | None -> record false - | Some (_rx, _ry, rz) -> - Stdio.printf " Result Z = %s\n" (Z.to_string rz); - (* Result should be point at infinity: Z = 0 *) - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 8: Non-trivial projective coordinates *) - (* ============================================== *) - - Stdio.printf "Test 8: Addition with non-trivial Z coordinates\n"; - - (* Represent G as (2*G_x : 2*G_y : 2) which equals (G_x : G_y : 1) *) - let z_val = Z.of_int 2 in - let x1_proj = mod_mul g_x z_val in - let y1_proj = mod_mul g_y z_val in - - (* Represent 2G similarly *) - (match p2g with - | None -> record false - | Some (x_2g, y_2g) -> - let z2_val = Z.of_int 3 in - let x2_proj = mod_mul x_2g z2_val in - let y2_proj = mod_mul y_2g z2_val in - - let expected_3g = affine_add (g_x, g_y) (x_2g, y_2g) in - - (match run_point_add ~x1:x1_proj ~y1:y1_proj ~z1:z_val ~x2:x2_proj ~y2:y2_proj ~z2:z2_val with - | None -> record false - | Some (rx, ry, rz) -> - match expected_3g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 3G (affine): x=%s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass)); - - (* ============================================== *) - (* TEST 9: Verify 7G via scalar multiplication *) - (* ============================================== *) - - Stdio.printf "Test 9: Compute 7G = 3G + 4G and verify\n"; - - let p4g = - match p2g with - | None -> None - | Some (x_2g, y_2g) -> affine_add (x_2g, y_2g) (x_2g, y_2g) - in - - (match p3g, p4g with - | Some (x_3g, y_3g), Some (x_4g, y_4g) -> - let expected_7g = scalar_mult (Z.of_int 7) (g_x, g_y) in - - (match run_point_add ~x1:x_3g ~y1:y_3g ~z1:Z.one ~x2:x_4g ~y2:y_4g ~z2:Z.one with - | None -> record false - | Some (rx, ry, rz) -> - match expected_7g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 7G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - Stdio.printf " Expected 7G_y = %s...\n" (String.prefix (Z.to_string ey) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass) - | _ -> - Stdio.printf " ERROR: Could not compute 3G or 4G\n"; - record false); - - (* ============================================== *) - (* TEST 10: Larger scalar - 100G *) - (* ============================================== *) - - Stdio.printf "Test 10: Verify 100G = 64G + 32G + 4G\n"; - - (* Compute powers of 2 times G *) - let compute_2pow_g n = - let rec loop i pt = - if i >= n then Some pt - else - match affine_add pt pt with - | None -> None - | Some p -> loop (i + 1) p - in - loop 0 (g_x, g_y) - in - - let p4g_direct = compute_2pow_g 2 in (* 4G = 2^2 * G *) - let p32g = compute_2pow_g 5 in (* 32G = 2^5 * G *) - let p64g = compute_2pow_g 6 in (* 64G = 2^6 * G *) - - (match p4g_direct, p32g, p64g with - | Some (x_4g, y_4g), Some (x_32g, y_32g), Some (x_64g, y_64g) -> - (* First compute 64G + 32G = 96G *) - (match run_point_add ~x1:x_64g ~y1:y_64g ~z1:Z.one ~x2:x_32g ~y2:y_32g ~z2:Z.one with - | None -> record false - | Some (rx_96, ry_96, rz_96) -> - (* Then compute 96G + 4G = 100G *) - (match run_point_add ~x1:rx_96 ~y1:ry_96 ~z1:rz_96 ~x2:x_4g ~y2:y_4g ~z2:Z.one with - | None -> record false - | Some (rx, ry, rz) -> - let expected_100g = scalar_mult (Z.of_int 100) (g_x, g_y) in - match expected_100g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 100G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass)) - | _ -> - Stdio.printf " ERROR: Could not compute required points\n"; - record false); - - - (* ============================================== *) -(* REPEATED OPERATIONS TEST *) -(* ============================================== *) - -Stdio.printf "=== Repeated Operations Test ===\n\n"; - -(* Compute 10G by doing G + G + G + ... (10 additions) *) -Stdio.printf "Test: Compute 10G via 9 sequential additions\n"; - -let repeated_pass = ref true in -let current_x = ref g_x in -let current_y = ref g_y in -let current_z = ref Z.one in - -for i = 1 to 9 do - (* current += G *) - match run_point_add - ~x1:!current_x ~y1:!current_y ~z1:!current_z - ~x2:g_x ~y2:g_y ~z2:Z.one - with - | None -> - Stdio.printf " Addition %d: TIMEOUT\n" i; - repeated_pass := false - | Some (rx, ry, rz) -> - current_x := rx; - current_y := ry; - current_z := rz; - if i % 3 = 0 then - Stdio.printf " Addition %d complete\n" i -done; - -(* Verify final result is 10G *) -let expected_10g = scalar_mult (Z.of_int 10) (g_x, g_y) in -(match expected_10g with -| None -> - Stdio.printf " Expected: point at infinity\n"; - repeated_pass := false -| Some (ex, ey) -> - let pass = proj_equals_affine ~proj_x:!current_x ~proj_y:!current_y ~proj_z:!current_z ~aff_x:ex ~aff_y:ey in - Stdio.printf " Final result matches 10G: %b\n" pass; - if not pass then repeated_pass := false); - -Stdio.printf " %s\n\n" (if !repeated_pass then "PASS ✓" else "FAIL ✗"); -record !repeated_pass; - -(* ============================================== *) -(* ALTERNATING DOUBLING AND ADDITION TEST *) -(* ============================================== *) - -Stdio.printf "=== Double-and-Add Pattern Test ===\n\n"; - -(* Simulate scalar multiplication pattern: double, conditionally add *) -(* Compute 11G = 1011 in binary: G -> 2G -> 2*2G+G=5G -> 2*5G+G=11G *) -Stdio.printf "Test: Compute 11G using double-and-add pattern\n"; - -let dbl_add_pass = ref true in - -(* Start with G *) -let acc_x = ref g_x in -let acc_y = ref g_y in -let acc_z = ref Z.one in - -(* Process bits of 11 = 1011 from MSB-1 downward: 0, 1, 1 *) -let bits = [0; 1; 1] in - -List.iteri bits ~f:(fun i bit -> - (* Double *) - Stdio.printf " Step %d: Double\n" (i * 2); - (match run_point_add - ~x1:!acc_x ~y1:!acc_y ~z1:!acc_z - ~x2:!acc_x ~y2:!acc_y ~z2:!acc_z - with - | None -> - Stdio.printf " TIMEOUT on double\n"; - dbl_add_pass := false - | Some (rx, ry, rz) -> - acc_x := rx; - acc_y := ry; - acc_z := rz); - - (* Conditionally add G if bit is 1 *) - if bit = 1 then begin - Stdio.printf " Step %d: Add G (bit=1)\n" (i * 2 + 1); - (match run_point_add - ~x1:!acc_x ~y1:!acc_y ~z1:!acc_z - ~x2:g_x ~y2:g_y ~z2:Z.one - with - | None -> - Stdio.printf " TIMEOUT on add\n"; - dbl_add_pass := false - | Some (rx, ry, rz) -> - acc_x := rx; - acc_y := ry; - acc_z := rz) - end else - Stdio.printf " Step %d: Skip add (bit=0)\n" (i * 2 + 1) -); - -(* Verify result is 11G *) -let expected_11g = scalar_mult (Z.of_int 11) (g_x, g_y) in -(match expected_11g with -| None -> - Stdio.printf " Expected: point at infinity\n"; - dbl_add_pass := false -| Some (ex, ey) -> - let pass = proj_equals_affine ~proj_x:!acc_x ~proj_y:!acc_y ~proj_z:!acc_z ~aff_x:ex ~aff_y:ey in - Stdio.printf " Result matches 11G: %b\n" pass; - if not pass then dbl_add_pass := false); - -Stdio.printf " %s\n\n" (if !dbl_add_pass then "PASS ✓" else "FAIL ✗"); -record !dbl_add_pass; - -(* ============================================== *) -(* MIXED POINT SOURCES TEST *) -(* ============================================== *) - -Stdio.printf "=== Mixed Point Sources Test ===\n\n"; - -(* Simulate table lookup pattern - alternate between different source points *) -Stdio.printf "Test: Alternating between different precomputed points\n"; - -(* Precompute 1G, 2G, 3G as our "table" *) -let table = Array.create ~len:4 (Z.zero, Z.zero) in -table.(0) <- (g_x, g_y); (* 1G *) - -(match scalar_mult (Z.of_int 2) (g_x, g_y) with -| Some p -> table.(1) <- p -| None -> ()); - -(match scalar_mult (Z.of_int 3) (g_x, g_y) with -| Some p -> table.(2) <- p -| None -> ()); - -(* Compute: G + 2G + 3G + G + 2G = 9G *) -let indices = [0; 1; 2; 0; 1] in (* Table indices to use *) -let mixed_pass = ref true in - -let acc_x = ref Z.zero in -let acc_y = ref Z.one in -let acc_z = ref Z.zero in (* Start at point at infinity *) - -List.iteri indices ~f:(fun i idx -> - let (tx, ty) = table.(idx) in - Stdio.printf " Op %d: Add table[%d] (%dG)\n" i idx (idx + 1); - - (match run_point_add - ~x1:!acc_x ~y1:!acc_y ~z1:!acc_z - ~x2:tx ~y2:ty ~z2:Z.one - with - | None -> - Stdio.printf " TIMEOUT\n"; - mixed_pass := false - | Some (rx, ry, rz) -> - acc_x := rx; - acc_y := ry; - acc_z := rz) -); - -(* 1 + 2 + 3 + 1 + 2 = 9G *) -let expected_9g = scalar_mult (Z.of_int 9) (g_x, g_y) in -(match expected_9g with -| None -> - Stdio.printf " Expected: point at infinity\n"; - mixed_pass := false -| Some (ex, ey) -> - let pass = proj_equals_affine ~proj_x:!acc_x ~proj_y:!acc_y ~proj_z:!acc_z ~aff_x:ex ~aff_y:ey in - Stdio.printf " Result matches 9G: %b\n" pass; - if not pass then mixed_pass := false); - -Stdio.printf " %s\n\n" (if !mixed_pass then "PASS ✓" else "FAIL ✗"); -record !mixed_pass; - -(* ============================================== *) -(* ACCUMULATOR GROWS THEN RESETS TEST *) -(* ============================================== *) - -Stdio.printf "=== Accumulator Reset Pattern Test ===\n\n"; - -(* Pattern: build up accumulator, verify, reset, build again *) -Stdio.printf "Test: Multiple independent scalar multiplications\n"; - -let reset_pass = ref true in - -(* First: compute 5G *) -Stdio.printf " Computing 5G...\n"; -let acc1_x = ref g_x in -let acc1_y = ref g_y in -let acc1_z = ref Z.one in - -for _ = 1 to 4 do - match run_point_add - ~x1:!acc1_x ~y1:!acc1_y ~z1:!acc1_z - ~x2:g_x ~y2:g_y ~z2:Z.one - with - | None -> reset_pass := false - | Some (rx, ry, rz) -> - acc1_x := rx; - acc1_y := ry; - acc1_z := rz -done; - -let expected_5g = scalar_mult (Z.of_int 5) (g_x, g_y) in -(match expected_5g with -| Some (ex, ey) -> - let pass = proj_equals_affine ~proj_x:!acc1_x ~proj_y:!acc1_y ~proj_z:!acc1_z ~aff_x:ex ~aff_y:ey in - Stdio.printf " 5G correct: %b\n" pass; - if not pass then reset_pass := false -| None -> reset_pass := false); - -(* Second: fresh start, compute 7G *) -Stdio.printf " Computing 7G (fresh start)...\n"; -let acc2_x = ref g_x in -let acc2_y = ref g_y in -let acc2_z = ref Z.one in - -for _ = 1 to 6 do - match run_point_add - ~x1:!acc2_x ~y1:!acc2_y ~z1:!acc2_z - ~x2:g_x ~y2:g_y ~z2:Z.one - with - | None -> reset_pass := false - | Some (rx, ry, rz) -> - acc2_x := rx; - acc2_y := ry; - acc2_z := rz -done; - -let expected_7g = scalar_mult (Z.of_int 7) (g_x, g_y) in -(match expected_7g with -| Some (ex, ey) -> - let pass = proj_equals_affine ~proj_x:!acc2_x ~proj_y:!acc2_y ~proj_z:!acc2_z ~aff_x:ex ~aff_y:ey in - Stdio.printf " 7G correct: %b\n" pass; - if not pass then reset_pass := false -| None -> reset_pass := false); - -(* Third: compute 5G + 7G = 12G using results from above *) -Stdio.printf " Computing 5G + 7G = 12G...\n"; -(match run_point_add - ~x1:!acc1_x ~y1:!acc1_y ~z1:!acc1_z - ~x2:!acc2_x ~y2:!acc2_y ~z2:!acc2_z -with -| None -> - Stdio.printf " TIMEOUT\n"; - reset_pass := false -| Some (rx, ry, rz) -> - let expected_12g = scalar_mult (Z.of_int 12) (g_x, g_y) in - (match expected_12g with - | Some (ex, ey) -> - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " 12G correct: %b\n" pass; - if not pass then reset_pass := false - | None -> reset_pass := false)); - -Stdio.printf " %s\n\n" (if !reset_pass then "PASS ✓" else "FAIL ✗"); -record !reset_pass; - -(* ============================================== *) -(* RAPID SMALL OPERATIONS TEST *) -(* ============================================== *) - -Stdio.printf "=== Rapid Small Operations Test ===\n\n"; - -(* Many quick additions with small coordinates to stress control flow *) -Stdio.printf "Test: 20 rapid point additions\n"; - -let rapid_pass = ref true in - -(* Actually, let's just use multiples of G to ensure valid points *) -let acc_x = ref g_x in -let acc_y = ref g_y in -let acc_z = ref Z.one in - -for i = 1 to 20 do - match run_point_add - ~x1:!acc_x ~y1:!acc_y ~z1:!acc_z - ~x2:g_x ~y2:g_y ~z2:Z.one - with - | None -> - Stdio.printf " Op %d: TIMEOUT\n" i; - rapid_pass := false - | Some (rx, ry, rz) -> - acc_x := rx; - acc_y := ry; - acc_z := rz -done; - -(* Verify: should be 21G *) -let expected_21g = scalar_mult (Z.of_int 21) (g_x, g_y) in -(match expected_21g with -| None -> - Stdio.printf " Expected: point at infinity\n"; - rapid_pass := false -| Some (ex, ey) -> - let pass = proj_equals_affine ~proj_x:!acc_x ~proj_y:!acc_y ~proj_z:!acc_z ~aff_x:ex ~aff_y:ey in - Stdio.printf " After 20 additions, result matches 21G: %b\n" pass; - if not pass then rapid_pass := false); - -Stdio.printf " %s\n\n" (if !rapid_pass then "PASS ✓" else "FAIL ✗"); -record !rapid_pass; - - (* ============================================== *) - (* TEST SUMMARY *) - (* ============================================== *) - - let results_list = List.rev !results in - let passed = List.count results_list ~f:Fn.id in - let total = List.length results_list in - - Stdio.printf "=== Test Summary ===\n"; - Stdio.printf "Passed: %d/%d\n" passed total; - - if passed = total then begin - Stdio.printf "\n"; - Stdio.printf "███████╗██╗ ██╗ ██████╗ ██████╗███████╗███████╗███████╗\n"; - Stdio.printf "██╔════╝██║ ██║██╔════╝██╔════╝██╔════╝██╔════╝██╔════╝\n"; - Stdio.printf "███████╗██║ ██║██║ ██║ █████╗ ███████╗███████╗\n"; - Stdio.printf "╚════██║██║ ██║██║ ██║ ██╔══╝ ╚════██║╚════██║\n"; - Stdio.printf "███████║╚██████╔╝╚██████╗╚██████╗███████╗███████║███████║\n"; - Stdio.printf "╚══════╝ ╚═════╝ ╚═════╝ ╚═════╝╚══════╝╚══════╝╚══════╝\n"; - Stdio.printf "\nAll point addition tests passed! ✓✓✓\n" - end else - Stdio.printf "\n✗ Some tests failed - review above for details\n" \ No newline at end of file diff --git a/test/test_point_mul.ml b/test/test_point_mul.ml deleted file mode 100644 index 2cd4d7e..0000000 --- a/test/test_point_mul.ml +++ /dev/null @@ -1,642 +0,0 @@ -open Base -open Hardcaml - -let () = - Stdio.printf "=== Scalar Multiplication Unit Test ===\n\n"; - - let scope = Scope.create ~flatten_design:true () in - let module Sim = Cyclesim.With_interface(Point_mul.I)(Point_mul.O) in - let sim = Sim.create (Point_mul.create scope) in - - let inputs = Cyclesim.inputs sim in - let outputs = Cyclesim.outputs sim in - - let prime_p = Arith.Config.prime_p in - - (* secp256k1 curve parameters: y² = x³ + 7, so a = 0, b = 7, b3 = 21 *) - let param_a = Z.zero in - let param_b3 = Z.of_int 21 in - - (* Generator point G for secp256k1 - must match Config *) - let g_x = Z.of_string_base 16 "79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" in - let g_y = Z.of_string_base 16 "483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8" in - - (* Curve order n for secp256k1 *) - let curve_order = Z.of_string_base 16 "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141" in - - let z_to_bits z = - let hex_str = Z.format "%x" z in - let padded = String.pad_left hex_str ~len:64 ~char:'0' in - Bits.of_hex ~width:256 padded - in - - let bits_to_z bits = Z.of_string_base 2 (Bits.to_bstr bits) in - - (* Modular arithmetic helpers *) - let mod_add a b = Z.((a + b) mod prime_p) in - let mod_sub a b = Z.(erem (a - b) prime_p) in - let mod_mul a b = Z.((a * b) mod prime_p) in - let mod_inv a = Z.invert a prime_p in - - (* Convert projective (X:Y:Z) to affine (x, y) *) - let proj_to_affine x y z = - if Z.equal z Z.zero then - None (* Point at infinity *) - else - let z_inv = mod_inv z in - let aff_x = mod_mul x z_inv in - let aff_y = mod_mul y z_inv in - Some (aff_x, aff_y) - in - - (* Check if projective point equals affine point *) - let proj_equals_affine ~proj_x ~proj_y ~proj_z ~aff_x ~aff_y = - match proj_to_affine proj_x proj_y proj_z with - | None -> false - | Some (x, y) -> Z.equal x aff_x && Z.equal y aff_y - in - - (* Reference implementation: affine point addition *) - let affine_add (x1, y1) (x2, y2) = - if Z.equal x1 x2 then begin - if Z.equal y1 y2 then - (* Point doubling *) - let lambda = mod_mul (mod_mul (Z.of_int 3) (mod_mul x1 x1)) (mod_inv (mod_mul (Z.of_int 2) y1)) in - let x3 = mod_sub (mod_mul lambda lambda) (mod_add x1 x2) in - let y3 = mod_sub (mod_mul lambda (mod_sub x1 x3)) y1 in - Some (x3, y3) - else - (* P + (-P) = O *) - None - end else begin - let lambda = mod_mul (mod_sub y2 y1) (mod_inv (mod_sub x2 x1)) in - let x3 = mod_sub (mod_sub (mod_mul lambda lambda) x1) x2 in - let y3 = mod_sub (mod_mul lambda (mod_sub x1 x3)) y1 in - Some (x3, y3) - end - in - - (* Reference scalar multiplication using double-and-add *) - let scalar_mult k (x, y) = - let rec loop n acc pt = - if Z.equal n Z.zero then acc - else - let acc' = - if Z.(equal (n land one) one) then - match acc with - | None -> Some pt - | Some a -> affine_add a pt - else acc - in - let pt' = - match affine_add pt pt with - | None -> pt - | Some p -> p - in - loop Z.(n asr 1) acc' pt' - in - loop k None (x, y) - in - - let reset () = - inputs.clear := Bits.vdd; - inputs.start := Bits.gnd; - inputs.scalar := Bits.zero 256; - inputs.param_a := Bits.zero 256; - inputs.param_b3 := Bits.zero 256; - Cyclesim.cycle sim; - inputs.clear := Bits.gnd; - Cyclesim.cycle sim - in - - let run_scalar_mult ~scalar = - reset (); - - inputs.scalar := z_to_bits scalar; - inputs.param_a := z_to_bits param_a; - inputs.param_b3 := z_to_bits param_b3; - inputs.start := Bits.vdd; - Cyclesim.cycle sim; - inputs.start := Bits.gnd; - - let max_cycles = 6000000 in - let rec wait n = - if n >= max_cycles then begin - Stdio.printf " TIMEOUT after %d cycles\n" max_cycles; - None - end else if Bits.to_bool !(outputs.done_) then begin - let result_x = bits_to_z !(outputs.x) in - let result_y = bits_to_z !(outputs.y) in - let result_z = bits_to_z !(outputs.z) in - Stdio.printf " Completed in %d cycles\n" n; - Some (result_x, result_y, result_z) - end else begin - Cyclesim.cycle sim; - wait (n + 1) - end - in - wait 0 - in - - let results = ref [] in - let record result = results := result :: !results in - - (* ============================================== *) - (* TEST 1: [1]G = G *) - (* ============================================== *) - - Stdio.printf "Test 1: [1]G = G\n"; - Stdio.printf " Scalar = 1\n"; - - (match run_scalar_mult ~scalar:Z.one with - | None -> record false - | Some (rx, ry, rz) -> - Stdio.printf " Result (projective): X=%s..., Z=%s...\n" - (String.prefix (Z.to_string rx) 20) - (String.prefix (Z.to_string rz) 20); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:g_x ~aff_y:g_y in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 2: [2]G = 2G *) - (* ============================================== *) - - Stdio.printf "Test 2: [2]G = 2G\n"; - Stdio.printf " Scalar = 2\n"; - - let expected_2g = scalar_mult (Z.of_int 2) (g_x, g_y) in - - (match run_scalar_mult ~scalar:(Z.of_int 2) with - | None -> record false - | Some (rx, ry, rz) -> - match expected_2g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected (affine): x=%s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 3: [3]G = 3G *) - (* ============================================== *) - - Stdio.printf "Test 3: [3]G = 3G\n"; - Stdio.printf " Scalar = 3 (binary: 11)\n"; - - let expected_3g = scalar_mult (Z.of_int 3) (g_x, g_y) in - - (match run_scalar_mult ~scalar:(Z.of_int 3) with - | None -> record false - | Some (rx, ry, rz) -> - match expected_3g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 3G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 4: [7]G = 7G *) - (* ============================================== *) - - Stdio.printf "Test 4: [7]G = 7G\n"; - Stdio.printf " Scalar = 7 (binary: 111)\n"; - - let expected_7g = scalar_mult (Z.of_int 7) (g_x, g_y) in - - (match run_scalar_mult ~scalar:(Z.of_int 7) with - | None -> record false - | Some (rx, ry, rz) -> - match expected_7g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 7G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 5: [10]G = 10G *) - (* ============================================== *) - - Stdio.printf "Test 5: [10]G = 10G\n"; - Stdio.printf " Scalar = 10 (binary: 1010)\n"; - - let expected_10g = scalar_mult (Z.of_int 10) (g_x, g_y) in - - (match run_scalar_mult ~scalar:(Z.of_int 10) with - | None -> record false - | Some (rx, ry, rz) -> - match expected_10g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 10G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 6: [100]G = 100G *) - (* ============================================== *) - - Stdio.printf "Test 6: [100]G = 100G\n"; - Stdio.printf " Scalar = 100 (binary: 1100100)\n"; - - let expected_100g = scalar_mult (Z.of_int 100) (g_x, g_y) in - - (match run_scalar_mult ~scalar:(Z.of_int 100) with - | None -> record false - | Some (rx, ry, rz) -> - match expected_100g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 100G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 7: [255]G = 255G *) - (* ============================================== *) - - Stdio.printf "Test 7: [255]G = 255G\n"; - Stdio.printf " Scalar = 255 (binary: 11111111)\n"; - - let expected_255g = scalar_mult (Z.of_int 255) (g_x, g_y) in - - (match run_scalar_mult ~scalar:(Z.of_int 255) with - | None -> record false - | Some (rx, ry, rz) -> - match expected_255g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 255G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 8: Power of 2 - [256]G *) - (* ============================================== *) - - Stdio.printf "Test 8: [256]G = 256G\n"; - Stdio.printf " Scalar = 256 (binary: 100000000)\n"; - - let expected_256g = scalar_mult (Z.of_int 256) (g_x, g_y) in - - (match run_scalar_mult ~scalar:(Z.of_int 256) with - | None -> record false - | Some (rx, ry, rz) -> - match expected_256g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected 256G_x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 9: Large power of 2 - [2^16]G *) - (* ============================================== *) - - Stdio.printf "Test 9: [2^16]G = 65536G\n"; - Stdio.printf " Scalar = 65536\n"; - - let scalar_2_16 = Z.shift_left Z.one 16 in - let expected_2_16_g = scalar_mult scalar_2_16 (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_2_16 with - | None -> record false - | Some (rx, ry, rz) -> - match expected_2_16_g with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 10: Sparse bits - [2^100 + 1]G *) - (* ============================================== *) - - Stdio.printf "Test 10: [2^100 + 1]G\n"; - Stdio.printf " Scalar has bits set at positions 0 and 100 only\n"; - - let scalar_sparse = Z.(shift_left one 100 + one) in - let expected_sparse = scalar_mult scalar_sparse (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_sparse with - | None -> record false - | Some (rx, ry, rz) -> - match expected_sparse with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 11: Dense bits in low byte *) - (* ============================================== *) - - Stdio.printf "Test 11: [0xFF]G = 255G (all low bits set)\n"; - - let scalar_ff = Z.of_int 0xFF in - let expected_ff = scalar_mult scalar_ff (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_ff with - | None -> record false - | Some (rx, ry, rz) -> - match expected_ff with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 12: Alternating bits pattern *) - (* ============================================== *) - - Stdio.printf "Test 12: [0xAAAA]G (alternating bits: 1010...)\n"; - - let scalar_aaaa = Z.of_int 0xAAAA in - let expected_aaaa = scalar_mult scalar_aaaa (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_aaaa with - | None -> record false - | Some (rx, ry, rz) -> - match expected_aaaa with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 13: Known ECDSA-like scalar *) - (* ============================================== *) - - Stdio.printf "Test 13: Large random-looking scalar\n"; - - let scalar_large = Z.of_string_base 16 "DEADBEEFCAFEBABE0123456789ABCDEF" in - Stdio.printf " Scalar = 0x%s\n" (Z.format "%X" scalar_large); - - let expected_large = scalar_mult scalar_large (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_large with - | None -> record false - | Some (rx, ry, rz) -> - match expected_large with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 14: 128-bit scalar *) - (* ============================================== *) - - Stdio.printf "Test 14: 128-bit scalar\n"; - - let scalar_128 = Z.of_string_base 16 "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" in - Stdio.printf " Scalar = 2^128 - 1\n"; - - let expected_128 = scalar_mult scalar_128 (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_128 with - | None -> record false - | Some (rx, ry, rz) -> - match expected_128 with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 15: Full 256-bit scalar (max value) *) - (* ============================================== *) - - Stdio.printf "Test 15: Full 256-bit scalar (2^256 - 1)\n"; - - let scalar_max = Z.(shift_left one 256 - one) in - Stdio.printf " Scalar = 2^256 - 1\n"; - - let expected_max = scalar_mult scalar_max (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_max with - | None -> record false - | Some (rx, ry, rz) -> - match expected_max with - | None -> - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass - | Some (ex, ey) -> - Stdio.printf " Expected x = %s...\n" (String.prefix (Z.to_string ex) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 16: [n-1]G where n is curve order *) - (* ============================================== *) - - Stdio.printf "Test 16: [n-1]G = -G (one less than curve order)\n"; - - let scalar_n_minus_1 = Z.(curve_order - one) in - Stdio.printf " Scalar = n - 1\n"; - - (* [n-1]G = -G, which has same x as G but negated y *) - let expected_neg_g_y = mod_sub Z.zero g_y in - - (match run_scalar_mult ~scalar:scalar_n_minus_1 with - | None -> record false - | Some (rx, ry, rz) -> - Stdio.printf " Result x = %s...\n" (String.prefix (Z.to_string rx) 30); - let pass = proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:g_x ~aff_y:expected_neg_g_y in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 17: [n]G = O (curve order gives infinity)*) - (* ============================================== *) - - Stdio.printf "Test 17: [n]G = O (curve order gives point at infinity)\n"; - - Stdio.printf " Scalar = n (curve order)\n"; - - (match run_scalar_mult ~scalar:curve_order with - | None -> record false - | Some (_rx, _ry, rz) -> - Stdio.printf " Result Z = %s\n" (Z.to_string rz); - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 18: [0]G = O (zero scalar) *) - (* ============================================== *) - - Stdio.printf "Test 18: [0]G = O (zero scalar gives point at infinity)\n"; - - (match run_scalar_mult ~scalar:Z.zero with - | None -> record false - | Some (_rx, _ry, rz) -> - Stdio.printf " Result Z = %s\n" (Z.to_string rz); - let pass = Z.equal rz Z.zero in - Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); - record pass); - - (* ============================================== *) - (* TEST 19: Repeated scalar multiplications *) - (* ============================================== *) - - Stdio.printf "Test 19: Repeated scalar multiplications (consistency check)\n"; - - let repeated_pass = ref true in - let test_scalars = [5; 17; 42; 100; 1000] in - - List.iter test_scalars ~f:(fun s -> - let scalar = Z.of_int s in - let expected = scalar_mult scalar (g_x, g_y) in - - match run_scalar_mult ~scalar with - | None -> - Stdio.printf " [%d]G: TIMEOUT\n" s; - repeated_pass := false - | Some (rx, ry, rz) -> - match expected with - | None -> - if not (Z.equal rz Z.zero) then begin - Stdio.printf " [%d]G: expected infinity, got finite point\n" s; - repeated_pass := false - end else - Stdio.printf " [%d]G: OK (infinity)\n" s - | Some (ex, ey) -> - if proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey then - Stdio.printf " [%d]G: OK\n" s - else begin - Stdio.printf " [%d]G: MISMATCH\n" s; - repeated_pass := false - end); - - Stdio.printf " %s\n\n" (if !repeated_pass then "PASS ✓" else "FAIL ✗"); - record !repeated_pass; - - (* ============================================== *) - (* TEST 20: Bit pattern edge cases *) - (* ============================================== *) - - Stdio.printf "Test 20: Bit pattern edge cases\n"; - - let edge_pass = ref true in - - (* Single high bit *) - let scalar_high = Z.shift_left Z.one 255 in - Stdio.printf " Testing 2^255 (single high bit)...\n"; - let expected_high = scalar_mult scalar_high (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_high with - | None -> - Stdio.printf " TIMEOUT\n"; - edge_pass := false - | Some (rx, ry, rz) -> - match expected_high with - | None -> - if not (Z.equal rz Z.zero) then edge_pass := false - | Some (ex, ey) -> - if not (proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey) then - edge_pass := false); - - (* High bit and low bit *) - let scalar_high_low = Z.(shift_left one 255 + one) in - Stdio.printf " Testing 2^255 + 1 (high and low bits)...\n"; - let expected_high_low = scalar_mult scalar_high_low (g_x, g_y) in - - (match run_scalar_mult ~scalar:scalar_high_low with - | None -> - Stdio.printf " TIMEOUT\n"; - edge_pass := false - | Some (rx, ry, rz) -> - match expected_high_low with - | None -> - if not (Z.equal rz Z.zero) then edge_pass := false - | Some (ex, ey) -> - if not (proj_equals_affine ~proj_x:rx ~proj_y:ry ~proj_z:rz ~aff_x:ex ~aff_y:ey) then - edge_pass := false); - - Stdio.printf " %s\n\n" (if !edge_pass then "PASS ✓" else "FAIL ✗"); - record !edge_pass; - - (* ============================================== *) - (* TEST SUMMARY *) - (* ============================================== *) - - let results_list = List.rev !results in - let passed = List.count results_list ~f:Fn.id in - let total = List.length results_list in - - Stdio.printf "=== Test Summary ===\n"; - Stdio.printf "Passed: %d/%d\n" passed total; - - if passed = total then begin - Stdio.printf "\n"; - Stdio.printf "███████╗██╗ ██╗ ██████╗ ██████╗███████╗███████╗███████╗\n"; - Stdio.printf "██╔════╝██║ ██║██╔════╝██╔════╝██╔════╝██╔════╝██╔════╝\n"; - Stdio.printf "███████╗██║ ██║██║ ██║ █████╗ ███████╗███████╗\n"; - Stdio.printf "╚════██║██║ ██║██║ ██║ ██╔══╝ ╚════██║╚════██║\n"; - Stdio.printf "███████║╚██████╔╝╚██████╗╚██████╗███████║███████║███████║\n"; - Stdio.printf "╚══════╝ ╚═════╝ ╚═════╝ ╚═════╝╚══════╝╚══════╝╚══════╝\n"; - Stdio.printf "\nAll scalar multiplication tests passed! ✓✓✓\n" - end else - Stdio.printf "\n✗ Some tests failed - review above for details\n" \ No newline at end of file