Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/level2/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ from .gbmv_device import *
from .sbmv_device import *
from .trsv_device import *
from .symv_device import *
from .tpmv_device import *
150 changes: 150 additions & 0 deletions src/level2/tpmv_device.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from gpu import thread_idx, block_idx, block_dim, grid_dim
from gpu.host import DeviceContext
from math import ceildiv

comptime TBsize = 512

# level2.tpmv
# Performs matrix-vector multiplication of form
# x := A*x, or x := A**T*x
# where x is a vector and A is an n by n unit or non-unit, upper or lower
# triangular band matrix with (k+1) diagonals.
fn stpmv_device(
uplo: Int,
trans: Int,
diag: Int,
n: Int,
AP: UnsafePointer[Float32, ImmutAnyOrigin],
x: UnsafePointer[Float32, ImmutAnyOrigin],
incx: Int,
workspace: UnsafePointer[Float32, MutAnyOrigin],
):
var global_i = block_dim.x * block_idx.x + thread_idx.x
var n_threads = grid_dim.x * block_dim.x

for i in range(global_i, n, n_threads):
var sum = x[i * incx]
if trans:
if uplo:
var index = (i * (i + 1)) / 2 + i
if not diag:
sum *= AP[index]
for j in range(i + 1, n):
index += j
sum += AP[index] * x[j * incx]
else:
if not diag:
sum *= AP[i * n - ((i - 1) * i) / 2]
var index = i
for j in range(0, i):
sum += AP[index] * x[j * incx]
index += n - j
else:
if uplo:
var index = (i * (i + 1)) / 2 + i
if not diag:
sum *= AP[index]
index -= i
for j in range(0, i):
sum += AP[index] * x[j * incx]
index += 1
else:
var index = i * n - ((i - 1) * i) / 2
if not diag:
sum *= AP[index]
for j in range(i + 1, n):
index += 1
sum += AP[index] * x[j * incx]
workspace[i * incx] = sum


fn dtpmv_device(
uplo: Int,
trans: Int,
diag: Int,
n: Int,
AP: UnsafePointer[Float64, ImmutAnyOrigin],
x: UnsafePointer[Float64, ImmutAnyOrigin],
incx: Int,
workspace: UnsafePointer[Float64, MutAnyOrigin],
):
var global_i = block_dim.x * block_idx.x + thread_idx.x
var n_threads = grid_dim.x * block_dim.x

for i in range(global_i, n, n_threads):
var sum = x[i * incx]
if trans:
if uplo:
var index = (i * (i + 1)) / 2 + i
if not diag:
sum *= AP[index]
for j in range(i + 1, n):
index += j
sum += AP[index] * x[j * incx]
else:
if not diag:
sum *= AP[i * n - ((i - 1) * i) / 2]
var index = i
for j in range(0, i):
sum += AP[index] * x[j * incx]
index += n - j
else:
if uplo:
var index = (i * (i + 1)) / 2 + i
if not diag:
sum *= AP[index]
index -= i
for j in range(0, i):
sum += AP[index] * x[j * incx]
index += 1
else:
var index = i * n - ((i - 1) * i) / 2
if not diag:
sum *= AP[index]
for j in range(i + 1, n):
index += 1
sum += AP[index] * x[j * incx]
workspace[i * incx] = sum


fn blas_tpmv[dtype: DType](
uplo: Int,
trans: Bool,
diag: Int,
n: Int,
d_AP: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
d_x: UnsafePointer[Scalar[dtype], MutAnyOrigin],
incx: Int,
ctx: DeviceContext,
) raises:
# NOTE: add error checking here?

var trans_i = 1 if trans else 0

var workspace = ctx.enqueue_create_buffer[dtype](n)

@parameter
if dtype == DType.float32:
ctx.enqueue_function[stpmv_device, stpmv_device](
uplo, trans_i, diag,
n, d_AP,
d_x, incx,
workspace,
grid_dim=ceildiv(n, TBsize),
block_dim=TBsize,
)
elif dtype == DType.float64:
ctx.enqueue_function[dtpmv_device, dtpmv_device](
uplo, trans_i, diag,
n, d_AP,
d_x, incx,
workspace,
grid_dim=ceildiv(n, TBsize),
block_dim=TBsize,
)
else:
raise Error("blas_tpmv: Unsupported type")

ctx.enqueue_copy(d_x, workspace)

ctx.synchronize()
21 changes: 21 additions & 0 deletions src/testing_utils/testing_utils.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,24 @@ fn dense_to_sym_band_cm[dtype: DType](
val = A_band[(j - i) + i * lda]

A_dense[i * n + j] = val

fn dense_to_tri_band[dtype: DType](
A_dense: UnsafePointer[Scalar[dtype], MutAnyOrigin],
A_band: UnsafePointer[Scalar[dtype], MutAnyOrigin],
n: Int,
uplo: Int,
):
var index = 0
for j in range(n):
if uplo:
for i in range(0, j+1):
A_band[index] = A_dense[i * n + j]
index += 1
else:
for i in range(j, n):
A_band[index] = A_dense[i * n + j]
index += 1

var n_packed = n * (n + 1) / 2
for i in range(index, n_packed):
A_band[i] = 0
130 changes: 130 additions & 0 deletions test-level2.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,118 @@ def trsv_test[
)
assert_true(ok)


def tpmv_test[
dtype: DType,
n: Int,
trans: Bool,
uplo: Int,
diag: Int,
]():
with DeviceContext() as ctx:
A_dense = ctx.enqueue_create_host_buffer[dtype](n * n)

AP_d = ctx.enqueue_create_buffer[dtype](n * n)
AP = ctx.enqueue_create_host_buffer[dtype](n * n)
x_d = ctx.enqueue_create_buffer[dtype](n)
x = ctx.enqueue_create_host_buffer[dtype](n)

generate_random_arr[dtype](n * n, A_dense.unsafe_ptr(), -1, 1)
generate_random_arr[dtype](n, x.unsafe_ptr(), -1, 1)

# Zero out off-triangle elements to make A strictly triangular
for i in range(n):
for j in range(n):
# upper triangular
if uplo == 0:
if i > j:
A_dense[i * n + j] = 0
# lower triangular
else:
if i < j:
A_dense[i * n + j] = 0

# Handle diagonal based on diag parameter
for i in range(n):
# unit diagonal
if diag == 1:
A_dense[i * n + i] = 1
# non-unit diagonal: make diagonally dominant to reduce numerical instability
else:
A_dense[i * n + i] += 1000

dense_to_tri_band(A_dense.unsafe_ptr(), AP.unsafe_ptr(), n, uplo)

ctx.enqueue_copy(AP_d, AP)
ctx.enqueue_copy(x_d, x)
ctx.synchronize()

# Compute norms for error checks
var norm_A = frobenius_norm[dtype](A_dense.unsafe_ptr(), n * n)
var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), n)

blas_tpmv[dtype](
uplo,
trans,
diag,
n,
AP_d.unsafe_ptr(),
x_d.unsafe_ptr(), 1,
ctx,
)

# Import SciPy and numpy
sp = Python.import_module("scipy")
np = Python.import_module("numpy")
sp_blas = sp.linalg.blas

py_A = Python.list()
py_x = Python.list()
for i in range(n * n):
py_A.append(AP[i])
for i in range(n):
py_x.append(x[i])

var sp_res: PythonObject

if dtype == DType.float32:
np_A = np.array(py_A, dtype=np.float32)
np_x = np.array(py_x, dtype=np.float32)
sp_res = sp_blas.stpmv(n, np_A, np_x,
lower=0 if uplo else 1,
trans=1 if trans else 0,
diag=1 if diag else 0
)
elif dtype == DType.float64:
np_A = np.array(py_A, dtype=np.float64)
np_x = np.array(py_x, dtype=np.float64)
sp_res = sp_blas.dtpmv(n, np_A, np_x,
lower=0 if uplo else 1,
trans=1 if trans else 0,
diag=1 if diag else 0
)
else:
print("Unsupported type: ", dtype)
return

with x_d.map_to_host() as res_mojo:

var norm_diff = Scalar[dtype](0)
for i in range(n):
var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i])
print(res_mojo[i], Scalar[dtype](py=sp_res[i]))
norm_diff += diff * diff
norm_diff = sqrt(norm_diff)
print(norm_diff)

var ok = check_gemm_error[dtype](
1, n, n,
Scalar[dtype](1), Scalar[dtype](0),
norm_A, norm_x, Scalar[dtype](0),
norm_diff
)
assert_true(ok)

def test_gemv():
gemv_test[DType.float32, 64, 64, False]()
gemv_test[DType.float32, 64, 64, True]()
Expand Down Expand Up @@ -848,6 +960,24 @@ def test_trsv():
trsv_test[DType.float64, 64, False, 0, 1]()
trsv_test[DType.float64, 64, False, 1, 1]()

def test_tpmv():
tpmv_test[DType.float32, 64, True, 0, 0]()
tpmv_test[DType.float32, 64, True, 1, 0]()
tpmv_test[DType.float32, 64, True, 0, 1]()
tpmv_test[DType.float32, 64, True, 1, 1]()
tpmv_test[DType.float32, 64, False, 0, 0]()
tpmv_test[DType.float32, 64, False, 1, 0]()
tpmv_test[DType.float32, 64, False, 0, 1]()
tpmv_test[DType.float32, 64, False, 1, 1]()
tpmv_test[DType.float64, 64, True, 0, 0]()
tpmv_test[DType.float64, 64, True, 1, 0]()
tpmv_test[DType.float64, 64, True, 0, 1]()
tpmv_test[DType.float64, 64, True, 1, 1]()
tpmv_test[DType.float64, 64, False, 0, 0]()
tpmv_test[DType.float64, 64, False, 1, 0]()
tpmv_test[DType.float64, 64, False, 0, 1]()
tpmv_test[DType.float64, 64, False, 1, 1]()

def main():
print("--- MojoBLAS Level 2 routines testing ---")
var args = argv()
Expand Down
Loading