From fba2752c278cd2d5aea4e33c956f7c086660d226 Mon Sep 17 00:00:00 2001 From: noahho Date: Fri, 28 Feb 2025 09:23:33 +0100 Subject: [PATCH 01/11] Add support for handling text columns with NA values Fixes issue #138: NA handling in text columns - Add skrub>=0.3.0 dependency to handle mixed string/NA data - Integrate TableVectorizer in TabPFNClassifier to properly process text columns with NA values - Add test to verify the solution works as expected --- pyproject.toml | 1 + src/tabpfn/classifier.py | 35 +++++++++++++++++++++++------ src/tabpfn/utils.py | 31 +++++++++++++++++++++++-- tests/test_classifier_interface.py | 36 ++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bacd2c94b..4503edee1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "pandas>=1.4.0,<3", "einops>=0.2.0,<0.9", "huggingface-hub>=0.0.1,<1", + "skrub>=0.3.0", ] requires-python = ">=3.9,<3.13" authors = [ diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index c52b805fb..551b176f9 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -48,7 +48,6 @@ from tabpfn.utils import ( _fix_dtypes, _get_embeddings, - _get_ordinal_encoder, infer_categorical_features, infer_device_and_type, infer_random_state, @@ -451,11 +450,29 @@ def fit(self, X: XType, y: YType) -> Self: # as handle `np.object` arrays or otherwise `object` dtype pandas columns. X = _fix_dtypes(X, cat_indices=self.categorical_features_indices) - # Ensure categories are ordinally encoded - ord_encoder = _get_ordinal_encoder() - X = ord_encoder.fit_transform(X) # type: ignore - assert isinstance(X, np.ndarray) - self.preprocessor_ = ord_encoder + # Use skrub's TableVectorizer to handle text columns with NAs properly + from sklearn.preprocessing import OrdinalEncoder + from skrub import TableVectorizer + + # Configure TableVectorizer to handle missing values consistently + # By using OrdinalEncoder with unknown_value=float("nan") for all text columns + table_vectorizer = TableVectorizer( + low_cardinality=OrdinalEncoder( + handle_unknown="use_encoded_value", unknown_value=float("nan"), + ), + high_cardinality=OrdinalEncoder( + handle_unknown="use_encoded_value", unknown_value=float("nan"), + ), + numeric="passthrough", + drop_null_fraction=None, # Don't drop columns with NAs + ) + + X = table_vectorizer.fit_transform(X) + self.preprocessor_ = table_vectorizer + + # TableVectorizer returns a DataFrame, convert to numpy array + if hasattr(X, "values"): + X = X.to_numpy() self.inferred_categorical_indices_ = infer_categorical_features( X=X, @@ -533,6 +550,10 @@ def predict_proba(self, X: XType) -> np.ndarray: X = _fix_dtypes(X, cat_indices=self.categorical_features_indices) X = self.preprocessor_.transform(X) + # Ensure X is a numpy array, not a DataFrame + if hasattr(X, "values"): + X = X.to_numpy() + outputs: list[torch.Tensor] = [] for output, config in self.executor_.iter_outputs( @@ -591,4 +612,4 @@ def get_embeddings( Returns: np.ndarray: The computed embeddings for each fitted estimator. """ - return _get_embeddings(self, X, data_source) + return _get_embeddings(self, X, data_source) \ No newline at end of file diff --git a/src/tabpfn/utils.py b/src/tabpfn/utils.py index 298507a5b..ef7e18a95 100644 --- a/src/tabpfn/utils.py +++ b/src/tabpfn/utils.py @@ -71,6 +71,10 @@ def _get_embeddings( X = _fix_dtypes(X, cat_indices=model.categorical_features_indices) X = model.preprocessor_.transform(X) + # Ensure X is a numpy array, not a DataFrame + if hasattr(X, "values"): + X = X.to_numpy() + embeddings: list[np.ndarray] = [] for output, config in model.executor_.iter_outputs( @@ -390,11 +394,12 @@ def load_model_criterion_config( model_name=model_name, ) if res != "ok": - repo_type = "clf" if which == "classifier" else "reg" raise RuntimeError( f"Failed to download model to {model_path}!\n\n" f"For offline usage, please download the model manually from:\n" - f"https://huggingface.co/Prior-Labs/TabPFN-v2-{repo_type}/resolve/main/{model_name}\n\n" + f"https://huggingface.co/Prior-Labs/TabPFN-v2-" + f"{'clf' if which == 'classifier' else 'reg'}" + f"/resolve/main/{model_name}\n\n" f"Then place it at: {model_path}", ) from res[0] @@ -481,6 +486,28 @@ def _fix_dtypes( if convert_dtype: X = X.convert_dtypes() + # Handle NAs in text/string/object columns by replacing with a placeholder + # This avoids mixed type errors in scikit-learn's validation + string_cols = X.select_dtypes(include=["string", "object"]).columns + if len(string_cols) > 0: + # Use a placeholder for NaN values in string columns + placeholder = "__MISSING__" + # We need to handle several pandas versions and their different NA handling + for col in string_cols: + # Force to object type first for consistent behavior + X[col] = X[col].astype("object") + # Create mask for all NaN-like values + is_all_na = X[col].isna().all() + is_none = False if is_all_na else X[col].apply(lambda x: x is None) + mask = X[col].isna() | is_none + # Apply mask to set placeholder + X.loc[mask, col] = placeholder + # After encoding (which happens in the preprocessing pipeline), + # these placeholders will be encoded as a category, but at least + # the data will be processable without errors. + # We'll then handle this consistently with numeric NaNs. + + # Convert numeric columns to the specified numeric dtype integer_columns = X.select_dtypes(include=["number"]).columns if len(integer_columns) > 0: X[integer_columns] = X[integer_columns].astype(numeric_dtype) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index 90969d306..899a9117e 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -310,3 +310,39 @@ def test_get_embeddings(X_y: tuple[np.ndarray, np.ndarray], data_source: str) -> assert embeddings.shape[0] == n_estimators assert embeddings.shape[1] == X.shape[0] assert embeddings.shape[2] == encoder_shape + + +def test_classifier_with_text_and_na() -> None: + """Test that TabPFNClassifier correctly handles text columns with NA values.""" + # Create a DataFrame with text and NA values + import pandas as pd + + # Create test data with text and NA values + data = { + "text_feature": [ + "good product", "bad service", None, "excellent", "average", None, + ], + "numeric_feature": [10, 5, 8, 15, 7, 12], + "target": [1, 0, 1, 1, 0, 0], + } + + # Create DataFrame + df = pd.DataFrame(data) + + # Split into X and y + X = df[["text_feature", "numeric_feature"]] + y = df["target"] + + # Initialize and fit TabPFN on data with text+NA + classifier = TabPFNClassifier(device="cpu", n_estimators=2) + + # This should now work without raising errors + classifier.fit(X, y) + + # Verify we can predict + probabilities = classifier.predict_proba(X) + predictions = classifier.predict(X) + + # Check output shapes + assert probabilities.shape == (X.shape[0], len(np.unique(y))) + assert predictions.shape == (X.shape[0],) From 9c1b1266cd2f60ebcd6714402b3f02976a0bf589 Mon Sep 17 00:00:00 2001 From: noahho Date: Fri, 28 Feb 2025 14:08:56 +0100 Subject: [PATCH 02/11] Refactor string NA handling in utils.py for improved readability --- src/tabpfn/utils.py | 46 +++++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/tabpfn/utils.py b/src/tabpfn/utils.py index ef7e18a95..41840d47b 100644 --- a/src/tabpfn/utils.py +++ b/src/tabpfn/utils.py @@ -424,6 +424,30 @@ def load_model_criterion_config( UNSUPPORTED_DTYPE_KINDS = "cM" # Not needed, just for completeness +def _handle_string_na_values(X: pd.DataFrame) -> pd.DataFrame: + """Replace NA values in string columns with a placeholder. + + This avoids mixed type errors in scikit-learn's validation. + """ + string_cols = X.select_dtypes(include=["string", "object"]).columns + if len(string_cols) == 0: + return X + + # Use a placeholder for NaN values in string columns + placeholder = "__MISSING__" + # We need to handle several pandas versions and their different NA handling + for col in string_cols: + # Force to object type first for consistent behavior + X[col] = X[col].astype("object") + # Create mask for all NaN-like values + is_all_na = X[col].isna().all() + is_none = False if is_all_na else X[col].apply(lambda x: x is None) + mask = X[col].isna() | is_none + # Apply mask to set placeholder + X.loc[mask, col] = placeholder + return X + + def _fix_dtypes( X: pd.DataFrame | np.ndarray, cat_indices: Sequence[int | str] | None, @@ -486,26 +510,8 @@ def _fix_dtypes( if convert_dtype: X = X.convert_dtypes() - # Handle NAs in text/string/object columns by replacing with a placeholder - # This avoids mixed type errors in scikit-learn's validation - string_cols = X.select_dtypes(include=["string", "object"]).columns - if len(string_cols) > 0: - # Use a placeholder for NaN values in string columns - placeholder = "__MISSING__" - # We need to handle several pandas versions and their different NA handling - for col in string_cols: - # Force to object type first for consistent behavior - X[col] = X[col].astype("object") - # Create mask for all NaN-like values - is_all_na = X[col].isna().all() - is_none = False if is_all_na else X[col].apply(lambda x: x is None) - mask = X[col].isna() | is_none - # Apply mask to set placeholder - X.loc[mask, col] = placeholder - # After encoding (which happens in the preprocessing pipeline), - # these placeholders will be encoded as a category, but at least - # the data will be processable without errors. - # We'll then handle this consistently with numeric NaNs. + # Handle NAs in text/string/object columns + X = _handle_string_na_values(X) # Convert numeric columns to the specified numeric dtype integer_columns = X.select_dtypes(include=["number"]).columns From 1bbe8f07b6f11197a7137783f3b3bb3dab965cb4 Mon Sep 17 00:00:00 2001 From: noahho Date: Fri, 28 Feb 2025 14:16:23 +0100 Subject: [PATCH 03/11] Format code to comply with ruff styling rules --- src/tabpfn/classifier.py | 8 +++++--- tests/test_classifier_interface.py | 7 ++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 551b176f9..52491c4b5 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -458,10 +458,12 @@ def fit(self, X: XType, y: YType) -> Self: # By using OrdinalEncoder with unknown_value=float("nan") for all text columns table_vectorizer = TableVectorizer( low_cardinality=OrdinalEncoder( - handle_unknown="use_encoded_value", unknown_value=float("nan"), + handle_unknown="use_encoded_value", + unknown_value=float("nan"), ), high_cardinality=OrdinalEncoder( - handle_unknown="use_encoded_value", unknown_value=float("nan"), + handle_unknown="use_encoded_value", + unknown_value=float("nan"), ), numeric="passthrough", drop_null_fraction=None, # Don't drop columns with NAs @@ -612,4 +614,4 @@ def get_embeddings( Returns: np.ndarray: The computed embeddings for each fitted estimator. """ - return _get_embeddings(self, X, data_source) \ No newline at end of file + return _get_embeddings(self, X, data_source) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index 899a9117e..6f2d5dd78 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -320,7 +320,12 @@ def test_classifier_with_text_and_na() -> None: # Create test data with text and NA values data = { "text_feature": [ - "good product", "bad service", None, "excellent", "average", None, + "good product", + "bad service", + None, + "excellent", + "average", + None, ], "numeric_feature": [10, 5, 8, 15, 7, 12], "target": [1, 0, 1, 1, 0, 0], From 9007260d2a62c3abead996b5b912c4c6a111eece Mon Sep 17 00:00:00 2001 From: noahho Date: Fri, 28 Feb 2025 14:22:10 +0100 Subject: [PATCH 04/11] Update skrub version to 0.2.0 for compatibility with existing scipy dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4503edee1..36b03941d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "pandas>=1.4.0,<3", "einops>=0.2.0,<0.9", "huggingface-hub>=0.0.1,<1", - "skrub>=0.3.0", + "skrub>=0.2.0,<0.3", ] requires-python = ">=3.9,<3.13" authors = [ From 050e9539c7f5e210fdacab8aeeca2ddeddc9fcf4 Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Mon, 3 Mar 2025 17:46:06 +0000 Subject: [PATCH 05/11] Simplify text and NA handling using only TableVectorizer, remove _handle_string_na_values --- src/tabpfn/classifier.py | 1 - src/tabpfn/utils.py | 27 --------------------------- 2 files changed, 28 deletions(-) diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 52491c4b5..97289cdf3 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -466,7 +466,6 @@ def fit(self, X: XType, y: YType) -> Self: unknown_value=float("nan"), ), numeric="passthrough", - drop_null_fraction=None, # Don't drop columns with NAs ) X = table_vectorizer.fit_transform(X) diff --git a/src/tabpfn/utils.py b/src/tabpfn/utils.py index 41840d47b..e59b24e8c 100644 --- a/src/tabpfn/utils.py +++ b/src/tabpfn/utils.py @@ -424,30 +424,6 @@ def load_model_criterion_config( UNSUPPORTED_DTYPE_KINDS = "cM" # Not needed, just for completeness -def _handle_string_na_values(X: pd.DataFrame) -> pd.DataFrame: - """Replace NA values in string columns with a placeholder. - - This avoids mixed type errors in scikit-learn's validation. - """ - string_cols = X.select_dtypes(include=["string", "object"]).columns - if len(string_cols) == 0: - return X - - # Use a placeholder for NaN values in string columns - placeholder = "__MISSING__" - # We need to handle several pandas versions and their different NA handling - for col in string_cols: - # Force to object type first for consistent behavior - X[col] = X[col].astype("object") - # Create mask for all NaN-like values - is_all_na = X[col].isna().all() - is_none = False if is_all_na else X[col].apply(lambda x: x is None) - mask = X[col].isna() | is_none - # Apply mask to set placeholder - X.loc[mask, col] = placeholder - return X - - def _fix_dtypes( X: pd.DataFrame | np.ndarray, cat_indices: Sequence[int | str] | None, @@ -510,9 +486,6 @@ def _fix_dtypes( if convert_dtype: X = X.convert_dtypes() - # Handle NAs in text/string/object columns - X = _handle_string_na_values(X) - # Convert numeric columns to the specified numeric dtype integer_columns = X.select_dtypes(include=["number"]).columns if len(integer_columns) > 0: From d94c2d73a64084642dcdd66fc6fe37bf4ab54e9f Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Mon, 3 Mar 2025 17:50:49 +0000 Subject: [PATCH 06/11] Bump scikit-learn minimum version to 1.2.1 for compatibility with skrub --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 36b03941d..f600e18c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "tabpfn" version = "2.0.5" dependencies = [ "torch>=2.1,<3", - "scikit-learn>=1.2.0,<1.7", + "scikit-learn>=1.2.1,<1.7", "typing_extensions>=4.4.0", "scipy>=1.7.3,<2", "pandas>=1.4.0,<3", From a9c739c58ffe3e082ee1c3ff16b3eba91b2b270c Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Mon, 3 Mar 2025 17:54:32 +0000 Subject: [PATCH 07/11] Add test case for column with all NaNs --- tests/test_classifier_interface.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index 6f2d5dd78..c2c575e65 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -316,6 +316,7 @@ def test_classifier_with_text_and_na() -> None: """Test that TabPFNClassifier correctly handles text columns with NA values.""" # Create a DataFrame with text and NA values import pandas as pd + import numpy as np # Create test data with text and NA values data = { @@ -328,6 +329,7 @@ def test_classifier_with_text_and_na() -> None: None, ], "numeric_feature": [10, 5, 8, 15, 7, 12], + "all_na_column": [None, None, None, None, None, None], # Column with all NaNs "target": [1, 0, 1, 1, 0, 0], } @@ -335,10 +337,10 @@ def test_classifier_with_text_and_na() -> None: df = pd.DataFrame(data) # Split into X and y - X = df[["text_feature", "numeric_feature"]] + X = df[["text_feature", "numeric_feature", "all_na_column"]] y = df["target"] - # Initialize and fit TabPFN on data with text+NA + # Initialize and fit TabPFN on data with text+NA and a column with all NAs classifier = TabPFNClassifier(device="cpu", n_estimators=2) # This should now work without raising errors From 3910535b2cce495bce063fde28af7ec6e92c7fae Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Mon, 3 Mar 2025 17:55:37 +0000 Subject: [PATCH 08/11] Bump pandas minimum version to 1.5.3 for compatibility with skrub --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f600e18c6..4ad5be86c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "scikit-learn>=1.2.1,<1.7", "typing_extensions>=4.4.0", "scipy>=1.7.3,<2", - "pandas>=1.4.0,<3", + "pandas>=1.5.3,<3", "einops>=0.2.0,<0.9", "huggingface-hub>=0.0.1,<1", "skrub>=0.2.0,<0.3", From 1c4f1aeb1565524e372490da234acb480fe19d2f Mon Sep 17 00:00:00 2001 From: LeoGrin <45738728+LeoGrin@users.noreply.github.com> Date: Mon, 3 Mar 2025 18:05:42 +0000 Subject: [PATCH 09/11] Fix merge --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ce0caf302..e987a48ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,10 +9,8 @@ dependencies = [ "torch>=2.1,<3", "scikit-learn>=1.2.1,<1.7", "typing_extensions>=4.4.0", - "scipy>=1.7.3,<2", "pandas>=1.5.3,<3", "scipy>=1.11.1,<2", - "pandas>=1.4.0,<3", "einops>=0.2.0,<0.9", "huggingface-hub>=0.0.1,<1", "skrub>=0.2.0,<0.3", From 3adc9395c6f6a1748107f7355c5425ffe82ad871 Mon Sep 17 00:00:00 2001 From: LeoGrin <45738728+LeoGrin@users.noreply.github.com> Date: Mon, 3 Mar 2025 18:11:43 +0000 Subject: [PATCH 10/11] update max skrub --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e987a48ea..86a00e1b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "scipy>=1.11.1,<2", "einops>=0.2.0,<0.9", "huggingface-hub>=0.0.1,<1", - "skrub>=0.2.0,<0.3", + "skrub>=0.2.0,<0.6", ] requires-python = ">=3.9,<3.13" authors = [ From 7fb2935b0e4581212176577d24de26979ce1c270 Mon Sep 17 00:00:00 2001 From: LeoGrin Date: Mon, 3 Mar 2025 18:21:01 +0000 Subject: [PATCH 11/11] fix ruff? --- tests/test_classifier_interface.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index c2c575e65..b1b0c1dbb 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -6,6 +6,7 @@ from typing import Callable, Literal import numpy as np +import pandas as pd import pytest import sklearn.datasets import torch @@ -315,9 +316,6 @@ def test_get_embeddings(X_y: tuple[np.ndarray, np.ndarray], data_source: str) -> def test_classifier_with_text_and_na() -> None: """Test that TabPFNClassifier correctly handles text columns with NA values.""" # Create a DataFrame with text and NA values - import pandas as pd - import numpy as np - # Create test data with text and NA values data = { "text_feature": [