From aada542bf9d7e8dea0fb04f6489525707e6e25de Mon Sep 17 00:00:00 2001 From: tdehoff Date: Fri, 20 Mar 2026 16:08:10 -0400 Subject: [PATCH 1/2] Added symv test --- test-level2.mojo | 101 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 2 deletions(-) diff --git a/test-level2.mojo b/test-level2.mojo index ca94fd6..9f44af8 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -574,6 +574,92 @@ def sbmv_test[ ) assert_true(ok) +def symv_test[ + dtype: DType, + n: Int, + upper: Bool, +](): + with DeviceContext() as ctx: + A_d = ctx.enqueue_create_buffer[dtype](n * n) + A = ctx.enqueue_create_host_buffer[dtype](n * n) + x_d = ctx.enqueue_create_buffer[dtype](n) + x = ctx.enqueue_create_host_buffer[dtype](n) + y_d = ctx.enqueue_create_buffer[dtype](n) + y = ctx.enqueue_create_host_buffer[dtype](n) + + generate_random_arr[dtype](n * n, A.unsafe_ptr(), -1, 1) + + # Force symmetry: A = 0.5 * (A + A^T) + for i in range(n): + for j in range(i, n): + var sym = (A[i * n + j] + A[j * n + i]) * Scalar[dtype](0.5) + A[i * n + j] = sym + A[j * n + i] = sym + + generate_random_arr[dtype](n, x.unsafe_ptr(), -100, 100) + generate_random_arr[dtype](n, y.unsafe_ptr(), -100, 100) + + var alpha = generate_random_scalar[dtype](-100, 100) + var beta = generate_random_scalar[dtype](-100, 100) + + var norm_A = frobenius_norm[dtype](A.unsafe_ptr(), n * n) + var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), n) + var norm_y = frobenius_norm[dtype](y.unsafe_ptr(), n) + + ctx.enqueue_copy(A_d, A) + ctx.enqueue_copy(x_d, x) + ctx.enqueue_copy(y_d, y) + ctx.synchronize() + + blas_symv[dtype]( + upper, + n, + alpha, + A_d.unsafe_ptr(), n, + x_d.unsafe_ptr(), 1, + beta, + y_d.unsafe_ptr(), 1, + ctx, + ) + + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_x = Python.list() + py_y = Python.list() + for i in range(n * n): + py_A.append(A[i]) + for i in range(n): + py_x.append(x[i]) + py_y.append(y[i]) + + var sp_res: PythonObject + + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32).reshape(n, n) + np_x = np.array(py_x, dtype=np.float32) + np_y = np.array(py_y, dtype=np.float32) + sp_res = sp_blas.ssymv(alpha, np_A, np_x, beta=beta, y=np_y, lower=0 if upper else 1) + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64).reshape(n, n) + np_x = np.array(py_x, dtype=np.float64) + np_y = np.array(py_y, dtype=np.float64) + sp_res = sp_blas.dsymv(alpha, np_A, np_x, beta=beta, y=np_y, lower=0 if upper else 1) + else: + print("Unsupported type: ", dtype) + return + + with y_d.map_to_host() as res_mojo: + var norm_diff = Scalar[dtype](0) + for i in range(n): + var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i]) + norm_diff += diff * diff + norm_diff = sqrt(norm_diff) + var ok = check_gemm_error[dtype](1, n, n, alpha, beta, norm_A, norm_x, norm_y, norm_diff) + assert_true(ok) + def trsv_test[ dtype: DType, n: Int, @@ -734,6 +820,16 @@ def test_sbmv(): sbmv_test[DType.float64, 512, 32, False]() sbmv_test[DType.float64, 512, 32, True]() +def test_symv(): + symv_test[DType.float32, 64, True]() + symv_test[DType.float32, 64, False]() + symv_test[DType.float64, 64, True]() + symv_test[DType.float64, 64, False]() + symv_test[DType.float32, 1024, True]() + symv_test[DType.float32, 1024, False]() + symv_test[DType.float64, 1024, True]() + symv_test[DType.float64, 1024, False]() + def test_trsv(): trsv_test[DType.float32, 64, True, 0, 0]() trsv_test[DType.float32, 64, True, 1, 0]() @@ -761,13 +857,14 @@ def main(): var suite = TestSuite(cli_args=List[StaticString]()) for i in range(1, len(args)): - if args[i] == "gemv": suite.test[test_gemv]() + if args[i] == "gemv": suite.test[test_gemv]() elif args[i] == "ger": suite.test[test_ger]() - elif args[i] == "sbmv": suite.test[test_sbmv]() + elif args[i] == "sbmv": suite.test[test_sbmv]() elif args[i] == "syr": suite.test[test_syr]() elif args[i] == "syr2": suite.test[test_syr2]() elif args[i] == "gbmv": suite.test[test_gbmv]() elif args[i] == "trsv": suite.test[test_trsv]() + elif args[i] == "symv": suite.test[test_symv]() else: print("unknown routine:", args[i]) suite^.run() From c1d07cf118f9a9a43d3d9fcddad4726ca098e5a1 Mon Sep 17 00:00:00 2001 From: tdehoff Date: Fri, 20 Mar 2026 16:08:56 -0400 Subject: [PATCH 2/2] symv impl --- src/level2/__init__.mojo | 1 + src/level2/symv_device.mojo | 116 ++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 src/level2/symv_device.mojo diff --git a/src/level2/__init__.mojo b/src/level2/__init__.mojo index df63738..59539ef 100644 --- a/src/level2/__init__.mojo +++ b/src/level2/__init__.mojo @@ -5,3 +5,4 @@ from .syr2_device import * from .gbmv_device import * from .sbmv_device import * from .trsv_device import * +from .symv_device import * diff --git a/src/level2/symv_device.mojo b/src/level2/symv_device.mojo new file mode 100644 index 0000000..45a1e30 --- /dev/null +++ b/src/level2/symv_device.mojo @@ -0,0 +1,116 @@ +from gpu import thread_idx, block_idx, block_dim, grid_dim +from gpu.host import DeviceContext +from math import ceildiv + +comptime TBsize = 512 + +# level2.symv +# Performs matrix-vector multiplication of form +# y := alpha*A*x + beta*y +# where alpha and beta are scalars, x and y are vectors and A is an n by n symmetric matrix. +fn ssymv_device( + upper: Int, + n: Int, + alpha: Float32, + A: UnsafePointer[Float32, ImmutAnyOrigin], + lda: Int, + x: UnsafePointer[Float32, ImmutAnyOrigin], + incx: Int, + beta: Float32, + y: UnsafePointer[Float32, MutAnyOrigin], + incy: Int, +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + for i in range(global_i, n, n_threads): + var sum = Scalar[DType.float32](0) + for j in range(n): + var val: Float32 + if upper: + if j >= i: + val = A[i * lda + j] + else: + val = A[j * lda + i] + else: + if j <= i: + val = A[i * lda + j] + else: + val = A[j * lda + i] + sum += val * x[j * incx] + y[i * incy] = alpha * sum + beta * y[i * incy] + + +fn dsymv_device( + upper: Int, + n: Int, + alpha: Float64, + A: UnsafePointer[Float64, ImmutAnyOrigin], + lda: Int, + x: UnsafePointer[Float64, ImmutAnyOrigin], + incx: Int, + beta: Float64, + y: UnsafePointer[Float64, MutAnyOrigin], + incy: Int, +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + for i in range(global_i, n, n_threads): + var sum = Scalar[DType.float64](0) + for j in range(n): + var val: Float64 + if upper: + if j >= i: + val = A[i * lda + j] + else: + val = A[j * lda + i] + else: + if j <= i: + val = A[i * lda + j] + else: + val = A[j * lda + i] + sum += val * x[j * incx] + y[i * incy] = alpha * sum + beta * y[i * incy] + + +fn blas_symv[dtype: DType]( + upper: Bool, + n: Int, + alpha: Scalar[dtype], + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_x: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + incx: Int, + beta: Scalar[dtype], + d_y: UnsafePointer[Scalar[dtype], MutAnyOrigin], + incy: Int, + ctx: DeviceContext, +) raises: + # NOTE: add error checking here? + + var upper_i = 1 if upper else 0 + + @parameter + if dtype == DType.float32: + ctx.enqueue_function[ssymv_device, ssymv_device]( + upper_i, n, + alpha, d_A, lda, + d_x, incx, + beta, d_y, incy, + grid_dim=ceildiv(n, TBsize), + block_dim=TBsize, + ) + elif dtype == DType.float64: + ctx.enqueue_function[dsymv_device, dsymv_device]( + upper_i, n, + alpha, d_A, lda, + d_x, incx, + beta, d_y, incy, + grid_dim=ceildiv(n, TBsize), + block_dim=TBsize, + ) + else: + raise Error("blas_symv: Unsupported type") + + ctx.synchronize()