Skip to content

Commit 4c4e70f

Browse files
committed
Fix final(?) torch import
1 parent e318088 commit 4c4e70f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/spikeinterface/sortingcomponents/motion/decentralized.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,14 @@ def compute_pairwise_displacement(
320320
# use torch if installed
321321
try:
322322
import torch
323-
import torch.nn.functional as F
324323

325324
conv_engine = "torch"
326325
except ImportError:
327326
conv_engine = "numpy"
328327

328+
if conv_engine == "torch":
329+
import torch
330+
329331
assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'"
330332
size = motion_hist.shape[0]
331333
pairwise_displacement = np.zeros((size, size), dtype="float32")

0 commit comments

Comments
 (0)