Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/level2/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ from .gemv_device import *
from .ger_device import *
from .syr_device import *
from .syr2_device import *
from .gbmv_device import *
154 changes: 154 additions & 0 deletions src/level2/gbmv_device.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from gpu import thread_idx, block_idx, block_dim, grid_dim
from gpu.host import DeviceContext
from math import ceildiv

comptime TBsize = 512

# level2.gbmv
# Performs band matrix-vector multiplication of form
# y := alpha*A*x + beta*y, or y := alpha*A**T*x + beta*y,
# where A is an m by n band matrix with kl sub-diagonals and ku super-diagonals.

fn sgbmv_device(
trans: Int,
m: Int,
n: Int,
kl: Int,
ku: Int,
alpha: Float32,
A: UnsafePointer[Float32, ImmutAnyOrigin],
lda: Int,
x: UnsafePointer[Float32, ImmutAnyOrigin],
incx: Int,
beta: Float32,
y: UnsafePointer[Float32, MutAnyOrigin],
incy: Int,
):
var global_i = block_dim.x * block_idx.x + thread_idx.x
var n_threads = grid_dim.x * block_dim.x

if not trans:
# y(i) = alpha * sum_j A(i,j)*x(j) + beta*y(i)
for i in range(global_i, m, n_threads):
var sum = Scalar[DType.float32](0)

var j_start = (i - kl) if (i - kl > 0) else 0
var j_end = (i + ku) if (i + ku < n-1) else n-1

for j in range(j_start, j_end+1):
var band_col = kl + j - i
sum += A[i * lda + band_col] * x[j * incx]

y[i * incy] = alpha * sum + beta * y[i * incy]

else:
# y(j) = alpha * sum_i A(i,j)*x(i) + beta*y(j)
for j in range(global_i, n, n_threads):
var sum = Scalar[DType.float32](0)

var i_start = (j - ku) if (j - ku > 0) else 0
var i_end = (j + kl) if (j + kl < m-1) else m-1

for i in range(i_start, i_end+1):
var band_col = kl + j - i
sum += A[i * lda + band_col] * x[i * incx]

y[j * incy] = alpha * sum + beta * y[j * incy]


fn dgbmv_device(
trans: Int,
m: Int,
n: Int,
kl: Int,
ku: Int,
alpha: Float64,
A: UnsafePointer[Float64, ImmutAnyOrigin],
lda: Int,
x: UnsafePointer[Float64, ImmutAnyOrigin],
incx: Int,
beta: Float64,
y: UnsafePointer[Float64, MutAnyOrigin],
incy: Int,
):
var global_i = block_dim.x * block_idx.x + thread_idx.x
var n_threads = grid_dim.x * block_dim.x

if not trans:
# y(i) = alpha * sum_j A(i,j)*x(j) + beta*y(i)
for i in range(global_i, m, n_threads):
var sum = Scalar[DType.float64](0)

var j_start = (i - kl) if (i - kl > 0) else 0
var j_end = (i + ku) if (i + ku < n-1) else n-1

for j in range(j_start, j_end+1):
var band_col = kl + j - i
sum += A[i * lda + band_col] * x[j * incx]

y[i * incy] = alpha * sum + beta * y[i * incy]

else:
# y(j) = alpha * sum_i A(i,j)*x(i) + beta*y(j)
for j in range(global_i, n, n_threads):
var sum = Scalar[DType.float64](0)

var i_start = (j - ku) if (j - ku > 0) else 0
var i_end = (j + kl) if (j + kl < m-1) else m-1

for i in range(i_start, i_end+1):
var band_col = kl + j - i
sum += A[i * lda + band_col] * x[i * incx]

y[j * incy] = alpha * sum + beta * y[j * incy]


fn blas_gbmv[dtype: DType](
trans: Bool,
m: Int,
n: Int,
kl: Int,
ku: Int,
alpha: Scalar[dtype],
d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
lda: Int,
d_x: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
incx: Int,
beta: Scalar[dtype],
d_y: UnsafePointer[Scalar[dtype], MutAnyOrigin],
incy: Int,
ctx: DeviceContext,
) raises:

# NOTE: add error checking here?
# check m, n > 0
# check kl, ku >= 0
# check lda >= kl + ku + 1
# check incx, incy > 0

var out_len = m if not trans else n
var trans_i = 1 if trans else 0

@parameter
if dtype == DType.float32:
ctx.enqueue_function[sgbmv_device, sgbmv_device](
trans_i, m, n, kl, ku,
alpha, d_A, lda,
d_x, incx,
beta, d_y, incy,
grid_dim=ceildiv(out_len, TBsize),
block_dim=TBsize,
)
elif dtype == DType.float64:
ctx.enqueue_function[dgbmv_device, dgbmv_device](
trans_i, m, n, kl, ku,
alpha, d_A, lda,
d_x, incx,
beta, d_y, incy,
grid_dim=ceildiv(out_len, TBsize),
block_dim=TBsize,
)
else:
raise Error("blas_gbmv: Unsupported type")

ctx.synchronize()
21 changes: 21 additions & 0 deletions src/testing_utils/testing_utils.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,24 @@ fn frobenius_norm_symmetric[dtype: DType](


return sqrt(sum)

fn dense_to_band[dtype: DType](
A: UnsafePointer[Scalar[dtype], MutAnyOrigin],
B: UnsafePointer[Scalar[dtype], MutAnyOrigin],
m: Int,
n: Int,
kl: Int,
ku: Int
):
band_width = kl + ku + 1

for i in range(m):
for j in range(band_width):
B[i * band_width + j] = 0.0

for j in range(n):
if j >= i - kl and j <= i + ku:
band_col = j - (i - kl)
B[i * band_width + band_col] = A[i * n + j]
else:
A[i * n + j] = Scalar[dtype](0)
116 changes: 116 additions & 0 deletions test-level2.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,112 @@ def syr2_test[

assert_true(passed)

def gbmv_test[
dtype: DType,
m: Int,
n: Int,
kl: Int,
ku: Int,
trans: Bool,
]():
comptime x_len = n if not trans else m
comptime y_len = m if not trans else n
comptime lda = kl + ku + 1

with DeviceContext() as ctx:
# Dense host matrix (for reference + band extraction)
A_dense = ctx.enqueue_create_host_buffer[dtype](m * n)

# Band storage
A_band = ctx.enqueue_create_host_buffer[dtype](lda * m)
A_d = ctx.enqueue_create_buffer[dtype](lda * m)

x = ctx.enqueue_create_host_buffer[dtype](x_len)
x_d = ctx.enqueue_create_buffer[dtype](x_len)

y = ctx.enqueue_create_host_buffer[dtype](y_len)
y_d = ctx.enqueue_create_buffer[dtype](y_len)

generate_random_arr[dtype, m * n](A_dense.unsafe_ptr(), -1, 1)
generate_random_arr[dtype, x_len](x.unsafe_ptr(), -100, 100)
generate_random_arr[dtype, y_len](y.unsafe_ptr(), -100, 100)

dense_to_band(A_dense.unsafe_ptr(), A_band.unsafe_ptr(), m, n, kl, ku)

ctx.enqueue_copy(A_d, A_band)
ctx.enqueue_copy(x_d, x)
ctx.enqueue_copy(y_d, y)
ctx.synchronize()

var alpha = generate_random_scalar[dtype](-100, 100)
var beta = generate_random_scalar[dtype](-100, 100)

var norm_A = frobenius_norm[dtype](A_dense.unsafe_ptr(), m * n)
var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), x_len)
var norm_y = frobenius_norm[dtype](y.unsafe_ptr(), y_len)

blas_gbmv[dtype](
trans,
m, n,
kl, ku,
alpha,
A_d.unsafe_ptr(), lda,
x_d.unsafe_ptr(), 1,
beta,
y_d.unsafe_ptr(), 1,
ctx,
)

# Python reference
sp = Python.import_module("scipy")
np = Python.import_module("numpy")
sp_blas = sp.linalg.blas

py_A = Python.list()
py_x = Python.list()
py_y = Python.list()

for i in range(m * n):
py_A.append(A_dense[i])
for i in range(x_len):
py_x.append(x[i])
for i in range(y_len):
py_y.append(y[i])

var sp_res: PythonObject

if dtype == DType.float32:
np_A = np.array(py_A, dtype=np.float32).reshape(m, n)
np_x = np.array(py_x, dtype=np.float32)
np_y = np.array(py_y, dtype=np.float32)

sp_res = sp_blas.sgemv(alpha, np_A, np_x, beta=beta, y=np_y, trans=1 if trans else 0)

elif dtype == DType.float64:
np_A = np.array(py_A, dtype=np.float64).reshape(m, n)
np_x = np.array(py_x, dtype=np.float64)
np_y = np.array(py_y, dtype=np.float64)

sp_res = sp_blas.dgemv(alpha, np_A, np_x, beta=beta, y=np_y, trans=1 if trans else 0)
else:
print("Unsupported type")
return

with y_d.map_to_host() as res_mojo:
var norm_diff = Scalar[dtype](0)
for i in range(y_len):
var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i])

norm_diff += diff * diff
norm_diff = sqrt(norm_diff)

var ok = check_gemm_error[dtype](
1, y_len, x_len,
alpha, beta,
norm_A, norm_x, norm_y,
norm_diff
)
assert_true(ok)

def test_gemv():
gemv_test[DType.float32, 64, 64, False]()
Expand Down Expand Up @@ -349,6 +455,16 @@ def test_syr2():
syr2_test[DType.float64, 512, 0]()
syr2_test[DType.float64, 512, 1]()

def test_gbmv():
gbmv_test[DType.float32, 64, 64, 1, 2, False]()
gbmv_test[DType.float32, 64, 64, 2, 2, True]()
gbmv_test[DType.float64, 64, 64, 1, 2, False]()
gbmv_test[DType.float64, 64, 64, 2, 2, True]()
gbmv_test[DType.float32, 512, 64, 1, 2, False]()
gbmv_test[DType.float32, 512, 64, 2, 2, True]()
gbmv_test[DType.float64, 512, 64, 1, 2, False]()
gbmv_test[DType.float64, 512, 64, 2, 2, True]()

def main():
print("--- MojoBLAS Level 2 routines testing ---")
TestSuite.discover_tests[__functions_in_module()]().run()