From c20fe0aec6bcf19c4a05a57603e14f312b0355f7 Mon Sep 17 00:00:00 2001 From: WestonJB Date: Fri, 20 Feb 2026 11:42:21 -0500 Subject: [PATCH 1/2] Added rotmg --- src/level1/__init__.mojo | 1 + src/level1/rotmg.mojo | 150 +++++++++++++++++++++++++++++++++++++++ test-level1.mojo | 89 ++++++++++------------- 3 files changed, 189 insertions(+), 51 deletions(-) create mode 100644 src/level1/rotmg.mojo diff --git a/src/level1/__init__.mojo b/src/level1/__init__.mojo index 80477f0..14d7d04 100644 --- a/src/level1/__init__.mojo +++ b/src/level1/__init__.mojo @@ -5,6 +5,7 @@ from .copy_device import * from .rot_device import * from .rotg import * from .rotm_device import * +from .rotmg import * from .swap_device import * from .dot_device import * from .dotc_device import * diff --git a/src/level1/rotmg.mojo b/src/level1/rotmg.mojo new file mode 100644 index 0000000..ee736b9 --- /dev/null +++ b/src/level1/rotmg.mojo @@ -0,0 +1,150 @@ +fn blas_rotmg[dtype: DType]( + d1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + d2: UnsafePointer[Scalar[dtype], MutAnyOrigin], + x1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + y1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + param: UnsafePointer[SIMD[dtype, 5], MutAnyOrigin] +): + + var flag: Scalar[dtype] + var h11: Scalar[dtype] = 0 + var h12: Scalar[dtype] = 0 + var h21: Scalar[dtype] = 0 + var h22: Scalar[dtype] = 0 + var p1: Scalar[dtype] + var p2: Scalar[dtype] + var q1: Scalar[dtype] + var q2: Scalar[dtype] + var temp: Scalar[dtype] + var u: Scalar[dtype] + var gam: Scalar[dtype] = 4096.0 + var gamsq: Scalar[dtype] = 16777216.0 + var rgamsq: Scalar[dtype] = 5.9604645e-8 + + if (d1[] < 0): + # GO 0-H-D-AND-x1.. + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + # CASE-d1-NONNEGATIVE + p2 = d2[]*y1[] + if (p2 == 0): + flag = -2 + param[0] = flag + return + + # REGULAR-CASE.. + p1 = d1[]*x1[] + q2 = p2*y1[] + q1 = p1*x1[] + # + if (abs(q1) > abs(q2)): + h21 = -y1[]/x1[] + h12 = p2/p1 + # + u = 1 - h12*h21 + # + if (u > 0): + flag = 0 + d1[] /= u + d2[] /= u + x1[] *= u + else: + # This code path if here for safety. We do not expect this + # condition to ever hold except in edge cases with rounding + # errors. See DOI: 10.1145/355841.355847 + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + # + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + if (q2 < 0): + # GO 0-H-D-AND-x1.. + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + # + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + flag = 1 + h11 = p1/p2 + h22 = x1[]/y1[] + u = 1 + h11*h22 + temp = d2[]/u + d2[] = d1[]/u + d1[] = temp + x1[] = y1[]*u + + # PROCEDURE..SCALE-CHECK + if (d1[] != 0): + while ((d1[] <= rgamsq) or (d1[] >= gamsq)): + if (flag == 0): + h11 = 1 + h22 = 1 + flag = -1 + else: + h21 = -1 + h12 = 1 + flag = -1 + + if (d1[] <= rgamsq): + d1[] *= gam**2 + x1[] /= gam + h11 /= gam + h12 /= gam + else: + d1[] /= gam**2 + x1[] *= gam + h11 *= gam + h12 *= gam + + if (d2[] != 0): + while ( (abs(d2[]) <= rgamsq) or (abs(d2[]) >= gamsq) ): + if (flag == 0): + h11 = 1 + h22 = 1 + flag = -1 + else: + h21 = -1 + h12 = 1 + flag = -1 + + if (abs(d2[]) <= rgamsq): + d2[] *= gam**2 + h21 /= gam + h22 /= gam + else: + d2[] *= gam**2 + h21 *= gam + h22 *= gam + + if (flag < 0): + param[1] = h11 + param[2] = h21 + param[3] = h12 + param[4] = h22 + elif (flag == 0): + param[2] = h21 + param[3] = h12 + else: + param[1] = h11 + param[4] = h22 + + param[0] = flag + return diff --git a/test-level1.mojo b/test-level1.mojo index 1820efc..be5c3f7 100644 --- a/test-level1.mojo +++ b/test-level1.mojo @@ -691,57 +691,44 @@ def rotm_test[ assert_almost_equal(y_result[i], expected_y, atol=atol) -# def rotmg_test[ -# dtype: DType, -# size: Int -# ](): -# with DeviceContext() as ctx: -# # d1 and d2 must be positive -# var d1 = generate_random_scalar[dtype](1, 10000) -# var d2 = generate_random_scalar[dtype](1, 10000) -# var x1 = generate_random_scalar[dtype](-10000, 10000) -# var y1 = generate_random_scalar[dtype](-10000, 10000) - -# d_d1 = ctx.enqueue_create_buffer[dtype](1) -# d_d1.enqueue_fill(d1) -# d_d2 = ctx.enqueue_create_buffer[dtype](1) -# d_d2.enqueue_fill(d2) -# d_x1 = ctx.enqueue_create_buffer[dtype](1) -# d_x1.enqueue_fill(x1) -# d_y1 = ctx.enqueue_create_buffer[dtype](1) -# d_y1.enqueue_fill(y1) -# d_param = ctx.enqueue_create_buffer[dtype](5) - -# # Launch Mojo BLAS kernel -# # NOTE: not implemented -# # blas_rotmg[dtype]( -# # d1.unsafe_ptr(), -# # d2.unsafe_ptr(), -# # x1.unsafe_ptr(), -# # x2.unsafe_ptr(), -# # d_param.unsafe_ptr(), -# # ctx -# # ) - -# # Import SciPy and numpy -# sp = Python.import_module("scipy") -# np = Python.import_module("numpy") -# sp_blas = sp.linalg.blas - -# # srotmg - float32, drotmg - float64 -# if dtype == DType.float32: -# py_p = sp_blas.srotmg(d1, d2, x1, y1) -# elif dtype == DType.float64: -# py_p = sp_blas.drotmg(d1, d2, x1, y1) -# else: -# print(dtype , " is not supported by SciPy") -# return - -# # Only compare param -# with d_param.map_to_host() as mojo_param: -# for i in range(5): -# var py_ref = Scalar[dtype](py=py_p[i]) -# assert_equal(mojo_param[i], py_ref) +def rotmg_test[ + dtype: DType, + size: Int +](): + # d1 and d2 must be positive + var d1 = generate_random_scalar[dtype](1, 10000) + var d2 = generate_random_scalar[dtype](1, 10000) + var x1 = generate_random_scalar[dtype](-10000, 10000) + var y1 = generate_random_scalar[dtype](-10000, 10000) + var param = SIMD[dtype, 5]() + + # Launch Mojo BLAS kernel + blas_rotmg[dtype]( + UnsafePointer(to=d1), + UnsafePointer(to=d2), + UnsafePointer(to=x1), + UnsafePointer(to=y1), + UnsafePointer(to=param), + ) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + # srotmg - float32, drotmg - float64 + if dtype == DType.float32: + py_p = sp_blas.srotmg(d1, d2, x1, y1) + elif dtype == DType.float64: + py_p = sp_blas.drotmg(d1, d2, x1, y1) + else: + print(dtype , " is not supported by SciPy") + return + + # Only compare param + for i in range(5): + var py_ref = Scalar[dtype](py=py_p[i]) + assert_equal(param[i], py_ref) def scal_test[ From 6355f9296af61aa88ccb66b9a310bf400466c788 Mon Sep 17 00:00:00 2001 From: WestonJB Date: Wed, 4 Mar 2026 17:58:15 -0500 Subject: [PATCH 2/2] modified rotmg implementation and testing --- src/level1/rotmg.mojo | 4 ++-- test-level1.mojo | 21 ++++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/level1/rotmg.mojo b/src/level1/rotmg.mojo index ee736b9..4fdbfbf 100644 --- a/src/level1/rotmg.mojo +++ b/src/level1/rotmg.mojo @@ -3,7 +3,7 @@ fn blas_rotmg[dtype: DType]( d2: UnsafePointer[Scalar[dtype], MutAnyOrigin], x1: UnsafePointer[Scalar[dtype], MutAnyOrigin], y1: UnsafePointer[Scalar[dtype], MutAnyOrigin], - param: UnsafePointer[SIMD[dtype, 5], MutAnyOrigin] + mut param: List[Scalar[dtype]] ): var flag: Scalar[dtype] @@ -130,7 +130,7 @@ fn blas_rotmg[dtype: DType]( h21 /= gam h22 /= gam else: - d2[] *= gam**2 + d2[] /= gam**2 h21 *= gam h22 *= gam diff --git a/test-level1.mojo b/test-level1.mojo index be5c3f7..548d8d6 100644 --- a/test-level1.mojo +++ b/test-level1.mojo @@ -700,7 +700,7 @@ def rotmg_test[ var d2 = generate_random_scalar[dtype](1, 10000) var x1 = generate_random_scalar[dtype](-10000, 10000) var y1 = generate_random_scalar[dtype](-10000, 10000) - var param = SIMD[dtype, 5]() + var param = List[Scalar[dtype]](length=5, fill=0.0) # Launch Mojo BLAS kernel blas_rotmg[dtype]( @@ -708,7 +708,7 @@ def rotmg_test[ UnsafePointer(to=d2), UnsafePointer(to=x1), UnsafePointer(to=y1), - UnsafePointer(to=param), + param ) # Import SciPy and numpy @@ -725,6 +725,13 @@ def rotmg_test[ print(dtype , " is not supported by SciPy") return + # print(d1) + # print(d2) + # print(x1) + # print(y1) + # print(param) + # print(py_p) + # Only compare param for i in range(5): var py_ref = Scalar[dtype](py=py_p[i]) @@ -880,11 +887,11 @@ def test_rotm(): rotm_test[DType.float64, 256]() rotm_test[DType.float64, 4096]() -# def test_rotmg(): -# rotmg_test[DType.float32, 256]() -# rotmg_test[DType.float32, 4096]() -# rotmg_test[DType.float64, 256]() -# rotmg_test[DType.float64, 4096]() +def test_rotmg(): + rotmg_test[DType.float32, 256]() + rotmg_test[DType.float32, 4096]() + rotmg_test[DType.float64, 256]() + rotmg_test[DType.float64, 4096]() def test_scal(): scal_test[DType.float32, 256]()