Skip to content

Commit 0b961e5

Browse files
Conchylicultorcopybara-github
authored andcommitted
Fix typing annotations for get_parameters_overview
PiperOrigin-RevId: 612454326
1 parent eed40a1 commit 0b961e5

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

clu/parameter_overview.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,15 +259,17 @@ def _get_parameter_overview(
259259
def get_parameter_overview(
260260
params: _ParamsContainer,
261261
*,
262-
include_stats: bool = True,
262+
include_stats: bool | str = True,
263263
max_lines: int | None = None,
264264
) -> str:
265265
"""Returns a string with variables names, their shapes, count.
266266
267267
Args:
268268
params: Dictionary with parameters as NumPy arrays. The dictionary can be
269269
nested.
270-
include_stats: If True, add columns with mean and std for each variable.
270+
include_stats: If True, add columns with mean and std for each variable. If
271+
the string "global", params are sharded global arrays and this function
272+
assumes it is called on every host, i.e. can use collectives.
271273
max_lines: If not `None`, the maximum number of variables to include.
272274
273275
Returns:

0 commit comments

Comments
 (0)