Skip to content

Commit 307b0bc

Browse files
CLU Authorscopybara-github
authored andcommitted
Add point clouds summary writer to kauldron interface with tensorboard and jaxboard.
PiperOrigin-RevId: 665840429
1 parent 7c22ddf commit 307b0bc

File tree

10 files changed

+206
-1
lines changed

10 files changed

+206
-1
lines changed

clu/metric_writers/async_writer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,22 @@ def write_histograms(self,
112112
self._pool(self._writer.write_histograms)(
113113
step=step, arrays=arrays, num_buckets=num_buckets)
114114

115+
@_wrap_exceptions
116+
def write_pointcloud(
117+
self,
118+
step: int,
119+
point_clouds: Mapping[str, Array],
120+
*,
121+
point_colors: Mapping[str, Array] | None = None,
122+
configs: Mapping[str, str | float | bool | None] | None = None,
123+
):
124+
self._pool(self._writer.write_pointcloud)(
125+
step=step,
126+
point_clouds=point_clouds,
127+
point_colors=point_colors,
128+
configs=configs,
129+
)
130+
115131
@_wrap_exceptions
116132
def write_hparams(self, hparams: Mapping[str, Any]):
117133
self._pool(self._writer.write_hparams)(hparams=hparams)

clu/metric_writers/async_writer_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,27 @@ def test_write_videos(self):
7272
self.sync_writer.write_videos.assert_called_with(4,
7373
{"input_videos": mock.ANY})
7474

75+
def test_write_pointcloud(self):
76+
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
77+
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
78+
config = {
79+
"material": "PointCloudMaterial",
80+
"size": 0.09,
81+
}
82+
self.writer.write_pointcloud(
83+
step=0,
84+
point_clouds={"pcd": point_clouds},
85+
point_colors={"pcd": point_colors},
86+
configs={"config": config},
87+
)
88+
self.writer.flush()
89+
self.sync_writer.write_pointcloud.assert_called_with(
90+
step=0,
91+
point_clouds={"pcd": mock.ANY},
92+
point_colors={"pcd": mock.ANY},
93+
configs={"config": mock.ANY},
94+
)
95+
7596
def test_write_texts(self):
7697
self.writer.write_texts(4, {"samples": "bla"})
7798
self.writer.flush()

clu/metric_writers/interface.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,26 @@ def write_histograms(self,
153153
of the MetricWriter.
154154
"""
155155

156+
def write_pointcloud(
157+
self,
158+
step: int,
159+
point_clouds: Mapping[str, Array],
160+
*,
161+
point_colors: Mapping[str, Array] | None = None,
162+
configs: Mapping[str, str | float | bool | None] | None = None,
163+
):
164+
"""Writes point cloud summaries.
165+
166+
Args:
167+
step: Step at which the point cloud was generated.
168+
point_clouds: Mapping from point clouds key to point cloud of shape [N, 3]
169+
array of point coordinates.
170+
point_colors: Mapping from point colors key to [N, 3] array of point
171+
colors.
172+
configs: A dictionary of configuration options for the point cloud.
173+
"""
174+
raise NotImplementedError()
175+
156176
@abc.abstractmethod
157177
def write_hparams(self, hparams: Mapping[str, Any]):
158178
"""Write hyper parameters.

clu/metric_writers/logging_writer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,27 @@ def write_histograms(self,
7777
self._collection_str, key,
7878
_get_histogram_as_string(histo, bins))
7979

80+
def write_pointcloud(
81+
self,
82+
step: int,
83+
point_clouds: Mapping[str, Array],
84+
*,
85+
point_colors: Mapping[str, Any] | None = None,
86+
configs: Mapping[str, str | float | bool | None] | None = None,
87+
):
88+
logging.info(
89+
"[%d]%s Got point clouds: %s, point_colors: %s, configs: %s.",
90+
step,
91+
self._collection_str,
92+
{k: v.shape for k, v in point_clouds.items()},
93+
(
94+
{k: v.shape for k, v in point_colors.items()}
95+
if point_colors is not None
96+
else None
97+
),
98+
configs,
99+
)
100+
80101
def write_hparams(self, hparams: Mapping[str, Any]):
81102
logging.info("[Hyperparameters]%s %s", self._collection_str, hparams)
82103

clu/metric_writers/logging_writer_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ def test_write_histogram(self):
8080
"INFO:absl:[4] Histogram for 'c' = {[-0.4, 0.6]: 5}",
8181
])
8282

83+
def test_write_pointcloud(self):
84+
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
85+
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
86+
config = {
87+
"material": "PointCloudMaterial",
88+
"size": 0.09,
89+
}
90+
with self.assertLogs(level="INFO") as logs:
91+
self.writer.write_pointcloud(
92+
step=4,
93+
point_clouds={"pcd": point_clouds},
94+
point_colors={"pcd": point_colors},
95+
configs={"configs": config},
96+
)
97+
self.assertEqual(
98+
logs.output,
99+
[
100+
"INFO:absl:[4] Got point clouds: {'pcd': (1, 1024, 3)},"
101+
" point_colors: {'pcd': (1, 1024, 3)}, configs: {'configs':"
102+
" {'material': 'PointCloudMaterial', 'size': 0.09}}."
103+
],
104+
)
105+
83106
def test_write_hparams(self):
84107
with self.assertLogs(level="INFO") as logs:
85108
self.writer.write_hparams({"learning_rate": 0.1, "batch_size": 128})

clu/metric_writers/multi_writer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,19 @@ def write_histograms(self,
6464
for w in self._writers:
6565
w.write_histograms(step, arrays, num_buckets)
6666

67+
def write_pointcloud(
68+
self,
69+
step: int,
70+
point_clouds: Mapping[str, Array],
71+
*,
72+
point_colors: Mapping[str, Array] | None = None,
73+
configs: Mapping[str, str | float | bool | None] | None = None,
74+
):
75+
for w in self._writers:
76+
w.write_pointcloud(
77+
step, point_clouds, point_colors=point_colors, configs=configs
78+
)
79+
6780
def write_hparams(self, hparams: Mapping[str, Any]):
6881
for w in self._writers:
6982
w.write_hparams(hparams)

clu/metric_writers/multi_writer_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from clu.metric_writers import interface
2020
from clu.metric_writers import multi_writer
21+
import numpy as np
2122
import tensorflow as tf
2223

2324

@@ -48,6 +49,29 @@ def test_write_scalars(self):
4849
])
4950
w.flush.assert_called()
5051

52+
def test_write_pointcloud(self):
53+
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
54+
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
55+
config = {
56+
"material": "PointCloudMaterial",
57+
"size": 0.09,
58+
}
59+
self.writer.write_pointcloud(
60+
step=0,
61+
point_clouds={"pcd": point_clouds},
62+
point_colors={"pcd": point_colors},
63+
configs={"config": config},
64+
)
65+
self.writer.flush()
66+
for w in self.writers:
67+
w.write_pointcloud.assert_called_with(
68+
step=0,
69+
point_clouds={"pcd": point_clouds},
70+
point_colors={"pcd": point_colors},
71+
configs={"config": config},
72+
)
73+
w.flush.assert_called()
74+
5175

5276
if __name__ == "__main__":
5377
tf.test.main()

clu/metric_writers/tf/summary_writer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
with epy.lazy_imports():
3232
# pylint: disable=g-import-not-at-top
3333
from tensorboard.plugins.hparams import api as hparams_api
34+
from tensorboard.plugins.mesh import summary as mesh_summary # pylint: disable=line-too-long
3435
# pylint: enable=g-import-not-at-top
3536

3637

@@ -97,6 +98,26 @@ def write_histograms(
9798
buckets = None if num_buckets is None else num_buckets.get(key)
9899
tf.summary.histogram(key, value, step=step, buckets=buckets)
99100

101+
def write_pointcloud(
102+
self,
103+
step: int,
104+
point_clouds: Mapping[str, Array],
105+
*,
106+
point_colors: Mapping[str, Array] | None = None,
107+
configs: Mapping[str, str | float | bool | None] | None = None,
108+
):
109+
with self._summary_writer.as_default():
110+
for key, vertices in point_clouds.items():
111+
colors = None if point_colors is None else point_colors.get(key)
112+
config = None if configs is None else configs.get(key)
113+
mesh_summary.mesh(
114+
key,
115+
vertices=vertices,
116+
colors=colors,
117+
step=step,
118+
config_dict=config,
119+
)
120+
100121
def write_hparams(self, hparams: Mapping[str, Any]):
101122
with self._summary_writer.as_default():
102123
hparams_api.hparams(dict(utils.flatten_dict(hparams)))

clu/metric_writers/tf/summary_writer_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,21 @@ def _load_scalars_data(logdir: str):
6969
return data
7070

7171

72+
def _load_pointcloud_data(logdir: str):
73+
"""Loads pointcloud summaries from events in a logdir."""
74+
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
75+
data = collections.defaultdict(dict)
76+
for path in paths:
77+
for event in tf.compat.v1.train.summary_iterator(path):
78+
for value in event.summary.value:
79+
if value.metadata.plugin_data.plugin_name == "mesh":
80+
if "config" not in value.tag:
81+
data[event.step][value.tag] = tf.make_ndarray(value.tensor)
82+
else:
83+
data[event.step][value.tag] = value.metadata.plugin_data.content
84+
return data
85+
86+
7287
def _load_hparams(logdir: str):
7388
"""Loads hparams summaries from events in a logdir."""
7489
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
@@ -142,6 +157,24 @@ def test_write_histograms(self):
142157
]
143158
self.assertAllClose(data["b"], ([0, 2], expected_histograms_b))
144159

160+
def test_write_pointcloud(self):
161+
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
162+
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
163+
config = {
164+
"material": "PointCloudMaterial",
165+
"size": 0.09,
166+
}
167+
self.writer.write_pointcloud(
168+
step=0,
169+
point_clouds={"pcd": point_clouds},
170+
point_colors={"pcd": point_colors},
171+
configs={"config": config},
172+
)
173+
self.writer.flush()
174+
data = _load_pointcloud_data(self.logdir)
175+
self.assertAllClose(data[0]["pcd_VERTEX"], point_clouds)
176+
self.assertAllClose(data[0]["pcd_COLOR"], point_colors)
177+
145178
def test_hparams(self):
146179
self.writer.write_hparams(dict(batch_size=512, num_epochs=90))
147180
hparams = _load_hparams(self.logdir)

clu/metric_writers/torch_tensorboard_writer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from typing import Any, Optional
2323
from absl import logging
2424

25-
2625
from clu.metric_writers import interface
2726
from torch.utils import tensorboard
2827

@@ -79,6 +78,20 @@ def write_histograms(self,
7978
self._writer.add_histogram(
8079
tag, values, global_step=step, bins="auto", max_bins=bins)
8180

81+
def write_pointcloud(
82+
self,
83+
step: int,
84+
point_clouds: Mapping[str, Array],
85+
*,
86+
point_colors: Mapping[str, Array] | None = None,
87+
configs: Mapping[str, str | float | bool | None] | None = None,
88+
):
89+
logging.log_first_n(
90+
logging.WARNING,
91+
"TorchTensorBoardWriter does not support writing point clouds.",
92+
1,
93+
)
94+
8295
def write_hparams(self, hparams: Mapping[str, Any]):
8396
self._writer.add_hparams(hparams, {})
8497

0 commit comments

Comments
 (0)