Skip to content

Commit 5d86a69

Browse files
committed
Add EGCD.
Fix some comments in GCD. Make ml_kem use lcm and egcd from std/math.
1 parent b0842c3 commit 5d86a69

File tree

3 files changed

+120
-30
lines changed

3 files changed

+120
-30
lines changed

lib/std/crypto/ml_kem.zig

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -634,33 +634,11 @@ test "invNTTReductions bounds" {
634634
}
635635
}
636636

637-
// Extended euclidean algorithm.
638-
//
639-
// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute
640-
// modular inverse.
641-
fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) {
642-
if (a == 0) {
643-
return .{ .gcd = b, .x = 0, .y = 1 };
644-
}
645-
const r = eea(@rem(b, a), a);
646-
return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x };
647-
}
648-
649-
fn EeaResult(comptime T: type) type {
650-
return struct { gcd: T, x: T, y: T };
651-
}
652-
653-
// Returns least common multiple of a and b.
654-
fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) {
655-
const r = eea(a, b);
656-
return a * b / r.gcd;
657-
}
658-
659637
// Invert modulo p.
660638
fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) {
661-
const r = eea(a, p);
662-
assert(r.gcd == 1);
663-
return r.x;
639+
const gcd, const x, _ = std.math.egcd(a, p);
640+
assert(gcd == 1);
641+
return x;
664642
}
665643

666644
// Reduce mod q for testing.
@@ -1054,7 +1032,7 @@ const Poly = struct {
10541032
var in_off: usize = 0;
10551033
var out_off: usize = 0;
10561034

1057-
const batch_size: usize = comptime lcm(@as(i16, d), 8);
1035+
const batch_size: usize = comptime std.math.lcm(@as(i16, d), 8);
10581036
const in_batch_size: usize = comptime batch_size / d;
10591037
const out_batch_size: usize = comptime batch_size / 8;
10601038

@@ -1118,7 +1096,7 @@ const Poly = struct {
11181096
var in_off: usize = 0;
11191097
var out_off: usize = 0;
11201098

1121-
const batch_size: usize = comptime lcm(@as(i16, d), 8);
1099+
const batch_size: usize = comptime std.math.lcm(@as(i16, d), 8);
11221100
const in_batch_size: usize = comptime batch_size / 8;
11231101
const out_batch_size: usize = comptime batch_size / d;
11241102

lib/std/math/egcd.zig

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
//! Extended Greatest Common Divisor (https://mathworld.wolfram.com/ExtendedGreatestCommonDivisor.html)
2+
const std = @import("../std.zig");
3+
4+
/// Result type of `egcd`.
5+
pub fn ExtendedCommonDivisor(S: anytype) type {
6+
if (@typeInfo(S) != .int or @typeInfo(S).int.signedness != .signed) {
7+
@compileError("`S` must be a signed integer.");
8+
}
9+
10+
return struct {
11+
gcd: S,
12+
bezout_coeff_1: S,
13+
bezout_coeff_2: S,
14+
};
15+
}
16+
17+
/// Returns the Extended Greatest Common Divisor (EGCD) of two signed integers (`a` and `b`) which are not both zero.
18+
pub fn egcd(a: anytype, b: anytype) ExtendedCommonDivisor(@TypeOf(a, b)) {
19+
const S = switch (@TypeOf(a, b)) {
20+
// convert comptime_int to some sized int type for @ctz
21+
comptime_int => b: {
22+
const n = @max(@abs(a), @abs(b));
23+
break :b std.math.IntFittingRange(-n, n);
24+
},
25+
else => |T| T,
26+
};
27+
28+
if (@typeInfo(S) != .int or @typeInfo(S).int.signedness != .signed) {
29+
@compileError("`a` and `b` must be signed integers");
30+
}
31+
32+
std.debug.assert(a != 0 or b != 0);
33+
34+
var x: S = @intCast(@abs(a));
35+
var y: S = @intCast(@abs(b));
36+
37+
// Mantain a = s * x + t * y.
38+
var s: S = std.math.sign(a);
39+
var t: S = 0;
40+
41+
// Mantain b = u * x + v * y.
42+
var u: S = 0;
43+
var v: S = std.math.sign(b);
44+
45+
while (x != 0) {
46+
const q = @divTrunc(y, x);
47+
const old_x = x;
48+
const old_s = s;
49+
const old_t = t;
50+
x = y - q * x;
51+
s = u - q * s;
52+
t = v - q * t;
53+
y = old_x;
54+
u = old_s;
55+
v = old_t;
56+
}
57+
58+
return .{ .gcd = y, .bezout_coeff_1 = u, .bezout_coeff_2 = v };
59+
}
60+
61+
test {
62+
{
63+
const a: i32 = 0;
64+
const b: i32 = 5;
65+
const s, const t = egcd(a, b);
66+
const g = std.math.gcd(@abs(a), @abs(b));
67+
try std.testing.expect(s * a + t * b == g);
68+
}
69+
{
70+
const a: i32 = 5;
71+
const b: i32 = 0;
72+
const s, const t = egcd(a, b);
73+
const g = std.math.gcd(@as(u32, @intCast(a)), @as(u32, @intCast(b)));
74+
try std.testing.expect(s * a + t * b == g);
75+
}
76+
77+
{
78+
const a: i32 = 21;
79+
const b: i32 = 15;
80+
const s, const t = egcd(a, b);
81+
const g = std.math.gcd(@abs(a), @abs(b));
82+
try std.testing.expect(s * a + t * b == g);
83+
}
84+
{
85+
const a: i32 = -21;
86+
const b: i32 = 15;
87+
const s, const t = egcd(a, b);
88+
const g = std.math.gcd(@abs(a), @abs(b));
89+
try std.testing.expect(s * a + t * b == g);
90+
}
91+
{
92+
const a = -21;
93+
const b = 15;
94+
const s, const t = egcd(a, b);
95+
const g = std.math.gcd(@abs(a), @abs(b));
96+
try std.testing.expect(s * a + t * b == g);
97+
}
98+
{
99+
const a = 927372692193078999176;
100+
const b = 573147844013817084101;
101+
const s, const t = egcd(a, b);
102+
const g = std.math.gcd(@abs(a), @abs(b));
103+
try std.testing.expect(s * a + t * b == g);
104+
}
105+
{
106+
const a = 453973694165307953197296969697410619233826;
107+
const b = 280571172992510140037611932413038677189525;
108+
const s, const t = egcd(a, b);
109+
const g = std.math.gcd(@abs(a), @abs(b));
110+
try std.testing.expect(s * a + t * b == g);
111+
}
112+
}

lib/std/math/gcd.zig

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
//! Greatest common divisor (https://mathworld.wolfram.com/GreatestCommonDivisor.html)
2-
const std = @import("std");
1+
//! Greatest Common Divisor (https://mathworld.wolfram.com/GreatestCommonDivisor.html)
2+
const std = @import("../std.zig");
33

4-
/// Returns the greatest common divisor (GCD) of two unsigned integers (`a` and `b`) which are not both zero.
4+
/// Returns the Greatest Common Divisor (GCD) of two unsigned integers (`a` and `b`) which are not both zero.
55
/// For example, the GCD of `8` and `12` is `4`, that is, `gcd(8, 12) == 4`.
66
pub fn gcd(a: anytype, b: anytype) @TypeOf(a, b) {
77
const N = switch (@TypeOf(a, b)) {

0 commit comments

Comments
 (0)