diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index 775106a..9f0e5ef 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -156,7 +156,8 @@ fn dense_to_band[dtype: DType]( else: A[i * n + j] = Scalar[dtype](0) -fn dense_to_sym_band[dtype: DType]( +# Packs row-major dense matrix to row-major band buffer +fn dense_to_sym_band_rm[dtype: DType]( A_dense: UnsafePointer[Scalar[dtype], MutAnyOrigin], A_band: UnsafePointer[Scalar[dtype], MutAnyOrigin], n: Int, @@ -204,3 +205,53 @@ fn dense_to_sym_band[dtype: DType]( val = A_band[j * lda + (j - i)] A_dense[i * n + j] = val + +# Packs row-major dense matrix to column-major band buffer +fn dense_to_sym_band_cm[dtype: DType]( + A_dense: UnsafePointer[Scalar[dtype], MutAnyOrigin], + A_band: UnsafePointer[Scalar[dtype], MutAnyOrigin], + n: Int, + k: Int, + upper: Bool, +): + var lda = k + 1 + + # zero initialize + for i in range(n): + for b in range(lda): + A_band[i * lda + b] = 0 + + if upper: + # store A[i,j] for j >= i + for j in range(n): + var m = k - j + var i_start = j - k if (j - k > 0) else 0 + for i in range(i_start, j + 1): + A_band[j * lda + m + i] = A_dense[j * n + i] + else: + # store A[i,j] for j <= i + for j in range(n): + var i_end = (j + k) if (j + k < n - 1) else (n - 1) + for i in range(j, i_end + 1): + var band_col = i - j + A_band[j * lda + band_col] = A_dense[i * n + j] + + # Overwrite original matrix with band reconstruction + for i in range(n): + for j in range(n): + var val = Scalar[dtype](0) + + if upper: + if j >= i and j <= i + k: + val = A_band[(k - (j - i)) + j * lda] + elif i >= j and i <= j + k: + # symmetric mirror + val = A_band[(k - (i - j)) + i * lda] + else: + if j <= i and j >= i - k: + val = A_band[(i - j) + j * lda] + elif j >= i and j <= i + k: + # symmetric mirror + val = A_band[(j - i) + i * lda] + + A_dense[i * n + j] = val diff --git a/test-level2.mojo b/test-level2.mojo index 53b7ed8..ca94fd6 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -440,7 +440,8 @@ def sbmv_test[ A_dense = ctx.enqueue_create_host_buffer[dtype](n * n) # Band storage (symmetric) - A_band = ctx.enqueue_create_host_buffer[dtype](lda * n) + A_band_rm = ctx.enqueue_create_host_buffer[dtype](lda * n) + A_band_cm = ctx.enqueue_create_host_buffer[dtype](n * lda) A_d = ctx.enqueue_create_buffer[dtype](lda * n) x = ctx.enqueue_create_host_buffer[dtype](n) @@ -464,16 +465,25 @@ def sbmv_test[ generate_random_arr[dtype](n, x.unsafe_ptr(), -100, 100) generate_random_arr[dtype](n, y.unsafe_ptr(), -100, 100) - # Convert dense -> symmetric band - dense_to_sym_band( + # Convert dense -> symmetric row-major band for MojoBLAS routine + dense_to_sym_band_rm( A_dense.unsafe_ptr(), - A_band.unsafe_ptr(), + A_band_rm.unsafe_ptr(), n, k, upper ) - ctx.enqueue_copy(A_d, A_band) + # Convert dense -> symmetric col-major band for SciPy routine + dense_to_sym_band_cm( + A_dense.unsafe_ptr(), + A_band_cm.unsafe_ptr(), + n, + k, + upper + ) + + ctx.enqueue_copy(A_d, A_band_rm) ctx.enqueue_copy(x_d, x) ctx.enqueue_copy(y_d, y) ctx.synchronize() @@ -505,8 +515,8 @@ def sbmv_test[ py_x = Python.list() py_y = Python.list() - for i in range(n * n): - py_A.append(A_dense[i]) + for i in range(n * lda): + py_A.append(A_band_cm[i]) for i in range(n): py_x.append(x[i]) py_y.append(y[i]) @@ -514,11 +524,13 @@ def sbmv_test[ var sp_res: PythonObject if dtype == DType.float32: - np_A = np.array(py_A, dtype=np.float32).reshape(n, n) + # order='F' is col-major, 'C' is row-major + np_A = np.array(py_A, dtype=np.float32).reshape(lda, n, order='F') np_x = np.array(py_x, dtype=np.float32) np_y = np.array(py_y, dtype=np.float32) - sp_res = sp_blas.ssymv( + sp_res = sp_blas.ssbmv( + k, alpha, np_A, np_x, @@ -528,11 +540,12 @@ def sbmv_test[ ) elif dtype == DType.float64: - np_A = np.array(py_A, dtype=np.float64).reshape(n, n) + np_A = np.array(py_A, dtype=np.float64).reshape(lda, n, order='F') np_x = np.array(py_x, dtype=np.float64) np_y = np.array(py_y, dtype=np.float64) - sp_res = sp_blas.dsymv( + sp_res = sp_blas.dsbmv( + k, alpha, np_A, np_x, @@ -663,8 +676,8 @@ def trsv_test[ norm_A, norm_x, Scalar[dtype](0), norm_diff ) - assert_true(ok) - + assert_true(ok) + def test_gemv(): gemv_test[DType.float32, 64, 64, False]() gemv_test[DType.float32, 64, 64, True]() @@ -720,7 +733,7 @@ def test_sbmv(): sbmv_test[DType.float64, 512, 2, True]() sbmv_test[DType.float64, 512, 32, False]() sbmv_test[DType.float64, 512, 32, True]() - + def test_trsv(): trsv_test[DType.float32, 64, True, 0, 0]() trsv_test[DType.float32, 64, True, 1, 0]()