Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion src/testing_utils/testing_utils.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ fn dense_to_band[dtype: DType](
else:
A[i * n + j] = Scalar[dtype](0)

fn dense_to_sym_band[dtype: DType](
# Packs row-major dense matrix to row-major band buffer
fn dense_to_sym_band_rm[dtype: DType](
A_dense: UnsafePointer[Scalar[dtype], MutAnyOrigin],
A_band: UnsafePointer[Scalar[dtype], MutAnyOrigin],
n: Int,
Expand Down Expand Up @@ -204,3 +205,53 @@ fn dense_to_sym_band[dtype: DType](
val = A_band[j * lda + (j - i)]

A_dense[i * n + j] = val

# Packs row-major dense matrix to column-major band buffer
fn dense_to_sym_band_cm[dtype: DType](
A_dense: UnsafePointer[Scalar[dtype], MutAnyOrigin],
A_band: UnsafePointer[Scalar[dtype], MutAnyOrigin],
n: Int,
k: Int,
upper: Bool,
):
var lda = k + 1

# zero initialize
for i in range(n):
for b in range(lda):
A_band[i * lda + b] = 0

if upper:
# store A[i,j] for j >= i
for j in range(n):
var m = k - j
var i_start = j - k if (j - k > 0) else 0
for i in range(i_start, j + 1):
A_band[j * lda + m + i] = A_dense[j * n + i]
else:
# store A[i,j] for j <= i
for j in range(n):
var i_end = (j + k) if (j + k < n - 1) else (n - 1)
for i in range(j, i_end + 1):
var band_col = i - j
A_band[j * lda + band_col] = A_dense[i * n + j]

# Overwrite original matrix with band reconstruction
for i in range(n):
for j in range(n):
var val = Scalar[dtype](0)

if upper:
if j >= i and j <= i + k:
val = A_band[(k - (j - i)) + j * lda]
elif i >= j and i <= j + k:
# symmetric mirror
val = A_band[(k - (i - j)) + i * lda]
else:
if j <= i and j >= i - k:
val = A_band[(i - j) + j * lda]
elif j >= i and j <= i + k:
# symmetric mirror
val = A_band[(j - i) + i * lda]

A_dense[i * n + j] = val
41 changes: 27 additions & 14 deletions test-level2.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ def sbmv_test[
A_dense = ctx.enqueue_create_host_buffer[dtype](n * n)

# Band storage (symmetric)
A_band = ctx.enqueue_create_host_buffer[dtype](lda * n)
A_band_rm = ctx.enqueue_create_host_buffer[dtype](lda * n)
A_band_cm = ctx.enqueue_create_host_buffer[dtype](n * lda)
A_d = ctx.enqueue_create_buffer[dtype](lda * n)

x = ctx.enqueue_create_host_buffer[dtype](n)
Expand All @@ -464,16 +465,25 @@ def sbmv_test[
generate_random_arr[dtype](n, x.unsafe_ptr(), -100, 100)
generate_random_arr[dtype](n, y.unsafe_ptr(), -100, 100)

# Convert dense -> symmetric band
dense_to_sym_band(
# Convert dense -> symmetric row-major band for MojoBLAS routine
dense_to_sym_band_rm(
A_dense.unsafe_ptr(),
A_band.unsafe_ptr(),
A_band_rm.unsafe_ptr(),
n,
k,
upper
)

ctx.enqueue_copy(A_d, A_band)
# Convert dense -> symmetric col-major band for SciPy routine
dense_to_sym_band_cm(
A_dense.unsafe_ptr(),
A_band_cm.unsafe_ptr(),
n,
k,
upper
)

ctx.enqueue_copy(A_d, A_band_rm)
ctx.enqueue_copy(x_d, x)
ctx.enqueue_copy(y_d, y)
ctx.synchronize()
Expand Down Expand Up @@ -505,20 +515,22 @@ def sbmv_test[
py_x = Python.list()
py_y = Python.list()

for i in range(n * n):
py_A.append(A_dense[i])
for i in range(n * lda):
py_A.append(A_band_cm[i])
for i in range(n):
py_x.append(x[i])
py_y.append(y[i])

var sp_res: PythonObject

if dtype == DType.float32:
np_A = np.array(py_A, dtype=np.float32).reshape(n, n)
# order='F' is col-major, 'C' is row-major
np_A = np.array(py_A, dtype=np.float32).reshape(lda, n, order='F')
np_x = np.array(py_x, dtype=np.float32)
np_y = np.array(py_y, dtype=np.float32)

sp_res = sp_blas.ssymv(
sp_res = sp_blas.ssbmv(
k,
alpha,
np_A,
np_x,
Expand All @@ -528,11 +540,12 @@ def sbmv_test[
)

elif dtype == DType.float64:
np_A = np.array(py_A, dtype=np.float64).reshape(n, n)
np_A = np.array(py_A, dtype=np.float64).reshape(lda, n, order='F')
np_x = np.array(py_x, dtype=np.float64)
np_y = np.array(py_y, dtype=np.float64)

sp_res = sp_blas.dsymv(
sp_res = sp_blas.dsbmv(
k,
alpha,
np_A,
np_x,
Expand Down Expand Up @@ -663,8 +676,8 @@ def trsv_test[
norm_A, norm_x, Scalar[dtype](0),
norm_diff
)
assert_true(ok)
assert_true(ok)

def test_gemv():
gemv_test[DType.float32, 64, 64, False]()
gemv_test[DType.float32, 64, 64, True]()
Expand Down Expand Up @@ -720,7 +733,7 @@ def test_sbmv():
sbmv_test[DType.float64, 512, 2, True]()
sbmv_test[DType.float64, 512, 32, False]()
sbmv_test[DType.float64, 512, 32, True]()

def test_trsv():
trsv_test[DType.float32, 64, True, 0, 0]()
trsv_test[DType.float32, 64, True, 1, 0]()
Expand Down
Loading