diff --git a/examples/finetune_classifier.py b/examples/finetune_classifier.py index 7c59b3bec..5b4a8e5fa 100644 --- a/examples/finetune_classifier.py +++ b/examples/finetune_classifier.py @@ -4,6 +4,7 @@ support for the Apple Silicon (MPS) backend is still under development. """ +import logging import warnings import numpy as np @@ -23,27 +24,23 @@ module=r"google\.api_core\._python_version_support", ) +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + # ============================================================================= # Fine-tuning Configuration # For details and more options see FinetunedTabPFNClassifier +# +# These settings work well for the Higgs dataset. +# For other datasets, you may need to adjust these settings to get good results. # ============================================================================= # Training hyperparameters NUM_EPOCHS = 30 LEARNING_RATE = 2e-5 -# Data sampling configuration (dataset dependent) -# the ratio of the total dataset to be used for validation during training -VALIDATION_SPLIT_RATIO = 0.1 -# total context split into train/test -NUM_FINETUNE_CTX_PLUS_QUERY_SAMPLES = 10_000 -# the following means 0.2*10_000=2_000 test samples are used in training -FINETUNE_CTX_QUERY_SPLIT_RATIO = 0.2 -NUM_INFERENCE_SUBSAMPLE_SAMPLES = 50_000 -# to reduce memory usage during training we can use activation checkpointing, -# may not be necessary for small datasets -USE_ACTIVATION_CHECKPOINTING = True - # Ensemble configuration # number of estimators to use during finetuning NUM_ESTIMATORS_FINETUNE = 2 @@ -84,14 +81,11 @@ def main() -> None: ) # 2. Initial model evaluation on test set - inference_config = { - "SUBSAMPLE_SAMPLES": NUM_INFERENCE_SUBSAMPLE_SAMPLES, - } base_clf = TabPFNClassifier( device=[f"cuda:{i}" for i in range(torch.cuda.device_count())], n_estimators=NUM_ESTIMATORS_FINAL_INFERENCE, ignore_pretraining_limits=True, - inference_config=inference_config, + inference_config={"SUBSAMPLE_SAMPLES": 50_000}, ) base_clf.fit(X_train, y_train) @@ -110,15 +104,9 @@ def main() -> None: device="cuda" if torch.cuda.is_available() else "cpu", epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, - validation_split_ratio=VALIDATION_SPLIT_RATIO, - n_finetune_ctx_plus_query_samples=NUM_FINETUNE_CTX_PLUS_QUERY_SAMPLES, - finetune_ctx_query_split_ratio=FINETUNE_CTX_QUERY_SPLIT_RATIO, - n_inference_subsample_samples=NUM_INFERENCE_SUBSAMPLE_SAMPLES, - random_state=RANDOM_STATE, n_estimators_finetune=NUM_ESTIMATORS_FINETUNE, n_estimators_validation=NUM_ESTIMATORS_VALIDATION, n_estimators_final_inference=NUM_ESTIMATORS_FINAL_INFERENCE, - use_activation_checkpointing=USE_ACTIVATION_CHECKPOINTING, ) # 4. Call .fit() to start the fine-tuning process on the training data diff --git a/examples/finetune_regressor.py b/examples/finetune_regressor.py new file mode 100644 index 000000000..b14476594 --- /dev/null +++ b/examples/finetune_regressor.py @@ -0,0 +1,118 @@ +"""Example of fine-tuning a TabPFN regressor using the FinetunedTabPFNRegressor wrapper. + +Note: We recommend running the fine-tuning scripts on a CUDA-enabled GPU, as full +support for the Apple Silicon (MPS) backend is still under development. +""" + +import logging +import warnings + +import sklearn.datasets +import torch +from sklearn.metrics import mean_squared_error, r2_score +from sklearn.model_selection import train_test_split + +from tabpfn import TabPFNRegressor +from tabpfn.finetuning.finetuned_regressor import FinetunedTabPFNRegressor + +warnings.filterwarnings( + "ignore", + category=FutureWarning, + module=r"google\.api_core\._python_version_support", +) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + +# ============================================================================= +# Fine-tuning Configuration +# For details and more options see FinetunedTabPFNRegressor +# +# These settings work well for the California Housing dataset. +# For other datasets, you may need to adjust these settings to get good results. +# ============================================================================= + +# Training hyperparameters +NUM_EPOCHS = 30 +LEARNING_RATE = 1e-5 + +# We can fine-tune using almost the entire housing dataset +# in the context of the train batches. +N_FINETUNE_CTX_PLUS_QUERY_SAMPLES = 20_000 + +# Ensemble configuration +# number of estimators to use during finetuning +NUM_ESTIMATORS_FINETUNE = 8 +# number of estimators to use during train time validation +NUM_ESTIMATORS_VALIDATION = 8 +# number of estimators to use during final inference +NUM_ESTIMATORS_FINAL_INFERENCE = 8 + +# Reproducibility +RANDOM_STATE = 0 + + +def main() -> None: + data = sklearn.datasets.fetch_california_housing(as_frame=True) + X_all = data.data + y_all = data.target + + X_train, X_test, y_train, y_test = train_test_split( + X_all, y_all, test_size=0.1, random_state=RANDOM_STATE + ) + + print( + f"Loaded {len(X_train):,} samples for training and " + f"{len(X_test):,} samples for testing." + ) + + # 2. Initial model evaluation on test set + base_reg = TabPFNRegressor( + device=[f"cuda:{i}" for i in range(torch.cuda.device_count())], + n_estimators=NUM_ESTIMATORS_FINAL_INFERENCE, + ignore_pretraining_limits=True, + inference_config={"SUBSAMPLE_SAMPLES": 50_000}, + ) + base_reg.fit(X_train, y_train) + + base_pred = base_reg.predict(X_test) + mse = mean_squared_error(y_test, base_pred) + r2 = r2_score(y_test, base_pred) + + print(f"📊 Default TabPFN Test MSE: {mse:.4f}") + print(f"📊 Default TabPFN Test R²: {r2:.4f}\n") + + # 3. Initialize and run fine-tuning + print("--- 2. Initializing and Fitting Model ---\n") + + # Instantiate the wrapper with your desired hyperparameters + finetuned_reg = FinetunedTabPFNRegressor( + device="cuda" if torch.cuda.is_available() else "cpu", + epochs=NUM_EPOCHS, + learning_rate=LEARNING_RATE, + random_state=RANDOM_STATE, + n_finetune_ctx_plus_query_samples=N_FINETUNE_CTX_PLUS_QUERY_SAMPLES, + n_estimators_finetune=NUM_ESTIMATORS_FINETUNE, + n_estimators_validation=NUM_ESTIMATORS_VALIDATION, + n_estimators_final_inference=NUM_ESTIMATORS_FINAL_INFERENCE, + ) + + # 4. Call .fit() to start the fine-tuning process on the training data + finetuned_reg.fit(X_train.values, y_train.values) + print("\n") + + # 5. Evaluate the fine-tuned model + print("--- 3. Evaluating Model on Held-out Test Set ---\n") + y_pred = finetuned_reg.predict(X_test.values) + + mse = mean_squared_error(y_test, y_pred) + r2 = r2_score(y_test, y_pred) + + print(f"📊 Finetuned TabPFN Test MSE: {mse:.4f}") + print(f"📊 Finetuned TabPFN Test R²: {r2:.4f}") + + +if __name__ == "__main__": + main() diff --git a/src/tabpfn/finetuning/__init__.py b/src/tabpfn/finetuning/__init__.py index e69de29bb..4ebcbfed2 100644 --- a/src/tabpfn/finetuning/__init__.py +++ b/src/tabpfn/finetuning/__init__.py @@ -0,0 +1,15 @@ +"""Single-dataset fine-tuning wrappers for TabPFN models.""" + +from tabpfn.finetuning.data_util import ClassifierBatch, RegressorBatch +from tabpfn.finetuning.finetuned_base import EvalResult, FinetunedTabPFNBase +from tabpfn.finetuning.finetuned_classifier import FinetunedTabPFNClassifier +from tabpfn.finetuning.finetuned_regressor import FinetunedTabPFNRegressor + +__all__ = [ + "ClassifierBatch", + "EvalResult", + "FinetunedTabPFNBase", + "FinetunedTabPFNClassifier", + "FinetunedTabPFNRegressor", + "RegressorBatch", +] diff --git a/src/tabpfn/finetuning/data_util.py b/src/tabpfn/finetuning/data_util.py index adc689e6f..ac0ce59a3 100644 --- a/src/tabpfn/finetuning/data_util.py +++ b/src/tabpfn/finetuning/data_util.py @@ -2,7 +2,9 @@ from __future__ import annotations +import warnings from collections.abc import Sequence +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Literal from typing_extensions import override @@ -23,6 +25,62 @@ from tabpfn.constants import XType, YType +@dataclass +class ClassifierBatch: + """Batch data for classifier fine-tuning. + + Attributes: + X_context: Preprocessed training features (list per estimator). + X_query: Preprocessed test features (list per estimator). + y_context: Preprocessed training targets (list per estimator). + y_query: Raw test target tensor. + cat_indices: Categorical feature indices (list per estimator). + configs: Preprocessing configurations used for this batch. + """ + + X_context: list[torch.Tensor] + X_query: list[torch.Tensor] + y_context: list[torch.Tensor] + y_query: torch.Tensor + # In a single dataset sample, this is "per-estimator": + # list[list[int] | None] + # After collation (batch_size datasets), categorical indices must be batched: + # list[list[list[int] | None]] + # The batched structure is required by InferenceEngineBatchedNoPreprocessing. + cat_indices: list[list[int] | None] | list[list[list[int] | None]] + configs: list[Any] + + +@dataclass +class RegressorBatch: + """Batch data for regressor fine-tuning. + + Attributes: + X_context: Preprocessed training features (list per estimator). + X_query: Preprocessed test features (list per estimator). + y_context: Preprocessed standardized training targets (list per estimator). + y_query: Standardized test target tensor. + cat_indices: Categorical feature indices (list per estimator). + configs: Preprocessing configurations used for this batch. + raw_space_bardist: Bar distribution in raw (original) target space. + znorm_space_bardist: Bar distribution in z-normalized target space. + X_query_raw: Original unprocessed test features. + y_query_raw: Original unprocessed test targets. + """ + + X_context: list[torch.Tensor] + X_query: list[torch.Tensor] + y_context: list[torch.Tensor] + y_query: torch.Tensor + # See ClassifierBatch.cat_indices for the rationale of this union type. + cat_indices: list[list[int] | None] | list[list[list[int] | None]] + configs: list[Any] + raw_space_bardist: FullSupportBarDistribution + znorm_space_bardist: FullSupportBarDistribution + X_query_raw: torch.Tensor + y_query_raw: torch.Tensor + + def _take(obj: Any, idx: np.ndarray) -> Any: """Index obj by idx using .iloc when available (for pd.DataFrame), otherwise [].""" return obj.iloc[idx] if hasattr(obj, "iloc") else obj[idx] @@ -229,85 +287,57 @@ def __len__(self) -> int: return len(self.configs) @override - def __getitem__(self, index: int) -> tuple[Any, ...]: # noqa: C901, PLR0912 + def __getitem__(self, index: int) -> ClassifierBatch | RegressorBatch: # noqa: C901, PLR0912 """Retrieves, splits, and preprocesses the dataset config at the index. Performs train/test splitting and applies potentially multiple preprocessing pipelines defined in the dataset's configuration. Args: - index (int): The index of the dataset configuration in the + index: The index of the dataset configuration in the `dataset_config_collection` to process. Returns: - Tuple: A tuple containing the processed data and metadata. Each - element in the tuple is a list whose length equals the number - of estimators in the TabPFN ensemble. As such each element - in the list corresponds to the preprocessed data/configs for a - single ensemble member. - - The structure depends on the task type derived from the dataset - configuration object (`RegressorDatasetConfig` or - `ClassifierDatasetConfig`): - - For **Classification** tasks (`ClassifierDatasetConfig`): - * `X_trains_preprocessed` (List[torch.Tensor]): List of preprocessed - training feature tensors (one per preprocessing pipeline). - * `X_tests_preprocessed` (List[torch.Tensor]): List of preprocessed - test feature tensors (one per preprocessing pipeline). - * `y_trains_preprocessed` (List[torch.Tensor]): List of preprocessed - training target tensors (one per preprocessing pipeline). - * `y_test_raw` (torch.Tensor): Original, unprocessed test target - tensor. - * `cat_ixs` (List[Optional[List[int]]]): List of categorical feature - indices corresponding to each preprocessed X_train/X_test. - * `conf` (List): The list of preprocessing configurations used for - this dataset (usually reflects ensemble settings). - - For **Regression** tasks (`RegressorDatasetConfig`): - * `X_trains_preprocessed` (List[torch.Tensor]): List of preprocessed - training feature tensors. - * `X_tests_preprocessed` (List[torch.Tensor]): List of preprocessed - test feature tensors. - * `y_trains_preprocessed` (List[torch.Tensor]): List of preprocessed - *standardized* training target tensors. - * `y_test_standardized` (torch.Tensor): *Standardized* test target - tensor (derived from `y_full_standardised`). - * `cat_ixs` (List[Optional[List[int]]]): List of categorical feature - indices corresponding to each preprocessed X_train/X_test. - * `conf` (List): The list of preprocessing configurations used. - * `raw_space_bardist_` (FullSupportBarDistribution): Binning class - for target variable (specific to the regression config). The - calculations will be on raw data in raw space. - * `znorm_space_bardist_` (FullSupportBarDistribution): Binning class for - target variable (specific to the regression config). The calculations - will be on standardized data in znorm space. - * `x_test_raw` (torch.Tensor): Original, unprocessed test feature - tensor. - * `y_test_raw` (torch.Tensor): Original, unprocessed test target - tensor. + A ClassifierBatch or RegressorBatch dataclass containing the + processed data and metadata. Each list field has length equal to + the number of estimators in the TabPFN ensemble. + + For **Classification** tasks: Returns a ClassifierBatch with: + - X_context: Preprocessed training features (per estimator) + - X_query: Preprocessed test features (per estimator) + - y_context: Preprocessed training targets (per estimator) + - y_query: Raw test target tensor + - cat_indices: Categorical feature indices (per estimator) + - configs: Preprocessing configurations used + + For **Regression** tasks: Returns a RegressorBatch with all + ClassifierBatch fields plus: + - raw_space_bardist: Bar distribution in raw target space + - znorm_space_bardist: Bar distribution in z-normalized space + - X_query_raw: Original unprocessed test features + - y_query_raw: Original unprocessed test targets Raises: - IndexError: If the index is out of the bounds of the dataset collection. + IndexError: If the index is out of the bounds of the dataset + collection. ValueError: If the dataset configuration type at the index is not - recognized (neither `RegressorDatasetConfig` nor - `ClassifierDatasetConfig`). - AssertionError: If sanity checks during processing fail (e.g., - standardized mean not close to zero in regression). + recognized. + AssertionError: If sanity checks during processing fail. """ if index < 0 or index >= len(self): raise IndexError("Index out of bounds.") config = self.configs[index] + is_regression_task = isinstance(config, RegressorDatasetConfig) + # Check type of Dataset Config - if isinstance(config, RegressorDatasetConfig): + if is_regression_task: conf = config.config x_full_raw = config.X_raw y_full_raw = config.y_raw cat_ix = config.cat_ix znorm_space_bardist_ = config.znorm_space_bardist_ - regression_task = True else: assert isinstance(config, ClassifierDatasetConfig), ( "Invalid dataset config type" @@ -316,9 +346,8 @@ def __getitem__(self, index: int) -> tuple[Any, ...]: # noqa: C901, PLR0912 x_full_raw = config.X_raw y_full_raw = config.y_raw cat_ix = config.cat_ix - regression_task = False - stratify_y = y_full_raw if not regression_task and self.stratify else None + stratify_y = y_full_raw if not is_regression_task and self.stratify else None x_train_raw, x_test_raw, y_train_raw, y_test_raw = self.split_fn( x_full_raw, y_full_raw, stratify=stratify_y ) @@ -329,9 +358,21 @@ def __getitem__(self, index: int) -> tuple[Any, ...]: # noqa: C901, PLR0912 # it is not set as an attribute of the Regressor class # This however makes also sense when considering that # this attribute changes on every dataset - if regression_task: + if is_regression_task: train_mean = np.mean(y_train_raw) train_std = np.std(y_train_raw) + + eps = 1e-8 + if train_std < eps: + warnings.warn( + f"Target variable has constant or near-constant values " + f"(std={train_std:.2e}). Adding epsilon={eps} to prevent " + f"division by zero in standardization.", + UserWarning, + stacklevel=2, + ) + train_std = eps + y_test_standardized = (y_test_raw - train_mean) / train_std y_train_standardized = (y_train_raw - train_mean) / train_std raw_space_bardist_ = FullSupportBarDistribution( @@ -381,7 +422,7 @@ def __getitem__(self, index: int) -> tuple[Any, ...]: # noqa: C901, PLR0912 y_trains_preprocessed[i], dtype=torch.float32 ) - if regression_task and not isinstance(y_test_standardized, torch.Tensor): + if is_regression_task and not isinstance(y_test_standardized, torch.Tensor): y_test_standardized = torch.from_numpy(y_test_standardized) if torch.is_floating_point(y_test_standardized): y_test_standardized = y_test_standardized.float() @@ -391,58 +432,138 @@ def __getitem__(self, index: int) -> tuple[Any, ...]: # noqa: C901, PLR0912 x_test_raw = torch.from_numpy(x_test_raw) y_test_raw = torch.from_numpy(y_test_raw) - # Also return raw_target variable because of flexiblity - # in optimisation space -> see examples/ - # Also return corresponding target variable binning - # classes raw_space_bardist_ and znorm_space_bardist_ - if regression_task: - return ( - X_trains_preprocessed, - X_tests_preprocessed, - y_trains_preprocessed, - y_test_standardized, - cat_ixs, - conf, - raw_space_bardist_, - znorm_space_bardist_, - x_test_raw, - y_test_raw, + # Return structured batch data using dataclasses for clarity + if is_regression_task: + return RegressorBatch( + X_context=X_trains_preprocessed, + X_query=X_tests_preprocessed, + y_context=y_trains_preprocessed, + y_query=y_test_standardized, + cat_indices=list(cat_ixs), + configs=list(conf), + raw_space_bardist=raw_space_bardist_, + znorm_space_bardist=znorm_space_bardist_, + X_query_raw=x_test_raw, + y_query_raw=y_test_raw, ) - return ( - X_trains_preprocessed, - X_tests_preprocessed, - y_trains_preprocessed, - y_test_raw, - cat_ixs, - conf, + return ClassifierBatch( + X_context=X_trains_preprocessed, + X_query=X_tests_preprocessed, + y_context=y_trains_preprocessed, + y_query=y_test_raw, + cat_indices=list(cat_ixs), + configs=list(conf), ) -def meta_dataset_collator(batch: list, padding_val: float = 0.0) -> tuple: +def _collate_list_field( + batch: list, + field_name: str, + num_estimators: int, + padding_val: float, +) -> list: + """Collate a list field (per-estimator data) from batch items.""" + batch_sz = len(batch) + field_values = [getattr(b, field_name) for b in batch] + estim_list = [] + for estim_no in range(num_estimators): + if isinstance(field_values[0][0], torch.Tensor): + labels = field_values[0][0].ndim == 1 + estim_list.append( + torch.stack( + pad_tensors( + [field_values[r][estim_no] for r in range(batch_sz)], + padding_val=padding_val, + labels=labels, + ) + ) + ) + else: + estim_list.append( + list(field_values[r][estim_no] for r in range(batch_sz)) # noqa: C400 + ) + return estim_list + + +def _collate_tensor_field( + batch: list, + field_name: str, + padding_val: float, +) -> torch.Tensor: + """Collate a tensor field from batch items.""" + batch_sz = len(batch) + field_values = [getattr(b, field_name) for b in batch] + labels = field_values[0].ndim == 1 + return torch.stack( + pad_tensors( + [field_values[r] for r in range(batch_sz)], + padding_val=padding_val, + labels=labels, + ) + ) + + +def _collate_cat_indices( + batch: Sequence[ClassifierBatch | RegressorBatch], +) -> list[list[list[int] | None]]: + """Collate cat indices into the batched shape expected by the batched executor. + + In fine-tuning, the batched inference engine expects categorical indices as: + [dataset_batch][estimator][cat_index] + + Individual dataset samples carry categorical indices as: + [estimator][cat_index] (or None per estimator) + """ + batched_cat_indices: list[list[list[int] | None]] = [] + for item in batch: + cat_indices = item.cat_indices + + # Empty is unambiguous (no estimators / no categorical features). + if len(cat_indices) == 0: + batched_cat_indices.append([]) + continue + + # If the first element is a list of ints, it's the per-estimator form. + first = cat_indices[0] + if first is None: + batched_cat_indices.append(cat_indices) # type: ignore[arg-type] + continue + + if len(first) == 0 or isinstance(first[0], int): + batched_cat_indices.append(cat_indices) # type: ignore[arg-type] + continue + + # Otherwise it's already batched: [dataset_batch][estimator][...]. + # We only support batch_size=1 in this collator. + assert len(cat_indices) == 1 + batched_cat_indices.append(cat_indices[0]) # type: ignore[index] + + return batched_cat_indices + + +def meta_dataset_collator( + batch: list[ClassifierBatch | RegressorBatch], + padding_val: float = 0.0, +) -> ClassifierBatch | RegressorBatch: """Collate function for torch.utils.data.DataLoader. Designed for batches from DatasetCollectionWithPreprocessing. Takes a list of dataset samples (the batch) and structures them - into a single tuple suitable for model input, often for fine-tuning - using `fit_from_preprocessed`. + into a single batch dataclass suitable for model input, often for + fine-tuning using `fit_from_preprocessed`. Handles samples containing nested lists (e.g., for ensemble members) and tensors. Pads tensors to consistent shapes using `pad_tensors` before stacking. Non-tensor items are grouped into lists. Args: - batch (list): A list where each element is one sample from the - Dataset. Samples often contain multiple components like - features, labels, configs, etc., potentially nested in lists. - padding_val (float): Value used for padding tensors to allow - stacking across the batch dimension. + batch: A list of ClassifierBatch or RegressorBatch dataclass instances. + padding_val: Value used for padding tensors to allow stacking across + the batch dimension. Returns: - tuple: A tuple where each element is a collated component from the - input batch (e.g., stacked tensors, lists of configs). - The structure matches the input required by methods like - `fit_from_preprocessed`. + A collated ClassifierBatch or RegressorBatch with stacked/padded data. Note: Currently only implemented and tested for `batch_size = 1`, @@ -450,43 +571,39 @@ def meta_dataset_collator(batch: list, padding_val: float = 0.0) -> tuple: """ batch_sz = len(batch) assert batch_sz == 1, "Only Implemented and tested for batch size of 1" - num_estim = len(batch[0][0]) - items_list = [] - for item_idx in range(len(batch[0])): - if isinstance(batch[0][item_idx], list): - estim_list = [] - for estim_no in range(num_estim): - if isinstance(batch[0][item_idx][0], torch.Tensor): - labels = batch[0][item_idx][0].ndim == 1 - estim_list.append( - torch.stack( - pad_tensors( - [batch[r][item_idx][estim_no] for r in range(batch_sz)], - padding_val=padding_val, - labels=labels, - ) - ) - ) - else: - estim_list.append( - list(batch[r][item_idx][estim_no] for r in range(batch_sz)) # noqa: C400 - ) - items_list.append(estim_list) - elif isinstance(batch[0][item_idx], torch.Tensor): - labels = batch[0][item_idx].ndim == 1 - items_list.append( - torch.stack( - pad_tensors( - [batch[r][item_idx] for r in range(batch_sz)], - padding_val=padding_val, - labels=labels, - ) - ) - ) - else: - items_list.append([batch[r][item_idx] for r in range(batch_sz)]) - return tuple(items_list) + first_item = batch[0] + num_estimators = len(first_item.X_context) + + if isinstance(first_item, ClassifierBatch): + return ClassifierBatch( + X_context=_collate_list_field( + batch, "X_context", num_estimators, padding_val + ), + X_query=_collate_list_field(batch, "X_query", num_estimators, padding_val), + y_context=_collate_list_field( + batch, "y_context", num_estimators, padding_val + ), + y_query=_collate_tensor_field(batch, "y_query", padding_val), + cat_indices=_collate_cat_indices(batch), + configs=_collate_list_field(batch, "configs", num_estimators, padding_val), + ) + + # RegressorBatch - since batch_size=1, we extract the single item for bardist + # first_item is already narrowed to RegressorBatch by the isinstance check above + assert isinstance(first_item, RegressorBatch) + return RegressorBatch( + X_context=_collate_list_field(batch, "X_context", num_estimators, padding_val), + X_query=_collate_list_field(batch, "X_query", num_estimators, padding_val), + y_context=_collate_list_field(batch, "y_context", num_estimators, padding_val), + y_query=_collate_tensor_field(batch, "y_query", padding_val), + cat_indices=_collate_cat_indices(batch), + configs=_collate_list_field(batch, "configs", num_estimators, padding_val), + raw_space_bardist=first_item.raw_space_bardist, + znorm_space_bardist=first_item.znorm_space_bardist, + X_query_raw=_collate_tensor_field(batch, "X_query_raw", padding_val), + y_query_raw=_collate_tensor_field(batch, "y_query_raw", padding_val), + ) def shuffle_and_chunk_data( diff --git a/src/tabpfn/finetuning/finetuned_base.py b/src/tabpfn/finetuning/finetuned_base.py new file mode 100644 index 000000000..552535a92 --- /dev/null +++ b/src/tabpfn/finetuning/finetuned_base.py @@ -0,0 +1,675 @@ +"""Abstract base class for fine-tuning TabPFN models. + +This module provides the FinetunedTabPFNBase class, which contains shared +functionality for fine-tuning TabPFN on a specific dataset using the familiar +scikit-learn .fit() and .predict() API. +""" + +from __future__ import annotations + +import copy +import logging +import warnings +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +import torch +from sklearn.base import BaseEstimator +from sklearn.model_selection import train_test_split +from torch.nn.utils import clip_grad_norm_ +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from tabpfn.finetuning._torch_compat import GradScaler, autocast, sdpa_kernel_context +from tabpfn.finetuning.data_util import ( + ClassifierBatch, + RegressorBatch, + get_preprocessed_dataset_chunks, + meta_dataset_collator, +) +from tabpfn.finetuning.train_util import ( + get_and_init_optimizer, + get_checkpoint_path_and_epoch_from_output_dir, + get_cosine_schedule_with_warmup, + save_checkpoint, +) +from tabpfn.utils import infer_devices, validate_Xy_fit + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from tabpfn.constants import XType, YType + +# Currently, we only support a batch size of 1 for finetuning. +META_BATCH_SIZE = 1 + +# Hard limit on the number of samples to use for validation. +# This is used to avoid spending too much time on validation +# and prevent OOM issues for very large datasets. +MAX_VALIDATION_SAMPLES = 50_000 + + +@dataclass +class EvalResult: + """Container for evaluation results. + + Attributes: + primary: The primary metric used for early stopping decisions. + secondary: Additional metrics for logging purposes only. + """ + + primary: float + secondary: dict[str, float] = field(default_factory=dict) + + +class FinetunedTabPFNBase(BaseEstimator, ABC): + """Abstract base class for fine-tuning TabPFN models. + + This class encapsulates the shared fine-tuning logic, allowing you to + fine-tune TabPFN on a specific dataset using the familiar .fit() and + .predict() API. + + Args: + device: The device to run the model on. Defaults to "cuda". + epochs: The total number of passes through the fine-tuning data. + Defaults to 30. + learning_rate: The learning rate for the AdamW optimizer. A small value + is crucial for stable fine-tuning. Defaults to 1e-5. + weight_decay: The weight decay for the AdamW optimizer. Defaults to 0.01. + validation_split_ratio: Fraction of the original training data reserved + as a validation set for early stopping and monitoring. Defaults to 0.1. + n_finetune_ctx_plus_query_samples: The total number of samples per + meta-dataset during fine-tuning (context plus query) before applying + the `finetune_ctx_query_split_ratio`. Defaults to 10_000. + finetune_ctx_query_split_ratio: The proportion of each fine-tuning + meta-dataset to use as query samples for calculating the loss. The + remainder is used as context. Defaults to 0.2. + n_inference_subsample_samples: The total number of subsampled training + samples per estimator during validation and final inference. + Defaults to 50_000. + random_state: Seed for reproducibility of data splitting and model + initialization. Defaults to 0. + early_stopping: Whether to use early stopping based on validation + performance. Defaults to True. + early_stopping_patience: Number of epochs to wait for improvement before + early stopping. Defaults to 8. + min_delta: Minimum change in metric to be considered as an improvement. + Defaults to 1e-4. + grad_clip_value: Maximum norm for gradient clipping. If None, gradient + clipping is disabled. Gradient clipping helps stabilize training by + preventing exploding gradients. Defaults to 1.0. + use_lr_scheduler: Whether to use a learning rate scheduler (linear warmup + with optional cosine decay) during fine-tuning. Defaults to True. + lr_warmup_only: If True, only performs linear warmup to the base learning + rate and then keeps it constant. If False, applies cosine decay after + warmup. Defaults to False. + n_estimators_finetune: If set, overrides `n_estimators` of the underlying + estimator only during fine-tuning to control the number of + estimators (ensemble size) used in the training loop. If None, the + value from `kwargs` or the estimator default is used. + Defaults to 2. + n_estimators_validation: If set, overrides `n_estimators` only for + validation-time evaluation during fine-tuning (early-stopping / + monitoring). If None, the value from `kwargs` or the + estimator default is used. Defaults to 2. + n_estimators_final_inference: If set, overrides `n_estimators` only for + the final fitted inference model that is used after fine-tuning. If + None, the value from `kwargs` or the estimator default is used. + Defaults to 8. + use_activation_checkpointing: Whether to use activation checkpointing to + reduce memory usage. Defaults to True. + save_checkpoint_interval: Number of epochs between checkpoint saves. This + only has an effect if `output_dir` is provided during the `fit()` call. + If None, no intermediate checkpoints are saved. The best model checkpoint + is always saved regardless of this setting. Defaults to 10. + """ + + def __init__( # noqa: PLR0913 + self, + *, + device: str = "cuda", + epochs: int = 30, + learning_rate: float = 1e-5, + weight_decay: float = 0.01, + validation_split_ratio: float = 0.1, + n_finetune_ctx_plus_query_samples: int = 10_000, + finetune_ctx_query_split_ratio: float = 0.2, + n_inference_subsample_samples: int = 50_000, + random_state: int = 0, + early_stopping: bool = True, + early_stopping_patience: int = 8, + min_delta: float = 1e-4, + grad_clip_value: float | None = 1.0, + use_lr_scheduler: bool = True, + lr_warmup_only: bool = False, + n_estimators_finetune: int = 2, + n_estimators_validation: int = 2, + n_estimators_final_inference: int = 8, + use_activation_checkpointing: bool = True, + save_checkpoint_interval: int | None = 10, + ): + super().__init__() + self.device = device + self.epochs = epochs + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.validation_split_ratio = validation_split_ratio + self.n_finetune_ctx_plus_query_samples = n_finetune_ctx_plus_query_samples + self.finetune_ctx_query_split_ratio = finetune_ctx_query_split_ratio + self.n_inference_subsample_samples = n_inference_subsample_samples + self.random_state = random_state + self.early_stopping = early_stopping + self.early_stopping_patience = early_stopping_patience + self.min_delta = min_delta + self.grad_clip_value = grad_clip_value + self.use_lr_scheduler = use_lr_scheduler + self.lr_warmup_only = lr_warmup_only + self.n_estimators_finetune = n_estimators_finetune + self.n_estimators_validation = n_estimators_validation + self.n_estimators_final_inference = n_estimators_final_inference + self.use_activation_checkpointing = use_activation_checkpointing + self.save_checkpoint_interval = save_checkpoint_interval + self.meta_batch_size = META_BATCH_SIZE + + def _build_estimator_config( + self, + base_config: dict[str, Any], + n_estimators_override: int | None, + ) -> dict[str, Any]: + """Return a deep-copy of base_config with an optional n_estimators override.""" + config = copy.deepcopy(base_config) + if n_estimators_override is not None: + config["n_estimators"] = n_estimators_override + return config + + def _build_eval_config( + self, + base_config: dict[str, Any], + n_estimators_override: int | None, + ) -> dict[str, Any]: + """Return eval config with n_estimators override and subsample setting.""" + config = self._build_estimator_config(base_config, n_estimators_override) + existing = dict(config.get("inference_config", {}) or {}) + existing["SUBSAMPLE_SAMPLES"] = self.n_inference_subsample_samples + config["inference_config"] = existing + return config + + @property + @abstractmethod + def _estimator_kwargs(self) -> dict[str, Any]: + """Return the task-specific estimator kwargs.""" + ... + + @property + @abstractmethod + def _model_type(self) -> Literal["classifier", "regressor"]: + """Return the model type string ('classifier' or 'regressor').""" + ... + + @property + @abstractmethod + def _metric_name(self) -> str: + """Return the name of the primary metric for logging.""" + ... + + @abstractmethod + def _create_estimator(self, config: dict[str, Any]) -> Any: + """Create and return the underlying TabPFN estimator with the given config.""" + ... + + @abstractmethod + def _setup_estimator(self) -> None: + """Perform any task-specific setup after estimator creation.""" + ... + + @abstractmethod + def _setup_batch(self, batch: ClassifierBatch | RegressorBatch) -> None: + """Perform any batch-specific setup before the forward pass.""" + ... + + @abstractmethod + def _should_skip_batch(self, batch: ClassifierBatch | RegressorBatch) -> bool: + """Check if the batch should be skipped.""" + ... + + @abstractmethod + def _forward_with_loss( + self, + batch: ClassifierBatch | RegressorBatch, + ) -> torch.Tensor: + """Perform forward pass and compute loss for the given batch. + + Args: + batch: The batch tuple from the dataloader. + + Returns: + The computed loss tensor. + """ + ... + + @abstractmethod + def _evaluate_model( + self, + eval_config: dict[str, Any], + X_train: np.ndarray, + y_train: np.ndarray, + X_val: np.ndarray, + y_val: np.ndarray, + ) -> EvalResult: + """Evaluate the model on validation data and return metrics. + + Args: + eval_config: Configuration dictionary for the evaluation estimator. + X_train: Training input samples. + y_train: Training target values. + X_val: Validation input samples. + y_val: Validation target values. + + Returns: + EvalResult with primary metric for early stopping and secondary + metrics for logging. + """ + ... + + @abstractmethod + def _is_improvement(self, current: float, best: float) -> bool: + """Return True if current metric is an improvement over best. + + Args: + current: The current metric value. + best: The best metric value seen so far. + + Returns: + True if current is better than best (accounting for min_delta). + """ + ... + + @abstractmethod + def _get_initial_best_metric(self) -> float: + """Return initial 'best' metric (inf for min, -inf for max).""" + ... + + @abstractmethod + def _get_checkpoint_metrics(self, eval_result: EvalResult) -> dict[str, float]: + """Return the metrics dict to save in checkpoints.""" + ... + + @abstractmethod + def _log_epoch_evaluation( + self, epoch: int, eval_result: EvalResult, mean_train_loss: float | None + ) -> None: + """Log the evaluation results for the current epoch.""" + ... + + @abstractmethod + def _setup_inference_model( + self, final_inference_eval_config: dict[str, Any] + ) -> None: + """Set up the final inference model after fine-tuning completes.""" + ... + + @abstractmethod + def predict(self, X: np.ndarray) -> np.ndarray: + """Predict target values for X.""" + ... + + def _get_train_val_split( + self, X: np.ndarray, y: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Split data into train/validation sets with task-specific options.""" + n_samples = len(y) + desired_val_size = int(n_samples * self.validation_split_ratio) + + if desired_val_size > MAX_VALIDATION_SAMPLES: + test_size = MAX_VALIDATION_SAMPLES + warnings.warn( + f"Validation set size would be {desired_val_size:,} samples " + f"based on validation_split_ratio=" + f"{self.validation_split_ratio:.2f}, but limiting to " + f"{MAX_VALIDATION_SAMPLES:,} samples to avoid excessive " + f"validation time and memory usage.", + UserWarning, + stacklevel=3, + ) + else: + test_size = self.validation_split_ratio + + return train_test_split( # type: ignore[return-value] + X, + y, + test_size=test_size, + random_state=self.random_state, + stratify=y if self._model_type == "classifier" else None, + ) + + def fit( + self, X: XType, y: YType, output_dir: Path | None = None + ) -> FinetunedTabPFNBase: + """Fine-tune the TabPFN model on the provided training data. + + Args: + X: The training input samples of shape (n_samples, n_features). + y: The target values of shape (n_samples,). + output_dir: Directory path for saving checkpoints. If None, no + checkpointing is performed and progress will be lost if + training is interrupted. + + Returns: + The fitted instance itself. + """ + if output_dir is None: + warnings.warn( + "`output_dir` is not set. This means no checkpointing will be done and " + "all progress will be lost if the training is interrupted.", + UserWarning, + stacklevel=2, + ) + else: + output_dir.mkdir(parents=True, exist_ok=True) + + return self._fit(X=X, y=y, output_dir=output_dir) + + def _fit( # noqa: C901,PLR0912 + self, + X: XType, + y: YType, + output_dir: Path | None = None, + ) -> FinetunedTabPFNBase: + """Internal implementation of fit that runs the finetuning loop.""" + # Store the original training size for checkpoint naming + train_size = X.shape[0] + + inference_config = copy.deepcopy( + self._estimator_kwargs.get("inference_config", {}) + ) + base_estimator_config: dict[str, Any] = { + **self._estimator_kwargs, + "ignore_pretraining_limits": True, + "device": self.device, + "random_state": self.random_state, + "inference_config": inference_config, + } + + # Config used for the finetuning loop. + finetuning_estimator_config = self._build_estimator_config( + base_estimator_config, + self.n_estimators_finetune, + ) + + # Configs used for validation-time evaluation and final inference. + validation_eval_config = self._build_eval_config( + base_estimator_config, + self.n_estimators_validation, + ) + final_inference_eval_config = self._build_eval_config( + base_estimator_config, + self.n_estimators_final_inference, + ) + + eval_devices = infer_devices(self.device) + validation_eval_config["device"] = eval_devices + final_inference_eval_config["device"] = eval_devices + + epoch_to_start_from = 0 + checkpoint_path = None + if output_dir is not None: + checkpoint_path, epoch_to_start_from = ( + get_checkpoint_path_and_epoch_from_output_dir( + output_dir=output_dir, + train_size=train_size, + get_best=False, + ) + ) + if checkpoint_path is not None: + logger.info( + f"Restarting training from checkpoint {checkpoint_path} at epoch " + f"{epoch_to_start_from}", + ) + finetuning_estimator_config["model_path"] = checkpoint_path + + self.finetuned_estimator_ = self._create_estimator(finetuning_estimator_config) + self._setup_estimator() + + X, y, _, _ = validate_Xy_fit( + X, + y, + estimator=self.finetuned_estimator_, + ensure_y_numeric=self._model_type == "regressor", + max_num_samples=-1, # ignored if ignore_pretraining_limits is True + max_num_features=-1, # ignored if ignore_pretraining_limits is True + ignore_pretraining_limits=True, + ) + + self.X_ = X + self.y_ = y + + X_train, X_val, y_train, y_val = self._get_train_val_split(X, y) + + # Calculate the context size used during finetuning. + n_finetune_ctx_plus_query_samples = min( + self.n_finetune_ctx_plus_query_samples, + len(y_train), + ) + + self.finetuned_estimator_._initialize_model_variables() + self.finetuned_estimator_.model_.to(self.device) + + if self.use_activation_checkpointing: + self.finetuned_estimator_.model_.recompute_layer = True # type: ignore + + optimizer = get_and_init_optimizer( + model_parameters=self.finetuned_estimator_.model_.parameters(), # type: ignore + learning_rate=self.learning_rate, + weight_decay=self.weight_decay, + checkpoint_path=checkpoint_path, + device=self.device, + ) + + use_amp = self.device.startswith("cuda") and torch.cuda.is_available() + scaler = GradScaler() if use_amp else None # type: ignore + + logger.info("--- 🚀 Starting Fine-tuning ---") + + best_metric: float = self._get_initial_best_metric() + patience_counter = 0 + best_model = None + + scheduler: LambdaLR | None = None + + for epoch in range(epoch_to_start_from, self.epochs): + # Per-epoch aggregates for cleaner learning curves. + epoch_loss_sum = 0.0 + epoch_batches = 0 + + # Regenerate datasets each epoch with a different random_state + training_splitter = partial( + train_test_split, + test_size=self.finetune_ctx_query_split_ratio, + random_state=self.random_state + epoch, + ) + + training_datasets = get_preprocessed_dataset_chunks( + calling_instance=self.finetuned_estimator_, + X_raw=X_train, + y_raw=y_train, + split_fn=training_splitter, + max_data_size=n_finetune_ctx_plus_query_samples, + model_type=self._model_type, + equal_split_size=False, + seed=self.random_state + epoch, + ) + + dataloader_generator = torch.Generator().manual_seed( + self.random_state + epoch + ) + finetuning_dataloader = DataLoader( + training_datasets, + batch_size=self.meta_batch_size, + collate_fn=meta_dataset_collator, + shuffle=True, + generator=dataloader_generator, + ) + + # Instantiate the LR scheduler only once + if self.use_lr_scheduler and scheduler is None: + steps_per_epoch = len(finetuning_dataloader) + if steps_per_epoch == 0: + logger.warning( + "No training batches available; ending training early.", + ) + break + + total_steps = steps_per_epoch * self.epochs + warmup_steps = int(total_steps * 0.1) + + lrate_schedule_fn = get_cosine_schedule_with_warmup( + total_steps=total_steps, + warmup_steps=warmup_steps, + warmup_only=self.lr_warmup_only, + ) + scheduler = LambdaLR(optimizer, lr_lambda=lrate_schedule_fn) + + logger.info( + "Using LambdaLR %s schedule: total_steps=%d, warmup_steps=%d", + "warmup-only (constant LR after warmup)" + if self.lr_warmup_only + else "warmup+cosine", + total_steps, + warmup_steps, + ) + + progress_bar = tqdm( + finetuning_dataloader, + desc=f"Finetuning Epoch {epoch + 1}/{self.epochs}", + ) + for batch in progress_bar: + optimizer.zero_grad() + + if self._should_skip_batch(batch): + continue + + self._setup_batch(batch) + + self.finetuned_estimator_.fit_from_preprocessed( + batch.X_context, + batch.y_context, + batch.cat_indices, + batch.configs, + ) + + use_scaler = use_amp and scaler is not None + + with autocast(enabled=use_scaler), sdpa_kernel_context(): # type: ignore + loss = self._forward_with_loss(batch) + + if use_scaler: + with sdpa_kernel_context(): + scaler.scale(loss).backward() # type: ignore + scaler.unscale_(optimizer) # type: ignore + + if self.grad_clip_value is not None: + clip_grad_norm_( + self.finetuned_estimator_.model_.parameters(), # type: ignore + self.grad_clip_value, + ) + + scaler.step(optimizer) # type: ignore + scaler.update() # type: ignore + else: + with sdpa_kernel_context(): + loss.backward() + + if self.grad_clip_value is not None: + clip_grad_norm_( + self.finetuned_estimator_.model_.parameters(), # type: ignore + self.grad_clip_value, + ) + + optimizer.step() + + if scheduler is not None: + scheduler.step() + + loss_scalar = float(loss.detach().item()) + + epoch_loss_sum += loss_scalar + epoch_batches += 1 + + progress_bar.set_postfix( + loss=f"{loss_scalar:.4f}", + ) + + mean_train_loss = ( + epoch_loss_sum / epoch_batches if epoch_batches > 0 else None + ) + + eval_result = self._evaluate_model( + validation_eval_config, + X_train, # pyright: ignore[reportArgumentType] + y_train, # pyright: ignore[reportArgumentType] + X_val, # pyright: ignore[reportArgumentType] + y_val, # pyright: ignore[reportArgumentType] + ) + + self._log_epoch_evaluation(epoch, eval_result, mean_train_loss) + + primary_metric = eval_result.primary + + if output_dir is not None and not np.isnan(primary_metric): + save_interval_checkpoint = ( + self.save_checkpoint_interval is not None + and (epoch + 1) % self.save_checkpoint_interval == 0 + ) + + is_best = self._is_improvement(primary_metric, best_metric) + + if save_interval_checkpoint or is_best: + save_checkpoint( + estimator=self.finetuned_estimator_, + output_dir=output_dir, + epoch=epoch + 1, + optimizer=optimizer, + metrics=self._get_checkpoint_metrics(eval_result), + train_size=train_size, + is_best=is_best, + save_interval_checkpoint=save_interval_checkpoint, + ) + + if self.early_stopping and not np.isnan(primary_metric): + if self._is_improvement(primary_metric, best_metric): + best_metric = primary_metric + patience_counter = 0 + best_model = copy.deepcopy(self.finetuned_estimator_) + else: + patience_counter += 1 + logger.info( + "⚠️ No improvement for %s epochs. Best %s: %.4f", + patience_counter, + self._metric_name, + best_metric, + ) + + if patience_counter >= self.early_stopping_patience: + logger.info( + "🛑 Early stopping triggered. Best %s: %.4f", + self._metric_name, + best_metric, + ) + if best_model is not None: + self.finetuned_estimator_ = best_model + break + + if self.early_stopping and best_model is not None: + self.finetuned_estimator_ = best_model + + logger.info("--- ✅ Fine-tuning Finished ---") + + self._setup_inference_model(final_inference_eval_config) + + self.is_fitted_ = True + return self diff --git a/src/tabpfn/finetuning/finetuned_classifier.py b/src/tabpfn/finetuning/finetuned_classifier.py index 813e99e9f..55ea9fd6b 100644 --- a/src/tabpfn/finetuning/finetuned_classifier.py +++ b/src/tabpfn/finetuning/finetuned_classifier.py @@ -7,98 +7,49 @@ from __future__ import annotations -import copy import logging -import warnings -from functools import partial from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any, Literal +from typing_extensions import override import numpy as np import torch -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import ClassifierMixin from sklearn.metrics import log_loss, roc_auc_score -from sklearn.model_selection import train_test_split from sklearn.utils.validation import check_is_fitted -from torch.nn.utils import clip_grad_norm_ -from torch.optim.lr_scheduler import LambdaLR -from torch.utils.data import DataLoader -from tqdm.auto import tqdm from tabpfn import TabPFNClassifier -from tabpfn.finetuning._torch_compat import GradScaler, autocast, sdpa_kernel_context -from tabpfn.finetuning.data_util import ( - get_preprocessed_dataset_chunks, - meta_dataset_collator, -) -from tabpfn.finetuning.train_util import ( - clone_model_for_evaluation, - get_and_init_optimizer, - get_checkpoint_path_and_epoch_from_output_dir, - get_cosine_schedule_with_warmup, - save_checkpoint, -) - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) +from tabpfn.finetuning.finetuned_base import EvalResult, FinetunedTabPFNBase +from tabpfn.finetuning.train_util import clone_model_for_evaluation logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from tabpfn.constants import XType, YType + from tabpfn.finetuning.data_util import ClassifierBatch -def evaluate_model( - classifier: TabPFNClassifier, - eval_config: dict, - X_train: np.ndarray, - y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray, -) -> tuple[float, float]: - """Evaluate the model's performance on the validation set. + +def compute_classification_loss( + *, + predictions_BLQ: torch.Tensor, + targets_BQ: torch.Tensor, +) -> torch.Tensor: + """Compute the cross-entropy training loss. + + Shapes suffixes: + B=batch * estimators, L=logits, Q=n_queries. Args: - classifier: The TabPFNClassifier instance to evaluate. - eval_config: Configuration dictionary for the evaluation classifier. - X_train: Training input samples of shape (n_samples, n_features). - y_train: Training target values of shape (n_samples,). - X_val: Validation input samples of shape (n_samples, n_features). - y_val: Validation target values of shape (n_samples,). + predictions_BLQ: Raw logits of shape (B*E, L, Q). + targets_BQ: Integer class targets of shape (B*E, Q). Returns: - A tuple containing (roc_auc, log_loss_score). Returns (nan, nan) if - evaluation fails due to an error. + A scalar loss tensor. """ - eval_classifier = clone_model_for_evaluation( - classifier, - eval_config, - TabPFNClassifier, - ) - eval_classifier.fit(X_train, y_train) - - try: - probabilities = eval_classifier.predict_proba(X_val) # type: ignore - roc_auc = ( - roc_auc_score( - y_val, - probabilities, - multi_class="ovr", - ) - if len(np.unique(y_val)) > 2 - else roc_auc_score( - y_val, - probabilities[:, 1], - ) - ) - log_loss_score = log_loss(y_val, probabilities) - except (ValueError, RuntimeError, AttributeError) as e: - logger.warning(f"An error occurred during evaluation: {e}") - roc_auc, log_loss_score = np.nan, np.nan - - return roc_auc, log_loss_score # pyright: ignore[reportReturnType] + return torch.nn.functional.cross_entropy(predictions_BLQ, targets_BQ) -class FinetunedTabPFNClassifier(BaseEstimator, ClassifierMixin): +class FinetunedTabPFNClassifier(FinetunedTabPFNBase, ClassifierMixin): """A scikit-learn compatible wrapper for fine-tuning the TabPFNClassifier. This class encapsulates the fine-tuning loop, allowing you to fine-tune @@ -110,27 +61,25 @@ class FinetunedTabPFNClassifier(BaseEstimator, ClassifierMixin): Defaults to 30. learning_rate: The learning rate for the AdamW optimizer. A small value is crucial for stable fine-tuning. Defaults to 1e-5. - weight_decay: The weight decay for the AdamW optimizer. Defaults to 1e-3. + weight_decay: The weight decay for the AdamW optimizer. Defaults to 0.01. validation_split_ratio: Fraction of the original training data reserved as a validation set for early stopping and monitoring. Defaults to 0.1. n_finetune_ctx_plus_query_samples: The total number of samples per meta-dataset during fine-tuning (context plus query) before applying - the `finetune_ctx_query_split_ratio`. Defaults to 20_000. + the `finetune_ctx_query_split_ratio`. Defaults to 10_000. finetune_ctx_query_split_ratio: The proportion of each fine-tuning meta-dataset to use as query samples for calculating the loss. The - remainder is used as context. Defaults to 0.02. + remainder is used as context. Defaults to 0.2. n_inference_subsample_samples: The total number of subsampled training samples per estimator during validation and final inference. Defaults to 50_000. - meta_batch_size: The number of meta-datasets to process in a single - optimization step. Currently, this should be kept at 1. Defaults to 1. random_state: Seed for reproducibility of data splitting and model initialization. Defaults to 0. - early_stopping: Whether to use early stopping based on validation ROC AUC + early_stopping: Whether to use early stopping based on validation performance. Defaults to True. early_stopping_patience: Number of epochs to wait for improvement before - early stopping. Defaults to 6. - min_delta: Minimum change in ROC AUC to be considered as an improvement. + early stopping. Defaults to 8. + min_delta: Minimum change in metric to be considered as an improvement. Defaults to 1e-4. grad_clip_value: Maximum norm for gradient clipping. If None, gradient clipping is disabled. Gradient clipping helps stabilize training by @@ -141,25 +90,29 @@ class FinetunedTabPFNClassifier(BaseEstimator, ClassifierMixin): rate and then keeps it constant. If False, applies cosine decay after warmup. Defaults to False. n_estimators_finetune: If set, overrides `n_estimators` of the underlying - `TabPFNClassifier` only during fine-tuning to control the number of + estimator only during fine-tuning to control the number of estimators (ensemble size) used in the training loop. If None, the - value from `kwargs` or the `TabPFNClassifier` default is used. + value from `kwargs` or the estimator default is used. Defaults to 2. n_estimators_validation: If set, overrides `n_estimators` only for validation-time evaluation during fine-tuning (early-stopping / monitoring). If None, the value from `kwargs` or the - `TabPFNClassifier` default is used. Defaults to 4. + estimator default is used. Defaults to 2. n_estimators_final_inference: If set, overrides `n_estimators` only for the final fitted inference model that is used after fine-tuning. If - None, the value from `kwargs` or the `TabPFNClassifier` default is - used. Defaults to 8. + None, the value from `kwargs` or the estimator default is used. + Defaults to 8. use_activation_checkpointing: Whether to use activation checkpointing to - reduce memory usage. Defaults to False. - save_checkpoint_interval: Number of epochs between checkpoint saves. If - None, no intermediate checkpoints are saved. The best model checkpoint + reduce memory usage. Defaults to True. + save_checkpoint_interval: Number of epochs between checkpoint saves. This + only has an effect if `output_dir` is provided during the `fit()` call. + If None, no intermediate checkpoints are saved. The best model checkpoint is always saved regardless of this setting. Defaults to 10. - **kwargs: Additional keyword arguments to pass to the underlying - `TabPFNClassifier`, such as `n_estimators`. + + FinetunedTabPFNClassifier specific arguments: + + extra_classifier_kwargs: Additional keyword arguments to pass to the + underlying `TabPFNClassifier`, such as `n_estimators`. """ def __init__( # noqa: PLR0913 @@ -168,12 +121,11 @@ def __init__( # noqa: PLR0913 device: str = "cuda", epochs: int = 30, learning_rate: float = 1e-5, - weight_decay: float = 0.001, + weight_decay: float = 0.01, validation_split_ratio: float = 0.1, - n_finetune_ctx_plus_query_samples: int = 20_000, - finetune_ctx_query_split_ratio: float = 0.02, + n_finetune_ctx_plus_query_samples: int = 10_000, + finetune_ctx_query_split_ratio: float = 0.2, n_inference_subsample_samples: int = 50_000, - meta_batch_size: int = 1, random_state: int = 0, early_stopping: bool = True, early_stopping_patience: int = 8, @@ -181,448 +133,204 @@ def __init__( # noqa: PLR0913 grad_clip_value: float | None = 1.0, use_lr_scheduler: bool = True, lr_warmup_only: bool = False, - n_estimators_finetune: int | None = 2, - n_estimators_validation: int | None = 2, - n_estimators_final_inference: int | None = 8, - use_activation_checkpointing: bool = False, + n_estimators_finetune: int = 2, + n_estimators_validation: int = 2, + n_estimators_final_inference: int = 8, + use_activation_checkpointing: bool = True, save_checkpoint_interval: int | None = 10, extra_classifier_kwargs: dict[str, Any] | None = None, ): - super().__init__() - self.device = device - self.epochs = epochs - self.learning_rate = learning_rate - self.weight_decay = weight_decay - self.validation_split_ratio = validation_split_ratio - self.n_finetune_ctx_plus_query_samples = n_finetune_ctx_plus_query_samples - self.finetune_ctx_query_split_ratio = finetune_ctx_query_split_ratio - self.n_inference_subsample_samples = n_inference_subsample_samples - self.meta_batch_size = meta_batch_size - self.random_state = random_state - self.early_stopping = early_stopping - self.early_stopping_patience = early_stopping_patience - self.min_delta = min_delta - self.grad_clip_value = grad_clip_value - self.use_lr_scheduler = use_lr_scheduler - self.lr_warmup_only = lr_warmup_only - self.n_estimators_finetune = n_estimators_finetune - self.n_estimators_validation = n_estimators_validation - self.n_estimators_final_inference = n_estimators_final_inference - self.use_activation_checkpointing = use_activation_checkpointing - self.classifier_kwargs = extra_classifier_kwargs or {} - - self.save_checkpoint_interval = save_checkpoint_interval - - assert self.meta_batch_size == 1, "meta_batch_size must be 1 for finetuning" - - def _build_classifier_config( - self, - base_config: dict[str, Any], - n_estimators_override: int | None, - ) -> dict[str, Any]: - """Return a deep-copy of base_config with an optional n_estimators override.""" - config = copy.deepcopy(base_config) - if n_estimators_override is not None: - config["n_estimators"] = n_estimators_override - return config - - def _build_eval_config( - self, - base_config: dict[str, Any], - n_estimators_override: int | None, - ) -> dict[str, Any]: - """Return eval config sharing settings except for an optional n_estimators override.""" # noqa: E501 - config = self._build_classifier_config(base_config, n_estimators_override) - existing = dict(config.get("inference_config", {}) or {}) - existing["SUBSAMPLE_SAMPLES"] = self.n_inference_subsample_samples - config["inference_config"] = existing - return config - - def fit( - self, X: np.ndarray, y: np.ndarray, output_dir: Path | None = None - ) -> FinetunedTabPFNClassifier: - """Fine-tune the TabPFN model on the provided training data. - - Args: - X: The training input samples of shape (n_samples, n_features). - y: The target values of shape (n_samples,). - output_dir: Directory path for saving checkpoints. If None, no - checkpointing is performed and progress will be lost if - training is interrupted. - - Returns: - The fitted instance itself. - """ - if output_dir is None: - warnings.warn( - "`output_dir` is not set. This means no checkpointing will be done and " - "all progress will be lost if the training is interrupted.", - UserWarning, - stacklevel=2, - ) - else: - output_dir.mkdir(parents=True, exist_ok=True) - - self.X_ = X - self.y_ = y - - return self._fit(X=X, y=y, output_dir=output_dir) - - def _fit( # noqa: C901,PLR0912 - self, - X: np.ndarray, - y: np.ndarray, - output_dir: Path | None = None, - ) -> FinetunedTabPFNClassifier: - """Internal implementation of fit that runs the finetuning loop.""" - # Store the original training size for checkpoint naming - train_size = X.shape[0] - - X_train, X_val, y_train, y_val = train_test_split( - X, - y, - test_size=self.validation_split_ratio, - random_state=self.random_state, - stratify=y, + super().__init__( + device=device, + epochs=epochs, + learning_rate=learning_rate, + weight_decay=weight_decay, + validation_split_ratio=validation_split_ratio, + n_finetune_ctx_plus_query_samples=n_finetune_ctx_plus_query_samples, + finetune_ctx_query_split_ratio=finetune_ctx_query_split_ratio, + n_inference_subsample_samples=n_inference_subsample_samples, + random_state=random_state, + early_stopping=early_stopping, + early_stopping_patience=early_stopping_patience, + min_delta=min_delta, + grad_clip_value=grad_clip_value, + use_lr_scheduler=use_lr_scheduler, + lr_warmup_only=lr_warmup_only, + n_estimators_finetune=n_estimators_finetune, + n_estimators_validation=n_estimators_validation, + n_estimators_final_inference=n_estimators_final_inference, + use_activation_checkpointing=use_activation_checkpointing, + save_checkpoint_interval=save_checkpoint_interval, ) - - # Calculate the context size used during finetuning. - n_finetune_ctx_plus_query_samples = min( - self.n_finetune_ctx_plus_query_samples, - len(y_train), + self.extra_classifier_kwargs = extra_classifier_kwargs + + @property + @override + def _estimator_kwargs(self) -> dict[str, Any]: + """Return the classifier-specific kwargs.""" + return self.extra_classifier_kwargs or {} + + @property + @override + def _model_type(self) -> Literal["classifier", "regressor"]: + """Return the model type string.""" + return "classifier" + + @property + @override + def _metric_name(self) -> str: + """Return the name of the primary metric.""" + return "ROC AUC" + + @override + def _create_estimator(self, config: dict[str, Any]) -> TabPFNClassifier: + """Create the TabPFNClassifier with the given config.""" + return TabPFNClassifier( + **config, + fit_mode="batched", + differentiable_input=False, ) - inference_config = self.classifier_kwargs.get("inference_config", {}) - base_classifier_config: dict[str, Any] = { - **self.classifier_kwargs, - "ignore_pretraining_limits": True, - "device": self.device, - "random_state": self.random_state, - "inference_config": inference_config, - } - - # Config used for the finetuning loop. - finetuning_classifier_config = self._build_classifier_config( - base_classifier_config, - self.n_estimators_finetune, + @override + def _setup_estimator(self) -> None: + """Set up softmax temperature after estimator creation.""" + self.finetuned_estimator_.softmax_temperature_ = ( + self.finetuned_estimator_.softmax_temperature ) - # Configs used for validation-time evaluation and final inference. They - # share all settings except for a potential `n_estimators` override, and - # both use the same `SUBSAMPLE_SAMPLES` setting. - validation_eval_config = self._build_eval_config( - base_classifier_config, - self.n_estimators_validation, + @override + def _setup_batch(self, batch: ClassifierBatch) -> None: # type: ignore[override] + """No batch-specific setup needed for classifier.""" + + @override + def _should_skip_batch(self, batch: ClassifierBatch) -> bool: # type: ignore[override] + """Check if the batch should be skipped.""" + ctx_unique = torch.unique( + torch.cat([torch.unique(t.reshape(-1)) for t in batch.y_context]) ) - final_inference_eval_config = self._build_eval_config( - base_classifier_config, - self.n_estimators_final_inference, + qry_unique = torch.unique( + torch.cat([torch.unique(t.reshape(-1)) for t in batch.y_query]) ) - if self.device.startswith("cuda") and torch.cuda.is_available(): - eval_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())] - else: - eval_devices = ["cpu"] # Used in tests - - validation_eval_config["device"] = eval_devices - final_inference_eval_config["device"] = eval_devices - - epoch_to_start_from = 0 - checkpoint_path = None - if output_dir is not None: - checkpoint_path, epoch_to_start_from = ( - get_checkpoint_path_and_epoch_from_output_dir( - output_dir=output_dir, - train_size=train_size, - get_best=False, - ) + query_in_context = torch.isin(qry_unique, ctx_unique, assume_unique=True) + if not bool(query_in_context.all()): + missing_labels = qry_unique[~query_in_context].detach().cpu().numpy() + context_labels = ctx_unique.detach().cpu().numpy() + logger.warning( + "Skipping batch: query labels %s are not a subset of context labels %s", + missing_labels, + context_labels, ) - if checkpoint_path is not None: - logger.info( - f"Restarting training from checkpoint {checkpoint_path} at epoch " - f"{epoch_to_start_from}", - ) - finetuning_classifier_config["model_path"] = checkpoint_path - - self.finetuned_classifier_ = TabPFNClassifier( - **finetuning_classifier_config, - fit_mode="batched", - differentiable_input=False, - ) - self.finetuned_classifier_.softmax_temperature_ = ( - self.finetuned_classifier_.softmax_temperature - ) + return True + return False - self.finetuned_classifier_._initialize_model_variables() + @override + def _forward_with_loss(self, batch: ClassifierBatch) -> torch.Tensor: # type: ignore[override] + """Perform forward pass and compute and return cross-entropy loss. - self.finetuned_classifier_.model_.to(self.device) + Args: + batch: The ClassifierBatch containing preprocessed context and + query data. - if self.use_activation_checkpointing: - self.finetuned_classifier_.model_.recompute_layer = True # type: ignore + Returns: + The computed cross-entropy loss tensor. + """ + X_query_batch = batch.X_query + y_query_batch = batch.y_query - optimizer = get_and_init_optimizer( - model_parameters=self.finetuned_classifier_.model_.parameters(), # type: ignore - learning_rate=self.learning_rate, - weight_decay=self.weight_decay, - checkpoint_path=checkpoint_path, - device=self.device, + # shape suffix: Q=n_queries, B=batch(=1), E=estimators, L=logits + predictions_QBEL = self.finetuned_estimator_.forward( + X_query_batch, + return_raw_logits=True, ) - loss_function = torch.nn.CrossEntropyLoss() - - use_amp = self.device.startswith("cuda") and torch.cuda.is_available() - scaler = GradScaler() if use_amp else None # type: ignore - - logger.info("--- 🚀 Starting Fine-tuning ---") - - best_roc_auc = -np.inf - patience_counter = 0 - best_model = None - - scheduler: LambdaLR | None = None - - for epoch in range(epoch_to_start_from, self.epochs): - # Per-epoch aggregates for cleaner learning curves. - epoch_loss_sum = 0.0 - epoch_batches = 0 - - # Regenerate datasets each epoch with a different random_state to ensure - # diversity in context/query pairs across epochs. This prevents the - # model from seeing the exact same splits in every epoch, which could - # reduce training signal diversity. - training_splitter = partial( - train_test_split, - test_size=self.finetune_ctx_query_split_ratio, - random_state=self.random_state + epoch, - ) - - training_datasets = get_preprocessed_dataset_chunks( - calling_instance=self.finetuned_classifier_, - X_raw=X_train, - y_raw=y_train, - split_fn=training_splitter, - max_data_size=n_finetune_ctx_plus_query_samples, - model_type="classifier", - equal_split_size=False, - seed=self.random_state + epoch, - ) - - finetuning_dataloader = DataLoader( - training_datasets, - batch_size=self.meta_batch_size, - collate_fn=meta_dataset_collator, - shuffle=True, - ) + Q, B, E, L = predictions_QBEL.shape + assert y_query_batch.shape[1] == Q + assert B == 1 + assert self.n_estimators_finetune == E + assert self.finetuned_estimator_.n_classes_ == L + + # Reshape for CE loss: treat estimator dim as batch dim + # permute to shape (B, E, L, Q) then reshape to (B*E, L, Q) + predictions_BLQ = predictions_QBEL.permute(1, 2, 3, 0).reshape(B * E, L, Q) + targets_BQ = y_query_batch.repeat(B * self.n_estimators_finetune, 1).to( + self.device + ) - # Instantiate the LR scheduler only once so that the warmup and - # cosine schedule run continuously across all epochs. scheduler is None - # only in the first epoch. - if self.use_lr_scheduler and scheduler is None: - steps_per_epoch = len(finetuning_dataloader) - if steps_per_epoch == 0: - logger.warning( - "No training batches available; ending training early.", - ) - break - - total_steps = steps_per_epoch * self.epochs - warmup_steps = int(total_steps * 0.1) - - lrate_schedule_fn = get_cosine_schedule_with_warmup( - total_steps=total_steps, - warmup_steps=warmup_steps, - warmup_only=self.lr_warmup_only, - ) - scheduler = LambdaLR(optimizer, lr_lambda=lrate_schedule_fn) - - logger.info( - "Using LambdaLR %s schedule: total_steps=%d, warmup_steps=%d", - "warmup-only (constant LR after warmup)" - if self.lr_warmup_only - else "warmup+cosine", - total_steps, - warmup_steps, - ) - - progress_bar = tqdm( - finetuning_dataloader, - desc=f"Finetuning Epoch {epoch + 1}/{self.epochs}", - ) - for ( - X_context_batch, - X_query_batch, - y_context_batch, - y_query_batch, - cat_ixs, - confs, - ) in progress_bar: - ctx = set( - torch.cat([t.flatten() for t in y_context_batch]).unique().tolist() - ) - qry = set( - torch.cat([t.flatten() for t in y_query_batch]).unique().tolist() - ) - if not qry.issubset(ctx): - logger.warning( - "Skipping batch: query labels %s are not a subset of " - "context labels %s", - qry, - ctx, - ) - continue - - optimizer.zero_grad() - - self.finetuned_classifier_.fit_from_preprocessed( - X_context_batch, - y_context_batch, - cat_ixs, - confs, - ) - - use_scaler = use_amp and scaler is not None - - with autocast(enabled=use_scaler): # type: ignore - with sdpa_kernel_context(): - # shape suffix: Q=n_queries, B=batch(=1), E=estimators, L=logits - predictions_QBEL = self.finetuned_classifier_.forward( - X_query_batch, - return_raw_logits=True, - ) - - Q, B, E, L = predictions_QBEL.shape - assert y_query_batch.shape[1] == Q - assert B == 1 - assert self.n_estimators_finetune == E - assert self.finetuned_classifier_.n_classes_ == L - - # For getting the loss using the CE loss, we need to reshape. - # We treat the estimator dim as batch dim and - # permute so that the shape is (B*E, L, Q). This way - # the loss is first calculated for each estimator and then - # the results are averaged. This is what we want. If we - # average each estimator first and then take the mean we - # don't improve the individual estimators but the sum of it, - # which is not ideal. - predictions_BLQ = predictions_QBEL.permute(1, 2, 3, 0).reshape( - B * E, L, Q - ) - - loss = loss_function( - predictions_BLQ, - y_query_batch.repeat(self.n_estimators_finetune, 1).to( - self.device - ), - ) - - if use_scaler: - # When using activation checkpointing, we need to exclude the cuDNN - # backend also during the backward pass because checkpointing re- - # computes the forward pass during backward. - with sdpa_kernel_context(): - scaler.scale(loss).backward() # type: ignore - scaler.unscale_(optimizer) # type: ignore - - if self.grad_clip_value is not None: - clip_grad_norm_( - self.finetuned_classifier_.model_.parameters(), # type: ignore - self.grad_clip_value, - ) - - scaler.step(optimizer) # type: ignore - scaler.update() # type: ignore - else: - with sdpa_kernel_context(): - loss.backward() - - if self.grad_clip_value is not None: - clip_grad_norm_( - self.finetuned_classifier_.model_.parameters(), # type: ignore - self.grad_clip_value, - ) - - optimizer.step() - - if scheduler is not None: - scheduler.step() - - loss_scalar = float(loss.detach().item()) - - epoch_loss_sum += loss_scalar - epoch_batches += 1 - - progress_bar.set_postfix( - loss=f"{loss_scalar:.4f}", - ) - - mean_train_loss = ( - epoch_loss_sum / epoch_batches if epoch_batches > 0 else None - ) + return compute_classification_loss( + predictions_BLQ=predictions_BLQ, + targets_BQ=targets_BQ, + ) - roc_auc, log_loss_score = evaluate_model( - self.finetuned_classifier_, - validation_eval_config, - X_train, # pyright: ignore[reportArgumentType] - y_train, # pyright: ignore[reportArgumentType] - X_val, # pyright: ignore[reportArgumentType] - y_val, # pyright: ignore[reportArgumentType] - ) + @override + def _evaluate_model( + self, + eval_config: dict[str, Any], + X_train: np.ndarray, + y_train: np.ndarray, + X_val: np.ndarray, + y_val: np.ndarray, + ) -> EvalResult: + """Evaluate the classifier using ROC AUC and log loss.""" + eval_classifier = clone_model_for_evaluation( + self.finetuned_estimator_, + eval_config, + TabPFNClassifier, + ) + eval_classifier.fit(X_train, y_train) + + try: + probabilities = eval_classifier.predict_proba(X_val) # type: ignore + if probabilities.shape[1] > 2: + roc_auc = roc_auc_score(y_val, probabilities, multi_class="ovr") + else: + roc_auc = roc_auc_score(y_val, probabilities[:, 1]) + log_loss_score = log_loss(y_val, probabilities) + except (ValueError, RuntimeError, AttributeError) as e: + logger.warning(f"An error occurred during evaluation: {e}") + roc_auc, log_loss_score = np.nan, np.nan + + return EvalResult( + primary=roc_auc, # pyright: ignore[reportArgumentType] + secondary={"log_loss": log_loss_score}, + ) - logger.info( - f"📊 Epoch {epoch + 1} Evaluation | Val ROC: {roc_auc:.4f}, " - f"Val Log Loss: {log_loss_score:.4f}, Train Loss: {mean_train_loss:.4f}" - ) + @override + def _is_improvement(self, current: float, best: float) -> bool: + """Check if current ROC AUC is better (higher) than best.""" + return current > best + self.min_delta + + @override + def _get_initial_best_metric(self) -> float: + """Return -inf for maximization.""" + return -np.inf + + @override + def _get_checkpoint_metrics(self, eval_result: EvalResult) -> dict[str, float]: + """Return metrics for checkpoint saving.""" + return { + "roc_auc": eval_result.primary, + "log_loss": eval_result.secondary.get("log_loss", np.nan), + } - if output_dir is not None and not np.isnan(roc_auc): - save_interval_checkpoint = ( - self.save_checkpoint_interval is not None - and (epoch + 1) % self.save_checkpoint_interval == 0 - ) - - is_best = roc_auc > best_roc_auc + self.min_delta - - if save_interval_checkpoint or is_best: - save_checkpoint( - estimator=self.finetuned_classifier_, - output_dir=output_dir, - epoch=epoch + 1, - optimizer=optimizer, - metrics={"roc_auc": roc_auc, "log_loss": log_loss_score}, - train_size=train_size, - is_best=is_best, - save_interval_checkpoint=save_interval_checkpoint, - ) - - if self.early_stopping and not np.isnan(roc_auc): - if roc_auc > best_roc_auc + self.min_delta: - best_roc_auc = roc_auc - patience_counter = 0 - best_model = copy.deepcopy(self.finetuned_classifier_) - else: - patience_counter += 1 - logger.info( - "⚠️ No improvement for %s epochs. Best ROC AUC: %.4f", - patience_counter, - best_roc_auc, - ) - - if patience_counter >= self.early_stopping_patience: - logger.info( - "🛑 Early stopping triggered. Best ROC AUC: %.4f", - best_roc_auc, - ) - if best_model is not None: - self.finetuned_classifier_ = best_model - # Log one last set of epoch metrics before breaking. - break - - if self.early_stopping and best_model is not None: - self.finetuned_classifier_ = best_model - - logger.info("--- ✅ Fine-tuning Finished ---") + @override + def _log_epoch_evaluation( + self, epoch: int, eval_result: EvalResult, mean_train_loss: float | None + ) -> None: + """Log evaluation results for classification.""" + log_loss_score = eval_result.secondary.get("log_loss", np.nan) + logger.info( + f"📊 Epoch {epoch + 1} Evaluation | Val ROC: {eval_result.primary:.4f}, " + f"Val Log Loss: {log_loss_score:.4f}, Train Loss: {mean_train_loss:.4f}" + ) + @override + def _setup_inference_model( + self, final_inference_eval_config: dict[str, Any] + ) -> None: + """Set up the final inference classifier.""" finetuned_inference_classifier = clone_model_for_evaluation( - self.finetuned_classifier_, # type: ignore + self.finetuned_estimator_, final_inference_eval_config, TabPFNClassifier, ) @@ -630,10 +338,26 @@ def _fit( # noqa: C901,PLR0912 self.finetuned_inference_classifier_.fit_mode = "fit_preprocessors" # type: ignore self.finetuned_inference_classifier_.fit(self.X_, self.y_) # type: ignore - self.is_fitted_ = True + @override + def fit( + self, X: XType, y: YType, output_dir: Path | None = None + ) -> FinetunedTabPFNClassifier: + """Fine-tune the TabPFN model on the provided training data. + + Args: + X: The training input samples of shape (n_samples, n_features). + y: The target values of shape (n_samples,). + output_dir: Directory path for saving checkpoints. If None, no + checkpointing is performed and progress will be lost if + training is interrupted. + + Returns: + The fitted instance itself. + """ + super().fit(X, y, output_dir) return self - def predict_proba(self, X: np.ndarray) -> np.ndarray: + def predict_proba(self, X: XType) -> np.ndarray: """Predict class probabilities for X. Args: @@ -647,7 +371,8 @@ def predict_proba(self, X: np.ndarray) -> np.ndarray: return self.finetuned_inference_classifier_.predict_proba(X) # type: ignore - def predict(self, X: np.ndarray) -> np.ndarray: + @override + def predict(self, X: XType) -> np.ndarray: """Predict the class for X. Args: diff --git a/src/tabpfn/finetuning/finetuned_regressor.py b/src/tabpfn/finetuning/finetuned_regressor.py new file mode 100644 index 000000000..39b58006a --- /dev/null +++ b/src/tabpfn/finetuning/finetuned_regressor.py @@ -0,0 +1,376 @@ +"""A TabPFN regressor that finetunes the underlying model for a single task. + +This module provides the FinetunedTabPFNRegressor class, which wraps TabPFN +and allows fine-tuning on a specific dataset using the familiar scikit-learn +.fit() and .predict() API. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal +from typing_extensions import override + +import numpy as np +import torch +from sklearn.base import RegressorMixin +from sklearn.metrics import mean_squared_error +from sklearn.utils.validation import check_is_fitted + +from tabpfn import TabPFNRegressor +from tabpfn.finetuning.finetuned_base import EvalResult, FinetunedTabPFNBase +from tabpfn.finetuning.train_util import clone_model_for_evaluation +from tabpfn.model_loading import get_n_out + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from tabpfn.constants import XType, YType + from tabpfn.finetuning.data_util import RegressorBatch + + +def compute_regression_loss( + *, + predictions_BLQ: torch.Tensor, + targets_BQ: torch.Tensor, + bardist_loss_fn: Any, + mse_loss_weight: float, + mse_loss_clip: float | None, +) -> torch.Tensor: + """Compute the regression training loss from bar distribution and auxiliary terms. + + Shapes suffixes: + B=batch * estimators, L=logits, Q=n_queries. + + Args: + predictions_BLQ: Bar distribution logits of shape (B*E, Q, L). + targets_BQ: Regression targets of shape (B*E, Q). + bardist_loss_fn: Bar distribution loss function (callable) which also + exposes a `.mean()` method for converting bar logits to mean + predictions. + mse_loss_weight: Weight for an auxiliary MSE term. Set to 0.0 to disable. + mse_loss_clip: Optional upper bound for the auxiliary MSE term. + + Returns: + A scalar loss tensor. + """ + losses = bardist_loss_fn(predictions_BLQ, targets_BQ) + + if mse_loss_weight > 0.0: + predictions_mean = bardist_loss_fn.mean(predictions_BLQ) + mse_aux_loss = ((predictions_mean - targets_BQ) ** 2).mean() + if mse_loss_clip is not None: + mse_aux_loss = mse_aux_loss.clamp(max=mse_loss_clip) + losses = losses + mse_loss_weight * mse_aux_loss + + return losses.mean() + + +class FinetunedTabPFNRegressor(FinetunedTabPFNBase, RegressorMixin): + """A scikit-learn compatible wrapper for fine-tuning the TabPFNRegressor. + + This class encapsulates the fine-tuning loop, allowing you to fine-tune + TabPFN on a specific dataset using the familiar .fit() and .predict() API. + + Args: + device: The device to run the model on. Defaults to "cuda". + epochs: The total number of passes through the fine-tuning data. + Defaults to 30. + learning_rate: The learning rate for the AdamW optimizer. A small value + is crucial for stable fine-tuning. Defaults to 1e-5. + weight_decay: The weight decay for the AdamW optimizer. Defaults to 0.01. + validation_split_ratio: Fraction of the original training data reserved + as a validation set for early stopping and monitoring. Defaults to 0.1. + n_finetune_ctx_plus_query_samples: The total number of samples per + meta-dataset during fine-tuning (context plus query) before applying + the `finetune_ctx_query_split_ratio`. Defaults to 10_000. + finetune_ctx_query_split_ratio: The proportion of each fine-tuning + meta-dataset to use as query samples for calculating the loss. The + remainder is used as context. Defaults to 0.2. + n_inference_subsample_samples: The total number of subsampled training + samples per estimator during validation and final inference. + Defaults to 50_000. + random_state: Seed for reproducibility of data splitting and model + initialization. Defaults to 0. + early_stopping: Whether to use early stopping based on validation + performance. Defaults to True. + early_stopping_patience: Number of epochs to wait for improvement before + early stopping. Defaults to 8. + min_delta: Minimum change in metric to be considered as an improvement. + Defaults to 1e-4. + grad_clip_value: Maximum norm for gradient clipping. If None, gradient + clipping is disabled. Gradient clipping helps stabilize training by + preventing exploding gradients. Defaults to 1.0. + use_lr_scheduler: Whether to use a learning rate scheduler (linear warmup + with optional cosine decay) during fine-tuning. Defaults to True. + lr_warmup_only: If True, only performs linear warmup to the base learning + rate and then keeps it constant. If False, applies cosine decay after + warmup. Defaults to False. + n_estimators_finetune: If set, overrides `n_estimators` of the underlying + estimator only during fine-tuning to control the number of + estimators (ensemble size) used in the training loop. If None, the + value from `kwargs` or the estimator default is used. + Defaults to 2. + n_estimators_validation: If set, overrides `n_estimators` only for + validation-time evaluation during fine-tuning (early-stopping / + monitoring). If None, the value from `kwargs` or the + estimator default is used. Defaults to 2. + n_estimators_final_inference: If set, overrides `n_estimators` only for + the final fitted inference model that is used after fine-tuning. If + None, the value from `kwargs` or the estimator default is used. + Defaults to 8. + use_activation_checkpointing: Whether to use activation checkpointing to + reduce memory usage. Defaults to True. + save_checkpoint_interval: Number of epochs between checkpoint saves. This + only has an effect if `output_dir` is provided during the `fit()` call. + If None, no intermediate checkpoints are saved. The best model checkpoint + is always saved regardless of this setting. Defaults to 10. + + FinetunedTabPFNRegressor specific arguments: + + extra_regressor_kwargs: Additional keyword arguments to pass to the + underlying `TabPFNRegressor`, such as `n_estimators`. + mse_loss_weight: Weight for an auxiliary MSE loss term added to the + bar distribution loss. Set to 0.0 to disable. Defaults to 8.0. + mse_loss_clip: Optional upper bound for the auxiliary MSE loss term. + If None, no clipping is applied. Defaults to None. + """ + + def __init__( # noqa: PLR0913 + self, + *, + device: str = "cuda", + epochs: int = 30, + learning_rate: float = 1e-5, + weight_decay: float = 0.01, + validation_split_ratio: float = 0.1, + n_finetune_ctx_plus_query_samples: int = 10_000, + finetune_ctx_query_split_ratio: float = 0.2, + n_inference_subsample_samples: int = 50_000, + random_state: int = 0, + early_stopping: bool = True, + early_stopping_patience: int = 8, + min_delta: float = 1e-4, + grad_clip_value: float | None = 1.0, + use_lr_scheduler: bool = True, + lr_warmup_only: bool = False, + n_estimators_finetune: int = 2, + n_estimators_validation: int = 2, + n_estimators_final_inference: int = 8, + use_activation_checkpointing: bool = True, + save_checkpoint_interval: int | None = 10, + extra_regressor_kwargs: dict[str, Any] | None = None, + mse_loss_weight: float = 8.0, + mse_loss_clip: float | None = None, + ): + super().__init__( + device=device, + epochs=epochs, + learning_rate=learning_rate, + weight_decay=weight_decay, + validation_split_ratio=validation_split_ratio, + n_finetune_ctx_plus_query_samples=n_finetune_ctx_plus_query_samples, + finetune_ctx_query_split_ratio=finetune_ctx_query_split_ratio, + n_inference_subsample_samples=n_inference_subsample_samples, + random_state=random_state, + early_stopping=early_stopping, + early_stopping_patience=early_stopping_patience, + min_delta=min_delta, + grad_clip_value=grad_clip_value, + use_lr_scheduler=use_lr_scheduler, + lr_warmup_only=lr_warmup_only, + n_estimators_finetune=n_estimators_finetune, + n_estimators_validation=n_estimators_validation, + n_estimators_final_inference=n_estimators_final_inference, + use_activation_checkpointing=use_activation_checkpointing, + save_checkpoint_interval=save_checkpoint_interval, + ) + self.extra_regressor_kwargs = extra_regressor_kwargs + self.mse_loss_weight = mse_loss_weight + self.mse_loss_clip = mse_loss_clip + + @property + @override + def _estimator_kwargs(self) -> dict[str, Any]: + """Return the regressor-specific kwargs.""" + return self.extra_regressor_kwargs or {} + + @property + @override + def _model_type(self) -> Literal["classifier", "regressor"]: + """Return the model type string.""" + return "regressor" + + @property + @override + def _metric_name(self) -> str: + """Return the name of the primary metric.""" + return "MSE" + + @override + def _create_estimator(self, config: dict[str, Any]) -> TabPFNRegressor: + """Create the TabPFNRegressor with the given config.""" + return TabPFNRegressor( + **config, + fit_mode="batched", + differentiable_input=False, + ) + + @override + def _setup_estimator(self) -> None: + """No additional setup needed for regressor at creation time.""" + + @override + def _should_skip_batch(self, batch: RegressorBatch) -> bool: # type: ignore[override] + """Never skip a batch for regression.""" + return False + + @override + def _setup_batch(self, batch: RegressorBatch) -> None: # type: ignore[override] + """Set up bar distribution for this batch.""" + self.finetuned_estimator_.raw_space_bardist_ = batch.raw_space_bardist + self.finetuned_estimator_.bardist_ = batch.znorm_space_bardist + self._bardist_loss = batch.znorm_space_bardist + + @override + def _forward_with_loss(self, batch: RegressorBatch) -> torch.Tensor: # type: ignore[override] + """Perform forward pass and compute bar distribution loss with optional MSE. + + Args: + batch: The RegressorBatch containing preprocessed context and query + data plus bar distribution information. + + Returns: + The computed loss tensor (bar distribution + optional MSE auxiliary). + """ + X_query_batch = batch.X_query + y_query_batch = batch.y_query + bardist_loss_fn = self._bardist_loss + + _, per_estim_logits, _ = self.finetuned_estimator_.forward(X_query_batch) + # per_estim_logits is a list (per estimator) of tensors with shape [Q, B(=1), L] + + # shape suffix: Q=n_queries, B=batch(=1), E=estimators, L=logits + predictions_QBEL = torch.stack(per_estim_logits, dim=2) + + Q, B, E, L = predictions_QBEL.shape + num_bars = get_n_out(self.finetuned_estimator_.configs_[0], bardist_loss_fn) + assert y_query_batch.shape[1] == Q + assert B == 1 + assert self.n_estimators_finetune == E + assert num_bars == L + + # Reshape for bar distribution loss: treat estimator dim as batch dim + # permute to shape (B, E, Q, L) then reshape to (B*E, Q, L) + predictions_BLQ = predictions_QBEL.permute(1, 2, 0, 3).reshape(B * E, Q, L) + + targets_BQ = y_query_batch.repeat(B * self.n_estimators_finetune, 1).to( + self.device + ) + + return compute_regression_loss( + predictions_BLQ=predictions_BLQ, + targets_BQ=targets_BQ, + bardist_loss_fn=bardist_loss_fn, + mse_loss_weight=self.mse_loss_weight, + mse_loss_clip=self.mse_loss_clip, + ) + + @override + def _evaluate_model( + self, + eval_config: dict[str, Any], + X_train: np.ndarray, + y_train: np.ndarray, + X_val: np.ndarray, + y_val: np.ndarray, + ) -> EvalResult: + """Evaluate the regressor using MSE.""" + eval_regressor = clone_model_for_evaluation( + self.finetuned_estimator_, + eval_config, + TabPFNRegressor, + ) + eval_regressor.fit(X_train, y_train) + + try: + predictions = eval_regressor.predict(X_val) # type: ignore + mse = mean_squared_error(y_val, predictions) + except (ValueError, RuntimeError, AttributeError) as e: + logger.warning(f"An error occurred during evaluation: {e}") + mse = np.nan + + return EvalResult(primary=mse) # pyright: ignore[reportArgumentType] + + @override + def _is_improvement(self, current: float, best: float) -> bool: + """Check if current MSE is better (lower) than best.""" + return current < best - self.min_delta + + @override + def _get_initial_best_metric(self) -> float: + """Return inf for minimization.""" + return np.inf + + @override + def _get_checkpoint_metrics(self, eval_result: EvalResult) -> dict[str, float]: + """Return metrics for checkpoint saving.""" + return {"mse": eval_result.primary} + + @override + def _log_epoch_evaluation( + self, epoch: int, eval_result: EvalResult, mean_train_loss: float | None + ) -> None: + """Log evaluation results for regression.""" + logger.info( + f"📊 Epoch {epoch + 1} Evaluation | Val MSE: {eval_result.primary:.4f}, " + f"Train Loss: {mean_train_loss:.4f}" + ) + + @override + def _setup_inference_model( + self, final_inference_eval_config: dict[str, Any] + ) -> None: + """Set up the final inference regressor.""" + finetuned_inference_regressor = clone_model_for_evaluation( + self.finetuned_estimator_, + final_inference_eval_config, + TabPFNRegressor, + ) + self.finetuned_inference_regressor_ = finetuned_inference_regressor + self.finetuned_inference_regressor_.fit_mode = "fit_preprocessors" # type: ignore + self.finetuned_inference_regressor_.fit(self.X_, self.y_) # type: ignore + + @override + def fit( + self, X: XType, y: YType, output_dir: Path | None = None + ) -> FinetunedTabPFNRegressor: + """Fine-tune the TabPFN model on the provided training data. + + Args: + X: The training input samples of shape (n_samples, n_features). + y: The target values of shape (n_samples,). + output_dir: Directory path for saving checkpoints. If None, no + checkpointing is performed and progress will be lost if + training is interrupted. + + Returns: + The fitted instance itself. + """ + super().fit(X, y, output_dir) + return self + + @override + def predict(self, X: XType) -> np.ndarray: + """Predict target values for X. + + Args: + X: The input samples of shape (n_samples, n_features). + + Returns: + The predicted target values with shape (n_samples,). + """ + check_is_fitted(self) + + return self.finetuned_inference_regressor_.predict(X) # type: ignore diff --git a/src/tabpfn/finetuning/train_util.py b/src/tabpfn/finetuning/train_util.py index 73b6ccc92..90fb45770 100644 --- a/src/tabpfn/finetuning/train_util.py +++ b/src/tabpfn/finetuning/train_util.py @@ -19,13 +19,7 @@ if TYPE_CHECKING: from collections.abc import Iterator - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) - -logging.getLogger().setLevel(logging.INFO) +logger = logging.getLogger(__name__) def _format_train_size(train_size: int) -> str: diff --git a/src/tabpfn/utils.py b/src/tabpfn/utils.py index 8daa57be6..c0e2c5a29 100644 --- a/src/tabpfn/utils.py +++ b/src/tabpfn/utils.py @@ -430,7 +430,23 @@ def validate_Xy_fit( ensure_y_numeric: bool = False, ignore_pretraining_limits: bool = False, ) -> tuple[np.ndarray, np.ndarray, npt.NDArray[Any] | None, int]: - """Validate the input data for fitting.""" + """Validate the input data for fitting. + + Args: + X: The input data. + y: The target data. + estimator: The estimator to validate the data for. + max_num_features: The maximum number of features to allow. + Ignored if `ignore_pretraining_limits` is True. + max_num_samples: The maximum number of samples to allow. + Ignored if `ignore_pretraining_limits` is True. + ensure_y_numeric: Whether to ensure the target data is numeric. + ignore_pretraining_limits: Whether to ignore the pretraining limits. + + Returns: + A tuple of the validated input data X, target data y, feature names, + and number of features. + """ # Calls `validate_data()` with specification # Checks that we do not call validate_data() in case diff --git a/tests/test_finetuning_classifier.py b/tests/test_finetuning_classifier.py index 75ef2f4e8..0bd533673 100644 --- a/tests/test_finetuning_classifier.py +++ b/tests/test_finetuning_classifier.py @@ -26,6 +26,7 @@ from tabpfn import TabPFNClassifier from tabpfn.finetuning.data_util import ( + ClassifierBatch, DatasetCollectionWithPreprocessing, get_preprocessed_dataset_chunks, meta_dataset_collator, @@ -340,7 +341,7 @@ def test_finetuned_tabpfn_classifier_fit_and_predict( finetuned_clf.fit(X_train, y_train) assert finetuned_clf.is_fitted_ - assert hasattr(finetuned_clf, "finetuned_classifier_") + assert hasattr(finetuned_clf, "finetuned_estimator_") assert hasattr(finetuned_clf, "finetuned_inference_classifier_") probabilities = finetuned_clf.predict_proba(X_test) @@ -721,8 +722,13 @@ def test_get_preprocessed_datasets_basic() -> None: assert hasattr(dataset, "__len__") assert len(dataset) > 0 item = dataset[0] - assert isinstance(item, tuple) - assert len(item) == 6 + assert isinstance(item, ClassifierBatch) + assert hasattr(item, "X_context") + assert hasattr(item, "X_query") + assert hasattr(item, "y_context") + assert hasattr(item, "y_query") + assert hasattr(item, "cat_indices") + assert hasattr(item, "configs") def test_datasetcollectionwithpreprocessing_classification_single_dataset( @@ -751,29 +757,17 @@ def test_datasetcollectionwithpreprocessing_classification_single_dataset( assert len(dataset_collection) == 1, "Collection should contain one dataset config" item_index = 0 - processed_dataset_item = dataset_collection[item_index] + batch = dataset_collection[item_index] - assert isinstance(processed_dataset_item, tuple) - assert len(processed_dataset_item) == 6, ( - "Item tuple should have 6 elements for classification" - ) + assert isinstance(batch, ClassifierBatch) - ( - X_trains_preprocessed, - _X_tests_preprocessed, - _y_trains_preprocessed, - y_test_raw_tensor, - _cat_ixs, - _returned_ensemble_configs, - ) = processed_dataset_item - - assert isinstance(X_trains_preprocessed, list) - assert len(X_trains_preprocessed) == n_estimators + assert isinstance(batch.X_context, list) + assert len(batch.X_context) == n_estimators n_samples_total = X_raw.shape[0] expected_n_test = int(np.floor(n_samples_total * test_size)) expected_n_train = n_samples_total - expected_n_test - assert y_test_raw_tensor.shape == (expected_n_test,) - assert X_trains_preprocessed[0].shape[0] == expected_n_train + assert batch.y_query.shape == (expected_n_test,) + assert batch.X_context[0].shape[0] == expected_n_train def test_datasetcollectionwithpreprocessing_classification_multiple_datasets( @@ -807,26 +801,15 @@ def test_datasetcollectionwithpreprocessing_classification_multiple_datasets( ) for item_index in range(len(datasets)): - processed_dataset_item = dataset_collection[item_index] - assert isinstance(processed_dataset_item, tuple) - assert len(processed_dataset_item) == 6, ( - "Item tuple should have 6 elements for classification" - ) - ( - X_trains_preprocessed, - _X_tests_preprocessed, - _y_trains_preprocessed, - y_test_raw_tensor, - _cat_ixs, - _returned_ensemble_configs, - ) = processed_dataset_item - assert isinstance(X_trains_preprocessed, list) - assert len(X_trains_preprocessed) == n_estimators + batch = dataset_collection[item_index] + assert isinstance(batch, ClassifierBatch) + assert isinstance(batch.X_context, list) + assert len(batch.X_context) == n_estimators n_samples_total = X_list[item_index].shape[0] expected_n_test = int(np.floor(n_samples_total * test_size)) expected_n_train = n_samples_total - expected_n_test - assert y_test_raw_tensor.shape == (expected_n_test,) - assert X_trains_preprocessed[0].shape[0] == expected_n_train + assert batch.y_query.shape == (expected_n_test,) + assert batch.X_context[0].shape[0] == expected_n_train def test_dataset_and_collator_with_dataloader_uniform( @@ -853,15 +836,13 @@ def test_dataset_and_collator_with_dataloader_uniform( collate_fn=meta_dataset_collator, ) for batch in dl: - # Should be tuple with X_trains, X_tests, y_trains, y_tests, cat_ixs, confs - assert isinstance(batch, tuple) - X_trains, _X_tests, y_trains, _y_tests, _cat_ixs, _confs = batch - for est_tensor in X_trains: + assert isinstance(batch, ClassifierBatch) + for est_tensor in batch.X_context: assert isinstance(est_tensor, torch.Tensor), ( "Each estimator's batch should be a tensor." ) assert est_tensor.shape[0] == batch_size - for est_tensor in y_trains: + for est_tensor in batch.y_context: assert isinstance(est_tensor, torch.Tensor), ( "Each estimator's batch should be a tensor for labels." ) @@ -893,16 +874,15 @@ def test_classifier_dataset_and_collator_batches_type( collate_fn=meta_dataset_collator, ) for batch in dl: - assert isinstance(batch, tuple) - X_trains, _X_tests, y_trains, _y_tests, cat_ixs, confs = batch - for est_tensor in X_trains: + assert isinstance(batch, ClassifierBatch) + for est_tensor in batch.X_context: assert isinstance(est_tensor, torch.Tensor) assert est_tensor.shape[0] == batch_size - for est_tensor in y_trains: + for est_tensor in batch.y_context: assert isinstance(est_tensor, torch.Tensor) assert est_tensor.shape[0] == batch_size - assert isinstance(cat_ixs, list) - for conf in confs: + assert isinstance(batch.cat_indices, list) + for conf in batch.configs: for c in conf: assert isinstance(c, ClassifierEnsembleConfig) break @@ -996,13 +976,15 @@ def test_fit_from_preprocessed_runs( datasets_list, batch_size=batch_size, collate_fn=meta_dataset_collator ) - for data_batch in dl: - X_trains, X_tests, y_trains, y_tests, cat_ixs, confs = data_batch - clf.fit_from_preprocessed(X_trains, y_trains, cat_ixs, confs) - preds = clf.forward(X_tests) + for batch in dl: + assert isinstance(batch, ClassifierBatch) + clf.fit_from_preprocessed( + batch.X_context, batch.y_context, batch.cat_indices, batch.configs + ) + preds = clf.forward(batch.X_query) assert preds.ndim == 3, f"Expected 3D output, got {preds.shape}" - assert preds.shape[0] == X_tests[0].shape[0] - assert preds.shape[0] == y_tests.shape[0] + assert preds.shape[0] == batch.X_query[0].shape[0] + assert preds.shape[0] == batch.y_query.shape[0] assert preds.shape[1] == clf.n_classes_ probs_sum = preds.sum(dim=1) @@ -1124,12 +1106,13 @@ def test_finetuning_consistency_preprocessing_classifier() -> None: collate_fn=meta_dataset_collator, shuffle=False, ) - data_batch = next(iter(dataloader), None) - assert data_batch is not None, "DataLoader yielded no batches." - - X_trains_p2, X_tests_p2, y_trains_p2, _, cat_ixs_p2, confs_p2, *_ = data_batch + batch = next(iter(dataloader), None) + assert batch is not None, "DataLoader yielded no batches." + assert isinstance(batch, ClassifierBatch) - clf_batched.fit_from_preprocessed(X_trains_p2, y_trains_p2, cat_ixs_p2, confs_p2) + clf_batched.fit_from_preprocessed( + batch.X_context, batch.y_context, batch.cat_indices, batch.configs + ) assert hasattr(clf_batched, "models_"), ( "Batched classifier models_ not found after fit_from_preprocessed." ) @@ -1144,7 +1127,7 @@ def test_finetuning_consistency_preprocessing_classifier() -> None: with patch.object( clf_batched.models_[0], "forward", wraps=clf_batched.models_[0].forward ) as mock_forward_p2: - _ = clf_batched.forward(X_tests_p2) + _ = clf_batched.forward(batch.X_query) assert mock_forward_p2.called, "Batched models_[0].forward was not called." # Capture the tensor input 'x' (assuming same argument position as Path 1) diff --git a/tests/test_finetuning_regressor.py b/tests/test_finetuning_regressor.py index 2352c668d..a5fb4137d 100644 --- a/tests/test_finetuning_regressor.py +++ b/tests/test_finetuning_regressor.py @@ -1,109 +1,115 @@ +"""Tests for TabPFN regressor finetuning functionality. + +This module contains regressor-specific tests for: +- The FinetunedTabPFNRegressor wrapper class (.fit() / .predict()). +- Regression checkpoint metric fields (e.g. storing 'mse'). + +We intentionally avoid duplicating tests that primarily exercise common logic in +`FinetunedTabPFNBase`, since those are covered by the classifier finetuning tests. +""" + from __future__ import annotations -import unittest -from functools import partial -from typing import Literal -from unittest.mock import patch +from collections.abc import Callable +from pathlib import Path +from unittest import mock import numpy as np import pytest -import sklearn import torch +from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split -from torch.optim import Adam from torch.utils.data import DataLoader -from tabpfn import TabPFNRegressor -from tabpfn.architectures.base.bar_distribution import ( - BarDistribution, - FullSupportBarDistribution, -) from tabpfn.finetuning.data_util import ( + RegressorBatch, get_preprocessed_dataset_chunks, meta_dataset_collator, ) +from tabpfn.finetuning.finetuned_regressor import ( + FinetunedTabPFNRegressor, +) from tabpfn.preprocessing import RegressorEnsembleConfig +from tabpfn.regressor import TabPFNRegressor -from .utils import get_pytest_devices, mark_mps_configs_as_slow +from .utils import get_pytest_devices rng = np.random.default_rng(42) devices = get_pytest_devices() -fit_modes = [ - "batched", - "fit_preprocessors", -] -inference_precision_methods: list[torch.types._dtype | Literal["autocast", "auto"]] = [ - "auto", - torch.float64, -] -estimators = [1, 2] -optimization_spaces_values = ["raw_label_space", "preprocessed"] - -param_order = [ - "device", - "n_estimators", - "fit_mode", - "inference_precision", - "optimization_space", -] - -default_config = { - "n_estimators": 1, - "device": "cpu", - "fit_mode": "batched", - "inference_precision": "auto", - "optimization_space": "raw_label_space", -} - -param_values: dict[str, list] = { - "n_estimators": estimators, - "device": devices, - "fit_mode": fit_modes, - "inference_precision": inference_precision_methods, - "optimization_space": optimization_spaces_values, -} - -combinations = [tuple(default_config[p] for p in param_order)] -for param_name in param_order: - for value in param_values[param_name]: - if value != default_config[param_name]: - current_config = default_config.copy() - current_config[param_name] = value - combinations.append(tuple(current_config[p] for p in param_order)) +def create_mock_architecture_forward_regression() -> Callable[..., torch.Tensor]: + """Return a side_effect for mocking the internal Architecture forward in regression. -@pytest.fixture(scope="module") -def synthetic_regression_data(): - """Generate synthetic regression data.""" - X = rng.normal(size=(30, 4)).astype(np.float32) - # Generate continuous target variable - y = (X @ rng.normal(size=4)).astype(np.float32) - # Add to previous as line too long otherwise - y += rng.normal(size=30).astype(np.float32) * 0.1 - return X, y + The Architecture.forward method signature is: + forward(x, y, *, only_return_standard_out=True, categorical_inds=None) + Where: + - x has shape (train+test rows, batch_size, num_features) + - y has shape (train rows, batch_size) or (train rows, batch_size, 1) + - returns shape (test rows, batch_size, n_out), with n_out determined by the model. + """ -@pytest.fixture(params=devices) -def ft_regressor_instance(request) -> TabPFNRegressor: - """Provides a basic regressor instance, parameterized by device.""" - device = request.param - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip("CUDA device requested but not available.") - return TabPFNRegressor( - n_estimators=2, - device=device, + def mock_forward( + self: torch.nn.Module, + x: torch.Tensor | dict[str, torch.Tensor], + y: torch.Tensor | dict[str, torch.Tensor] | None, + **_kwargs: bool, + ) -> torch.Tensor: + """Mock forward pass that returns random logits with the correct shape.""" + if isinstance(x, dict): + x = x["main"] + + if y is not None: + y_tensor = y["main"] if isinstance(y, dict) else y + num_train_rows = y_tensor.shape[0] + else: + num_train_rows = 0 + + total_rows = x.shape[0] + batch_size = x.shape[1] + num_test_rows = total_rows - num_train_rows + + # Touch a model parameter so gradients flow during backward pass. + # This mirrors the classifier tests and avoids GradScaler issues on CUDA. + first_param = next(self.parameters()) + param_contribution = 0.0 * first_param.sum() + + n_out = int(getattr(self, "n_out", 1)) + return ( + torch.randn( + num_test_rows, + batch_size, + n_out, + requires_grad=True, + device=x.device, + ) + + param_contribution + ) + + return mock_forward + + +@pytest.fixture(scope="module") +def synthetic_regression_data() -> tuple[np.ndarray, np.ndarray]: + """Generate synthetic regression data for testing.""" + result = make_regression( + n_samples=120, + n_features=6, + n_informative=4, + noise=0.1, random_state=42, - inference_precision=torch.float32, - fit_mode="batched", - differentiable_input=False, + coef=False, ) + X = np.asarray(result[0], dtype=np.float32) + y = np.asarray(result[1], dtype=np.float32) + return X, y @pytest.fixture(params=devices) -def std_regressor_instance(request) -> TabPFNRegressor: - """Provides a basic regressor instance, parameterized by device.""" +def regressor_instance(request: pytest.FixtureRequest) -> TabPFNRegressor: + """Provide a basic regressor instance, parameterized by device.""" device = request.param if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA device requested but not available.") @@ -112,501 +118,191 @@ def std_regressor_instance(request) -> TabPFNRegressor: device=device, random_state=42, inference_precision=torch.float32, - fit_mode="low_memory", + fit_mode="batched", differentiable_input=False, ) -def create_regressor( - n_estimators: int, +@pytest.fixture +def variable_synthetic_regression_dataset_collection() -> list[ + tuple[np.ndarray, np.ndarray] +]: + """Create a small collection of synthetic regression datasets with varying sizes.""" + datasets = [] + dataset_sizes = [10, 20, 30] + num_features = 3 + for num_samples in dataset_sizes: + X = rng.normal(size=(num_samples, num_features)).astype(np.float32) + y = rng.normal(size=(num_samples,)).astype(np.float32) + datasets.append((X, y)) + return datasets + + +@pytest.mark.parametrize( + ("device", "early_stopping", "use_lr_scheduler"), + [ + (device, early_stopping, use_lr_scheduler) + for device in devices + for early_stopping in [True, False] + for use_lr_scheduler in [True, False] + ], +) +def test_finetuned_tabpfn_regressor_fit_and_predict( device: str, - fit_mode: str, - inference_precision: torch.types._dtype | Literal["autocast", "auto"], - **kwargs, -) -> TabPFNRegressor: - """Instantiates regressor with common parameters.""" - if device == "cpu" and inference_precision == "autocast": - pytest.skip("Unsupported combination: CPU with 'autocast'") + early_stopping: bool, + use_lr_scheduler: bool, + synthetic_regression_data: tuple[np.ndarray, np.ndarray], +) -> None: + """Test FinetunedTabPFNRegressor fit/predict with a mocked forward pass.""" if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA device requested but not available.") - default_kwargs = {"random_state": 42} - default_kwargs.update(kwargs) + X, y = synthetic_regression_data + X_train, X_test, y_train, _y_test = train_test_split( + X, y, test_size=0.3, random_state=42 + ) + X_train = np.asarray(X_train) + X_test = np.asarray(X_test) + y_train = np.asarray(y_train) - return TabPFNRegressor( - n_estimators=n_estimators, + epochs = 4 if early_stopping else 2 + finetuned_reg = FinetunedTabPFNRegressor( device=device, - fit_mode=fit_mode, - inference_precision=inference_precision, - **default_kwargs, + epochs=epochs, + learning_rate=1e-4, + validation_split_ratio=0.2, + n_finetune_ctx_plus_query_samples=60, + finetune_ctx_query_split_ratio=0.2, + n_inference_subsample_samples=120, + random_state=42, + early_stopping=early_stopping, + early_stopping_patience=2, + n_estimators_finetune=1, + n_estimators_validation=1, + n_estimators_final_inference=1, + use_lr_scheduler=use_lr_scheduler, + lr_warmup_only=False, ) + mock_forward = create_mock_architecture_forward_regression() + with mock.patch( + "tabpfn.architectures.base.transformer.PerFeatureTransformer.forward", + autospec=True, + side_effect=mock_forward, + ): + finetuned_reg.fit(X_train, y_train) -# --- Tests --- + assert finetuned_reg.is_fitted_ + assert hasattr(finetuned_reg, "finetuned_estimator_") + assert hasattr(finetuned_reg, "finetuned_inference_regressor_") - -def test_regressor_dataset_and_collator_batches_type( - synthetic_regression_data, ft_regressor_instance -): - """Test that the batches returned by the dataset and collator - are of the correct type. - """ - X, y = synthetic_regression_data - dataset_collection = get_preprocessed_dataset_chunks( - ft_regressor_instance, - X, - y, - train_test_split, - max_data_size=100, - model_type="regressor", - equal_split_size=True, - seed=42, - ) - batch_size = 1 - dl = DataLoader( - dataset_collection, - batch_size=batch_size, - collate_fn=meta_dataset_collator, - ) - for batch in dl: - assert isinstance(batch, tuple) - ( - X_trains_preprocessed, - _X_tests_preprocessed, - y_trains_preprocessed, - _y_test_standardized, - cat_ixs, - confs, - raw_space_bardist_, - bar_distribution, - _x_test_raw, - _y_test_raw, - ) = batch - for est_tensor in X_trains_preprocessed: - assert isinstance(est_tensor, torch.Tensor) - assert est_tensor.shape[0] == batch_size - for est_tensor in y_trains_preprocessed: - assert isinstance(est_tensor, torch.Tensor) - assert est_tensor.shape[0] == batch_size - assert isinstance(cat_ixs, list) - for conf in confs: - for c in conf: - assert isinstance(c, RegressorEnsembleConfig) - for ren_crit in raw_space_bardist_: - assert isinstance(ren_crit, FullSupportBarDistribution) - for bar_dist in bar_distribution: - assert isinstance(bar_dist, BarDistribution) - break + predictions = finetuned_reg.predict(X_test) + assert predictions.shape == (X_test.shape[0],) + assert np.isfinite(predictions).all() -@pytest.mark.parametrize(param_order, mark_mps_configs_as_slow(combinations)) -def test_tabpfn_regressor_finetuning_loop( - device, - n_estimators, - fit_mode, - inference_precision, - optimization_space, - synthetic_regression_data, +@pytest.mark.parametrize("device", devices) +def test_regressor_checkpoint_contains_mse_metric( + device: str, + tmp_path: Path, + synthetic_regression_data: tuple[np.ndarray, np.ndarray], ) -> None: + """Ensure regressor checkpoints store regression metrics (mse). + + This also checks that classifier-only metric fields are not stored. + """ + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA device requested but not available.") + X, y = synthetic_regression_data X_train, _X_test, y_train, _y_test = train_test_split( X, y, test_size=0.3, random_state=42 ) + X_train = np.asarray(X_train) + y_train = np.asarray(y_train) + output_folder = tmp_path / "checkpoints_regressor" - reg = create_regressor( - n_estimators, - device, - fit_mode, - inference_precision, - random_state=2, - differentiable_input=False, - ) - - datasets_list = get_preprocessed_dataset_chunks( - reg, - X_train, - y_train, - train_test_split, - max_data_size=100, - model_type="regressor", - equal_split_size=True, - seed=42, - ) - - batch_size = 1 - my_dl_train = DataLoader( - datasets_list, batch_size=batch_size, collate_fn=meta_dataset_collator + finetuned_reg = FinetunedTabPFNRegressor( + device=device, + epochs=2, + learning_rate=1e-4, + validation_split_ratio=0.2, + n_finetune_ctx_plus_query_samples=60, + finetune_ctx_query_split_ratio=0.2, + n_inference_subsample_samples=120, + random_state=42, + early_stopping=False, + use_lr_scheduler=False, + n_estimators_finetune=1, + n_estimators_validation=1, + n_estimators_final_inference=1, + save_checkpoint_interval=1, ) - optim_impl = Adam(reg.models_[0].parameters(), lr=1e-5) - - if inference_precision == torch.float64: - pass - # TODO: check that it fails with the right error - - elif fit_mode in [ - "fit_preprocessors", - "fit_with_cache", - "low_memory", - ]: - # TODO: check that it fails with the right error - pass - else: - for data_batch in my_dl_train: - optim_impl.zero_grad() - - ( - X_trains_preprocessed, - X_tests_preprocessed, - y_trains_preprocessed, - y_test_standardized, - cat_ixs, - confs, - raw_space_bardist_, - _bar_distribution, - _batch_x_test_raw, - batch_y_test_raw, - ) = data_batch - - reg.fit_from_preprocessed( - X_trains_preprocessed, y_trains_preprocessed, cat_ixs, confs - ) - - reg.raw_space_bardist_ = raw_space_bardist_[0] - - averaged_pred_logits, _, _ = reg.forward(X_tests_preprocessed) - - # --- Basic Shape Checks --- - assert averaged_pred_logits.ndim == 3, ( - f"Expected 3D output, got {averaged_pred_logits.shape}" - ) - - # Batch Size - assert averaged_pred_logits.shape[0] == batch_y_test_raw.shape[0] - assert averaged_pred_logits.shape[0] == batch_size - assert averaged_pred_logits.shape[0] == X_tests_preprocessed[0].shape[0] - assert averaged_pred_logits.shape[0] == y_test_standardized.shape[0] - - # N_samples - assert averaged_pred_logits.shape[1] == batch_y_test_raw.shape[1] - assert averaged_pred_logits.shape[1] == y_test_standardized.shape[1] - - # N_bins - n_borders_bardist = reg.znorm_space_bardist_.borders.shape[0] - assert averaged_pred_logits.shape[2] == n_borders_bardist - 1 - n_borders_norm_crit = reg.raw_space_bardist_.borders.shape[0] - assert averaged_pred_logits.shape[2] == n_borders_norm_crit - 1 - - assert len(X_tests_preprocessed) == reg.n_estimators - assert len(X_trains_preprocessed) == reg.n_estimators - assert len(y_trains_preprocessed) == reg.n_estimators - assert reg.models_ is not None, "Model not initialized after fit" - assert hasattr(reg, "znorm_space_bardist_"), ( - "Regressor missing 'znorm_space_bardist_' attribute after fit" - ) - assert hasattr(reg, "raw_space_bardist_"), ( - "Regressor missing 'raw_space_bardist_' attribute after fit" - ) - assert reg.znorm_space_bardist_ is not None, ( - "reg.znorm_space_bardist_ is None" - ) - - lossfn = None - if optimization_space == "raw_label_space": - lossfn = reg.raw_space_bardist_ - elif optimization_space == "preprocessed": - lossfn = reg.znorm_space_bardist_ - else: - raise ValueError("Need to define optimization space") - - nll_loss_per_sample = lossfn( - averaged_pred_logits, batch_y_test_raw.to(device) - ) - loss = nll_loss_per_sample.mean() - - # --- Gradient Check --- - loss.backward() - optim_impl.step() - - assert torch.isfinite(loss).all(), f"Loss is not finite: {loss.item()}" - - gradients_found = False - for param in reg.models_[0].parameters(): - if ( - param.requires_grad - and param.grad is not None - and param.grad.abs().sum().item() > 1e-12 - ): - gradients_found = True - break - assert gradients_found, "No non-zero gradients found." - - reg.models_[0].zero_grad() - break # Only test one batch - - -def test_finetuning_consistency_bar_distribution( - std_regressor_instance, ft_regressor_instance, synthetic_regression_data -): - """Tests if predict() output matches the output derived from - get_preprocessed_datasets -> fit_from_preprocessed -> forward() -> post-processing, - when no actual fine-tuning occurs. - """ - common_seed = 10 - test_set_size = 0.2 + mock_forward = create_mock_architecture_forward_regression() + with mock.patch( + "tabpfn.architectures.base.transformer.PerFeatureTransformer.forward", + autospec=True, + side_effect=mock_forward, + ): + finetuned_reg.fit(X_train, y_train, output_dir=output_folder) - reg_standard = std_regressor_instance - reg_batched = ft_regressor_instance + best_checkpoint_candidates = list(output_folder.glob("checkpoint_*_best.pth")) + assert len(best_checkpoint_candidates) == 1, "Expected exactly one best checkpoint." + best_checkpoint_path = best_checkpoint_candidates[0] - if reg_standard.device != reg_batched.device: - pytest.skip("Devices do not match.") + best_checkpoint = torch.load(best_checkpoint_path, weights_only=False) + assert "state_dict" in best_checkpoint + assert "config" in best_checkpoint + assert "optimizer" in best_checkpoint + assert "epoch" in best_checkpoint + assert "mse" in best_checkpoint + assert "roc_auc" not in best_checkpoint + assert "log_loss" not in best_checkpoint - x_full_raw, y_full_raw = synthetic_regression_data - splitfn = partial( +def test_regressor_dataset_and_collator_batches_type( + variable_synthetic_regression_dataset_collection: list[ + tuple[np.ndarray, np.ndarray] + ], + regressor_instance: TabPFNRegressor, +) -> None: + """Test that dataset and collator produce correctly-typed RegressorBatch objects.""" + X_list = [X for X, _ in variable_synthetic_regression_dataset_collection] + y_list = [y for _, y in variable_synthetic_regression_dataset_collection] + dataset_collection = get_preprocessed_dataset_chunks( + regressor_instance, + X_list, + y_list, train_test_split, - test_size=test_set_size, - random_state=common_seed, - shuffle=False, - ) - - X_train_raw, X_test_raw, y_train_raw, y_test_raw = splitfn(x_full_raw, y_full_raw) - - reg_standard.fit(X_train_raw, y_train_raw) - reg_standard.predict(X_test_raw, output_type="mean") - - datasets_list = get_preprocessed_dataset_chunks( - reg_batched, - x_full_raw, - y_full_raw, - splitfn, - max_data_size=1_000, + 100, model_type="regressor", equal_split_size=True, seed=42, - shuffle=False, ) - batch_size = 1 - dataloader = DataLoader( - datasets_list, - batch_size=batch_size, + dl = DataLoader( + dataset_collection, + batch_size=1, collate_fn=meta_dataset_collator, - shuffle=False, - ) - data_batch = next(iter(dataloader)) - ( - X_trains_preprocessed, - _X_tests_preprocessed, - y_trains_preprocessed, - y_test_standardized, - cat_ixs, - confs, - raw_space_bardist_, - _bar_distribution, - batch_x_test_raw, - batch_y_test_raw, - ) = data_batch - - np.testing.assert_allclose( - batch_y_test_raw.flatten().detach().cpu().numpy(), - y_test_raw, - rtol=1e-5, - atol=1e-5, - ) - - reg_batched.fit_from_preprocessed( - X_trains_preprocessed, y_trains_preprocessed, cat_ixs, confs - ) - - mean = np.mean(y_train_raw) - std = np.std(y_train_raw) - y_train_std_ = std.item() + 1e-20 - y_train_mean_ = mean.item() - y_standardised_investigated = (y_test_raw - y_train_mean_) / y_train_std_ - - np.testing.assert_allclose( - y_test_standardized[0].flatten().detach().cpu().numpy(), - y_standardised_investigated, - rtol=1e-5, - atol=1e-5, ) + for batch in dl: + assert isinstance(batch, RegressorBatch) + for est_tensor in batch.X_context: + assert isinstance(est_tensor, torch.Tensor) + assert est_tensor.shape[0] == 1 + for est_tensor in batch.y_context: + assert isinstance(est_tensor, torch.Tensor) + assert est_tensor.shape[0] == 1 - np.testing.assert_allclose( - batch_x_test_raw[0].detach().cpu().numpy(), - X_test_raw, - rtol=1e-5, - atol=1e-5, - ) - - raw_space_bardist_ = raw_space_bardist_[0] - reg_batched.raw_space_bardist_ = raw_space_bardist_ - - torch.testing.assert_close( - raw_space_bardist_.borders, - reg_batched.raw_space_bardist_.borders, - rtol=1e-5, - atol=1e-5, - msg="Renormalized criterion borders do not match.", - ) - - torch.testing.assert_close( - raw_space_bardist_.borders, - reg_standard.raw_space_bardist_.borders, - rtol=1e-5, # Standard float tolerance - atol=1e-5, - msg="Renormalized criterion borders do not match.", - ) - - torch.testing.assert_close( - reg_standard.raw_space_bardist_.borders, - reg_batched.raw_space_bardist_.borders, - rtol=1e-5, # Standard float tolerance - atol=1e-5, - msg="Renormalized criterion borders do not match.", - ) - - torch.testing.assert_close( - reg_standard.znorm_space_bardist_.borders, - reg_batched.znorm_space_bardist_.borders, - rtol=1e-5, # Standard float tolerance - atol=1e-5, - msg="Bar distribution borders do not match.", - ) - - -# ---------------- - - -class TestTabPFNPreprocessingInspection(unittest.TestCase): - def test_finetuning_consistency_preprocessing_regressor(self): - """In order to test the consistency of our FineTuning code - and the preprocessing code, we will test the consistency - of the preprocessed datasets. We do this by checking - comparing the tensors that enter the internal transformer - model. - """ - test_set_size = 0.3 - common_seed = 42 - n_total = 20 - n_features = 10 - n_estimators = 1 - - X, y = sklearn.datasets.make_regression( - n_samples=n_total, n_features=n_features, random_state=common_seed - ) - splitfn = partial( - train_test_split, - test_size=test_set_size, - random_state=common_seed, - shuffle=False, # Keep False for consistent results if slicing were needed - ) - X_train_raw, X_test_raw, y_train_raw, _ = splitfn(X, y) - - # Initialize two regressors with the inference and FineTuning - reg_standard = TabPFNRegressor( - n_estimators=n_estimators, - device="auto", - random_state=common_seed, - fit_mode="fit_preprocessors", # Example standard mode - ) - reg_batched = TabPFNRegressor( - n_estimators=n_estimators, - device="auto", - random_state=common_seed, - fit_mode="batched", # Mode compatible with get_preprocessed_datasets - ) - - # --- 2. Path 1: Standard fit -> predict -> Capture Tensor --- - reg_standard.fit(X_train_raw, y_train_raw) - assert hasattr(reg_standard, "models_") - assert hasattr(reg_standard.models_[0], "forward") - - tensor_p1_full = None - # Patch the standard regressor's internal model's forward method - with patch.object( - reg_standard.models_[0], "forward", wraps=reg_standard.models_[0].forward - ) as mock_forward_p1: - _ = reg_standard.predict(X_test_raw) # Trigger the patched method - assert mock_forward_p1.called - # Capture the tensor input to the internal model - tensor_p1_full = mock_forward_p1.call_args.args[0] - - assert tensor_p1_full is not None - # Standard path's internal model receives the combined train+test sequence - assert tensor_p1_full.shape[0] == n_total - - # --- 3. Path 3: FT Full Workflow --- - # (get_prep -> fit_prep -> forward -> Capture Tensor) - - datasets_list = get_preprocessed_dataset_chunks( - reg_batched, - X, - y, - splitfn, - max_data_size=1000, - model_type="regressor", - equal_split_size=True, - seed=42, - shuffle=False, - ) - - # Fit FT regressor - dataloader = DataLoader( - datasets_list, - batch_size=1, - collate_fn=meta_dataset_collator, - shuffle=False, - ) - data_batch = next(iter(dataloader)) - ( - X_trains_p2, - X_tests_p2, - y_trains_p2, - _, - cat_ixs_p2, - confs_p2, - _, - _, - _, - _, - ) = data_batch - reg_batched.fit_from_preprocessed( - X_trains_p2, y_trains_p2, cat_ixs_p2, confs_p2 - ) - assert hasattr(reg_batched, "models_") - assert hasattr(reg_batched.models_[0], "forward") - - # Step 3c: Call forward and capture the input tensor to the *internal model* - tensor_p3_full = None - # Patch the *batched* regressor's internal model's forward method - with patch.object( - reg_batched.models_[0], "forward", wraps=reg_batched.models_[0].forward - ) as mock_forward_p3: - # Pass the list of preprocessed test tensors obtained earlier - _ = reg_batched.forward(X_tests_p2) - assert mock_forward_p3.called - # Capture the tensor input to the internal model - tensor_p3_full = mock_forward_p3.call_args.args[0] - - assert tensor_p3_full is not None - # As confirmed before, the internal model in this path - # also receives the full sequence - assert tensor_p3_full.shape[0] == n_total - - # --- 4. Comparison (Path 1 vs Path 3) --- - - # Compare the two full tensors captured from the input to models_[0].forward - # Squeeze dimensions of size 1 for direct comparison - # shapes should be [N_Total, Features+1] - p1_squeezed = tensor_p1_full.squeeze() - p3_squeezed = tensor_p3_full.squeeze() - - assert p1_squeezed.shape == p3_squeezed.shape, ( - "Shapes of final model input tensors mismatch." - ) - - atol = 1e-6 - tensors_match = torch.allclose(p1_squeezed, p3_squeezed, atol=atol) + assert isinstance(batch.cat_indices, list) + for conf in batch.configs: + for c in conf: + assert isinstance(c, RegressorEnsembleConfig) - assert tensors_match, "Mismatch between preprocessed model input tensors." + assert isinstance(batch.X_query_raw, torch.Tensor) + assert isinstance(batch.y_query_raw, torch.Tensor) + assert batch.X_query_raw.shape[0] == 1 + assert batch.y_query_raw.shape[0] == 1 + assert batch.y_query.shape[0] == 1 + break