diff --git a/src/__init__.mojo b/src/__init__.mojo index 4762ea5..b5c2f85 100644 --- a/src/__init__.mojo +++ b/src/__init__.mojo @@ -1,3 +1,4 @@ +from .util import * from .testing_utils import * from .level1 import * from .level2 import * diff --git a/src/level1/asum_device.mojo b/src/level1/asum_device.mojo index 5bbc8b4..9fde31b 100644 --- a/src/level1/asum_device.mojo +++ b/src/level1/asum_device.mojo @@ -18,9 +18,6 @@ fn asum_device[ incx: Int, result: UnsafePointer[Scalar[dtype], MutAnyOrigin] ): - if n < 1 or incx <= 0: - result[0] = 0 - return var local_tid = thread_idx.x @@ -65,6 +62,9 @@ fn blas_asum[dtype: DType]( d_res: UnsafePointer[Scalar[dtype], MutAnyOrigin], ctx: DeviceContext ) raises: + blas_error_if["blas_asum", "n < 0"](n < 0) + blas_error_if["blas_asum", "incx <= 0"](incx <= 0) + comptime kernel = asum_device[TBsize, dtype] ctx.enqueue_function[kernel, kernel]( n, diff --git a/src/level1/axpy_device.mojo b/src/level1/axpy_device.mojo index 953d65a..0626d08 100644 --- a/src/level1/axpy_device.mojo +++ b/src/level1/axpy_device.mojo @@ -12,12 +12,6 @@ fn axpy_device[dtype: DType]( y: UnsafePointer[Scalar[dtype], MutAnyOrigin], incy: Int ): - if (n <= 0): - return - if (a == 0): - return - if (incx == 0 or incy == 0): - return var global_i = global_idx.x var n_threads = Int(grid_dim.x * block_dim.x) @@ -36,6 +30,14 @@ fn blas_axpy[dtype: DType]( incy: Int, ctx: DeviceContext ) raises: + blas_error_if["blas_axpy", "n < 0"](n < 0) + blas_error_if["blas_axpy", "incx == 0"](incx == 0) + blas_error_if["blas_axpy", "incy == 0"](incy == 0) + + # quick return + if(a == 0) : + return + comptime kernel = axpy_device[dtype] ctx.enqueue_function[kernel, kernel]( n, a, diff --git a/src/level1/copy_device.mojo b/src/level1/copy_device.mojo index ac44ada..0784174 100644 --- a/src/level1/copy_device.mojo +++ b/src/level1/copy_device.mojo @@ -13,10 +13,6 @@ fn copy_device[dtype: DType]( y: UnsafePointer[Scalar[dtype], MutAnyOrigin], incy: Int ): - if (n <= 0): - return - if (incx == 0 or incy == 0): - return var global_i = global_idx.x var n_threads = Int(grid_dim.x * block_dim.x) @@ -34,6 +30,12 @@ fn blas_copy[dtype: DType]( incy: Int, ctx: DeviceContext ) raises: + + blas_error_if["blas_copy", "n < 0"](n < 0) + blas_error_if["blas_copy", "incx == 0"](incx == 0) + blas_error_if["blas_copy", "incy == 0"](incy == 0) + + comptime kernel = copy_device[dtype] ctx.enqueue_function[kernel, kernel]( n, diff --git a/src/level1/dot_device.mojo b/src/level1/dot_device.mojo index 68025d6..87e8d5f 100644 --- a/src/level1/dot_device.mojo +++ b/src/level1/dot_device.mojo @@ -19,8 +19,6 @@ fn dot_device[ incy: Int, output: UnsafePointer[Scalar[dtype], MutAnyOrigin], ): - if n < 1: - return var global_i = block_dim.x * block_idx.x + thread_idx.x var n_threads = grid_dim.x * block_dim.x @@ -59,6 +57,11 @@ fn blas_dot[dtype: DType]( d_out: UnsafePointer[Scalar[dtype], MutAnyOrigin], ctx: DeviceContext ) raises: + + blas_error_if["blas_dot", "n < 0"](n < 0) + blas_error_if["blas_copy", "incx == 0"](incx == 0) + blas_error_if["blas_copy", "incy == 0"](incy == 0) + comptime kernel = dot_device[TBsize, dtype] ctx.enqueue_function[kernel, kernel]( n, d_x, incx, diff --git a/src/level1/dotc_device.mojo b/src/level1/dotc_device.mojo index f4ae13b..c14fac7 100644 --- a/src/level1/dotc_device.mojo +++ b/src/level1/dotc_device.mojo @@ -20,8 +20,6 @@ fn dotc_device[ incy: Int, output: UnsafePointer[Scalar[dtype], MutAnyOrigin], ): - if n < 1: - return var global_i = block_dim.x * block_idx.x + thread_idx.x var n_threads = grid_dim.x * block_dim.x @@ -70,6 +68,10 @@ fn blas_dotc[dtype: DType]( d_out: UnsafePointer[Scalar[dtype], MutAnyOrigin], ctx: DeviceContext ) raises: + blas_error_if["blas_dotc", "n < 0"](n < 0) + blas_error_if["blas_dotc", "incx == 0"](incx == 0) + blas_error_if["blas_dotc", "incy == 0"](incy == 0) + comptime kernel = dotc_device[TBsize, dtype] ctx.enqueue_function[kernel, kernel]( n, d_x, incx, diff --git a/src/level1/dotu_device.mojo b/src/level1/dotu_device.mojo index 957037a..6d95df1 100644 --- a/src/level1/dotu_device.mojo +++ b/src/level1/dotu_device.mojo @@ -20,8 +20,7 @@ fn dotu_device[ incy: Int, output: UnsafePointer[Scalar[dtype], MutAnyOrigin], ): - if n < 1: - return + var global_i = block_dim.x * block_idx.x + thread_idx.x var n_threads = grid_dim.x * block_dim.x @@ -70,6 +69,11 @@ fn blas_dotu[dtype: DType]( d_out: UnsafePointer[Scalar[dtype], MutAnyOrigin], ctx: DeviceContext ) raises: + + blas_error_if["blas_dotu", "n < 0"](n < 0) + blas_error_if["blas_dotu", "incx == 0"](incx == 0) + blas_error_if["blas_dotu", "incy == 0"](incy == 0) + comptime kernel = dotu_device[TBsize, dtype] ctx.enqueue_function[kernel, kernel]( n, d_x, incx, diff --git a/src/level1/iamax_device.mojo b/src/level1/iamax_device.mojo index 4a70e45..aba2f53 100644 --- a/src/level1/iamax_device.mojo +++ b/src/level1/iamax_device.mojo @@ -89,6 +89,11 @@ fn blas_iamax[dtype: DType]( d_res: UnsafePointer[Scalar[DType.int64], MutAnyOrigin], ctx: DeviceContext ) raises: + + blas_error_if["blas_iamax", "n < 0"](n<=0) + blas_error_if["blas_iamax", "incx <= 0"](incx <= 0) + + comptime kernel = iamax_device[TBsize, dtype] ctx.enqueue_function[kernel, kernel]( n, d_v, incx, diff --git a/src/level1/nrm2_device.mojo b/src/level1/nrm2_device.mojo index 6627f3a..df9ae23 100644 --- a/src/level1/nrm2_device.mojo +++ b/src/level1/nrm2_device.mojo @@ -58,6 +58,9 @@ fn blas_nrm2[dtype: DType]( d_out: UnsafePointer[Scalar[dtype], MutAnyOrigin], ctx: DeviceContext ) raises: + blas_error_if["blas_nrm2", "n < 0"](n < 0) + blas_error_if["blas_nrm2", "incx <= 0"](incx <= 0) + comptime kernel = nrm2_device[TBsize, dtype] ctx.enqueue_function[kernel, kernel]( n, d_x, incx, d_out, diff --git a/src/level1/rot_device.mojo b/src/level1/rot_device.mojo index d4bf07b..63162be 100644 --- a/src/level1/rot_device.mojo +++ b/src/level1/rot_device.mojo @@ -17,8 +17,7 @@ fn rot_device[ c: Scalar[dtype], s: Scalar[dtype] ): - if (n < 1): - return + var global_tid = block_idx.x * block_dim.x + thread_idx.x var n_threads = grid_dim.x * block_dim.x @@ -43,6 +42,15 @@ fn blas_rot[dtype: DType]( s: Scalar[dtype], ctx: DeviceContext ) raises: + blas_error_if["blas_rot", "n < 0"](n < 0) + blas_error_if["blas_rot", "incx == 0"](incx == 0) + blas_error_if["blas_rot", "incy == 0"](incy == 0) + + + # quick return + if(n == 0 or (c == 1 and s == 0)) : + return + comptime kernel = rot_device[dtype] ctx.enqueue_function[kernel, kernel]( n, diff --git a/src/level1/rotm_device.mojo b/src/level1/rotm_device.mojo index 940b5c0..e290130 100644 --- a/src/level1/rotm_device.mojo +++ b/src/level1/rotm_device.mojo @@ -17,8 +17,6 @@ fn rotm_device[ param: UnsafePointer[Scalar[dtype], MutAnyOrigin] ): var flag = param[0] - if (n < 1): - return var idx = block_idx.x * block_dim.x + thread_idx.x var n_threads = grid_dim.x * block_dim.x @@ -63,6 +61,12 @@ fn blas_rotm[dtype: DType]( d_param: UnsafePointer[Scalar[dtype], MutAnyOrigin], ctx: DeviceContext ) raises: + blas_error_if["blas_rotm", "n < 0"](n < 0) + blas_error_if["blas_rotm", "incx == 0"](incx == 0) + blas_error_if["blas_rotm", "incy == 0"](incy == 0) + + if(n == 0 ): + return comptime kernel = rotm_device[dtype] ctx.enqueue_function[kernel, kernel]( n, diff --git a/src/level1/scal_device.mojo b/src/level1/scal_device.mojo index e35f801..d16f069 100644 --- a/src/level1/scal_device.mojo +++ b/src/level1/scal_device.mojo @@ -10,12 +10,6 @@ fn scal_device[dtype: DType]( x: UnsafePointer[Scalar[dtype], MutAnyOrigin], incx: Int, ): - if (n <= 0): - return - if (a == 0): - return - if (incx == 0): - return var global_i = global_idx.x var n_threads = Int(grid_dim.x * block_dim.x) @@ -32,6 +26,11 @@ fn blas_scal[dtype: DType] ( incx: Int, ctx: DeviceContext ) raises: + + blas_error_if["blas_scal", "n < 0"](n < 0) + blas_error_if["blas_scal", "incx <= 0"](incx <= 0) + + comptime kernel = scal_device[dtype] ctx.enqueue_function[kernel, kernel]( n, a, d_x, incx, diff --git a/src/level1/swap_device.mojo b/src/level1/swap_device.mojo index b880ccc..43b1041 100644 --- a/src/level1/swap_device.mojo +++ b/src/level1/swap_device.mojo @@ -11,10 +11,7 @@ fn swap_device[dtype: DType]( y: UnsafePointer[Scalar[dtype], MutAnyOrigin], incy: Int ): - if (n <= 0): - return - if (incx == 0 or incy == 0): - return + var global_i = global_idx.x var n_threads = Int(grid_dim.x * block_dim.x) @@ -34,6 +31,10 @@ fn blas_swap[dtype: DType]( incy: Int, ctx: DeviceContext ) raises: + blas_error_if["blas_swap", "n < 0"](n < 0) + blas_error_if["blas_swap", "incx == 0"](incx == 0) + blas_error_if["blas_swap", "incy == 0"](incy == 0) + comptime kernel = swap_device[dtype] ctx.enqueue_function[kernel, kernel]( n, diff --git a/src/util.mojo b/src/util.mojo new file mode 100644 index 0000000..26e22e7 --- /dev/null +++ b/src/util.mojo @@ -0,0 +1,8 @@ + +fn blas_error_if[caller: String, cond_str: String](cond: Bool) raises: + """ + Function raises an error describing the bad paramters passed to caller. + """ + if(cond) : + raise Error("Error: {} in {}".format(cond_str, caller)) +