Skip to content

Commit 26585e1

Browse files
committed
Put score back into base class (Accounting for arbitrary deep learning models is not possible)
1 parent c516006 commit 26585e1

File tree

2 files changed

+13
-37
lines changed

2 files changed

+13
-37
lines changed

modAL/models/base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,19 @@ def query(self, X_pool, *query_args, **query_kwargs) -> Union[Tuple, modALinput]
176176

177177
return query_result, retrieve_rows(X_pool, query_result), query_metrics
178178

179-
@abc.abstractmethod
180-
def score(self, *args, **kwargs) -> None:
181-
pass
179+
def score(self, X: modALinput, y: modALinput, **score_kwargs) -> Any:
180+
"""
181+
Interface for the score method of the predictor.
182+
183+
Args:
184+
X: The samples for which prediction accuracy is to be calculated.
185+
y: Ground truth labels for X.
186+
**score_kwargs: Keyword arguments to be passed to the .score() method of the predictor.
187+
188+
Returns:
189+
The score of the predictor.
190+
"""
191+
return self.estimator.score(X, y, **score_kwargs)
182192

183193
@abc.abstractmethod
184194
def teach(self, *args, **kwargs) -> None:

modAL/models/learners.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -160,20 +160,6 @@ def fit(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwarg
160160
self.X_training, self.y_training = X, y
161161
return self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
162162

163-
def score(self, X: modALinput, y: modALinput, **score_kwargs) -> Any:
164-
"""
165-
Interface for the score method of the predictor.
166-
167-
Args:
168-
X: The samples for which prediction accuracy is to be calculated.
169-
y: Ground truth labels for X.
170-
**score_kwargs: Keyword arguments to be passed to the .score() method of the predictor.
171-
172-
Returns:
173-
The score of the predictor.
174-
"""
175-
return self.estimator.score(X, y, **score_kwargs)
176-
177163
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None:
178164
"""
179165
Adds X and y to the known training data and retrains the predictor with the augmented dataset.
@@ -245,26 +231,6 @@ def fit(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwarg
245231
"""
246232
return self._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
247233

248-
def score(self, X: modALinput, y: modALinput) -> Any:
249-
"""
250-
Interface for the score method of the predictor.
251-
252-
Args:
253-
X: The samples for which prediction accuracy is to be calculated.
254-
y: Ground truth labels for X.
255-
256-
Returns:
257-
The score of the predictor.
258-
"""
259-
"""
260-
sklearn does only accept tensors of different dim for X and Y, if we use
261-
Multilabel classifiaction. Using tensors of different sizes for more complex models (e.g. Transformers)
262-
requires to bypass the sklearn checks by directly calling the NeuralNets infer() function.
263-
"""
264-
prediction = self.estimator.infer(X)
265-
criterion = self.estimator.criterion()
266-
return criterion(prediction, y).item()
267-
268234
def teach(self, X: modALinput, y: modALinput, warm_start: bool = True, bootstrap: bool = False, **fit_kwargs) -> None:
269235
"""
270236
Adds X and y to the known training data and retrains the predictor with the augmented dataset.

0 commit comments

Comments
 (0)