diff --git a/src/__init__.mojo b/src/__init__.mojo index b5c2f85..587d1ea 100644 --- a/src/__init__.mojo +++ b/src/__init__.mojo @@ -2,4 +2,4 @@ from .util import * from .testing_utils import * from .level1 import * from .level2 import * -# from .level3 import * +from .level3 import * diff --git a/src/level3/__init__.mojo b/src/level3/__init__.mojo new file mode 100644 index 0000000..f3a2520 --- /dev/null +++ b/src/level3/__init__.mojo @@ -0,0 +1 @@ +from .gemm_device import* diff --git a/src/level3/gemm_device.mojo b/src/level3/gemm_device.mojo new file mode 100644 index 0000000..adf91cf --- /dev/null +++ b/src/level3/gemm_device.mojo @@ -0,0 +1,155 @@ +from gpu import thread_idx, block_idx, block_dim, grid_dim +from gpu.host import DeviceContext +from math import ceildiv + +comptime TBsize = 512 +comptime TBx = 32 +comptime TBy = 16 +fn sgemm_device( + trans_a: Int, trans_b: Int, + m: Int, + n: Int, + k: Int, + alpha: Float32, + A: UnsafePointer[Float32, ImmutAnyOrigin], + lda: Int, + B: UnsafePointer[Float32, ImmutAnyOrigin], + ldb: Int, + beta: Float32, + C: UnsafePointer[Float32, MutAnyOrigin], + ldc: Int, +) : + var global_row = block_dim.y * block_idx.y + thread_idx.y + var global_col = block_dim.x * block_idx.x + thread_idx.x + var n_threads_row = grid_dim.y * block_dim.y + var n_threads_col = grid_dim.x * block_dim.x + + for i in range(global_row, m, n_threads_row) : + for j in range(global_col, n, n_threads_col) : + var sum = Scalar[DType.float32](0) + if trans_a and trans_b : + for kk in range(k) : + sum += A[kk * lda + i] * B[j * ldb + kk] + elif trans_a : + for kk in range(k) : + sum += A[kk * lda + i] * B[kk * ldb + j] + elif trans_b : + for kk in range(k) : + sum += A[i * lda + kk] * B[j * ldb + kk] + else : + for kk in range(k) : + sum += A[i * lda + kk] * B[kk * ldb + j] + C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + + +fn dgemm_device( + trans_a: Int, trans_b: Int, + m: Int, + n: Int, + k: Int, + alpha: Float64, + A: UnsafePointer[Float64, ImmutAnyOrigin], + lda: Int, + B: UnsafePointer[Float64, ImmutAnyOrigin], + ldb: Int, + beta: Float64, + C: UnsafePointer[Float64, MutAnyOrigin], + ldc: Int, +) : + var global_row = block_dim.y * block_idx.y + thread_idx.y + var global_col = block_dim.x * block_idx.x + thread_idx.x + var n_threads_row = grid_dim.y * block_dim.y + var n_threads_col = grid_dim.x * block_dim.x + + for i in range(global_row, m, n_threads_row) : + for j in range(global_col, n, n_threads_col) : + var sum = Scalar[DType.float64](0) + if trans_a and trans_b : + for kk in range(k) : + sum += A[kk * lda + i] * B[j * ldb + kk] + elif trans_a : + for kk in range(k) : + sum += A[kk * lda + i] * B[kk * ldb + j] + elif trans_b : + for kk in range(k) : + sum += A[i * lda + kk] * B[j * ldb + kk] + else : + for kk in range(k) : + sum += A[i * lda + kk] * B[kk * ldb + j] + C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + + +fn blas_gemm[dtype: DType]( + trans_a: Bool, trans_b: Bool, + m: Int, + n: Int, + k: Int, + alpha: Scalar[dtype], + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + ldb: Int, + beta: Scalar[dtype], + d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin], + ldc: Int, + ctx: DeviceContext +) raises : + """ + Performs Matrix multiplication of from: + C := alpha*op( A )*op( B ) + beta*C + where op ( X ) is one of + op( X ) = X or op ( X ) = X**T + alpha and beta are scalars, and A, B and C are matrices, with op( A ) + an m by k matrix, op( B ) a k by n matrix and C an m by n matrix. + """ + + blas_error_if["blas_gemm" , "m < 0"](m < 0) + blas_error_if["blas_gemm" , "n < 0"](n < 0) + blas_error_if["blas_gemm" , "k < 0"](k < 0) + var trans_a_i = 0 + var trans_b_i = 0 + + if trans_a : + blas_error_if["blas_gemm" , "lda < m"](lda < m) + trans_a_i = 1 + else : + blas_error_if["blas_gemm" , "lda < k"](lda < k) + if trans_b : + blas_error_if["blas_gemm" , "ldb < k"](ldb < k) + trans_b_i = 1 + else : + blas_error_if["blas_gemm" , "ldb < n"](ldb < n) + + blas_error_if["blas_gemm" , "ldc < n"](ldc < n) + + + # quick returns ? + # BLAS Fast path dont load C when beta = 0 ? + @parameter + if dtype == DType.float32: + ctx.enqueue_function[sgemm_device, sgemm_device]( + trans_a_i, trans_b_i, + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), + block_dim=(TBx, TBy)) + elif dtype == DType.float64: + ctx.enqueue_function[dgemm_device, dgemm_device]( + trans_a_i, trans_b_i, + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), + block_dim=(TBx, TBy) + ) + else: + raise Error("blas_gemm: Unsupported type") + + ctx.synchronize() diff --git a/test-level3.mojo b/test-level3.mojo new file mode 100644 index 0000000..3e38596 --- /dev/null +++ b/test-level3.mojo @@ -0,0 +1,199 @@ +from sys import argv +from testing import assert_equal, assert_almost_equal, assert_true, TestSuite +from gpu.host import DeviceContext + +from src import * +from python import Python, PythonObject + + +def gemm_test[ + dtype: DType, + m: Int, + n: Int, + k: Int, + trans_a: Bool, trans_b: Bool +]() : + with DeviceContext() as ctx: + A_d = ctx.enqueue_create_buffer[dtype](m * k) + A = ctx.enqueue_create_host_buffer[dtype](m * k) + B_d = ctx.enqueue_create_buffer[dtype](k * n) + B = ctx.enqueue_create_host_buffer[dtype](k * n) + C_d = ctx.enqueue_create_buffer[dtype](m * n) + C = ctx.enqueue_create_host_buffer[dtype](m * n) + + generate_random_arr[dtype](m * k, A.unsafe_ptr(), -100, 100) + generate_random_arr[dtype](k * n, B.unsafe_ptr(), -100, 100) + generate_random_arr[dtype](m * n, C.unsafe_ptr(), -100, 100) + + ctx.enqueue_copy(A_d, A) + ctx.enqueue_copy(B_d, B) + ctx.enqueue_copy(C_d, C) + + var alpha = generate_random_scalar[dtype](-100, 100) + var beta = generate_random_scalar[dtype](-100, 100) + + var norm_A = frobenius_norm[dtype](A.unsafe_ptr(), m * k) + var norm_B = frobenius_norm[dtype](B.unsafe_ptr(), k * n) + var norm_C = frobenius_norm[dtype](C.unsafe_ptr(), m * n) + + var lda = m if trans_a else k + var ldb = k if trans_b else n + var ldc = n + + blas_gemm[dtype]( + trans_a, trans_b, + m, n, k, + alpha, + A_d.unsafe_ptr(), lda, + B_d.unsafe_ptr(), ldb, + beta, + C_d.unsafe_ptr(), ldc, + ctx + ) + + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_B = Python.list() + py_C = Python.list() + for i in range(m * k): + py_A.append(A[i]) + for i in range(k * n): + py_B.append(B[i]) + for i in range(m * n): + py_C.append(C[i]) + + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32).reshape(m, k) if not trans_a else np.array(py_A, dtype=np.float32).reshape(k, m) + np_B = np.array(py_B, dtype=np.float32).reshape(k, n) if not trans_b else np.array(py_B, dtype=np.float32).reshape(n, k) + np_C = np.array(py_C, dtype=np.float32).reshape(m, n) + + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64).reshape(m, k) if not trans_a else np.array(py_A, dtype=np.float64).reshape(k, m) + np_B = np.array(py_B, dtype=np.float64).reshape(k, n) if not trans_b else np.array(py_B, dtype=np.float64).reshape(n, k) + np_C = np.array(py_C, dtype=np.float64).reshape(m, n) + #sp_res = sp_blas.dgemm(alpha, np_A, np_B, beta=beta, c=np_C, trans_a= 0 if trans_a else 1, trans_b= 0 if trans_b else 1) + else : + print("Unsupported type: ", dtype) + return + + var op_A = np.transpose(np_A) if trans_a else np_A + var op_B = np.transpose(np_B) if trans_b else np_B + + var sp_res = alpha * np.matmul(op_A, op_B) + beta * np_C + + with C_d.map_to_host() as res_mojo : + var norm_diff = Scalar[dtype](0) + for i in range(m): + for j in range(n) : + var diff = res_mojo[i * n + j] - Scalar[dtype](py=sp_res[i][j]) + norm_diff += diff * diff + norm_diff = sqrt(norm_diff) + var ok = check_gemm_error[dtype](m, n, k, alpha, beta, norm_A, norm_B, norm_C, norm_diff) + assert_true(ok) + +def test_gemm() : + + gemm_test[DType.float32, 64, 64, 64, False, False]() + gemm_test[DType.float32, 64, 64, 64, True, False]() + gemm_test[DType.float32, 64, 64, 64, False, True]() + gemm_test[DType.float32, 64, 64, 64, True, True]() + + gemm_test[DType.float32, 64, 1024, 64, False, False]() + gemm_test[DType.float32, 64, 1024, 64, True, False]() + gemm_test[DType.float32, 64, 1024, 64, False, True]() + gemm_test[DType.float32, 64, 1024, 64, True, True]() + + gemm_test[DType.float32, 1024, 64, 64, False, False]() + gemm_test[DType.float32, 1024, 64, 64, True, False]() + gemm_test[DType.float32, 1024, 64, 64, False, True]() + gemm_test[DType.float32, 1024, 64, 64, True, True]() + + gemm_test[DType.float32, 64, 64, 1024, False, False]() + gemm_test[DType.float32, 64, 64, 1024, True, False]() + gemm_test[DType.float32, 64, 64, 1024, False, True]() + gemm_test[DType.float32, 64, 64, 1024, True, True]() + + gemm_test[DType.float32, 1024, 64, 1024, False, False]() + gemm_test[DType.float32, 1024, 64, 1024, True, False]() + gemm_test[DType.float32, 1024, 64, 1024, False, True]() + gemm_test[DType.float32, 1024, 64, 1024, True, True]() + + gemm_test[DType.float32, 64, 1024, 1024, False, False]() + gemm_test[DType.float32, 64, 1024, 1024, True, False]() + gemm_test[DType.float32, 64, 1024, 1024, False, True]() + gemm_test[DType.float32, 64, 1024, 1024, True, True]() + + gemm_test[DType.float32, 1024, 1024, 64, False, False]() + gemm_test[DType.float32, 1024, 1024, 64, True, False]() + gemm_test[DType.float32, 1024, 1024, 64, False, True]() + gemm_test[DType.float32, 1024, 1024, 64, True, True]() + + gemm_test[DType.float32, 1024, 1024, 1024, False, False]() + gemm_test[DType.float32, 1024, 1024, 1024, True, False]() + gemm_test[DType.float32, 1024, 1024, 1024, False, True]() + gemm_test[DType.float32, 1024, 1024, 1024, True, True]() + + gemm_test[DType.float64, 64, 64, 64, False, False]() + gemm_test[DType.float64, 64, 64, 64, True, False]() + gemm_test[DType.float64, 64, 64, 64, False, True]() + gemm_test[DType.float64, 64, 64, 64, True, True]() + + gemm_test[DType.float64, 64, 1024, 64, False, False]() + gemm_test[DType.float64, 64, 1024, 64, True, False]() + gemm_test[DType.float64, 64, 1024, 64, False, True]() + gemm_test[DType.float64, 64, 1024, 64, True, True]() + + gemm_test[DType.float64, 1024, 64, 64, False, False]() + gemm_test[DType.float64, 1024, 64, 64, True, False]() + gemm_test[DType.float64, 1024, 64, 64, False, True]() + gemm_test[DType.float64, 1024, 64, 64, True, True]() + + gemm_test[DType.float64, 64, 64, 1024, False, False]() + gemm_test[DType.float64, 64, 64, 1024, True, False]() + gemm_test[DType.float64, 64, 64, 1024, False, True]() + gemm_test[DType.float64, 64, 64, 1024, True, True]() + + gemm_test[DType.float64, 1024, 64, 1024, False, False]() + gemm_test[DType.float64, 1024, 64, 1024, True, False]() + gemm_test[DType.float64, 1024, 64, 1024, False, True]() + gemm_test[DType.float64, 1024, 64, 1024, True, True]() + + gemm_test[DType.float64, 64, 1024, 1024, False, False]() + gemm_test[DType.float64, 64, 1024, 1024, True, False]() + gemm_test[DType.float64, 64, 1024, 1024, False, True]() + gemm_test[DType.float64, 64, 1024, 1024, True, True]() + + gemm_test[DType.float64, 1024, 1024, 64, False, False]() + gemm_test[DType.float64, 1024, 1024, 64, True, False]() + gemm_test[DType.float64, 1024, 1024, 64, False, True]() + gemm_test[DType.float64, 1024, 1024, 64, True, True]() + + gemm_test[DType.float64, 1024, 1024, 1024, False, False]() + gemm_test[DType.float64, 1024, 1024, 1024, True, False]() + gemm_test[DType.float64, 1024, 1024, 1024, False, True]() + gemm_test[DType.float64, 1024, 1024, 1024, True, True]() + + +def main(): + print("--- MojoBLAS Level 2 routines testing ---") + var args = argv() + if (len(args) < 2): + TestSuite.discover_tests[__functions_in_module()]().run() + return + + var suite = TestSuite(cli_args=List[StaticString]()) + for i in range(1, len(args)): + if args[i] == "gemm": suite.test[test_gemm]() + else: print("unknown routine:", args[i]) + suite^.run() + + + + + + + +