diff --git a/src/amltk/optimization/metric.py b/src/amltk/optimization/metric.py index 9d5a462b..435dcdcd 100644 --- a/src/amltk/optimization/metric.py +++ b/src/amltk/optimization/metric.py @@ -28,8 +28,12 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import TYPE_CHECKING from typing_extensions import Self, override +if TYPE_CHECKING: + from sklearn.metrics._scorer import _BaseScorer + @dataclass(frozen=True) class Metric: @@ -65,6 +69,53 @@ def __post_init__(self) -> None: " Must be a valid Python identifier.", ) + @classmethod + def from_sklearn( + cls, + scorer: str | _BaseScorer, + *, + bounds: tuple[float, float] | None = None, + name: str | None = None, + ) -> Metric: + """Create a metric from a sklearn metric. + + The benefit of using this function is that it will also + set the bounds of the metric if possible, which can often + help Optimizers know how to normalize metrics and helping + them with search. + + ```python exec="true" source="material-block" result="python" + from amltk.optimization import Metric + + metric_acc = Metric.from_sklearn("accuracy") + metric_neg_log_loss = Metric.from_sklearn("neg_log_loss") + + from sklearn.metrics import get_scorer + + scorer = get_scorer("roc_auc") + metric_roc_auc = Metric.from_sklearn(scorer) + ``` + + Args: + scorer: The name of the sklearn scorer. + bounds: The bounds of the metric, if any. + By default, we will do a lookup of known scorers to get + their bounds. If not specified and no bounds are found, + a warning will be raised. + name: The name to give the metric specifically. By default, + it will use the `scorer`. If `scorer` is a string, it will + use that string, otherwise it will use the name of the + scorer function, appending `neg_` to make + `neg_{scorer.func.__name__}`. This is to make it match + the sklearn `get_scorer()`. + + Returns: + The metric. + """ + from amltk.sklearn.metrics import as_metric + + return as_metric(scorer=scorer, bounds=bounds, name=name) + @override def __str__(self) -> str: parts = [self.name] diff --git a/src/amltk/sklearn/metrics.py b/src/amltk/sklearn/metrics.py new file mode 100644 index 00000000..e5199e92 --- /dev/null +++ b/src/amltk/sklearn/metrics.py @@ -0,0 +1,91 @@ +"""Utilities for sklearn metrics.""" +from __future__ import annotations + +import warnings + +import numpy as np +from sklearn.metrics import get_scorer +from sklearn.metrics._scorer import _BaseScorer + +from amltk.optimization.metric import Metric + +# All of these bounds are from the perspective of an sklearn **scorer** +# where the are already negated +# these metrics are taken from `get_scorer` +_SCORER_BOUNDS = { + "explained_variance": (-np.inf, 1.0), # Default metric is positive + "r2": (-np.inf, 1.0), # Default metric is positive + "max_error": (-np.inf, 0), # Default metric is negative + "matthews_corrcoef": (-1.0, 1.0), # Default metric is positive + "neg_median_absolute_error": (-np.inf, 0), # Default metric is negative + "neg_mean_absolute_error": (-np.inf, 0), # Default metric is negative + "neg_mean_absolute_percentage_error": (-np.inf, 0), # Default metric is negative + "neg_mean_squared_error": (-np.inf, 0), # Default metric is negative + "neg_mean_squared_log_error": (-np.inf, 0), # Default metric is negative + "neg_root_mean_squared_error": (-np.inf, 0), # Default metric is negative + "neg_root_mean_squared_log_error": (-np.inf, 0), # Default metric is negative + "neg_mean_poisson_deviance": (-np.inf, 0), # Default metric is negative + "neg_mean_gamma_deviance": (-np.inf, 0), # Default metric is negative + "accuracy": (0, 1.0), # Default metric is positive + "top_k_accuracy": (0, 1.0), # Default metric is positive + "roc_auc": (0, 1.0), # Default metric is positive + "roc_auc_ovr": (0, 1.0), # Default metric is positive + "roc_auc_ovo": (0, 1.0), # Default metric is positive + "roc_auc_ovr_weighted": (0, 1.0), # Default metric is positive + "roc_auc_ovo_weighted": (0, 1.0), # Default metric is positive + "balanced_accuracy": (0, 1.0), # Default metric is positive + "average_precision": (0, 1.0), # Default metric is positive + "neg_log_loss": (-np.inf, 0), # Default metric is negative + "neg_brier_score": (-np.inf, 0), # Default metric is negative + "positive_likelihood_ratio": (0, np.inf), # Default metric is positive + "neg_negative_likelihood_ratio": (-np.inf, 0), # Default metric is negative + # Cluster metrics that use supervised evaluation + "adjusted_rand_score": (-0.5, 1.0), # Default metric is positive + "rand_score": (0, 1.0), # Default metric is positive + "homogeneity_score": (0, 1.0), # Default metric is positive + "completeness_score": (0, 1.0), # Default metric is positive + "v_measure_score": (0, 1.0), # Default metric is positive + "mutual_info_score": (0, np.inf), # Default metric is positive + # TODO: Not sure about the lower bound on this. + # Seems that 0 is pure randomness but theoretically it could be negative + "adjusted_mutual_info_score": (-1.0, 1.0), # Default metric is positive + "normalized_mutual_info_score": (0.0, 1.0), # Default metric is positive + "fowlkes_mallows_score": (0.0, 1.0), # Default metric is positive +} + + +def as_metric( + scorer: str | _BaseScorer, + *, + bounds: tuple[float, float] | None = None, + name: str | None = None, +) -> Metric: + """Convert a scorer to a metric.""" + match scorer: + case str(): + _scorer = get_scorer(scorer) + _name = scorer if name is None else name + case _BaseScorer(): + _scorer = scorer + if name is not None: + _name = name + else: + _name = scorer._score_func.__name__ + _name = f"neg_{_name}" if scorer._sign == -1 else _name + case _: + raise TypeError(f"Cannot convert {scorer!r} to a metric.") + + # This is using what sklearn use in their `__repr__` method + if bounds is None: + bounds = _SCORER_BOUNDS.get(_name, None) + + if bounds is None: + warnings.warn( + f"Cannot infer bounds for scorer {_name}. Please explicitly provide " + " them with the `bounds` argument or set them to `(-np.inf, np.inf)`.", + UserWarning, + stacklevel=2, + ) + + # Sklearn scorers are always positive + return Metric(name=_name, bounds=bounds, minimize=False)