diff --git a/pyproject.toml b/pyproject.toml index 4f8ac5c9e..86a00e1b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,12 +7,13 @@ name = "tabpfn" version = "2.0.6" dependencies = [ "torch>=2.1,<3", - "scikit-learn>=1.2.0,<1.7", + "scikit-learn>=1.2.1,<1.7", "typing_extensions>=4.4.0", + "pandas>=1.5.3,<3", "scipy>=1.11.1,<2", - "pandas>=1.4.0,<3", "einops>=0.2.0,<0.9", "huggingface-hub>=0.0.1,<1", + "skrub>=0.2.0,<0.6", ] requires-python = ">=3.9,<3.13" authors = [ diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index c52b805fb..97289cdf3 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -48,7 +48,6 @@ from tabpfn.utils import ( _fix_dtypes, _get_embeddings, - _get_ordinal_encoder, infer_categorical_features, infer_device_and_type, infer_random_state, @@ -451,11 +450,30 @@ def fit(self, X: XType, y: YType) -> Self: # as handle `np.object` arrays or otherwise `object` dtype pandas columns. X = _fix_dtypes(X, cat_indices=self.categorical_features_indices) - # Ensure categories are ordinally encoded - ord_encoder = _get_ordinal_encoder() - X = ord_encoder.fit_transform(X) # type: ignore - assert isinstance(X, np.ndarray) - self.preprocessor_ = ord_encoder + # Use skrub's TableVectorizer to handle text columns with NAs properly + from sklearn.preprocessing import OrdinalEncoder + from skrub import TableVectorizer + + # Configure TableVectorizer to handle missing values consistently + # By using OrdinalEncoder with unknown_value=float("nan") for all text columns + table_vectorizer = TableVectorizer( + low_cardinality=OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=float("nan"), + ), + high_cardinality=OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=float("nan"), + ), + numeric="passthrough", + ) + + X = table_vectorizer.fit_transform(X) + self.preprocessor_ = table_vectorizer + + # TableVectorizer returns a DataFrame, convert to numpy array + if hasattr(X, "values"): + X = X.to_numpy() self.inferred_categorical_indices_ = infer_categorical_features( X=X, @@ -533,6 +551,10 @@ def predict_proba(self, X: XType) -> np.ndarray: X = _fix_dtypes(X, cat_indices=self.categorical_features_indices) X = self.preprocessor_.transform(X) + # Ensure X is a numpy array, not a DataFrame + if hasattr(X, "values"): + X = X.to_numpy() + outputs: list[torch.Tensor] = [] for output, config in self.executor_.iter_outputs( diff --git a/src/tabpfn/utils.py b/src/tabpfn/utils.py index 298507a5b..e59b24e8c 100644 --- a/src/tabpfn/utils.py +++ b/src/tabpfn/utils.py @@ -71,6 +71,10 @@ def _get_embeddings( X = _fix_dtypes(X, cat_indices=model.categorical_features_indices) X = model.preprocessor_.transform(X) + # Ensure X is a numpy array, not a DataFrame + if hasattr(X, "values"): + X = X.to_numpy() + embeddings: list[np.ndarray] = [] for output, config in model.executor_.iter_outputs( @@ -390,11 +394,12 @@ def load_model_criterion_config( model_name=model_name, ) if res != "ok": - repo_type = "clf" if which == "classifier" else "reg" raise RuntimeError( f"Failed to download model to {model_path}!\n\n" f"For offline usage, please download the model manually from:\n" - f"https://huggingface.co/Prior-Labs/TabPFN-v2-{repo_type}/resolve/main/{model_name}\n\n" + f"https://huggingface.co/Prior-Labs/TabPFN-v2-" + f"{'clf' if which == 'classifier' else 'reg'}" + f"/resolve/main/{model_name}\n\n" f"Then place it at: {model_path}", ) from res[0] @@ -481,6 +486,7 @@ def _fix_dtypes( if convert_dtype: X = X.convert_dtypes() + # Convert numeric columns to the specified numeric dtype integer_columns = X.select_dtypes(include=["number"]).columns if len(integer_columns) > 0: X[integer_columns] = X[integer_columns].astype(numeric_dtype) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index 90969d306..b1b0c1dbb 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -6,6 +6,7 @@ from typing import Callable, Literal import numpy as np +import pandas as pd import pytest import sklearn.datasets import torch @@ -310,3 +311,43 @@ def test_get_embeddings(X_y: tuple[np.ndarray, np.ndarray], data_source: str) -> assert embeddings.shape[0] == n_estimators assert embeddings.shape[1] == X.shape[0] assert embeddings.shape[2] == encoder_shape + + +def test_classifier_with_text_and_na() -> None: + """Test that TabPFNClassifier correctly handles text columns with NA values.""" + # Create a DataFrame with text and NA values + # Create test data with text and NA values + data = { + "text_feature": [ + "good product", + "bad service", + None, + "excellent", + "average", + None, + ], + "numeric_feature": [10, 5, 8, 15, 7, 12], + "all_na_column": [None, None, None, None, None, None], # Column with all NaNs + "target": [1, 0, 1, 1, 0, 0], + } + + # Create DataFrame + df = pd.DataFrame(data) + + # Split into X and y + X = df[["text_feature", "numeric_feature", "all_na_column"]] + y = df["target"] + + # Initialize and fit TabPFN on data with text+NA and a column with all NAs + classifier = TabPFNClassifier(device="cpu", n_estimators=2) + + # This should now work without raising errors + classifier.fit(X, y) + + # Verify we can predict + probabilities = classifier.predict_proba(X) + predictions = classifier.predict(X) + + # Check output shapes + assert probabilities.shape == (X.shape[0], len(np.unique(y))) + assert predictions.shape == (X.shape[0],)