Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 69 additions & 14 deletions examples/finetune_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,56 @@
from tabpfn.utils import meta_dataset_collator


def rps_loss(
outputs: torch.Tensor,
targets: torch.Tensor,
bucket_widths: torch.Tensor,
) -> torch.Tensor:
"""
Calculates the Ranked Probability Score (RPS) loss for binned regression with variable bin widths.
RPS measures the difference between predicted cumulative probabilities
and the true cumulative probabilities for ordered categorical outcomes.
Lower RPS indicates better probabilistic forecasts.
The formula for RPS for a single prediction and true value is:
$$ \text{RPS} = \sum_{k=1}^{K} w_k (\hat{F}(k) - F(k))^2 $$

Check failure on line 40 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W605)

examples/finetune_regressor.py:40:41: W605 Invalid escape sequence: `\h`

Check failure on line 40 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W605)

examples/finetune_regressor.py:40:21: W605 Invalid escape sequence: `\s`
where $K$ is the number of bins, $w_k$ is the width of bin $k$, $\hat{F}(k)$ is the

Check failure on line 41 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W605)

examples/finetune_regressor.py:41:70: W605 Invalid escape sequence: `\h`
predicted cumulative probability up to bin $k$, and $F(k)$ is the observed (true)
cumulative probability up to bin $k$.
The function first converts the true target value into a one-hot encoded vector,
effectively creating an empirical probability mass function (PMF) where the true bin
has a probability of 1 and all others are 0. This PMF is then converted into a
step-like empirical cumulative distribution function (CDF).
For variable bin widths (e.g., quantile-based bins), bucket_widths should be provided
to properly weight the squared CDF differences.
Args:
outputs (torch.Tensor): Predicted probabilities for each bin. Shape (batch_size, num_bins).
targets (torch.Tensor): True bin labels (integer indices). Shape (batch_size,).
bucket_widths (torch.Tensor): Width of each bin. Shape (num_bins,).
Returns:
torch.Tensor: The mean RPS loss for the batch.
"""

Check failure on line 60 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (D301)

examples/finetune_regressor.py:32:5: D301 Use `r"""` if any backslashes in a docstring

Check failure on line 60 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (D212)

examples/finetune_regressor.py:32:5: D212 Multi-line docstring summary should start at the first line
# Ensure targets are long integers for scatter operation
targets = targets.long()

Check failure on line 63 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W293)

examples/finetune_regressor.py:63:1: W293 Blank line contains whitespace
pred_cdf = torch.cumsum(outputs, dim=1)
target_one_hot = torch.zeros_like(outputs).scatter_(1, targets.unsqueeze(1), 1.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For creating the one-hot encoded target tensor, consider using torch.nn.functional.one_hot. It's more idiomatic for this operation and can make the code's intent clearer. This function is specifically designed for one-hot encoding and might offer better performance.

Suggested change
target_one_hot = torch.zeros_like(outputs).scatter_(1, targets.unsqueeze(1), 1.)
target_one_hot = torch.nn.functional.one_hot(targets, num_classes=outputs.shape[1]).to(outputs.dtype)

target_cdf = torch.cumsum(target_one_hot, dim=1)
cdf_diff = pred_cdf - target_cdf

# Weight by bucket widths for variable bin sizes
bucket_widths = bucket_widths.to(outputs.device)
weighted_squared_diff = cdf_diff**2 * bucket_widths.unsqueeze(0)
rps = torch.mean(torch.sum(weighted_squared_diff, dim=1))

return rps

Check failure on line 74 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (RET504)

examples/finetune_regressor.py:74:12: RET504 Unnecessary assignment to `rps` before `return` statement


def prepare_data(config: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Loads, subsets, and splits the California Housing dataset."""
print("--- 1. Data Preparation ---")
Expand Down Expand Up @@ -145,14 +195,6 @@
X_train, X_test, y_train, y_test = prepare_data(config)
regressor, regressor_config = setup_regressor(config)

if len(regressor.models_) > 1:
raise ValueError(
f"Your TabPFNRegressor uses multiple models ({len(regressor.models_)}). "
"Finetuning is not supported for multiple models. Please use a single model."
)

model = regressor.models_[0]

splitter = partial(train_test_split, test_size=config["valid_set_ratio"])
# Note: `max_data_size` corresponds to the finetuning `batch_size` in the config
training_datasets = regressor.get_preprocessed_datasets(
Expand All @@ -165,7 +207,7 @@
)

# Optimizer must be created AFTER get_preprocessed_datasets, which initializes the model
optimizer = Adam(model.parameters(), lr=config["finetuning"]["learning_rate"])
optimizer = Adam(regressor.model_.parameters(), lr=config["finetuning"]["learning_rate"])
print(
f"--- Optimizer Initialized: Adam, LR: {config['finetuning']['learning_rate']} ---\n"
)
Expand Down Expand Up @@ -206,11 +248,24 @@
)
logits, _, _ = regressor.forward(X_tests_preprocessed)

# For regression, the loss function is part of the preprocessed data
loss_fn = znorm_space_bardist_[0]
y_target = y_test_znorm

loss = loss_fn(logits, y_target.to(config["device"])).mean()
# Get bucket widths from the BarDistribution for RPS loss
bucket_widths = znorm_space_bardist_[0].bucket_widths

Check failure on line 253 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W293)

examples/finetune_regressor.py:253:1: W293 Blank line contains whitespace
# Logits have shape (seq_len, batch_size, num_bars), take last sequence position
logits = logits[-1] # Now shape: (batch_size, num_bars)

Check failure on line 256 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W293)

examples/finetune_regressor.py:256:1: W293 Blank line contains whitespace
# Convert logits to probabilities for RPS loss
probabilities = torch.softmax(logits, dim=-1)

Check failure on line 259 in examples/finetune_regressor.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (W293)

examples/finetune_regressor.py:259:1: W293 Blank line contains whitespace
# Map continuous targets to bin indices
# y_test_znorm has shape (seq_len, batch_size), take last position
y_target = y_test_znorm[-1] if y_test_znorm.dim() > 1 else y_test_znorm
y_target_bins = znorm_space_bardist_[0].map_to_bucket_idx(
y_target.to(config["device"])
).clamp(0, znorm_space_bardist_[0].num_bars - 1)

# Use RPS loss instead of default BarDistribution loss
loss = rps_loss(probabilities, y_target_bins, bucket_widths)
loss.backward()
optimizer.step()

Expand Down
Loading