diff --git a/src/level1/__init__.mojo b/src/level1/__init__.mojo index 80477f0..14d7d04 100644 --- a/src/level1/__init__.mojo +++ b/src/level1/__init__.mojo @@ -5,6 +5,7 @@ from .copy_device import * from .rot_device import * from .rotg import * from .rotm_device import * +from .rotmg import * from .swap_device import * from .dot_device import * from .dotc_device import * diff --git a/src/level1/rotmg.mojo b/src/level1/rotmg.mojo new file mode 100644 index 0000000..4fdbfbf --- /dev/null +++ b/src/level1/rotmg.mojo @@ -0,0 +1,150 @@ +fn blas_rotmg[dtype: DType]( + d1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + d2: UnsafePointer[Scalar[dtype], MutAnyOrigin], + x1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + y1: UnsafePointer[Scalar[dtype], MutAnyOrigin], + mut param: List[Scalar[dtype]] +): + + var flag: Scalar[dtype] + var h11: Scalar[dtype] = 0 + var h12: Scalar[dtype] = 0 + var h21: Scalar[dtype] = 0 + var h22: Scalar[dtype] = 0 + var p1: Scalar[dtype] + var p2: Scalar[dtype] + var q1: Scalar[dtype] + var q2: Scalar[dtype] + var temp: Scalar[dtype] + var u: Scalar[dtype] + var gam: Scalar[dtype] = 4096.0 + var gamsq: Scalar[dtype] = 16777216.0 + var rgamsq: Scalar[dtype] = 5.9604645e-8 + + if (d1[] < 0): + # GO 0-H-D-AND-x1.. + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + # CASE-d1-NONNEGATIVE + p2 = d2[]*y1[] + if (p2 == 0): + flag = -2 + param[0] = flag + return + + # REGULAR-CASE.. + p1 = d1[]*x1[] + q2 = p2*y1[] + q1 = p1*x1[] + # + if (abs(q1) > abs(q2)): + h21 = -y1[]/x1[] + h12 = p2/p1 + # + u = 1 - h12*h21 + # + if (u > 0): + flag = 0 + d1[] /= u + d2[] /= u + x1[] *= u + else: + # This code path if here for safety. We do not expect this + # condition to ever hold except in edge cases with rounding + # errors. See DOI: 10.1145/355841.355847 + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + # + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + if (q2 < 0): + # GO 0-H-D-AND-x1.. + flag = -1 + h11 = 0 + h12 = 0 + h21 = 0 + h22 = 0 + # + d1[] = 0 + d2[] = 0 + x1[] = 0 + else: + flag = 1 + h11 = p1/p2 + h22 = x1[]/y1[] + u = 1 + h11*h22 + temp = d2[]/u + d2[] = d1[]/u + d1[] = temp + x1[] = y1[]*u + + # PROCEDURE..SCALE-CHECK + if (d1[] != 0): + while ((d1[] <= rgamsq) or (d1[] >= gamsq)): + if (flag == 0): + h11 = 1 + h22 = 1 + flag = -1 + else: + h21 = -1 + h12 = 1 + flag = -1 + + if (d1[] <= rgamsq): + d1[] *= gam**2 + x1[] /= gam + h11 /= gam + h12 /= gam + else: + d1[] /= gam**2 + x1[] *= gam + h11 *= gam + h12 *= gam + + if (d2[] != 0): + while ( (abs(d2[]) <= rgamsq) or (abs(d2[]) >= gamsq) ): + if (flag == 0): + h11 = 1 + h22 = 1 + flag = -1 + else: + h21 = -1 + h12 = 1 + flag = -1 + + if (abs(d2[]) <= rgamsq): + d2[] *= gam**2 + h21 /= gam + h22 /= gam + else: + d2[] /= gam**2 + h21 *= gam + h22 *= gam + + if (flag < 0): + param[1] = h11 + param[2] = h21 + param[3] = h12 + param[4] = h22 + elif (flag == 0): + param[2] = h21 + param[3] = h12 + else: + param[1] = h11 + param[4] = h22 + + param[0] = flag + return diff --git a/test-level1.mojo b/test-level1.mojo index 1820efc..548d8d6 100644 --- a/test-level1.mojo +++ b/test-level1.mojo @@ -691,57 +691,51 @@ def rotm_test[ assert_almost_equal(y_result[i], expected_y, atol=atol) -# def rotmg_test[ -# dtype: DType, -# size: Int -# ](): -# with DeviceContext() as ctx: -# # d1 and d2 must be positive -# var d1 = generate_random_scalar[dtype](1, 10000) -# var d2 = generate_random_scalar[dtype](1, 10000) -# var x1 = generate_random_scalar[dtype](-10000, 10000) -# var y1 = generate_random_scalar[dtype](-10000, 10000) - -# d_d1 = ctx.enqueue_create_buffer[dtype](1) -# d_d1.enqueue_fill(d1) -# d_d2 = ctx.enqueue_create_buffer[dtype](1) -# d_d2.enqueue_fill(d2) -# d_x1 = ctx.enqueue_create_buffer[dtype](1) -# d_x1.enqueue_fill(x1) -# d_y1 = ctx.enqueue_create_buffer[dtype](1) -# d_y1.enqueue_fill(y1) -# d_param = ctx.enqueue_create_buffer[dtype](5) - -# # Launch Mojo BLAS kernel -# # NOTE: not implemented -# # blas_rotmg[dtype]( -# # d1.unsafe_ptr(), -# # d2.unsafe_ptr(), -# # x1.unsafe_ptr(), -# # x2.unsafe_ptr(), -# # d_param.unsafe_ptr(), -# # ctx -# # ) - -# # Import SciPy and numpy -# sp = Python.import_module("scipy") -# np = Python.import_module("numpy") -# sp_blas = sp.linalg.blas - -# # srotmg - float32, drotmg - float64 -# if dtype == DType.float32: -# py_p = sp_blas.srotmg(d1, d2, x1, y1) -# elif dtype == DType.float64: -# py_p = sp_blas.drotmg(d1, d2, x1, y1) -# else: -# print(dtype , " is not supported by SciPy") -# return - -# # Only compare param -# with d_param.map_to_host() as mojo_param: -# for i in range(5): -# var py_ref = Scalar[dtype](py=py_p[i]) -# assert_equal(mojo_param[i], py_ref) +def rotmg_test[ + dtype: DType, + size: Int +](): + # d1 and d2 must be positive + var d1 = generate_random_scalar[dtype](1, 10000) + var d2 = generate_random_scalar[dtype](1, 10000) + var x1 = generate_random_scalar[dtype](-10000, 10000) + var y1 = generate_random_scalar[dtype](-10000, 10000) + var param = List[Scalar[dtype]](length=5, fill=0.0) + + # Launch Mojo BLAS kernel + blas_rotmg[dtype]( + UnsafePointer(to=d1), + UnsafePointer(to=d2), + UnsafePointer(to=x1), + UnsafePointer(to=y1), + param + ) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + # srotmg - float32, drotmg - float64 + if dtype == DType.float32: + py_p = sp_blas.srotmg(d1, d2, x1, y1) + elif dtype == DType.float64: + py_p = sp_blas.drotmg(d1, d2, x1, y1) + else: + print(dtype , " is not supported by SciPy") + return + + # print(d1) + # print(d2) + # print(x1) + # print(y1) + # print(param) + # print(py_p) + + # Only compare param + for i in range(5): + var py_ref = Scalar[dtype](py=py_p[i]) + assert_equal(param[i], py_ref) def scal_test[ @@ -893,11 +887,11 @@ def test_rotm(): rotm_test[DType.float64, 256]() rotm_test[DType.float64, 4096]() -# def test_rotmg(): -# rotmg_test[DType.float32, 256]() -# rotmg_test[DType.float32, 4096]() -# rotmg_test[DType.float64, 256]() -# rotmg_test[DType.float64, 4096]() +def test_rotmg(): + rotmg_test[DType.float32, 256]() + rotmg_test[DType.float32, 4096]() + rotmg_test[DType.float64, 256]() + rotmg_test[DType.float64, 4096]() def test_scal(): scal_test[DType.float32, 256]()