diff --git a/README.md b/README.md
index 4179898..3bfc018 100644
--- a/README.md
+++ b/README.md
@@ -71,24 +71,24 @@ flowchart TB
subgraph external_left[" "]
AUTH["License
Authority"]:::external
end
-
+
subgraph SECURITY_BLOCK["SECURITY BLOCK"]
direction TB
SL["Security Logic
(State Machine)"]:::security
-
+
subgraph submodules[" "]
direction LR
TRNG["TRNG
256-bit"]:::trng
ECDSA["ECDSA
secp256k1"]:::ecdsa
ALLOW["Allowance
64-bit"]:::allowance
end
-
+
subgraph datapath[" "]
direction LR
ADDER["Int8 Add"]:::adder
AND["AND Gate"]:::andgate
end
-
+
SL -->|request_new| TRNG
TRNG -->|"nonce, valid"| SL
SL -->|start| ECDSA
@@ -97,17 +97,17 @@ flowchart TB
ALLOW -->|enabled| AND
ADDER --> AND
end
-
+
subgraph external_io[" "]
direction LR
WIN["Workload
Input"]:::external
WOUT["Workload
Output"]:::external
end
-
+
AUTH <-->|"license_submit, r, s
nonce, ready"| SL
WIN --> ADDER
AND --> WOUT
-
+
classDef external fill:#fff,stroke:#333,stroke-dasharray: 5 5
classDef security fill:#cce5ff,stroke:#004085
classDef trng fill:#c3e6cb,stroke:#155724
@@ -309,22 +309,22 @@ Here's an expanded section on the ECDSA and modular arithmetic architecture to a
flowchart TB
subgraph ECDSA["ECDSA Verification Block"]
direction LR
-
+
subgraph left[" "]
direction TB
-
+
subgraph SM["State Machine"]
direction TB
SM_PREP["Prep Phase
u1, u2 computation"]
SM_LOOP["Scalar Mult Loop
256 iterations"]
SM_FIN["Finalize
projective to affine"]
SM_CMP["Compare
x_affine == r ?"]
-
+
SM_PREP --> SM_LOOP
SM_LOOP --> SM_FIN
SM_FIN --> SM_CMP
end
-
+
subgraph REGS["Register File --- 17 x 256-bit"]
direction LR
R_PT["Point Coords
X1 Y1 Z1
X2 Y2 Z2
X3 Y3 Z3"]
@@ -332,40 +332,40 @@ flowchart TB
R_PRM["Params
a, b3"]
end
end
-
+
subgraph right[" "]
direction TB
-
+
subgraph ARITH["Modular Arithmetic Unit"]
direction TB
-
- subgraph ops[" "]
- direction LR
- INV["Inverse
Ext Euclidean"]
- MUL["Multiply
shift-and-add"]
- ADDSUB["Add - Sub"]
+
+ subgraph INV["Inverse
Ext Euclidean"]
+ direction TB
+ end
+
+ subgraph MUL["Multiply
shift-and-add"]
+ direction TB
end
-
- subgraph shared["Shared Datapath"]
+
+ subgraph ADDSUB["Add - Sub"]
direction TB
MOD["Modulus Select
prime p or order n"]
ADD256["256-bit Adder"]
MOD --> ADD256
end
-
- INV --> shared
- MUL --> shared
- ADDSUB --> shared
+
+ INV --> ADDSUB
+ MUL --> ADDSUB
end
end
-
+
SM <-->|"start, op
done"| ARITH
REGS <-->|"read A B
write result"| ARITH
end
-
+
EXT_IN["Inputs:
z, r, s"] --> ECDSA
ECDSA --> EXT_OUT["Output:
valid"]
-
+
classDef outer fill:#f0f7ff,stroke:#2563eb,stroke-width:2px,color:#1e40af
classDef arithbox fill:#fef9e7,stroke:#b7950b,stroke-width:2px,color:#7d6608
classDef smbox fill:#e8f8f5,stroke:#1abc9c,stroke-width:2px,color:#0e6655
@@ -375,7 +375,7 @@ flowchart TB
classDef subunit fill:#fdebd0,stroke:#e67e22,stroke-width:1px,color:#a04000
classDef sharedbox fill:#fcf3cf,stroke:#d4ac0d,stroke-width:1px,color:#9a7d0a
classDef external fill:#ffffff,stroke:#5d6d7e,stroke-width:1px,stroke-dasharray: 5 5,color:#2c3e50
-
+
class ECDSA outer
class ARITH arithbox
class SM smbox
@@ -458,8 +458,8 @@ All operations work over 256-bit operands and can use either the field prime `p`
The arithmetic unit interfaces with a 17-register file. Operations are started with a pulse and signal completion via `done_`. Typical cycle counts:
- Add/Sub: 2–3 cycles
-- Mul: ~250-750 cycles (bit-serial, varies with y input)
-- Inv: ~1000–1500 cycles (varies with input)
+- Mul: ~500-1000 cycles (bit-serial, varies with y input)
+- Inv: ~2000–3000 cycles (varies with input)
### State Machine Overview
@@ -487,7 +487,7 @@ Idle → Prep_op → Loop ⟷ Load → Run_add → Finalize_op → Compare → D
### Cycle Count
-Total verification takes approximately 3–4 million cycles, dominated by the ~256 point operations in the scalar multiplication loop. At 1 GHz, this is in milliseconds—negligible compared to the licensing interval (minutes to days).
+Total verification takes approximately 5 million cycles, dominated by the ~256 point operations in the scalar multiplication loop. At 1 GHz, this is in milliseconds—negligible compared to the licensing interval (minutes to days).
### Hardcoded Constants
diff --git a/src/arith.ml b/src/arith.ml
index 234e624..3cbde80 100644
--- a/src/arith.ml
+++ b/src/arith.ml
@@ -3,24 +3,24 @@ open Signal
(* Arith - Modular arithmetic unit for secp256k1 field operations
-
+
Performs add, sub, mul, inv modulo either field prime p or curve order n.
Operands are read from and results written to an external 32×256-bit register file.
-
+
Operations (op input):
0 = add: r[addr_out] <- r[addr_a] + r[addr_b] mod m
1 = sub: r[addr_out] <- r[addr_a] - r[addr_b] mod m
2 = mul: r[addr_out] <- r[addr_a] * r[addr_b] mod m
3 = inv: r[addr_out] <- r[addr_a]^(-1) mod m (addr_b ignored)
-
+
Modulus selection (prime_sel): 0 = prime_p, 1 = prime_n
-
+
Protocol:
1. Set addr_a, addr_b, addr_out, op, prime_sel; pulse start
2. Provide reg_read_data_a/b in response to reg_read_addr_a/b
3. Wait for done_ pulse; result written via reg_write_* signals
4. For inv, check inv_exists to confirm inverse was found
-
+
State machine: Idle -> Load -> Capture -> Compute -> Write -> Done -> Idle
*)
@@ -28,12 +28,12 @@ module Config = struct
let width = 256
let num_registers = 32
let reg_addr_width = 5
-
+
(* secp256k1 field prime: p = 2^256 - 2^32 - 977 *)
let prime_p = Z.of_string "115792089237316195423570985008687907853269984665640564039457584007908834671663"
(* secp256k1 curve order - CORRECTED *)
let prime_n = Z.of_string "115792089237316195423570985008687907852837564279074904382605163141518161494337"
-
+
let z_to_constant z =
let hex_str = Z.format "%x" z in
let padded = String.make ((width / 4) - String.length hex_str) '0' ^ hex_str in
@@ -91,9 +91,9 @@ end
let create scope (i : _ I.t) =
let open Always in
let ( -- ) = Scope.naming scope in
-
+
let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in
-
+
(* State machine *)
let sm = State_machine.create (module State) spec ~enable:vdd in
@@ -103,47 +103,54 @@ let create scope (i : _ I.t) =
let addr_a_reg = Variable.reg spec ~width:Config.reg_addr_width in
let addr_b_reg = Variable.reg spec ~width:Config.reg_addr_width in
let addr_out_reg = Variable.reg spec ~width:Config.reg_addr_width in
-
+
(* Captured operands *)
let operand_a = Variable.reg spec ~width:Config.width in
let operand_b = Variable.reg spec ~width:Config.width in
-
+
(* Operation start signals *)
- let start_add = Variable.reg spec ~width:1 in
- let start_sub = Variable.reg spec ~width:1 in
+ let mod_add_valid = Variable.reg spec ~width:1 in
+ let mod_sub_valid = Variable.reg spec ~width:1 in
let start_mul = Variable.reg spec ~width:1 in
let start_inv = Variable.reg spec ~width:1 in
-
+
(* Result capture *)
let result_reg = Variable.reg spec ~width:Config.width in
let inv_exists_reg = Variable.reg spec ~width:1 in
-
+
(* Output registers *)
let reg_write_enable = Variable.reg spec ~width:1 in
let done_flag = Variable.reg spec ~width:1 in
-
+
(* Prime constants *)
let prime_p_const = Signal.of_constant (Config.z_to_constant Config.prime_p) in
let prime_n_const = Signal.of_constant (Config.z_to_constant Config.prime_n) in
let selected_prime = mux2 prime_sel_reg.value prime_n_const prime_p_const in
-
+
(* Count significant bits in operand_b for multiplication optimization *)
(* Simple approach: use full width, or implement leading zero count *)
let num_bits_for_mul = of_int ~width:9 Config.width in
-
+
(* Instantiate arithmetic modules *)
- let mod_addsub_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add")
+
+ (* Forward declaration wires for mod_add inputs, driven below after mod_mul is available *)
+ let mod_add_valid_w = Signal.wire 1 in
+ let mod_add_a_w = Signal.wire Config.width in
+ let mod_add_b_w = Signal.wire Config.width in
+ let mod_add_subtract_w = Signal.wire 1 in
+
+ let mod_add_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add")
{ Mod_add.ModAdd.I.
- clock = i.clock
- ; clear = i.clear
- ; start = start_add.value |: start_sub.value
- ; a = operand_a.value
- ; b = operand_b.value
- ; modulus = selected_prime
- ; subtract = start_sub.value
+ clock = i.clock
+ ; clear = i.clear
+ ; valid = mod_add_valid_w
+ ; a = mod_add_a_w
+ ; b = mod_add_b_w
+ ; modulus = selected_prime
+ ; subtract = mod_add_subtract_w
}
in
-
+
let mod_mul_out = Mod_mul.ModMul.create (Scope.sub_scope scope "mod_mul")
{ Mod_mul.ModMul.I.
clock = i.clock
@@ -153,9 +160,11 @@ let create scope (i : _ I.t) =
; y = operand_b.value
; modulus = selected_prime
; num_bits = num_bits_for_mul
+ ; mod_add_result = mod_add_out.result
+ ; mod_add_ready = mod_add_out.ready
}
in
-
+
let mod_inv_out = Mod_inv.ModInv.create (Scope.sub_scope scope "mod_inv")
{ Mod_inv.ModInv.I.
clock = i.clock
@@ -163,37 +172,65 @@ let create scope (i : _ I.t) =
; start = start_inv.value
; x = operand_a.value
; modulus = selected_prime
+ ; mod_add_result = mod_add_out.result
+ ; mod_add_ready = mod_add_out.ready
+ ; mod_add_adjusted = mod_add_out.adjusted
}
in
-
+
+ (* Drive mod_add inputs, muxed by operation *)
+ assign mod_add_valid_w (mux op_reg.value [
+ mod_add_valid.value; (* add *)
+ mod_sub_valid.value; (* sub *)
+ mod_mul_out.mod_add_valid; (* mul *)
+ mod_inv_out.mod_add_valid; (* inv *)
+ ]);
+ assign mod_add_a_w (mux op_reg.value [
+ operand_a.value; (* add *)
+ operand_a.value; (* sub *)
+ mod_mul_out.mod_add_a; (* mul *)
+ mod_inv_out.mod_add_a; (* inv *)
+ ]);
+ assign mod_add_b_w (mux op_reg.value [
+ operand_b.value; (* add *)
+ operand_b.value; (* sub *)
+ mod_mul_out.mod_add_b; (* mul *)
+ mod_inv_out.mod_add_b; (* inv *)
+ ]);
+ assign mod_add_subtract_w (mux op_reg.value [
+ gnd; (* add *)
+ vdd; (* sub *)
+ mod_mul_out.mod_add_subtract; (* mul *)
+ mod_inv_out.mod_add_subtract; (* inv *)
+ ]);
+
(* Mux results based on operation *)
- let op_result =
+ let op_result =
mux op_reg.value [
- mod_addsub_out.result;
- mod_addsub_out.result;
+ mod_add_out.result;
+ mod_add_out.result;
mod_mul_out.result;
mod_inv_out.result;
]
in
-
- let op_valid =
+
+ let op_ready =
mux op_reg.value [
- mod_addsub_out.valid;
- mod_addsub_out.valid;
+ mod_add_out.ready;
+ mod_add_out.ready;
mod_mul_out.valid;
mod_inv_out.valid;
]
in
-
+
compile [
(* Default: clear pulse signals *)
- start_add <-- gnd;
- start_sub <-- gnd;
+ (* TODO move start_mul, start_inv clear to Compute step as well when updated *)
start_mul <-- gnd;
start_inv <-- gnd;
reg_write_enable <-- gnd;
done_flag <-- gnd;
-
+
sm.switch [
State.Idle, [
when_ i.start [
@@ -206,43 +243,47 @@ let create scope (i : _ I.t) =
sm.set_next Load;
];
];
-
+
State.Load, [
(* Wait one cycle for register file to provide data *)
sm.set_next Capture;
];
-
+
State.Capture, [
operand_a <-- i.reg_read_data_a;
operand_b <-- i.reg_read_data_b;
-
+
switch op_reg.value [
- of_int ~width:2 Op.add, [ start_add <-- vdd ];
- of_int ~width:2 Op.sub, [ start_sub <-- vdd ];
+ of_int ~width:2 Op.add, [ mod_add_valid <-- vdd ];
+ of_int ~width:2 Op.sub, [ mod_sub_valid <-- vdd ];
of_int ~width:2 Op.mul, [ start_mul <-- vdd ];
of_int ~width:2 Op.inv, [ start_inv <-- vdd ];
];
-
+
sm.set_next Compute;
];
State.Compute, [
(* Simply wait for operation to complete *)
- when_ op_valid [
+ when_ op_ready [
+ (* Clear valid signals *)
+ mod_add_valid <-- gnd;
+ mod_sub_valid <-- gnd;
+
result_reg <-- op_result;
inv_exists_reg <-- mod_inv_out.exists;
sm.set_next Write;
];
];
-
+
State.Write, [
(* Write result to register file *)
reg_write_enable <-- vdd;
sm.set_next Done;
];
-
+
State.Done, [
done_flag <-- vdd;
if_ i.start [
@@ -259,10 +300,10 @@ State.Done, [
];
];
];
-
+
(* Busy when not idle *)
let busy = ~:(sm.is Idle) in
-
+
{ O.
busy = busy -- "busy"
; done_ = done_flag.value -- "done"
diff --git a/src/mod_add.ml b/src/mod_add.ml
index 53911e0..fc677fd 100644
--- a/src/mod_add.ml
+++ b/src/mod_add.ml
@@ -7,18 +7,18 @@ module Config = struct
let width = 256
end
-(* Modular Addition/Subtraction Module (combinatorial adder variant)
+(* Modular Addition/Subtraction Module
Performs (a ± b) mod n where n is provided as an input.
- - Uses comb_add for the initial a ± b in a single cycle
- - 2-cycle operation: Free (latch + add) → Modulus_adjust (modular reduction)
+ - Uses an internal comb_add for arithmetic
+ - 2-cycle operation: Add (latch + add) → Adjust (modular reduction)
- Automatic modular reduction
*)
module ModAdd = struct
module State = struct
type t =
- | Free
- | Modulus_adjust
+ | Add
+ | Adjust
[@@deriving sexp_of, compare, enumerate]
end
@@ -26,7 +26,7 @@ module ModAdd = struct
type 'a t =
{ clock : 'a
; clear : 'a
- ; start : 'a
+ ; valid : 'a
; a : 'a [@bits Config.width] (* First operand *)
; b : 'a [@bits Config.width] (* Second operand *)
; modulus : 'a [@bits Config.width] (* Modulus n *)
@@ -37,8 +37,9 @@ module ModAdd = struct
module O = struct
type 'a t =
- { result : 'a [@bits Config.width]
- ; valid : 'a
+ { result : 'a [@bits Config.width]
+ ; ready : 'a (* 1 result is valid *)
+ ; adjusted : 'a (* 1 if modular correction was applied this cycle, valid with ready *)
}
[@@deriving sexp_of, hardcaml]
end
@@ -46,7 +47,7 @@ module ModAdd = struct
let create scope (i : _ I.t) =
let open Always in
let ( -- ) = Scope.naming scope in
-
+
let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in
(* State machine *)
@@ -55,44 +56,40 @@ module ModAdd = struct
(* Registers for computation *)
let result_ab = Variable.reg spec ~width:Config.width in
let carry_ab = Variable.reg spec ~width:1 in
- let modulus_reg = Variable.reg spec ~width:Config.width in
- let is_subtract = Variable.reg spec ~width:1 in
-
- (* Output registers *)
- let result = Variable.reg spec ~width:Config.width in
- let valid = Variable.reg spec ~width:1 in
-
- (* Mux comb_add inputs by state:
- - Free: i.a ± i.b (initial operation)
- - Modulus_adjust: result_ab ± modulus (correction: add modulus if sub underflow,
- subtract modulus if sum higher than modulus) *)
- let in_modulus_adjust = sm.is Modulus_adjust in
- let comb_a = mux2 in_modulus_adjust result_ab.value i.a in
- let comb_b = mux2 in_modulus_adjust modulus_reg.value i.b in
+
+ (* Adder inputs muxed by state:
+ - Add: i.a ± i.b (initial operation)
+ - Adjust: result_ab ± modulus (correction step) *)
+ let in_adjust = sm.is Adjust in
+ let adder_a = mux2 in_adjust result_ab.value i.a in
+ let adder_b = mux2 in_adjust i.modulus i.b in
(* modulus adjust uses the opposite of the original operation - add if a-b, subtract if a+b *)
- let comb_subtract = mux2 in_modulus_adjust ~:(is_subtract.value) i.subtract in
+ let adder_subtract = mux2 in_adjust ~:(i.subtract) i.subtract in
+ let adder_ready = vdd in (* currently using combinatorial adder, so no delay *)
let comb_add_out =
Comb_add.CombAdd.create (Scope.sub_scope scope "comb_add")
- { Comb_add.CombAdd.I.a = comb_a; b = comb_b; subtract = comb_subtract }
+ { Comb_add.CombAdd.I.a = adder_a; b = adder_b; subtract = adder_subtract }
in
+ let result_w = Variable.wire ~default:(zero Config.width) in
+ let ready_w = Variable.wire ~default:gnd in
+ let adjusted_w = Variable.wire ~default:gnd in
+
compile [
sm.switch [
- State.Free, [
- valid <-- gnd;
- when_ i.start [
+ State.Add, [
+ when_ (i.valid &: adder_ready) [
+ (* update register values and select next state *)
result_ab <-- comb_add_out.result;
carry_ab <-- comb_add_out.carry_out;
- modulus_reg <-- i.modulus;
- is_subtract <-- i.subtract;
- sm.set_next Modulus_adjust;
+ sm.set_next Adjust;
];
];
-
- State.Modulus_adjust, [
-
- (* comb_add is computing result_ab ± modulus_reg here. *)
+
+ State.Adjust, [
+
+ (* adder is computing result_ab ± modulus here *)
(* For add:
-- reduce if carry_ab=1 (a+b overflowed 256 bits, so a+b >= 2^256 > n)
NOTE: comb_add_out.carry_out is not valid in this case
@@ -104,22 +101,26 @@ module ModAdd = struct
let sub_needs_adjust = carry_ab.value in
let final_result =
- mux2 is_subtract.value
+ mux2 i.subtract
(mux2 sub_needs_adjust comb_add_out.result result_ab.value)
(mux2 add_needs_adjust comb_add_out.result result_ab.value)
in
- proc [
- result <-- final_result;
- valid <-- vdd;
- sm.set_next Free;
+ when_ adder_ready [
+ (* combinatorial: drive result output *)
+ result_w <-- final_result;
+ ready_w <-- vdd; (* No clear needed, Variable.wire ~default takes care *)
+ adjusted_w <-- mux2 i.subtract sub_needs_adjust add_needs_adjust;
+ (* select next state *)
+ sm.set_next Add;
];
];
];
];
- { O.result = result.value -- "result"
- ; valid = valid.value -- "valid"
+ { O.result = result_w.value -- "result"
+ ; ready = ready_w.value -- "ready"
+ ; adjusted = adjusted_w.value -- "adjusted"
}
end
diff --git a/src/mod_inv.ml b/src/mod_inv.ml
index f2c0bc3..276b023 100644
--- a/src/mod_inv.ml
+++ b/src/mod_inv.ml
@@ -6,25 +6,26 @@ module Config = struct
let width = 256
end
-(* Binary Extended GCD for modular inverse computation
-
+(* Binary Extended GCD for modular inverse computation
+
REQUIREMENT: This implementation assumes the modulus is an odd prime.
- For odd prime modulus and any x coprime to it, gcd(x, modulus) = 1
- Since modulus is odd, at most one of {x, modulus} can be even
- Therefore we never need to handle the "both even" case or factor out common powers of 2
- This simplifies the algorithm significantly compared to the general case
+
+ All arithmetic is driven through an external mod_add instance (shared with add/sub/mul).
*)
module ModInv = struct
module State = struct
type t =
| Idle
| Op_sel (* decide which operation to start next *)
- | Div2_xs (* divide x and s by 2, s might need adjusting *)
- | Div2_yu (* divide y and u by 2, u might need adjusting *)
- | Sub_rems (* subtract remainders *)
- | Sub_rems_reverse (* reverse subtract remainders, we don't know in advance which direction needed *)
+ | Div2_add (* divide appropriate remainder and corresponding coeffiecient by 2, start adjusting the coefficient (c>>1 + mod>>1) if needed *)
+ | Div2_p1 (* finish adjusting the coefficient (c=c+1) if needed *)
+ | Sub_rems (* subtract remainders: x - y *)
+ | Sub_rems_reverse (* reverse subtract: y - x (when x < y) *)
| Sub_coeffs (* subtract coefficients *)
- | Adjust_coeff (* adjust the reduced coefficient after subtraction *)
| Done
[@@deriving sexp_of, compare, enumerate]
end
@@ -36,70 +37,61 @@ module ModInv = struct
; start : 'a
; x : 'a [@bits Config.width]
; modulus : 'a [@bits Config.width]
+ ; mod_add_result : 'a [@bits Config.width]
+ ; mod_add_ready : 'a
+ ; mod_add_adjusted : 'a (* underflow/overflow flag, valid when mod_add_ready *)
}
[@@deriving sexp_of, hardcaml]
end
-module O = struct
- type 'a t =
- { result : 'a [@bits Config.width]
- ; valid : 'a
- ; exists : 'a
- (* Debug outputs *)
- ; dbg_x : 'a [@bits Config.width]
- ; dbg_y : 'a [@bits Config.width]
- ; dbg_s : 'a [@bits Config.width]
- ; dbg_u : 'a [@bits Config.width]
- }
- [@@deriving sexp_of, hardcaml]
-end
+ module O = struct
+ type 'a t =
+ { result : 'a [@bits Config.width]
+ ; valid : 'a
+ ; exists : 'a
+ ; mod_add_valid : 'a
+ ; mod_add_a : 'a [@bits Config.width]
+ ; mod_add_b : 'a [@bits Config.width]
+ ; mod_add_subtract : 'a
+ (* Debug outputs *)
+ ; dbg_x : 'a [@bits Config.width]
+ ; dbg_y : 'a [@bits Config.width]
+ ; dbg_s : 'a [@bits Config.width]
+ ; dbg_u : 'a [@bits Config.width]
+ }
+ [@@deriving sexp_of, hardcaml]
+ end
let create scope (i : _ I.t) =
let open Always in
let ( -- ) = Scope.naming scope in
let width = Config.width in
-
+
let spec = Reg_spec.create ~clock:i.clock ~clear:i.clear () in
-
+
let sm = State_machine.create (module State) spec ~enable:vdd in
let x = Variable.reg spec ~width in (* remainder *)
let y = Variable.reg spec ~width in (* remainder *)
let s = Variable.reg spec ~width in (* coefficient for x *)
let u = Variable.reg spec ~width in (* coefficient for y *)
-
+
let modulus_reg = Variable.reg spec ~width:width in
- (* helper reg for the FSM *)
- let reduced_xny = Variable.reg spec ~width:1 in (* 1 = x was reduced (x-y), 0 = y was reduced (y-x) *)
-
+ (* helper regs for the FSM *)
+ let reduced_xny = Variable.reg spec ~width:1 in (* 1 = x was reduced (x-y), 0 = y was reduced (y-x) *)
+ let div2_xny = Variable.reg spec ~width:1 in (* 1 = divide x, 0 = divide y *)
+ let div2_coeff_odd = Variable.reg spec ~width:1 in (* was original coefficient odd? *)
+
let result = Variable.reg spec ~width in
let valid = Variable.reg spec ~width:1 in
let exists = Variable.reg spec ~width:1 in
- (* 256 bit adder instance *)
- let comb_add_a = Variable.wire ~default:(zero width) in
- let comb_add_b = Variable.wire ~default:(zero width) in
- let comb_subtract = Variable.wire ~default:gnd in
- let comb_add_out =
- Comb_add.CombAdd.create (Scope.sub_scope scope "comb_add")
- { Comb_add.CombAdd.I.a = comb_add_a.value;
- b = comb_add_b.value;
- subtract = comb_subtract.value }
- in
-
- (* Compute new coefficient after a div2 step. *)
- (* if c is even: *)
- (* c = c/2 *)
- (* else: *)
- (* c = (c + mod)/2 *)
- let div2_new_coeff c =
- (* c + mod, using adder instance *)
- let sum_c_mod = comb_add_out.carry_out @: comb_add_out.result in (* sum might be +1 bit wide, use carry *)
- (* * if c is even: c = c/2 else: c = (c + mod)/2 *)
- mux2 ~:(lsb c) (srl c 1)
- (sel_top sum_c_mod width) (* using sel_top instead of srl as sum is +1 bit wide *)
- in
+ (* Output wires for driving external mod_add *)
+ let mod_add_valid_w = Variable.wire ~default:gnd in
+ let mod_add_a_w = Variable.wire ~default:(zero Config.width) in
+ let mod_add_b_w = Variable.wire ~default:(zero Config.width) in
+ let mod_add_subtract_w = Variable.wire ~default:gnd in
compile [
sm.switch [
@@ -128,7 +120,6 @@ end
modulus_reg <-- i.modulus;
s <-- of_int ~width:width 1;
u <-- zero width;
-
sm.set_next Op_sel;
];
];
@@ -147,60 +138,76 @@ end
valid <-- vdd;
sm.set_next Done;
] (* else, next operation selection below *)
- @@ elif (~:(lsb x.value)) [sm.set_next Div2_xs] (* if x even -> divide x *)
- @@ elif (~:(lsb y.value)) [sm.set_next Div2_yu] (* if y even -> divide y *)
- @@ [sm.set_next Sub_rems]; (* else -> subtract remainders *)
+ @@ elif (~:(lsb x.value)) [ (* if x even -> divide x (xs pair) *)
+ div2_xny <-- vdd;
+ div2_coeff_odd <-- lsb s.value; (* needed in Div2_p1 *)
+ sm.set_next Div2_add;
+ ]
+ @@ elif (~:(lsb y.value)) [ (* if y even -> divide y (yu pair) *)
+ div2_xny <-- gnd;
+ div2_coeff_odd <-- lsb u.value; (* needed in Div2_p1 *)
+ sm.set_next Div2_add;
+ ]
+ @@ [ (* else -> subtract remainders *)
+ sm.set_next Sub_rems
+ ];
];
- State.Div2_xs, [
- (* x = x/2 *)
- (* if s is even: *)
- (* s = s/2 *)
- (* else: *)
- (* s = (s + mod)/2 *)
+ State.Div2_add, [
+ (* r = r/2 r = x or y *)
+ (* if c is even: c = s or u *)
+ (* c = c/2 *)
+ (* else: *)
+ (* c = (c + mod)/2 *)
+ (* *)
+ (* CAUTION! (c + mod)/2 has 2 steps: *)
+ (* 1. Div2_add: *)
+ (* c = (c >> 1) + (mod >> 1) *)
+ (* 2. Div2_p1: *)
+ (* c = c + 1 *)
+
+ let coeff = mux2 div2_xny.value s.value u.value in
- (* x = x/2 *)
- let new_x = srl x.value 1 in
-
- (* see div2_new_coeff *)
- let new_s = div2_new_coeff s.value in
-
proc [
- (* combinatoinal drive the adder inputs *)
- comb_add_a <-- s.value;
- comb_add_b <-- modulus_reg.value;
- comb_subtract <-- gnd;
-
- (* update register values and select next state *)
- x <-- new_x;
- s <-- new_s;
- sm.set_next Op_sel;
+ (* Combinationally drive mod_add: c>>1 + mod>>1 *)
+ mod_add_valid_w <-- vdd;
+ mod_add_a_w <-- srl coeff 1;
+ mod_add_b_w <-- srl modulus_reg.value 1;
+ mod_add_subtract_w <-- gnd;
+
+ when_ i.mod_add_ready [
+ (* Shift the remainder (fires exactly once) *)
+ if_ div2_xny.value [
+ x <-- srl x.value 1;
+ s <-- mux2 (lsb s.value) i.mod_add_result (srl s.value 1);
+ ] [
+ y <-- srl y.value 1;
+ u <-- mux2 (lsb u.value) i.mod_add_result (srl u.value 1);
+ ];
+ div2_coeff_odd <-- lsb coeff; (* needed for Div2_p1 *)
+ sm.set_next Div2_p1;
+ ];
];
];
- State.Div2_yu, [
- (* y = y/2 *)
- (* if u is even: *)
- (* u = u/2 *)
- (* else: *)
- (* u = (u + mod)/2 *)
-
- (* y = y/2 *)
- let new_y = srl y.value 1 in
-
- (* see div2_new_coeff *)
- let new_u = div2_new_coeff u.value in
+ State.Div2_p1, [
+ (* see Div2_add for what this state implements *)
proc [
- (* combinatoinal drive the adder inputs *)
- comb_add_a <-- u.value;
- comb_add_b <-- modulus_reg.value;
- comb_subtract <-- gnd;
-
- (* update register values and select next state *)
- y <-- new_y;
- u <-- new_u;
- sm.set_next Op_sel;
+ (* Combinationally drive mod_add: c + 1 *)
+ mod_add_valid_w <-- vdd;
+ mod_add_a_w <-- mux2 div2_xny.value s.value u.value;
+ mod_add_b_w <-- of_int ~width:Config.width 1;
+ mod_add_subtract_w <-- gnd;
+
+ when_ i.mod_add_ready [
+ if_ div2_xny.value [
+ s <-- mux2 div2_coeff_odd.value i.mod_add_result s.value;
+ ] [
+ u <-- mux2 div2_coeff_odd.value i.mod_add_result u.value;
+ ];
+ sm.set_next Op_sel;
+ ];
];
];
@@ -209,22 +216,23 @@ end
(* x = x - y *)
(* else: *)
(* Sub_rems_reverse *)
-
+
(* combinatoinal drive the adder inputs *)
- comb_add_a <-- x.value;
- comb_add_b <-- y.value;
- comb_subtract <-- vdd;
-
- (* update register values and select next state *)
- if_ ~:(comb_add_out.carry_out) [
- (* carry_out = 0 means no underflow (x was >= y), store results and move to coefficients *)
- x <-- comb_add_out.result;
- reduced_xny <-- vdd; (* Subtract reduced x *)
- (* REVISIT is it okay to skip a state conditionally? *)
- sm.set_next Sub_coeffs;
- ] [
- (* underflow, need Sub_rems_reverse instead *)
- sm.set_next Sub_rems_reverse;
+ mod_add_valid_w <-- vdd;
+ mod_add_a_w <-- x.value;
+ mod_add_b_w <-- y.value;
+ mod_add_subtract_w <-- vdd;
+
+ when_ i.mod_add_ready [
+ if_ ~:(i.mod_add_adjusted) [
+ (* No underflow: x >= y, result is x - y *)
+ x <-- i.mod_add_result;
+ reduced_xny <-- vdd; (* x was reduced *)
+ sm.set_next Sub_coeffs;
+ ] [
+ (* Underflow: x < y, need reverse subtraction *)
+ sm.set_next Sub_rems_reverse;
+ ];
];
];
@@ -233,14 +241,16 @@ end
(* y = y - x *)
(* combinatoinal drive the adder inputs *)
- comb_add_a <-- y.value;
- comb_add_b <-- x.value;
- comb_subtract <-- vdd;
-
- (* update register values and select next state *)
- y <-- comb_add_out.result;
- reduced_xny <-- gnd; (* Subtract reduced y *)
- sm.set_next Sub_coeffs;
+ mod_add_valid_w <-- vdd;
+ mod_add_a_w <-- y.value;
+ mod_add_b_w <-- x.value;
+ mod_add_subtract_w <-- vdd;
+
+ when_ i.mod_add_ready [
+ y <-- i.mod_add_result;
+ reduced_xny <-- gnd; (* y was reduced *)
+ sm.set_next Sub_coeffs;
+ ];
];
State.Sub_coeffs, [
@@ -251,62 +261,39 @@ end
(* u = u-s *)
(* combinatorial drive the adder inputs *)
- (* if x was reduced: attempt s - u, else: attempt u - s *)
- comb_add_a <-- mux2 reduced_xny.value s.value u.value;
- comb_add_b <-- mux2 reduced_xny.value u.value s.value;
- comb_subtract <-- vdd;
-
- (* update register values and select next state *)
- (* the result is always valid, even if underflowed, Adjust_coeff will fix *)
- if_ reduced_xny.value [
- s <-- comb_add_out.result;
- ] [
- u <-- comb_add_out.result;
- ];
- (* if underflow, go to Adjust_coeff to add modulus *)
- if_ comb_add_out.carry_out [
- sm.set_next Adjust_coeff;
- ] [
+ (* if x was reduced: do s - u, else: u - s *)
+ mod_add_valid_w <-- vdd;
+ mod_add_a_w <-- mux2 reduced_xny.value s.value u.value;
+ mod_add_b_w <-- mux2 reduced_xny.value u.value s.value;
+ mod_add_subtract_w <-- vdd;
+
+ when_ i.mod_add_ready [
+ if_ reduced_xny.value [
+ s <-- i.mod_add_result;
+ ] [
+ u <-- i.mod_add_result;
+ ];
sm.set_next Op_sel;
];
];
- State.Adjust_coeff, [
- (* _old might have changed by now *)
- (* if x_old >= y_old: *)
- (* if s_old < u_old: *)
- (* s = s_old - u_old + mod *)
- (* else: *)
- (* if u_old < s_old: *)
- (* u = u_old - s_old + mod *)
-
- (* combinatorial drive the adder inputs *)
- comb_add_a <-- mux2 reduced_xny.value s.value u.value;
- comb_add_b <-- modulus_reg.value;
- comb_subtract <-- gnd;
-
- (* update register values and select next state *)
- if_ reduced_xny.value [
- s <-- comb_add_out.result;
- ] [
- u <-- comb_add_out.result;
- ];
- sm.set_next Op_sel;
+ State.Done, [
+ valid <-- gnd;
+ sm.set_next Idle; (* Return to Idle after one cycle *)
];
-
-State.Done, [
- valid <-- gnd;
- sm.set_next Idle; (* Return to Idle after one cycle *)
-];
];
];
-{ O.result = result.value -- "result"
-; valid = valid.value -- "valid"
-; exists = exists.value -- "exists"
-; dbg_x = x.value
-; dbg_y = y.value
-; dbg_s = s.value
-; dbg_u = u.value
-}
-end
\ No newline at end of file
+ { O.result = result.value -- "result"
+ ; valid = valid.value -- "valid"
+ ; exists = exists.value -- "exists"
+ ; mod_add_valid = mod_add_valid_w.value -- "mod_add_valid"
+ ; mod_add_a = mod_add_a_w.value -- "mod_add_a"
+ ; mod_add_b = mod_add_b_w.value -- "mod_add_b"
+ ; mod_add_subtract = mod_add_subtract_w.value -- "mod_add_subtract"
+ ; dbg_x = x.value
+ ; dbg_y = y.value
+ ; dbg_s = s.value
+ ; dbg_u = u.value
+ }
+end
diff --git a/src/mod_mul.ml b/src/mod_mul.ml
index d9197fe..5cd03d3 100644
--- a/src/mod_mul.ml
+++ b/src/mod_mul.ml
@@ -9,18 +9,18 @@ module Config = struct
end
(* Simple Modular Multiplication
-
+
Computes (x * y) mod r where r is the modulus.
-
- Uses the standard shift-and-add algorithm with modular reduction at each step. *)
+
+ Uses the standard shift-and-add algorithm with modular reduction at each step.
+ Drives an external mod_add instance for all arithmetic. *)
module ModMul = struct
module State = struct
type t =
| Idle
| Init
| Add
- | Adjust
- | Double_adjust
+ | Double
| Done
[@@deriving sexp_of, compare, enumerate]
end
@@ -34,6 +34,8 @@ module ModMul = struct
; y : 'a [@bits Config.width] (* Second multiplicand *)
; modulus : 'a [@bits Config.width] (* Modulus r *)
; num_bits : 'a [@bits Config.bit_count_width] (* Number of bits to process in y (0-256) *)
+ ; mod_add_result : 'a [@bits Config.width] (* result from external mod_add *)
+ ; mod_add_ready : 'a (* ready from external mod_add *)
}
[@@deriving sexp_of, hardcaml]
end
@@ -42,6 +44,10 @@ module ModMul = struct
type 'a t =
{ result : 'a [@bits Config.width] (* (x * y) mod r *)
; valid : 'a (* High when result is valid *)
+ ; mod_add_valid : 'a
+ ; mod_add_a : 'a [@bits Config.width]
+ ; mod_add_b : 'a [@bits Config.width]
+ ; mod_add_subtract : 'a
}
[@@deriving sexp_of, hardcaml]
end
@@ -59,7 +65,6 @@ module ModMul = struct
(* Result accumulator *)
let result_acc = Variable.reg spec ~width:width in
- let add_carry = Variable.reg spec ~width:1 in
(* Multiplier (y) and bit counter *)
let multiplier = Variable.reg spec ~width in
@@ -67,27 +72,18 @@ module ModMul = struct
(* Current value of x (gets doubled each iteration) *)
let x_current = Variable.reg spec ~width:width in
-
- (* Stored modulus *)
- let modulus_reg = Variable.reg spec ~width:width in
+
let num_bits_orig = Variable.reg spec ~width:bit_count_width in
(* Output registers *)
let result = Variable.reg spec ~width in
let valid = Variable.reg spec ~width:1 in
- (* 256 bit adder instance *)
- let in_add_state = sm.is Add in
- let in_double_adj_state = sm.is Double_adjust in
- let x_doubled = sll x_current.value 1 in
-
- let comb_add_a = mux2 ~:in_double_adj_state result_acc.value x_doubled in
- let comb_add_b = mux2 in_add_state x_current.value modulus_reg.value in
- let comb_subtract = mux2 in_add_state gnd vdd in
- let comb_add_out =
- Comb_add.CombAdd.create (Scope.sub_scope scope "comb_add")
- { Comb_add.CombAdd.I.a = comb_add_a; b = comb_add_b; subtract = comb_subtract }
- in
+ (* Output wires for driving external mod_add *)
+ let mod_add_valid_w = Variable.wire ~default:gnd in
+ let mod_add_a_w = Variable.wire ~default:(zero Config.width) in
+ let mod_add_b_w = Variable.wire ~default:(zero Config.width) in
+ let mod_add_subtract_w = Variable.wire ~default:gnd in
compile [
sm.switch [
@@ -109,7 +105,6 @@ module ModMul = struct
sm.set_next Done;
] [
x_current <-- i.x;
- modulus_reg <-- i.modulus;
multiplier <-- i.y;
num_bits_orig <-- i.num_bits;
sm.set_next Init;
@@ -125,65 +120,46 @@ module ModMul = struct
];
State.Add, [
- (* Check LSB of multiplier *)
- let current_bit = lsb multiplier.value in
-
- (* If bit is set, add current x to result *)
- let after_add = mux2 current_bit comb_add_out.result result_acc.value in
-
- proc [
- valid <-- gnd;
- result_acc <-- after_add;
- add_carry <-- mux2 current_bit comb_add_out.carry_out gnd;
- sm.set_next Adjust;
- ];
- ];
-
- State.Adjust, [
-
- (* Reduce if >= modulus:
- -- if the addition had overflow (add_carry=1)
- -- if subtracting the modulus does not underflow (comb_add_out.carry_out=0) *)
- let add_reduce_needed = add_carry.value |: ~:(comb_add_out.carry_out) in
- let after_add_reduce = mux2 add_reduce_needed comb_add_out.result result_acc.value in
-
- proc [
- valid <-- gnd;
- result_acc <-- after_add_reduce;
- add_carry <-- gnd; (* Clear carry as it is N/A here *)
- sm.set_next Double_adjust;
+ if_ (lsb multiplier.value) [
+ (* Combinatorially drive mod_add *)
+ mod_add_valid_w <-- vdd;
+ mod_add_a_w <-- result_acc.value;
+ mod_add_b_w <-- x_current.value;
+ mod_add_subtract_w <-- gnd;
+
+ when_ i.mod_add_ready [
+ result_acc <-- i.mod_add_result;
+ sm.set_next Double;
+ ];
+ ] [
+ sm.set_next Double; (* LSB = 0: skip addition *)
];
];
- State.Double_adjust, [
-
- (* Double x for next iteration and
- reduce new x if >= modulus:
- -- if the doubling caused overflow (MSB=1)
- -- if subtracting the modulus does not underflow (comb_add_out.carry_out=0) *)
- let double_reduce_needed = msb x_current.value |: ~:(comb_add_out.carry_out) in
- let new_x = mux2 double_reduce_needed comb_add_out.result x_doubled in
- let new_multiplier = srl multiplier.value 1 in
- let new_bit_count = bit_count.value +:. 1 in
-
- proc [
- valid <-- gnd;
- x_current <-- new_x;
- multiplier <-- new_multiplier;
- bit_count <-- new_bit_count;
-
- if_ ((new_bit_count ==: num_bits_orig.value) |: (new_multiplier ==:. 0)) [
- (* Exit when done OR when no more bits to process *)
- result <-- result_acc.value;
- valid <-- vdd;
- sm.set_next Done;
- ] (* REVISIT is it okay to just skip Add and Adjust if LSB=0? *)
- @@ elif ( ~:(lsb new_multiplier) ) [
- (* Skip Add if next bit is 0 *)
- sm.set_next Double_adjust;
- ]
- @@ [
- sm.set_next Add;
+ State.Double, [
+ (* Combinatorially drive mod_add: x_current + x_current (double) *)
+ mod_add_valid_w <-- vdd;
+ mod_add_a_w <-- x_current.value;
+ mod_add_b_w <-- x_current.value;
+ mod_add_subtract_w <-- gnd;
+
+ when_ i.mod_add_ready [
+ let new_multiplier = srl multiplier.value 1 in
+ let new_bit_count = bit_count.value +:. 1 in
+ proc [
+ x_current <-- i.mod_add_result;
+ multiplier <-- new_multiplier;
+ bit_count <-- new_bit_count;
+
+ if_ ((new_bit_count ==: num_bits_orig.value) |: (new_multiplier ==:. 0)) [
+ result <-- result_acc.value;
+ valid <-- vdd;
+ sm.set_next Done;
+ ] @@ elif (~:(lsb new_multiplier)) [
+ sm.set_next Double; (* skip Add when next bit = 0 *)
+ ] @@ [
+ sm.set_next Add;
+ ];
];
];
];
@@ -195,8 +171,11 @@ module ModMul = struct
];
];
- { O.result = result.value -- "result"
- ; valid = valid.value -- "valid"
+ { O.result = result.value -- "result"
+ ; valid = valid.value -- "valid"
+ ; mod_add_valid = mod_add_valid_w.value -- "mod_add_valid"
+ ; mod_add_a = mod_add_a_w.value -- "mod_add_a"
+ ; mod_add_b = mod_add_b_w.value -- "mod_add_b"
+ ; mod_add_subtract = mod_add_subtract_w.value -- "mod_add_subtract"
}
end
-
diff --git a/test/test_mod_add.ml b/test/test_mod_add.ml
index 4fe5434..ecbde41 100644
--- a/test/test_mod_add.ml
+++ b/test/test_mod_add.ml
@@ -2,10 +2,10 @@ open Base
open Hardcaml
let test () =
- let test_modulus_z =
+ let test_modulus_z =
Z.of_string "115792089237316195423570985008687907853269984665640564039457584007908834671663"
in
-
+
Stdio.printf "=== ModAdd Hardware Test (256-bit modular add/sub) ===\n\n";
Stdio.printf "Test Modulus (n) = %s\n\n" (Z.to_string test_modulus_z);
@@ -14,7 +14,7 @@ let test () =
let sim = Sim.create (Mod_add.ModAdd.create scope) in
let inputs = Cyclesim.inputs sim in
- let outputs = Cyclesim.outputs sim in
+ let outputs = Cyclesim.outputs ~clock_edge:Before sim in
let z_to_bits z =
let hex_str = Z.format "%x" z in
@@ -35,14 +35,14 @@ let test () =
Stdio.printf " b = %s\n" (Z.to_string b_z);
Stdio.printf " n = %s\n" (Z.to_string modulus_z);
- let expected_z =
+ let expected_z =
let raw = if is_sub then Z.(a_z - b_z) else Z.(a_z + b_z) in
Z.(erem raw modulus_z)
in
Stdio.printf " expected = %s\n" (Z.to_string expected_z);
inputs.clear := Bits.vdd;
- inputs.start := Bits.gnd;
+ inputs.valid := Bits.gnd;
inputs.a := Bits.zero Mod_add.Config.width;
inputs.b := Bits.zero Mod_add.Config.width;
inputs.modulus := Bits.zero Mod_add.Config.width;
@@ -56,30 +56,28 @@ let test () =
inputs.b := z_to_bits b_z;
inputs.modulus := z_to_bits modulus_z;
inputs.subtract := if is_sub then Bits.vdd else Bits.gnd;
- inputs.start := Bits.vdd;
- Cyclesim.cycle sim;
-
- inputs.start := Bits.gnd;
+ inputs.valid := Bits.vdd;
let max_cycles = 20 in
let rec wait cycle_count =
if cycle_count >= max_cycles then begin
Stdio.printf " ERROR: Timeout after %d cycles\n\n" max_cycles;
false
- end else if Bits.to_bool !(outputs.valid) then begin
+ end else if Bits.to_bool !(outputs.ready) then begin
+ inputs.valid := Bits.gnd;
let result_z = bits_to_z !(outputs.result) in
-
+
Stdio.printf " Completed in %d cycles\n" cycle_count;
Stdio.printf " result = %s\n" (Z.to_string result_z);
-
+
let verified = Z.equal result_z expected_z in
-
+
if verified then
Stdio.printf " Verification: PASS ✓\n"
else
- Stdio.printf " Verification: FAIL ✗ (diff: %s)\n"
+ Stdio.printf " Verification: FAIL ✗ (diff: %s)\n"
(Z.to_string Z.(result_z - expected_z));
-
+
Stdio.printf "\n";
verified
end else begin
@@ -87,7 +85,8 @@ let test () =
wait (cycle_count + 1)
end
in
- wait 0
+ Cyclesim.cycle sim;
+ wait 1
in
let test_with_default_modulus name a_z b_z is_sub =
@@ -97,29 +96,29 @@ let test () =
let results = [
test_with_default_modulus "Simple addition"
(Z.of_int 100) (Z.of_int 200) false;
-
+
test_with_default_modulus "Addition with modular wrap"
(Z.of_string "115792089237316195423570985008687907853269984665640564039457584007908834671600")
(Z.of_int 100) false;
-
+
test_with_default_modulus "Simple subtraction"
(Z.of_int 500) (Z.of_int 300) true;
-
+
test_with_default_modulus "Subtraction requiring modular correction"
(Z.of_int 100) (Z.of_int 200) true;
-
+
test_with_default_modulus "Subtraction from modulus-1"
Z.(test_modulus_z - one) (Z.of_int 10) true;
-
+
test_with_default_modulus "Add zero"
(Z.of_int 12345) Z.zero false;
-
+
test_with_default_modulus "Subtract zero"
(Z.of_int 12345) Z.zero true;
-
+
test_case "Different modulus: small prime 997"
(Z.of_int 500) (Z.of_int 600) (Z.of_int 997) false;
-
+
test_case "BN254 subtraction wrap"
(Z.of_int 10)
(Z.of_int 20)
@@ -138,4 +137,4 @@ let test () =
else
Stdio.printf "\n✗ Some tests failed\n"
-let () = test ()
\ No newline at end of file
+let () = test ()
diff --git a/test/test_mod_inv.ml b/test/test_mod_inv.ml
index 1f047c7..d933d7b 100644
--- a/test/test_mod_inv.ml
+++ b/test/test_mod_inv.ml
@@ -1,13 +1,49 @@
open Base
open Hardcaml
+module ModInvWithModAdd = struct
+ module I = Mod_inv.ModInv.I
+ module O = Mod_inv.ModInv.O
+
+ let create scope (i : _ I.t) =
+ let mod_add_result_w = Signal.wire Mod_inv.Config.width in
+ let mod_add_ready_w = Signal.wire 1 in
+ let mod_add_adjusted_w = Signal.wire 1 in
+
+ let inv_out = Mod_inv.ModInv.create (Scope.sub_scope scope "mod_inv")
+ { i with
+ mod_add_result = mod_add_result_w
+ ; mod_add_ready = mod_add_ready_w
+ ; mod_add_adjusted = mod_add_adjusted_w
+ }
+ in
+
+ let add_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add")
+ { Mod_add.ModAdd.I.
+ clock = i.clock
+ ; clear = i.clear
+ ; valid = inv_out.mod_add_valid
+ ; a = inv_out.mod_add_a
+ ; b = inv_out.mod_add_b
+ ; modulus = i.modulus
+ ; subtract = inv_out.mod_add_subtract
+ }
+ in
+
+ Signal.assign mod_add_result_w add_out.result;
+ Signal.assign mod_add_ready_w add_out.ready;
+ Signal.assign mod_add_adjusted_w add_out.adjusted;
+
+ inv_out
+end
+
let test () =
Stdio.printf "=== ModInv Hardware Test (256-bit with Zarith) ===\n";
Stdio.printf "=== Assuming odd prime modulus ===\n\n";
let scope = Scope.create ~flatten_design:true () in
- let module Sim = Cyclesim.With_interface(Mod_inv.ModInv.I)(Mod_inv.ModInv.O) in
- let sim = Sim.create (Mod_inv.ModInv.create scope) in
+ let module Sim = Cyclesim.With_interface(ModInvWithModAdd.I)(ModInvWithModAdd.O) in
+ let sim = Sim.create (ModInvWithModAdd.create scope) in
let inputs = Cyclesim.inputs sim in
let outputs = Cyclesim.outputs sim in
diff --git a/test/test_mod_mul.ml b/test/test_mod_mul.ml
index 9d0ed2f..1301363 100644
--- a/test/test_mod_mul.ml
+++ b/test/test_mod_mul.ml
@@ -1,12 +1,45 @@
open Base
open Hardcaml
+module ModMulWithModAdd = struct
+ module I = Mod_mul.ModMul.I
+ module O = Mod_mul.ModMul.O
+
+ let create scope (i : _ I.t) =
+ let mod_add_result_w = Signal.wire Mod_mul.Config.width in
+ let mod_add_ready_w = Signal.wire 1 in
+
+ let mul_out = Mod_mul.ModMul.create (Scope.sub_scope scope "mod_mul")
+ { i with
+ mod_add_result = mod_add_result_w
+ ; mod_add_ready = mod_add_ready_w
+ }
+ in
+
+ let add_out = Mod_add.ModAdd.create (Scope.sub_scope scope "mod_add")
+ { Mod_add.ModAdd.I.
+ clock = i.clock
+ ; clear = i.clear
+ ; valid = mul_out.mod_add_valid
+ ; a = mul_out.mod_add_a
+ ; b = mul_out.mod_add_b
+ ; modulus = i.modulus
+ ; subtract = mul_out.mod_add_subtract
+ }
+ in
+
+ Signal.(mod_add_result_w <== add_out.result);
+ Signal.(mod_add_ready_w <== add_out.ready);
+
+ mul_out
+end
+
let test () =
Stdio.printf "=== ModMul Hardware Test (256-bit with Zarith and num_bits) ===\n\n";
let scope = Scope.create ~flatten_design:true () in
- let module Sim = Cyclesim.With_interface(Mod_mul.ModMul.I)(Mod_mul.ModMul.O) in
- let sim = Sim.create (Mod_mul.ModMul.create scope) in
+ let module Sim = Cyclesim.With_interface(ModMulWithModAdd.I)(ModMulWithModAdd.O) in
+ let sim = Sim.create (ModMulWithModAdd.create scope) in
let inputs = Cyclesim.inputs sim in
let outputs = Cyclesim.outputs sim in
@@ -35,8 +68,8 @@ let test () =
let y_bits = z_to_bits y_z in
let mod_bits = z_to_bits mod_z in
let num_bits = z_num_bits y_z in
-
- Stdio.printf " y bits = %d (optimization: %dx speedup)\n"
+
+ Stdio.printf " y bits = %d (optimization: %dx speedup)\n"
num_bits (Mod_mul.Config.width / (max 1 num_bits));
(* Reset *)
@@ -69,21 +102,21 @@ let test () =
false
end else if Bits.to_bool !(outputs.valid) then begin
let result_bits = !(outputs.result) in
- let result_z =
+ let result_z =
let bin_str = Bits.to_bstr result_bits in
Z.of_string_base 2 bin_str
in
-
+
Stdio.printf " Completed in %d cycles\n" cycle_count;
Stdio.printf " result = %s\n" (Z.to_string result_z);
-
+
let verified = Z.equal result_z expected_z in
-
+
if verified then
Stdio.printf " Verification: PASS ✓\n"
else
Stdio.printf " Verification: FAIL ✗\n";
-
+
Stdio.printf "\n";
verified
end else begin
@@ -97,32 +130,32 @@ let test () =
let results = [
test_case "3 * 5 mod 7"
(Z.of_int 3) (Z.of_int 5) (Z.of_int 7);
-
+
test_case "12 * 15 mod 17"
(Z.of_int 12) (Z.of_int 15) (Z.of_int 17);
-
+
test_case "123 * 456 mod 1009"
(Z.of_int 123) (Z.of_int 456) (Z.of_int 1009);
-
+
test_case "Edge: x * 0 mod m"
(Z.of_int 12345) (Z.of_int 0) (Z.of_int 7919);
-
+
test_case "Edge: x * 1 mod m"
(Z.of_int 12345) (Z.of_int 1) (Z.of_int 7919);
-
+
test_case "Large: 123456 * 789012 mod 1000003"
(Z.of_int 123456) (Z.of_int 789012) (Z.of_int 1000003);
-
+
test_case "64-bit multiplication"
(Z.of_string "12345678901234")
(Z.of_string "98765432109876")
(Z.of_string "999999999999999989");
-
+
test_case "128-bit multiplication"
(Z.of_string "123456789012345678901234567890")
(Z.of_string "987654321098765432109876543210")
(Z.of_string "340282366920938463463374607431768211297");
-
+
test_case "256-bit: multiplication mod secp256k1 prime"
(Z.of_string "123456789012345678901234567890123456789")
(Z.of_string "987654321098765432109876543210987654321")
@@ -152,4 +185,4 @@ let test () =
end else
Stdio.printf "Some tests failed ✗\n"
-let () = test ()
\ No newline at end of file
+let () = test ()