From 0b15b8f52336b4dea1ad2971e120ac94db3e6d3d Mon Sep 17 00:00:00 2001 From: Jackson Mowry Date: Tue, 3 Mar 2026 23:53:09 +0000 Subject: [PATCH 1/2] gbmv implementation --- src/level2/__init__.mojo | 1 + src/level2/gbmv_device.mojo | 154 +++++++++++++++++++++++++++ src/testing_utils/testing_utils.mojo | 21 ++++ test-level2.mojo | 116 ++++++++++++++++++++ 4 files changed, 292 insertions(+) create mode 100644 src/level2/gbmv_device.mojo diff --git a/src/level2/__init__.mojo b/src/level2/__init__.mojo index 101e319..244b222 100644 --- a/src/level2/__init__.mojo +++ b/src/level2/__init__.mojo @@ -2,3 +2,4 @@ from .gemv_device import * from .ger_device import * from .syr_device import * from .syr2_device import * +from .gbmv_device import * diff --git a/src/level2/gbmv_device.mojo b/src/level2/gbmv_device.mojo new file mode 100644 index 0000000..9654f69 --- /dev/null +++ b/src/level2/gbmv_device.mojo @@ -0,0 +1,154 @@ +from gpu import thread_idx, block_idx, block_dim, grid_dim +from gpu.host import DeviceContext +from math import ceildiv + +comptime TBsize = 512 + +# level2.gbmv +# Performs band matrix-vector multiplication of form +# y := alpha*A*x + beta*y, or y := alpha*A**T*x + beta*y, +# where A is an m by n band matrix with kl sub-diagonals and ku super-diagonals. + +fn sgbmv_device( + trans: Int, + m: Int, + n: Int, + kl: Int, + ku: Int, + alpha: Float32, + A: UnsafePointer[Float32, ImmutAnyOrigin], + lda: Int, + x: UnsafePointer[Float32, ImmutAnyOrigin], + incx: Int, + beta: Float32, + y: UnsafePointer[Float32, MutAnyOrigin], + incy: Int, +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + if not trans: + # y(i) = alpha * sum_j A(i,j)*x(j) + beta*y(i) + for i in range(global_i, m, n_threads): + var sum = Scalar[DType.float32](0) + + var j_start = (i - kl) if (i - kl > 0) else 0 + var j_end = (i + ku) if (i + ku < n-1) else n-1 + + for j in range(j_start, j_end+1): + var band_col = ku + j - i + sum += A[i * lda + band_col] * x[j * incx] + + y[i * incy] = alpha * sum + beta * y[i * incy] + + else: + # y(j) = alpha * sum_i A(i,j)*x(i) + beta*y(j) + for j in range(global_i, n, n_threads): + var sum = Scalar[DType.float32](0) + + var i_start = (j - ku) if (j - ku > 0) else 0 + var i_end = (j + kl) if (j + kl < m-1) else m-1 + + for i in range(i_start, i_end+1): + var band_col = ku + j - i + sum += A[i * lda + band_col] * x[i * incx] + + y[j * incy] = alpha * sum + beta * y[j * incy] + + +fn dgbmv_device( + trans: Int, + m: Int, + n: Int, + kl: Int, + ku: Int, + alpha: Float64, + A: UnsafePointer[Float64, ImmutAnyOrigin], + lda: Int, + x: UnsafePointer[Float64, ImmutAnyOrigin], + incx: Int, + beta: Float64, + y: UnsafePointer[Float64, MutAnyOrigin], + incy: Int, +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + if not trans: + # y(i) = alpha * sum_j A(i,j)*x(j) + beta*y(i) + for i in range(global_i, m, n_threads): + var sum = Scalar[DType.float64](0) + + var j_start = (i - kl) if (i - kl > 0) else 0 + var j_end = (i + ku) if (i + ku < n-1) else n-1 + + for j in range(j_start, j_end+1): + var band_col = ku + j - i + sum += A[i * lda + band_col] * x[j * incx] + + y[i * incy] = alpha * sum + beta * y[i * incy] + + else: + # y(j) = alpha * sum_i A(i,j)*x(i) + beta*y(j) + for j in range(global_i, n, n_threads): + var sum = Scalar[DType.float64](0) + + var i_start = (j - ku) if (j - ku > 0) else 0 + var i_end = (j + kl) if (j + kl < m-1) else m-1 + + for i in range(i_start, i_end+1): + var band_col = ku + j - i + sum += A[i * lda + band_col] * x[i * incx] + + y[j * incy] = alpha * sum + beta * y[j * incy] + + +fn blas_gbmv[dtype: DType]( + trans: Bool, + m: Int, + n: Int, + kl: Int, + ku: Int, + alpha: Scalar[dtype], + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_x: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + incx: Int, + beta: Scalar[dtype], + d_y: UnsafePointer[Scalar[dtype], MutAnyOrigin], + incy: Int, + ctx: DeviceContext, +) raises: + + # NOTE: add error checking here? + # check m, n > 0 + # check kl, ku >= 0 + # check lda >= kl + ku + 1 + # check incx, incy > 0 + + var out_len = m if not trans else n + var trans_i = 1 if trans else 0 + + @parameter + if dtype == DType.float32: + ctx.enqueue_function[sgbmv_device, sgbmv_device]( + trans_i, m, n, kl, ku, + alpha, d_A, lda, + d_x, incx, + beta, d_y, incy, + grid_dim=ceildiv(out_len, TBsize), + block_dim=TBsize, + ) + elif dtype == DType.float64: + ctx.enqueue_function[dgbmv_device, dgbmv_device]( + trans_i, m, n, kl, ku, + alpha, d_A, lda, + d_x, incx, + beta, d_y, incy, + grid_dim=ceildiv(out_len, TBsize), + block_dim=TBsize, + ) + else: + raise Error("blas_gbmv: Unsupported type") + + ctx.synchronize() diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index 44ea985..c30e81a 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -134,3 +134,24 @@ fn frobenius_norm_symmetric[dtype: DType]( return sqrt(sum) + +fn dense_to_band[dtype: DType]( + A: UnsafePointer[Scalar[dtype], MutAnyOrigin], + B: UnsafePointer[Scalar[dtype], MutAnyOrigin], + m: Int, + n: Int, + kl: Int, + ku: Int +): + band_width = kl + ku + 1 + + for i in range(m): + for j in range(band_width): + B[i * band_width + j] = 0.0 + + for j in range(n): + if j >= i - kl and j <= i + ku: + band_col = j - (i - kl) + B[i * band_width + band_col] = A[i * n + j] + else: + A[i * n + j] = Scalar[dtype](0) diff --git a/test-level2.mojo b/test-level2.mojo index 88ba02f..4882183 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -320,6 +320,112 @@ def syr2_test[ assert_true(passed) +def gbmv_test[ + dtype: DType, + m: Int, + n: Int, + kl: Int, + ku: Int, + trans: Bool, +](): + comptime x_len = n if not trans else m + comptime y_len = m if not trans else n + comptime lda = kl + ku + 1 + + with DeviceContext() as ctx: + # Dense host matrix (for reference + band extraction) + A_dense = ctx.enqueue_create_host_buffer[dtype](m * n) + + # Band storage + A_band = ctx.enqueue_create_host_buffer[dtype](lda * m) + A_d = ctx.enqueue_create_buffer[dtype](lda * m) + + x = ctx.enqueue_create_host_buffer[dtype](x_len) + x_d = ctx.enqueue_create_buffer[dtype](x_len) + + y = ctx.enqueue_create_host_buffer[dtype](y_len) + y_d = ctx.enqueue_create_buffer[dtype](y_len) + + generate_random_arr[dtype, m * n](A_dense.unsafe_ptr(), -1, 1) + generate_random_arr[dtype, x_len](x.unsafe_ptr(), -100, 100) + generate_random_arr[dtype, y_len](y.unsafe_ptr(), -100, 100) + + dense_to_band(A_dense.unsafe_ptr(), A_band.unsafe_ptr(), m, n, kl, ku) + + ctx.enqueue_copy(A_d, A_band) + ctx.enqueue_copy(x_d, x) + ctx.enqueue_copy(y_d, y) + ctx.synchronize() + + var alpha = generate_random_scalar[dtype](-100, 100) + var beta = generate_random_scalar[dtype](-100, 100) + + var norm_A = frobenius_norm[dtype](A_dense.unsafe_ptr(), m * n) + var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), x_len) + var norm_y = frobenius_norm[dtype](y.unsafe_ptr(), y_len) + + blas_gbmv[dtype]( + trans, + m, n, + kl, ku, + alpha, + A_d.unsafe_ptr(), lda, + x_d.unsafe_ptr(), 1, + beta, + y_d.unsafe_ptr(), 1, + ctx, + ) + + # Python reference + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_x = Python.list() + py_y = Python.list() + + for i in range(m * n): + py_A.append(A_dense[i]) + for i in range(x_len): + py_x.append(x[i]) + for i in range(y_len): + py_y.append(y[i]) + + var sp_res: PythonObject + + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32).reshape(m, n) + np_x = np.array(py_x, dtype=np.float32) + np_y = np.array(py_y, dtype=np.float32) + + sp_res = sp_blas.sgemv(alpha, np_A, np_x, beta=beta, y=np_y, trans=1 if trans else 0) + + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64).reshape(m, n) + np_x = np.array(py_x, dtype=np.float64) + np_y = np.array(py_y, dtype=np.float64) + + sp_res = sp_blas.dgemv(alpha, np_A, np_x, beta=beta, y=np_y, trans=1 if trans else 0) + else: + print("Unsupported type") + return + + with y_d.map_to_host() as res_mojo: + var norm_diff = Scalar[dtype](0) + for i in range(y_len): + var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i]) + + norm_diff += diff * diff + norm_diff = sqrt(norm_diff) + + var ok = check_gemm_error[dtype]( + 1, y_len, x_len, + alpha, beta, + norm_A, norm_x, norm_y, + norm_diff + ) + assert_true(ok) def test_gemv(): gemv_test[DType.float32, 64, 64, False]() @@ -349,6 +455,16 @@ def test_syr2(): syr2_test[DType.float64, 512, 0]() syr2_test[DType.float64, 512, 1]() +def test_gbmv(): + gbmv_test[DType.float32, 64, 64, 1, 1, False]() + gbmv_test[DType.float32, 64, 64, 2, 2, True]() + gbmv_test[DType.float64, 64, 64, 1, 1, False]() + gbmv_test[DType.float64, 64, 64, 2, 2, True]() + gbmv_test[DType.float32, 512, 64, 1, 1, False]() + gbmv_test[DType.float32, 512, 64, 2, 2, True]() + gbmv_test[DType.float64, 512, 64, 1, 1, False]() + gbmv_test[DType.float64, 512, 64, 2, 2, True]() + def main(): print("--- MojoBLAS Level 2 routines testing ---") TestSuite.discover_tests[__functions_in_module()]().run() From 72dd9929169c2ba5ef4af0e455d58baad0409680 Mon Sep 17 00:00:00 2001 From: Jackson Mowry Date: Wed, 4 Mar 2026 17:06:28 +0000 Subject: [PATCH 2/2] PR feedback fixes --- src/level2/gbmv_device.mojo | 8 ++++---- test-level2.mojo | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/level2/gbmv_device.mojo b/src/level2/gbmv_device.mojo index 9654f69..9b1b4c4 100644 --- a/src/level2/gbmv_device.mojo +++ b/src/level2/gbmv_device.mojo @@ -36,7 +36,7 @@ fn sgbmv_device( var j_end = (i + ku) if (i + ku < n-1) else n-1 for j in range(j_start, j_end+1): - var band_col = ku + j - i + var band_col = kl + j - i sum += A[i * lda + band_col] * x[j * incx] y[i * incy] = alpha * sum + beta * y[i * incy] @@ -50,7 +50,7 @@ fn sgbmv_device( var i_end = (j + kl) if (j + kl < m-1) else m-1 for i in range(i_start, i_end+1): - var band_col = ku + j - i + var band_col = kl + j - i sum += A[i * lda + band_col] * x[i * incx] y[j * incy] = alpha * sum + beta * y[j * incy] @@ -83,7 +83,7 @@ fn dgbmv_device( var j_end = (i + ku) if (i + ku < n-1) else n-1 for j in range(j_start, j_end+1): - var band_col = ku + j - i + var band_col = kl + j - i sum += A[i * lda + band_col] * x[j * incx] y[i * incy] = alpha * sum + beta * y[i * incy] @@ -97,7 +97,7 @@ fn dgbmv_device( var i_end = (j + kl) if (j + kl < m-1) else m-1 for i in range(i_start, i_end+1): - var band_col = ku + j - i + var band_col = kl + j - i sum += A[i * lda + band_col] * x[i * incx] y[j * incy] = alpha * sum + beta * y[j * incy] diff --git a/test-level2.mojo b/test-level2.mojo index 4882183..cbaa7d3 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -456,13 +456,13 @@ def test_syr2(): syr2_test[DType.float64, 512, 1]() def test_gbmv(): - gbmv_test[DType.float32, 64, 64, 1, 1, False]() + gbmv_test[DType.float32, 64, 64, 1, 2, False]() gbmv_test[DType.float32, 64, 64, 2, 2, True]() - gbmv_test[DType.float64, 64, 64, 1, 1, False]() + gbmv_test[DType.float64, 64, 64, 1, 2, False]() gbmv_test[DType.float64, 64, 64, 2, 2, True]() - gbmv_test[DType.float32, 512, 64, 1, 1, False]() + gbmv_test[DType.float32, 512, 64, 1, 2, False]() gbmv_test[DType.float32, 512, 64, 2, 2, True]() - gbmv_test[DType.float64, 512, 64, 1, 1, False]() + gbmv_test[DType.float64, 512, 64, 1, 2, False]() gbmv_test[DType.float64, 512, 64, 2, 2, True]() def main():