diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 97fbf6c4c5..0c4f24e173 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -11,6 +11,10 @@ import numpy as np +def _to_np_dtype(dt): + return np.float32 if dt == "bfloat16" else getattr(np, dt) + + class TestLoad(mlx_tests.MLXTestCase): dtypes = [ "uint8", @@ -38,7 +42,7 @@ def tearDownClass(cls): cls.test_dir_fid.cleanup() def test_save_and_load(self): - for dt in self.dtypes: + for dt in self.dtypes + ["bfloat16"]: with self.subTest(dtype=dt): for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): with self.subTest(shape=shape): @@ -46,8 +50,8 @@ def test_save_and_load(self): save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy") save_arr = np.random.uniform(0.0, 32.0, size=shape) - save_arr_npy = save_arr.astype(getattr(np, dt)) - save_arr_mlx = mx.array(save_arr_npy) + save_arr_npy = save_arr.astype(_to_np_dtype(dt)) + save_arr_mlx = mx.array(save_arr_npy, dtype=getattr(mx, dt)) mx.save(save_file_mlx, save_arr_mlx) np.save(save_file_npy, save_arr_npy) @@ -304,7 +308,7 @@ def test_save_and_load_fs(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) - for dt in self.dtypes: + for dt in self.dtypes + ["bfloat16"]: with self.subTest(dtype=dt): for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): with self.subTest(shape=shape): @@ -316,8 +320,8 @@ def test_save_and_load_fs(self): ) save_arr = np.random.uniform(0.0, 32.0, size=shape) - save_arr_npy = save_arr.astype(getattr(np, dt)) - save_arr_mlx = mx.array(save_arr_npy) + save_arr_npy = save_arr.astype(_to_np_dtype(dt)) + save_arr_mlx = mx.array(save_arr_npy, dtype=getattr(mx, dt)) with open(save_file_mlx, "wb") as f: mx.save(f, save_arr_mlx) @@ -342,7 +346,7 @@ def test_savez_and_loadz(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) - for dt in self.dtypes: + for dt in self.dtypes + ["bfloat16"]: with self.subTest(dtype=dt): shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)] save_file_mlx_uncomp = os.path.join( @@ -358,10 +362,13 @@ def test_savez_and_loadz(self): save_arrs_npy = { f"save_arr_{i}": np.random.uniform( 0.0, 32.0, size=shapes[i] - ).astype(getattr(np, dt)) + ).astype(_to_np_dtype(dt)) for i in range(len(shapes)) } - save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()} + save_arrs_mlx = { + k: mx.array(v, dtype=getattr(mx, dt)) + for k, v in save_arrs_npy.items() + } # Save as npz files np.savez(save_file_npy_uncomp, **save_arrs_npy)