Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/level2/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ from .syr2_device import *
from .gbmv_device import *
from .sbmv_device import *
from .trsv_device import *
from .symv_device import *
116 changes: 116 additions & 0 deletions src/level2/symv_device.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from gpu import thread_idx, block_idx, block_dim, grid_dim
from gpu.host import DeviceContext
from math import ceildiv

comptime TBsize = 512

# level2.symv
# Performs matrix-vector multiplication of form
# y := alpha*A*x + beta*y
# where alpha and beta are scalars, x and y are vectors and A is an n by n symmetric matrix.
fn ssymv_device(
upper: Int,
n: Int,
alpha: Float32,
A: UnsafePointer[Float32, ImmutAnyOrigin],
lda: Int,
x: UnsafePointer[Float32, ImmutAnyOrigin],
incx: Int,
beta: Float32,
y: UnsafePointer[Float32, MutAnyOrigin],
incy: Int,
):
var global_i = block_dim.x * block_idx.x + thread_idx.x
var n_threads = grid_dim.x * block_dim.x

for i in range(global_i, n, n_threads):
var sum = Scalar[DType.float32](0)
for j in range(n):
var val: Float32
if upper:
if j >= i:
val = A[i * lda + j]
else:
val = A[j * lda + i]
else:
if j <= i:
val = A[i * lda + j]
else:
val = A[j * lda + i]
sum += val * x[j * incx]
y[i * incy] = alpha * sum + beta * y[i * incy]


fn dsymv_device(
upper: Int,
n: Int,
alpha: Float64,
A: UnsafePointer[Float64, ImmutAnyOrigin],
lda: Int,
x: UnsafePointer[Float64, ImmutAnyOrigin],
incx: Int,
beta: Float64,
y: UnsafePointer[Float64, MutAnyOrigin],
incy: Int,
):
var global_i = block_dim.x * block_idx.x + thread_idx.x
var n_threads = grid_dim.x * block_dim.x

for i in range(global_i, n, n_threads):
var sum = Scalar[DType.float64](0)
for j in range(n):
var val: Float64
if upper:
if j >= i:
val = A[i * lda + j]
else:
val = A[j * lda + i]
else:
if j <= i:
val = A[i * lda + j]
else:
val = A[j * lda + i]
sum += val * x[j * incx]
y[i * incy] = alpha * sum + beta * y[i * incy]


fn blas_symv[dtype: DType](
upper: Bool,
n: Int,
alpha: Scalar[dtype],
d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
lda: Int,
d_x: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
incx: Int,
beta: Scalar[dtype],
d_y: UnsafePointer[Scalar[dtype], MutAnyOrigin],
incy: Int,
ctx: DeviceContext,
) raises:
# NOTE: add error checking here?

var upper_i = 1 if upper else 0

@parameter
if dtype == DType.float32:
ctx.enqueue_function[ssymv_device, ssymv_device](
upper_i, n,
alpha, d_A, lda,
d_x, incx,
beta, d_y, incy,
grid_dim=ceildiv(n, TBsize),
block_dim=TBsize,
)
elif dtype == DType.float64:
ctx.enqueue_function[dsymv_device, dsymv_device](
upper_i, n,
alpha, d_A, lda,
d_x, incx,
beta, d_y, incy,
grid_dim=ceildiv(n, TBsize),
block_dim=TBsize,
)
else:
raise Error("blas_symv: Unsupported type")

ctx.synchronize()
101 changes: 99 additions & 2 deletions test-level2.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,92 @@ def sbmv_test[
)
assert_true(ok)

def symv_test[
dtype: DType,
n: Int,
upper: Bool,
]():
with DeviceContext() as ctx:
A_d = ctx.enqueue_create_buffer[dtype](n * n)
A = ctx.enqueue_create_host_buffer[dtype](n * n)
x_d = ctx.enqueue_create_buffer[dtype](n)
x = ctx.enqueue_create_host_buffer[dtype](n)
y_d = ctx.enqueue_create_buffer[dtype](n)
y = ctx.enqueue_create_host_buffer[dtype](n)

generate_random_arr[dtype](n * n, A.unsafe_ptr(), -1, 1)

# Force symmetry: A = 0.5 * (A + A^T)
for i in range(n):
for j in range(i, n):
var sym = (A[i * n + j] + A[j * n + i]) * Scalar[dtype](0.5)
A[i * n + j] = sym
A[j * n + i] = sym

generate_random_arr[dtype](n, x.unsafe_ptr(), -100, 100)
generate_random_arr[dtype](n, y.unsafe_ptr(), -100, 100)

var alpha = generate_random_scalar[dtype](-100, 100)
var beta = generate_random_scalar[dtype](-100, 100)

var norm_A = frobenius_norm[dtype](A.unsafe_ptr(), n * n)
var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), n)
var norm_y = frobenius_norm[dtype](y.unsafe_ptr(), n)

ctx.enqueue_copy(A_d, A)
ctx.enqueue_copy(x_d, x)
ctx.enqueue_copy(y_d, y)
ctx.synchronize()

blas_symv[dtype](
upper,
n,
alpha,
A_d.unsafe_ptr(), n,
x_d.unsafe_ptr(), 1,
beta,
y_d.unsafe_ptr(), 1,
ctx,
)

sp = Python.import_module("scipy")
np = Python.import_module("numpy")
sp_blas = sp.linalg.blas

py_A = Python.list()
py_x = Python.list()
py_y = Python.list()
for i in range(n * n):
py_A.append(A[i])
for i in range(n):
py_x.append(x[i])
py_y.append(y[i])

var sp_res: PythonObject

if dtype == DType.float32:
np_A = np.array(py_A, dtype=np.float32).reshape(n, n)
np_x = np.array(py_x, dtype=np.float32)
np_y = np.array(py_y, dtype=np.float32)
sp_res = sp_blas.ssymv(alpha, np_A, np_x, beta=beta, y=np_y, lower=0 if upper else 1)
elif dtype == DType.float64:
np_A = np.array(py_A, dtype=np.float64).reshape(n, n)
np_x = np.array(py_x, dtype=np.float64)
np_y = np.array(py_y, dtype=np.float64)
sp_res = sp_blas.dsymv(alpha, np_A, np_x, beta=beta, y=np_y, lower=0 if upper else 1)
else:
print("Unsupported type: ", dtype)
return

with y_d.map_to_host() as res_mojo:
var norm_diff = Scalar[dtype](0)
for i in range(n):
var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i])
norm_diff += diff * diff
norm_diff = sqrt(norm_diff)
var ok = check_gemm_error[dtype](1, n, n, alpha, beta, norm_A, norm_x, norm_y, norm_diff)
assert_true(ok)

def trsv_test[
dtype: DType,
n: Int,
Expand Down Expand Up @@ -734,6 +820,16 @@ def test_sbmv():
sbmv_test[DType.float64, 512, 32, False]()
sbmv_test[DType.float64, 512, 32, True]()

def test_symv():
symv_test[DType.float32, 64, True]()
symv_test[DType.float32, 64, False]()
symv_test[DType.float64, 64, True]()
symv_test[DType.float64, 64, False]()
symv_test[DType.float32, 1024, True]()
symv_test[DType.float32, 1024, False]()
symv_test[DType.float64, 1024, True]()
symv_test[DType.float64, 1024, False]()

def test_trsv():
trsv_test[DType.float32, 64, True, 0, 0]()
trsv_test[DType.float32, 64, True, 1, 0]()
Expand Down Expand Up @@ -761,13 +857,14 @@ def main():

var suite = TestSuite(cli_args=List[StaticString]())
for i in range(1, len(args)):
if args[i] == "gemv": suite.test[test_gemv]()
if args[i] == "gemv": suite.test[test_gemv]()
elif args[i] == "ger": suite.test[test_ger]()
elif args[i] == "sbmv": suite.test[test_sbmv]()
elif args[i] == "sbmv": suite.test[test_sbmv]()
elif args[i] == "syr": suite.test[test_syr]()
elif args[i] == "syr2": suite.test[test_syr2]()
elif args[i] == "gbmv": suite.test[test_gbmv]()
elif args[i] == "trsv": suite.test[test_trsv]()
elif args[i] == "symv": suite.test[test_symv]()
else: print("unknown routine:", args[i])
suite^.run()

Loading