From c20fe0aec6bcf19c4a05a57603e14f312b0355f7 Mon Sep 17 00:00:00 2001 From: WestonJB Date: Fri, 20 Feb 2026 11:42:21 -0500 Subject: [PATCH 1/5] Added rotmg --- src/level1/__init__.mojo | 1 + src/level1/rotmg.mojo | 150 +++++++++++++++++++++++++++++++++++++++ test-level1.mojo | 89 ++++++++++------------- 3 files changed, 189 insertions(+), 51 deletions(-) create mode 100644 src/level1/rotmg.mojo diff --git a/src/level1/__init__.mojo b/src/level1/__init__.mojo index 80477f0..14d7d04 100644 --- a/src/level1/__init__.mojo +++ b/src/level1/__init__.mojo @@ -5,6 +5,7 @@ from .copy_device import * from .rot_device import * from .rotg import * from .rotm_device import * +from .rotmg import * from .swap_device import * from .dot_device import * from .dotc_device import * diff --git a/src/level1/rotmg.mojo b/src/level1/rotmg.mojo new file mode 100644 index 0000000..ee736b9 --- /dev/null +++ b/src/level1/rotmg.mojo @@ -0,0 +1,150 @@ +fn blas_rotmg[dtype: DType]( + d1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + d2: UnsafePointer[Scalar[dtype], MutAnyOrigin], + x1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + y1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + param: UnsafePointer[SIMD[dtype, 5], MutAnyOrigin] +): + + var flag: Scalar[dtype] + var h11: Scalar[dtype] = 0 + var h12: Scalar[dtype] = 0 + var h21: Scalar[dtype] = 0 + var h22: Scalar[dtype] = 0 + var p1: Scalar[dtype] + var p2: Scalar[dtype] + var q1: Scalar[dtype] + var q2: Scalar[dtype] + var temp: Scalar[dtype] + var u: Scalar[dtype] + var gam: Scalar[dtype] = 4096.0 + var gamsq: Scalar[dtype] = 16777216.0 + var rgamsq: Scalar[dtype] = 5.9604645e-8 + + if (d1[] < 0): + # GO 0-H-D-AND-x1.. + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + # CASE-d1-NONNEGATIVE + p2 = d2[]*y1[] + if (p2 == 0): + flag = -2 + param[0] = flag + return + + # REGULAR-CASE.. + p1 = d1[]*x1[] + q2 = p2*y1[] + q1 = p1*x1[] + # + if (abs(q1) > abs(q2)): + h21 = -y1[]/x1[] + h12 = p2/p1 + # + u = 1 - h12*h21 + # + if (u > 0): + flag = 0 + d1[] /= u + d2[] /= u + x1[] *= u + else: + # This code path if here for safety. We do not expect this + # condition to ever hold except in edge cases with rounding + # errors. See DOI: 10.1145/355841.355847 + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + # + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + if (q2 < 0): + # GO 0-H-D-AND-x1.. + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + # + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + flag = 1 + h11 = p1/p2 + h22 = x1[]/y1[] + u = 1 + h11*h22 + temp = d2[]/u + d2[] = d1[]/u + d1[] = temp + x1[] = y1[]*u + + # PROCEDURE..SCALE-CHECK + if (d1[] != 0): + while ((d1[] <= rgamsq) or (d1[] >= gamsq)): + if (flag == 0): + h11 = 1 + h22 = 1 + flag = -1 + else: + h21 = -1 + h12 = 1 + flag = -1 + + if (d1[] <= rgamsq): + d1[] *= gam**2 + x1[] /= gam + h11 /= gam + h12 /= gam + else: + d1[] /= gam**2 + x1[] *= gam + h11 *= gam + h12 *= gam + + if (d2[] != 0): + while ( (abs(d2[]) <= rgamsq) or (abs(d2[]) >= gamsq) ): + if (flag == 0): + h11 = 1 + h22 = 1 + flag = -1 + else: + h21 = -1 + h12 = 1 + flag = -1 + + if (abs(d2[]) <= rgamsq): + d2[] *= gam**2 + h21 /= gam + h22 /= gam + else: + d2[] *= gam**2 + h21 *= gam + h22 *= gam + + if (flag < 0): + param[1] = h11 + param[2] = h21 + param[3] = h12 + param[4] = h22 + elif (flag == 0): + param[2] = h21 + param[3] = h12 + else: + param[1] = h11 + param[4] = h22 + + param[0] = flag + return diff --git a/test-level1.mojo b/test-level1.mojo index 1820efc..be5c3f7 100644 --- a/test-level1.mojo +++ b/test-level1.mojo @@ -691,57 +691,44 @@ def rotm_test[ assert_almost_equal(y_result[i], expected_y, atol=atol) -# def rotmg_test[ -# dtype: DType, -# size: Int -# ](): -# with DeviceContext() as ctx: -# # d1 and d2 must be positive -# var d1 = generate_random_scalar[dtype](1, 10000) -# var d2 = generate_random_scalar[dtype](1, 10000) -# var x1 = generate_random_scalar[dtype](-10000, 10000) -# var y1 = generate_random_scalar[dtype](-10000, 10000) - -# d_d1 = ctx.enqueue_create_buffer[dtype](1) -# d_d1.enqueue_fill(d1) -# d_d2 = ctx.enqueue_create_buffer[dtype](1) -# d_d2.enqueue_fill(d2) -# d_x1 = ctx.enqueue_create_buffer[dtype](1) -# d_x1.enqueue_fill(x1) -# d_y1 = ctx.enqueue_create_buffer[dtype](1) -# d_y1.enqueue_fill(y1) -# d_param = ctx.enqueue_create_buffer[dtype](5) - -# # Launch Mojo BLAS kernel -# # NOTE: not implemented -# # blas_rotmg[dtype]( -# # d1.unsafe_ptr(), -# # d2.unsafe_ptr(), -# # x1.unsafe_ptr(), -# # x2.unsafe_ptr(), -# # d_param.unsafe_ptr(), -# # ctx -# # ) - -# # Import SciPy and numpy -# sp = Python.import_module("scipy") -# np = Python.import_module("numpy") -# sp_blas = sp.linalg.blas - -# # srotmg - float32, drotmg - float64 -# if dtype == DType.float32: -# py_p = sp_blas.srotmg(d1, d2, x1, y1) -# elif dtype == DType.float64: -# py_p = sp_blas.drotmg(d1, d2, x1, y1) -# else: -# print(dtype , " is not supported by SciPy") -# return - -# # Only compare param -# with d_param.map_to_host() as mojo_param: -# for i in range(5): -# var py_ref = Scalar[dtype](py=py_p[i]) -# assert_equal(mojo_param[i], py_ref) +def rotmg_test[ + dtype: DType, + size: Int +](): + # d1 and d2 must be positive + var d1 = generate_random_scalar[dtype](1, 10000) + var d2 = generate_random_scalar[dtype](1, 10000) + var x1 = generate_random_scalar[dtype](-10000, 10000) + var y1 = generate_random_scalar[dtype](-10000, 10000) + var param = SIMD[dtype, 5]() + + # Launch Mojo BLAS kernel + blas_rotmg[dtype]( + UnsafePointer(to=d1), + UnsafePointer(to=d2), + UnsafePointer(to=x1), + UnsafePointer(to=y1), + UnsafePointer(to=param), + ) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + # srotmg - float32, drotmg - float64 + if dtype == DType.float32: + py_p = sp_blas.srotmg(d1, d2, x1, y1) + elif dtype == DType.float64: + py_p = sp_blas.drotmg(d1, d2, x1, y1) + else: + print(dtype , " is not supported by SciPy") + return + + # Only compare param + for i in range(5): + var py_ref = Scalar[dtype](py=py_p[i]) + assert_equal(param[i], py_ref) def scal_test[ From 6355f9296af61aa88ccb66b9a310bf400466c788 Mon Sep 17 00:00:00 2001 From: WestonJB Date: Wed, 4 Mar 2026 17:58:15 -0500 Subject: [PATCH 2/5] modified rotmg implementation and testing --- src/level1/rotmg.mojo | 4 ++-- test-level1.mojo | 21 ++++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/level1/rotmg.mojo b/src/level1/rotmg.mojo index ee736b9..4fdbfbf 100644 --- a/src/level1/rotmg.mojo +++ b/src/level1/rotmg.mojo @@ -3,7 +3,7 @@ fn blas_rotmg[dtype: DType]( d2: UnsafePointer[Scalar[dtype], MutAnyOrigin], x1: UnsafePointer[Scalar[dtype], MutAnyOrigin], y1: UnsafePointer[Scalar[dtype], MutAnyOrigin], - param: UnsafePointer[SIMD[dtype, 5], MutAnyOrigin] + mut param: List[Scalar[dtype]] ): var flag: Scalar[dtype] @@ -130,7 +130,7 @@ fn blas_rotmg[dtype: DType]( h21 /= gam h22 /= gam else: - d2[] *= gam**2 + d2[] /= gam**2 h21 *= gam h22 *= gam diff --git a/test-level1.mojo b/test-level1.mojo index be5c3f7..548d8d6 100644 --- a/test-level1.mojo +++ b/test-level1.mojo @@ -700,7 +700,7 @@ def rotmg_test[ var d2 = generate_random_scalar[dtype](1, 10000) var x1 = generate_random_scalar[dtype](-10000, 10000) var y1 = generate_random_scalar[dtype](-10000, 10000) - var param = SIMD[dtype, 5]() + var param = List[Scalar[dtype]](length=5, fill=0.0) # Launch Mojo BLAS kernel blas_rotmg[dtype]( @@ -708,7 +708,7 @@ def rotmg_test[ UnsafePointer(to=d2), UnsafePointer(to=x1), UnsafePointer(to=y1), - UnsafePointer(to=param), + param ) # Import SciPy and numpy @@ -725,6 +725,13 @@ def rotmg_test[ print(dtype , " is not supported by SciPy") return + # print(d1) + # print(d2) + # print(x1) + # print(y1) + # print(param) + # print(py_p) + # Only compare param for i in range(5): var py_ref = Scalar[dtype](py=py_p[i]) @@ -880,11 +887,11 @@ def test_rotm(): rotm_test[DType.float64, 256]() rotm_test[DType.float64, 4096]() -# def test_rotmg(): -# rotmg_test[DType.float32, 256]() -# rotmg_test[DType.float32, 4096]() -# rotmg_test[DType.float64, 256]() -# rotmg_test[DType.float64, 4096]() +def test_rotmg(): + rotmg_test[DType.float32, 256]() + rotmg_test[DType.float32, 4096]() + rotmg_test[DType.float64, 256]() + rotmg_test[DType.float64, 4096]() def test_scal(): scal_test[DType.float32, 256]() From 09e4def1b4aa83fa6e223546560b3a4f6f1cbe8c Mon Sep 17 00:00:00 2001 From: tdehoff Date: Mon, 16 Mar 2026 11:01:26 -0400 Subject: [PATCH 3/5] Added trsv impl and tests --- src/level2/__init__.mojo | 1 + src/level2/trsv_device.mojo | 165 +++++++++++++++++++++++++++ src/testing_utils/testing_utils.mojo | 2 +- test-level2.mojo | 112 +++++++++++++++++- 4 files changed, 276 insertions(+), 4 deletions(-) create mode 100644 src/level2/trsv_device.mojo diff --git a/src/level2/__init__.mojo b/src/level2/__init__.mojo index 244b222..af90075 100644 --- a/src/level2/__init__.mojo +++ b/src/level2/__init__.mojo @@ -3,3 +3,4 @@ from .ger_device import * from .syr_device import * from .syr2_device import * from .gbmv_device import * +from .trsv_device import * diff --git a/src/level2/trsv_device.mojo b/src/level2/trsv_device.mojo new file mode 100644 index 0000000..e485170 --- /dev/null +++ b/src/level2/trsv_device.mojo @@ -0,0 +1,165 @@ +from gpu import thread_idx, block_idx, block_dim, grid_dim +from gpu.host import DeviceContext +from math import ceildiv + +# level2.trsv +# Triangular matrix-vector solve +# Solves one of the systems of equations +# A*x = b, or A**T*x = b, +# where b and x are n element vectors and A is an n by n unit, or +# non-unit, upper or lower triangular matrix. +# +# uplo: 0 = lower triangular, 1 = upper triangular +# trans: 0 = no transpose, 1 = transpose +# diag: 0 = non-unit diagonal, 1 = unit diagonal +fn strsv_device( + uplo: Int, + trans: Int, + diag: Int, + n: Int, + A: UnsafePointer[Float32, ImmutAnyOrigin], + lda: Int, + x: UnsafePointer[Float32, MutAnyOrigin], + incx: Int, +): + var tid = block_dim.x * block_idx.x + thread_idx.x + if tid != 0: + return + + if not trans: + if not uplo: + # Lower triangular: forward substitution + for i in range(n): + var temp = x[i * incx] + for j in range(i): + temp -= A[i * lda + j] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + # Upper triangular: backward substitution + for k in range(n): + var i = n - 1 - k + var temp = x[i * incx] + for j in range(i + 1, n): + temp -= A[i * lda + j] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + if not uplo: + # Transpose lower triangular (A^T is upper): backward substitution + for k in range(n): + var i = n - 1 - k + var temp = x[i * incx] + for j in range(i + 1, n): + temp -= A[j * lda + i] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + # Transpose upper triangular (A^T is lower): forward substitution + for i in range(n): + var temp = x[i * incx] + for j in range(i): + temp -= A[j * lda + i] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + + +fn dtrsv_device( + uplo: Int, + trans: Int, + diag: Int, + n: Int, + A: UnsafePointer[Float64, ImmutAnyOrigin], + lda: Int, + x: UnsafePointer[Float64, MutAnyOrigin], + incx: Int, +): + var tid = block_dim.x * block_idx.x + thread_idx.x + if tid != 0: + return + + if not trans: + if not uplo: + # Lower triangular: forward substitution + for i in range(n): + var temp = x[i * incx] + for j in range(i): + temp -= A[i * lda + j] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + # Upper triangular: backward substitution + for k in range(n): + var i = n - 1 - k + var temp = x[i * incx] + for j in range(i + 1, n): + temp -= A[i * lda + j] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + if not uplo: + # Transpose lower triangular (A^T is upper): backward substitution + for k in range(n): + var i = n - 1 - k + var temp = x[i * incx] + for j in range(i + 1, n): + temp -= A[j * lda + i] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + # Transpose upper triangular (A^T is lower): forward substitution + for i in range(n): + var temp = x[i * incx] + for j in range(i): + temp -= A[j * lda + i] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + + +fn blas_trsv[dtype: DType]( + uplo: Int, + trans: Bool, + diag: Int, + n: Int, + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_x: UnsafePointer[Scalar[dtype], MutAnyOrigin], + incx: Int, + ctx: DeviceContext, +) raises: + + # NOTE: add error checking here? + # check n > 0 + # check incx > 0 + + var trans_i = 1 if trans else 0 + + @parameter + if dtype == DType.float32: + ctx.enqueue_function[strsv_device, strsv_device]( + uplo, trans_i, diag, + n, d_A, lda, + d_x, incx, + grid_dim=1, + block_dim=1, + ) + elif dtype == DType.float64: + ctx.enqueue_function[dtrsv_device, dtrsv_device]( + uplo, trans_i, diag, + n, d_A, lda, + d_x, incx, + grid_dim=1, + block_dim=1, + ) + else: + raise Error("blas_trsv: Unsupported type") + + ctx.synchronize() diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index 184eccd..9e2b6e7 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -37,7 +37,7 @@ def generate_random_scalar[ # Error check following BLAS++ check_gemm: -# NOTE: can't get epsilon value in Mojo; using tol32, tol64 +# NOTE: might use this for dot, gemv, ger, geru, gemm, symv, hemv, symm, trmv, trsv?, trmm, trsm? fn check_gemm_error[dtype: DType]( m: Int, n: Int, k: Int, alpha: Scalar[dtype], diff --git a/test-level2.mojo b/test-level2.mojo index acaa4fe..404ce2f 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -170,7 +170,6 @@ def ger_test[ for j in range(n): assert_almost_equal(Scalar[dtype](py=sp_res[i][j]), res_mojo[(i*n)+j], atol=atol) - def syr_test[ dtype: DType, n: Int, @@ -428,6 +427,93 @@ def gbmv_test[ ) assert_true(ok) +def trsv_test[ + dtype: DType, + n: Int, + trans: Bool, + uplo: Int, + diag: Int, +](): + with DeviceContext() as ctx: + A_d = ctx.enqueue_create_buffer[dtype](n * n) + A = ctx.enqueue_create_host_buffer[dtype](n * n) + x_d = ctx.enqueue_create_buffer[dtype](n) + x = ctx.enqueue_create_host_buffer[dtype](n) + + generate_random_arr[dtype](n * n, A.unsafe_ptr(), -1, 1) + generate_random_arr[dtype](n, x.unsafe_ptr(), -1, 1) + + # Make A diagonally dominant to reduce numerical instability + for i in range(n): + A[i * n + i] += 1000 + + ctx.enqueue_copy(A_d, A) + ctx.enqueue_copy(x_d, x) + ctx.synchronize() + + # Compute norms for error checks + var norm_A = frobenius_norm[dtype](A.unsafe_ptr(), n * n) + var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), n) + + blas_trsv[dtype]( + uplo, + trans, + diag, + n, + A_d.unsafe_ptr(), n, + x_d.unsafe_ptr(), 1, + ctx, + ) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_x = Python.list() + for i in range(n * n): + py_A.append(A[i]) + for i in range(n): + py_x.append(x[i]) + + var sp_res: PythonObject + + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32).reshape(n, n) + np_x = np.array(py_x, dtype=np.float32) + sp_res = sp_blas.strsv(np_A, np_x, + lower=0 if uplo else 1, + trans=1 if trans else 0, + diag=1 if diag else 0 + ) + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64).reshape(n, n) + np_x = np.array(py_x, dtype=np.float64) + sp_res = sp_blas.dtrsv(np_A, np_x, + lower=0 if uplo else 1, + trans=1 if trans else 0, + diag=1 if diag else 0 + ) + else: + print("Unsupported type: ", dtype) + return + + with x_d.map_to_host() as res_mojo: + var norm_diff = Scalar[dtype](0) + for i in range(n): + var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i]) + norm_diff += diff * diff + norm_diff = sqrt(norm_diff) + + var ok = check_gemm_error[dtype]( + 1, n, n, + Scalar[dtype](1), Scalar[dtype](0), + norm_A, norm_x, Scalar[dtype](0), + norm_diff + ) + assert_true(ok) + def test_gemv(): gemv_test[DType.float32, 64, 64, False]() gemv_test[DType.float32, 64, 64, True]() @@ -466,6 +552,24 @@ def test_gbmv(): gbmv_test[DType.float64, 512, 64, 1, 2, False]() gbmv_test[DType.float64, 512, 64, 2, 2, True]() +def test_trsv(): + trsv_test[DType.float32, 64, True, 0, 0]() + trsv_test[DType.float32, 64, True, 1, 0]() + trsv_test[DType.float32, 64, True, 0, 1]() + trsv_test[DType.float32, 64, True, 1, 1]() + trsv_test[DType.float32, 64, False, 0, 0]() + trsv_test[DType.float32, 64, False, 1, 0]() + trsv_test[DType.float32, 64, False, 0, 1]() + trsv_test[DType.float32, 64, False, 1, 1]() + trsv_test[DType.float64, 64, True, 0, 0]() + trsv_test[DType.float64, 64, True, 1, 0]() + trsv_test[DType.float64, 64, True, 0, 1]() + trsv_test[DType.float64, 64, True, 1, 1]() + trsv_test[DType.float64, 64, False, 0, 0]() + trsv_test[DType.float64, 64, False, 1, 0]() + trsv_test[DType.float64, 64, False, 0, 1]() + trsv_test[DType.float64, 64, False, 1, 1]() + def main(): print("--- MojoBLAS Level 2 routines testing ---") var args = argv() @@ -477,8 +581,10 @@ def main(): for i in range(1, len(args)): if args[i] == "gemv": suite.test[test_gemv]() elif args[i] == "ger": suite.test[test_ger]() - # elif args[i] == "syr": suite.test[test_syr]() - # elif args[i] == "syr2": suite.test[test_syr2]() + elif args[i] == "syr": suite.test[test_syr]() + elif args[i] == "syr2": suite.test[test_syr2]() + elif args[i] == "gbmv": suite.test[test_gbmv]() + elif args[i] == "trsv": suite.test[test_trsv]() else: print("unknown routine:", args[i]) suite^.run() From 54ea7120a689d244ae80f7211b27dc0832a3d270 Mon Sep 17 00:00:00 2001 From: tdehoff Date: Mon, 16 Mar 2026 11:19:46 -0400 Subject: [PATCH 4/5] Saved initial d1, d2, x1 values in rotmg tests --- test-level1.mojo | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/test-level1.mojo b/test-level1.mojo index 548d8d6..cc1a683 100644 --- a/test-level1.mojo +++ b/test-level1.mojo @@ -702,6 +702,11 @@ def rotmg_test[ var y1 = generate_random_scalar[dtype](-10000, 10000) var param = List[Scalar[dtype]](length=5, fill=0.0) + # Save initial values before we modify d1, d2, x1 in-place + var d1_ini = d1 + var d2_ini = d2 + var x1_ini = x1 + # Launch Mojo BLAS kernel blas_rotmg[dtype]( UnsafePointer(to=d1), @@ -718,19 +723,13 @@ def rotmg_test[ # srotmg - float32, drotmg - float64 if dtype == DType.float32: - py_p = sp_blas.srotmg(d1, d2, x1, y1) + py_p = sp_blas.srotmg(d1_ini, d2_ini, x1_ini, y1) elif dtype == DType.float64: - py_p = sp_blas.drotmg(d1, d2, x1, y1) + py_p = sp_blas.drotmg(d1_ini, d2_ini, x1_ini, y1) else: print(dtype , " is not supported by SciPy") return - # print(d1) - # print(d2) - # print(x1) - # print(y1) - # print(param) - # print(py_p) # Only compare param for i in range(5): @@ -925,7 +924,7 @@ def main(): elif args[i] == "rot": suite.test[test_rot]() elif args[i] == "rotg": suite.test[test_rotg]() elif args[i] == "rotm": suite.test[test_rotm]() - # elif args[i] == "rotmg": suite.test[test_rotmg]() + elif args[i] == "rotmg": suite.test[test_rotmg]() elif args[i] == "scal": suite.test[test_scal]() elif args[i] == "swap": suite.test[test_swap]() else: print("unknown routine:", args[i]) From 250a9c9858616bb88fc589e9c07973dcea3e996d Mon Sep 17 00:00:00 2001 From: tdehoff Date: Wed, 18 Mar 2026 15:45:05 -0400 Subject: [PATCH 5/5] generate accurate imput matrix --- test-level2.mojo | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/test-level2.mojo b/test-level2.mojo index 404ce2f..7722d8c 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -443,9 +443,26 @@ def trsv_test[ generate_random_arr[dtype](n * n, A.unsafe_ptr(), -1, 1) generate_random_arr[dtype](n, x.unsafe_ptr(), -1, 1) - # Make A diagonally dominant to reduce numerical instability + # Zero out off-triangle elements to make A strictly triangular for i in range(n): - A[i * n + i] += 1000 + for j in range(n): + # upper triangular + if uplo == 0: + if i > j: + A[i * n + j] = 0 + # lower triangular + else: + if i < j: + A[i * n + j] = 0 + + # Handle diagonal based on diag parameter + for i in range(n): + # unit diagonal + if diag == 1: + A[i * n + i] = 1 + # non-unit diagonal: make diagonally dominant to reduce numerical instability + else: + A[i * n + i] += 1000 ctx.enqueue_copy(A_d, A) ctx.enqueue_copy(x_d, x)