Skip to content

Commit a72386a

Browse files
committed
torch no grad in get_predictions
1 parent 4f42d20 commit a72386a

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

modAL/dropout.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,14 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
314314
#call Skorch infer function to perform model forward pass
315315
#In comparison to: predict(), predict_proba() the infer()
316316
# does not change train/eval mode of other layers
317-
logits = classifier.estimator.infer(samples)
318-
prediction = logits_adaptor(logits, samples)
319-
320-
mask = ~prediction.isnan()
321-
prediction[mask] = prediction[mask].unsqueeze(0).softmax(1)
322-
if probas is None: probas = torch.empty((number_of_samples, prediction.shape[-1]), device='cpu')
323-
probas[range(sample_per_forward_pass*index, sample_per_forward_pass*(index+1)), :] = prediction.cpu()
317+
with torch.no_grad:
318+
logits = classifier.estimator.infer(samples)
319+
prediction = logits_adaptor(logits, samples)
320+
321+
mask = ~prediction.isnan()
322+
prediction[mask] = prediction[mask].unsqueeze(0).softmax(1)
323+
if probas is None: probas = torch.empty((number_of_samples, prediction.shape[-1]), device='cpu')
324+
probas[range(sample_per_forward_pass*index, sample_per_forward_pass*(index+1)), :] = prediction.cpu()
324325

325326
probas = to_numpy(probas)
326327
predictions.append(probas)

0 commit comments

Comments
 (0)