Skip to content

Commit af49d6d

Browse files
committed
device tensor adaption
1 parent ad081bb commit af49d6d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

modAL/dropout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
317317
logits = classifier.estimator.infer(samples)
318318
prediction = logits_adaptor(logits, samples)
319319

320-
if probas is None: probas = torch.empty((number_of_samples, prediction.shape[-1]))
320+
if probas is None: probas = torch.empty((number_of_samples, prediction.shape[-1]), device=prediction.device)
321321
probas[range(sample_per_forward_pass*index, sample_per_forward_pass*(index+1)), :] = prediction
322322

323323

0 commit comments

Comments
 (0)