diff --git a/examples/finetune_regressor.py b/examples/finetune_regressor.py index f92e88b14..976c57b33 100644 --- a/examples/finetune_regressor.py +++ b/examples/finetune_regressor.py @@ -24,6 +24,56 @@ from tabpfn.utils import meta_dataset_collator +def rps_loss( + outputs: torch.Tensor, + targets: torch.Tensor, + bucket_widths: torch.Tensor, +) -> torch.Tensor: + """ + Calculates the Ranked Probability Score (RPS) loss for binned regression with variable bin widths. + + RPS measures the difference between predicted cumulative probabilities + and the true cumulative probabilities for ordered categorical outcomes. + Lower RPS indicates better probabilistic forecasts. + + The formula for RPS for a single prediction and true value is: + $$ \text{RPS} = \sum_{k=1}^{K} w_k (\hat{F}(k) - F(k))^2 $$ + where $K$ is the number of bins, $w_k$ is the width of bin $k$, $\hat{F}(k)$ is the + predicted cumulative probability up to bin $k$, and $F(k)$ is the observed (true) + cumulative probability up to bin $k$. + + The function first converts the true target value into a one-hot encoded vector, + effectively creating an empirical probability mass function (PMF) where the true bin + has a probability of 1 and all others are 0. This PMF is then converted into a + step-like empirical cumulative distribution function (CDF). + + For variable bin widths (e.g., quantile-based bins), bucket_widths should be provided + to properly weight the squared CDF differences. + + Args: + outputs (torch.Tensor): Predicted probabilities for each bin. Shape (batch_size, num_bins). + targets (torch.Tensor): True bin labels (integer indices). Shape (batch_size,). + bucket_widths (torch.Tensor): Width of each bin. Shape (num_bins,). + + Returns: + torch.Tensor: The mean RPS loss for the batch. + """ + # Ensure targets are long integers for scatter operation + targets = targets.long() + + pred_cdf = torch.cumsum(outputs, dim=1) + target_one_hot = torch.zeros_like(outputs).scatter_(1, targets.unsqueeze(1), 1.) + target_cdf = torch.cumsum(target_one_hot, dim=1) + cdf_diff = pred_cdf - target_cdf + + # Weight by bucket widths for variable bin sizes + bucket_widths = bucket_widths.to(outputs.device) + weighted_squared_diff = cdf_diff**2 * bucket_widths.unsqueeze(0) + rps = torch.mean(torch.sum(weighted_squared_diff, dim=1)) + + return rps + + def prepare_data(config: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Loads, subsets, and splits the California Housing dataset.""" print("--- 1. Data Preparation ---") @@ -145,14 +195,6 @@ def main() -> None: X_train, X_test, y_train, y_test = prepare_data(config) regressor, regressor_config = setup_regressor(config) - if len(regressor.models_) > 1: - raise ValueError( - f"Your TabPFNRegressor uses multiple models ({len(regressor.models_)}). " - "Finetuning is not supported for multiple models. Please use a single model." - ) - - model = regressor.models_[0] - splitter = partial(train_test_split, test_size=config["valid_set_ratio"]) # Note: `max_data_size` corresponds to the finetuning `batch_size` in the config training_datasets = regressor.get_preprocessed_datasets( @@ -165,7 +207,7 @@ def main() -> None: ) # Optimizer must be created AFTER get_preprocessed_datasets, which initializes the model - optimizer = Adam(model.parameters(), lr=config["finetuning"]["learning_rate"]) + optimizer = Adam(regressor.model_.parameters(), lr=config["finetuning"]["learning_rate"]) print( f"--- Optimizer Initialized: Adam, LR: {config['finetuning']['learning_rate']} ---\n" ) @@ -206,11 +248,24 @@ def main() -> None: ) logits, _, _ = regressor.forward(X_tests_preprocessed) - # For regression, the loss function is part of the preprocessed data - loss_fn = znorm_space_bardist_[0] - y_target = y_test_znorm - - loss = loss_fn(logits, y_target.to(config["device"])).mean() + # Get bucket widths from the BarDistribution for RPS loss + bucket_widths = znorm_space_bardist_[0].bucket_widths + + # Logits have shape (seq_len, batch_size, num_bars), take last sequence position + logits = logits[-1] # Now shape: (batch_size, num_bars) + + # Convert logits to probabilities for RPS loss + probabilities = torch.softmax(logits, dim=-1) + + # Map continuous targets to bin indices + # y_test_znorm has shape (seq_len, batch_size), take last position + y_target = y_test_znorm[-1] if y_test_znorm.dim() > 1 else y_test_znorm + y_target_bins = znorm_space_bardist_[0].map_to_bucket_idx( + y_target.to(config["device"]) + ).clamp(0, znorm_space_bardist_[0].num_bars - 1) + + # Use RPS loss instead of default BarDistribution loss + loss = rps_loss(probabilities, y_target_bins, bucket_widths) loss.backward() optimizer.step()