File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed
torchprime/torch_xla_models Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -204,13 +204,14 @@ def _log_shapes(self, batch):
204
204
)
205
205
logger .info (f"[{ self .name } ] data shapes: { shapes } " )
206
206
207
- # Visualize one tensor.
208
- import click
209
- from torch_xla .distributed .spmd .debugging import visualize_tensor_sharding
210
-
207
+ # Visualize one example tensor.
211
208
t = next (iter (pytree .tree_iter (batch )))
212
- generated_table = visualize_tensor_sharding (t , use_color = False )
213
- click .echo (generated_table )
209
+ if t .device .type == "xla" :
210
+ import click
211
+ from torch_xla .distributed .spmd .debugging import visualize_tensor_sharding
212
+
213
+ generated_table = visualize_tensor_sharding (t , use_color = False )
214
+ click .echo (generated_table )
214
215
215
216
def __len__ (self ):
216
217
return len (self .dataloader )
You can’t perform that action at this time.
0 commit comments