diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 4aec60140..ca06a83c8 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -879,6 +879,51 @@ def predict( quantiles: list[float] | None = None, ) -> FullOutputDict: ... + def _raw_predict(self, X: XType) -> torch.Tensor: + """Internal method to run prediction. + + Handles input validation, preprocessing, and the forward pass. + Returns the stack of aligned probabilities from all estimators + (shape: [n_estimators, n_samples, n_borders]) mapped to the global + `znorm_space_bardist_` borders. + """ + check_is_fitted(self) + + # TODO: Move these at some point to InferenceEngine + X = validate_X_predict(X, self) + + # Constant target handling + if hasattr(self, "is_constant_target_") and self.is_constant_target_: + # If the target is constant, we have a single bucket (len(borders) - 1). + # We return ones (probability of 1.0) for this single bucket. + n_buckets = len(self.znorm_space_bardist_.borders) - 1 + return torch.ones( + (self.n_estimators, X.shape[0], n_buckets), + device=self.devices_[0], + ) + + X = fix_dtypes(X, cat_indices=self.inferred_categorical_indices_) + X = process_text_na_dataframe(X, ord_encoder=self.preprocessor_) # type: ignore + + # Runs over iteration engine + ( + _, + outputs, # list of tensors [N_est, N_samples, N_borders] (after forward) + borders, # list of numpy arrays containing borders for each estimator + ) = self.forward(X, use_inference_mode=True) + + # --- Translate probs --- + # Map specific estimator borders to the global znorm_space_bardist_ + transformed_probs = [ + translate_probs_across_borders( + logits, + frm=torch.as_tensor(borders_t, device=logits.device), + to=self.znorm_space_bardist_.borders.to(logits.device), + ) + for logits, borders_t in zip(outputs, borders) + ] + return torch.stack(transformed_probs, dim=0) + @config_context(transform_output="default") # type: ignore @track_model_call(model_method="predict", param_names=["X"]) def predict( @@ -920,11 +965,6 @@ def predict( """ check_is_fitted(self) - # TODO: Move these at some point to InferenceEngine - X = validate_X_predict(X, self) - - check_is_fitted(self) - if quantiles is None: quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] else: @@ -935,42 +975,31 @@ def predict( raise ValueError(f"Invalid output type: {output_type}") if hasattr(self, "is_constant_target_") and self.is_constant_target_: + # We must validate X even if constant target to ensure shape is correct + X = validate_X_predict(X, self) return self._handle_constant_target(X.shape[0], output_type, quantiles) - X = fix_dtypes(X, cat_indices=self.inferred_categorical_indices_) - X = process_text_na_dataframe(X, ord_encoder=self.preprocessor_) # type: ignore + # Get the aligned stack of probabilities from all estimators + stacked_probs = self._raw_predict(X) - # Runs over iteration engine - ( - _, - outputs, # list of tensors [N_est, N_samples, N_borders] (after forward) - borders, # list of numpy arrays containing borders for each estimator - ) = self.forward(X, use_inference_mode=True) - - # --- Translate probs, average, get final logits --- - transformed_logits = [ - translate_probs_across_borders( - logits, - frm=torch.as_tensor(borders_t, device=logits.device), - to=self.znorm_space_bardist_.borders.to(logits.device), - ) - for logits, borders_t in zip(outputs, borders) - ] - stacked_logits = torch.stack(transformed_logits, dim=0) if self.average_before_softmax: - logits = stacked_logits.log().mean(dim=0).softmax(dim=-1) + # stacked_probs from _raw_predict are probabilities (sum to 1) + # We take log to get logits, average them, then softmax + ensemble_probs = stacked_probs.log().mean(dim=0).softmax(dim=-1) else: - logits = stacked_logits.mean(dim=0) + # Average the probabilities directly + ensemble_probs = stacked_probs.mean(dim=0) # Post-process the logits - logits = logits.log() - if logits.dtype == torch.float16: - logits = logits.float() + # We ensure we are working with log-probabilities for the criterion methods + ensemble_log_probs = ensemble_probs.log() + if ensemble_log_probs.dtype == torch.float16: + ensemble_log_probs = ensemble_log_probs.float() # Determine and return intended output type logit_to_output = partial( _logits_to_output, - logits=logits, + logits=ensemble_log_probs, criterion=self.raw_space_bardist_, quantiles=quantiles, ) @@ -1000,13 +1029,36 @@ def predict( return FullOutputDict( **main_outputs, criterion=self.raw_space_bardist_, - logits=logits, + logits=ensemble_log_probs, ) return main_outputs return logit_to_output(output_type=output_type) + @config_context(transform_output="default") + @track_model_call(model_method="predict", param_names=["X"]) + def predict_raw_logits(self, X: XType) -> np.ndarray: + """Predict the raw logits for the provided input samples. + + This method returns the raw logits for each estimator, without averaging + estimators. In the case of regression, these logits are aligned to the + global bar distribution used by the model (handling potential target + shifting/scaling per estimator). + + Args: + X: The input data for prediction. + + Returns: + An array of predicted logits for each estimator, + Shape (n_estimators, n_samples, n_bins). + """ + # _raw_predict returns aligned probabilities (output of translate_probs) + stacked_probs = self._raw_predict(X) + + # Convert probabilities to logits (log-space) and detach + return stacked_probs.log().float().detach().cpu().numpy() + def forward( self, X: list[torch.Tensor] | XType, diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 87bf39599..6fb9cc1c0 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -834,3 +834,42 @@ def test__create_default_for_version__passes_through_overrides() -> None: assert estimator.n_estimators == 16 assert estimator.softmax_temperature == 0.9 + + +@pytest.mark.parametrize("n_estimators", [1, 2]) +def test_predict_raw_logits( + X_y: tuple[np.ndarray, np.ndarray], + n_estimators: int, +) -> None: + """Tests the predict_raw_logits method.""" + X, y = X_y + + # Ensure y is float32 for consistency + y = y.astype(np.float32) + + regressor = TabPFNRegressor( + n_estimators=n_estimators, + random_state=42, + ) + regressor.fit(X, y) + + logits = regressor.predict_raw_logits(X) + + # The number of bins is determined by the internal bar distribution borders + # (borders - 1 = number of buckets/logits) + n_bins = regressor.znorm_space_bardist_.borders.shape[0] - 1 + + assert isinstance(logits, np.ndarray) + assert logits.shape == (n_estimators, X.shape[0], n_bins) + assert logits.dtype == np.float32 + assert not np.isnan(logits).any() + + # Regressors can have -inf logits (log(0) probability), but should not have +inf + assert not (np.isinf(logits) & (logits > 0)).any(), "Found +inf in logits" + + if n_estimators > 1: + # Ensure estimators are providing different outputs (diversity check) + # We check if the first estimator is exactly equal to the second one across all data + assert not np.all(logits[0] == logits[1]), ( + "Logits are identical across estimators, indicating trivial output." + )