diff --git a/src/arith.ml b/src/arith.ml index 4eaa5b0..8f29744 100644 --- a/src/arith.ml +++ b/src/arith.ml @@ -56,7 +56,7 @@ module I = struct type 'a t = { clock : 'a ; clear : 'a - ; start : 'a + ; valid : 'a ; op : 'a [@bits 2] ; prime_sel : 'a ; reg_read_data_a : 'a [@bits Config.width] @@ -68,7 +68,7 @@ end module O = struct type 'a t = { busy : 'a - ; done_ : 'a + ; ready : 'a ; reg_write_data : 'a [@bits Config.width] ; inv_exists : 'a } @@ -98,13 +98,6 @@ let create scope (i : _ I.t) = let start_mul = Variable.reg spec ~width:1 in let start_inv = Variable.reg spec ~width:1 in - (* Result capture *) - let result_reg = Variable.reg spec ~width:Config.width in - let inv_exists_reg = Variable.reg spec ~width:1 in - - (* Output registers *) - let done_flag = Variable.reg spec ~width:1 in - (* Prime constants *) let prime_p_const = Signal.of_constant (Config.z_to_constant Config.prime_p) in let prime_n_const = Signal.of_constant (Config.z_to_constant Config.prime_n) in @@ -211,11 +204,10 @@ let create scope (i : _ I.t) = (* TODO move start_mul, start_inv clear to Compute step as well when updated *) start_mul <-- gnd; start_inv <-- gnd; - done_flag <-- gnd; sm.switch [ State.Idle, [ - when_ i.start [ + when_ i.valid [ (* Latch all inputs *) op_reg <-- i.op; prime_sel_reg <-- i.prime_sel; @@ -242,9 +234,6 @@ let create scope (i : _ I.t) = mod_add_valid <-- gnd; mod_sub_valid <-- gnd; - done_flag <-- vdd; - result_reg <-- op_result; - inv_exists_reg <-- mod_inv_out.exists; sm.set_next Idle; ]; ]; @@ -256,7 +245,7 @@ let create scope (i : _ I.t) = { O. busy = busy -- "busy" - ; done_ = done_flag.value -- "done" - ; reg_write_data = result_reg.value -- "reg_write_data" - ; inv_exists = inv_exists_reg.value -- "inv_exists" + ; ready = op_ready -- "done" + ; reg_write_data = op_result -- "reg_write_data" + ; inv_exists = mod_inv_out.exists -- "inv_exists" } \ No newline at end of file diff --git a/src/ecdsa.ml b/src/ecdsa.ml index 86da25b..21a7f12 100644 --- a/src/ecdsa.ml +++ b/src/ecdsa.ml @@ -33,12 +33,6 @@ open Signal - Q: public key (currently 2G for testing) - G+Q: precomputed sum (currently 3G) - State machine: - Idle -> Prep_op (3 ops) -> Loop/Load/Run_add (scalar mult) - -> Finalize_op (2 ops) -> Compare -> Done -> Idle - - Cycle count: ~1.5-2M cycles for typical 256-bit scalars - Note: Does not currently reduce x_affine mod n before comparison, or check that z, s, r are within range. @@ -126,19 +120,31 @@ end module State = struct type t = | Idle - | Prep_op + | Run_prep + | Sample_prep (* Separate step to sample the u1, u2 values when visible in the register file *) | Loop - | Load - | Run_add - | Finalize_op - | Compare - | Done + | Run_point_add + | Run_finalize + | Done (* Separate step to sample the result when visible in the register file *) [@@deriving sexp_of, compare, enumerate] end type instr = { op : int; src1 : int; src2 : int; dst : int } -let program = [| +(* Implements: *) +(* w = s^(-1) mod n *) +(* u1 = z * w mod n *) +(* u2 = r * w mod n *) +(* *) +(* assumes t0=s, t1=z, t2=r *) +(* u1 placed in t1, u2 in t2 *) +let program_prepare = [| + { op = Op.inv; src1 = Config.t0; src2 = Config.t0; dst = Config.t0 }; + { op = Op.mul; src1 = Config.t1; src2 = Config.t0; dst = Config.t1 }; + { op = Op.mul; src1 = Config.t2; src2 = Config.t0; dst = Config.t2 }; +|] + +let program_point_add = [| { op = Op.mul; src1 = Config.x1; src2 = Config.x2; dst = Config.t0 }; { op = Op.mul; src1 = Config.y1; src2 = Config.y2; dst = Config.t1 }; { op = Op.mul; src1 = Config.z1; src2 = Config.z2; dst = Config.t2 }; @@ -181,6 +187,30 @@ let program = [| { op = Op.add; src1 = Config.z3; src2 = Config.t0; dst = Config.z1 }; |] +(* Implements: *) +(* z_inv = Z1^(-1) mod p *) +(* x_affine = X1 * z_inv mod p *) +(* result = x_affine - r mod p *) +(* *) +(* assumes t2=r - consistent with prepare *) +(* result placed in t0 *) +(* if result == 0, then x_affine == r *) +let program_finalize = [| + { op = Op.inv; src1 = Config.z1; src2 = Config.z1; dst = Config.t0 }; + { op = Op.mul; src1 = Config.x1; src2 = Config.t0; dst = Config.t0 }; + { op = Op.sub; src1 = Config.t0; src2 = Config.t2; dst = Config.t0 }; +|] + +let program = Array.concat [program_prepare; program_point_add; program_finalize] + +(* Segment boundaries — compile-time constants *) +let prepare_start = 0 +let prepare_end = Array.length program_prepare - 1 +let point_add_start = Array.length program_prepare +let point_add_end = point_add_start + Array.length program_point_add - 1 +let finalize_start = point_add_end + 1 +let finalize_end = finalize_start + Array.length program_finalize - 1 + (* Helper to extract a bit from a signal using a signal index *) let bit_select_dynamic signal index_signal = let width = Signal.width signal in @@ -190,6 +220,7 @@ let bit_select_dynamic signal index_signal = let create scope (i : _ I.t) = let open Always in let width = Config.width in + let bit_cnt_width = Int.ceil_log2 width in let addr_width = Config.reg_addr_width in let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in @@ -210,110 +241,58 @@ let create scope (i : _ I.t) = let const_gpqy = of_z ~width Config.gpq_y in let const_gpqz = of_z ~width Config.gpq_z in - (* Input registers *) - let z_reg = Variable.reg spec ~width in + (* Input register *) let r_reg = Variable.reg spec ~width in - let s_reg = Variable.reg spec ~width in - (* Prep phase registers *) - let w_reg = Variable.reg spec ~width in (* s^(-1) mod n *) - let u1_reg = Variable.reg spec ~width in (* z * w mod n *) - let u2_reg = Variable.reg spec ~width in (* r * w mod n *) - let prep_step = Variable.reg spec ~width:2 in - let prep_op_started = Variable.reg spec ~width:1 in + (* Scalar registers — captured at end of Prep_op, read throughout Loop *) + let u1_reg = Variable.reg spec ~width in (* z * s^(-1) mod n *) + let u2_reg = Variable.reg spec ~width in (* r * s^(-1) mod n *) + + (* Unified program counter *) + let pc = Variable.reg spec ~width:6 in (* Main loop registers *) - let bit_pos = Variable.reg spec ~width:8 in + let bit_pos = Variable.reg spec ~width:bit_cnt_width in let doubling = Variable.reg spec ~width:1 in let last_step = Variable.reg spec ~width:1 in - let step = Variable.reg spec ~width:6 in - let load_idx = Variable.reg spec ~width:2 in - let op_started = Variable.reg spec ~width:1 in - - (* Latched second operand registers *) - let x2_latched = Variable.reg spec ~width in - let y2_latched = Variable.reg spec ~width in - let z2_latched = Variable.reg spec ~width in - - (* Finalize phase registers *) - let finalize_step = Variable.reg spec ~width:1 in - let finalize_op_started = Variable.reg spec ~width:1 in - let z_inv_reg = Variable.reg spec ~width in - let x_affine_reg = Variable.reg spec ~width in - - (* Output registers *) - let out_x = Variable.reg spec ~width in - let out_y = Variable.reg spec ~width in - let out_z = Variable.reg spec ~width in - let valid_reg = Variable.reg spec ~width:1 in - let done_flag = Variable.reg spec ~width:1 in - - let current_bit_u = bit_select_dynamic u1_reg.value bit_pos.value in - let current_bit_v = bit_select_dynamic u2_reg.value bit_pos.value in + + (* Output wires *) + let done_w = Variable.wire ~default:gnd in + let valid_w = Variable.wire ~default:gnd in + + let current_bit_u1 = bit_select_dynamic u1_reg.value bit_pos.value in + let current_bit_u2 = bit_select_dynamic u2_reg.value bit_pos.value in (* Check if P is point at infinity (z1 = 0) *) let p_is_infinity = reg_file.(Config.z1).value ==:. 0 in - (* Instruction decode for point addition *) - let default_instr = of_int ~width:addr_width 0 in - let decode field = - mux step.value + (* Unified instruction decode over full program *) + let pc_decode field w = + mux pc.value (Array.to_list (Array.map program ~f:(fun instr -> - of_int ~width:addr_width (field instr)))) - |> fun s -> mux2 (step.value >=:. Config.num_steps) default_instr s + of_int ~width:w (field instr)))) in - let decode_op = - mux step.value - (Array.to_list (Array.map program ~f:(fun instr -> - of_int ~width:2 instr.op))) - |> fun s -> mux2 (step.value >=:. Config.num_steps) (zero 2) s - in - - let current_dst = decode (fun instr -> instr.dst) in - let current_src1 = decode (fun instr -> instr.src1) in - let current_src2 = decode (fun instr -> instr.src2) in - let current_op_from_program = decode_op in - - (* Arith control signals *) - let arith_start = Variable.wire ~default:gnd in - let arith_op = Variable.wire ~default:(zero 2) in - let arith_prime_sel = Variable.wire ~default:gnd in - let arith_a = Variable.wire ~default:(zero width) in - let arith_b = Variable.wire ~default:(zero width) in + let current_src1 = pc_decode (fun instr -> instr.src1) addr_width in + let current_src2 = pc_decode (fun instr -> instr.src2) addr_width in + let current_dst = pc_decode (fun instr -> instr.dst) addr_width in + let current_op = pc_decode (fun instr -> instr.op) 2 in (* Use latched values for x2/y2/z2 in point addition *) let reg_read idx = let base_values = Array.to_list (Array.map reg_file ~f:(fun r -> r.value)) in - let base = mux idx base_values in - mux2 (idx ==:. Config.x2) x2_latched.value - (mux2 (idx ==:. Config.y2) y2_latched.value - (mux2 (idx ==:. Config.z2) z2_latched.value base)) + mux idx base_values in - (* Select arith inputs based on current state *) - let in_prep_or_finalize = (sm.is Prep_op) |: (sm.is Finalize_op) in - - let arith_read_data_a = - mux2 in_prep_or_finalize arith_a.value (reg_read current_src1) - in - let arith_read_data_b = - mux2 in_prep_or_finalize arith_b.value (reg_read current_src2) - in - - let arith_op_selected = - mux2 in_prep_or_finalize arith_op.value current_op_from_program - in - - let arith_prime_sel_selected = - mux2 in_prep_or_finalize arith_prime_sel.value gnd - in + let arith_prime_sel_selected = (pc.value >=:. prepare_start) &&: (pc.value <=:. prepare_end) in (* mod n for prepare *) + let arith_read_data_a = reg_read current_src1 in + let arith_read_data_b = reg_read current_src2 in let arith_out = Arith.create (Scope.sub_scope scope "arith") { Arith.I. clock = i.clock ; clear = i.clear - ; start = arith_start.value - ; op = arith_op_selected + ; valid = sm.is (State.Run_prep) ||: sm.is State.Run_point_add ||: sm.is State.Run_finalize + ; op = current_op ; prime_sel = arith_prime_sel_selected ; reg_read_data_a = arith_read_data_a ; reg_read_data_b = arith_read_data_b @@ -321,199 +300,127 @@ let create scope (i : _ I.t) = in compile [ - done_flag <-- gnd; + (* default value of output wires, should be redundant with default value specified in the declaration but just in case *) + done_w <-- gnd; + valid_w <-- gnd; + + (* in either state, if an arithmetic operation completed, store the result and increment the PC *) + when_ arith_out.ready [ + pc <-- pc.value +:. 1; + proc (Array.to_list (Array.mapi reg_file ~f:(fun idx reg -> + when_ (current_dst ==:. idx) [ + reg <-- arith_out.reg_write_data ]))); + ]; sm.switch [ State.Idle, [ when_ i.start [ - z_reg <-- i.z; + (* capture r value to be used in Finalize state *) r_reg <-- i.r; - s_reg <-- i.s; - prep_step <--. 0; - prep_op_started <-- gnd; - reg_file.(Config.param_a) <-- i.param_a; + + reg_file.(Config.param_a) <-- i.param_a; reg_file.(Config.param_b3) <-- i.param_b3; - sm.set_next Prep_op; + + (* Initialize and start the Prepare calculations *) + reg_file.(Config.t0) <-- i.s; (* see program_prepare assumptions *) + reg_file.(Config.t1) <-- i.z; (* see program_prepare assumptions *) + reg_file.(Config.t2) <-- i.r; (* see program_prepare assumptions *) + pc <--. prepare_start; + sm.set_next Run_prep; ]; ]; - State.Prep_op, [ - (* Set up arith inputs based on prep_step *) - arith_prime_sel <-- vdd; (* All prep ops use mod n *) - - if_ (prep_step.value ==:. 0) [ - (* w = s^(-1) mod n *) - arith_op <--. Op.inv; - arith_a <-- s_reg.value; - arith_b <-- zero width; - ] @@ elif (prep_step.value ==:. 1) [ - (* u1 = z * w mod n *) - arith_op <--. Op.mul; - arith_a <-- z_reg.value; - arith_b <-- w_reg.value; - ] [ - (* u2 = r * w mod n *) - arith_op <--. Op.mul; - arith_a <-- r_reg.value; - arith_b <-- w_reg.value; + State.Run_prep, [ + (* wait for program to finish and proceed to next state *) + when_ (arith_out.ready &&: (pc.value ==:. prepare_end)) [ + sm.set_next Sample_prep; ]; + ]; - if_ (~:(prep_op_started.value)) [ - arith_start <-- vdd; - prep_op_started <-- vdd; - ] [ - when_ arith_out.done_ [ - (* Store result *) - if_ (prep_step.value ==:. 0) [ - w_reg <-- arith_out.reg_write_data; - ] @@ elif (prep_step.value ==:. 1) [ - u1_reg <-- arith_out.reg_write_data; - ] [ - u2_reg <-- arith_out.reg_write_data; - ]; - - if_ (prep_step.value ==:. 2) [ - (* Initialize for main loop *) - bit_pos <--. 255; - doubling <-- vdd; - last_step <-- gnd; - step <-- zero 6; - op_started <-- gnd; - reg_file.(Config.x1) <-- of_z ~width Config.infinity_x; - reg_file.(Config.y1) <-- of_z ~width Config.infinity_y; - reg_file.(Config.z1) <-- of_z ~width Config.infinity_z; - sm.set_next Loop; - ] [ - prep_step <-- prep_step.value +:. 1; - prep_op_started <-- gnd; - ]; - ]; - ]; + State.Sample_prep, [ + (* Capture u1/u2 only AFTER! the program has finished - u2 is not available in the register earlier! *) + u1_reg <-- reg_file.(Config.t1).value; (* see program_prepare results handling *) + u2_reg <-- reg_file.(Config.t2).value; (* see program_prepare results handling *) + + (* Initialize main point multiply calculation and move to next state *) + bit_pos <-- ones bit_cnt_width; + doubling <-- vdd; + last_step <-- gnd; + reg_file.(Config.x1) <-- of_z ~width Config.infinity_x; + reg_file.(Config.y1) <-- of_z ~width Config.infinity_y; + reg_file.(Config.z1) <-- of_z ~width Config.infinity_z; + + sm.set_next Loop; ]; State.Loop, [ - if_ last_step.value [ - (* Initialize for finalize phase *) - finalize_step <-- gnd; - finalize_op_started <-- gnd; - sm.set_next Finalize_op; - ] @@ elif doubling.value [ - (* Doubling phase *) + if_ last_step.value [ (* completion condition *) + (* Results are stored in x1, y1, z1, finalize will use those *) + (* Initialize and start finalize calculations *) + reg_file.(Config.t2) <-- r_reg.value; (* see program_finalize assumptions *) + pc <--. finalize_start; + sm.set_next Run_finalize; + ] + @@ elif doubling.value [ + (* In Doubling phase, next will be Adding *) doubling <-- gnd; - if_ p_is_infinity [ - sm.set_next Loop; + + (* only calculate if P != infinity, can be skipped otherwise *) + if_ (~:p_is_infinity) [ + (* Initialize and start point_add program *) + reg_file.(Config.x2) <-- reg_file.(Config.x1).value; + reg_file.(Config.y2) <-- reg_file.(Config.y1).value; + reg_file.(Config.z2) <-- reg_file.(Config.z1).value; + pc <--. point_add_start; + sm.set_next Run_point_add; ] [ - load_idx <--. 0; - sm.set_next Load; + sm.set_next Loop; ]; - ] [ - (* Adding phase *) + ] + @@ [ + (* In Adding phase, next will be Doubling *) doubling <-- vdd; - load_idx <-- (current_bit_v @: current_bit_u); + last_step <-- (bit_pos.value ==:. 0); bit_pos <-- bit_pos.value -:. 1; - if_ ((~: current_bit_u) &: (~: current_bit_v)) [ - sm.set_next Loop; + (* add only needed if either of the current u1 or u2 bits is set, skip otherwise *) + if_ (current_bit_u1 ||: current_bit_u2) [ + (* Initialize and start point_add program *) + (* Shamir's trick: adding G, Q or precomputed G+Q *) + (* Reminder: R = u1*G + u2*Q *) + reg_file.(Config.x2) <-- mux2 ~:current_bit_u2 const_gx (mux2 ~:current_bit_u1 const_qx const_gpqx); + reg_file.(Config.y2) <-- mux2 ~:current_bit_u2 const_gy (mux2 ~:current_bit_u1 const_qy const_gpqy); + reg_file.(Config.z2) <-- mux2 ~:current_bit_u2 const_gz (mux2 ~:current_bit_u1 const_qz const_gpqz); + pc <--. point_add_start; + sm.set_next Run_point_add; ] [ - sm.set_next Load; + sm.set_next Loop; ]; ]; ]; - State.Load, [ - if_ (load_idx.value ==:. 0) [ - x2_latched <-- reg_file.(Config.x1).value; - y2_latched <-- reg_file.(Config.y1).value; - z2_latched <-- reg_file.(Config.z1).value; - ] @@ elif (load_idx.value ==:. 1) [ - x2_latched <-- const_gx; - y2_latched <-- const_gy; - z2_latched <-- const_gz; - ] @@ elif (load_idx.value ==:. 2) [ - x2_latched <-- const_qx; - y2_latched <-- const_qy; - z2_latched <-- const_qz; - ] [ - x2_latched <-- const_gpqx; - y2_latched <-- const_gpqy; - z2_latched <-- const_gpqz; + State.Run_point_add, [ + (* wait for program to finish and proceed to next iteration *) + when_ (arith_out.ready &&: (pc.value ==:. point_add_end)) [ + sm.set_next Loop; ]; - step <-- zero 6; - sm.set_next Run_add; ]; - State.Run_add, [ - if_ (~:(op_started.value)) [ - arith_start <-- vdd; - op_started <-- vdd; - ] [ - when_ arith_out.done_ [ - proc (Array.to_list (Array.mapi reg_file ~f:(fun idx reg -> - when_ (current_dst ==:. idx) [ - reg <-- arith_out.reg_write_data; - ]))); - - if_ (step.value ==:. Config.num_steps - 1) [ - op_started <-- gnd; - sm.set_next Loop; - ] [ - step <-- step.value +:. 1; - op_started <-- gnd; - ]; - ]; + State.Run_finalize, [ + (* wait for program to finish and proceed to next state *) + when_ (arith_out.ready &&: (pc.value ==:. finalize_end)) [ + sm.set_next Done; ]; ]; - State.Finalize_op, [ - (* Set up arith inputs based on finalize_step *) - arith_prime_sel <-- gnd; (* Finalize ops use mod p *) - - if_ (~:(finalize_step.value)) [ - (* z_inv = Z1^(-1) mod p *) - arith_op <--. Op.inv; - arith_a <-- reg_file.(Config.z1).value; - arith_b <-- zero width; - ] [ - (* x_affine = X1 * z_inv mod p *) - arith_op <--. Op.mul; - arith_a <-- reg_file.(Config.x1).value; - arith_b <-- z_inv_reg.value; - ]; - - if_ (~:(finalize_op_started.value)) [ - arith_start <-- vdd; - finalize_op_started <-- vdd; - ] [ - when_ arith_out.done_ [ - (* Store result *) - if_ (~:(finalize_step.value)) [ - z_inv_reg <-- arith_out.reg_write_data; - ] [ - x_affine_reg <-- arith_out.reg_write_data; - ]; - - if_ finalize_step.value [ - sm.set_next Compare; - ] [ - finalize_step <-- vdd; - finalize_op_started <-- gnd; - ]; - ]; - ]; - ]; - - State.Compare, [ - (* Compare x_affine with r *) - valid_reg <-- (x_affine_reg.value ==: r_reg.value); - out_x <-- reg_file.(Config.x1).value; - out_y <-- reg_file.(Config.y1).value; - out_z <-- reg_file.(Config.z1).value; - sm.set_next Done; - ]; - State.Done, [ - done_flag <-- vdd; + (* Combinationally assign the done and valid outputs *) + done_w <-- vdd; + (* Check the result only AFTER! the program has finished - result is not available in the register earlier! *) + valid_w <-- (reg_file.(Config.t0).value ==: zero width); (* see program_finalize result handling *) + + (* Set next state *) sm.set_next Idle; ]; ]; @@ -521,9 +428,9 @@ let create scope (i : _ I.t) = { O. busy = ~:(sm.is Idle) - ; done_ = done_flag.value - ; valid = valid_reg.value - ; x = out_x.value - ; y = out_y.value - ; z_out = out_z.value - } \ No newline at end of file + ; done_ = done_w.value + ; valid = valid_w.value + ; x = reg_file.(Config.x1).value + ; y = reg_file.(Config.y1).value + ; z_out = reg_file.(Config.z1).value + } diff --git a/test/test_arith.ml b/test/test_arith.ml index d139a47..9989960 100644 --- a/test/test_arith.ml +++ b/test/test_arith.ml @@ -9,7 +9,7 @@ let () = let sim = Sim.create (Arith.create scope) in let inputs = Cyclesim.inputs sim in - let outputs = Cyclesim.outputs sim in + let outputs = Cyclesim.outputs ~clock_edge:Before sim in (* Simulated register file *) let registers = Array.init 32 ~f:(fun _ -> Z.zero) in @@ -24,7 +24,7 @@ let () = let reset () = inputs.clear := Bits.vdd; - inputs.start := Bits.gnd; + inputs.valid := Bits.gnd; inputs.op := Bits.zero 2; inputs.prime_sel := Bits.gnd; inputs.reg_read_data_a := Bits.zero 256; @@ -46,11 +46,10 @@ let () = inputs.reg_read_data_b := z_to_bits registers.(addr_b); (* Start operation *) - inputs.start := Bits.vdd; + inputs.valid := Bits.vdd; inputs.op := Bits.of_int ~width:2 op; inputs.prime_sel := if prime_sel then Bits.vdd else Bits.gnd; Cyclesim.cycle sim; - inputs.start := Bits.gnd; (* Run until done *) let max_cycles = 10_000 in @@ -59,9 +58,10 @@ let () = Stdio.printf " TIMEOUT after %d cycles\n" max_cycles; false end else begin - let is_done = Bits.to_bool !(outputs.done_) in + let is_done = Bits.to_bool !(outputs.ready) in if is_done then begin registers.(addr_out) <- bits_to_z !(outputs.reg_write_data); + inputs.valid := Bits.gnd; Cyclesim.cycle sim; (* let machine return to Idle *) true end else begin @@ -916,21 +916,21 @@ let val_d = Z.of_int 22222 in (* Helper to run op WITHOUT reset *) let run_op_no_reset ~op ~data_a ~data_b = - inputs.start := Bits.vdd; + inputs.valid := Bits.vdd; inputs.op := Bits.of_int ~width:2 op; inputs.prime_sel := Bits.gnd; inputs.reg_read_data_a := z_to_bits data_a; inputs.reg_read_data_b := z_to_bits data_b; Cyclesim.cycle sim; - inputs.start := Bits.gnd; let max_cycles = 3000 in let rec wait n = if n >= max_cycles then begin Stdio.printf " TIMEOUT after %d cycles\n" max_cycles; None - end else if Bits.to_bool !(outputs.done_) then begin + end else if Bits.to_bool !(outputs.ready) then begin let result = bits_to_z !(outputs.reg_write_data) in + inputs.valid := Bits.gnd; Cyclesim.cycle sim; (* let Done -> Idle transition complete *) Some result end else begin @@ -1010,11 +1010,10 @@ let run_pm_op ~src1 ~src2 ~dst = inputs.reg_read_data_b := z_to_bits pm_registers.(src2); (* Start multiplication *) - inputs.start := Bits.vdd; + inputs.valid := Bits.vdd; inputs.op := Bits.of_int ~width:2 2; (* mul *) inputs.prime_sel := Bits.gnd; Cyclesim.cycle sim; - inputs.start := Bits.gnd; (* Wait for completion *) let max_cycles = 300 in @@ -1023,12 +1022,13 @@ let run_pm_op ~src1 ~src2 ~dst = Stdio.printf " TIMEOUT after %d cycles!\n" max_cycles; Stdio.printf " busy=%d done=%d\n" (Bits.to_int !(outputs.busy)) - (Bits.to_int !(outputs.done_)); + (Bits.to_int !(outputs.ready)); None end else begin - if Bits.to_bool !(outputs.done_) then begin + if Bits.to_bool !(outputs.ready) then begin let result = bits_to_z !(outputs.reg_write_data) in pm_registers.(dst) <- result; + inputs.valid := Bits.gnd; Cyclesim.cycle sim; (* let Done -> Idle complete *) Stdio.printf " Completed in %d cycles\n" n; Some result @@ -1106,19 +1106,19 @@ let run_mul_verbose ~a ~b ~label = inputs.reg_read_data_a := z_to_bits a; inputs.reg_read_data_b := z_to_bits b; - inputs.start := Bits.vdd; + inputs.valid := Bits.vdd; inputs.op := Bits.of_int ~width:2 2; (* mul *) inputs.prime_sel := Bits.gnd; Cyclesim.cycle sim; - inputs.start := Bits.gnd; for cycle = 1 to 20 do let busy = Bits.to_int !(outputs.busy) in - let done_ = Bits.to_int !(outputs.done_) in - Stdio.printf " cycle %2d: busy=%d done=%d\n" cycle busy done_; + let ready = Bits.to_int !(outputs.ready) in + Stdio.printf " cycle %2d: busy=%d done=%d\n" cycle busy ready; - if done_ = 1 then begin + if ready = 1 then begin let result = bits_to_z !(outputs.reg_write_data) in + inputs.valid := Bits.gnd; Stdio.printf " Result: %s\n\n" (Z.to_string result); (* Cycle once more to clear done state *) Cyclesim.cycle sim; diff --git a/test/test_ecdsa.ml b/test/test_ecdsa.ml index b406208..b98736f 100644 --- a/test/test_ecdsa.ml +++ b/test/test_ecdsa.ml @@ -3,41 +3,41 @@ open Hardcaml let () = Stdio.printf "=== ECDSA Signature Verification Test ===\n\n"; - + let scope = Scope.create ~flatten_design:true () in let module Sim = Cyclesim.With_interface(Ecdsa.I)(Ecdsa.O) in let sim = Sim.create (Ecdsa.create scope) in - + let inputs = Cyclesim.inputs sim in let outputs = Cyclesim.outputs sim in - + let prime_p = Arith.Config.prime_p in let prime_n = Arith.Config.prime_n in - + (* secp256k1 curve parameters *) let param_a = Z.zero in let param_b3 = Z.of_int 21 in - + (* Generator point G *) let g_x = Z.of_string "0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" in let g_y = Z.of_string "0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8" in - + (* Private key d = 2, Public key Q = 2G *) let private_key = Z.of_int 2 in - + let z_to_bits z = let hex_str = Z.format "%x" z in let padded = String.pad_left hex_str ~len:64 ~char:'0' in Bits.of_hex ~width:256 padded in - + (* Modular arithmetic helpers *) let mod_mul_p a b = Z.((a * b) mod prime_p) in let mod_inv_p a = Z.invert a prime_p in let mod_inv_n a = Z.invert a prime_n in let mod_mul_n a b = Z.((a * b) mod prime_n) in let mod_add_n a b = Z.((a + b) mod prime_n) in - + (* Reference: affine point addition *) let affine_add (x1, y1) (x2, y2) = let mod_add a b = Z.((a + b) mod prime_p) in @@ -57,20 +57,20 @@ let () = Some (x3, y3) end in - + (* Reference: scalar multiplication *) let scalar_mult k (x, y) = let rec loop n acc pt = if Z.equal n Z.zero then acc else - let acc' = + let acc' = if Z.(equal (n land one) one) then match acc with | None -> Some pt | Some a -> affine_add a pt else acc in - let pt' = + let pt' = match affine_add pt pt with | None -> pt | Some p -> p @@ -79,7 +79,7 @@ let () = in loop k None (x, y) in - + (* Generate a valid ECDSA signature for message hash z using nonce k *) let sign ~z ~k = (* R = k*G *) @@ -95,7 +95,7 @@ let () = if Z.equal s Z.zero then None else Some (r, s) in - + let reset () = inputs.clear := Bits.vdd; inputs.start := Bits.gnd; @@ -108,10 +108,10 @@ let () = inputs.clear := Bits.gnd; Cyclesim.cycle sim in - + let run_verify ~z ~r ~s = reset (); - + inputs.z := z_to_bits z; inputs.r := z_to_bits r; inputs.s := z_to_bits s; @@ -120,7 +120,7 @@ let () = inputs.start := Bits.vdd; Cyclesim.cycle sim; inputs.start := Bits.gnd; - + let max_cycles = 15_000_000 in let rec wait n = if n >= max_cycles then begin @@ -137,21 +137,21 @@ let () = in wait 0 in - + let results = ref [] in let record result = results := result :: !results in - + (* ============================================== *) (* TEST 1: Valid signature with small values *) (* ============================================== *) - + Stdio.printf "Test 1: Valid signature (z=12345, k=7)\n"; - + let z = Z.of_int 12345 in let k = Z.of_int 7 in - + (match sign ~z ~k with - | None -> + | None -> Stdio.printf " Failed to generate signature\n"; record false | Some (r, s) -> @@ -164,18 +164,18 @@ let () = let pass = valid in Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass); - + (* ============================================== *) (* TEST 2: Valid signature with larger values *) (* ============================================== *) - + Stdio.printf "Test 2: Valid signature (z=0xDEADBEEF, k=0x123456)\n"; - + let z = Z.of_string "0xDEADBEEF" in let k = Z.of_string "0x123456" in - + (match sign ~z ~k with - | None -> + | None -> Stdio.printf " Failed to generate signature\n"; record false | Some (r, s) -> @@ -188,18 +188,18 @@ let () = let pass = valid in Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass); - + (* ============================================== *) (* TEST 3: Valid signature with 256-bit hash *) (* ============================================== *) - + Stdio.printf "Test 3: Valid signature (256-bit z, k)\n"; - + let z = Z.of_string "0xb94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9" in let k = Z.of_string "0x3b78ce563f89a0ed9414f5aa28ad0d96d6795f9c63" in - + (match sign ~z ~k with - | None -> + | None -> Stdio.printf " Failed to generate signature\n"; record false | Some (r, s) -> @@ -212,18 +212,18 @@ let () = let pass = valid in Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass); - + (* ============================================== *) (* TEST 4: Invalid signature - wrong z *) (* ============================================== *) - + Stdio.printf "Test 4: Invalid signature (wrong message hash)\n"; - + let z = Z.of_int 12345 in let k = Z.of_int 7 in - + (match sign ~z ~k with - | None -> + | None -> Stdio.printf " Failed to generate signature\n"; record false | Some (r, s) -> @@ -236,18 +236,18 @@ let () = let pass = not valid in Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass); - + (* ============================================== *) (* TEST 5: Invalid signature - wrong r *) (* ============================================== *) - + Stdio.printf "Test 5: Invalid signature (wrong r)\n"; - + let z = Z.of_int 12345 in let k = Z.of_int 7 in - + (match sign ~z ~k with - | None -> + | None -> Stdio.printf " Failed to generate signature\n"; record false | Some (_r, s) -> @@ -260,18 +260,18 @@ let () = let pass = not valid in Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass); - + (* ============================================== *) (* TEST 6: Invalid signature - wrong s *) (* ============================================== *) - + Stdio.printf "Test 6: Invalid signature (wrong s)\n"; - + let z = Z.of_int 12345 in let k = Z.of_int 7 in - + (match sign ~z ~k with - | None -> + | None -> Stdio.printf " Failed to generate signature\n"; record false | Some (r, _s) -> @@ -284,18 +284,18 @@ let () = let pass = not valid in Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass); - + (* ============================================== *) (* TEST 7: Valid signature - another random case *) (* ============================================== *) - + Stdio.printf "Test 7: Valid signature (z=0xCAFEBABE, k=0x999)\n"; - + let z = Z.of_string "0xCAFEBABE" in let k = Z.of_string "0x999" in - + (match sign ~z ~k with - | None -> + | None -> Stdio.printf " Failed to generate signature\n"; record false | Some (r, s) -> @@ -308,17 +308,17 @@ let () = let pass = valid in Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass); - + (* ============================================== *) (* TEST 8: Completely random signature (invalid) *) (* ============================================== *) - + Stdio.printf "Test 8: Random values (should be invalid)\n"; - + let z = Z.of_string "0x1111111111111111" in let r = Z.of_string "0x2222222222222222" in let s = Z.of_string "0x3333333333333333" in - + (match run_verify ~z ~r ~s with | None -> record false | Some valid -> @@ -326,18 +326,18 @@ let () = let pass = not valid in Stdio.printf " %s\n\n" (if pass then "PASS ✓" else "FAIL ✗"); record pass); - + (* ============================================== *) (* TEST SUMMARY *) (* ============================================== *) - + let results_list = List.rev !results in let passed = List.count results_list ~f:Fn.id in let total = List.length results_list in - + Stdio.printf "=== Test Summary ===\n"; Stdio.printf "Passed: %d/%d\n" passed total; - + if passed = total then begin Stdio.printf "\n"; Stdio.printf "███████╗██╗ ██╗ ██████╗ ██████╗███████╗███████╗███████╗\n";