2525from clu .metric_writers import utils
2626from clu .metric_writers .async_writer import AsyncMultiWriter
2727from clu .metric_writers .async_writer import AsyncWriter
28- from clu .metric_writers .summary_writer import SummaryWriter
2928from clu .metric_writers .interface import MetricWriter
3029from clu .metric_writers .logging_writer import LoggingWriter
3130from clu .metric_writers .multi_writer import MultiWriter
31+ from clu .metric_writers .summary_writer import SummaryWriter
3232import clu .metrics
3333import flax .struct
3434import jax .numpy as jnp
@@ -129,17 +129,20 @@ def test_write(self):
129129 "image" : ImageMetric (jnp .asarray ([[4 , 5 ], [1 , 2 ]])),
130130 }
131131 histogram_metrics = {
132- "hist" :
133- HistogramMetric (value = jnp .asarray ([7 , 8 ]), num_buckets = num_buckets ),
134- "hist2" :
135- HistogramMetric (
136- value = jnp .asarray ([9 , 10 ]), num_buckets = num_buckets ),
132+ "hist" : HistogramMetric (
133+ value = jnp .asarray ([7 , 8 ]), num_buckets = num_buckets
134+ ),
135+ "hist2" : HistogramMetric (
136+ value = jnp .asarray ([9 , 10 ]), num_buckets = num_buckets
137+ ),
137138 }
138139 audio_metrics = {
139- "audio" :
140- AudioMetric (value = jnp .asarray ([1 , 5 ]), sample_rate = sample_rate ),
141- "audio2" :
142- AudioMetric (value = jnp .asarray ([1 , 5 ]), sample_rate = sample_rate + 2 ),
140+ "audio" : AudioMetric (
141+ value = jnp .asarray ([1 , 5 ]), sample_rate = sample_rate
142+ ),
143+ "audio2" : AudioMetric (
144+ value = jnp .asarray ([1 , 5 ]), sample_rate = sample_rate + 2
145+ ),
143146 }
144147 text_metrics = {
145148 "text" : TextMetric (value = "hello" ),
@@ -148,10 +151,10 @@ def test_write(self):
148151 "lr" : HyperParamMetric (value = 0.01 ),
149152 }
150153 summary_metrics = {
151- "summary" :
152- SummaryMetric ( value = jnp .asarray ([2 , 3 , 10 ]), metadata = "some info" ),
153- "summary2" :
154- SummaryMetric (value = jnp .asarray ([2 , 3 , 10 ]), metadata = 5 ),
154+ "summary" : SummaryMetric (
155+ value = jnp .asarray ([2 , 3 , 10 ]), metadata = "some info"
156+ ),
157+ "summary2" : SummaryMetric (value = jnp .asarray ([2 , 3 , 10 ]), metadata = 5 ),
155158 }
156159 metrics = {
157160 ** scalar_metrics ,
@@ -166,29 +169,36 @@ def test_write(self):
166169 utils .write_values (writer , step , metrics )
167170
168171 writer .write_scalars .assert_called_once_with (
169- step , {k : m .compute () for k , m in scalar_metrics .items ()})
170- writer .write_images .assert_called_once_with (step ,
171- _to_summary (image_metrics ))
172+ step , {k : m .compute () for k , m in scalar_metrics .items ()}
173+ )
174+ writer .write_images .assert_called_once_with (
175+ step , _to_summary (image_metrics )
176+ )
172177 writer .write_histograms .assert_called_once_with (
173178 step ,
174179 _to_summary (histogram_metrics ),
175- num_buckets = {k : v .num_buckets for k , v in histogram_metrics .items ()})
180+ num_buckets = {k : v .num_buckets for k , v in histogram_metrics .items ()},
181+ )
176182 writer .write_audios .assert_called_with (
177183 step ,
178184 ONEOF (_to_list_of_dicts (_to_summary (audio_metrics ))),
179- sample_rate = ONEOF ([sample_rate , sample_rate + 2 ]))
185+ sample_rate = ONEOF ([sample_rate , sample_rate + 2 ]),
186+ )
180187 writer .write_texts .assert_called_once_with (step , _to_summary (text_metrics ))
181- writer .write_hparams .assert_called_once_with (step ,
182- _to_summary (hparam_metrics ))
188+ writer .write_hparams .assert_called_once_with (
189+ step , _to_summary (hparam_metrics )
190+ )
183191 writer .write_summaries .assert_called_with (
184192 step ,
185193 ONEOF (_to_list_of_dicts (_to_summary (summary_metrics ))),
186- metadata = ONEOF (["some info" , 5 ]))
194+ metadata = ONEOF (["some info" , 5 ]),
195+ )
187196
188197
189198 def test_create_default_writer_summary_writer_is_added (self ):
190199 writer = utils .create_default_writer (
191- logdir = self .get_temp_dir (), asynchronous = False )
200+ logdir = self .get_temp_dir (), asynchronous = False
201+ )
192202 self .assertTrue (any (isinstance (w , SummaryWriter ) for w in writer ._writers ))
193203 writer = utils .create_default_writer (logdir = None , asynchronous = False )
194204 self .assertFalse (any (isinstance (w , SummaryWriter ) for w in writer ._writers ))
0 commit comments