Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
86d4e0d
add regressor finetuning wrapper
bejaeger Dec 31, 2025
c62f32d
add abstract base class for finetuning wrappers
bejaeger Dec 31, 2025
adf2abd
cleaner handling of batches
bejaeger Dec 31, 2025
11cf7cf
tweaks
bejaeger Dec 31, 2025
b2b5fe1
fix tests
bejaeger Dec 31, 2025
ab77196
add finetuning regressor tests
bejaeger Dec 31, 2025
e4877e1
Merge branch 'main' into ben/add-regressor-finetuning-wrapper
bejaeger Dec 31, 2025
7db6717
make finetuning deterministic
bejaeger Dec 31, 2025
b1af62f
better regressor training settings
bejaeger Dec 31, 2025
6c5e159
refactor loss calc
bejaeger Jan 2, 2026
9f9ab18
Merge branch 'ben/add-regressor-finetuning-wrapper' of github.com:Pri…
bejaeger Jan 2, 2026
4170cde
cleanup
bejaeger Jan 2, 2026
3adf058
comment update
bejaeger Jan 2, 2026
0d8e651
add test
bejaeger Jan 2, 2026
ac003b6
revision
bejaeger Jan 5, 2026
1c9fc95
revision 2
bejaeger Jan 5, 2026
4d19626
add comment
bejaeger Jan 5, 2026
3304175
remove configurations for finetuning examples
bejaeger Jan 5, 2026
6289d08
add back one parameter for better performance on example
bejaeger Jan 5, 2026
55fcc15
minor tweaks
bejaeger Jan 7, 2026
382e41a
add regressor finetuning wrapper
bejaeger Dec 31, 2025
8c0e186
add abstract base class for finetuning wrappers
bejaeger Dec 31, 2025
e5b72cc
cleaner handling of batches
bejaeger Dec 31, 2025
aab7f15
tweaks
bejaeger Dec 31, 2025
6712c1a
fix tests
bejaeger Dec 31, 2025
c3b644c
add finetuning regressor tests
bejaeger Dec 31, 2025
a2f80c8
refactor loss calc
bejaeger Jan 2, 2026
6403ae9
make finetuning deterministic
bejaeger Dec 31, 2025
81ec221
better regressor training settings
bejaeger Dec 31, 2025
b473d58
cleanup
bejaeger Jan 2, 2026
0da205b
comment update
bejaeger Jan 2, 2026
5b4a9f1
add test
bejaeger Jan 2, 2026
4cd088b
revision
bejaeger Jan 5, 2026
6ae3420
revision 2
bejaeger Jan 5, 2026
6991cc0
add comment
bejaeger Jan 5, 2026
cb18bc7
remove configurations for finetuning examples
bejaeger Jan 5, 2026
1cbfb6f
add back one parameter for better performance on example
bejaeger Jan 5, 2026
c699c3c
minor tweaks
bejaeger Jan 7, 2026
c798691
Merge branch 'main' into ben/add-regressor-finetuning-wrapper
bejaeger Jan 7, 2026
266b8b7
Merge branch 'ben/add-regressor-finetuning-wrapper' of github.com:Pri…
bejaeger Jan 7, 2026
6027e00
remove old file that sneaked in
bejaeger Jan 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions examples/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
support for the Apple Silicon (MPS) backend is still under development.
"""

import logging
import warnings

import numpy as np
Expand All @@ -23,27 +24,23 @@
module=r"google\.api_core\._python_version_support",
)

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)

# =============================================================================
# Fine-tuning Configuration
# For details and more options see FinetunedTabPFNClassifier
#
# These settings work well for the Higgs dataset.
# For other datasets, you may need to adjust these settings to get good results.
# =============================================================================

# Training hyperparameters
NUM_EPOCHS = 30
LEARNING_RATE = 2e-5

# Data sampling configuration (dataset dependent)
# the ratio of the total dataset to be used for validation during training
VALIDATION_SPLIT_RATIO = 0.1
# total context split into train/test
NUM_FINETUNE_CTX_PLUS_QUERY_SAMPLES = 10_000
# the following means 0.2*10_000=2_000 test samples are used in training
FINETUNE_CTX_QUERY_SPLIT_RATIO = 0.2
NUM_INFERENCE_SUBSAMPLE_SAMPLES = 50_000
# to reduce memory usage during training we can use activation checkpointing,
# may not be necessary for small datasets
USE_ACTIVATION_CHECKPOINTING = True

# Ensemble configuration
# number of estimators to use during finetuning
NUM_ESTIMATORS_FINETUNE = 2
Expand Down Expand Up @@ -84,14 +81,11 @@ def main() -> None:
)

# 2. Initial model evaluation on test set
inference_config = {
"SUBSAMPLE_SAMPLES": NUM_INFERENCE_SUBSAMPLE_SAMPLES,
}
base_clf = TabPFNClassifier(
device=[f"cuda:{i}" for i in range(torch.cuda.device_count())],
n_estimators=NUM_ESTIMATORS_FINAL_INFERENCE,
ignore_pretraining_limits=True,
inference_config=inference_config,
inference_config={"SUBSAMPLE_SAMPLES": 50_000},
)
base_clf.fit(X_train, y_train)

Expand All @@ -110,15 +104,9 @@ def main() -> None:
device="cuda" if torch.cuda.is_available() else "cpu",
epochs=NUM_EPOCHS,
learning_rate=LEARNING_RATE,
validation_split_ratio=VALIDATION_SPLIT_RATIO,
n_finetune_ctx_plus_query_samples=NUM_FINETUNE_CTX_PLUS_QUERY_SAMPLES,
finetune_ctx_query_split_ratio=FINETUNE_CTX_QUERY_SPLIT_RATIO,
n_inference_subsample_samples=NUM_INFERENCE_SUBSAMPLE_SAMPLES,
random_state=RANDOM_STATE,
n_estimators_finetune=NUM_ESTIMATORS_FINETUNE,
n_estimators_validation=NUM_ESTIMATORS_VALIDATION,
n_estimators_final_inference=NUM_ESTIMATORS_FINAL_INFERENCE,
use_activation_checkpointing=USE_ACTIVATION_CHECKPOINTING,
)

# 4. Call .fit() to start the fine-tuning process on the training data
Expand Down
118 changes: 118 additions & 0 deletions examples/finetune_regressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Example of fine-tuning a TabPFN regressor using the FinetunedTabPFNRegressor wrapper.

Note: We recommend running the fine-tuning scripts on a CUDA-enabled GPU, as full
support for the Apple Silicon (MPS) backend is still under development.
"""

import logging
import warnings

import sklearn.datasets
import torch
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

from tabpfn import TabPFNRegressor
from tabpfn.finetuning.finetuned_regressor import FinetunedTabPFNRegressor

warnings.filterwarnings(
"ignore",
category=FutureWarning,
module=r"google\.api_core\._python_version_support",
)

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)

# =============================================================================
# Fine-tuning Configuration
# For details and more options see FinetunedTabPFNRegressor
#
# These settings work well for the California Housing dataset.
# For other datasets, you may need to adjust these settings to get good results.
# =============================================================================

# Training hyperparameters
NUM_EPOCHS = 30
LEARNING_RATE = 1e-5

# We can fine-tune using almost the entire housing dataset
# in the context of the train batches.
N_FINETUNE_CTX_PLUS_QUERY_SAMPLES = 20_000

# Ensemble configuration
# number of estimators to use during finetuning
NUM_ESTIMATORS_FINETUNE = 8
# number of estimators to use during train time validation
NUM_ESTIMATORS_VALIDATION = 8
# number of estimators to use during final inference
NUM_ESTIMATORS_FINAL_INFERENCE = 8

# Reproducibility
RANDOM_STATE = 0


def main() -> None:
data = sklearn.datasets.fetch_california_housing(as_frame=True)
X_all = data.data
y_all = data.target

X_train, X_test, y_train, y_test = train_test_split(
X_all, y_all, test_size=0.1, random_state=RANDOM_STATE
)

print(
f"Loaded {len(X_train):,} samples for training and "
f"{len(X_test):,} samples for testing."
)

# 2. Initial model evaluation on test set
base_reg = TabPFNRegressor(
device=[f"cuda:{i}" for i in range(torch.cuda.device_count())],
n_estimators=NUM_ESTIMATORS_FINAL_INFERENCE,
ignore_pretraining_limits=True,
inference_config={"SUBSAMPLE_SAMPLES": 50_000},
)
base_reg.fit(X_train, y_train)

base_pred = base_reg.predict(X_test)
mse = mean_squared_error(y_test, base_pred)
r2 = r2_score(y_test, base_pred)

print(f"📊 Default TabPFN Test MSE: {mse:.4f}")
print(f"📊 Default TabPFN Test R²: {r2:.4f}\n")

# 3. Initialize and run fine-tuning
print("--- 2. Initializing and Fitting Model ---\n")

# Instantiate the wrapper with your desired hyperparameters
finetuned_reg = FinetunedTabPFNRegressor(
device="cuda" if torch.cuda.is_available() else "cpu",
epochs=NUM_EPOCHS,
learning_rate=LEARNING_RATE,
random_state=RANDOM_STATE,
n_finetune_ctx_plus_query_samples=N_FINETUNE_CTX_PLUS_QUERY_SAMPLES,
n_estimators_finetune=NUM_ESTIMATORS_FINETUNE,
n_estimators_validation=NUM_ESTIMATORS_VALIDATION,
n_estimators_final_inference=NUM_ESTIMATORS_FINAL_INFERENCE,
)

# 4. Call .fit() to start the fine-tuning process on the training data
finetuned_reg.fit(X_train.values, y_train.values)
print("\n")

# 5. Evaluate the fine-tuned model
print("--- 3. Evaluating Model on Held-out Test Set ---\n")
y_pred = finetuned_reg.predict(X_test.values)

mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"📊 Finetuned TabPFN Test MSE: {mse:.4f}")
print(f"📊 Finetuned TabPFN Test R²: {r2:.4f}")


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions src/tabpfn/finetuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Single-dataset fine-tuning wrappers for TabPFN models."""

from tabpfn.finetuning.data_util import ClassifierBatch, RegressorBatch
from tabpfn.finetuning.finetuned_base import EvalResult, FinetunedTabPFNBase
from tabpfn.finetuning.finetuned_classifier import FinetunedTabPFNClassifier
from tabpfn.finetuning.finetuned_regressor import FinetunedTabPFNRegressor

__all__ = [
"ClassifierBatch",
"EvalResult",
"FinetunedTabPFNBase",
"FinetunedTabPFNClassifier",
"FinetunedTabPFNRegressor",
"RegressorBatch",
]
Loading
Loading