|
21 | 21 | thread on the next write_*() call. |
22 | 22 | """ |
23 | 23 |
|
| 24 | +from collections.abc import Mapping, Sequence |
24 | 25 | import contextlib |
25 | | -from typing import Any, Mapping, Optional, Sequence |
| 26 | +from typing import Any, Optional, Union |
26 | 27 |
|
27 | 28 | from clu import asynclib |
28 | 29 |
|
29 | 30 | from clu.metric_writers import interface |
30 | 31 | from clu.metric_writers import multi_writer |
31 | 32 | import wrapt |
32 | 33 |
|
| 34 | + |
33 | 35 | Array = interface.Array |
34 | 36 | Scalar = interface.Scalar |
35 | 37 |
|
@@ -95,21 +97,44 @@ def write_videos(self, step: int, videos: Mapping[str, Array]): |
95 | 97 |
|
96 | 98 | @_wrap_exceptions |
97 | 99 | def write_audios( |
98 | | - self, step: int, audios: Mapping[str, Array], *, sample_rate: int): |
| 100 | + self, step: int, audios: Mapping[str, Array], *, sample_rate: int |
| 101 | + ): |
99 | 102 | self._pool(self._writer.write_audios)( |
100 | | - step=step, audios=audios, sample_rate=sample_rate) |
| 103 | + step=step, audios=audios, sample_rate=sample_rate |
| 104 | + ) |
101 | 105 |
|
102 | 106 | @_wrap_exceptions |
103 | 107 | def write_texts(self, step: int, texts: Mapping[str, str]): |
104 | 108 | self._pool(self._writer.write_texts)(step=step, texts=texts) |
105 | 109 |
|
106 | 110 | @_wrap_exceptions |
107 | | - def write_histograms(self, |
108 | | - step: int, |
109 | | - arrays: Mapping[str, Array], |
110 | | - num_buckets: Optional[Mapping[str, int]] = None): |
| 111 | + def write_histograms( |
| 112 | + self, |
| 113 | + step: int, |
| 114 | + arrays: Mapping[str, Array], |
| 115 | + num_buckets: Optional[Mapping[str, int]] = None, |
| 116 | + ): |
111 | 117 | self._pool(self._writer.write_histograms)( |
112 | | - step=step, arrays=arrays, num_buckets=num_buckets) |
| 118 | + step=step, arrays=arrays, num_buckets=num_buckets |
| 119 | + ) |
| 120 | + |
| 121 | + @_wrap_exceptions |
| 122 | + def write_pointcloud( |
| 123 | + self, |
| 124 | + step: int, |
| 125 | + point_clouds: Mapping[str, Array], |
| 126 | + *, |
| 127 | + point_colors: Optional[Array] = None, |
| 128 | + configs: Optional[ |
| 129 | + Mapping[str, Union[str, int, float, bool, None]] |
| 130 | + ] = None, |
| 131 | + ): |
| 132 | + self._pool(self._writer.write_pointcloud)( |
| 133 | + step=step, |
| 134 | + point_clouds=point_clouds, |
| 135 | + point_colors=point_colors, |
| 136 | + configs=configs, |
| 137 | + ) |
113 | 138 |
|
114 | 139 | @_wrap_exceptions |
115 | 140 | def write_hparams(self, hparams: Mapping[str, Any]): |
|
0 commit comments