From 3cc73949f7f9bea205c6b7f2a006744e52d46cd5 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Mon, 4 Mar 2024 05:21:11 -0800 Subject: [PATCH] Support sharding for get_parameter_overview PiperOrigin-RevId: 612412904 --- clu/parameter_overview.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/clu/parameter_overview.py b/clu/parameter_overview.py index 84cf435..ae2fd5c 100644 --- a/clu/parameter_overview.py +++ b/clu/parameter_overview.py @@ -259,7 +259,7 @@ def _get_parameter_overview( def get_parameter_overview( params: _ParamsContainer, *, - include_stats: bool = True, + include_stats: bool | str = True, max_lines: int | None = None, ) -> str: """Returns a string with variables names, their shapes, count. @@ -267,7 +267,9 @@ def get_parameter_overview( Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. - include_stats: If True, add columns with mean and std for each variable. + include_stats: If True, add columns with mean and std for each variable. If + the string "global", params are sharded global arrays and this function + assumes it is called on every host, i.e. can use collectives. max_lines: If not `None`, the maximum number of variables to include. Returns: