Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/test_classifier_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,31 @@ def test_fit(
assert predictions.shape == (X.shape[0],), "Predictions shape is incorrect!"


def test_fit_modes_all_return_equal_results(
X_y: tuple[np.ndarray, np.ndarray],
) -> None:
kwargs = {"n_estimators": 2, "device": "cpu", "random_state": 0}
X, y = X_y

torch.random.manual_seed(0)
tabpfn = TabPFNClassifier(fit_mode="fit_preprocessors", **kwargs)
tabpfn.fit(X, y)
probs = tabpfn.predict_proba(X)
preds = tabpfn.predict(X)

torch.random.manual_seed(0)
tabpfn = TabPFNClassifier(fit_mode="fit_with_cache", **kwargs)
tabpfn.fit(X, y)
np.testing.assert_array_almost_equal(probs, tabpfn.predict_proba(X))
np.testing.assert_array_equal(preds, tabpfn.predict(X))

torch.random.manual_seed(0)
tabpfn = TabPFNClassifier(fit_mode="low_memory", **kwargs)
tabpfn.fit(X, y)
np.testing.assert_array_almost_equal(probs, tabpfn.predict_proba(X))
np.testing.assert_array_equal(preds, tabpfn.predict(X))


@pytest.mark.parametrize(
(
"n_estimators",
Expand Down
212 changes: 0 additions & 212 deletions tests/test_finetuning_classifier.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations

import unittest
from functools import partial
from typing import Any, Literal
from unittest.mock import patch

import numpy as np
import pytest
import sklearn
import torch
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -505,212 +502,3 @@ def test_fit_from_preprocessed_runs(classifier_instance, classification_data) ->
probs_sum, torch.ones_like(probs_sum), atol=1e-5
), "Probabilities do not sum to 1"
break # Only need to check one batch for this test


class TestTabPFNClassifierPreprocessingInspection(unittest.TestCase):
def test_finetuning_consistency_preprocessing_classifier(self):
"""Tests the consistency between standard preprocessing (fit -> predict_proba)
and the fine-tuning preprocessing pipeline
(get_preprocessed_datasets -> fit_from_preprocessed
-> forward)
for the TabPFNClassifier by comparing the tensors entering the internal model.
"""
# --- Test Parameters ---
test_set_size = 0.3
common_seed = 42
n_total = 50 # Increased slightly for more robust testing
n_features = 8
n_classes = 2 # Use a specific number of classes
n_informative = 5 # For make_classification
n_estimators = 1 # Keep N=1 for easier direct comparison of tensors

# --- 1. Setup ---
X, y = sklearn.datasets.make_classification(
n_samples=n_total,
n_features=n_features,
n_informative=n_informative,
n_redundant=n_features - n_informative,
n_classes=n_classes,
n_clusters_per_class=1, # Simpler structure
random_state=common_seed,
)
splitfn = partial(
train_test_split,
test_size=test_set_size,
random_state=common_seed,
shuffle=False, # Keep False for consistent splitting
)
X_train_raw, X_test_raw, y_train_raw, y_test_raw = splitfn(X, y)

# Initialize two classifiers with the necessary modes
clf_standard = TabPFNClassifier(
n_estimators=n_estimators,
device="cpu",
random_state=common_seed,
fit_mode="fit_preprocessors", # A standard mode that preprocesses on fit
)
# 'batched' mode is required for get_preprocessed_datasets
# and fit_from_preprocessed
clf_batched = TabPFNClassifier(
n_estimators=n_estimators,
device="cpu",
random_state=common_seed,
fit_mode="batched",
)

# --- 2. Path 1: Standard fit -> predict_proba -> Capture Tensor ---

clf_standard.fit(X_train_raw, y_train_raw)
# Ensure the internal model attribute exists after fit
assert all(
[hasattr(clf_standard, "model_"), hasattr(clf_standard.model_, "forward")]
), "Standard classifier model_ or model_.forward not found after fit."

tensor_p1_full = None
# Patch the standard classifier's *internal model's* forward method
# The internal model typically receives the combined train+test sequence
with patch.object(
clf_standard.model_, "forward", wraps=clf_standard.model_.forward
) as mock_forward_p1:
_ = clf_standard.predict_proba(X_test_raw)
assert mock_forward_p1.called, "Standard model_.forward was not called."

# Capture the tensor input 'x' (usually the second positional argument)
call_args_list = mock_forward_p1.call_args_list
assert (
len(call_args_list) > 0
), "No calls recorded for standard model_.forward."
if len(call_args_list[0].args) > 1:
tensor_p1_full = call_args_list[0].args[1]
tensor_p1_full = mock_forward_p1.call_args.args[1]

else:
self.fail(
f"Standard model_.forward call had "
f"unexpected arguments: {call_args_list[0].args}"
)

assert (
tensor_p1_full is not None
), "Failed to capture tensor from standard path."
# Shape might be [1, N_Total, Features+1] or similar. Check the actual shape.
# Example assertion: Check if the sequence length matches n_total
assert tensor_p1_full.shape[0] == n_total, (
f"Path 1 tensor sequence length ({tensor_p1_full.shape[0]})"
f"does not match n_total ({n_total}). Shape was {tensor_p1_full.shape}"
)

# FT Workflow (get_prep -> fit_prep -> predict_prep -> Capture Tensor) ---
# Step 3a: Get preprocessed datasets using the *full* dataset
# Requires fit_mode='batched' on clf_batched
# Make sure default max_data_size is large enough.
datasets_list = clf_batched.get_preprocessed_datasets(
X,
y,
splitfn, # Use the full X, y and the split function
)
assert len(datasets_list) > 0, "get_preprocessed_datasets returned empty list."

dataloader = DataLoader(
datasets_list,
batch_size=1,
collate_fn=meta_dataset_collator,
shuffle=False,
)
try:
data_batch = next(iter(dataloader))
except StopIteration:
self.fail("DataLoader yielded no batches.")

try:
(X_trains_p2, X_tests_p2, y_trains_p2, _, cat_ixs_p2, confs_p2, *_) = (
data_batch
)
except ValueError as e:
self.fail(
f"Failed to unpack data batch from DataLoader."
f"Structure might be different. Error: {e}. Batch content: {data_batch}"
)

clf_batched.fit_from_preprocessed(
X_trains_p2, y_trains_p2, cat_ixs_p2, confs_p2
)
assert all(
[hasattr(clf_batched, "model_"), hasattr(clf_batched.model_, "forward")]
), (
"Batched classifier model_ or model_.forward not"
"found after fit_from_preprocessed."
)

# Step 3c: Call forward and capture the input tensor
# to the *internal transformer model*
tensor_p2_full = None
# Patch the *batched* classifier's internal model's forward method
with patch.object(
clf_batched.model_, "forward", wraps=clf_batched.model_.forward
) as mock_forward_p2:
_ = clf_batched.forward(X_tests_p2)
assert mock_forward_p2.called, "Batched model_.forward was not called."

# Capture the tensor input 'x' (assuming same argument position as Path 1)
call_args_list = mock_forward_p2.call_args_list
assert (
len(call_args_list) > 0
), "No calls recorded for batched model_.forward."
if len(call_args_list[0].args) > 1:
tensor_p2_full = mock_forward_p2.call_args.args[1]
else:
self.fail(
f"Batched model_.forward call had "
f"unexpected arguments: {call_args_list[0].args}"
)

assert tensor_p2_full is not None, "Failed to capture tensor from batched path."
# The internal model in this path should
# also receive the full sequence if n_estimators=1
# and the dataloader yielded the full split.
assert tensor_p2_full.shape[0] == n_total, (
f"Path 2 tensor sequence length ({tensor_p2_full.shape[0]}) "
f"does not match n_total ({n_total}). Shape was {tensor_p2_full.shape}"
)

# --- 4. Comparison (Path 1 vs Path 2) ---

# Ensure tensors are on the same device (CPU) for comparison
tensor_p1_full = tensor_p1_full.cpu()
tensor_p2_full = tensor_p2_full.cpu()

# Squeeze dimensions of size 1
# E.g., if shape is [1, N_Total, Features+1], squeeze the first dim
if tensor_p1_full.shape[0] == 1:
p1_squeezed = tensor_p1_full.squeeze(0)
else:
p1_squeezed = tensor_p1_full

if tensor_p2_full.shape[0] == 1:
p2_squeezed = tensor_p2_full.squeeze(0)
else:
p2_squeezed = tensor_p2_full

# Final check of shapes after potential squeeze
assert (
p1_squeezed.shape == p2_squeezed.shape
), "Shapes of final model input tensors mismatch after squeeze. "

# Visual inspection snippet

# Perform numerical comparison using torch.allclose
# Use a reasonably small tolerance. Preprocessing should be near-identical.
# Floating point ops might introduce tiny differences.
atol = 1e-6
rtol = 1e-5
tensors_match = torch.allclose(p1_squeezed, p2_squeezed, atol=atol, rtol=rtol)

if not tensors_match:
diff = torch.abs(p1_squeezed - p2_squeezed)
# Find where they differ most
max_diff_val, max_diff_idx = torch.max(diff.flatten(), dim=0)
np.unravel_index(max_diff_idx.item(), p1_squeezed.shape)

# Assertion: The final tensors fed to the model sh
assert tensors_match, "Mismatch between final model input tensors."
Loading
Loading