Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 83 additions & 31 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,51 @@ def predict(
quantiles: list[float] | None = None,
) -> FullOutputDict: ...

def _raw_predict(self, X: XType) -> torch.Tensor:
"""Internal method to run prediction.

Handles input validation, preprocessing, and the forward pass.
Returns the stack of aligned probabilities from all estimators
(shape: [n_estimators, n_samples, n_borders]) mapped to the global
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for _raw_predict states that the returned tensor has a shape of [n_estimators, n_samples, n_borders]. However, the last dimension corresponds to the number of bins/buckets, which is len(borders) - 1. To avoid confusion, it would be clearer to use n_bins instead of n_borders.

Suggested change
(shape: [n_estimators, n_samples, n_borders]) mapped to the global
(shape: [n_estimators, n_samples, n_bins]) mapped to the global

`znorm_space_bardist_` borders.
"""
check_is_fitted(self)

# TODO: Move these at some point to InferenceEngine
X = validate_X_predict(X, self)

# Constant target handling
if hasattr(self, "is_constant_target_") and self.is_constant_target_:
# If the target is constant, we have a single bucket (len(borders) - 1).
# We return ones (probability of 1.0) for this single bucket.
n_buckets = len(self.znorm_space_bardist_.borders) - 1
return torch.ones(
(self.n_estimators, X.shape[0], n_buckets),
device=self.devices_[0],
)

X = fix_dtypes(X, cat_indices=self.inferred_categorical_indices_)
X = process_text_na_dataframe(X, ord_encoder=self.preprocessor_) # type: ignore

# Runs over iteration engine
(
_,
outputs, # list of tensors [N_est, N_samples, N_borders] (after forward)
borders, # list of numpy arrays containing borders for each estimator
) = self.forward(X, use_inference_mode=True)

# --- Translate probs ---
# Map specific estimator borders to the global znorm_space_bardist_
transformed_probs = [
translate_probs_across_borders(
logits,
frm=torch.as_tensor(borders_t, device=logits.device),
to=self.znorm_space_bardist_.borders.to(logits.device),
)
for logits, borders_t in zip(outputs, borders)
]
return torch.stack(transformed_probs, dim=0)

@config_context(transform_output="default") # type: ignore
@track_model_call(model_method="predict", param_names=["X"])
def predict(
Expand Down Expand Up @@ -920,11 +965,6 @@ def predict(
"""
check_is_fitted(self)

# TODO: Move these at some point to InferenceEngine
X = validate_X_predict(X, self)

check_is_fitted(self)

if quantiles is None:
quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
else:
Expand All @@ -935,42 +975,31 @@ def predict(
raise ValueError(f"Invalid output type: {output_type}")

if hasattr(self, "is_constant_target_") and self.is_constant_target_:
# We must validate X even if constant target to ensure shape is correct
X = validate_X_predict(X, self)
return self._handle_constant_target(X.shape[0], output_type, quantiles)

X = fix_dtypes(X, cat_indices=self.inferred_categorical_indices_)
X = process_text_na_dataframe(X, ord_encoder=self.preprocessor_) # type: ignore
# Get the aligned stack of probabilities from all estimators
stacked_probs = self._raw_predict(X)

# Runs over iteration engine
(
_,
outputs, # list of tensors [N_est, N_samples, N_borders] (after forward)
borders, # list of numpy arrays containing borders for each estimator
) = self.forward(X, use_inference_mode=True)

# --- Translate probs, average, get final logits ---
transformed_logits = [
translate_probs_across_borders(
logits,
frm=torch.as_tensor(borders_t, device=logits.device),
to=self.znorm_space_bardist_.borders.to(logits.device),
)
for logits, borders_t in zip(outputs, borders)
]
stacked_logits = torch.stack(transformed_logits, dim=0)
if self.average_before_softmax:
logits = stacked_logits.log().mean(dim=0).softmax(dim=-1)
# stacked_probs from _raw_predict are probabilities (sum to 1)
# We take log to get logits, average them, then softmax
ensemble_probs = stacked_probs.log().mean(dim=0).softmax(dim=-1)
else:
logits = stacked_logits.mean(dim=0)
# Average the probabilities directly
ensemble_probs = stacked_probs.mean(dim=0)

# Post-process the logits
logits = logits.log()
if logits.dtype == torch.float16:
logits = logits.float()
# We ensure we are working with log-probabilities for the criterion methods
ensemble_log_probs = ensemble_probs.log()
if ensemble_log_probs.dtype == torch.float16:
ensemble_log_probs = ensemble_log_probs.float()

# Determine and return intended output type
logit_to_output = partial(
_logits_to_output,
logits=logits,
logits=ensemble_log_probs,
criterion=self.raw_space_bardist_,
quantiles=quantiles,
)
Expand Down Expand Up @@ -1000,13 +1029,36 @@ def predict(
return FullOutputDict(
**main_outputs,
criterion=self.raw_space_bardist_,
logits=logits,
logits=ensemble_log_probs,
)

return main_outputs

return logit_to_output(output_type=output_type)

@config_context(transform_output="default")
@track_model_call(model_method="predict", param_names=["X"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @track_model_call decorator for predict_raw_logits appears to have a copy-pasted model_method from the predict method. For accurate telemetry, this should be updated to "predict_raw_logits".

Suggested change
@track_model_call(model_method="predict", param_names=["X"])
@track_model_call(model_method="predict_raw_logits", param_names=["X"])

def predict_raw_logits(self, X: XType) -> np.ndarray:
"""Predict the raw logits for the provided input samples.

This method returns the raw logits for each estimator, without averaging
estimators. In the case of regression, these logits are aligned to the
global bar distribution used by the model (handling potential target
shifting/scaling per estimator).

Args:
X: The input data for prediction.

Returns:
An array of predicted logits for each estimator,
Shape (n_estimators, n_samples, n_bins).
"""
# _raw_predict returns aligned probabilities (output of translate_probs)
stacked_probs = self._raw_predict(X)

# Convert probabilities to logits (log-space) and detach
return stacked_probs.log().float().detach().cpu().numpy()

def forward(
self,
X: list[torch.Tensor] | XType,
Expand Down
39 changes: 39 additions & 0 deletions tests/test_regressor_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,3 +834,42 @@

assert estimator.n_estimators == 16
assert estimator.softmax_temperature == 0.9


@pytest.mark.parametrize("n_estimators", [1, 2])
def test_predict_raw_logits(
X_y: tuple[np.ndarray, np.ndarray],
n_estimators: int,
) -> None:
"""Tests the predict_raw_logits method."""
X, y = X_y

# Ensure y is float32 for consistency
y = y.astype(np.float32)

regressor = TabPFNRegressor(
n_estimators=n_estimators,
random_state=42,
)
regressor.fit(X, y)

logits = regressor.predict_raw_logits(X)

Check failure on line 857 in tests/test_regressor_interface.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W293)

tests/test_regressor_interface.py:857:1: W293 Blank line contains whitespace
# The number of bins is determined by the internal bar distribution borders
# (borders - 1 = number of buckets/logits)
n_bins = regressor.znorm_space_bardist_.borders.shape[0] - 1

assert isinstance(logits, np.ndarray)
assert logits.shape == (n_estimators, X.shape[0], n_bins)
assert logits.dtype == np.float32
assert not np.isnan(logits).any()

Check failure on line 866 in tests/test_regressor_interface.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W293)

tests/test_regressor_interface.py:866:1: W293 Blank line contains whitespace
# Regressors can have -inf logits (log(0) probability), but should not have +inf
assert not (np.isinf(logits) & (logits > 0)).any(), "Found +inf in logits"

Check failure on line 869 in tests/test_regressor_interface.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W293)

tests/test_regressor_interface.py:869:1: W293 Blank line contains whitespace
if n_estimators > 1:
# Ensure estimators are providing different outputs (diversity check)
# We check if the first estimator is exactly equal to the second one across all data

Check failure on line 872 in tests/test_regressor_interface.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (E501)

tests/test_regressor_interface.py:872:89: E501 Line too long (92 > 88)
assert not np.all(logits[0] == logits[1]), (
"Logits are identical across estimators, indicating trivial output."
)
Loading