-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Add printoptions #3333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add printoptions #3333
Changes from all commits
0803bee
73d01b0
c3ce56e
64f8413
8de3e6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| Print Options | ||
| ============ | ||
|
|
||
| .. currentmodule:: mlx.core | ||
|
|
||
| .. autosummary:: | ||
| :toctree: _autosummary | ||
|
|
||
| set_printoptions | ||
| printoptions | ||
| get_printoptions | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,8 +53,11 @@ struct PrintFormatter { | |
| inline void print(std::ostream& os, complex64_t val); | ||
|
|
||
| bool capitalize_bool{false}; | ||
| int precision{-1}; | ||
| }; | ||
|
|
||
| MLX_API void set_printoptions(int precision); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you make the API take a |
||
|
|
||
| MLX_API PrintFormatter& get_global_formatter(); | ||
|
|
||
| /** Print the exception and then abort. */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -96,10 +96,63 @@ class ArrayPythonIterator { | |||||
| std::vector<mx::array> splits_; | ||||||
| }; | ||||||
|
|
||||||
| struct PrintOptionsContext { | ||||||
| int old_precision; | ||||||
| int new_precision; | ||||||
| PrintOptionsContext(int p) : new_precision(p) {} | ||||||
| PrintOptionsContext& __enter__() { | ||||||
| old_precision = mx::get_global_formatter().precision; | ||||||
| mx::set_printoptions(new_precision); | ||||||
| return *this; | ||||||
| } | ||||||
| void __exit__(nb::args) { | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the C++ code we don't add
Suggested change
|
||||||
| mx::set_printoptions(old_precision); | ||||||
| } | ||||||
| }; | ||||||
|
|
||||||
| void init_array(nb::module_& m) { | ||||||
| // Set Python print formatting options | ||||||
| mx::get_global_formatter().capitalize_bool = true; | ||||||
|
|
||||||
| // Expose printing options to Python: allow setting global precision. | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move the print APIs to a new |
||||||
| m.def( | ||||||
| "set_printoptions", | ||||||
| &mx::set_printoptions, | ||||||
| "precision"_a, | ||||||
| R"pbdoc( | ||||||
| Set global printing precision for array formatting. | ||||||
|
|
||||||
| Args: | ||||||
| precision (int): Number of decimal places to use when printing | ||||||
| floating point numbers in arrays. | ||||||
| )pbdoc"); | ||||||
| m.def( | ||||||
| "get_printoptions", | ||||||
| []() { return mx::get_global_formatter().precision; }, | ||||||
| R"pbdoc( | ||||||
| Get global printing precision for array formatting. | ||||||
|
|
||||||
| Returns: | ||||||
| int: The number of decimal places used when printing floating point | ||||||
| numbers in arrays. | ||||||
| )pbdoc"); | ||||||
|
|
||||||
| nb::class_<PrintOptionsContext>(m, "_PrintOptionsContext") | ||||||
| .def(nb::init<int>()) | ||||||
| .def("__enter__", &PrintOptionsContext::__enter__) | ||||||
| .def("__exit__", &PrintOptionsContext::__exit__); | ||||||
|
|
||||||
| m.def( | ||||||
| "printoptions", | ||||||
| [](int precision) { return PrintOptionsContext(precision); }, | ||||||
| "precision"_a, | ||||||
| R"pbdoc( | ||||||
| Context manager for setting print options temporarily. | ||||||
|
|
||||||
| Args: | ||||||
| precision (int): Number of decimal places. | ||||||
| )pbdoc"); | ||||||
|
|
||||||
| // Types | ||||||
| nb::class_<mx::Dtype>( | ||||||
| m, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fix the warning when building docs?