Skip to content

Commit 21633ae

Browse files
committed
Only visualize XLA tensors
1 parent 6aed8a0 commit 21633ae

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

torchprime/torch_xla_models/train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,14 @@ def _log_shapes(self, batch):
204204
)
205205
logger.info(f"[{self.name}] data shapes: {shapes}")
206206

207-
# Visualize one tensor.
208-
import click
209-
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
210-
207+
# Visualize one example tensor.
211208
t = next(iter(pytree.tree_iter(batch)))
212-
generated_table = visualize_tensor_sharding(t, use_color=False)
213-
click.echo(generated_table)
209+
if t.device.type == "xla":
210+
import click
211+
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
212+
213+
generated_table = visualize_tensor_sharding(t, use_color=False)
214+
click.echo(generated_table)
214215

215216
def __len__(self):
216217
return len(self.dataloader)

0 commit comments

Comments
 (0)