diff --git a/README.md b/README.md index 3c973226..de59be7a 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,27 @@ If you use the 4DBInfer datasets in your work, please cite [4DBInfer](https://gi ``` +**Using TabArena datasets** + +RelBench includes an optional integration for TabArena: a collection of single-table OpenML tasks. TabArena datasets are exposed under names like `tabarena-credit-g`, with per-split tasks named `split-0`, `split-1`, etc. + +To use TabArena datasets, install the optional dependency: +```bash +pip install relbench[tabarena] +``` + +See [`examples/translate_tabarena_to_relbench.py`](examples/translate_tabarena_to_relbench.py) for a comparison between the original OpenML task and the relbenchified `records` and `split-*` tables, and [`examples/validate_tabarena_baseline.py`](examples/validate_tabarena_baseline.py) for a public-baseline validation script that checks inference consistency between the original and relbenchified views. + +If you use the TabArena datasets in your work, please cite TabArena as below: +``` +@inproceedings{erickson2025tabarena, + title={TabArena: A Living Benchmark for Machine Learning on Tabular Data}, + author={Erickson, Nick and Purucker, Lennart and Tschalzev, Andrej and Holzm{\"u}ller, David and Mutalik Desai, Prateek and Salinas, David and Hutter, Frank}, + booktitle={Advances in Neural Information Processing Systems}, + year={2025} +} +``` + # Package Usage diff --git a/examples/translate_tabarena_to_relbench.py b/examples/translate_tabarena_to_relbench.py new file mode 100644 index 00000000..420447d5 --- /dev/null +++ b/examples/translate_tabarena_to_relbench.py @@ -0,0 +1,151 @@ +"""Inspect how a TabArena OpenML task is represented in RelBench. + +This script compares the original OpenML dataset/task with the RelBench wrapper: +the source rows become a single ``records`` table, while each ``split-*`` task +materializes a thin task table keyed by ``record_id``. +""" + +from __future__ import annotations + +import argparse +from dataclasses import asdict + +import numpy as np +import pandas as pd + +from relbench.datasets.tabarena import TabArenaDataset, get_tabarena_dataset_slugs +from relbench.tasks.tabarena import TabArenaSplitEntityTask + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--dataset", + type=str, + default="credit-g", + choices=get_tabarena_dataset_slugs(), + help="TabArena dataset slug, for example `credit-g` or `airfoil-self-noise`.", + ) + parser.add_argument( + "--split", + type=int, + default=0, + help="OpenML split index exposed in RelBench as `split-`.", + ) + parser.add_argument( + "--show_rows", + type=int, + default=3, + help="Number of joined examples to print from the RelBench train split.", + ) + return parser.parse_args() + + +def _load_openml_frame(dataset: TabArenaDataset) -> tuple[pd.DataFrame, pd.Series]: + task = dataset.get_openml_task() + X_df, y_ser, _cat, _names = task.get_dataset().get_data( + target=task.target_name, + dataset_format="dataframe", + ) + X_df = pd.DataFrame(X_df).reset_index(drop=True) + y_ser = pd.Series(y_ser, name=task.target_name).reset_index(drop=True) + return X_df, y_ser + + +def _join_records(records_df: pd.DataFrame, task_df: pd.DataFrame) -> pd.DataFrame: + joined = task_df.merge(records_df, on="record_id", how="left", validate="1:1") + feature_cols = [col for col in joined.columns if col not in {"record_id", "target"}] + return joined[["record_id", *feature_cols, "target"]] + + +def _check_translation( + dataset: TabArenaDataset, + task: TabArenaSplitEntityTask, +) -> dict[str, object]: + X_df, _y_ser = _load_openml_frame(dataset) + y_encoded = pd.Series(dataset.get_target_array(), name="target").reset_index( + drop=True + ) + records_df = dataset.get_db().table_dict["records"].df.reset_index(drop=True) + + openml_train_idx, openml_test_idx = dataset.get_openml_split_indices(task.split) + train_table = task.get_table("train", mask_input_cols=False) + val_table = task.get_table("val", mask_input_cols=False) + test_table = task.get_table("test", mask_input_cols=False) + + relbench_train_ids = train_table.df["record_id"].to_numpy() + relbench_val_ids = val_table.df["record_id"].to_numpy() + relbench_test_ids = test_table.df["record_id"].to_numpy() + relbench_trainval_ids = set(relbench_train_ids).union(relbench_val_ids) + + target_matches = True + for split_name, split_df in { + "train": train_table.df, + "val": val_table.df, + "test": test_table.df, + }.items(): + relbench_target = split_df["target"].reset_index(drop=True) + source_target = y_encoded.iloc[split_df["record_id"].to_numpy()].reset_index( + drop=True + ) + if not relbench_target.equals(source_target): + target_matches = False + print(f"[check] target mismatch on {split_name}") + + return { + "dataset_slug": dataset.spec.slug, + "dataset_name": dataset.name, + "tabarena_name": dataset.tabarena_name, + "problem_type": dataset.problem_type, + "openml_task_id": dataset.task_id, + "openml_dataset_id": dataset.openml_dataset_id, + "target_name": dataset.target_name, + "records_rows": len(records_df), + "records_feature_columns": len(records_df.columns) - 1, + "openml_rows": len(X_df), + "openml_train_rows": len(openml_train_idx), + "openml_test_rows": len(openml_test_idx), + "relbench_train_rows": len(train_table.df), + "relbench_val_rows": len(val_table.df), + "relbench_test_rows": len(test_table.df), + "records_match_openml_rows": len(records_df) == len(X_df), + "record_ids_are_row_indices": np.array_equal( + records_df["record_id"].to_numpy(), + np.arange(len(records_df), dtype=np.int64), + ), + "relbench_test_matches_openml_test": set(relbench_test_ids) + == set(openml_test_idx), + "relbench_train_val_partition_openml_train": relbench_trainval_ids + == set(openml_train_idx), + "relbench_train_val_are_disjoint": set(relbench_train_ids).isdisjoint( + relbench_val_ids + ), + "targets_match_source_rows": target_matches, + } + + +def main() -> None: + args = parse_args() + + dataset = TabArenaDataset(dataset_slug=args.dataset) + task = TabArenaSplitEntityTask(dataset, split=args.split) + records_df = dataset.get_db().table_dict["records"].df.reset_index(drop=True) + + print("[dataset spec]") + print(asdict(dataset.spec)) + print() + + summary = _check_translation(dataset, task) + print("[translation summary]") + for key, value in summary.items(): + print(f"{key}: {value}") + print() + + train_table = task.get_table("train", mask_input_cols=False) + joined_train = _join_records(records_df, train_table.df).head(args.show_rows) + print("[joined relbench train rows]") + print(joined_train.to_string(index=False)) + + +if __name__ == "__main__": + main() diff --git a/examples/validate_tabarena_baseline.py b/examples/validate_tabarena_baseline.py new file mode 100644 index 00000000..6a1828f3 --- /dev/null +++ b/examples/validate_tabarena_baseline.py @@ -0,0 +1,281 @@ +"""Validate the TabArena RelBench wrapper with a public baseline. + +The script trains a baseline on the original OpenML rows selected by a RelBench task +split, then runs inference on both the original and relbenchified views of the same +validation and test rows. Matching predictions and metrics confirm that the RelBench +wrapper preserves the original single-table task semantics. +""" + +from __future__ import annotations + +import argparse + +import numpy as np +import pandas as pd +from pandas import CategoricalDtype +from sklearn.compose import ColumnTransformer +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.impute import SimpleImputer +from sklearn.metrics import mean_squared_error, roc_auc_score +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import OneHotEncoder + +from relbench.datasets.tabarena import TabArenaDataset, get_tabarena_dataset_slugs +from relbench.tasks.tabarena import TabArenaSplitEntityTask + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--dataset", + type=str, + default="credit-g", + choices=get_tabarena_dataset_slugs(), + help="TabArena dataset slug, for example `credit-g` or `airfoil-self-noise`.", + ) + parser.add_argument( + "--split", + type=int, + default=0, + help="OpenML split index exposed in RelBench as `split-`.", + ) + parser.add_argument( + "--model", + type=str, + default="random-forest", + choices=["random-forest", "xgboost"], + help="Public baseline model used for validation.", + ) + parser.add_argument( + "--random_state", + type=int, + default=0, + ) + return parser.parse_args() + + +def _load_openml_frame(dataset: TabArenaDataset) -> pd.DataFrame: + task = dataset.get_openml_task() + X_df, y_ser, _cat, _names = task.get_dataset().get_data( + target=task.target_name, + dataset_format="dataframe", + ) + X_df = pd.DataFrame(X_df).reset_index(drop=True) + y_ser = pd.Series(y_ser, name=task.target_name).reset_index(drop=True) + _ = y_ser + return X_df + + +def _join_records( + records_df: pd.DataFrame, task_df: pd.DataFrame +) -> tuple[pd.DataFrame, np.ndarray]: + joined = task_df.merge(records_df, on="record_id", how="left", validate="1:1") + X_df = joined.drop(columns=["record_id", "target"]) + y = joined["target"].to_numpy() + return X_df, y + + +def _build_preprocessor(X_df: pd.DataFrame) -> ColumnTransformer: + categorical_cols = [ + col + for col in X_df.columns + if pd.api.types.is_object_dtype(X_df[col]) + or isinstance(X_df[col].dtype, CategoricalDtype) + or pd.api.types.is_bool_dtype(X_df[col]) + ] + numeric_cols = [col for col in X_df.columns if col not in categorical_cols] + + return ColumnTransformer( + transformers=[ + ( + "numeric", + Pipeline([("imputer", SimpleImputer(strategy="median"))]), + numeric_cols, + ), + ( + "categorical", + Pipeline( + [ + ("imputer", SimpleImputer(strategy="most_frequent")), + ( + "encoder", + OneHotEncoder(handle_unknown="ignore", sparse_output=True), + ), + ] + ), + categorical_cols, + ), + ] + ) + + +def _build_model( + *, + task: TabArenaSplitEntityTask, + X_df: pd.DataFrame, + model_name: str, + random_state: int, +) -> Pipeline: + preprocessor = _build_preprocessor(X_df) + + if model_name == "random-forest": + if task.task_type.value == "regression": + estimator = RandomForestRegressor( + n_estimators=200, + random_state=random_state, + n_jobs=-1, + ) + else: + estimator = RandomForestClassifier( + n_estimators=200, + random_state=random_state, + n_jobs=-1, + ) + elif model_name == "xgboost": + try: + from xgboost import XGBClassifier, XGBRegressor + except ImportError as exc: # pragma: no cover - dependency is optional + raise ImportError( + "The `xgboost` example requires `pip install xgboost`." + ) from exc + + if task.task_type.value == "regression": + estimator = XGBRegressor( + n_estimators=300, + max_depth=6, + learning_rate=0.05, + subsample=0.8, + colsample_bytree=0.8, + tree_method="hist", + random_state=random_state, + ) + else: + estimator = XGBClassifier( + n_estimators=300, + max_depth=6, + learning_rate=0.05, + subsample=0.8, + colsample_bytree=0.8, + tree_method="hist", + eval_metric="logloss", + random_state=random_state, + ) + else: # pragma: no cover - guarded by argparse + raise ValueError(f"Unsupported model {model_name!r}") + + return Pipeline( + [ + ("preprocess", preprocessor), + ("model", estimator), + ] + ) + + +def _predict( + model: Pipeline, + task: TabArenaSplitEntityTask, + X_df: pd.DataFrame, +) -> np.ndarray: + if task.task_type.value == "regression": + return np.asarray(model.predict(X_df), dtype=np.float64) + proba = model.predict_proba(X_df) + if proba.ndim != 2 or proba.shape[1] != 2: + raise RuntimeError( + f"Expected binary predict_proba output, got shape={proba.shape}" + ) + return np.asarray(proba[:, 1], dtype=np.float64) + + +def _metric_name_and_value( + task: TabArenaSplitEntityTask, y_true: np.ndarray, pred: np.ndarray +) -> tuple[str, float]: + if task.task_type.value == "regression": + return "rmse", float(np.sqrt(mean_squared_error(y_true, pred))) + return "auroc", float(roc_auc_score(y_true, pred)) + + +def _max_prediction_delta(a: np.ndarray, b: np.ndarray) -> float: + a = np.asarray(a, dtype=np.float64).reshape(-1) + b = np.asarray(b, dtype=np.float64).reshape(-1) + return float(np.max(np.abs(a - b))) if len(a) else 0.0 + + +def main() -> None: + args = parse_args() + + dataset = TabArenaDataset(dataset_slug=args.dataset) + task = TabArenaSplitEntityTask(dataset, split=args.split) + records_df = dataset.get_db().table_dict["records"].df.reset_index(drop=True) + openml_X = _load_openml_frame(dataset) + openml_y = pd.Series(dataset.get_target_array(), name="target").reset_index( + drop=True + ) + + relbench_train = task.get_table("train", mask_input_cols=False).df + relbench_val = task.get_table("val", mask_input_cols=False).df + relbench_test = task.get_table("test", mask_input_cols=False).df + + X_train_rb, y_train_rb = _join_records(records_df, relbench_train) + X_val_rb, y_val_rb = _join_records(records_df, relbench_val) + X_test_rb, y_test_rb = _join_records(records_df, relbench_test) + + X_train_orig = openml_X.iloc[relbench_train["record_id"].to_numpy()].reset_index( + drop=True + ) + y_train_orig = openml_y.iloc[relbench_train["record_id"].to_numpy()].to_numpy() + X_val_orig = openml_X.iloc[relbench_val["record_id"].to_numpy()].reset_index( + drop=True + ) + y_val_orig = openml_y.iloc[relbench_val["record_id"].to_numpy()].to_numpy() + X_test_orig = openml_X.iloc[relbench_test["record_id"].to_numpy()].reset_index( + drop=True + ) + y_test_orig = openml_y.iloc[relbench_test["record_id"].to_numpy()].to_numpy() + + print("[data equality checks]") + print(f"train features identical: {X_train_orig.equals(X_train_rb)}") + print(f"val features identical: {X_val_orig.equals(X_val_rb)}") + print(f"test features identical: {X_test_orig.equals(X_test_rb)}") + print(f"train labels identical: {np.array_equal(y_train_orig, y_train_rb)}") + print(f"val labels identical: {np.array_equal(y_val_orig, y_val_rb)}") + print(f"test labels identical: {np.array_equal(y_test_orig, y_test_rb)}") + print() + + model = _build_model( + task=task, + X_df=X_train_orig, + model_name=args.model, + random_state=args.random_state, + ) + model.fit(X_train_orig, y_train_orig) + + val_pred_orig = _predict(model, task, X_val_orig) + val_pred_rb = _predict(model, task, X_val_rb) + test_pred_orig = _predict(model, task, X_test_orig) + test_pred_rb = _predict(model, task, X_test_rb) + + metric_name, val_metric_orig = _metric_name_and_value( + task, y_val_orig, val_pred_orig + ) + _, val_metric_rb = _metric_name_and_value(task, y_val_rb, val_pred_rb) + _, test_metric_orig = _metric_name_and_value(task, y_test_orig, test_pred_orig) + _, test_metric_rb = _metric_name_and_value(task, y_test_rb, test_pred_rb) + + print("[prediction consistency]") + print( + f"val max |orig - relbench|: {_max_prediction_delta(val_pred_orig, val_pred_rb):.6g}" + ) + print( + f"test max |orig - relbench|: {_max_prediction_delta(test_pred_orig, test_pred_rb):.6g}" + ) + print() + + print(f"[{metric_name}]") + print(f"validation original-openml: {val_metric_orig:.6f}") + print(f"validation relbenchified: {val_metric_rb:.6f}") + print(f"test original-openml: {test_metric_orig:.6f}") + print(f"test relbenchified: {test_metric_rb:.6f}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 5d908c21..ff6b8b83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ example=[ "torch_geometric", "tqdm", ] +tabarena=[ + "openml", +] test=[ "pytest", ] diff --git a/relbench/datasets/__init__.py b/relbench/datasets/__init__.py index 91dcf8f5..b2e91cda 100644 --- a/relbench/datasets/__init__.py +++ b/relbench/datasets/__init__.py @@ -18,6 +18,7 @@ ratebeer, salt, stack, + tabarena, tgb, trial, ) @@ -72,6 +73,13 @@ def download_dataset(name: str) -> None: `dataset.get_db()` is called. """ + if name.startswith("tabarena-"): + print( + f"Dataset '{name}' is derived from TabArena OpenML tasks and must be " + "generated locally; skipping download." + ) + return + if name == "rel-mimic": from relbench.datasets.mimic import verify_mimic_access @@ -172,3 +180,10 @@ def get_dataset(name: str, download=True) -> Dataset: register_dataset("tgbn-genre", tgb.TGBDataset, tgb_name="tgbn-genre") register_dataset("tgbn-reddit", tgb.TGBDataset, tgb_name="tgbn-reddit") register_dataset("tgbn-token", tgb.TGBDataset, tgb_name="tgbn-token") + +for dataset_slug in tabarena.get_tabarena_dataset_slugs(): + register_dataset( + f"tabarena-{dataset_slug}", + tabarena.TabArenaDataset, + dataset_slug=dataset_slug, + ) diff --git a/relbench/datasets/tabarena.py b/relbench/datasets/tabarena.py new file mode 100644 index 00000000..0f4157f8 --- /dev/null +++ b/relbench/datasets/tabarena.py @@ -0,0 +1,719 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import pandas as pd + +from relbench.base import Database, Dataset, Table + + +@dataclass(frozen=True) +class TabArenaDatasetSpec: + slug: str + name: str + task_id: int + dataset_id: int + target: str + task_type: str + num_classes: int + fold_count: int + + +TABARENA_DATASETS: dict[str, TabArenaDatasetSpec] = { + "airfoil-self-noise": TabArenaDatasetSpec( + slug="airfoil-self-noise", + name="airfoil_self_noise", + task_id=363612, + dataset_id=46904, + target="scaled-sound-pressure", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "amazon-employee-access": TabArenaDatasetSpec( + slug="amazon-employee-access", + name="Amazon_employee_access", + task_id=363613, + dataset_id=46905, + target="ResourceApproved", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "anneal": TabArenaDatasetSpec( + slug="anneal", + name="anneal", + task_id=363614, + dataset_id=46906, + target="classes", + task_type="Supervised Classification", + num_classes=5, + fold_count=30, + ), + "another-dataset-on-used-fiat-500": TabArenaDatasetSpec( + slug="another-dataset-on-used-fiat-500", + name="Another-Dataset-on-used-Fiat-500", + task_id=363615, + dataset_id=46907, + target="price", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "apsfailure": TabArenaDatasetSpec( + slug="apsfailure", + name="APSFailure", + task_id=363616, + dataset_id=46908, + target="AirPressureSystemFailure", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "bank-customer-churn": TabArenaDatasetSpec( + slug="bank-customer-churn", + name="Bank_Customer_Churn", + task_id=363619, + dataset_id=46911, + target="churn", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "bank-marketing": TabArenaDatasetSpec( + slug="bank-marketing", + name="bank-marketing", + task_id=363618, + dataset_id=46910, + target="SubscribeTermDeposit", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "bioresponse": TabArenaDatasetSpec( + slug="bioresponse", + name="Bioresponse", + task_id=363620, + dataset_id=46912, + target="MoleculeElicitsResponse", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "blood-transfusion-service-center": TabArenaDatasetSpec( + slug="blood-transfusion-service-center", + name="blood-transfusion-service-center", + task_id=363621, + dataset_id=46913, + target="DonatedBloodInMarch2007", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "churn": TabArenaDatasetSpec( + slug="churn", + name="churn", + task_id=363623, + dataset_id=46915, + target="CustomerChurned", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "coil2000-insurance-policies": TabArenaDatasetSpec( + slug="coil2000-insurance-policies", + name="coil2000_insurance_policies", + task_id=363624, + dataset_id=46916, + target="MobileHomePolicy", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "concrete-compressive-strength": TabArenaDatasetSpec( + slug="concrete-compressive-strength", + name="concrete_compressive_strength", + task_id=363625, + dataset_id=46917, + target="ConcreteCompressiveStrength", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "credit-card-clients-default": TabArenaDatasetSpec( + slug="credit-card-clients-default", + name="credit_card_clients_default", + task_id=363627, + dataset_id=46919, + target="DefaultOnPaymentNextMonth", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "credit-g": TabArenaDatasetSpec( + slug="credit-g", + name="credit-g", + task_id=363626, + dataset_id=46918, + target="good_or_bad_customer", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "customer-satisfaction-in-airline": TabArenaDatasetSpec( + slug="customer-satisfaction-in-airline", + name="customer_satisfaction_in_airline", + task_id=363628, + dataset_id=46920, + target="satisfaction", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "diabetes": TabArenaDatasetSpec( + slug="diabetes", + name="diabetes", + task_id=363629, + dataset_id=46921, + target="TestedPositiveForDiabetes", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "diabetes130us": TabArenaDatasetSpec( + slug="diabetes130us", + name="Diabetes130US", + task_id=363630, + dataset_id=46922, + target="EarlyReadmission", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "diamonds": TabArenaDatasetSpec( + slug="diamonds", + name="diamonds", + task_id=363631, + dataset_id=46923, + target="price", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "e-commereshippingdata": TabArenaDatasetSpec( + slug="e-commereshippingdata", + name="E-CommereShippingData", + task_id=363632, + dataset_id=46924, + target="ArrivedLate", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "fitness-club": TabArenaDatasetSpec( + slug="fitness-club", + name="Fitness_Club", + task_id=363671, + dataset_id=46927, + target="attended", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "food-delivery-time": TabArenaDatasetSpec( + slug="food-delivery-time", + name="Food_Delivery_Time", + task_id=363672, + dataset_id=46928, + target="Time_taken(min)", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "givemesomecredit": TabArenaDatasetSpec( + slug="givemesomecredit", + name="GiveMeSomeCredit", + task_id=363673, + dataset_id=46929, + target="FinancialDistressNextTwoYears", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "hazelnut-spread-contaminant-detection": TabArenaDatasetSpec( + slug="hazelnut-spread-contaminant-detection", + name="hazelnut-spread-contaminant-detection", + task_id=363674, + dataset_id=46930, + target="Contaminated", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "healthcare-insurance-expenses": TabArenaDatasetSpec( + slug="healthcare-insurance-expenses", + name="healthcare_insurance_expenses", + task_id=363675, + dataset_id=46931, + target="charges", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "heloc": TabArenaDatasetSpec( + slug="heloc", + name="heloc", + task_id=363676, + dataset_id=46932, + target="RiskPerformance", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "hiva-agnostic": TabArenaDatasetSpec( + slug="hiva-agnostic", + name="hiva_agnostic", + task_id=363677, + dataset_id=46933, + target="CompoundActivity", + task_type="Supervised Classification", + num_classes=3, + fold_count=9, + ), + "houses": TabArenaDatasetSpec( + slug="houses", + name="houses", + task_id=363678, + dataset_id=46934, + target="LnMedianHouseValue", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "hr-analytics-job-change-of-data-scientists": TabArenaDatasetSpec( + slug="hr-analytics-job-change-of-data-scientists", + name="HR_Analytics_Job_Change_of_Data_Scientists", + task_id=363679, + dataset_id=46935, + target="LookingForJobChange", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "in-vehicle-coupon-recommendation": TabArenaDatasetSpec( + slug="in-vehicle-coupon-recommendation", + name="in_vehicle_coupon_recommendation", + task_id=363681, + dataset_id=46937, + target="AcceptCoupon", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "is-this-a-good-customer": TabArenaDatasetSpec( + slug="is-this-a-good-customer", + name="Is-this-a-good-customer", + task_id=363682, + dataset_id=46938, + target="bad_client_target", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "jm1": TabArenaDatasetSpec( + slug="jm1", + name="jm1", + task_id=363712, + dataset_id=46979, + target="defects", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "kddcup09-appetency": TabArenaDatasetSpec( + slug="kddcup09-appetency", + name="kddcup09_appetency", + task_id=363683, + dataset_id=46939, + target="appetency", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "marketing-campaign": TabArenaDatasetSpec( + slug="marketing-campaign", + name="Marketing_Campaign", + task_id=363684, + dataset_id=46940, + target="Response", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "maternal-health-risk": TabArenaDatasetSpec( + slug="maternal-health-risk", + name="maternal_health_risk", + task_id=363685, + dataset_id=46941, + target="RiskLevel", + task_type="Supervised Classification", + num_classes=3, + fold_count=30, + ), + "miami-housing": TabArenaDatasetSpec( + slug="miami-housing", + name="miami_housing", + task_id=363686, + dataset_id=46942, + target="SALE_PRC", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "mic": TabArenaDatasetSpec( + slug="mic", + name="MIC", + task_id=363711, + dataset_id=46980, + target="LET_IS", + task_type="Supervised Classification", + num_classes=8, + fold_count=30, + ), + "naticusdroid": TabArenaDatasetSpec( + slug="naticusdroid", + name="NATICUSdroid", + task_id=363689, + dataset_id=46969, + target="Malware", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "online-shoppers-intention": TabArenaDatasetSpec( + slug="online-shoppers-intention", + name="online_shoppers_intention", + task_id=363691, + dataset_id=46947, + target="Revenue", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "physiochemical-protein": TabArenaDatasetSpec( + slug="physiochemical-protein", + name="physiochemical_protein", + task_id=363693, + dataset_id=46949, + target="ResidualSize", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "polish-companies-bankruptcy": TabArenaDatasetSpec( + slug="polish-companies-bankruptcy", + name="polish_companies_bankruptcy", + task_id=363694, + dataset_id=46950, + target="company_bankrupt", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "qsar-biodeg": TabArenaDatasetSpec( + slug="qsar-biodeg", + name="qsar-biodeg", + task_id=363696, + dataset_id=46952, + target="Biodegradable", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "qsar-fish-toxicity": TabArenaDatasetSpec( + slug="qsar-fish-toxicity", + name="QSAR_fish_toxicity", + task_id=363698, + dataset_id=46954, + target="LC50", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "qsar-tid-11": TabArenaDatasetSpec( + slug="qsar-tid-11", + name="QSAR-TID-11", + task_id=363697, + dataset_id=46953, + target="MEDIAN_PXC50", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "sdss17": TabArenaDatasetSpec( + slug="sdss17", + name="SDSS17", + task_id=363699, + dataset_id=46955, + target="ObjectType", + task_type="Supervised Classification", + num_classes=3, + fold_count=9, + ), + "seismic-bumps": TabArenaDatasetSpec( + slug="seismic-bumps", + name="seismic-bumps", + task_id=363700, + dataset_id=46956, + target="HighEnergySeismicBump", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "splice": TabArenaDatasetSpec( + slug="splice", + name="splice", + task_id=363702, + dataset_id=46958, + target="SiteType", + task_type="Supervised Classification", + num_classes=3, + fold_count=9, + ), + "students-dropout-and-academic-success": TabArenaDatasetSpec( + slug="students-dropout-and-academic-success", + name="students_dropout_and_academic_success", + task_id=363704, + dataset_id=46960, + target="AcademicOutcome", + task_type="Supervised Classification", + num_classes=3, + fold_count=9, + ), + "superconductivity": TabArenaDatasetSpec( + slug="superconductivity", + name="superconductivity", + task_id=363705, + dataset_id=46961, + target="critical_temp", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "taiwanese-bankruptcy-prediction": TabArenaDatasetSpec( + slug="taiwanese-bankruptcy-prediction", + name="taiwanese_bankruptcy_prediction", + task_id=363706, + dataset_id=46962, + target="Bankrupt", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "website-phishing": TabArenaDatasetSpec( + slug="website-phishing", + name="website_phishing", + task_id=363707, + dataset_id=46963, + target="WebsiteType", + task_type="Supervised Classification", + num_classes=3, + fold_count=30, + ), + "wine-quality": TabArenaDatasetSpec( + slug="wine-quality", + name="wine_quality", + task_id=363708, + dataset_id=46964, + target="median_wine_quality", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), +} + + +def get_tabarena_dataset_slugs() -> list[str]: + return sorted(TABARENA_DATASETS.keys()) + + +def _import_openml(): + try: + import openml + except ImportError as exc: + raise ImportError( + "TabArena datasets require the `openml` package. Install it with " + "`pip install relbench[tabarena]` (or `pip install openml`)." + ) from exc + return openml + + +def _problem_type_from_spec(task_type: str, num_classes: int) -> str: + if task_type == "Supervised Regression": + return "regression" + if int(num_classes) <= 2: + return "binary" + return "multiclass" + + +class TabArenaDataset(Dataset): + r"""Single-table RelBench dataset wrapper over TabArena OpenML tasks.""" + + url = "https://huggingface.co/datasets/TabArena/benchmark_results" + val_timestamp = pd.Timestamp("2000-01-02") + test_timestamp = pd.Timestamp("2000-01-03") + + def __init__(self, *, dataset_slug: str, cache_dir: Optional[str] = None): + slug = str(dataset_slug) + if slug not in TABARENA_DATASETS: + raise ValueError( + f"Unknown TabArena dataset slug={slug!r}. Known values: {sorted(TABARENA_DATASETS.keys())}" + ) + + self.spec = TABARENA_DATASETS[slug] + self.name = f"tabarena-{slug}" + self.tabarena_name = self.spec.name + self.task_id = int(self.spec.task_id) + self.openml_dataset_id = int(self.spec.dataset_id) + self.target_name = self.spec.target + self.problem_type = _problem_type_from_spec( + self.spec.task_type, self.spec.num_classes + ) + self.num_classes = ( + int(self.spec.num_classes) if self.problem_type != "regression" else 0 + ) + + self._openml_task = None + self._X_df: Optional[pd.DataFrame] = None + self._y_encoded: Optional[np.ndarray] = None + + super().__init__(cache_dir=cache_dir) + + @property + def available_splits(self) -> list[int]: + return list(range(int(self.spec.fold_count))) + + @property + def available_folds(self) -> list[int]: + # Backward-compatible alias. + return self.available_splits + + def _load_task_with_retry(self, task_id: int, retries: int = 4): + openml = _import_openml() + delay = 1.0 + for attempt in range(retries + 1): + try: + return openml.tasks.get_task( + task_id, + download_splits=True, + download_data=True, + download_qualities=False, + download_features_meta_data=True, + ) + except Exception: + if attempt == retries: + raise + time.sleep(delay) + delay *= 2.0 + + def _ensure_openml_loaded(self) -> None: + if ( + self._openml_task is not None + and self._X_df is not None + and self._y_encoded is not None + ): + return + + task = self._load_task_with_retry(self.task_id) + X_df, y_ser, _cat, _names = task.get_dataset().get_data( + target=task.target_name, + dataset_format="dataframe", + ) + + X_df = pd.DataFrame(X_df).reset_index(drop=True) + y_ser = pd.Series(y_ser, name=task.target_name).reset_index(drop=True) + + if self.problem_type == "regression": + y_encoded = y_ser.astype(float).to_numpy(copy=True) + else: + cat = pd.Categorical(y_ser) + if cat.codes.min() < 0: + raise RuntimeError( + f"Encountered missing labels in OpenML task_id={self.task_id} ({self.tabarena_name})." + ) + y_encoded = cat.codes.astype(np.int64, copy=False) + detected_num_classes = int(len(cat.categories)) + if self.num_classes and detected_num_classes != self.num_classes: + raise RuntimeError( + f"Label cardinality mismatch for {self.tabarena_name}: expected {self.num_classes}, got {detected_num_classes}." + ) + self.num_classes = detected_num_classes + + if len(X_df) != len(y_encoded): + raise RuntimeError( + f"Feature/label row mismatch for {self.tabarena_name}: {len(X_df)} vs {len(y_encoded)}." + ) + + self._openml_task = task + self._X_df = X_df + self._y_encoded = y_encoded + + def get_openml_task(self): + self._ensure_openml_loaded() + return self._openml_task + + def get_target_array(self) -> np.ndarray: + self._ensure_openml_loaded() + assert self._y_encoded is not None + return self._y_encoded + + def get_openml_split_indices(self, split: int) -> tuple[np.ndarray, np.ndarray]: + split = int(split) + if split < 0 or split >= int(self.spec.fold_count): + raise ValueError( + f"Invalid split={split} for {self.name}. Valid splits are 0..{int(self.spec.fold_count) - 1}." + ) + + task = self.get_openml_task() + n_repeats, n_folds, _n_samples = task.get_split_dimensions() + repeat = split // int(n_folds) + fold_in_repeat = split % int(n_folds) + if repeat >= int(n_repeats): + raise ValueError( + f"Split index {split} exceeds OpenML split dimensions for {self.name}: repeats={n_repeats}, folds={n_folds}." + ) + + train_idx, test_idx = task.get_train_test_split_indices( + repeat=repeat, + fold=fold_in_repeat, + sample=0, + ) + return ( + np.asarray(train_idx, dtype=np.int64), + np.asarray(test_idx, dtype=np.int64), + ) + + def get_openml_fold_indices(self, fold: int) -> tuple[np.ndarray, np.ndarray]: + # Backward-compatible alias. + return self.get_openml_split_indices(split=fold) + + def make_db(self) -> Database: + self._ensure_openml_loaded() + assert self._X_df is not None + + records = self._X_df.copy(deep=True) + records.insert(0, "record_id", np.arange(len(records), dtype=np.int64)) + + return Database( + { + "records": Table( + df=records, + fkey_col_to_pkey_table={}, + pkey_col="record_id", + time_col=None, + ) + } + ) diff --git a/relbench/tasks/__init__.py b/relbench/tasks/__init__.py index cb643d95..bf6a5f4d 100644 --- a/relbench/tasks/__init__.py +++ b/relbench/tasks/__init__.py @@ -9,6 +9,7 @@ from relbench.base import AutoCompleteTask, BaseTask, TaskType from relbench.datasets import get_dataset +from relbench.datasets.tabarena import TABARENA_DATASETS from relbench.tasks import ( amazon, arxiv, @@ -20,6 +21,7 @@ mimic, ratebeer, stack, + tabarena, tgb, trial, ) @@ -76,6 +78,13 @@ def download_task(dataset_name: str, task_name: str) -> None: `task.get_table(split)` is called. """ + if dataset_name.startswith("tabarena-"): + print( + f"Task '{dataset_name}/{task_name}' is derived from TabArena OpenML tasks " + "and must be generated locally; skipping download." + ) + return + DOWNLOAD_REGISTRY.fetch( f"{dataset_name}/tasks/{task_name}.zip", processor=pooch.Unzip(extract_dir="."), @@ -594,3 +603,13 @@ def _register_thgl_edge_type_tasks(dataset_name: str, edge_types: list[int]) -> spec=tgb.TGBNodePropSpec(), k=10, ) + +for dataset_slug, spec in TABARENA_DATASETS.items(): + dataset_name = f"tabarena-{dataset_slug}" + for split in range(spec.fold_count): + register_task( + dataset_name, + f"split-{split}", + tabarena.TabArenaSplitEntityTask, + split=split, + ) diff --git a/relbench/tasks/tabarena.py b/relbench/tasks/tabarena.py new file mode 100644 index 00000000..3b6cbec9 --- /dev/null +++ b/relbench/tasks/tabarena.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import Any, Dict, Optional + +import numpy as np +import pandas as pd +from sklearn.metrics import log_loss as sklearn_log_loss +from sklearn.metrics import roc_auc_score +from sklearn.model_selection import train_test_split + +from relbench.base import EntityTask, Table, TaskType +from relbench.datasets.tabarena import TabArenaDataset + + +def _binary_metric_error(true: np.ndarray, pred: np.ndarray) -> float: + pred = np.asarray(pred, dtype=np.float64) + if pred.ndim > 1: + pred = pred.reshape(pred.shape[0], -1) + if pred.shape[1] == 1: + pred = pred[:, 0] + else: + pred = pred[:, 1] + if pred.min() < 0.0 or pred.max() > 1.0: + pred = np.clip(pred, -40.0, 40.0) + pred = 1.0 / (1.0 + np.exp(-pred)) + score = roc_auc_score(np.asarray(true, dtype=np.int64), pred) + return float(1.0 - score) + + +def _softmax(x: np.ndarray) -> np.ndarray: + x = x - np.max(x, axis=1, keepdims=True) + exp_x = np.exp(x) + return exp_x / np.sum(exp_x, axis=1, keepdims=True) + + +def _multiclass_metric_error(true: np.ndarray, pred: np.ndarray) -> float: + return _multiclass_metric_error_with_num_classes(true, pred, num_classes=None) + + +def _multiclass_metric_error_with_num_classes( + true: np.ndarray, + pred: np.ndarray, + num_classes: Optional[int], +) -> float: + pred = np.asarray(pred, dtype=np.float64) + true_arr = np.asarray(true, dtype=np.int64) + + if num_classes is not None: + inferred_num_classes = int(num_classes) + elif pred.ndim == 2: + inferred_num_classes = int(pred.shape[1]) + else: + inferred_num_classes = int( + max(true_arr.max(initial=0), pred.max(initial=0)) + 1 + ) + if inferred_num_classes <= 1: + inferred_num_classes = 2 + + if pred.ndim == 1: + pred_labels = pred.astype(np.int64, copy=False) + pred_labels = np.clip(pred_labels, 0, inferred_num_classes - 1) + eps = 1e-7 + probs = np.full( + (len(pred_labels), inferred_num_classes), + fill_value=eps / max(inferred_num_classes - 1, 1), + dtype=np.float64, + ) + probs[np.arange(len(pred_labels), dtype=np.int64), pred_labels] = 1.0 - eps + elif pred.ndim == 2: + probs = pred + if probs.shape[1] < inferred_num_classes: + padding = np.full( + (probs.shape[0], inferred_num_classes - probs.shape[1]), + fill_value=0.0, + dtype=np.float64, + ) + probs = np.hstack([probs, padding]) + elif probs.shape[1] > inferred_num_classes: + probs = probs[:, :inferred_num_classes] + + row_sums = probs.sum(axis=1) + if np.all(probs >= 0.0) and np.allclose(row_sums, 1.0, atol=1e-4): + pass + else: + probs = _softmax(probs) + else: + raise ValueError( + "Expected multiclass predictions with shape (N,) or (N, num_classes). " + f"Got shape {pred.shape}." + ) + + labels = np.arange(inferred_num_classes, dtype=np.int64) + return float( + sklearn_log_loss( + true_arr, + probs, + labels=labels, + ) + ) + + +def _regression_metric_error(true: np.ndarray, pred: np.ndarray) -> float: + true = np.asarray(true, dtype=np.float64) + pred = np.asarray(pred, dtype=np.float64).reshape(-1) + return float(np.sqrt(np.mean((true - pred) ** 2))) + + +_binary_metric_error.__name__ = "metric_error" +_multiclass_metric_error.__name__ = "metric_error" +_regression_metric_error.__name__ = "metric_error" + + +class TabArenaSplitEntityTask(EntityTask): + r"""Single-table TabArena task for a specific OpenML split index.""" + + entity_col = "record_id" + entity_table = "records" + time_col = None + target_col = "target" + timedelta = pd.Timedelta(days=1) + num_eval_timestamps = 1 + + def __init__( + self, + dataset, + *, + split: Optional[int] = None, + fold: Optional[int] = None, + val_frac: float = 0.2, + random_state: Optional[int] = None, + cache_dir: Optional[str] = None, + ): + if not isinstance(dataset, TabArenaDataset): + raise TypeError( + "TabArenaSplitEntityTask expects a TabArenaDataset instance. " + f"Got {type(dataset)}" + ) + + if split is None and fold is None: + raise ValueError("Exactly one of `split` or `fold` must be provided.") + if split is not None and fold is not None and int(split) != int(fold): + raise ValueError( + f"Received conflicting split={split} and fold={fold}; please provide one index." + ) + + self.split = int(split if split is not None else fold) + self.fold = self.split # Backward-compatible alias. + self.val_frac = float(val_frac) + if not (0.0 < self.val_frac < 1.0): + raise ValueError(f"val_frac must be in (0, 1), got {self.val_frac}") + self.random_state = self.split if random_state is None else int(random_state) + + if self.split not in dataset.available_splits: + raise ValueError( + f"Split={self.split} is unavailable for {dataset.name}. " + f"Available splits: {dataset.available_splits}" + ) + + self.problem_type = dataset.problem_type + if self.problem_type == "regression": + self.task_type = TaskType.REGRESSION + self.metrics = [_regression_metric_error] + elif self.problem_type == "binary": + self.task_type = TaskType.BINARY_CLASSIFICATION + self.metrics = [_binary_metric_error] + elif self.problem_type == "multiclass": + self.task_type = TaskType.MULTICLASS_CLASSIFICATION + self.num_classes = int(dataset.num_classes) + self.metrics = [ + _make_multiclass_metric_error_with_num_classes(self.num_classes) + ] + else: + raise ValueError(f"Unsupported problem_type={self.problem_type}") + + super().__init__(dataset, cache_dir=cache_dir) + + def make_table(self, db, timestamps): # pragma: no cover + raise RuntimeError( + "TabArenaSplitEntityTask uses precomputed OpenML split indices and overrides _get_table()." + ) + + @lru_cache(maxsize=None) + def _split_indices(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + train_idx, test_idx = self.dataset.get_openml_split_indices(self.split) + y = self.dataset.get_target_array() + + stratify = ( + y[train_idx] if self.problem_type in {"binary", "multiclass"} else None + ) + try: + train_idx, val_idx = train_test_split( + train_idx, + test_size=self.val_frac, + random_state=self.random_state, + shuffle=True, + stratify=stratify, + ) + except ValueError: + train_idx, val_idx = train_test_split( + train_idx, + test_size=self.val_frac, + random_state=self.random_state, + shuffle=True, + stratify=None, + ) + + return ( + np.asarray(train_idx, dtype=np.int64), + np.asarray(val_idx, dtype=np.int64), + np.asarray(test_idx, dtype=np.int64), + ) + + def _get_table(self, split: str) -> Table: + if split not in {"train", "val", "test"}: + raise ValueError( + "Unknown split=" f"{split!r}. Expected one of ['test', 'train', 'val']." + ) + + train_idx, val_idx, test_idx = self._split_indices() + if split == "train": + idx = train_idx + elif split == "val": + idx = val_idx + else: + idx = test_idx + + y = self.dataset.get_target_array() + target = y[idx] + + df = pd.DataFrame( + { + self.entity_col: idx.astype(np.int64, copy=False), + self.target_col: target, + } + ) + + if self.task_type != TaskType.REGRESSION: + df[self.target_col] = df[self.target_col].astype(np.int64) + + return Table( + df=df, + # Keep task tables edge-free so single-table TabArena contexts stay single-node + # in external context samplers (e.g., RT rustler). + fkey_col_to_pkey_table={}, + pkey_col=None, + time_col=None, + ) + + def _mask_input_cols(self, table: Table) -> Table: + # Keep entity ids visible for inference while preserving edge-free task tables. + return Table( + df=table.df[[self.entity_col]], + fkey_col_to_pkey_table={}, + pkey_col=table.pkey_col, + time_col=table.time_col, + ) + + def stats(self) -> Dict[str, Dict[str, Any]]: + r"""Get split-level statistics for tasks without a time column.""" + res: Dict[str, Dict[str, Any]] = {} + for split in ["train", "val", "test"]: + table = self.get_table(split, mask_input_cols=False) + stats: Dict[str, Any] = { + "num_rows": len(table.df), + "num_unique_entities": table.df[self.entity_col].nunique(), + } + self._set_stats(table.df, stats) + res[split] = {"total": stats} + + total_df = pd.concat( + [ + self.get_table(split, mask_input_cols=False).df + for split in ["train", "val", "test"] + ] + ) + res["total"] = {} + self._set_stats(total_df, res["total"]) + + train_uniques = set(self.get_table("train").df[self.entity_col].unique()) + test_uniques = set( + self.get_table("test", mask_input_cols=False).df[self.entity_col].unique() + ) + overlap = len(train_uniques.intersection(test_uniques)) + res["total"]["ratio_train_test_entity_overlap"] = ( + overlap / len(test_uniques) if test_uniques else float("nan") + ) + return res + + +# Backward-compatible alias. +TabArenaFoldEntityTask = TabArenaSplitEntityTask + + +def _make_multiclass_metric_error_with_num_classes( + num_classes: int, +): + def _metric(true: np.ndarray, pred: np.ndarray) -> float: + return _multiclass_metric_error_with_num_classes( + true, + pred, + num_classes=num_classes, + ) + + _metric.__name__ = "metric_error" + return _metric diff --git a/test/datasets/test_tabarena.py b/test/datasets/test_tabarena.py new file mode 100644 index 00000000..e759838c --- /dev/null +++ b/test/datasets/test_tabarena.py @@ -0,0 +1,160 @@ +import numpy as np +import pandas as pd + +from relbench.datasets.tabarena import TabArenaDataset +from relbench.tasks.tabarena import TabArenaSplitEntityTask + + +class _FakeOpenMLDataset: + def __init__(self, X_df: pd.DataFrame, y_ser: pd.Series, target_name: str): + self._X_df = X_df + self._y_ser = y_ser + self._target_name = str(target_name) + + def get_data(self, *, target: str, dataset_format: str): + assert target == self._target_name + assert dataset_format == "dataframe" + categorical_indicator = [False] * len(self._X_df.columns) + attribute_names = list(self._X_df.columns) + return self._X_df, self._y_ser, categorical_indicator, attribute_names + + +class _FakeOpenMLTask: + def __init__( + self, + *, + target_name: str, + X_df: pd.DataFrame, + y_ser: pd.Series, + n_repeats: int, + n_folds: int, + ): + self.target_name = str(target_name) + self._dataset = _FakeOpenMLDataset( + X_df=X_df, y_ser=y_ser, target_name=target_name + ) + self._n_repeats = int(n_repeats) + self._n_folds = int(n_folds) + self._n_samples = int(len(X_df)) + + def get_dataset(self): + return self._dataset + + def get_split_dimensions(self): + return self._n_repeats, self._n_folds, self._n_samples + + def get_train_test_split_indices(self, *, repeat: int, fold: int, sample: int): + assert int(sample) == 0 + repeat = int(repeat) + fold = int(fold) + if repeat < 0 or repeat >= self._n_repeats: + raise ValueError( + f"repeat={repeat} out of range for n_repeats={self._n_repeats}" + ) + if fold < 0 or fold >= self._n_folds: + raise ValueError(f"fold={fold} out of range for n_folds={self._n_folds}") + + # Simple deterministic CV split: each fold takes indices i where i % n_folds == fold. + idx = np.arange(self._n_samples, dtype=np.int64) + test_idx = idx[idx % self._n_folds == fold] + train_idx = idx[idx % self._n_folds != fold] + return train_idx, test_idx + + +def _install_fake_openml(monkeypatch): + def _fake_load_task_with_retry( + self: TabArenaDataset, task_id: int, retries: int = 4 + ): + _ = task_id + _ = retries + n_samples = 90 + n_folds = int(self.spec.fold_count) + target_name = self.spec.target + + X_df = pd.DataFrame( + { + "feat_num": np.arange(n_samples, dtype=np.int64), + "feat_mod3": np.arange(n_samples, dtype=np.int64) % 3, + } + ) + if self.problem_type == "regression": + y_ser = pd.Series(np.linspace(0.0, 1.0, n_samples), name=target_name) + elif self.problem_type == "binary": + y_ser = pd.Series(["no", "yes"] * (n_samples // 2), name=target_name) + else: + classes = [f"class_{i}" for i in range(int(self.spec.num_classes))] + y_ser = pd.Series( + [classes[i % len(classes)] for i in range(n_samples)], name=target_name + ) + + return _FakeOpenMLTask( + target_name=target_name, + X_df=X_df, + y_ser=y_ser, + n_repeats=1, + n_folds=n_folds, + ) + + monkeypatch.setattr( + TabArenaDataset, "_load_task_with_retry", _fake_load_task_with_retry + ) + + +def test_tabarena_dataset_and_task_binary(monkeypatch): + _install_fake_openml(monkeypatch) + + dataset = TabArenaDataset(dataset_slug="apsfailure", cache_dir=None) + db = dataset.get_db() + records = db.table_dict["records"] + assert records.pkey_col == "record_id" + assert records.time_col is None + assert len(records) == 90 + + train_idx, test_idx = dataset.get_openml_split_indices(0) + assert train_idx.dtype == np.int64 + assert test_idx.dtype == np.int64 + assert set(train_idx).isdisjoint(set(test_idx)) + + task = TabArenaSplitEntityTask(dataset, split=0, cache_dir=None) + train_table = task.get_table("train") + assert set(train_table.df.columns) == {"record_id", "target"} + assert train_table.fkey_col_to_pkey_table == {} + test_table = task.get_table("test") + assert set(test_table.df.columns) == {"record_id"} + assert test_table.fkey_col_to_pkey_table == {} + + # Perfect predictions yield AUC=1.0 => metric_error=0.0. + full_test = task.get_table("test", mask_input_cols=False) + y_true = full_test.df["target"].to_numpy() + metrics = task.evaluate(y_true, target_table=full_test) + assert metrics["metric_error"] == 0.0 + + +def test_tabarena_dataset_and_task_regression(monkeypatch): + _install_fake_openml(monkeypatch) + + dataset = TabArenaDataset(dataset_slug="diamonds", cache_dir=None) + task = TabArenaSplitEntityTask(dataset, split=0, cache_dir=None) + + full_test = task.get_table("test", mask_input_cols=False) + y_true = full_test.df["target"].to_numpy() + metrics = task.evaluate(y_true, target_table=full_test) + assert metrics["metric_error"] == 0.0 + + +def test_tabarena_task_multiclass(monkeypatch): + _install_fake_openml(monkeypatch) + + dataset = TabArenaDataset(dataset_slug="splice", cache_dir=None) + task = TabArenaSplitEntityTask(dataset, split=0, cache_dir=None) + + full_test = task.get_table("test", mask_input_cols=False) + y_true = full_test.df["target"].to_numpy() + num_classes = int(dataset.num_classes) + + # Use a uniform distribution: log loss is finite and should be > 0. + probs = np.full( + (len(y_true), num_classes), fill_value=1.0 / num_classes, dtype=np.float64 + ) + metrics = task.evaluate(probs, target_table=full_test) + assert metrics["metric_error"] > 0.0