diff --git a/CHANGELOG.md b/CHANGELOG.md index 4581b672f..91833357a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - @benraha Improved the inference speed on CPU significantly [#459](https://github.com/PriorLabs/TabPFN/pull/459). - @benraha Added a fast-path for the column selection in RemoveEmptyFeaturesEncoderStep [#468](https://github.com/PriorLabs/TabPFN/pull/468). +- **(Breaking)** The `TabPFNRegressor.forward()` method signature has changed. It now returns a single logits tensor instead of a tuple, simplifying its interface for finetuning. +- Reduced memory consumption for `TabPFNRegressor` during inference by processing ensemble outputs sequentially instead of stacking them in memory. This improves performance, especially when using a high `n_estimators`. ### Bug Fixes diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index af1f7b5e2..ca915a2f2 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -20,7 +20,7 @@ import logging import typing import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable, Generator, Sequence from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Union @@ -45,7 +45,6 @@ get_preprocessed_datasets_helper, initialize_model_variables_helper, ) -from tabpfn.inference import InferenceEngine, InferenceEngineBatchedNoPreprocessing from tabpfn.model_loading import load_fitted_tabpfn_model, save_fitted_tabpfn_model from tabpfn.preprocessing import ( DatasetCollectionWithPreprocessing, @@ -126,6 +125,19 @@ class FullOutputDict(MainOutputDict): class TabPFNRegressor(RegressorMixin, BaseEstimator): """TabPFNRegressor class.""" + # Default quantiles returned by the predict method + _DEFAULT_REGRESSION_QUANTILES: typing.ClassVar[list[float]] = [ + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + ] + config_: ArchitectureConfig """The configuration of the loaded model to be used for inference. @@ -746,6 +758,25 @@ def fit(self, X: XType, y: YType) -> Self: return self + def _raw_predict(self, X: XType) -> torch.Tensor: + """Handles preprocessing and calls the forward pass to get final predictions.""" + # 1. Preprocess the input data + X_processed = validate_X_predict(X, self) + X_processed = fix_dtypes( + X_processed, cat_indices=self.inferred_categorical_indices_ + ) + X_processed = process_text_na_dataframe( + X_processed, ord_encoder=self.preprocessor_ + ) + + # 2. Get final logits directly from the efficient forward pass + final_logits = self.forward(X_processed, use_inference_mode=True) + + if final_logits is None: + raise ValueError("Prediction failed: the model produced no output.") + + return final_logits + @overload def predict( self, @@ -798,7 +829,6 @@ def predict( X: The input data. output_type: Determines the type of output to return. - - If `"mean"`, we return the mean over the predicted distribution. - If `"median"`, we return the median over the predicted distribution. - If `"mode"`, we return the mode over the predicted distribution. @@ -808,10 +838,8 @@ def predict( - If `"main"`, we return the all output types above in a dict. - If `"full"`, we return the full output of the model, including the logits and the criterion, and all the output types from "main". - quantiles: The quantiles to return if `output="quantiles"`. - By default, the `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]` quantiles are returned. The predictions per quantile match the input order. @@ -822,54 +850,25 @@ def predict( """ check_is_fitted(self) - # TODO: Move these at some point to InferenceEngine - X = validate_X_predict(X, self) + if quantiles is None: + quantiles = self._DEFAULT_REGRESSION_QUANTILES.copy() - check_is_fitted(self) + assert all( + (0 <= q <= 1) and (isinstance(q, float)) for q in quantiles + ), "All quantiles must be between 0 and 1 and floats." - if quantiles is None: - quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] - else: - assert all( - (0 <= q <= 1) and (isinstance(q, float)) for q in quantiles - ), "All quantiles must be between 0 and 1 and floats." if output_type not in _USABLE_OUTPUT_TYPES: raise ValueError(f"Invalid output type: {output_type}") if hasattr(self, "is_constant_target_") and self.is_constant_target_: - return self._handle_constant_target(X.shape[0], output_type, quantiles) - - X = fix_dtypes(X, cat_indices=self.inferred_categorical_indices_) - X = process_text_na_dataframe(X, ord_encoder=self.preprocessor_) # type: ignore - - # Runs over iteration engine - ( - _, - outputs, # list of tensors [N_est, N_samples, N_borders] (after forward) - borders, # list of numpy arrays containing borders for each estimator - ) = self.forward(X, use_inference_mode=True) - - # --- Translate probs, average, get final logits --- - transformed_logits = [ - translate_probs_across_borders( - logits, - frm=torch.as_tensor(borders_t, device=self.device_), - to=self.znorm_space_bardist_.borders.to(self.device_), + return self._handle_constant_target( + validate_X_predict(X, self).shape[0], output_type, quantiles ) - for logits, borders_t in zip(outputs, borders) - ] - stacked_logits = torch.stack(transformed_logits, dim=0) - if self.average_before_softmax: - logits = stacked_logits.log().mean(dim=0).softmax(dim=-1) - else: - logits = stacked_logits.mean(dim=0) - # Post-process the logits - logits = logits.log() - if logits.dtype == torch.float16: - logits = logits.float() + # Get the final logits from our single, powerful helper method + logits = self._raw_predict(X) - # Determine and return intended output type + # Convert final logits to the requested output format logit_to_output = partial( _logits_to_output, logits=logits, @@ -877,8 +876,6 @@ def predict( quantiles=quantiles, ) if output_type in ["full", "main"]: - # Create a dictionary of outputs with proper typing via TypedDict - # Get individual outputs with proper typing mean_out = typing.cast("np.ndarray", logit_to_output(output_type="mean")) median_out = typing.cast( "np.ndarray", logit_to_output(output_type="median") @@ -888,79 +885,134 @@ def predict( "list[np.ndarray]", logit_to_output(output_type="quantiles"), ) - - # Create our typed dictionary main_outputs = MainOutputDict( mean=mean_out, median=median_out, mode=mode_out, quantiles=quantiles_out, ) - if output_type == "full": - # Return full output with criterion and logits return FullOutputDict( **main_outputs, criterion=self.raw_space_bardist_, logits=logits, ) - return main_outputs return logit_to_output(output_type=output_type) + def _get_raw_output_generator( + self, X: XType, *, use_inference_mode: bool + ) -> Generator[tuple[torch.Tensor, RegressorEnsembleConfig], None, None]: + """A generator that executes the model and yields the raw output tensor + and its corresponding config for each estimator in the ensemble. + """ + # This method is only supported for fit_modes that might switch between + # training and inference. + if self.fit_mode in ["fit_preprocessors", "batched"]: + self.executor_.use_torch_inference_mode(use_inference=use_inference_mode) + + for output, config in self.executor_.iter_outputs( + X, device=self.device_, autocast=self.use_autocast_ + ): + config_for_ensemble = config[0] if isinstance(config, list) else config + if not isinstance(config_for_ensemble, RegressorEnsembleConfig): + raise TypeError( + f"Expected RegressorEnsembleConfig, " + f"but got {type(config_for_ensemble).__name__}." + ) + + yield output, config_for_ensemble + + def _process_one_output_for_prediction( + self, output: torch.Tensor, config: RegressorEnsembleConfig + ) -> torch.Tensor: + """Processes a single raw model output for the prediction path. + + This includes border transformation, temperature scaling, and probability + translation to the standardized target distribution. + """ + std_borders = self.bardist_.borders.cpu().numpy() + + # Check if a target transform exists before applying it + if config.target_transform is None: + borders_t = std_borders + logit_cancel_mask = None + descending_borders = False + else: + # A transform exists, so call the helper + ( + logit_cancel_mask, + descending_borders, + borders_t, + ) = transform_borders_one( + std_borders, + target_transform=config.target_transform, + repair_nan_borders_after_transform=( + self.interface_config_.FIX_NAN_BORDERS_AFTER_TARGET_TRANSFORM + ), + ) + + # Apply transformations based on the results + if descending_borders: + # Some target transforms can reverse the bin order + borders_t = borders_t.flip(-1) + + if logit_cancel_mask is not None: + # Clone to avoid modifying the original tensor which might be used elsewhere + output = output.clone() + output[..., logit_cancel_mask] = float("-inf") + + # Apply temperature + if self.softmax_temperature != 1.0: + output = output.float() / self.softmax_temperature + + # Translate to standardized borders and return probabilities + return translate_probs_across_borders( + output, + frm=torch.as_tensor(borders_t, device=self.device_), + to=self.bardist_.borders.to(self.device_), + ) + def forward( self, X: list[torch.Tensor] | XType, *, use_inference_mode: bool = False, - ) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]: - """Forward pass for TabPFNRegressor Inference Engine. - Used in fine-tuning and prediction. Called directly - in FineTuning training loop or by predict() function - with the use_inference_mode flag explicitly set to True. - - Iterates over outputs of InferenceEngine. + ) -> torch.Tensor | None: + """Performs a memory-efficient forward pass for fine-tuning or prediction. + Includes an optimized fast path for fine-tuning when no target transforms + are used. + """ + check_is_fitted(self) - Args: - X: list[torch.Tensor] in fine-tuning, XType in normal predictions. - use_inference_mode: Flag for inference mode., default at False since - it is called within predict. During FineTuning forward() is called - directly by user, so default should be False here. + # --- Pre-flight Checks and Assertions --- + # This import is only needed for type checking + from tabpfn.inference import InferenceEngineBatchedNoPreprocessing - Returns: - A tuple containing: - - Averaged logits over the ensemble (for fine-tuning). - - Raw outputs from each estimator in the ensemble. - - Borders used for each estimator. - """ - # Scenario 1: Standard inference path + # Scenario 1: Standard inference path for predict() is_standard_inference = use_inference_mode and not isinstance( self.executor_, InferenceEngineBatchedNoPreprocessing ) - - # Scenario 2: Batched path, typically for fine-tuning with gradients + # Scenario 2: Batched path for fine-tuning with gradients is_batched_for_grads = ( not use_inference_mode and isinstance(self.executor_, InferenceEngineBatchedNoPreprocessing) and isinstance(X, list) and (not X or isinstance(X[0], torch.Tensor)) ) - assert is_standard_inference or is_batched_for_grads, ( "Invalid forward pass: Bad combination of inference mode, input X, " - "or executor type. Ensure call is from standard predict or a " - "batched fine-tuning context." + "or executor type. Ensure the call is from the standard predict() method " + "or a batched fine-tuning context." ) - - # Specific check for float64 incompatibility if the batched engine is being - # used, now framed as an assertion that the problematic condition is NOT met. + # Specific check for float64 incompatibility with the fine-tuning workflow. assert not ( isinstance(self.executor_, InferenceEngineBatchedNoPreprocessing) and self.forced_inference_dtype_ == torch.float64 ), ( "Batched engine error: float64 precision is not supported for the " - "fine-tuning workflow (requires float32 for backpropagation)." + "fine-tuning workflow, which requires float32 for backpropagation." ) # Ensure torch.inference_mode is OFF to allow gradients @@ -968,79 +1020,72 @@ def forward( # only these two modes support this option self.executor_.use_torch_inference_mode(use_inference=use_inference_mode) - check_is_fitted(self) + # --- Check for Fast Path Optimization --- + configs = self.executor_.ensemble_configs - std_borders = self.znorm_space_bardist_.borders.cpu().numpy() - outputs: list[torch.Tensor] = [] - borders: list[np.ndarray] = [] + # Create a single iterator that handles both flat and nested lists + # TODO: We can use the fast path as long as the target transforms are the same + # to make it simpler we disable the fast path for now + # We could even merge transformed logits but this would likely be numerically + # unstable. + can_use_fast_path = len(configs) == 1 - # Iterate over estimators - for output, config in self.executor_.iter_outputs( - X, - device=self.device_, - autocast=self.use_autocast_, - ): - if self.softmax_temperature != 1: - output = output.float() / self.softmax_temperature # noqa: PLW2901 - - # BSz.= 1 Scenario, the same as normal predict() function - # Handled by first if-statement - config_for_ensemble = config - if isinstance(config, list) and len(config) == 1: - single_config = config[0] - config_for_ensemble = single_config - - if isinstance(config_for_ensemble, RegressorEnsembleConfig): - borders_t: np.ndarray - logit_cancel_mask: np.ndarray | None - descending_borders: bool - - # TODO(eddiebergman): Maybe this could be parallelized or done in fit - # but I somehow doubt it takes much time to be worth it. - # One reason to make it worth it is if you want fast predictions, i.e. - # don't re-do this each time. - # However it gets a bit more difficult as you need to line up the - # outputs from `iter_outputs` above (which may be in arbitrary order), - # along with the specific config the output belongs to. This is because - # the transformation done to the borders for a given output is dependant - # upon the target_transform of the config. - if config_for_ensemble.target_transform is None: - borders_t = std_borders.copy() - logit_cancel_mask = None - descending_borders = False - else: - logit_cancel_mask, descending_borders, borders_t = ( - transform_borders_one( - std_borders, - target_transform=config_for_ensemble.target_transform, - repair_nan_borders_after_transform=self.interface_config_.FIX_NAN_BORDERS_AFTER_TARGET_TRANSFORM, - ) - ) - if descending_borders: - borders_t = borders_t.flip(-1) # type: ignore - - borders.append(borders_t) - - if logit_cancel_mask is not None: - output = output.clone() # noqa: PLW2901 - output[..., logit_cancel_mask] = float("-inf") + output_generator = self._get_raw_output_generator( + X, use_inference_mode=use_inference_mode + ) - else: - raise ValueError( - "Unexpected config format " - "and Batch prediction is not supported yet!" + final_logits: torch.Tensor | None + + if can_use_fast_path and not use_inference_mode: + # --- Fast Path for Fine-Tuning - gradients are cleaner, no sotftmax --- + sum_of_logits = None + estimator_count = 0 + for raw_output, _ in output_generator: + estimator_count += 1 + sum_of_logits = ( + raw_output if sum_of_logits is None else sum_of_logits + raw_output ) + final_logits = ( + (sum_of_logits / estimator_count) if sum_of_logits is not None else None + ) + elif not can_use_fast_path and not use_inference_mode: + raise NotImplementedError( + "The fast path for fine-tuning is only supported when using a single " + "ensemble config for now." + ) + else: + # --- General Path (for Prediction OR if transforms used with finetune) --- + accumulator = None + estimator_count = 0 + for raw_output, config in output_generator: + estimator_count += 1 + translated_proba = self._process_one_output_for_prediction( + raw_output, config + ) + current_val = ( + translated_proba.log() + if self.average_before_softmax + else translated_proba + ) + accumulator = ( + current_val if accumulator is None else accumulator + current_val + ) + + avg_val = accumulator / estimator_count + final_logits = ( + torch.nn.functional.log_softmax(avg_val, dim=-1) + if self.average_before_softmax + else avg_val.log() + ) - outputs.append(output) # type: ignore + # --- Final Processing and Return --- + if final_logits.dtype == torch.float16: + return final_logits.float() - averaged_logits = None - all_logits = None - if outputs: - all_logits = torch.stack(outputs, dim=0) # [N_est, N_sampls, N_bord] - averaged_logits_over_ensemble = torch.mean(all_logits, dim=0) - averaged_logits = averaged_logits_over_ensemble.transpose(0, 1) + if not use_inference_mode: + return final_logits.transpose(0, 1) - return averaged_logits, outputs, borders + return final_logits def _handle_constant_target( self, n_samples: int, output_type: OutputType, quantiles: list[float] diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 2c63e62c4..5b78a223a 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -23,7 +23,12 @@ from tabpfn.base import RegressorModelSpecs, initialize_tabpfn_model from tabpfn.model_loading import ModelSource from tabpfn.preprocessing import PreprocessorConfig -from tabpfn.utils import infer_device_and_type +from tabpfn.utils import ( + _fix_dtypes, + _process_text_na_dataframe, + infer_device_and_type, + validate_X_predict, +) from .utils import check_cpu_float16_support, get_pytest_devices @@ -641,3 +646,45 @@ def test_initialize_model_variables_regressor_sets_required_attributes() -> None assert ( reg2.znorm_space_bardist_ is not None ), "znorm_space_bardist_ should be initialized for regressor2" + + +@pytest.mark.parametrize("average_before_softmax", [True, False]) +def test_forward_predict_logit_consistency( + X_y: tuple[np.ndarray, np.ndarray], average_before_softmax: bool +) -> None: + """Verify that the low-level `forward` method's output is identical to the + 'logits' returned by the high-level `predict(output_type='full')` method. + """ + X, y = X_y + model = TabPFNRegressor( + n_estimators=2, + average_before_softmax=average_before_softmax, + random_state=42, + device="cpu", + inference_precision=torch.float64, # Use high precision for stability + ) + model.fit(X, y) + + # 1. Get the "ground truth" logits from the high-level predict() API + # The `predict` method internally calls `_raw_predict`, which calls `forward`. + full_prediction = model.predict(X, output_type="full") + logits_from_predict = full_prediction["logits"] + + # 2. Call the low-level `forward()` method directly to get its output. + # We must replicate the minimal preprocessing done in `_raw_predict`. + X_processed = validate_X_predict(X, model) + X_processed = _fix_dtypes( + X_processed, cat_indices=model.inferred_categorical_indices_ + ) + X_processed = _process_text_na_dataframe( + X_processed, ord_encoder=model.preprocessor_ + ) + logits_from_forward = model.forward(X_processed, use_inference_mode=True) + + # 3. Assert the logits from both paths are identical. + assert logits_from_forward is not None + torch.testing.assert_close( + logits_from_predict, + logits_from_forward, + msg="Logits from low-level forward() differ from high-level predict().", + )