@@ -3,58 +3,117 @@ const std = @import("../std.zig");
33
44/// Result type of `egcd`.
55pub fn ExtendedGreatestCommonDivisor (S : anytype ) type {
6+ const N = switch (S ) {
7+ comptime_int = > comptime_int ,
8+ else = > | T | std .meta .Int (.unsigned , @bitSizeOf (T )),
9+ };
10+
611 return struct {
7- gcd : S ,
12+ gcd : N ,
813 bezout_coeff_1 : S ,
914 bezout_coeff_2 : S ,
1015 };
1116}
1217
18+ fn egcd_helper (other : anytype , odd : anytype , shift : anytype ) [3 ]@TypeOf (other , odd ) {
19+ const S = @TypeOf (other , odd );
20+ const toinv = @shrExact (other , @intCast (shift ));
21+ const ctrl = @shrExact (odd , @intCast (shift ));
22+
23+ var s : S = std .math .sign (toinv );
24+ var t : S = 0 ;
25+
26+ var x = @abs (toinv );
27+ var y = @abs (ctrl );
28+
29+ while (x & 1 == 0 ) {
30+ x = @shrExact (x , 1 );
31+ s = @shrExact (if (s & 1 == 0 ) s else s + ctrl , 1 );
32+ }
33+
34+ var y_minus_x = y -% x ;
35+ while (y_minus_x != 0 ) : (y_minus_x = y -% x ) {
36+ const t_minus_s = t - s ;
37+ const copy_x = x ;
38+ const copy_s = s ;
39+
40+ s -= t ;
41+ const carry = x < y ;
42+ x -%= y ;
43+ if (carry ) {
44+ x = y_minus_x ;
45+ y = copy_x ;
46+ s = t_minus_s ;
47+ t = copy_s ;
48+ }
49+ while (x & 1 == 0 ) {
50+ x = @shrExact (x , 1 );
51+ s = @shrExact (if (s & 1 == 0 ) s else s + ctrl , 1 );
52+ }
53+ }
54+
55+ y = @shlExact (y , @intCast (shift ));
56+ s = @shlExact (s , @intCast (shift ));
57+ // Using integer widening is only a temporary solution.
58+ const W = std .meta .Int (.signed , @bitSizeOf (S ) * 2 );
59+ t = @intCast (@divExact (y - @as (W , s ) * toinv , ctrl ));
60+ return .{ @bitCast (y ), s , t };
61+ }
62+
1363/// Returns the Extended Greatest Common Divisor (EGCD) of two signed integers (`a` and `b`) which are not both zero.
1464pub fn egcd (a : anytype , b : anytype ) ExtendedGreatestCommonDivisor (@TypeOf (a , b )) {
1565 const S = switch (@TypeOf (a , b )) {
16- // convert comptime_int to some sized int type for @ctz
1766 comptime_int = > b : {
1867 const n = @max (@abs (a ), @abs (b ));
1968 break :b std .math .IntFittingRange (- n , n );
2069 },
2170 else = > | T | T ,
2271 };
23-
2472 if (@typeInfo (S ) != .int or @typeInfo (S ).int .signedness != .signed ) {
2573 @compileError ("`a` and `b` must be signed integers" );
2674 }
2775
2876 std .debug .assert (a != 0 or b != 0 );
2977
30- var x : S = @intCast ( @ abs (a )) ;
31- var y : S = @intCast ( @ abs (b )) ;
78+ if ( a == 0 ) return .{ . gcd = @abs (b ), . bezout_coeff_1 = 0 , . bezout_coeff_2 = std . math . sign ( b ) } ;
79+ if ( b == 0 ) return .{ . gcd = @abs (a ), . bezout_coeff_1 = std . math . sign ( a ), . bezout_coeff_2 = 0 } ;
3280
33- // Mantain a = s * x + t * y.
34- var s : S = std .math .sign (a );
35- var t : S = 0 ;
81+ const x : S = a ;
82+ const y : S = b ;
3683
37- // Mantain b = u * x + v * y.
38- var u : S = 0 ;
39- var v : S = std .math .sign (b );
40-
41- while (x != 0 ) {
42- const q = @divTrunc (y , x );
43- const old_x = x ;
44- const old_s = s ;
45- const old_t = t ;
46- x = y - q * x ;
47- s = u - q * s ;
48- t = v - q * t ;
49- y = old_x ;
50- u = old_s ;
51- v = old_t ;
52- }
84+ const xz = @ctz (x );
85+ const yz = @ctz (y );
86+ const shift = @min (xz , yz );
5387
54- return .{ .gcd = y , .bezout_coeff_1 = u , .bezout_coeff_2 = v };
88+ if (xz < yz ) {
89+ const gcd , const t , const s = egcd_helper (y , x , shift );
90+ return .{ .gcd = @intCast (gcd ), .bezout_coeff_1 = s , .bezout_coeff_2 = t };
91+ } else {
92+ const gcd , const s , const t = egcd_helper (x , y , shift );
93+ return .{ .gcd = @intCast (gcd ), .bezout_coeff_1 = s , .bezout_coeff_2 = t };
94+ }
5595}
5696
5797test {
98+ {
99+ const a : i4 = -8 ;
100+ const b : i4 = 1 ;
101+ const r = egcd (a , b );
102+ const g = r .gcd ;
103+ const s = r .bezout_coeff_1 ;
104+ const t = r .bezout_coeff_2 ;
105+ try std .testing .expect (s * a + t * b == g );
106+ }
107+ {
108+ const a : i4 = -8 ;
109+ const b : i4 = 5 ;
110+ const r = egcd (a , b );
111+ const g = r .gcd ;
112+ // Avoid overflow in assert.
113+ const s : i8 = r .bezout_coeff_1 ;
114+ const t : i8 = r .bezout_coeff_2 ;
115+ try std .testing .expect (s * a + t * b == g );
116+ }
58117 {
59118 const a : i32 = 0 ;
60119 const b : i32 = 5 ;
0 commit comments