diff --git a/src/level2/__init__.mojo b/src/level2/__init__.mojo index 59539ef..c94e2ea 100644 --- a/src/level2/__init__.mojo +++ b/src/level2/__init__.mojo @@ -6,3 +6,4 @@ from .gbmv_device import * from .sbmv_device import * from .trsv_device import * from .symv_device import * +from .tpmv_device import * \ No newline at end of file diff --git a/src/level2/tpmv_device.mojo b/src/level2/tpmv_device.mojo new file mode 100644 index 0000000..1dbdd85 --- /dev/null +++ b/src/level2/tpmv_device.mojo @@ -0,0 +1,150 @@ +from gpu import thread_idx, block_idx, block_dim, grid_dim +from gpu.host import DeviceContext +from math import ceildiv + +comptime TBsize = 512 + +# level2.tpmv +# Performs matrix-vector multiplication of form +# x := A*x, or x := A**T*x +# where x is a vector and A is an n by n unit or non-unit, upper or lower +# triangular band matrix with (k+1) diagonals. +fn stpmv_device( + uplo: Int, + trans: Int, + diag: Int, + n: Int, + AP: UnsafePointer[Float32, ImmutAnyOrigin], + x: UnsafePointer[Float32, ImmutAnyOrigin], + incx: Int, + workspace: UnsafePointer[Float32, MutAnyOrigin], +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + for i in range(global_i, n, n_threads): + var sum = x[i * incx] + if trans: + if uplo: + var index = (i * (i + 1)) / 2 + i + if not diag: + sum *= AP[index] + for j in range(i + 1, n): + index += j + sum += AP[index] * x[j * incx] + else: + if not diag: + sum *= AP[i * n - ((i - 1) * i) / 2] + var index = i + for j in range(0, i): + sum += AP[index] * x[j * incx] + index += n - j + else: + if uplo: + var index = (i * (i + 1)) / 2 + i + if not diag: + sum *= AP[index] + index -= i + for j in range(0, i): + sum += AP[index] * x[j * incx] + index += 1 + else: + var index = i * n - ((i - 1) * i) / 2 + if not diag: + sum *= AP[index] + for j in range(i + 1, n): + index += 1 + sum += AP[index] * x[j * incx] + workspace[i * incx] = sum + + +fn dtpmv_device( + uplo: Int, + trans: Int, + diag: Int, + n: Int, + AP: UnsafePointer[Float64, ImmutAnyOrigin], + x: UnsafePointer[Float64, ImmutAnyOrigin], + incx: Int, + workspace: UnsafePointer[Float64, MutAnyOrigin], +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + for i in range(global_i, n, n_threads): + var sum = x[i * incx] + if trans: + if uplo: + var index = (i * (i + 1)) / 2 + i + if not diag: + sum *= AP[index] + for j in range(i + 1, n): + index += j + sum += AP[index] * x[j * incx] + else: + if not diag: + sum *= AP[i * n - ((i - 1) * i) / 2] + var index = i + for j in range(0, i): + sum += AP[index] * x[j * incx] + index += n - j + else: + if uplo: + var index = (i * (i + 1)) / 2 + i + if not diag: + sum *= AP[index] + index -= i + for j in range(0, i): + sum += AP[index] * x[j * incx] + index += 1 + else: + var index = i * n - ((i - 1) * i) / 2 + if not diag: + sum *= AP[index] + for j in range(i + 1, n): + index += 1 + sum += AP[index] * x[j * incx] + workspace[i * incx] = sum + + +fn blas_tpmv[dtype: DType]( + uplo: Int, + trans: Bool, + diag: Int, + n: Int, + d_AP: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + d_x: UnsafePointer[Scalar[dtype], MutAnyOrigin], + incx: Int, + ctx: DeviceContext, +) raises: + # NOTE: add error checking here? + + var trans_i = 1 if trans else 0 + + var workspace = ctx.enqueue_create_buffer[dtype](n) + + @parameter + if dtype == DType.float32: + ctx.enqueue_function[stpmv_device, stpmv_device]( + uplo, trans_i, diag, + n, d_AP, + d_x, incx, + workspace, + grid_dim=ceildiv(n, TBsize), + block_dim=TBsize, + ) + elif dtype == DType.float64: + ctx.enqueue_function[dtpmv_device, dtpmv_device]( + uplo, trans_i, diag, + n, d_AP, + d_x, incx, + workspace, + grid_dim=ceildiv(n, TBsize), + block_dim=TBsize, + ) + else: + raise Error("blas_tpmv: Unsupported type") + + ctx.enqueue_copy(d_x, workspace) + + ctx.synchronize() diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index 9f0e5ef..1d08eae 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -255,3 +255,24 @@ fn dense_to_sym_band_cm[dtype: DType]( val = A_band[(j - i) + i * lda] A_dense[i * n + j] = val + +fn dense_to_tri_band[dtype: DType]( + A_dense: UnsafePointer[Scalar[dtype], MutAnyOrigin], + A_band: UnsafePointer[Scalar[dtype], MutAnyOrigin], + n: Int, + uplo: Int, +): + var index = 0 + for j in range(n): + if uplo: + for i in range(0, j+1): + A_band[index] = A_dense[i * n + j] + index += 1 + else: + for i in range(j, n): + A_band[index] = A_dense[i * n + j] + index += 1 + + var n_packed = n * (n + 1) / 2 + for i in range(index, n_packed): + A_band[i] = 0 \ No newline at end of file diff --git a/test-level2.mojo b/test-level2.mojo index 9f44af8..1b46041 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -764,6 +764,118 @@ def trsv_test[ ) assert_true(ok) + +def tpmv_test[ + dtype: DType, + n: Int, + trans: Bool, + uplo: Int, + diag: Int, +](): + with DeviceContext() as ctx: + A_dense = ctx.enqueue_create_host_buffer[dtype](n * n) + + AP_d = ctx.enqueue_create_buffer[dtype](n * n) + AP = ctx.enqueue_create_host_buffer[dtype](n * n) + x_d = ctx.enqueue_create_buffer[dtype](n) + x = ctx.enqueue_create_host_buffer[dtype](n) + + generate_random_arr[dtype](n * n, A_dense.unsafe_ptr(), -1, 1) + generate_random_arr[dtype](n, x.unsafe_ptr(), -1, 1) + + # Zero out off-triangle elements to make A strictly triangular + for i in range(n): + for j in range(n): + # upper triangular + if uplo == 0: + if i > j: + A_dense[i * n + j] = 0 + # lower triangular + else: + if i < j: + A_dense[i * n + j] = 0 + + # Handle diagonal based on diag parameter + for i in range(n): + # unit diagonal + if diag == 1: + A_dense[i * n + i] = 1 + # non-unit diagonal: make diagonally dominant to reduce numerical instability + else: + A_dense[i * n + i] += 1000 + + dense_to_tri_band(A_dense.unsafe_ptr(), AP.unsafe_ptr(), n, uplo) + + ctx.enqueue_copy(AP_d, AP) + ctx.enqueue_copy(x_d, x) + ctx.synchronize() + + # Compute norms for error checks + var norm_A = frobenius_norm[dtype](A_dense.unsafe_ptr(), n * n) + var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), n) + + blas_tpmv[dtype]( + uplo, + trans, + diag, + n, + AP_d.unsafe_ptr(), + x_d.unsafe_ptr(), 1, + ctx, + ) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_x = Python.list() + for i in range(n * n): + py_A.append(AP[i]) + for i in range(n): + py_x.append(x[i]) + + var sp_res: PythonObject + + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32) + np_x = np.array(py_x, dtype=np.float32) + sp_res = sp_blas.stpmv(n, np_A, np_x, + lower=0 if uplo else 1, + trans=1 if trans else 0, + diag=1 if diag else 0 + ) + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64) + np_x = np.array(py_x, dtype=np.float64) + sp_res = sp_blas.dtpmv(n, np_A, np_x, + lower=0 if uplo else 1, + trans=1 if trans else 0, + diag=1 if diag else 0 + ) + else: + print("Unsupported type: ", dtype) + return + + with x_d.map_to_host() as res_mojo: + + var norm_diff = Scalar[dtype](0) + for i in range(n): + var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i]) + print(res_mojo[i], Scalar[dtype](py=sp_res[i])) + norm_diff += diff * diff + norm_diff = sqrt(norm_diff) + print(norm_diff) + + var ok = check_gemm_error[dtype]( + 1, n, n, + Scalar[dtype](1), Scalar[dtype](0), + norm_A, norm_x, Scalar[dtype](0), + norm_diff + ) + assert_true(ok) + def test_gemv(): gemv_test[DType.float32, 64, 64, False]() gemv_test[DType.float32, 64, 64, True]() @@ -848,6 +960,24 @@ def test_trsv(): trsv_test[DType.float64, 64, False, 0, 1]() trsv_test[DType.float64, 64, False, 1, 1]() +def test_tpmv(): + tpmv_test[DType.float32, 64, True, 0, 0]() + tpmv_test[DType.float32, 64, True, 1, 0]() + tpmv_test[DType.float32, 64, True, 0, 1]() + tpmv_test[DType.float32, 64, True, 1, 1]() + tpmv_test[DType.float32, 64, False, 0, 0]() + tpmv_test[DType.float32, 64, False, 1, 0]() + tpmv_test[DType.float32, 64, False, 0, 1]() + tpmv_test[DType.float32, 64, False, 1, 1]() + tpmv_test[DType.float64, 64, True, 0, 0]() + tpmv_test[DType.float64, 64, True, 1, 0]() + tpmv_test[DType.float64, 64, True, 0, 1]() + tpmv_test[DType.float64, 64, True, 1, 1]() + tpmv_test[DType.float64, 64, False, 0, 0]() + tpmv_test[DType.float64, 64, False, 1, 0]() + tpmv_test[DType.float64, 64, False, 0, 1]() + tpmv_test[DType.float64, 64, False, 1, 1]() + def main(): print("--- MojoBLAS Level 2 routines testing ---") var args = argv()