Skip to content

Commit d090965

Browse files
CLU Authorscopybara-github
authored andcommitted
Add point clouds summary writer to tensorboard interface and metric writer.
PiperOrigin-RevId: 653292723
1 parent 92334e9 commit d090965

File tree

10 files changed

+254
-21
lines changed

10 files changed

+254
-21
lines changed

clu/metric_writers/async_writer.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@
2121
thread on the next write_*() call.
2222
"""
2323

24+
from collections.abc import Mapping, Sequence
2425
import contextlib
25-
from typing import Any, Mapping, Optional, Sequence
26+
from typing import Any, Optional, Union
2627

2728
from clu import asynclib
2829

2930
from clu.metric_writers import interface
3031
from clu.metric_writers import multi_writer
3132
import wrapt
3233

34+
3335
Array = interface.Array
3436
Scalar = interface.Scalar
3537

@@ -95,21 +97,44 @@ def write_videos(self, step: int, videos: Mapping[str, Array]):
9597

9698
@_wrap_exceptions
9799
def write_audios(
98-
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
100+
self, step: int, audios: Mapping[str, Array], *, sample_rate: int
101+
):
99102
self._pool(self._writer.write_audios)(
100-
step=step, audios=audios, sample_rate=sample_rate)
103+
step=step, audios=audios, sample_rate=sample_rate
104+
)
101105

102106
@_wrap_exceptions
103107
def write_texts(self, step: int, texts: Mapping[str, str]):
104108
self._pool(self._writer.write_texts)(step=step, texts=texts)
105109

106110
@_wrap_exceptions
107-
def write_histograms(self,
108-
step: int,
109-
arrays: Mapping[str, Array],
110-
num_buckets: Optional[Mapping[str, int]] = None):
111+
def write_histograms(
112+
self,
113+
step: int,
114+
arrays: Mapping[str, Array],
115+
num_buckets: Optional[Mapping[str, int]] = None,
116+
):
111117
self._pool(self._writer.write_histograms)(
112-
step=step, arrays=arrays, num_buckets=num_buckets)
118+
step=step, arrays=arrays, num_buckets=num_buckets
119+
)
120+
121+
@_wrap_exceptions
122+
def write_pointcloud(
123+
self,
124+
step: int,
125+
point_clouds: Mapping[str, Array],
126+
*,
127+
point_colors: Optional[Array] = None,
128+
configs: Optional[
129+
Mapping[str, Union[str, int, float, bool, None]]
130+
] = None,
131+
):
132+
self._pool(self._writer.write_pointcloud)(
133+
step=step,
134+
point_clouds=point_clouds,
135+
point_colors=point_colors,
136+
configs=configs,
137+
)
113138

114139
@_wrap_exceptions
115140
def write_hparams(self, hparams: Mapping[str, Any]):

clu/metric_writers/async_writer_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,27 @@ def test_write_videos(self):
7272
self.sync_writer.write_videos.assert_called_with(4,
7373
{"input_videos": mock.ANY})
7474

75+
def test_write_pointcloud(self):
76+
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
77+
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
78+
config = {
79+
"material": "PointCloudMaterial",
80+
"size": 0.09,
81+
}
82+
self.writer.write_pointcloud(
83+
step=0,
84+
point_clouds={"pcd": point_clouds},
85+
point_colors={"pcd": point_colors},
86+
configs={"config": config},
87+
)
88+
self.writer.flush()
89+
self.sync_writer.write_pointcloud.assert_called_with(
90+
step=0,
91+
point_clouds={"pcd": mock.ANY},
92+
point_colors={"pcd": mock.ANY},
93+
configs={"config": mock.ANY},
94+
)
95+
7596
def test_write_texts(self):
7697
self.writer.write_texts(4, {"samples": "bla"})
7798
self.writer.flush()

clu/metric_writers/interface.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"""
2121

2222
import abc
23-
from typing import Any, Mapping, Optional, Union
23+
from collections.abc import Mapping
24+
from typing import Any, Optional, Union
2425

2526
import jax.numpy as jnp
2627
import numpy as np
@@ -152,6 +153,28 @@ def write_histograms(self,
152153
of the MetricWriter.
153154
"""
154155

156+
@abc.abstractmethod
157+
def write_pointcloud(
158+
self,
159+
step: int,
160+
point_clouds: Mapping[str, Array],
161+
*,
162+
point_colors: Optional[Mapping[str, Array]] = None,
163+
configs: Optional[
164+
Mapping[str, Union[str, int, float, bool, None]]
165+
] = None,
166+
):
167+
"""Writes point cloud summaries.
168+
169+
Args:
170+
step: Step at which the point cloud was generated.
171+
point_clouds: Mapping from point clouds key to point cloud of shape [N, 3]
172+
array of point coordinates.
173+
point_colors: Mapping from point colors key to [N, 3] array of point
174+
colors.
175+
configs: A dictionary of configuration options for the point cloud.
176+
"""
177+
155178
@abc.abstractmethod
156179
def write_hparams(self, hparams: Mapping[str, Any]):
157180
"""Write hyper parameters.

clu/metric_writers/logging_writer.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
"""MetricWriter that writes all values to INFO log."""
1616

17-
from typing import Any, Mapping, Optional, Tuple
17+
from collections.abc import Mapping
18+
from typing import Any, Optional, Union
1819

1920
from absl import logging
2021
from clu.metric_writers import interface
@@ -76,6 +77,29 @@ def write_histograms(self,
7677
self._collection_str, key,
7778
_get_histogram_as_string(histo, bins))
7879

80+
def write_pointcloud(
81+
self,
82+
step: int,
83+
point_clouds: Mapping[str, Array],
84+
*,
85+
point_colors: Optional[Mapping[str, Any]] = None,
86+
configs: Optional[
87+
Mapping[str, Union[str, int, float, bool, None]]
88+
] = None,
89+
):
90+
logging.info(
91+
"[%d]%s Got point clouds: %s, point_colors: %s, configs: %s.",
92+
step,
93+
self._collection_str,
94+
{k: v.shape for k, v in point_clouds.items()},
95+
(
96+
{k: v.shape for k, v in point_colors.items()}
97+
if point_colors is not None
98+
else None
99+
),
100+
configs,
101+
)
102+
79103
def write_hparams(self, hparams: Mapping[str, Any]):
80104
logging.info("[Hyperparameters]%s %s", self._collection_str, hparams)
81105

@@ -89,7 +113,7 @@ def close(self):
89113
def _compute_histogram_as_tf(
90114
array: np.ndarray,
91115
num_buckets: Optional[int] = None
92-
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
116+
) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
93117
"""Compute the histogram of the input array as TF would do.
94118
95119
Args:

clu/metric_writers/logging_writer_test.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,42 @@ def test_write_histogram(self):
8080
"INFO:absl:[4] Histogram for 'c' = {[-0.4, 0.6]: 5}",
8181
])
8282

83+
def test_write_pointcloud(self):
84+
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
85+
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
86+
config = {
87+
"material": "PointCloudMaterial",
88+
"size": 0.09,
89+
}
90+
with self.assertLogs(level="INFO") as logs:
91+
self.writer.write_pointcloud(
92+
step=4,
93+
point_clouds={"pcd": point_clouds},
94+
point_colors={"pcd": point_colors},
95+
configs={"configs": config},
96+
)
97+
self.assertEqual(
98+
logs.output,
99+
[
100+
"INFO:absl:[4] Got point clouds: {'pcd': (1, 1024, 3)},"
101+
" point_colors: {'pcd': (1, 1024, 3)}, configs: {'configs':"
102+
" {'material': 'PointCloudMaterial', 'size': 0.09}}."
103+
],
104+
)
105+
83106
def test_write_hparams(self):
84107
with self.assertLogs(level="INFO") as logs:
85108
self.writer.write_hparams({"learning_rate": 0.1, "batch_size": 128})
86-
self.assertEqual(logs.output, [
87-
"INFO:absl:[Hyperparameters] {'learning_rate': 0.1, 'batch_size': 128}"
88-
])
109+
self.assertEqual(
110+
logs.output,
111+
[
112+
"INFO:absl:[Hyperparameters] {'learning_rate': 0.1, 'batch_size':"
113+
" 128}"
114+
],
115+
)
89116

90117
def test_collection(self):
118+
writer = logging_writer.LoggingWriter(collection="train")
91119
writer = logging_writer.LoggingWriter(collection="train")
92120
with self.assertLogs(level="INFO") as logs:
93121
writer.write_scalars(0, {"a": 3, "b": 0.15})

clu/metric_writers/multi_writer.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
"""MetricWriter that writes to multiple MetricWriters."""
1616

17-
from typing import Any, Mapping, Optional, Sequence
17+
from collections.abc import Mapping, Sequence
18+
from typing import Any, Optional, Union
1819

1920
from clu.metric_writers import interface
2021

@@ -56,13 +57,27 @@ def write_texts(self, step: int, texts: Mapping[str, str]):
5657
for w in self._writers:
5758
w.write_texts(step, texts)
5859

59-
def write_histograms(self,
60-
step: int,
61-
arrays: Mapping[str, Array],
62-
num_buckets: Optional[Mapping[str, int]] = None):
60+
def write_histograms(
61+
self,
62+
step: int,
63+
arrays: Mapping[str, Array],
64+
num_buckets: Optional[Mapping[str, int]] = None):
6365
for w in self._writers:
6466
w.write_histograms(step, arrays, num_buckets)
6567

68+
def write_pointcloud(
69+
self,
70+
step: int,
71+
point_clouds: Mapping[str, Array],
72+
*,
73+
point_colors: Optional[Mapping[str, Array]] = None,
74+
configs: Optional[Mapping[str, Union[str, int, float, bool, None]]] = None
75+
):
76+
for w in self._writers:
77+
w.write_pointcloud(
78+
step, point_clouds, point_colors=point_colors, configs=configs
79+
)
80+
6681
def write_hparams(self, hparams: Mapping[str, Any]):
6782
for w in self._writers:
6883
w.write_hparams(hparams)

clu/metric_writers/multi_writer_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from clu.metric_writers import interface
2020
from clu.metric_writers import multi_writer
21+
import numpy as np
2122
import tensorflow as tf
2223

2324

@@ -48,6 +49,29 @@ def test_write_scalars(self):
4849
])
4950
w.flush.assert_called()
5051

52+
def test_write_pointcloud(self):
53+
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
54+
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
55+
config = {
56+
"material": "PointCloudMaterial",
57+
"size": 0.09,
58+
}
59+
self.writer.write_pointcloud(
60+
step=0,
61+
point_clouds={"pcd": point_clouds},
62+
point_colors={"pcd": point_colors},
63+
configs={"config": config},
64+
)
65+
self.writer.flush()
66+
for w in self.writers:
67+
w.write_pointcloud.assert_called_with(
68+
step=0,
69+
point_clouds={"pcd": point_clouds},
70+
point_colors={"pcd": point_colors},
71+
configs={"config": config},
72+
)
73+
w.flush.assert_called()
74+
5175

5276
if __name__ == "__main__":
5377
tf.test.main()

clu/metric_writers/tf/summary_writer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
from collections.abc import Mapping
22-
from typing import Any, Optional
22+
from typing import Any, Optional, Union
2323

2424
from absl import logging
2525

@@ -31,6 +31,7 @@
3131
with epy.lazy_imports():
3232
# pylint: disable=g-import-not-at-top
3333
from tensorboard.plugins.hparams import api as hparams_api
34+
from tensorboard.plugins.mesh import summary as mesh_summary # pylint: disable=line-too-long
3435
# pylint: enable=g-import-not-at-top
3536

3637

@@ -97,6 +98,28 @@ def write_histograms(
9798
buckets = None if num_buckets is None else num_buckets.get(key)
9899
tf.summary.histogram(key, value, step=step, buckets=buckets)
99100

101+
def write_pointcloud(
102+
self,
103+
step: int,
104+
point_clouds: Mapping[str, Array],
105+
*,
106+
point_colors: Optional[Mapping[str, Array]] = None,
107+
configs: Optional[
108+
Mapping[str, Union[str, int, float, bool, None]]
109+
] = None,
110+
):
111+
with self._summary_writer.as_default():
112+
for key, vertices in point_clouds.items():
113+
colors = None if point_colors is None else point_colors.get(key)
114+
config = None if configs is None else configs.get(key)
115+
mesh_summary.mesh(
116+
key,
117+
vertices=vertices,
118+
colors=colors,
119+
step=step,
120+
config_dict=config,
121+
)
122+
100123
def write_hparams(self, hparams: Mapping[str, Any]):
101124
with self._summary_writer.as_default():
102125
hparams_api.hparams(dict(utils.flatten_dict(hparams)))

0 commit comments

Comments
 (0)