diff --git a/docs/src/index.rst b/docs/src/index.rst index 74c52aaa2b..46d069929f 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -32,7 +32,7 @@ are the CPU and GPU. install .. toctree:: - :caption: Usage + :caption: Usage :maxdepth: 1 usage/quick_start @@ -78,6 +78,7 @@ are the CPU and GPU. python/optimizers python/distributed python/tree_utils + python/printoptions .. toctree:: :caption: C++ API Reference diff --git a/docs/src/python/printoptions.rst b/docs/src/python/printoptions.rst new file mode 100644 index 0000000000..ee7d3c191a --- /dev/null +++ b/docs/src/python/printoptions.rst @@ -0,0 +1,11 @@ +Print Options +============ + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + set_printoptions + printoptions + get_printoptions diff --git a/mlx/utils.cpp b/mlx/utils.cpp index cf0e0f38db..56b69d932a 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #include +#include #include #include #include @@ -57,23 +58,50 @@ inline void PrintFormatter::print(std::ostream& os, uint64_t val) { os << val; } inline void PrintFormatter::print(std::ostream& os, float16_t val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, float val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, double val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val.real(); - if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << val.imag() << "j"; + if (precision == -1) { + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } } else { - os << "-" << -val.imag() << "j"; + os << std::fixed << std::setprecision(precision) << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << std::fixed << std::setprecision(precision) << val.imag() + << "j"; + } else { + os << "-" << std::fixed << std::setprecision(precision) << -val.imag() + << "j"; + } } } @@ -82,6 +110,11 @@ PrintFormatter& get_global_formatter() { return formatter; } +void set_printoptions(int precision) { + auto& formatter = get_global_formatter(); + formatter.precision = precision; +} + void abort_with_exception(const std::exception& error) { std::ostringstream msg; msg << "Terminating due to uncaught exception: " << error.what(); diff --git a/mlx/utils.h b/mlx/utils.h index 62aa82b658..6be5f09ca2 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -53,8 +53,11 @@ struct PrintFormatter { inline void print(std::ostream& os, complex64_t val); bool capitalize_bool{false}; + int precision{-1}; }; +MLX_API void set_printoptions(int precision); + MLX_API PrintFormatter& get_global_formatter(); /** Print the exception and then abort. */ diff --git a/python/src/array.cpp b/python/src/array.cpp index 231474c2d9..1a07e16463 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -96,10 +96,63 @@ class ArrayPythonIterator { std::vector splits_; }; +struct PrintOptionsContext { + int old_precision; + int new_precision; + PrintOptionsContext(int p) : new_precision(p) {} + PrintOptionsContext& __enter__() { + old_precision = mx::get_global_formatter().precision; + mx::set_printoptions(new_precision); + return *this; + } + void __exit__(nb::args) { + mx::set_printoptions(old_precision); + } +}; + void init_array(nb::module_& m) { // Set Python print formatting options mx::get_global_formatter().capitalize_bool = true; + // Expose printing options to Python: allow setting global precision. + m.def( + "set_printoptions", + &mx::set_printoptions, + "precision"_a, + R"pbdoc( + Set global printing precision for array formatting. + + Args: + precision (int): Number of decimal places to use when printing + floating point numbers in arrays. + )pbdoc"); + m.def( + "get_printoptions", + []() { return mx::get_global_formatter().precision; }, + R"pbdoc( + Get global printing precision for array formatting. + + Returns: + int: The number of decimal places used when printing floating point + numbers in arrays. + )pbdoc"); + + nb::class_(m, "_PrintOptionsContext") + .def(nb::init()) + .def("__enter__", &PrintOptionsContext::__enter__) + .def("__exit__", &PrintOptionsContext::__exit__); + + m.def( + "printoptions", + [](int precision) { return PrintOptionsContext(precision); }, + "precision"_a, + R"pbdoc( + Context manager for setting print options temporarily. + + Args: + precision (int): Number of decimal places. + )pbdoc"); + // Types nb::class_( m, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 86328c2a1b..6fdb27c61f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -597,6 +597,27 @@ def test_array_repr(self): x = mx.array([1 - 1j], dtype=mx.complex64) expected = "array([1-1j], dtype=complex64)" + def test_array_repr_precision(self): + x = mx.array([1.123456789], dtype=mx.float32) + expected = "array([1.12346], dtype=float32)" + self.assertEqual(str(x), expected) + + with mx.printoptions(precision=4): + expected = "array([1.1235], dtype=float32)" + self.assertEqual(str(x), expected) + + mx.set_printoptions(precision=2) + expected = "array([1.12], dtype=float32)" + self.assertEqual(str(x), expected) + + x = mx.sin(x) + expected = "array([0.90], dtype=float32)" + self.assertEqual(str(x), expected) + + with mx.printoptions(precision=4): + expected = "array([0.9016], dtype=float32)" + self.assertEqual(str(x), expected) + def test_array_to_list(self): types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32] for t in types: