diff --git a/src/level1/__init__.mojo b/src/level1/__init__.mojo index 80477f0..14d7d04 100644 --- a/src/level1/__init__.mojo +++ b/src/level1/__init__.mojo @@ -5,6 +5,7 @@ from .copy_device import * from .rot_device import * from .rotg import * from .rotm_device import * +from .rotmg import * from .swap_device import * from .dot_device import * from .dotc_device import * diff --git a/src/level1/rotmg.mojo b/src/level1/rotmg.mojo new file mode 100644 index 0000000..4fdbfbf --- /dev/null +++ b/src/level1/rotmg.mojo @@ -0,0 +1,150 @@ +fn blas_rotmg[dtype: DType]( + d1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + d2: UnsafePointer[Scalar[dtype], MutAnyOrigin], + x1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + y1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + mut param: List[Scalar[dtype]] +): + + var flag: Scalar[dtype] + var h11: Scalar[dtype] = 0 + var h12: Scalar[dtype] = 0 + var h21: Scalar[dtype] = 0 + var h22: Scalar[dtype] = 0 + var p1: Scalar[dtype] + var p2: Scalar[dtype] + var q1: Scalar[dtype] + var q2: Scalar[dtype] + var temp: Scalar[dtype] + var u: Scalar[dtype] + var gam: Scalar[dtype] = 4096.0 + var gamsq: Scalar[dtype] = 16777216.0 + var rgamsq: Scalar[dtype] = 5.9604645e-8 + + if (d1[] < 0): + # GO 0-H-D-AND-x1.. + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + # CASE-d1-NONNEGATIVE + p2 = d2[]*y1[] + if (p2 == 0): + flag = -2 + param[0] = flag + return + + # REGULAR-CASE.. + p1 = d1[]*x1[] + q2 = p2*y1[] + q1 = p1*x1[] + # + if (abs(q1) > abs(q2)): + h21 = -y1[]/x1[] + h12 = p2/p1 + # + u = 1 - h12*h21 + # + if (u > 0): + flag = 0 + d1[] /= u + d2[] /= u + x1[] *= u + else: + # This code path if here for safety. We do not expect this + # condition to ever hold except in edge cases with rounding + # errors. See DOI: 10.1145/355841.355847 + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + # + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + if (q2 < 0): + # GO 0-H-D-AND-x1.. + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + # + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + flag = 1 + h11 = p1/p2 + h22 = x1[]/y1[] + u = 1 + h11*h22 + temp = d2[]/u + d2[] = d1[]/u + d1[] = temp + x1[] = y1[]*u + + # PROCEDURE..SCALE-CHECK + if (d1[] != 0): + while ((d1[] <= rgamsq) or (d1[] >= gamsq)): + if (flag == 0): + h11 = 1 + h22 = 1 + flag = -1 + else: + h21 = -1 + h12 = 1 + flag = -1 + + if (d1[] <= rgamsq): + d1[] *= gam**2 + x1[] /= gam + h11 /= gam + h12 /= gam + else: + d1[] /= gam**2 + x1[] *= gam + h11 *= gam + h12 *= gam + + if (d2[] != 0): + while ( (abs(d2[]) <= rgamsq) or (abs(d2[]) >= gamsq) ): + if (flag == 0): + h11 = 1 + h22 = 1 + flag = -1 + else: + h21 = -1 + h12 = 1 + flag = -1 + + if (abs(d2[]) <= rgamsq): + d2[] *= gam**2 + h21 /= gam + h22 /= gam + else: + d2[] /= gam**2 + h21 *= gam + h22 *= gam + + if (flag < 0): + param[1] = h11 + param[2] = h21 + param[3] = h12 + param[4] = h22 + elif (flag == 0): + param[2] = h21 + param[3] = h12 + else: + param[1] = h11 + param[4] = h22 + + param[0] = flag + return diff --git a/src/level2/__init__.mojo b/src/level2/__init__.mojo index 244b222..af90075 100644 --- a/src/level2/__init__.mojo +++ b/src/level2/__init__.mojo @@ -3,3 +3,4 @@ from .ger_device import * from .syr_device import * from .syr2_device import * from .gbmv_device import * +from .trsv_device import * diff --git a/src/level2/trsv_device.mojo b/src/level2/trsv_device.mojo new file mode 100644 index 0000000..e485170 --- /dev/null +++ b/src/level2/trsv_device.mojo @@ -0,0 +1,165 @@ +from gpu import thread_idx, block_idx, block_dim, grid_dim +from gpu.host import DeviceContext +from math import ceildiv + +# level2.trsv +# Triangular matrix-vector solve +# Solves one of the systems of equations +# A*x = b, or A**T*x = b, +# where b and x are n element vectors and A is an n by n unit, or +# non-unit, upper or lower triangular matrix. +# +# uplo: 0 = lower triangular, 1 = upper triangular +# trans: 0 = no transpose, 1 = transpose +# diag: 0 = non-unit diagonal, 1 = unit diagonal +fn strsv_device( + uplo: Int, + trans: Int, + diag: Int, + n: Int, + A: UnsafePointer[Float32, ImmutAnyOrigin], + lda: Int, + x: UnsafePointer[Float32, MutAnyOrigin], + incx: Int, +): + var tid = block_dim.x * block_idx.x + thread_idx.x + if tid != 0: + return + + if not trans: + if not uplo: + # Lower triangular: forward substitution + for i in range(n): + var temp = x[i * incx] + for j in range(i): + temp -= A[i * lda + j] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + # Upper triangular: backward substitution + for k in range(n): + var i = n - 1 - k + var temp = x[i * incx] + for j in range(i + 1, n): + temp -= A[i * lda + j] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + if not uplo: + # Transpose lower triangular (A^T is upper): backward substitution + for k in range(n): + var i = n - 1 - k + var temp = x[i * incx] + for j in range(i + 1, n): + temp -= A[j * lda + i] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + # Transpose upper triangular (A^T is lower): forward substitution + for i in range(n): + var temp = x[i * incx] + for j in range(i): + temp -= A[j * lda + i] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + + +fn dtrsv_device( + uplo: Int, + trans: Int, + diag: Int, + n: Int, + A: UnsafePointer[Float64, ImmutAnyOrigin], + lda: Int, + x: UnsafePointer[Float64, MutAnyOrigin], + incx: Int, +): + var tid = block_dim.x * block_idx.x + thread_idx.x + if tid != 0: + return + + if not trans: + if not uplo: + # Lower triangular: forward substitution + for i in range(n): + var temp = x[i * incx] + for j in range(i): + temp -= A[i * lda + j] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + # Upper triangular: backward substitution + for k in range(n): + var i = n - 1 - k + var temp = x[i * incx] + for j in range(i + 1, n): + temp -= A[i * lda + j] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + if not uplo: + # Transpose lower triangular (A^T is upper): backward substitution + for k in range(n): + var i = n - 1 - k + var temp = x[i * incx] + for j in range(i + 1, n): + temp -= A[j * lda + i] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + else: + # Transpose upper triangular (A^T is lower): forward substitution + for i in range(n): + var temp = x[i * incx] + for j in range(i): + temp -= A[j * lda + i] * x[j * incx] + if not diag: + temp /= A[i * lda + i] + x[i * incx] = temp + + +fn blas_trsv[dtype: DType]( + uplo: Int, + trans: Bool, + diag: Int, + n: Int, + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_x: UnsafePointer[Scalar[dtype], MutAnyOrigin], + incx: Int, + ctx: DeviceContext, +) raises: + + # NOTE: add error checking here? + # check n > 0 + # check incx > 0 + + var trans_i = 1 if trans else 0 + + @parameter + if dtype == DType.float32: + ctx.enqueue_function[strsv_device, strsv_device]( + uplo, trans_i, diag, + n, d_A, lda, + d_x, incx, + grid_dim=1, + block_dim=1, + ) + elif dtype == DType.float64: + ctx.enqueue_function[dtrsv_device, dtrsv_device]( + uplo, trans_i, diag, + n, d_A, lda, + d_x, incx, + grid_dim=1, + block_dim=1, + ) + else: + raise Error("blas_trsv: Unsupported type") + + ctx.synchronize() diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index 184eccd..9e2b6e7 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -37,7 +37,7 @@ def generate_random_scalar[ # Error check following BLAS++ check_gemm: -# NOTE: can't get epsilon value in Mojo; using tol32, tol64 +# NOTE: might use this for dot, gemv, ger, geru, gemm, symv, hemv, symm, trmv, trsv?, trmm, trsm? fn check_gemm_error[dtype: DType]( m: Int, n: Int, k: Int, alpha: Scalar[dtype], diff --git a/test-level1.mojo b/test-level1.mojo index 1820efc..cc1a683 100644 --- a/test-level1.mojo +++ b/test-level1.mojo @@ -691,57 +691,50 @@ def rotm_test[ assert_almost_equal(y_result[i], expected_y, atol=atol) -# def rotmg_test[ -# dtype: DType, -# size: Int -# ](): -# with DeviceContext() as ctx: -# # d1 and d2 must be positive -# var d1 = generate_random_scalar[dtype](1, 10000) -# var d2 = generate_random_scalar[dtype](1, 10000) -# var x1 = generate_random_scalar[dtype](-10000, 10000) -# var y1 = generate_random_scalar[dtype](-10000, 10000) - -# d_d1 = ctx.enqueue_create_buffer[dtype](1) -# d_d1.enqueue_fill(d1) -# d_d2 = ctx.enqueue_create_buffer[dtype](1) -# d_d2.enqueue_fill(d2) -# d_x1 = ctx.enqueue_create_buffer[dtype](1) -# d_x1.enqueue_fill(x1) -# d_y1 = ctx.enqueue_create_buffer[dtype](1) -# d_y1.enqueue_fill(y1) -# d_param = ctx.enqueue_create_buffer[dtype](5) - -# # Launch Mojo BLAS kernel -# # NOTE: not implemented -# # blas_rotmg[dtype]( -# # d1.unsafe_ptr(), -# # d2.unsafe_ptr(), -# # x1.unsafe_ptr(), -# # x2.unsafe_ptr(), -# # d_param.unsafe_ptr(), -# # ctx -# # ) - -# # Import SciPy and numpy -# sp = Python.import_module("scipy") -# np = Python.import_module("numpy") -# sp_blas = sp.linalg.blas - -# # srotmg - float32, drotmg - float64 -# if dtype == DType.float32: -# py_p = sp_blas.srotmg(d1, d2, x1, y1) -# elif dtype == DType.float64: -# py_p = sp_blas.drotmg(d1, d2, x1, y1) -# else: -# print(dtype , " is not supported by SciPy") -# return - -# # Only compare param -# with d_param.map_to_host() as mojo_param: -# for i in range(5): -# var py_ref = Scalar[dtype](py=py_p[i]) -# assert_equal(mojo_param[i], py_ref) +def rotmg_test[ + dtype: DType, + size: Int +](): + # d1 and d2 must be positive + var d1 = generate_random_scalar[dtype](1, 10000) + var d2 = generate_random_scalar[dtype](1, 10000) + var x1 = generate_random_scalar[dtype](-10000, 10000) + var y1 = generate_random_scalar[dtype](-10000, 10000) + var param = List[Scalar[dtype]](length=5, fill=0.0) + + # Save initial values before we modify d1, d2, x1 in-place + var d1_ini = d1 + var d2_ini = d2 + var x1_ini = x1 + + # Launch Mojo BLAS kernel + blas_rotmg[dtype]( + UnsafePointer(to=d1), + UnsafePointer(to=d2), + UnsafePointer(to=x1), + UnsafePointer(to=y1), + param + ) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + # srotmg - float32, drotmg - float64 + if dtype == DType.float32: + py_p = sp_blas.srotmg(d1_ini, d2_ini, x1_ini, y1) + elif dtype == DType.float64: + py_p = sp_blas.drotmg(d1_ini, d2_ini, x1_ini, y1) + else: + print(dtype , " is not supported by SciPy") + return + + + # Only compare param + for i in range(5): + var py_ref = Scalar[dtype](py=py_p[i]) + assert_equal(param[i], py_ref) def scal_test[ @@ -893,11 +886,11 @@ def test_rotm(): rotm_test[DType.float64, 256]() rotm_test[DType.float64, 4096]() -# def test_rotmg(): -# rotmg_test[DType.float32, 256]() -# rotmg_test[DType.float32, 4096]() -# rotmg_test[DType.float64, 256]() -# rotmg_test[DType.float64, 4096]() +def test_rotmg(): + rotmg_test[DType.float32, 256]() + rotmg_test[DType.float32, 4096]() + rotmg_test[DType.float64, 256]() + rotmg_test[DType.float64, 4096]() def test_scal(): scal_test[DType.float32, 256]() @@ -931,7 +924,7 @@ def main(): elif args[i] == "rot": suite.test[test_rot]() elif args[i] == "rotg": suite.test[test_rotg]() elif args[i] == "rotm": suite.test[test_rotm]() - # elif args[i] == "rotmg": suite.test[test_rotmg]() + elif args[i] == "rotmg": suite.test[test_rotmg]() elif args[i] == "scal": suite.test[test_scal]() elif args[i] == "swap": suite.test[test_swap]() else: print("unknown routine:", args[i]) diff --git a/test-level2.mojo b/test-level2.mojo index acaa4fe..7722d8c 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -170,7 +170,6 @@ def ger_test[ for j in range(n): assert_almost_equal(Scalar[dtype](py=sp_res[i][j]), res_mojo[(i*n)+j], atol=atol) - def syr_test[ dtype: DType, n: Int, @@ -428,6 +427,110 @@ def gbmv_test[ ) assert_true(ok) +def trsv_test[ + dtype: DType, + n: Int, + trans: Bool, + uplo: Int, + diag: Int, +](): + with DeviceContext() as ctx: + A_d = ctx.enqueue_create_buffer[dtype](n * n) + A = ctx.enqueue_create_host_buffer[dtype](n * n) + x_d = ctx.enqueue_create_buffer[dtype](n) + x = ctx.enqueue_create_host_buffer[dtype](n) + + generate_random_arr[dtype](n * n, A.unsafe_ptr(), -1, 1) + generate_random_arr[dtype](n, x.unsafe_ptr(), -1, 1) + + # Zero out off-triangle elements to make A strictly triangular + for i in range(n): + for j in range(n): + # upper triangular + if uplo == 0: + if i > j: + A[i * n + j] = 0 + # lower triangular + else: + if i < j: + A[i * n + j] = 0 + + # Handle diagonal based on diag parameter + for i in range(n): + # unit diagonal + if diag == 1: + A[i * n + i] = 1 + # non-unit diagonal: make diagonally dominant to reduce numerical instability + else: + A[i * n + i] += 1000 + + ctx.enqueue_copy(A_d, A) + ctx.enqueue_copy(x_d, x) + ctx.synchronize() + + # Compute norms for error checks + var norm_A = frobenius_norm[dtype](A.unsafe_ptr(), n * n) + var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), n) + + blas_trsv[dtype]( + uplo, + trans, + diag, + n, + A_d.unsafe_ptr(), n, + x_d.unsafe_ptr(), 1, + ctx, + ) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_x = Python.list() + for i in range(n * n): + py_A.append(A[i]) + for i in range(n): + py_x.append(x[i]) + + var sp_res: PythonObject + + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32).reshape(n, n) + np_x = np.array(py_x, dtype=np.float32) + sp_res = sp_blas.strsv(np_A, np_x, + lower=0 if uplo else 1, + trans=1 if trans else 0, + diag=1 if diag else 0 + ) + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64).reshape(n, n) + np_x = np.array(py_x, dtype=np.float64) + sp_res = sp_blas.dtrsv(np_A, np_x, + lower=0 if uplo else 1, + trans=1 if trans else 0, + diag=1 if diag else 0 + ) + else: + print("Unsupported type: ", dtype) + return + + with x_d.map_to_host() as res_mojo: + var norm_diff = Scalar[dtype](0) + for i in range(n): + var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i]) + norm_diff += diff * diff + norm_diff = sqrt(norm_diff) + + var ok = check_gemm_error[dtype]( + 1, n, n, + Scalar[dtype](1), Scalar[dtype](0), + norm_A, norm_x, Scalar[dtype](0), + norm_diff + ) + assert_true(ok) + def test_gemv(): gemv_test[DType.float32, 64, 64, False]() gemv_test[DType.float32, 64, 64, True]() @@ -466,6 +569,24 @@ def test_gbmv(): gbmv_test[DType.float64, 512, 64, 1, 2, False]() gbmv_test[DType.float64, 512, 64, 2, 2, True]() +def test_trsv(): + trsv_test[DType.float32, 64, True, 0, 0]() + trsv_test[DType.float32, 64, True, 1, 0]() + trsv_test[DType.float32, 64, True, 0, 1]() + trsv_test[DType.float32, 64, True, 1, 1]() + trsv_test[DType.float32, 64, False, 0, 0]() + trsv_test[DType.float32, 64, False, 1, 0]() + trsv_test[DType.float32, 64, False, 0, 1]() + trsv_test[DType.float32, 64, False, 1, 1]() + trsv_test[DType.float64, 64, True, 0, 0]() + trsv_test[DType.float64, 64, True, 1, 0]() + trsv_test[DType.float64, 64, True, 0, 1]() + trsv_test[DType.float64, 64, True, 1, 1]() + trsv_test[DType.float64, 64, False, 0, 0]() + trsv_test[DType.float64, 64, False, 1, 0]() + trsv_test[DType.float64, 64, False, 0, 1]() + trsv_test[DType.float64, 64, False, 1, 1]() + def main(): print("--- MojoBLAS Level 2 routines testing ---") var args = argv() @@ -477,8 +598,10 @@ def main(): for i in range(1, len(args)): if args[i] == "gemv": suite.test[test_gemv]() elif args[i] == "ger": suite.test[test_ger]() - # elif args[i] == "syr": suite.test[test_syr]() - # elif args[i] == "syr2": suite.test[test_syr2]() + elif args[i] == "syr": suite.test[test_syr]() + elif args[i] == "syr2": suite.test[test_syr2]() + elif args[i] == "gbmv": suite.test[test_gbmv]() + elif args[i] == "trsv": suite.test[test_trsv]() else: print("unknown routine:", args[i]) suite^.run()