From 5157e2947099262b167fed435056267cdc895264 Mon Sep 17 00:00:00 2001 From: Peter Drotos Date: Tue, 3 Mar 2026 17:12:36 +0100 Subject: [PATCH 1/2] Changed mod_add to valid/ready interface --- src/arith.ml | 99 +++++++++++++++++++++++--------------------- src/mod_add.ml | 79 +++++++++++++++++------------------ test/test_mod_add.ml | 49 +++++++++++----------- 3 files changed, 113 insertions(+), 114 deletions(-) diff --git a/src/arith.ml b/src/arith.ml index 234e624..af02ebe 100644 --- a/src/arith.ml +++ b/src/arith.ml @@ -3,24 +3,24 @@ open Signal (* Arith - Modular arithmetic unit for secp256k1 field operations - + Performs add, sub, mul, inv modulo either field prime p or curve order n. Operands are read from and results written to an external 32×256-bit register file. - + Operations (op input): 0 = add: r[addr_out] <- r[addr_a] + r[addr_b] mod m 1 = sub: r[addr_out] <- r[addr_a] - r[addr_b] mod m 2 = mul: r[addr_out] <- r[addr_a] * r[addr_b] mod m 3 = inv: r[addr_out] <- r[addr_a]^(-1) mod m (addr_b ignored) - + Modulus selection (prime_sel): 0 = prime_p, 1 = prime_n - + Protocol: 1. Set addr_a, addr_b, addr_out, op, prime_sel; pulse start 2. Provide reg_read_data_a/b in response to reg_read_addr_a/b 3. Wait for done_ pulse; result written via reg_write_* signals 4. For inv, check inv_exists to confirm inverse was found - + State machine: Idle -> Load -> Capture -> Compute -> Write -> Done -> Idle *) @@ -28,12 +28,12 @@ module Config = struct let width = 256 let num_registers = 32 let reg_addr_width = 5 - + (* secp256k1 field prime: p = 2^256 - 2^32 - 977 *) let prime_p = Z.of_string "115792089237316195423570985008687907853269984665640564039457584007908834671663" (* secp256k1 curve order - CORRECTED *) let prime_n = Z.of_string "115792089237316195423570985008687907852837564279074904382605163141518161494337" - + let z_to_constant z = let hex_str = Z.format "%x" z in let padded = String.make ((width / 4) - String.length hex_str) '0' ^ hex_str in @@ -91,9 +91,9 @@ end let create scope (i : _ I.t) = let open Always in let ( -- ) = Scope.naming scope in - + let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in - + (* State machine *) let sm = State_machine.create (module State) spec ~enable:vdd in @@ -103,47 +103,47 @@ let create scope (i : _ I.t) = let addr_a_reg = Variable.reg spec ~width:Config.reg_addr_width in let addr_b_reg = Variable.reg spec ~width:Config.reg_addr_width in let addr_out_reg = Variable.reg spec ~width:Config.reg_addr_width in - + (* Captured operands *) let operand_a = Variable.reg spec ~width:Config.width in let operand_b = Variable.reg spec ~width:Config.width in - + (* Operation start signals *) - let start_add = Variable.reg spec ~width:1 in - let start_sub = Variable.reg spec ~width:1 in + let mod_add_valid = Variable.reg spec ~width:1 in + let mod_sub_valid = Variable.reg spec ~width:1 in let start_mul = Variable.reg spec ~width:1 in let start_inv = Variable.reg spec ~width:1 in - + (* Result capture *) let result_reg = Variable.reg spec ~width:Config.width in let inv_exists_reg = Variable.reg spec ~width:1 in - + (* Output registers *) let reg_write_enable = Variable.reg spec ~width:1 in let done_flag = Variable.reg spec ~width:1 in - + (* Prime constants *) let prime_p_const = Signal.of_constant (Config.z_to_constant Config.prime_p) in let prime_n_const = Signal.of_constant (Config.z_to_constant Config.prime_n) in let selected_prime = mux2 prime_sel_reg.value prime_n_const prime_p_const in - + (* Count significant bits in operand_b for multiplication optimization *) (* Simple approach: use full width, or implement leading zero count *) let num_bits_for_mul = of_int ~width:9 Config.width in - + (* Instantiate arithmetic modules *) let mod_addsub_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add") { Mod_add.ModAdd.I. - clock = i.clock - ; clear = i.clear - ; start = start_add.value |: start_sub.value - ; a = operand_a.value - ; b = operand_b.value - ; modulus = selected_prime - ; subtract = start_sub.value + clock = i.clock + ; clear = i.clear + ; valid = mod_add_valid.value |: mod_sub_valid.value + ; a = operand_a.value + ; b = operand_b.value + ; modulus = selected_prime + ; subtract = mux2 (op_reg.value ==: of_int ~width:2 Op.sub) vdd gnd } in - + let mod_mul_out = Mod_mul.ModMul.create (Scope.sub_scope scope "mod_mul") { Mod_mul.ModMul.I. clock = i.clock @@ -155,7 +155,7 @@ let create scope (i : _ I.t) = ; num_bits = num_bits_for_mul } in - + let mod_inv_out = Mod_inv.ModInv.create (Scope.sub_scope scope "mod_inv") { Mod_inv.ModInv.I. clock = i.clock @@ -165,9 +165,9 @@ let create scope (i : _ I.t) = ; modulus = selected_prime } in - + (* Mux results based on operation *) - let op_result = + let op_result = mux op_reg.value [ mod_addsub_out.result; mod_addsub_out.result; @@ -175,25 +175,24 @@ let create scope (i : _ I.t) = mod_inv_out.result; ] in - - let op_valid = + + let op_ready = mux op_reg.value [ - mod_addsub_out.valid; - mod_addsub_out.valid; + mod_addsub_out.ready; + mod_addsub_out.ready; mod_mul_out.valid; mod_inv_out.valid; ] in - + compile [ (* Default: clear pulse signals *) - start_add <-- gnd; - start_sub <-- gnd; + (* TODO move start_mul, start_inv clear to Compute step as well when updated *) start_mul <-- gnd; start_inv <-- gnd; reg_write_enable <-- gnd; done_flag <-- gnd; - + sm.switch [ State.Idle, [ when_ i.start [ @@ -206,43 +205,47 @@ let create scope (i : _ I.t) = sm.set_next Load; ]; ]; - + State.Load, [ (* Wait one cycle for register file to provide data *) sm.set_next Capture; ]; - + State.Capture, [ operand_a <-- i.reg_read_data_a; operand_b <-- i.reg_read_data_b; - + switch op_reg.value [ - of_int ~width:2 Op.add, [ start_add <-- vdd ]; - of_int ~width:2 Op.sub, [ start_sub <-- vdd ]; + of_int ~width:2 Op.add, [ mod_add_valid <-- vdd ]; + of_int ~width:2 Op.sub, [ mod_sub_valid <-- vdd ]; of_int ~width:2 Op.mul, [ start_mul <-- vdd ]; of_int ~width:2 Op.inv, [ start_inv <-- vdd ]; ]; - + sm.set_next Compute; ]; State.Compute, [ (* Simply wait for operation to complete *) - when_ op_valid [ + when_ op_ready [ + (* Clear valid signals *) + mod_add_valid <-- gnd; + mod_sub_valid <-- gnd; + result_reg <-- op_result; inv_exists_reg <-- mod_inv_out.exists; sm.set_next Write; ]; ]; - + State.Write, [ (* Write result to register file *) reg_write_enable <-- vdd; sm.set_next Done; ]; - + State.Done, [ done_flag <-- vdd; if_ i.start [ @@ -259,10 +262,10 @@ State.Done, [ ]; ]; ]; - + (* Busy when not idle *) let busy = ~:(sm.is Idle) in - + { O. busy = busy -- "busy" ; done_ = done_flag.value -- "done" diff --git a/src/mod_add.ml b/src/mod_add.ml index 53911e0..0560eb4 100644 --- a/src/mod_add.ml +++ b/src/mod_add.ml @@ -7,18 +7,18 @@ module Config = struct let width = 256 end -(* Modular Addition/Subtraction Module (combinatorial adder variant) +(* Modular Addition/Subtraction Module Performs (a ± b) mod n where n is provided as an input. - - Uses comb_add for the initial a ± b in a single cycle - - 2-cycle operation: Free (latch + add) → Modulus_adjust (modular reduction) + - Uses an internal comb_add for arithmetic + - 2-cycle operation: Add (latch + add) → Adjust (modular reduction) - Automatic modular reduction *) module ModAdd = struct module State = struct type t = - | Free - | Modulus_adjust + | Add + | Adjust [@@deriving sexp_of, compare, enumerate] end @@ -26,7 +26,7 @@ module ModAdd = struct type 'a t = { clock : 'a ; clear : 'a - ; start : 'a + ; valid : 'a ; a : 'a [@bits Config.width] (* First operand *) ; b : 'a [@bits Config.width] (* Second operand *) ; modulus : 'a [@bits Config.width] (* Modulus n *) @@ -38,7 +38,7 @@ module ModAdd = struct module O = struct type 'a t = { result : 'a [@bits Config.width] - ; valid : 'a + ; ready : 'a (* 1 result is valid *) } [@@deriving sexp_of, hardcaml] end @@ -46,7 +46,7 @@ module ModAdd = struct let create scope (i : _ I.t) = let open Always in let ( -- ) = Scope.naming scope in - + let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in (* State machine *) @@ -55,44 +55,39 @@ module ModAdd = struct (* Registers for computation *) let result_ab = Variable.reg spec ~width:Config.width in let carry_ab = Variable.reg spec ~width:1 in - let modulus_reg = Variable.reg spec ~width:Config.width in - let is_subtract = Variable.reg spec ~width:1 in - - (* Output registers *) - let result = Variable.reg spec ~width:Config.width in - let valid = Variable.reg spec ~width:1 in - - (* Mux comb_add inputs by state: - - Free: i.a ± i.b (initial operation) - - Modulus_adjust: result_ab ± modulus (correction: add modulus if sub underflow, - subtract modulus if sum higher than modulus) *) - let in_modulus_adjust = sm.is Modulus_adjust in - let comb_a = mux2 in_modulus_adjust result_ab.value i.a in - let comb_b = mux2 in_modulus_adjust modulus_reg.value i.b in + + (* Adder inputs muxed by state: + - Add: i.a ± i.b (initial operation) + - Adjust: result_ab ± modulus (correction step) *) + let in_adjust = sm.is Adjust in + let adder_a = mux2 in_adjust result_ab.value i.a in + let adder_b = mux2 in_adjust i.modulus i.b in (* modulus adjust uses the opposite of the original operation - add if a-b, subtract if a+b *) - let comb_subtract = mux2 in_modulus_adjust ~:(is_subtract.value) i.subtract in + let adder_subtract = mux2 in_adjust ~:(i.subtract) i.subtract in + let adder_ready = vdd in (* currently using combinatorial adder, so no delay *) let comb_add_out = Comb_add.CombAdd.create (Scope.sub_scope scope "comb_add") - { Comb_add.CombAdd.I.a = comb_a; b = comb_b; subtract = comb_subtract } + { Comb_add.CombAdd.I.a = adder_a; b = adder_b; subtract = adder_subtract } in + let result_w = Variable.wire ~default:(zero Config.width) in + let ready_w = Variable.wire ~default:gnd in + compile [ sm.switch [ - State.Free, [ - valid <-- gnd; - when_ i.start [ + State.Add, [ + when_ (i.valid &: adder_ready) [ + (* update register values and select next state *) result_ab <-- comb_add_out.result; carry_ab <-- comb_add_out.carry_out; - modulus_reg <-- i.modulus; - is_subtract <-- i.subtract; - sm.set_next Modulus_adjust; + sm.set_next Adjust; ]; ]; - - State.Modulus_adjust, [ - - (* comb_add is computing result_ab ± modulus_reg here. *) + + State.Adjust, [ + + (* adder is computing result_ab ± modulus here *) (* For add: -- reduce if carry_ab=1 (a+b overflowed 256 bits, so a+b >= 2^256 > n) NOTE: comb_add_out.carry_out is not valid in this case @@ -104,22 +99,24 @@ module ModAdd = struct let sub_needs_adjust = carry_ab.value in let final_result = - mux2 is_subtract.value + mux2 i.subtract (mux2 sub_needs_adjust comb_add_out.result result_ab.value) (mux2 add_needs_adjust comb_add_out.result result_ab.value) in - proc [ - result <-- final_result; - valid <-- vdd; - sm.set_next Free; + when_ adder_ready [ + (* combinatorial: drive result output *) + result_w <-- final_result; + ready_w <-- vdd; (* No clear needed, Variable.wire ~default takes care *) + (* select next state *) + sm.set_next Add; ]; ]; ]; ]; - { O.result = result.value -- "result" - ; valid = valid.value -- "valid" + { O.result = result_w.value -- "result" + ; ready = ready_w.value -- "ready" } end diff --git a/test/test_mod_add.ml b/test/test_mod_add.ml index 4fe5434..ecbde41 100644 --- a/test/test_mod_add.ml +++ b/test/test_mod_add.ml @@ -2,10 +2,10 @@ open Base open Hardcaml let test () = - let test_modulus_z = + let test_modulus_z = Z.of_string "115792089237316195423570985008687907853269984665640564039457584007908834671663" in - + Stdio.printf "=== ModAdd Hardware Test (256-bit modular add/sub) ===\n\n"; Stdio.printf "Test Modulus (n) = %s\n\n" (Z.to_string test_modulus_z); @@ -14,7 +14,7 @@ let test () = let sim = Sim.create (Mod_add.ModAdd.create scope) in let inputs = Cyclesim.inputs sim in - let outputs = Cyclesim.outputs sim in + let outputs = Cyclesim.outputs ~clock_edge:Before sim in let z_to_bits z = let hex_str = Z.format "%x" z in @@ -35,14 +35,14 @@ let test () = Stdio.printf " b = %s\n" (Z.to_string b_z); Stdio.printf " n = %s\n" (Z.to_string modulus_z); - let expected_z = + let expected_z = let raw = if is_sub then Z.(a_z - b_z) else Z.(a_z + b_z) in Z.(erem raw modulus_z) in Stdio.printf " expected = %s\n" (Z.to_string expected_z); inputs.clear := Bits.vdd; - inputs.start := Bits.gnd; + inputs.valid := Bits.gnd; inputs.a := Bits.zero Mod_add.Config.width; inputs.b := Bits.zero Mod_add.Config.width; inputs.modulus := Bits.zero Mod_add.Config.width; @@ -56,30 +56,28 @@ let test () = inputs.b := z_to_bits b_z; inputs.modulus := z_to_bits modulus_z; inputs.subtract := if is_sub then Bits.vdd else Bits.gnd; - inputs.start := Bits.vdd; - Cyclesim.cycle sim; - - inputs.start := Bits.gnd; + inputs.valid := Bits.vdd; let max_cycles = 20 in let rec wait cycle_count = if cycle_count >= max_cycles then begin Stdio.printf " ERROR: Timeout after %d cycles\n\n" max_cycles; false - end else if Bits.to_bool !(outputs.valid) then begin + end else if Bits.to_bool !(outputs.ready) then begin + inputs.valid := Bits.gnd; let result_z = bits_to_z !(outputs.result) in - + Stdio.printf " Completed in %d cycles\n" cycle_count; Stdio.printf " result = %s\n" (Z.to_string result_z); - + let verified = Z.equal result_z expected_z in - + if verified then Stdio.printf " Verification: PASS ✓\n" else - Stdio.printf " Verification: FAIL ✗ (diff: %s)\n" + Stdio.printf " Verification: FAIL ✗ (diff: %s)\n" (Z.to_string Z.(result_z - expected_z)); - + Stdio.printf "\n"; verified end else begin @@ -87,7 +85,8 @@ let test () = wait (cycle_count + 1) end in - wait 0 + Cyclesim.cycle sim; + wait 1 in let test_with_default_modulus name a_z b_z is_sub = @@ -97,29 +96,29 @@ let test () = let results = [ test_with_default_modulus "Simple addition" (Z.of_int 100) (Z.of_int 200) false; - + test_with_default_modulus "Addition with modular wrap" (Z.of_string "115792089237316195423570985008687907853269984665640564039457584007908834671600") (Z.of_int 100) false; - + test_with_default_modulus "Simple subtraction" (Z.of_int 500) (Z.of_int 300) true; - + test_with_default_modulus "Subtraction requiring modular correction" (Z.of_int 100) (Z.of_int 200) true; - + test_with_default_modulus "Subtraction from modulus-1" Z.(test_modulus_z - one) (Z.of_int 10) true; - + test_with_default_modulus "Add zero" (Z.of_int 12345) Z.zero false; - + test_with_default_modulus "Subtract zero" (Z.of_int 12345) Z.zero true; - + test_case "Different modulus: small prime 997" (Z.of_int 500) (Z.of_int 600) (Z.of_int 997) false; - + test_case "BN254 subtraction wrap" (Z.of_int 10) (Z.of_int 20) @@ -138,4 +137,4 @@ let test () = else Stdio.printf "\n✗ Some tests failed\n" -let () = test () \ No newline at end of file +let () = test () From 6c439a27b008d3e73cabbf09af8c59efbec58f47 Mon Sep 17 00:00:00 2001 From: Peter Drotos Date: Wed, 4 Mar 2026 15:06:35 +0100 Subject: [PATCH 2/2] Update mod_mul & mod_inv to use mod_add instance --- README.md | 64 ++++----- src/arith.ml | 56 ++++++-- src/mod_add.ml | 20 +-- src/mod_inv.ml | 325 +++++++++++++++++++++---------------------- src/mod_mul.ml | 141 ++++++++----------- test/test_mod_inv.ml | 40 +++++- test/test_mod_mul.ml | 69 ++++++--- 7 files changed, 396 insertions(+), 319 deletions(-) diff --git a/README.md b/README.md index 4179898..3bfc018 100644 --- a/README.md +++ b/README.md @@ -71,24 +71,24 @@ flowchart TB subgraph external_left[" "] AUTH["License
Authority"]:::external end - + subgraph SECURITY_BLOCK["SECURITY BLOCK"] direction TB SL["Security Logic
(State Machine)"]:::security - + subgraph submodules[" "] direction LR TRNG["TRNG
256-bit"]:::trng ECDSA["ECDSA
secp256k1"]:::ecdsa ALLOW["Allowance
64-bit"]:::allowance end - + subgraph datapath[" "] direction LR ADDER["Int8 Add"]:::adder AND["AND Gate"]:::andgate end - + SL -->|request_new| TRNG TRNG -->|"nonce, valid"| SL SL -->|start| ECDSA @@ -97,17 +97,17 @@ flowchart TB ALLOW -->|enabled| AND ADDER --> AND end - + subgraph external_io[" "] direction LR WIN["Workload
Input"]:::external WOUT["Workload
Output"]:::external end - + AUTH <-->|"license_submit, r, s
nonce, ready"| SL WIN --> ADDER AND --> WOUT - + classDef external fill:#fff,stroke:#333,stroke-dasharray: 5 5 classDef security fill:#cce5ff,stroke:#004085 classDef trng fill:#c3e6cb,stroke:#155724 @@ -309,22 +309,22 @@ Here's an expanded section on the ECDSA and modular arithmetic architecture to a flowchart TB subgraph ECDSA["ECDSA Verification Block"] direction LR - + subgraph left[" "] direction TB - + subgraph SM["State Machine"] direction TB SM_PREP["Prep Phase
u1, u2 computation"] SM_LOOP["Scalar Mult Loop
256 iterations"] SM_FIN["Finalize
projective to affine"] SM_CMP["Compare
x_affine == r ?"] - + SM_PREP --> SM_LOOP SM_LOOP --> SM_FIN SM_FIN --> SM_CMP end - + subgraph REGS["Register File --- 17 x 256-bit"] direction LR R_PT["Point Coords
X1 Y1 Z1
X2 Y2 Z2
X3 Y3 Z3"] @@ -332,40 +332,40 @@ flowchart TB R_PRM["Params
a, b3"] end end - + subgraph right[" "] direction TB - + subgraph ARITH["Modular Arithmetic Unit"] direction TB - - subgraph ops[" "] - direction LR - INV["Inverse
Ext Euclidean"] - MUL["Multiply
shift-and-add"] - ADDSUB["Add - Sub"] + + subgraph INV["Inverse
Ext Euclidean"] + direction TB + end + + subgraph MUL["Multiply
shift-and-add"] + direction TB end - - subgraph shared["Shared Datapath"] + + subgraph ADDSUB["Add - Sub"] direction TB MOD["Modulus Select
prime p or order n"] ADD256["256-bit Adder"] MOD --> ADD256 end - - INV --> shared - MUL --> shared - ADDSUB --> shared + + INV --> ADDSUB + MUL --> ADDSUB end end - + SM <-->|"start, op
done"| ARITH REGS <-->|"read A B
write result"| ARITH end - + EXT_IN["Inputs:
z, r, s"] --> ECDSA ECDSA --> EXT_OUT["Output:
valid"] - + classDef outer fill:#f0f7ff,stroke:#2563eb,stroke-width:2px,color:#1e40af classDef arithbox fill:#fef9e7,stroke:#b7950b,stroke-width:2px,color:#7d6608 classDef smbox fill:#e8f8f5,stroke:#1abc9c,stroke-width:2px,color:#0e6655 @@ -375,7 +375,7 @@ flowchart TB classDef subunit fill:#fdebd0,stroke:#e67e22,stroke-width:1px,color:#a04000 classDef sharedbox fill:#fcf3cf,stroke:#d4ac0d,stroke-width:1px,color:#9a7d0a classDef external fill:#ffffff,stroke:#5d6d7e,stroke-width:1px,stroke-dasharray: 5 5,color:#2c3e50 - + class ECDSA outer class ARITH arithbox class SM smbox @@ -458,8 +458,8 @@ All operations work over 256-bit operands and can use either the field prime `p` The arithmetic unit interfaces with a 17-register file. Operations are started with a pulse and signal completion via `done_`. Typical cycle counts: - Add/Sub: 2–3 cycles -- Mul: ~250-750 cycles (bit-serial, varies with y input) -- Inv: ~1000–1500 cycles (varies with input) +- Mul: ~500-1000 cycles (bit-serial, varies with y input) +- Inv: ~2000–3000 cycles (varies with input) ### State Machine Overview @@ -487,7 +487,7 @@ Idle → Prep_op → Loop ⟷ Load → Run_add → Finalize_op → Compare → D ### Cycle Count -Total verification takes approximately 3–4 million cycles, dominated by the ~256 point operations in the scalar multiplication loop. At 1 GHz, this is in milliseconds—negligible compared to the licensing interval (minutes to days). +Total verification takes approximately 5 million cycles, dominated by the ~256 point operations in the scalar multiplication loop. At 1 GHz, this is in milliseconds—negligible compared to the licensing interval (minutes to days). ### Hardcoded Constants diff --git a/src/arith.ml b/src/arith.ml index af02ebe..3cbde80 100644 --- a/src/arith.ml +++ b/src/arith.ml @@ -132,15 +132,22 @@ let create scope (i : _ I.t) = let num_bits_for_mul = of_int ~width:9 Config.width in (* Instantiate arithmetic modules *) - let mod_addsub_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add") + + (* Forward declaration wires for mod_add inputs, driven below after mod_mul is available *) + let mod_add_valid_w = Signal.wire 1 in + let mod_add_a_w = Signal.wire Config.width in + let mod_add_b_w = Signal.wire Config.width in + let mod_add_subtract_w = Signal.wire 1 in + + let mod_add_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add") { Mod_add.ModAdd.I. clock = i.clock ; clear = i.clear - ; valid = mod_add_valid.value |: mod_sub_valid.value - ; a = operand_a.value - ; b = operand_b.value + ; valid = mod_add_valid_w + ; a = mod_add_a_w + ; b = mod_add_b_w ; modulus = selected_prime - ; subtract = mux2 (op_reg.value ==: of_int ~width:2 Op.sub) vdd gnd + ; subtract = mod_add_subtract_w } in @@ -153,6 +160,8 @@ let create scope (i : _ I.t) = ; y = operand_b.value ; modulus = selected_prime ; num_bits = num_bits_for_mul + ; mod_add_result = mod_add_out.result + ; mod_add_ready = mod_add_out.ready } in @@ -163,14 +172,43 @@ let create scope (i : _ I.t) = ; start = start_inv.value ; x = operand_a.value ; modulus = selected_prime + ; mod_add_result = mod_add_out.result + ; mod_add_ready = mod_add_out.ready + ; mod_add_adjusted = mod_add_out.adjusted } in + (* Drive mod_add inputs, muxed by operation *) + assign mod_add_valid_w (mux op_reg.value [ + mod_add_valid.value; (* add *) + mod_sub_valid.value; (* sub *) + mod_mul_out.mod_add_valid; (* mul *) + mod_inv_out.mod_add_valid; (* inv *) + ]); + assign mod_add_a_w (mux op_reg.value [ + operand_a.value; (* add *) + operand_a.value; (* sub *) + mod_mul_out.mod_add_a; (* mul *) + mod_inv_out.mod_add_a; (* inv *) + ]); + assign mod_add_b_w (mux op_reg.value [ + operand_b.value; (* add *) + operand_b.value; (* sub *) + mod_mul_out.mod_add_b; (* mul *) + mod_inv_out.mod_add_b; (* inv *) + ]); + assign mod_add_subtract_w (mux op_reg.value [ + gnd; (* add *) + vdd; (* sub *) + mod_mul_out.mod_add_subtract; (* mul *) + mod_inv_out.mod_add_subtract; (* inv *) + ]); + (* Mux results based on operation *) let op_result = mux op_reg.value [ - mod_addsub_out.result; - mod_addsub_out.result; + mod_add_out.result; + mod_add_out.result; mod_mul_out.result; mod_inv_out.result; ] @@ -178,8 +216,8 @@ let create scope (i : _ I.t) = let op_ready = mux op_reg.value [ - mod_addsub_out.ready; - mod_addsub_out.ready; + mod_add_out.ready; + mod_add_out.ready; mod_mul_out.valid; mod_inv_out.valid; ] diff --git a/src/mod_add.ml b/src/mod_add.ml index 0560eb4..fc677fd 100644 --- a/src/mod_add.ml +++ b/src/mod_add.ml @@ -37,8 +37,9 @@ module ModAdd = struct module O = struct type 'a t = - { result : 'a [@bits Config.width] - ; ready : 'a (* 1 result is valid *) + { result : 'a [@bits Config.width] + ; ready : 'a (* 1 result is valid *) + ; adjusted : 'a (* 1 if modular correction was applied this cycle, valid with ready *) } [@@deriving sexp_of, hardcaml] end @@ -71,8 +72,9 @@ module ModAdd = struct { Comb_add.CombAdd.I.a = adder_a; b = adder_b; subtract = adder_subtract } in - let result_w = Variable.wire ~default:(zero Config.width) in - let ready_w = Variable.wire ~default:gnd in + let result_w = Variable.wire ~default:(zero Config.width) in + let ready_w = Variable.wire ~default:gnd in + let adjusted_w = Variable.wire ~default:gnd in compile [ sm.switch [ @@ -106,8 +108,9 @@ module ModAdd = struct when_ adder_ready [ (* combinatorial: drive result output *) - result_w <-- final_result; - ready_w <-- vdd; (* No clear needed, Variable.wire ~default takes care *) + result_w <-- final_result; + ready_w <-- vdd; (* No clear needed, Variable.wire ~default takes care *) + adjusted_w <-- mux2 i.subtract sub_needs_adjust add_needs_adjust; (* select next state *) sm.set_next Add; ]; @@ -115,8 +118,9 @@ module ModAdd = struct ]; ]; - { O.result = result_w.value -- "result" - ; ready = ready_w.value -- "ready" + { O.result = result_w.value -- "result" + ; ready = ready_w.value -- "ready" + ; adjusted = adjusted_w.value -- "adjusted" } end diff --git a/src/mod_inv.ml b/src/mod_inv.ml index f2c0bc3..276b023 100644 --- a/src/mod_inv.ml +++ b/src/mod_inv.ml @@ -6,25 +6,26 @@ module Config = struct let width = 256 end -(* Binary Extended GCD for modular inverse computation - +(* Binary Extended GCD for modular inverse computation + REQUIREMENT: This implementation assumes the modulus is an odd prime. - For odd prime modulus and any x coprime to it, gcd(x, modulus) = 1 - Since modulus is odd, at most one of {x, modulus} can be even - Therefore we never need to handle the "both even" case or factor out common powers of 2 - This simplifies the algorithm significantly compared to the general case + + All arithmetic is driven through an external mod_add instance (shared with add/sub/mul). *) module ModInv = struct module State = struct type t = | Idle | Op_sel (* decide which operation to start next *) - | Div2_xs (* divide x and s by 2, s might need adjusting *) - | Div2_yu (* divide y and u by 2, u might need adjusting *) - | Sub_rems (* subtract remainders *) - | Sub_rems_reverse (* reverse subtract remainders, we don't know in advance which direction needed *) + | Div2_add (* divide appropriate remainder and corresponding coeffiecient by 2, start adjusting the coefficient (c>>1 + mod>>1) if needed *) + | Div2_p1 (* finish adjusting the coefficient (c=c+1) if needed *) + | Sub_rems (* subtract remainders: x - y *) + | Sub_rems_reverse (* reverse subtract: y - x (when x < y) *) | Sub_coeffs (* subtract coefficients *) - | Adjust_coeff (* adjust the reduced coefficient after subtraction *) | Done [@@deriving sexp_of, compare, enumerate] end @@ -36,70 +37,61 @@ module ModInv = struct ; start : 'a ; x : 'a [@bits Config.width] ; modulus : 'a [@bits Config.width] + ; mod_add_result : 'a [@bits Config.width] + ; mod_add_ready : 'a + ; mod_add_adjusted : 'a (* underflow/overflow flag, valid when mod_add_ready *) } [@@deriving sexp_of, hardcaml] end -module O = struct - type 'a t = - { result : 'a [@bits Config.width] - ; valid : 'a - ; exists : 'a - (* Debug outputs *) - ; dbg_x : 'a [@bits Config.width] - ; dbg_y : 'a [@bits Config.width] - ; dbg_s : 'a [@bits Config.width] - ; dbg_u : 'a [@bits Config.width] - } - [@@deriving sexp_of, hardcaml] -end + module O = struct + type 'a t = + { result : 'a [@bits Config.width] + ; valid : 'a + ; exists : 'a + ; mod_add_valid : 'a + ; mod_add_a : 'a [@bits Config.width] + ; mod_add_b : 'a [@bits Config.width] + ; mod_add_subtract : 'a + (* Debug outputs *) + ; dbg_x : 'a [@bits Config.width] + ; dbg_y : 'a [@bits Config.width] + ; dbg_s : 'a [@bits Config.width] + ; dbg_u : 'a [@bits Config.width] + } + [@@deriving sexp_of, hardcaml] + end let create scope (i : _ I.t) = let open Always in let ( -- ) = Scope.naming scope in let width = Config.width in - + let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in - + let sm = State_machine.create (module State) spec ~enable:vdd in let x = Variable.reg spec ~width in (* remainder *) let y = Variable.reg spec ~width in (* remainder *) let s = Variable.reg spec ~width in (* coefficient for x *) let u = Variable.reg spec ~width in (* coefficient for y *) - + let modulus_reg = Variable.reg spec ~width:width in - (* helper reg for the FSM *) - let reduced_xny = Variable.reg spec ~width:1 in (* 1 = x was reduced (x-y), 0 = y was reduced (y-x) *) - + (* helper regs for the FSM *) + let reduced_xny = Variable.reg spec ~width:1 in (* 1 = x was reduced (x-y), 0 = y was reduced (y-x) *) + let div2_xny = Variable.reg spec ~width:1 in (* 1 = divide x, 0 = divide y *) + let div2_coeff_odd = Variable.reg spec ~width:1 in (* was original coefficient odd? *) + let result = Variable.reg spec ~width in let valid = Variable.reg spec ~width:1 in let exists = Variable.reg spec ~width:1 in - (* 256 bit adder instance *) - let comb_add_a = Variable.wire ~default:(zero width) in - let comb_add_b = Variable.wire ~default:(zero width) in - let comb_subtract = Variable.wire ~default:gnd in - let comb_add_out = - Comb_add.CombAdd.create (Scope.sub_scope scope "comb_add") - { Comb_add.CombAdd.I.a = comb_add_a.value; - b = comb_add_b.value; - subtract = comb_subtract.value } - in - - (* Compute new coefficient after a div2 step. *) - (* if c is even: *) - (* c = c/2 *) - (* else: *) - (* c = (c + mod)/2 *) - let div2_new_coeff c = - (* c + mod, using adder instance *) - let sum_c_mod = comb_add_out.carry_out @: comb_add_out.result in (* sum might be +1 bit wide, use carry *) - (* * if c is even: c = c/2 else: c = (c + mod)/2 *) - mux2 ~:(lsb c) (srl c 1) - (sel_top sum_c_mod width) (* using sel_top instead of srl as sum is +1 bit wide *) - in + (* Output wires for driving external mod_add *) + let mod_add_valid_w = Variable.wire ~default:gnd in + let mod_add_a_w = Variable.wire ~default:(zero Config.width) in + let mod_add_b_w = Variable.wire ~default:(zero Config.width) in + let mod_add_subtract_w = Variable.wire ~default:gnd in compile [ sm.switch [ @@ -128,7 +120,6 @@ end modulus_reg <-- i.modulus; s <-- of_int ~width:width 1; u <-- zero width; - sm.set_next Op_sel; ]; ]; @@ -147,60 +138,76 @@ end valid <-- vdd; sm.set_next Done; ] (* else, next operation selection below *) - @@ elif (~:(lsb x.value)) [sm.set_next Div2_xs] (* if x even -> divide x *) - @@ elif (~:(lsb y.value)) [sm.set_next Div2_yu] (* if y even -> divide y *) - @@ [sm.set_next Sub_rems]; (* else -> subtract remainders *) + @@ elif (~:(lsb x.value)) [ (* if x even -> divide x (xs pair) *) + div2_xny <-- vdd; + div2_coeff_odd <-- lsb s.value; (* needed in Div2_p1 *) + sm.set_next Div2_add; + ] + @@ elif (~:(lsb y.value)) [ (* if y even -> divide y (yu pair) *) + div2_xny <-- gnd; + div2_coeff_odd <-- lsb u.value; (* needed in Div2_p1 *) + sm.set_next Div2_add; + ] + @@ [ (* else -> subtract remainders *) + sm.set_next Sub_rems + ]; ]; - State.Div2_xs, [ - (* x = x/2 *) - (* if s is even: *) - (* s = s/2 *) - (* else: *) - (* s = (s + mod)/2 *) + State.Div2_add, [ + (* r = r/2 r = x or y *) + (* if c is even: c = s or u *) + (* c = c/2 *) + (* else: *) + (* c = (c + mod)/2 *) + (* *) + (* CAUTION! (c + mod)/2 has 2 steps: *) + (* 1. Div2_add: *) + (* c = (c >> 1) + (mod >> 1) *) + (* 2. Div2_p1: *) + (* c = c + 1 *) + + let coeff = mux2 div2_xny.value s.value u.value in - (* x = x/2 *) - let new_x = srl x.value 1 in - - (* see div2_new_coeff *) - let new_s = div2_new_coeff s.value in - proc [ - (* combinatoinal drive the adder inputs *) - comb_add_a <-- s.value; - comb_add_b <-- modulus_reg.value; - comb_subtract <-- gnd; - - (* update register values and select next state *) - x <-- new_x; - s <-- new_s; - sm.set_next Op_sel; + (* Combinationally drive mod_add: c>>1 + mod>>1 *) + mod_add_valid_w <-- vdd; + mod_add_a_w <-- srl coeff 1; + mod_add_b_w <-- srl modulus_reg.value 1; + mod_add_subtract_w <-- gnd; + + when_ i.mod_add_ready [ + (* Shift the remainder (fires exactly once) *) + if_ div2_xny.value [ + x <-- srl x.value 1; + s <-- mux2 (lsb s.value) i.mod_add_result (srl s.value 1); + ] [ + y <-- srl y.value 1; + u <-- mux2 (lsb u.value) i.mod_add_result (srl u.value 1); + ]; + div2_coeff_odd <-- lsb coeff; (* needed for Div2_p1 *) + sm.set_next Div2_p1; + ]; ]; ]; - State.Div2_yu, [ - (* y = y/2 *) - (* if u is even: *) - (* u = u/2 *) - (* else: *) - (* u = (u + mod)/2 *) - - (* y = y/2 *) - let new_y = srl y.value 1 in - - (* see div2_new_coeff *) - let new_u = div2_new_coeff u.value in + State.Div2_p1, [ + (* see Div2_add for what this state implements *) proc [ - (* combinatoinal drive the adder inputs *) - comb_add_a <-- u.value; - comb_add_b <-- modulus_reg.value; - comb_subtract <-- gnd; - - (* update register values and select next state *) - y <-- new_y; - u <-- new_u; - sm.set_next Op_sel; + (* Combinationally drive mod_add: c + 1 *) + mod_add_valid_w <-- vdd; + mod_add_a_w <-- mux2 div2_xny.value s.value u.value; + mod_add_b_w <-- of_int ~width:Config.width 1; + mod_add_subtract_w <-- gnd; + + when_ i.mod_add_ready [ + if_ div2_xny.value [ + s <-- mux2 div2_coeff_odd.value i.mod_add_result s.value; + ] [ + u <-- mux2 div2_coeff_odd.value i.mod_add_result u.value; + ]; + sm.set_next Op_sel; + ]; ]; ]; @@ -209,22 +216,23 @@ end (* x = x - y *) (* else: *) (* Sub_rems_reverse *) - + (* combinatoinal drive the adder inputs *) - comb_add_a <-- x.value; - comb_add_b <-- y.value; - comb_subtract <-- vdd; - - (* update register values and select next state *) - if_ ~:(comb_add_out.carry_out) [ - (* carry_out = 0 means no underflow (x was >= y), store results and move to coefficients *) - x <-- comb_add_out.result; - reduced_xny <-- vdd; (* Subtract reduced x *) - (* REVISIT is it okay to skip a state conditionally? *) - sm.set_next Sub_coeffs; - ] [ - (* underflow, need Sub_rems_reverse instead *) - sm.set_next Sub_rems_reverse; + mod_add_valid_w <-- vdd; + mod_add_a_w <-- x.value; + mod_add_b_w <-- y.value; + mod_add_subtract_w <-- vdd; + + when_ i.mod_add_ready [ + if_ ~:(i.mod_add_adjusted) [ + (* No underflow: x >= y, result is x - y *) + x <-- i.mod_add_result; + reduced_xny <-- vdd; (* x was reduced *) + sm.set_next Sub_coeffs; + ] [ + (* Underflow: x < y, need reverse subtraction *) + sm.set_next Sub_rems_reverse; + ]; ]; ]; @@ -233,14 +241,16 @@ end (* y = y - x *) (* combinatoinal drive the adder inputs *) - comb_add_a <-- y.value; - comb_add_b <-- x.value; - comb_subtract <-- vdd; - - (* update register values and select next state *) - y <-- comb_add_out.result; - reduced_xny <-- gnd; (* Subtract reduced y *) - sm.set_next Sub_coeffs; + mod_add_valid_w <-- vdd; + mod_add_a_w <-- y.value; + mod_add_b_w <-- x.value; + mod_add_subtract_w <-- vdd; + + when_ i.mod_add_ready [ + y <-- i.mod_add_result; + reduced_xny <-- gnd; (* y was reduced *) + sm.set_next Sub_coeffs; + ]; ]; State.Sub_coeffs, [ @@ -251,62 +261,39 @@ end (* u = u-s *) (* combinatorial drive the adder inputs *) - (* if x was reduced: attempt s - u, else: attempt u - s *) - comb_add_a <-- mux2 reduced_xny.value s.value u.value; - comb_add_b <-- mux2 reduced_xny.value u.value s.value; - comb_subtract <-- vdd; - - (* update register values and select next state *) - (* the result is always valid, even if underflowed, Adjust_coeff will fix *) - if_ reduced_xny.value [ - s <-- comb_add_out.result; - ] [ - u <-- comb_add_out.result; - ]; - (* if underflow, go to Adjust_coeff to add modulus *) - if_ comb_add_out.carry_out [ - sm.set_next Adjust_coeff; - ] [ + (* if x was reduced: do s - u, else: u - s *) + mod_add_valid_w <-- vdd; + mod_add_a_w <-- mux2 reduced_xny.value s.value u.value; + mod_add_b_w <-- mux2 reduced_xny.value u.value s.value; + mod_add_subtract_w <-- vdd; + + when_ i.mod_add_ready [ + if_ reduced_xny.value [ + s <-- i.mod_add_result; + ] [ + u <-- i.mod_add_result; + ]; sm.set_next Op_sel; ]; ]; - State.Adjust_coeff, [ - (* _old might have changed by now *) - (* if x_old >= y_old: *) - (* if s_old < u_old: *) - (* s = s_old - u_old + mod *) - (* else: *) - (* if u_old < s_old: *) - (* u = u_old - s_old + mod *) - - (* combinatorial drive the adder inputs *) - comb_add_a <-- mux2 reduced_xny.value s.value u.value; - comb_add_b <-- modulus_reg.value; - comb_subtract <-- gnd; - - (* update register values and select next state *) - if_ reduced_xny.value [ - s <-- comb_add_out.result; - ] [ - u <-- comb_add_out.result; - ]; - sm.set_next Op_sel; + State.Done, [ + valid <-- gnd; + sm.set_next Idle; (* Return to Idle after one cycle *) ]; - -State.Done, [ - valid <-- gnd; - sm.set_next Idle; (* Return to Idle after one cycle *) -]; ]; ]; -{ O.result = result.value -- "result" -; valid = valid.value -- "valid" -; exists = exists.value -- "exists" -; dbg_x = x.value -; dbg_y = y.value -; dbg_s = s.value -; dbg_u = u.value -} -end \ No newline at end of file + { O.result = result.value -- "result" + ; valid = valid.value -- "valid" + ; exists = exists.value -- "exists" + ; mod_add_valid = mod_add_valid_w.value -- "mod_add_valid" + ; mod_add_a = mod_add_a_w.value -- "mod_add_a" + ; mod_add_b = mod_add_b_w.value -- "mod_add_b" + ; mod_add_subtract = mod_add_subtract_w.value -- "mod_add_subtract" + ; dbg_x = x.value + ; dbg_y = y.value + ; dbg_s = s.value + ; dbg_u = u.value + } +end diff --git a/src/mod_mul.ml b/src/mod_mul.ml index d9197fe..5cd03d3 100644 --- a/src/mod_mul.ml +++ b/src/mod_mul.ml @@ -9,18 +9,18 @@ module Config = struct end (* Simple Modular Multiplication - + Computes (x * y) mod r where r is the modulus. - - Uses the standard shift-and-add algorithm with modular reduction at each step. *) + + Uses the standard shift-and-add algorithm with modular reduction at each step. + Drives an external mod_add instance for all arithmetic. *) module ModMul = struct module State = struct type t = | Idle | Init | Add - | Adjust - | Double_adjust + | Double | Done [@@deriving sexp_of, compare, enumerate] end @@ -34,6 +34,8 @@ module ModMul = struct ; y : 'a [@bits Config.width] (* Second multiplicand *) ; modulus : 'a [@bits Config.width] (* Modulus r *) ; num_bits : 'a [@bits Config.bit_count_width] (* Number of bits to process in y (0-256) *) + ; mod_add_result : 'a [@bits Config.width] (* result from external mod_add *) + ; mod_add_ready : 'a (* ready from external mod_add *) } [@@deriving sexp_of, hardcaml] end @@ -42,6 +44,10 @@ module ModMul = struct type 'a t = { result : 'a [@bits Config.width] (* (x * y) mod r *) ; valid : 'a (* High when result is valid *) + ; mod_add_valid : 'a + ; mod_add_a : 'a [@bits Config.width] + ; mod_add_b : 'a [@bits Config.width] + ; mod_add_subtract : 'a } [@@deriving sexp_of, hardcaml] end @@ -59,7 +65,6 @@ module ModMul = struct (* Result accumulator *) let result_acc = Variable.reg spec ~width:width in - let add_carry = Variable.reg spec ~width:1 in (* Multiplier (y) and bit counter *) let multiplier = Variable.reg spec ~width in @@ -67,27 +72,18 @@ module ModMul = struct (* Current value of x (gets doubled each iteration) *) let x_current = Variable.reg spec ~width:width in - - (* Stored modulus *) - let modulus_reg = Variable.reg spec ~width:width in + let num_bits_orig = Variable.reg spec ~width:bit_count_width in (* Output registers *) let result = Variable.reg spec ~width in let valid = Variable.reg spec ~width:1 in - (* 256 bit adder instance *) - let in_add_state = sm.is Add in - let in_double_adj_state = sm.is Double_adjust in - let x_doubled = sll x_current.value 1 in - - let comb_add_a = mux2 ~:in_double_adj_state result_acc.value x_doubled in - let comb_add_b = mux2 in_add_state x_current.value modulus_reg.value in - let comb_subtract = mux2 in_add_state gnd vdd in - let comb_add_out = - Comb_add.CombAdd.create (Scope.sub_scope scope "comb_add") - { Comb_add.CombAdd.I.a = comb_add_a; b = comb_add_b; subtract = comb_subtract } - in + (* Output wires for driving external mod_add *) + let mod_add_valid_w = Variable.wire ~default:gnd in + let mod_add_a_w = Variable.wire ~default:(zero Config.width) in + let mod_add_b_w = Variable.wire ~default:(zero Config.width) in + let mod_add_subtract_w = Variable.wire ~default:gnd in compile [ sm.switch [ @@ -109,7 +105,6 @@ module ModMul = struct sm.set_next Done; ] [ x_current <-- i.x; - modulus_reg <-- i.modulus; multiplier <-- i.y; num_bits_orig <-- i.num_bits; sm.set_next Init; @@ -125,65 +120,46 @@ module ModMul = struct ]; State.Add, [ - (* Check LSB of multiplier *) - let current_bit = lsb multiplier.value in - - (* If bit is set, add current x to result *) - let after_add = mux2 current_bit comb_add_out.result result_acc.value in - - proc [ - valid <-- gnd; - result_acc <-- after_add; - add_carry <-- mux2 current_bit comb_add_out.carry_out gnd; - sm.set_next Adjust; - ]; - ]; - - State.Adjust, [ - - (* Reduce if >= modulus: - -- if the addition had overflow (add_carry=1) - -- if subtracting the modulus does not underflow (comb_add_out.carry_out=0) *) - let add_reduce_needed = add_carry.value |: ~:(comb_add_out.carry_out) in - let after_add_reduce = mux2 add_reduce_needed comb_add_out.result result_acc.value in - - proc [ - valid <-- gnd; - result_acc <-- after_add_reduce; - add_carry <-- gnd; (* Clear carry as it is N/A here *) - sm.set_next Double_adjust; + if_ (lsb multiplier.value) [ + (* Combinatorially drive mod_add *) + mod_add_valid_w <-- vdd; + mod_add_a_w <-- result_acc.value; + mod_add_b_w <-- x_current.value; + mod_add_subtract_w <-- gnd; + + when_ i.mod_add_ready [ + result_acc <-- i.mod_add_result; + sm.set_next Double; + ]; + ] [ + sm.set_next Double; (* LSB = 0: skip addition *) ]; ]; - State.Double_adjust, [ - - (* Double x for next iteration and - reduce new x if >= modulus: - -- if the doubling caused overflow (MSB=1) - -- if subtracting the modulus does not underflow (comb_add_out.carry_out=0) *) - let double_reduce_needed = msb x_current.value |: ~:(comb_add_out.carry_out) in - let new_x = mux2 double_reduce_needed comb_add_out.result x_doubled in - let new_multiplier = srl multiplier.value 1 in - let new_bit_count = bit_count.value +:. 1 in - - proc [ - valid <-- gnd; - x_current <-- new_x; - multiplier <-- new_multiplier; - bit_count <-- new_bit_count; - - if_ ((new_bit_count ==: num_bits_orig.value) |: (new_multiplier ==:. 0)) [ - (* Exit when done OR when no more bits to process *) - result <-- result_acc.value; - valid <-- vdd; - sm.set_next Done; - ] (* REVISIT is it okay to just skip Add and Adjust if LSB=0? *) - @@ elif ( ~:(lsb new_multiplier) ) [ - (* Skip Add if next bit is 0 *) - sm.set_next Double_adjust; - ] - @@ [ - sm.set_next Add; + State.Double, [ + (* Combinatorially drive mod_add: x_current + x_current (double) *) + mod_add_valid_w <-- vdd; + mod_add_a_w <-- x_current.value; + mod_add_b_w <-- x_current.value; + mod_add_subtract_w <-- gnd; + + when_ i.mod_add_ready [ + let new_multiplier = srl multiplier.value 1 in + let new_bit_count = bit_count.value +:. 1 in + proc [ + x_current <-- i.mod_add_result; + multiplier <-- new_multiplier; + bit_count <-- new_bit_count; + + if_ ((new_bit_count ==: num_bits_orig.value) |: (new_multiplier ==:. 0)) [ + result <-- result_acc.value; + valid <-- vdd; + sm.set_next Done; + ] @@ elif (~:(lsb new_multiplier)) [ + sm.set_next Double; (* skip Add when next bit = 0 *) + ] @@ [ + sm.set_next Add; + ]; ]; ]; ]; @@ -195,8 +171,11 @@ module ModMul = struct ]; ]; - { O.result = result.value -- "result" - ; valid = valid.value -- "valid" + { O.result = result.value -- "result" + ; valid = valid.value -- "valid" + ; mod_add_valid = mod_add_valid_w.value -- "mod_add_valid" + ; mod_add_a = mod_add_a_w.value -- "mod_add_a" + ; mod_add_b = mod_add_b_w.value -- "mod_add_b" + ; mod_add_subtract = mod_add_subtract_w.value -- "mod_add_subtract" } end - diff --git a/test/test_mod_inv.ml b/test/test_mod_inv.ml index 1f047c7..d933d7b 100644 --- a/test/test_mod_inv.ml +++ b/test/test_mod_inv.ml @@ -1,13 +1,49 @@ open Base open Hardcaml +module ModInvWithModAdd = struct + module I = Mod_inv.ModInv.I + module O = Mod_inv.ModInv.O + + let create scope (i : _ I.t) = + let mod_add_result_w = Signal.wire Mod_inv.Config.width in + let mod_add_ready_w = Signal.wire 1 in + let mod_add_adjusted_w = Signal.wire 1 in + + let inv_out = Mod_inv.ModInv.create (Scope.sub_scope scope "mod_inv") + { i with + mod_add_result = mod_add_result_w + ; mod_add_ready = mod_add_ready_w + ; mod_add_adjusted = mod_add_adjusted_w + } + in + + let add_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add") + { Mod_add.ModAdd.I. + clock = i.clock + ; clear = i.clear + ; valid = inv_out.mod_add_valid + ; a = inv_out.mod_add_a + ; b = inv_out.mod_add_b + ; modulus = i.modulus + ; subtract = inv_out.mod_add_subtract + } + in + + Signal.assign mod_add_result_w add_out.result; + Signal.assign mod_add_ready_w add_out.ready; + Signal.assign mod_add_adjusted_w add_out.adjusted; + + inv_out +end + let test () = Stdio.printf "=== ModInv Hardware Test (256-bit with Zarith) ===\n"; Stdio.printf "=== Assuming odd prime modulus ===\n\n"; let scope = Scope.create ~flatten_design:true () in - let module Sim = Cyclesim.With_interface(Mod_inv.ModInv.I)(Mod_inv.ModInv.O) in - let sim = Sim.create (Mod_inv.ModInv.create scope) in + let module Sim = Cyclesim.With_interface(ModInvWithModAdd.I)(ModInvWithModAdd.O) in + let sim = Sim.create (ModInvWithModAdd.create scope) in let inputs = Cyclesim.inputs sim in let outputs = Cyclesim.outputs sim in diff --git a/test/test_mod_mul.ml b/test/test_mod_mul.ml index 9d0ed2f..1301363 100644 --- a/test/test_mod_mul.ml +++ b/test/test_mod_mul.ml @@ -1,12 +1,45 @@ open Base open Hardcaml +module ModMulWithModAdd = struct + module I = Mod_mul.ModMul.I + module O = Mod_mul.ModMul.O + + let create scope (i : _ I.t) = + let mod_add_result_w = Signal.wire Mod_mul.Config.width in + let mod_add_ready_w = Signal.wire 1 in + + let mul_out = Mod_mul.ModMul.create (Scope.sub_scope scope "mod_mul") + { i with + mod_add_result = mod_add_result_w + ; mod_add_ready = mod_add_ready_w + } + in + + let add_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add") + { Mod_add.ModAdd.I. + clock = i.clock + ; clear = i.clear + ; valid = mul_out.mod_add_valid + ; a = mul_out.mod_add_a + ; b = mul_out.mod_add_b + ; modulus = i.modulus + ; subtract = mul_out.mod_add_subtract + } + in + + Signal.(mod_add_result_w <== add_out.result); + Signal.(mod_add_ready_w <== add_out.ready); + + mul_out +end + let test () = Stdio.printf "=== ModMul Hardware Test (256-bit with Zarith and num_bits) ===\n\n"; let scope = Scope.create ~flatten_design:true () in - let module Sim = Cyclesim.With_interface(Mod_mul.ModMul.I)(Mod_mul.ModMul.O) in - let sim = Sim.create (Mod_mul.ModMul.create scope) in + let module Sim = Cyclesim.With_interface(ModMulWithModAdd.I)(ModMulWithModAdd.O) in + let sim = Sim.create (ModMulWithModAdd.create scope) in let inputs = Cyclesim.inputs sim in let outputs = Cyclesim.outputs sim in @@ -35,8 +68,8 @@ let test () = let y_bits = z_to_bits y_z in let mod_bits = z_to_bits mod_z in let num_bits = z_num_bits y_z in - - Stdio.printf " y bits = %d (optimization: %dx speedup)\n" + + Stdio.printf " y bits = %d (optimization: %dx speedup)\n" num_bits (Mod_mul.Config.width / (max 1 num_bits)); (* Reset *) @@ -69,21 +102,21 @@ let test () = false end else if Bits.to_bool !(outputs.valid) then begin let result_bits = !(outputs.result) in - let result_z = + let result_z = let bin_str = Bits.to_bstr result_bits in Z.of_string_base 2 bin_str in - + Stdio.printf " Completed in %d cycles\n" cycle_count; Stdio.printf " result = %s\n" (Z.to_string result_z); - + let verified = Z.equal result_z expected_z in - + if verified then Stdio.printf " Verification: PASS ✓\n" else Stdio.printf " Verification: FAIL ✗\n"; - + Stdio.printf "\n"; verified end else begin @@ -97,32 +130,32 @@ let test () = let results = [ test_case "3 * 5 mod 7" (Z.of_int 3) (Z.of_int 5) (Z.of_int 7); - + test_case "12 * 15 mod 17" (Z.of_int 12) (Z.of_int 15) (Z.of_int 17); - + test_case "123 * 456 mod 1009" (Z.of_int 123) (Z.of_int 456) (Z.of_int 1009); - + test_case "Edge: x * 0 mod m" (Z.of_int 12345) (Z.of_int 0) (Z.of_int 7919); - + test_case "Edge: x * 1 mod m" (Z.of_int 12345) (Z.of_int 1) (Z.of_int 7919); - + test_case "Large: 123456 * 789012 mod 1000003" (Z.of_int 123456) (Z.of_int 789012) (Z.of_int 1000003); - + test_case "64-bit multiplication" (Z.of_string "12345678901234") (Z.of_string "98765432109876") (Z.of_string "999999999999999989"); - + test_case "128-bit multiplication" (Z.of_string "123456789012345678901234567890") (Z.of_string "987654321098765432109876543210") (Z.of_string "340282366920938463463374607431768211297"); - + test_case "256-bit: multiplication mod secp256k1 prime" (Z.of_string "123456789012345678901234567890123456789") (Z.of_string "987654321098765432109876543210987654321") @@ -152,4 +185,4 @@ let test () = end else Stdio.printf "Some tests failed ✗\n" -let () = test () \ No newline at end of file +let () = test ()