Skip to content

Commit 84f8b7e

Browse files
committed
Move tensor to device
1 parent 27f362e commit 84f8b7e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

chebai/models/electra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def __call__(self, target, input):
369369
negated_cone_axes, negated_cone_arcs = self.negate(cone_arcs, cone_axes)
370370

371371
predicted_vectors = input["predicted_vectors"]
372-
loss = torch.zeros((predicted_vectors.shape[0], cone_axes.shape[1]))
372+
loss = torch.zeros((predicted_vectors.shape[0], cone_axes.shape[1]), device=target.get_device())
373373
fltr = target.bool()
374374
loss[fltr] = 1 - self.cal_logit_cone(predicted_vectors, cone_axes, cone_arcs)[fltr]
375375
loss[~fltr] = 1 - self.cal_logit_cone(predicted_vectors, negated_cone_axes,

0 commit comments

Comments
 (0)