diff --git a/src/decimo/bigint/bigint.mojo b/src/decimo/bigint/bigint.mojo index 78004389..1b11d69d 100644 --- a/src/decimo/bigint/bigint.mojo +++ b/src/decimo/bigint/bigint.mojo @@ -28,6 +28,7 @@ in little-endian order, and a separate sign bit. """ from std.memory import UnsafePointer, memcpy +from std.sys import size_of import decimo.bigint.arithmetics import decimo.bigint.bitwise @@ -179,7 +180,7 @@ struct BigInt( self = Self.from_string(value) @implicit - def __init__(out self, value: Scalar): + def __init__(out self, value: Scalar) where value.dtype.is_integral(): """Constructs a BigInt from an integral scalar. This includes all SIMD integral types, such as Int8, Int16, UInt32, etc. @@ -210,72 +211,34 @@ struct BigInt( if value < 0: sign = True - # Handle Int.MIN (two's complement asymmetry) - if value == Int.MIN: - # |Int.MIN| = Int.MAX + 1 - magnitude = UInt(Int.MAX) + 1 - else: - magnitude = UInt(-value) + magnitude = UInt(0) - UInt(value) else: sign = False magnitude = UInt(value) - # Split the magnitude into 32-bit words - # On 64-bit platforms, Int is 64 bits → at most 2 words - var words = List[UInt32](capacity=2) - while magnitude != 0: + comptime if size_of[Int]() == 4: + # 32-bit platform: magnitude fits in 1 word + return Self(raw_words=[UInt32(magnitude)], sign=sign) + elif size_of[Int]() == 8: + # 64-bit platform: at most 2 words + var words = List[UInt32](capacity=2) words.append(UInt32(magnitude & 0xFFFF_FFFF)) - magnitude >>= 32 - - return Self(raw_words=words^, sign=sign) - - @staticmethod - def from_uint64(value: UInt64) -> Self: - """Creates a BigInt from a UInt64. - - Args: - value: The unsigned 64-bit integer value. - - Returns: - The BigInt representation. - """ - if value == 0: - return Self() - - var words = List[UInt32](capacity=2) - var lo = UInt32(value & 0xFFFF_FFFF) - var hi = UInt32(value >> 32) - words.append(lo) - if hi != 0: - words.append(hi) - - return Self(raw_words=words^, sign=False) - - @staticmethod - def from_uint128(value: UInt128) -> Self: - """Creates a BigInt from a UInt128. - - Args: - value: The unsigned 128-bit integer value. - - Returns: - The BigInt representation. - """ - if value == 0: - return Self() - - var words = List[UInt32](capacity=4) - var remaining = value - while remaining != 0: - words.append(UInt32(remaining & 0xFFFF_FFFF)) - remaining >>= 32 - - return Self(raw_words=words^, sign=False) + var hi = UInt32(magnitude >> 32) + if hi != 0: + words.append(hi) + return Self(raw_words=words^, sign=sign) + else: + comptime assert False, "unsupported platform Int size" @staticmethod - def from_integral_scalar[dtype: DType, //](value: SIMD[dtype, 1]) -> Self: + def from_integral_scalar[ + dtype: DType, // + ](value: SIMD[dtype, 1]) -> Self where dtype.is_integral(): """Initializes a BigInt from an integral scalar. - This includes all SIMD integral types, such as Int8, Int16, UInt32, etc. + This includes all SIMD integral types: + Int8, Int16, Int32, Int64, Int128, Int256, + UInt8, UInt16, UInt32, UInt64, UInt128, UInt256, + and the platform-sized Int (DType.int) and UInt (DType.uint). Constraints: The dtype must be integral. @@ -287,34 +250,164 @@ struct BigInt( The BigInt representation of the Scalar value. """ - comptime assert dtype.is_integral(), "dtype must be integral." - if value == 0: return Self() - var sign: Bool - var magnitude: UInt64 + # --- Unsigned types: direct word extraction via bit ops --- + + comptime if dtype == DType.uint8 or dtype == DType.uint16: + # Fits in 1 word + return Self(raw_words=[UInt32(value)], sign=False) + + elif dtype == DType.uint32: + return Self(raw_words=[UInt32(value)], sign=False) + + elif dtype == DType.uint64: + var words = List[UInt32](capacity=2) + words.append(UInt32(value & 0xFFFF_FFFF)) + var hi = UInt32(value >> 32) + if hi != 0: + words.append(hi) + return Self(raw_words=words^, sign=False) + + elif dtype == DType.uint128: + var words = List[UInt32](capacity=4) + var remaining = value + while remaining != 0: + words.append(UInt32(remaining & 0xFFFF_FFFF)) + remaining >>= 32 + return Self(raw_words=words^, sign=False) + + elif dtype == DType.uint256: + var words = List[UInt32](capacity=8) + var remaining = value + while remaining != 0: + words.append(UInt32(remaining & 0xFFFF_FFFF)) + remaining >>= 32 + return Self(raw_words=words^, sign=False) + + # --- Platform-sized UInt (pointer width, 32- or 64-bit) --- + + elif dtype == DType.uint: + comptime if size_of[Scalar[DType.uint]]() == 4: + # 32-bit platform: same as uint32 + return Self(raw_words=[UInt32(value)], sign=False) + elif size_of[Scalar[DType.uint]]() == 8: + # 64-bit platform: same as uint64 + var words = List[UInt32](capacity=2) + words.append(UInt32(value & 0xFFFF_FFFF)) + var hi = UInt32(value >> 32) + if hi != 0: + words.append(hi) + return Self(raw_words=words^, sign=False) + else: + comptime assert False, "unsupported platform UInt size" - comptime if dtype.is_unsigned(): - sign = False - magnitude = UInt64(value) - else: + # --- Signed types <= 64 bits: convert magnitude to UInt64 --- + + elif dtype == DType.int8 or dtype == DType.int16: + # Magnitude fits in 1 word if value < 0: - sign = True - # Compute magnitude using explicit two's-complement conversion + return Self(raw_words=[UInt32(-Int32(value))], sign=True) + else: + return Self(raw_words=[UInt32(value)], sign=False) + + elif dtype == DType.int32: + if value < 0: + var magnitude = UInt64(0) - UInt64(value) + var words = List[UInt32](capacity=2) + words.append(UInt32(magnitude & 0xFFFF_FFFF)) + var hi = UInt32(magnitude >> 32) + if hi != 0: + words.append(hi) + return Self(raw_words=words^, sign=True) + else: + return Self(raw_words=[UInt32(value)], sign=False) + + elif dtype == DType.int64: + var sign = value < 0 + var magnitude: UInt64 + if sign: magnitude = UInt64(0) - UInt64(value) else: - sign = False magnitude = UInt64(value) + var words = List[UInt32](capacity=2) + words.append(UInt32(magnitude & 0xFFFF_FFFF)) + var hi = UInt32(magnitude >> 32) + if hi != 0: + words.append(hi) + return Self(raw_words=words^, sign=sign) + + # --- Platform-sized Int (pointer width, 32- or 64-bit) --- + + elif dtype == DType.int: + comptime if size_of[Scalar[DType.int]]() == 4: + # 32-bit platform: same as int32 + if value < 0: + var magnitude = UInt64(0) - UInt64(value) + var words = List[UInt32](capacity=2) + words.append(UInt32(magnitude & 0xFFFF_FFFF)) + var hi = UInt32(magnitude >> 32) + if hi != 0: + words.append(hi) + return Self(raw_words=words^, sign=True) + else: + return Self(raw_words=[UInt32(value)], sign=False) + elif size_of[Scalar[DType.int]]() == 8: + # 64-bit platform: same as int64 + var sign = value < 0 + var magnitude: UInt64 + if sign: + magnitude = UInt64(0) - UInt64(value) + else: + magnitude = UInt64(value) + var words = List[UInt32](capacity=2) + words.append(UInt32(magnitude & 0xFFFF_FFFF)) + var hi = UInt32(magnitude >> 32) + if hi != 0: + words.append(hi) + return Self(raw_words=words^, sign=sign) + else: + comptime assert False, "unsupported platform Int size" + + # --- Int128: use division to extract 32-bit chunks --- + + elif dtype == DType.int128: + var sign = value < 0 + var words = List[UInt32](capacity=4) + var rem = Int128(value) + if sign: + while rem != 0: + var quotient = rem // Int128(-0x1_0000_0000) + var word_val = rem % Int128(-0x1_0000_0000) + words.append(UInt32(-word_val)) + rem = -quotient + else: + while rem != 0: + words.append(UInt32(rem & 0xFFFF_FFFF)) + rem >>= 32 + return Self(raw_words=words^, sign=sign) + + # --- Int256: use division to extract 32-bit chunks --- + + elif dtype == DType.int256: + var sign = value < 0 + var words = List[UInt32](capacity=8) + var rem = Int256(value) + if sign: + while rem != 0: + var quotient = rem // Int256(-0x1_0000_0000) + var word_val = rem % Int256(-0x1_0000_0000) + words.append(UInt32(-word_val)) + rem = -quotient + else: + while rem != 0: + words.append(UInt32(rem & 0xFFFF_FFFF)) + rem >>= 32 + return Self(raw_words=words^, sign=sign) - var words = List[UInt32](capacity=2) - var lo = UInt32(magnitude & 0xFFFF_FFFF) - var hi = UInt32(magnitude >> 32) - words.append(lo) - if hi != 0: - words.append(hi) - - return Self(raw_words=words^, sign=sign) + else: + comptime assert False, "unsupported integral dtype" @staticmethod def from_string(value: String) raises -> Self: diff --git a/src/decimo/bigint/exponential.mojo b/src/decimo/bigint/exponential.mojo index 5571ba73..9671b867 100644 --- a/src/decimo/bigint/exponential.mojo +++ b/src/decimo/bigint/exponential.mojo @@ -315,7 +315,7 @@ def sqrt(x: BigInt) raises -> BigInt: guess -= 1 while (guess + 1) * (guess + 1) <= val: guess += 1 - return BigInt.from_uint64(guess) + return BigInt.from_integral_scalar(guess) # For all larger inputs: optimized precision-doubling with UInt64 fast path return _sqrt_precision_doubling_fast(x) @@ -400,7 +400,7 @@ def _sqrt_precision_doubling_fast(x: BigInt) raises -> BigInt: ) if decimo.bigint.arithmetics._compare_word_lists(a_sq, x.words) > 0: a_val -= 1 - return BigInt.from_uint64(a_val) + return BigInt.from_integral_scalar(a_val) # --- Phase 1.5: UInt128 arithmetic for 1-2 more iterations --- # Extends the native phase to cover e+d up to ~126 bits, avoiding diff --git a/tests/bigint/test_bigint_conversion.mojo b/tests/bigint/test_bigint_conversion.mojo index 4b1111a9..3df9e2cc 100644 --- a/tests/bigint/test_bigint_conversion.mojo +++ b/tests/bigint/test_bigint_conversion.mojo @@ -67,6 +67,32 @@ def test_to_int_overflow() raises: testing.assert_true(raised, "to_int should raise for very large number") +# ===----------------------------------------------------------------------=== # +# Test: from_int edge cases (Int.MIN, Int.MAX) +# ===----------------------------------------------------------------------=== # + + +def test_from_int_edge_cases() raises: + """Test from_int with Int.MIN, Int.MAX, and other edge values.""" + # Int.MAX via from_int path + var max_val = BigInt(Int.MAX) + testing.assert_equal(String(max_val), "9223372036854775807") + testing.assert_equal(Int(max_val), Int.MAX) + + # Int.MIN via from_int path — the critical edge case + var min_val = BigInt(Int.MIN) + testing.assert_equal(String(min_val), "-9223372036854775808") + testing.assert_equal(Int(min_val), Int.MIN) + + # -1 via from_int + var neg_one = BigInt(-1) + testing.assert_equal(String(neg_one), "-1") + + # 0 via from_int + var zero = BigInt(Int(0)) + testing.assert_equal(String(zero), "0") + + # ===----------------------------------------------------------------------=== # # Test: from_integral_scalar / Scalar constructor # ===----------------------------------------------------------------------=== # @@ -106,6 +132,95 @@ def test_from_integral_scalar() raises: var i64 = BigInt(Int64(-9223372036854775808)) testing.assert_equal(String(i64), "-9223372036854775808") + # UInt128 + var u128_small = BigInt(UInt128(12345)) + testing.assert_equal(String(u128_small), "12345") + + var u128_large = BigInt(UInt128(80554649779790687400)) + testing.assert_equal(String(u128_large), "80554649779790687400") + + # UInt128.MAX = 340282366920938463463374607431768211455 + var u128_max = BigInt(UInt128.MAX) + testing.assert_equal( + String(u128_max), "340282366920938463463374607431768211455" + ) + + # Int128 + var i128_pos = BigInt(Int128(80554649779790687400)) + testing.assert_equal(String(i128_pos), "80554649779790687400") + + var i128_neg = BigInt(Int128(-80554649779790687400)) + testing.assert_equal(String(i128_neg), "-80554649779790687400") + + # Int128.MIN = -170141183460469231731687303715884105728 + var i128_min = BigInt(Int128.MIN) + testing.assert_equal( + String(i128_min), "-170141183460469231731687303715884105728" + ) + + # Int128.MAX = 170141183460469231731687303715884105727 + var i128_max = BigInt(Int128.MAX) + testing.assert_equal( + String(i128_max), "170141183460469231731687303715884105727" + ) + + # UInt256 + var u256_small = BigInt(UInt256(12345)) + testing.assert_equal(String(u256_small), "12345") + + var u256_large = BigInt(UInt256(80554649779790687400)) + testing.assert_equal(String(u256_large), "80554649779790687400") + + # UInt256 value larger than UInt64.MAX + var u256_big = BigInt(UInt256(8055464977979068740023761289648172697)) + testing.assert_equal( + String(u256_big), "8055464977979068740023761289648172697" + ) + + # Int256 + var i256_pos = BigInt(Int256(8055464977979068740023761289648172697)) + testing.assert_equal( + String(i256_pos), "8055464977979068740023761289648172697" + ) + + var i256_neg = BigInt(Int256(-8055464977979068740023761289648172697)) + testing.assert_equal( + String(i256_neg), "-8055464977979068740023761289648172697" + ) + + # Int256.MIN + var i256_min = BigInt(Int256.MIN) + testing.assert_equal( + String(i256_min), + "-57896044618658097711785492504343953926634992332820282019728792003956564819968", + ) + + # Int256.MAX + var i256_max = BigInt(Int256.MAX) + testing.assert_equal( + String(i256_max), + "57896044618658097711785492504343953926634992332820282019728792003956564819967", + ) + + # Platform-sized UInt + var u_plat = BigInt(UInt(18446744073709551615)) + testing.assert_equal(String(u_plat), "18446744073709551615") + + # Platform-sized Int + var i_plat_pos = BigInt(Scalar[DType.int](1234567890)) + testing.assert_equal(String(i_plat_pos), "1234567890") + + var i_plat_neg = BigInt(Scalar[DType.int](-1234567890)) + testing.assert_equal(String(i_plat_neg), "-1234567890") + + # Zero for various types + testing.assert_equal(String(BigInt(UInt8(0))), "0") + testing.assert_equal(String(BigInt(Int32(0))), "0") + testing.assert_equal(String(BigInt(UInt64(0))), "0") + testing.assert_equal(String(BigInt(Int128(0))), "0") + testing.assert_equal(String(BigInt(UInt256(0))), "0") + testing.assert_equal(String(BigInt(Int256(0))), "0") + # ===----------------------------------------------------------------------=== # # Test: D&C from_string for large numbers