@@ -82,13 +82,17 @@ def flatten_dict(
8282def _count_parameters (params : _ParamsContainer ) -> int :
8383 """Returns the count of variables for the module or parameter dictionary."""
8484 params = flatten_dict (params )
85- return sum (np .prod (v .shape ) for v in params .values ())
85+ return sum (np .prod (v .shape ) for v in params .values () if v is not None )
8686
8787
8888def _parameters_size (params : _ParamsContainer ) -> int :
8989 """Returns total size (bytes) for the module or parameter dictionary."""
9090 params = flatten_dict (params )
91- return sum (np .prod (v .shape ) * v .dtype .itemsize for v in params .values ())
91+ return sum (
92+ np .prod (v .shape ) * v .dtype .itemsize
93+ for v in params .values ()
94+ if v is not None
95+ )
9296
9397
9498def count_parameters (params : _ParamsContainer ) -> int :
@@ -127,6 +131,8 @@ def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
127131
128132def _make_row_with_stats (name , value , mean , std ) -> _ParamRowWithStats :
129133 row = _make_row (name , value )
134+ mean = mean or 0.0
135+ std = std or 0.0
130136 return _ParamRowWithStats (
131137 ** dataclasses .asdict (row ),
132138 mean = float (jax .device_get (mean )),
@@ -156,12 +162,11 @@ def _get_parameter_rows(
156162 params: Dictionary with parameters as NumPy arrays. The dictionary can be
157163 nested. Alternatively a `tf.Module` can be provided, in which case the
158164 `trainable_variables` of the module will be used.
159- include_stats: If True, add columns with mean and std for each variable.
160- If the string "sharding", add column a column with the sharding of the
161- variable.
162- If the string "global", params are sharded global arrays and this
163- function assumes it is called on every host, i.e. can use collectives.
164- The sharding of the variables is also added as a column.
165+ include_stats: If True, add columns with mean and std for each variable. If
166+ the string "sharding", add column a column with the sharding of the
167+ variable. If the string "global", params are sharded global arrays and
168+ this function assumes it is called on every host, i.e. can use
169+ collectives. The sharding of the variables is also added as a column.
165170
166171 Returns:
167172 A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
@@ -185,12 +190,14 @@ def _get_parameter_rows(
185190 case True :
186191 mean_and_std = _mean_std (values )
187192 return jax .tree_util .tree_map (
188- _make_row_with_stats , names , values , * mean_and_std )
193+ _make_row_with_stats , names , values , * mean_and_std
194+ )
189195
190196 case "global" :
191197 mean_and_std = _mean_std_jit (values )
192198 return jax .tree_util .tree_map (
193- _make_row_with_stats_and_sharding , names , values , * mean_and_std )
199+ _make_row_with_stats_and_sharding , names , values , * mean_and_std
200+ )
194201
195202 case "sharding" :
196203 return jax .tree_util .tree_map (_make_row_with_sharding , names , values )
@@ -256,8 +263,7 @@ def __init__(self, name, values):
256263 column_names = [field .name for field in dataclasses .fields (rows [0 ])]
257264
258265 columns = [
259- Column (name , [value_formatter (getattr (row , name ))
260- for row in rows ])
266+ Column (name , [value_formatter (getattr (row , name )) for row in rows ])
261267 for name in column_names
262268 ]
263269
@@ -312,12 +318,11 @@ def get_parameter_overview(
312318 Args:
313319 params: Dictionary with parameters as NumPy arrays. The dictionary can be
314320 nested.
315- include_stats: If True, add columns with mean and std for each variable.
316- If the string "sharding", add column a column with the sharding of the
317- variable.
318- If the string "global", params are sharded global arrays and this
319- function assumes it is called on every host, i.e. can use collectives.
320- The sharding of the variables is also added as a column.
321+ include_stats: If True, add columns with mean and std for each variable. If
322+ the string "sharding", add column a column with the sharding of the
323+ variable. If the string "global", params are sharded global arrays and
324+ this function assumes it is called on every host, i.e. can use
325+ collectives. The sharding of the variables is also added as a column.
321326 max_lines: If not `None`, the maximum number of variables to include.
322327
323328 Returns:
@@ -375,16 +380,19 @@ def log_parameter_overview(
375380 Args:
376381 params: Dictionary with parameters as NumPy arrays. The dictionary can be
377382 nested.
378- include_stats: If True, add columns with mean and std for each variable.
379- If the string "global", params are sharded global arrays and this
380- function assumes it is called on every host, i.e. can use collectives.
383+ include_stats: If True, add columns with mean and std for each variable. If
384+ the string "global", params are sharded global arrays and this function
385+ assumes it is called on every host, i.e. can use collectives.
381386 max_lines: If not `None`, the maximum number of variables to include.
382387 msg: Message to be logged before the overview.
383388 jax_logging_process: Which JAX process ID should do the logging. None = all.
384389 Use this to avoid logspam when include_stats="global".
385390 """
386391
387392 _log_parameter_overview (
388- params , include_stats = include_stats , max_lines = max_lines , msg = msg ,
389- jax_logging_process = jax_logging_process
393+ params ,
394+ include_stats = include_stats ,
395+ max_lines = max_lines ,
396+ msg = msg ,
397+ jax_logging_process = jax_logging_process ,
390398 )
0 commit comments