Skip to content

Commit 7a3d838

Browse files
committed
Add EGCD.
Fix some comments in GCD. Make ml_kem use lcm and egcd from std/math. Fix name. Add egcd function. Don't destructure. Use binary gcd and make overflow safe. Force inlining, use ctz to reduce dependency in loop.
1 parent 50ba48f commit 7a3d838

File tree

4 files changed

+190
-29
lines changed

4 files changed

+190
-29
lines changed

lib/std/crypto/ml_kem.zig

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -634,33 +634,11 @@ test "invNTTReductions bounds" {
634634
}
635635
}
636636

637-
// Extended euclidean algorithm.
638-
//
639-
// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute
640-
// modular inverse.
641-
fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) {
642-
if (a == 0) {
643-
return .{ .gcd = b, .x = 0, .y = 1 };
644-
}
645-
const r = eea(@rem(b, a), a);
646-
return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x };
647-
}
648-
649-
fn EeaResult(comptime T: type) type {
650-
return struct { gcd: T, x: T, y: T };
651-
}
652-
653-
// Returns least common multiple of a and b.
654-
fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) {
655-
const r = eea(a, b);
656-
return a * b / r.gcd;
657-
}
658-
659637
// Invert modulo p.
660638
fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) {
661-
const r = eea(a, p);
639+
const r = std.math.egcd(a, p);
662640
assert(r.gcd == 1);
663-
return r.x;
641+
return r.bezout_coeff_1;
664642
}
665643

666644
// Reduce mod q for testing.
@@ -1054,7 +1032,7 @@ const Poly = struct {
10541032
var in_off: usize = 0;
10551033
var out_off: usize = 0;
10561034

1057-
const batch_size: usize = comptime lcm(@as(i16, d), 8);
1035+
const batch_size: usize = comptime std.math.lcm(@as(i16, d), 8);
10581036
const in_batch_size: usize = comptime batch_size / d;
10591037
const out_batch_size: usize = comptime batch_size / 8;
10601038

@@ -1118,7 +1096,7 @@ const Poly = struct {
11181096
var in_off: usize = 0;
11191097
var out_off: usize = 0;
11201098

1121-
const batch_size: usize = comptime lcm(@as(i16, d), 8);
1099+
const batch_size: usize = comptime std.math.lcm(@as(i16, d), 8);
11221100
const in_batch_size: usize = comptime batch_size / 8;
11231101
const out_batch_size: usize = comptime batch_size / d;
11241102

lib/std/math.zig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ pub const sinh = @import("math/sinh.zig").sinh;
238238
pub const cosh = @import("math/cosh.zig").cosh;
239239
pub const tanh = @import("math/tanh.zig").tanh;
240240
pub const gcd = @import("math/gcd.zig").gcd;
241+
pub const egcd = @import("math/egcd.zig").egcd;
241242
pub const lcm = @import("math/lcm.zig").lcm;
242243
pub const gamma = @import("math/gamma.zig").gamma;
243244
pub const lgamma = @import("math/gamma.zig").lgamma;

lib/std/math/egcd.zig

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
//! Extended Greatest Common Divisor (https://mathworld.wolfram.com/ExtendedGreatestCommonDivisor.html)
2+
const std = @import("../std.zig");
3+
4+
/// Result type of `egcd`.
5+
pub fn ExtendedGreatestCommonDivisor(S: anytype) type {
6+
const N = switch (S) {
7+
comptime_int => comptime_int,
8+
else => |T| std.meta.Int(.unsigned, @bitSizeOf(T)),
9+
};
10+
11+
return struct {
12+
gcd: N,
13+
bezout_coeff_1: S,
14+
bezout_coeff_2: S,
15+
};
16+
}
17+
18+
inline fn egcd_helper(other: anytype, odd: anytype, shift: anytype) [3]@TypeOf(other, odd) {
19+
const S = @TypeOf(other, odd);
20+
const toinv = @shrExact(other, @intCast(shift));
21+
const ctrl = @shrExact(odd, @intCast(shift));
22+
23+
var s: S = std.math.sign(toinv);
24+
var t: S = 0;
25+
26+
var x = @abs(toinv);
27+
var y = @abs(ctrl);
28+
29+
{
30+
const xz = @ctz(x);
31+
x = @shrExact(x, @intCast(xz));
32+
for (0..xz) |_|
33+
s = @shrExact(if (s & 1 == 0) s else s + ctrl, 1);
34+
}
35+
36+
var y_minus_x = y -% x;
37+
while (y_minus_x != 0) : (y_minus_x = y -% x) {
38+
const t_minus_s = t - s;
39+
const copy_x = x;
40+
const copy_s = s;
41+
const xz = @ctz(y_minus_x);
42+
43+
s -= t;
44+
const carry = x < y;
45+
x -%= y;
46+
if (carry) {
47+
x = y_minus_x;
48+
y = copy_x;
49+
s = t_minus_s;
50+
t = copy_s;
51+
}
52+
x = @shrExact(x, @intCast(xz));
53+
for (0..xz) |_|
54+
s = @shrExact(if (s & 1 == 0) s else s + ctrl, 1);
55+
}
56+
57+
y = @shlExact(y, @intCast(shift));
58+
s = @shlExact(s, @intCast(shift));
59+
// Using integer widening is only a temporary solution.
60+
const W = std.meta.Int(.signed, @bitSizeOf(S) * 2);
61+
t = @intCast(@divExact(y - @as(W, s) * toinv, ctrl));
62+
return .{ @bitCast(y), s, t };
63+
}
64+
65+
/// Returns the Extended Greatest Common Divisor (EGCD) of two signed integers (`a` and `b`) which are not both zero.
66+
pub fn egcd(a: anytype, b: anytype) ExtendedGreatestCommonDivisor(@TypeOf(a, b)) {
67+
const S = switch (@TypeOf(a, b)) {
68+
comptime_int => b: {
69+
const n = @max(@abs(a), @abs(b));
70+
break :b std.math.IntFittingRange(-n, n);
71+
},
72+
else => |T| T,
73+
};
74+
if (@typeInfo(S) != .int or @typeInfo(S).int.signedness != .signed) {
75+
@compileError("`a` and `b` must be signed integers");
76+
}
77+
78+
std.debug.assert(a != 0 or b != 0);
79+
80+
if (a == 0) return .{ .gcd = @abs(b), .bezout_coeff_1 = 0, .bezout_coeff_2 = std.math.sign(b) };
81+
if (b == 0) return .{ .gcd = @abs(a), .bezout_coeff_1 = std.math.sign(a), .bezout_coeff_2 = 0 };
82+
83+
const x: S = a;
84+
const y: S = b;
85+
86+
const xz = @ctz(x);
87+
const yz = @ctz(y);
88+
89+
if (xz < yz) {
90+
const gcd, const t, const s = egcd_helper(y, x, xz);
91+
return .{ .gcd = @intCast(gcd), .bezout_coeff_1 = s, .bezout_coeff_2 = t };
92+
} else {
93+
const gcd, const s, const t = egcd_helper(x, y, yz);
94+
return .{ .gcd = @intCast(gcd), .bezout_coeff_1 = s, .bezout_coeff_2 = t };
95+
}
96+
}
97+
98+
test {
99+
{
100+
const a: i4 = -8;
101+
const b: i4 = 1;
102+
const r = egcd(a, b);
103+
const g = r.gcd;
104+
const s = r.bezout_coeff_1;
105+
const t = r.bezout_coeff_2;
106+
try std.testing.expect(s * a + t * b == g);
107+
}
108+
{
109+
const a: i4 = -8;
110+
const b: i4 = 5;
111+
const r = egcd(a, b);
112+
const g = r.gcd;
113+
// Avoid overflow in assert.
114+
const s: i8 = r.bezout_coeff_1;
115+
const t: i8 = r.bezout_coeff_2;
116+
try std.testing.expect(s * a + t * b == g);
117+
}
118+
{
119+
const a: i32 = 0;
120+
const b: i32 = 5;
121+
const r = egcd(a, b);
122+
const g = r.gcd;
123+
const s = r.bezout_coeff_1;
124+
const t = r.bezout_coeff_2;
125+
try std.testing.expect(s * a + t * b == g);
126+
}
127+
{
128+
const a: i32 = 5;
129+
const b: i32 = 0;
130+
const r = egcd(a, b);
131+
const g = r.gcd;
132+
const s = r.bezout_coeff_1;
133+
const t = r.bezout_coeff_2;
134+
try std.testing.expect(s * a + t * b == g);
135+
}
136+
137+
{
138+
const a: i32 = 21;
139+
const b: i32 = 15;
140+
const r = egcd(a, b);
141+
const g = r.gcd;
142+
const s = r.bezout_coeff_1;
143+
const t = r.bezout_coeff_2;
144+
try std.testing.expect(s * a + t * b == g);
145+
}
146+
{
147+
const a: i32 = -21;
148+
const b: i32 = 15;
149+
const r = egcd(a, b);
150+
const g = r.gcd;
151+
const s = r.bezout_coeff_1;
152+
const t = r.bezout_coeff_2;
153+
try std.testing.expect(s * a + t * b == g);
154+
}
155+
{
156+
const a = -21;
157+
const b = 15;
158+
const r = egcd(a, b);
159+
const g = r.gcd;
160+
const s = r.bezout_coeff_1;
161+
const t = r.bezout_coeff_2;
162+
try std.testing.expect(s * a + t * b == g);
163+
}
164+
{
165+
const a = 927372692193078999176;
166+
const b = 573147844013817084101;
167+
const r = egcd(a, b);
168+
const g = r.gcd;
169+
const s = r.bezout_coeff_1;
170+
const t = r.bezout_coeff_2;
171+
try std.testing.expect(s * a + t * b == g);
172+
}
173+
{
174+
const a = 453973694165307953197296969697410619233826;
175+
const b = 280571172992510140037611932413038677189525;
176+
const r = egcd(a, b);
177+
const g = r.gcd;
178+
const s = r.bezout_coeff_1;
179+
const t = r.bezout_coeff_2;
180+
try std.testing.expect(s * a + t * b == g);
181+
}
182+
}

lib/std/math/gcd.zig

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
//! Greatest common divisor (https://mathworld.wolfram.com/GreatestCommonDivisor.html)
2-
const std = @import("std");
1+
//! Greatest Common Divisor (https://mathworld.wolfram.com/GreatestCommonDivisor.html)
2+
const std = @import("../std.zig");
33

4-
/// Returns the greatest common divisor (GCD) of two unsigned integers (`a` and `b`) which are not both zero.
4+
/// Returns the Greatest Common Divisor (GCD) of two unsigned integers (`a` and `b`) which are not both zero.
55
/// For example, the GCD of `8` and `12` is `4`, that is, `gcd(8, 12) == 4`.
66
pub fn gcd(a: anytype, b: anytype) @TypeOf(a, b) {
77
const N = switch (@TypeOf(a, b)) {

0 commit comments

Comments
 (0)