Skip to content

Commit 61442b6

Browse files
committed
get_predictions with torch cat
1 parent 53e5484 commit 61442b6

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

modAL/dropout.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
308308

309309
for i in range(num_predictions):
310310

311-
probas = None
311+
probas = []
312312

313313
for index, samples in enumerate(split_args):
314314
#call Skorch infer function to perform model forward pass
@@ -317,14 +317,12 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
317317
with torch.no_grad():
318318
logits = classifier.estimator.infer(samples)
319319
prediction = logits_adaptor(logits, samples)
320-
321320
mask = ~prediction.isnan()
322321
prediction[mask] = prediction[mask].unsqueeze(0).softmax(1)
323-
if probas is None: probas = torch.empty((number_of_samples, prediction.shape[-1]), device='cpu')
324-
probas[range(sample_per_forward_pass*index, sample_per_forward_pass*(index+1)), :] = prediction.cpu()
325-
326-
probas = to_numpy(probas)
327-
predictions.append(probas)
322+
probas.append(prediction)
323+
324+
probas = torch.cat(probas)
325+
predictions.append(to_numpy(probas))
328326

329327
# set dropout layers to eval
330328
set_dropout_mode(classifier.estimator.module_, dropout_layer_indexes, train_mode=False)

0 commit comments

Comments
 (0)