Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
42464e3
- Change default estimators for classifier from 4 to 8
Jul 15, 2025
abc48a8
Merge remote-tracking branch 'origin/main'
Jul 16, 2025
afa7fb3
Merge remote-tracking branch 'origin/main'
Jul 16, 2025
775bb5e
Merge remote-tracking branch 'origin/main'
Jul 17, 2025
dad171f
Merge remote-tracking branch 'origin/main'
Jul 28, 2025
5062414
Merge remote-tracking branch 'origin/main'
Aug 3, 2025
2ce9d5b
Merge remote-tracking branch 'origin/main'
Aug 3, 2025
eee5d91
Merge remote-tracking branch 'origin/main'
Aug 3, 2025
db3a919
attempt to fix the naming of bar distributions
rosenyu304 Jul 23, 2025
bef0d24
ruff fix
rosenyu304 Jul 23, 2025
402b3da
ruff fix on the ipynb
rosenyu304 Jul 23, 2025
ad52ac4
naming change
rosenyu304 Jul 23, 2025
16efb80
resolve gemini suggestions
rosenyu304 Jul 24, 2025
0c030c9
ruff
rosenyu304 Jul 24, 2025
39914d9
debug test
rosenyu304 Jul 24, 2025
1902ac1
Delete my local runs
rosenyu304 Jul 24, 2025
7a7c914
adding attributes to allow using both the old naming convention and t…
rosenyu304 Jul 30, 2025
882e758
python compatibility issue on dataclass
rosenyu304 Jul 30, 2025
cf94778
ruff
rosenyu304 Jul 30, 2025
76ef0fe
Fixed the comments
rosenyu304 Aug 1, 2025
f769a8f
call the ys znorm. the x are still preprocessed.
rosenyu304 Aug 6, 2025
2995386
simplify preprocessing bardist attribute
rosenyu304 Aug 7, 2025
06014f0
Merge remote-tracking branch 'origin/main'
Aug 7, 2025
5ad66a9
refactor internally attempt
rosenyu304 Aug 7, 2025
753fbe8
ruff
rosenyu304 Aug 7, 2025
879d428
Merge remote-tracking branch 'origin/main'
Aug 8, 2025
ac9a8bd
Merge branch 'main' into finetuning-debugging
Aug 22, 2025
32a8ff6
Merge remote-tracking branch 'origin/main' into finetuning-debugging
Aug 22, 2025
a283d4f
- add subsampling with replacement
Aug 27, 2025
e3eadb5
Merge remote-tracking branch 'origin/main' into subsampling-with-repl…
Aug 27, 2025
56ff60b
- add subsampling with replacement
Aug 27, 2025
b4968d0
Update src/tabpfn/preprocessing.py
Aug 28, 2025
d46571c
- add subsampling with replacement
Aug 28, 2025
809327e
Merge remote-tracking branch 'origin/subsampling-with-replacement' in…
Aug 28, 2025
3714ec2
- add subsampling with replacement
Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def _initialize_dataset_preprocessing(
ensemble_configs = EnsembleConfig.generate_for_classification(
n=self.n_estimators,
subsample_size=self.interface_config_.SUBSAMPLE_SAMPLES,
subsample_with_replacement=self.interface_config_.SUBSAMPLE_SAMPLES_WITH_REPLACEMENT,
add_fingerprint_feature=self.interface_config_.FINGERPRINT_FEATURE,
feature_shift_decoder=self.interface_config_.FEATURE_SHIFT_METHOD,
polynomial_features=self.interface_config_.POLYNOMIAL_FEATURES,
Expand Down
6 changes: 6 additions & 0 deletions src/tabpfn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ class ModelInterfaceConfig:
- If a float, the percentage of samples to subsample.
"""

SUBSAMPLE_SAMPLES_WITH_REPLACEMENT: bool = False
"""Whether to subsample with replacement (bootstrapping). If False (default),
each sample can appear at most once in a subsample. If True, samples can be
drawn multiple times. This is only active when `SUBSAMPLE_SAMPLES` is not None.
"""

PREPROCESS_TRANSFORMS: list[PreprocessorConfig | dict] | None = None
"""The preprocessing applied to the data before passing it to TabPFN. See
`PreprocessorConfig` for options and more details. If a list of `PreprocessorConfig`
Expand Down
61 changes: 42 additions & 19 deletions src/tabpfn/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import math
import warnings
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, field
Expand Down Expand Up @@ -267,36 +268,52 @@ def generate_index_permutations(
*,
max_index: int,
subsample: int | float,
with_replacement: bool = False,
random_state: int | np.random.Generator | None,
) -> list[npt.NDArray[np.int64]]:
"""Generate indices for subsampling from the data.

Args:
n: Number of indices to generate.
max_index: Maximum index to generate.
n: Number of index arrays to generate.
max_index: The upper bound for the indices (samples from [0, max_index-1]).
subsample:
Number of indices to subsample. If `int`, subsample that many
indices. If float, subsample that fraction of indices.
random_state: Random number generator.
The number of indices to draw.
- If `int`, this is the absolute number of indices.
- If `float`, this is the fraction of `max_index` to draw.
with_replacement: If `True`, indices can be chosen more than once.
If `False` (default), indices are unique.
random_state: A seed or random number generator for reproducibility.

Returns:
List of indices to subsample.
A list containing `n` arrays of subsampled indices.
"""
_, rng = infer_random_state(random_state)
if isinstance(subsample, int):
if subsample < 1:
raise ValueError(f"{subsample=} must be larger than 1 if int")
subsample = min(subsample, max_index)
if max_index < 0:
raise ValueError(f"max_index must be non-negative, but got {max_index}")
if max_index == 0:
return [np.array([], dtype=np.int64) for _ in range(n)]

return [rng.permutation(max_index)[:subsample] for _ in range(n)]
_, rng = infer_random_state(random_state)

# Determine the number of items to subsample (k)
if isinstance(subsample, float):
if not (0 < subsample < 1):
raise ValueError(f"{subsample=} must be in (0, 1) if float")
subsample = int(subsample * max_index) + 1
return [rng.permutation(max_index)[:subsample] for _ in range(n)]
if not (0.0 < subsample <= 1.0):
raise ValueError(f"If float, {subsample=} must be in (0, 1].")
# Ensure at least one sample is drawn
k = max(1, math.ceil(subsample * max_index))
elif isinstance(subsample, int):
if subsample < 1:
raise ValueError(f"If int, {subsample=} must be at least 1.")
k = subsample
else:
raise TypeError(f"{subsample=} must be an int or float.")

raise ValueError(f"{subsample=} must be int or float.")
# Generate n lists of indices based on the replacement strategy
if with_replacement:
# Sample with replacement. The sample size `k` can be larger than `max_index`.
return [rng.choice(max_index, size=k, replace=True) for _ in range(n)]
# Sample without replacement. The sample size cannot exceed the population size.
sample_size = min(k, max_index)
return [rng.permutation(max_index)[:sample_size] for _ in range(n)]


# TODO: (Klemens)
Expand All @@ -321,7 +338,7 @@ class EnsembleConfig:
subsample_ix: npt.NDArray[np.int64] | None # OPTIM: Could use uintp

@classmethod
def generate_for_classification(
def generate_for_classification( # noqa: PLR0913
cls,
*,
n: int,
Expand All @@ -333,6 +350,7 @@ def generate_for_classification(
preprocessor_configs: Sequence[PreprocessorConfig],
class_shift_method: Literal["rotate", "shuffle"] | None,
n_classes: int,
subsample_with_replacement: bool = False,
random_state: int | np.random.Generator | None,
) -> list[ClassifierEnsembleConfig]:
"""Generate ensemble configurations for classification.
Expand All @@ -350,6 +368,7 @@ def generate_for_classification(
preprocessor_configs: Preprocessor configurations to use on the data.
class_shift_method: How to shift classes for classpermutation.
n_classes: Number of classes.
subsample_with_replacement: Whether to subsample with replacement.
random_state: Random number generator.

Returns:
Expand Down Expand Up @@ -389,9 +408,10 @@ def generate_for_classification(
n=n,
max_index=max_index,
subsample=subsample_size,
with_replacement=subsample_with_replacement,
random_state=static_seed,
)
elif subsample_size is None:
elif subsample_size is None: # No subsampling
subsamples = [None] * n # type: ignore
else:
raise ValueError(
Expand Down Expand Up @@ -440,6 +460,7 @@ def generate_for_regression(
feature_shift_decoder: Literal["shuffle", "rotate"] | None,
preprocessor_configs: Sequence[PreprocessorConfig],
target_transforms: Sequence[TransformerMixin | Pipeline | None],
subsample_with_replacement: bool = False,
random_state: int | np.random.Generator | None,
) -> list[RegressorEnsembleConfig]:
"""Generate ensemble configurations for regression.
Expand All @@ -456,6 +477,7 @@ def generate_for_regression(
feature_shift_decoder: How shift features
preprocessor_configs: Preprocessor configurations to use on the data.
target_transforms: Target transformations to apply.
subsample_with_replacement: Whether to subsample with replacement.
random_state: Random number generator.

Returns:
Expand All @@ -472,6 +494,7 @@ def generate_for_regression(
n=n,
max_index=max_index,
subsample=subsample_size,
with_replacement=subsample_with_replacement,
random_state=static_seed,
)
elif subsample_size is None:
Expand Down
1 change: 1 addition & 0 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def _initialize_dataset_preprocessing(
ensemble_configs = EnsembleConfig.generate_for_regression(
n=self.n_estimators,
subsample_size=self.interface_config_.SUBSAMPLE_SAMPLES,
subsample_with_replacement=self.interface_config_.SUBSAMPLE_SAMPLES_WITH_REPLACEMENT,
add_fingerprint_feature=self.interface_config_.FINGERPRINT_FEATURE,
feature_shift_decoder=self.interface_config_.FEATURE_SHIFT_METHOD,
polynomial_features=self.interface_config_.POLYNOMIAL_FEATURES,
Expand Down
26 changes: 26 additions & 0 deletions tests/test_classifier_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,29 @@ def test_initialize_model_variables_classifier_sets_required_attributes() -> Non
assert not hasattr(
classifier2, "bardist_"
), "classifier2 should not have bardist_ attribute"


def test_subsample_with_replacement_allows_oversampling(
X_y: tuple[np.ndarray, np.ndarray],
) -> None:
"""Tests that SUBSAMPLE_SAMPLES_WITH_REPLACEMENT=True allows sampling more
samples than available in the dataset (oversampling).
"""
X, y = X_y
n_samples = X.shape[0]
oversample_size = n_samples + 10 # Sample more than available

# This should work without errors because with_replacement=True allows
# drawing the same sample multiple times.
model_with_replacement = TabPFNClassifier(
n_estimators=2,
device="cpu",
inference_config={
"SUBSAMPLE_SAMPLES": oversample_size,
"SUBSAMPLE_SAMPLES_WITH_REPLACEMENT": True,
},
random_state=42,
)

model_with_replacement.fit(X, y)
model_with_replacement.predict(X)
Loading