Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ from .util import *
from .testing_utils import *
from .level1 import *
from .level2 import *
# from .level3 import *
from .level3 import *
1 change: 1 addition & 0 deletions src/level3/__init__.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .gemm_device import*
155 changes: 155 additions & 0 deletions src/level3/gemm_device.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from gpu import thread_idx, block_idx, block_dim, grid_dim
from gpu.host import DeviceContext
from math import ceildiv

comptime TBsize = 512
comptime TBx = 32
comptime TBy = 16
fn sgemm_device(
trans_a: Int, trans_b: Int,
m: Int,
n: Int,
k: Int,
alpha: Float32,
A: UnsafePointer[Float32, ImmutAnyOrigin],
lda: Int,
B: UnsafePointer[Float32, ImmutAnyOrigin],
ldb: Int,
beta: Float32,
C: UnsafePointer[Float32, MutAnyOrigin],
ldc: Int,
) :
var global_row = block_dim.y * block_idx.y + thread_idx.y
var global_col = block_dim.x * block_idx.x + thread_idx.x
var n_threads_row = grid_dim.y * block_dim.y
var n_threads_col = grid_dim.x * block_dim.x

for i in range(global_row, m, n_threads_row) :
for j in range(global_col, n, n_threads_col) :
var sum = Scalar[DType.float32](0)
if trans_a and trans_b :
for kk in range(k) :
sum += A[kk * lda + i] * B[j * ldb + kk]
elif trans_a :
for kk in range(k) :
sum += A[kk * lda + i] * B[kk * ldb + j]
elif trans_b :
for kk in range(k) :
sum += A[i * lda + kk] * B[j * ldb + kk]
else :
for kk in range(k) :
sum += A[i * lda + kk] * B[kk * ldb + j]
C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j]


fn dgemm_device(
trans_a: Int, trans_b: Int,
m: Int,
n: Int,
k: Int,
alpha: Float64,
A: UnsafePointer[Float64, ImmutAnyOrigin],
lda: Int,
B: UnsafePointer[Float64, ImmutAnyOrigin],
ldb: Int,
beta: Float64,
C: UnsafePointer[Float64, MutAnyOrigin],
ldc: Int,
) :
var global_row = block_dim.y * block_idx.y + thread_idx.y
var global_col = block_dim.x * block_idx.x + thread_idx.x
var n_threads_row = grid_dim.y * block_dim.y
var n_threads_col = grid_dim.x * block_dim.x

for i in range(global_row, m, n_threads_row) :
for j in range(global_col, n, n_threads_col) :
var sum = Scalar[DType.float64](0)
if trans_a and trans_b :
for kk in range(k) :
sum += A[kk * lda + i] * B[j * ldb + kk]
elif trans_a :
for kk in range(k) :
sum += A[kk * lda + i] * B[kk * ldb + j]
elif trans_b :
for kk in range(k) :
sum += A[i * lda + kk] * B[j * ldb + kk]
else :
for kk in range(k) :
sum += A[i * lda + kk] * B[kk * ldb + j]
C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j]


fn blas_gemm[dtype: DType](
trans_a: Bool, trans_b: Bool,
m: Int,
n: Int,
k: Int,
alpha: Scalar[dtype],
d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
lda: Int,
d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
ldb: Int,
beta: Scalar[dtype],
d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin],
ldc: Int,
ctx: DeviceContext
) raises :
"""
Performs Matrix multiplication of from:
C := alpha*op( A )*op( B ) + beta*C
where op ( X ) is one of
op( X ) = X or op ( X ) = X**T
alpha and beta are scalars, and A, B and C are matrices, with op( A )
an m by k matrix, op( B ) a k by n matrix and C an m by n matrix.
"""

blas_error_if["blas_gemm" , "m < 0"](m < 0)
blas_error_if["blas_gemm" , "n < 0"](n < 0)
blas_error_if["blas_gemm" , "k < 0"](k < 0)
var trans_a_i = 0
var trans_b_i = 0

if trans_a :
blas_error_if["blas_gemm" , "lda < m"](lda < m)
trans_a_i = 1
else :
blas_error_if["blas_gemm" , "lda < k"](lda < k)
if trans_b :
blas_error_if["blas_gemm" , "ldb < k"](ldb < k)
trans_b_i = 1
else :
blas_error_if["blas_gemm" , "ldb < n"](ldb < n)

blas_error_if["blas_gemm" , "ldc < n"](ldc < n)


# quick returns ?
# BLAS Fast path dont load C when beta = 0 ?
@parameter
if dtype == DType.float32:
ctx.enqueue_function[sgemm_device, sgemm_device](
trans_a_i, trans_b_i,
m, n, k,
alpha,
d_A, lda,
d_B, ldb,
beta,
d_C, ldc,
grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)),
block_dim=(TBx, TBy))
elif dtype == DType.float64:
ctx.enqueue_function[dgemm_device, dgemm_device](
trans_a_i, trans_b_i,
m, n, k,
alpha,
d_A, lda,
d_B, ldb,
beta,
d_C, ldc,
grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)),
block_dim=(TBx, TBy)
)
else:
raise Error("blas_gemm: Unsupported type")

ctx.synchronize()
199 changes: 199 additions & 0 deletions test-level3.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from sys import argv
from testing import assert_equal, assert_almost_equal, assert_true, TestSuite
from gpu.host import DeviceContext

from src import *
from python import Python, PythonObject


def gemm_test[
dtype: DType,
m: Int,
n: Int,
k: Int,
trans_a: Bool, trans_b: Bool
]() :
with DeviceContext() as ctx:
A_d = ctx.enqueue_create_buffer[dtype](m * k)
A = ctx.enqueue_create_host_buffer[dtype](m * k)
B_d = ctx.enqueue_create_buffer[dtype](k * n)
B = ctx.enqueue_create_host_buffer[dtype](k * n)
C_d = ctx.enqueue_create_buffer[dtype](m * n)
C = ctx.enqueue_create_host_buffer[dtype](m * n)

generate_random_arr[dtype](m * k, A.unsafe_ptr(), -100, 100)
generate_random_arr[dtype](k * n, B.unsafe_ptr(), -100, 100)
generate_random_arr[dtype](m * n, C.unsafe_ptr(), -100, 100)

ctx.enqueue_copy(A_d, A)
ctx.enqueue_copy(B_d, B)
ctx.enqueue_copy(C_d, C)

var alpha = generate_random_scalar[dtype](-100, 100)
var beta = generate_random_scalar[dtype](-100, 100)

var norm_A = frobenius_norm[dtype](A.unsafe_ptr(), m * k)
var norm_B = frobenius_norm[dtype](B.unsafe_ptr(), k * n)
var norm_C = frobenius_norm[dtype](C.unsafe_ptr(), m * n)

var lda = m if trans_a else k
var ldb = k if trans_b else n
var ldc = n

blas_gemm[dtype](
trans_a, trans_b,
m, n, k,
alpha,
A_d.unsafe_ptr(), lda,
B_d.unsafe_ptr(), ldb,
beta,
C_d.unsafe_ptr(), ldc,
ctx
)

sp = Python.import_module("scipy")
np = Python.import_module("numpy")
sp_blas = sp.linalg.blas

py_A = Python.list()
py_B = Python.list()
py_C = Python.list()
for i in range(m * k):
py_A.append(A[i])
for i in range(k * n):
py_B.append(B[i])
for i in range(m * n):
py_C.append(C[i])

if dtype == DType.float32:
np_A = np.array(py_A, dtype=np.float32).reshape(m, k) if not trans_a else np.array(py_A, dtype=np.float32).reshape(k, m)
np_B = np.array(py_B, dtype=np.float32).reshape(k, n) if not trans_b else np.array(py_B, dtype=np.float32).reshape(n, k)
np_C = np.array(py_C, dtype=np.float32).reshape(m, n)

elif dtype == DType.float64:
np_A = np.array(py_A, dtype=np.float64).reshape(m, k) if not trans_a else np.array(py_A, dtype=np.float64).reshape(k, m)
np_B = np.array(py_B, dtype=np.float64).reshape(k, n) if not trans_b else np.array(py_B, dtype=np.float64).reshape(n, k)
np_C = np.array(py_C, dtype=np.float64).reshape(m, n)
#sp_res = sp_blas.dgemm(alpha, np_A, np_B, beta=beta, c=np_C, trans_a= 0 if trans_a else 1, trans_b= 0 if trans_b else 1)
else :
print("Unsupported type: ", dtype)
return

var op_A = np.transpose(np_A) if trans_a else np_A
var op_B = np.transpose(np_B) if trans_b else np_B

var sp_res = alpha * np.matmul(op_A, op_B) + beta * np_C

with C_d.map_to_host() as res_mojo :
var norm_diff = Scalar[dtype](0)
for i in range(m):
for j in range(n) :
var diff = res_mojo[i * n + j] - Scalar[dtype](py=sp_res[i][j])
norm_diff += diff * diff
norm_diff = sqrt(norm_diff)
var ok = check_gemm_error[dtype](m, n, k, alpha, beta, norm_A, norm_B, norm_C, norm_diff)
assert_true(ok)

def test_gemm() :

gemm_test[DType.float32, 64, 64, 64, False, False]()
gemm_test[DType.float32, 64, 64, 64, True, False]()
gemm_test[DType.float32, 64, 64, 64, False, True]()
gemm_test[DType.float32, 64, 64, 64, True, True]()

gemm_test[DType.float32, 64, 1024, 64, False, False]()
gemm_test[DType.float32, 64, 1024, 64, True, False]()
gemm_test[DType.float32, 64, 1024, 64, False, True]()
gemm_test[DType.float32, 64, 1024, 64, True, True]()

gemm_test[DType.float32, 1024, 64, 64, False, False]()
gemm_test[DType.float32, 1024, 64, 64, True, False]()
gemm_test[DType.float32, 1024, 64, 64, False, True]()
gemm_test[DType.float32, 1024, 64, 64, True, True]()

gemm_test[DType.float32, 64, 64, 1024, False, False]()
gemm_test[DType.float32, 64, 64, 1024, True, False]()
gemm_test[DType.float32, 64, 64, 1024, False, True]()
gemm_test[DType.float32, 64, 64, 1024, True, True]()

gemm_test[DType.float32, 1024, 64, 1024, False, False]()
gemm_test[DType.float32, 1024, 64, 1024, True, False]()
gemm_test[DType.float32, 1024, 64, 1024, False, True]()
gemm_test[DType.float32, 1024, 64, 1024, True, True]()

gemm_test[DType.float32, 64, 1024, 1024, False, False]()
gemm_test[DType.float32, 64, 1024, 1024, True, False]()
gemm_test[DType.float32, 64, 1024, 1024, False, True]()
gemm_test[DType.float32, 64, 1024, 1024, True, True]()

gemm_test[DType.float32, 1024, 1024, 64, False, False]()
gemm_test[DType.float32, 1024, 1024, 64, True, False]()
gemm_test[DType.float32, 1024, 1024, 64, False, True]()
gemm_test[DType.float32, 1024, 1024, 64, True, True]()

gemm_test[DType.float32, 1024, 1024, 1024, False, False]()
gemm_test[DType.float32, 1024, 1024, 1024, True, False]()
gemm_test[DType.float32, 1024, 1024, 1024, False, True]()
gemm_test[DType.float32, 1024, 1024, 1024, True, True]()

gemm_test[DType.float64, 64, 64, 64, False, False]()
gemm_test[DType.float64, 64, 64, 64, True, False]()
gemm_test[DType.float64, 64, 64, 64, False, True]()
gemm_test[DType.float64, 64, 64, 64, True, True]()

gemm_test[DType.float64, 64, 1024, 64, False, False]()
gemm_test[DType.float64, 64, 1024, 64, True, False]()
gemm_test[DType.float64, 64, 1024, 64, False, True]()
gemm_test[DType.float64, 64, 1024, 64, True, True]()

gemm_test[DType.float64, 1024, 64, 64, False, False]()
gemm_test[DType.float64, 1024, 64, 64, True, False]()
gemm_test[DType.float64, 1024, 64, 64, False, True]()
gemm_test[DType.float64, 1024, 64, 64, True, True]()

gemm_test[DType.float64, 64, 64, 1024, False, False]()
gemm_test[DType.float64, 64, 64, 1024, True, False]()
gemm_test[DType.float64, 64, 64, 1024, False, True]()
gemm_test[DType.float64, 64, 64, 1024, True, True]()

gemm_test[DType.float64, 1024, 64, 1024, False, False]()
gemm_test[DType.float64, 1024, 64, 1024, True, False]()
gemm_test[DType.float64, 1024, 64, 1024, False, True]()
gemm_test[DType.float64, 1024, 64, 1024, True, True]()

gemm_test[DType.float64, 64, 1024, 1024, False, False]()
gemm_test[DType.float64, 64, 1024, 1024, True, False]()
gemm_test[DType.float64, 64, 1024, 1024, False, True]()
gemm_test[DType.float64, 64, 1024, 1024, True, True]()

gemm_test[DType.float64, 1024, 1024, 64, False, False]()
gemm_test[DType.float64, 1024, 1024, 64, True, False]()
gemm_test[DType.float64, 1024, 1024, 64, False, True]()
gemm_test[DType.float64, 1024, 1024, 64, True, True]()

gemm_test[DType.float64, 1024, 1024, 1024, False, False]()
gemm_test[DType.float64, 1024, 1024, 1024, True, False]()
gemm_test[DType.float64, 1024, 1024, 1024, False, True]()
gemm_test[DType.float64, 1024, 1024, 1024, True, True]()


def main():
print("--- MojoBLAS Level 2 routines testing ---")
var args = argv()
if (len(args) < 2):
TestSuite.discover_tests[__functions_in_module()]().run()
return

var suite = TestSuite(cli_args=List[StaticString]())
for i in range(1, len(args)):
if args[i] == "gemm": suite.test[test_gemm]()
else: print("unknown routine:", args[i])
suite^.run()








Loading