From 0803beeef5500ed25296df51f66b5e999e90c880 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:11:25 +0200 Subject: [PATCH 1/5] feat: add print options for precision --- mlx/utils.cpp | 53 +++++++++++++++++++++++++++++++++++--------- mlx/utils.h | 3 +++ python/mlx/utils.py | 3 ++- python/src/array.cpp | 53 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 12 deletions(-) diff --git a/mlx/utils.cpp b/mlx/utils.cpp index cf0e0f38db..f4188898a4 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #include +#include #include #include #include @@ -39,7 +40,7 @@ void PrintFormatter::print(std::ostream& os, bool val) { } } inline void PrintFormatter::print(std::ostream& os, int16_t val) { - os << val; + os << val; } inline void PrintFormatter::print(std::ostream& os, uint16_t val) { os << val; @@ -57,24 +58,49 @@ inline void PrintFormatter::print(std::ostream& os, uint64_t val) { os << val; } inline void PrintFormatter::print(std::ostream& os, float16_t val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, float val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, double val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val.real(); - if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << val.imag() << "j"; - } else { - os << "-" << -val.imag() << "j"; - } + if (precision == -1) { + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } + } else { + os << std::fixed << std::setprecision(precision) << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << std::fixed << std::setprecision(precision) << val.imag() << "j"; + } else { + os << "-" << std::fixed << std::setprecision(precision) << -val.imag() << "j"; + } + } } PrintFormatter& get_global_formatter() { @@ -82,6 +108,11 @@ PrintFormatter& get_global_formatter() { return formatter; } +void set_printoptions(int precision) { + auto &formatter = get_global_formatter(); + formatter.precision = precision; +} + void abort_with_exception(const std::exception& error) { std::ostringstream msg; msg << "Terminating due to uncaught exception: " << error.what(); diff --git a/mlx/utils.h b/mlx/utils.h index 62aa82b658..6be5f09ca2 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -53,8 +53,11 @@ struct PrintFormatter { inline void print(std::ostream& os, complex64_t val); bool capitalize_bool{false}; + int precision{-1}; }; +MLX_API void set_printoptions(int precision); + MLX_API PrintFormatter& get_global_formatter(); /** Print the exception and then abort. */ diff --git a/python/mlx/utils.py b/python/mlx/utils.py index f4aafe1e3d..35a8829485 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,8 +1,9 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict from itertools import zip_longest +from multiprocessing import context from typing import Any, Callable, Dict, List, Optional, Tuple, Union - +from contextlib import contextmanager def tree_map( fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None diff --git a/python/src/array.cpp b/python/src/array.cpp index 231474c2d9..1a07e16463 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -96,10 +96,63 @@ class ArrayPythonIterator { std::vector splits_; }; +struct PrintOptionsContext { + int old_precision; + int new_precision; + PrintOptionsContext(int p) : new_precision(p) {} + PrintOptionsContext& __enter__() { + old_precision = mx::get_global_formatter().precision; + mx::set_printoptions(new_precision); + return *this; + } + void __exit__(nb::args) { + mx::set_printoptions(old_precision); + } +}; + void init_array(nb::module_& m) { // Set Python print formatting options mx::get_global_formatter().capitalize_bool = true; + // Expose printing options to Python: allow setting global precision. + m.def( + "set_printoptions", + &mx::set_printoptions, + "precision"_a, + R"pbdoc( + Set global printing precision for array formatting. + + Args: + precision (int): Number of decimal places to use when printing + floating point numbers in arrays. + )pbdoc"); + m.def( + "get_printoptions", + []() { return mx::get_global_formatter().precision; }, + R"pbdoc( + Get global printing precision for array formatting. + + Returns: + int: The number of decimal places used when printing floating point + numbers in arrays. + )pbdoc"); + + nb::class_(m, "_PrintOptionsContext") + .def(nb::init()) + .def("__enter__", &PrintOptionsContext::__enter__) + .def("__exit__", &PrintOptionsContext::__exit__); + + m.def( + "printoptions", + [](int precision) { return PrintOptionsContext(precision); }, + "precision"_a, + R"pbdoc( + Context manager for setting print options temporarily. + + Args: + precision (int): Number of decimal places. + )pbdoc"); + // Types nb::class_( m, From 73d01b03c0782bdc680077d244e2dc77051078cd Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:11:42 +0200 Subject: [PATCH 2/5] test: add test for print options --- python/tests/test_array.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 86328c2a1b..acf70e0e07 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -597,6 +597,28 @@ def test_array_repr(self): x = mx.array([1 - 1j], dtype=mx.complex64) expected = "array([1-1j], dtype=complex64)" + def test_array_repr_precision(self): + x = mx.array([1.123456789], dtype=mx.float32) + expected = "array([1.12346], dtype=float32)" + self.assertEqual(str(x), expected) + + with mx.printoptions(precision=4): + expected = "array([1.1235], dtype=float32)" + self.assertEqual(str(x), expected) + + mx.set_printoptions(precision=2) + expected = "array([1.12], dtype=float32)" + self.assertEqual(str(x), expected) + + x = mx.sin(x) + expected = "array([0.90], dtype=float32)" + self.assertEqual(str(x), expected) + + with mx.printoptions(precision=4): + expected = "array([0.9016], dtype=float32)" + self.assertEqual(str(x), expected) + + def test_array_to_list(self): types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32] for t in types: From c3ce56e0506996ae57c28e400e19a835c87e6e37 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:28:18 +0200 Subject: [PATCH 3/5] style: reformat code --- mlx/utils.cpp | 72 ++++++++++++++++++++------------------ python/mlx/utils.py | 3 +- python/tests/test_array.py | 1 - 3 files changed, 39 insertions(+), 37 deletions(-) diff --git a/mlx/utils.cpp b/mlx/utils.cpp index f4188898a4..56b69d932a 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -40,7 +40,7 @@ void PrintFormatter::print(std::ostream& os, bool val) { } } inline void PrintFormatter::print(std::ostream& os, int16_t val) { - os << val; + os << val; } inline void PrintFormatter::print(std::ostream& os, uint16_t val) { os << val; @@ -58,49 +58,51 @@ inline void PrintFormatter::print(std::ostream& os, uint64_t val) { os << val; } inline void PrintFormatter::print(std::ostream& os, float16_t val) { - if (precision == -1) { - os << val; - } else { - os << std::fixed << std::setprecision(precision) << val; - } + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { - if (precision == -1) { - os << val; - } else { - os << std::fixed << std::setprecision(precision) << val; - } + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, float val) { - if (precision == -1) { - os << val; - } else { - os << std::fixed << std::setprecision(precision) << val; - } + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, double val) { - if (precision == -1) { - os << val; - } else { - os << std::fixed << std::setprecision(precision) << val; - } + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - if (precision == -1) { - os << val.real(); - if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << val.imag() << "j"; - } else { - os << "-" << -val.imag() << "j"; - } + if (precision == -1) { + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } + } else { + os << std::fixed << std::setprecision(precision) << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << std::fixed << std::setprecision(precision) << val.imag() + << "j"; } else { - os << std::fixed << std::setprecision(precision) << val.real(); - if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << std::fixed << std::setprecision(precision) << val.imag() << "j"; - } else { - os << "-" << std::fixed << std::setprecision(precision) << -val.imag() << "j"; - } + os << "-" << std::fixed << std::setprecision(precision) << -val.imag() + << "j"; } + } } PrintFormatter& get_global_formatter() { @@ -109,7 +111,7 @@ PrintFormatter& get_global_formatter() { } void set_printoptions(int precision) { - auto &formatter = get_global_formatter(); + auto& formatter = get_global_formatter(); formatter.precision = precision; } diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 35a8829485..c5fed71429 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,9 +1,10 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict +from contextlib import contextmanager from itertools import zip_longest from multiprocessing import context from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from contextlib import contextmanager + def tree_map( fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None diff --git a/python/tests/test_array.py b/python/tests/test_array.py index acf70e0e07..6fdb27c61f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -618,7 +618,6 @@ def test_array_repr_precision(self): expected = "array([0.9016], dtype=float32)" self.assertEqual(str(x), expected) - def test_array_to_list(self): types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32] for t in types: From 64f8413629d1c2709d8fdf0414fd059c9d6957b0 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:44:29 +0200 Subject: [PATCH 4/5] docs: add documentation for printoptions --- docs/src/index.rst | 3 ++- docs/src/python/printoptions.rst | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 docs/src/python/printoptions.rst diff --git a/docs/src/index.rst b/docs/src/index.rst index 74c52aaa2b..46d069929f 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -32,7 +32,7 @@ are the CPU and GPU. install .. toctree:: - :caption: Usage + :caption: Usage :maxdepth: 1 usage/quick_start @@ -78,6 +78,7 @@ are the CPU and GPU. python/optimizers python/distributed python/tree_utils + python/printoptions .. toctree:: :caption: C++ API Reference diff --git a/docs/src/python/printoptions.rst b/docs/src/python/printoptions.rst new file mode 100644 index 0000000000..ee7d3c191a --- /dev/null +++ b/docs/src/python/printoptions.rst @@ -0,0 +1,11 @@ +Print Options +============ + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + set_printoptions + printoptions + get_printoptions From 8de3e6a1c122d5e8f6489bf7df7a1b4fccc672d9 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:50:06 +0200 Subject: [PATCH 5/5] fix: remove unused deps --- python/mlx/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index c5fed71429..f4aafe1e3d 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,8 +1,6 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict -from contextlib import contextmanager from itertools import zip_longest -from multiprocessing import context from typing import Any, Callable, Dict, List, Optional, Tuple, Union