Skip to content

Commit 600fed7

Browse files
committed
Use binary gcd and make overflow safe.
1 parent 102b106 commit 600fed7

File tree

1 file changed

+84
-25
lines changed

1 file changed

+84
-25
lines changed

lib/std/math/egcd.zig

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,117 @@ const std = @import("../std.zig");
33

44
/// Result type of `egcd`.
55
pub fn ExtendedGreatestCommonDivisor(S: anytype) type {
6+
const N = switch (S) {
7+
comptime_int => comptime_int,
8+
else => |T| std.meta.Int(.unsigned, @bitSizeOf(T)),
9+
};
10+
611
return struct {
7-
gcd: S,
12+
gcd: N,
813
bezout_coeff_1: S,
914
bezout_coeff_2: S,
1015
};
1116
}
1217

18+
fn egcd_helper(other: anytype, odd: anytype, shift: anytype) [3]@TypeOf(other, odd) {
19+
const S = @TypeOf(other, odd);
20+
const toinv = @shrExact(other, @intCast(shift));
21+
const ctrl = @shrExact(odd, @intCast(shift));
22+
23+
var s: S = std.math.sign(toinv);
24+
var t: S = 0;
25+
26+
var x = @abs(toinv);
27+
var y = @abs(ctrl);
28+
29+
while (x & 1 == 0) {
30+
x = @shrExact(x, 1);
31+
s = @shrExact(if (s & 1 == 0) s else s + ctrl, 1);
32+
}
33+
34+
var y_minus_x = y -% x;
35+
while (y_minus_x != 0) : (y_minus_x = y -% x) {
36+
const t_minus_s = t - s;
37+
const copy_x = x;
38+
const copy_s = s;
39+
40+
s -= t;
41+
const carry = x < y;
42+
x -%= y;
43+
if (carry) {
44+
x = y_minus_x;
45+
y = copy_x;
46+
s = t_minus_s;
47+
t = copy_s;
48+
}
49+
while (x & 1 == 0) {
50+
x = @shrExact(x, 1);
51+
s = @shrExact(if (s & 1 == 0) s else s + ctrl, 1);
52+
}
53+
}
54+
55+
y = @shlExact(y, @intCast(shift));
56+
s = @shlExact(s, @intCast(shift));
57+
// Using integer widening is only a temporary solution.
58+
const W = std.meta.Int(.signed, @bitSizeOf(S) * 2);
59+
t = @intCast(@divExact(y - @as(W, s) * toinv, ctrl));
60+
return .{ @bitCast(y), s, t };
61+
}
62+
1363
/// Returns the Extended Greatest Common Divisor (EGCD) of two signed integers (`a` and `b`) which are not both zero.
1464
pub fn egcd(a: anytype, b: anytype) ExtendedGreatestCommonDivisor(@TypeOf(a, b)) {
1565
const S = switch (@TypeOf(a, b)) {
16-
// convert comptime_int to some sized int type for @ctz
1766
comptime_int => b: {
1867
const n = @max(@abs(a), @abs(b));
1968
break :b std.math.IntFittingRange(-n, n);
2069
},
2170
else => |T| T,
2271
};
23-
2472
if (@typeInfo(S) != .int or @typeInfo(S).int.signedness != .signed) {
2573
@compileError("`a` and `b` must be signed integers");
2674
}
2775

2876
std.debug.assert(a != 0 or b != 0);
2977

30-
var x: S = @intCast(@abs(a));
31-
var y: S = @intCast(@abs(b));
78+
if (a == 0) return .{ .gcd = @abs(b), .bezout_coeff_1 = 0, .bezout_coeff_2 = std.math.sign(b) };
79+
if (b == 0) return .{ .gcd = @abs(a), .bezout_coeff_1 = std.math.sign(a), .bezout_coeff_2 = 0 };
3280

33-
// Mantain a = s * x + t * y.
34-
var s: S = std.math.sign(a);
35-
var t: S = 0;
81+
const x: S = a;
82+
const y: S = b;
3683

37-
// Mantain b = u * x + v * y.
38-
var u: S = 0;
39-
var v: S = std.math.sign(b);
40-
41-
while (x != 0) {
42-
const q = @divTrunc(y, x);
43-
const old_x = x;
44-
const old_s = s;
45-
const old_t = t;
46-
x = y - q * x;
47-
s = u - q * s;
48-
t = v - q * t;
49-
y = old_x;
50-
u = old_s;
51-
v = old_t;
52-
}
84+
const xz = @ctz(x);
85+
const yz = @ctz(y);
86+
const shift = @min(xz, yz);
5387

54-
return .{ .gcd = y, .bezout_coeff_1 = u, .bezout_coeff_2 = v };
88+
if (xz < yz) {
89+
const gcd, const t, const s = egcd_helper(y, x, shift);
90+
return .{ .gcd = @intCast(gcd), .bezout_coeff_1 = s, .bezout_coeff_2 = t };
91+
} else {
92+
const gcd, const s, const t = egcd_helper(x, y, shift);
93+
return .{ .gcd = @intCast(gcd), .bezout_coeff_1 = s, .bezout_coeff_2 = t };
94+
}
5595
}
5696

5797
test {
98+
{
99+
const a: i4 = -8;
100+
const b: i4 = 1;
101+
const r = egcd(a, b);
102+
const g = r.gcd;
103+
const s = r.bezout_coeff_1;
104+
const t = r.bezout_coeff_2;
105+
try std.testing.expect(s * a + t * b == g);
106+
}
107+
{
108+
const a: i4 = -8;
109+
const b: i4 = 5;
110+
const r = egcd(a, b);
111+
const g = r.gcd;
112+
// Avoid overflow in assert.
113+
const s: i8 = r.bezout_coeff_1;
114+
const t: i8 = r.bezout_coeff_2;
115+
try std.testing.expect(s * a + t * b == g);
116+
}
58117
{
59118
const a: i32 = 0;
60119
const b: i32 = 5;

0 commit comments

Comments
 (0)