From 6bb795c013667ab2ff580d3e563aabb57b3a4fb4 Mon Sep 17 00:00:00 2001 From: pc0618 Date: Mon, 2 Mar 2026 00:54:25 +0000 Subject: [PATCH 1/8] Add TabArena single-table OpenML datasets --- README.md | 11 + examples/translate_tabarena_to_relbench.py | 124 ++++ pyproject.toml | 3 + relbench/datasets/__init__.py | 15 + relbench/datasets/tabarena.py | 710 +++++++++++++++++++++ relbench/tasks/__init__.py | 19 + relbench/tasks/tabarena.py | 258 ++++++++ test/datasets/test_tabarena.py | 158 +++++ 8 files changed, 1298 insertions(+) create mode 100644 examples/translate_tabarena_to_relbench.py create mode 100644 relbench/datasets/tabarena.py create mode 100644 relbench/tasks/tabarena.py create mode 100644 test/datasets/test_tabarena.py diff --git a/README.md b/README.md index 3c973226..a3130746 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,17 @@ If you use the 4DBInfer datasets in your work, please cite [4DBInfer](https://gi ``` +**Using TabArena datasets** + +RelBench includes an optional integration for TabArena: a collection of single-table OpenML tasks. TabArena datasets are exposed under names like `tabarena-credit-g`, with per-fold tasks named `fold-0`, `fold-1`, etc. + +To use TabArena datasets, install the optional dependency: +```bash +pip install relbench[tabarena] +``` + +TabArena datasets are generated locally (from OpenML) and cached under `~/.cache/relbench/tabarena-*/`. Passing `download=True` will skip the RelBench server download step for these datasets/tasks. + # Package Usage diff --git a/examples/translate_tabarena_to_relbench.py b/examples/translate_tabarena_to_relbench.py new file mode 100644 index 00000000..66ddab44 --- /dev/null +++ b/examples/translate_tabarena_to_relbench.py @@ -0,0 +1,124 @@ +"""Utilities to inspect how TabArena datasets are represented in RelBench.""" + +import argparse +from pathlib import Path + +import pandas as pd + +from relbench.datasets import get_dataset +from relbench.datasets.tabarena import TABARENA_DATASETS, get_tabarena_dataset_slugs + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset_slugs", + type=str, + default="all", + help=( + "Comma-separated TabArena slugs or 'all'. " + "Example: credit-g,airfoil-self-noise" + ), + ) + parser.add_argument( + "--output_csv", + type=str, + default="results/tabarena_relbench_translation.csv", + ) + parser.add_argument( + "--include_fold_examples", + action="store_true", + default=False, + help="If set, writes one sample record for fold-0 per dataset.", + ) + return parser.parse_args() + + +def _parse_dataset_slugs(arg: str) -> list[str]: + if arg.strip().lower() == "all": + return get_tabarena_dataset_slugs() + slugs = [slug.strip() for slug in arg.split(",") if slug.strip()] + valid = set(get_tabarena_dataset_slugs()) + invalid = [slug for slug in slugs if slug not in valid] + if invalid: + raise ValueError(f"Unknown dataset slugs: {invalid}") + return slugs + + +def _summarize_dataset(dataset_name: str) -> dict: + dataset = get_dataset(dataset_name, download=False) + spec = TABARENA_DATASETS[dataset.name.replace("tabarena-", "")] + db = dataset.make_db() + records = db.table_dict["records"] + row = { + "dataset_slug": spec.slug, + "dataset_name": dataset.name, + "tabarena_benchmark_name": spec.name, + "openml_task_id": spec.task_id, + "openml_dataset_id": spec.dataset_id, + "target_col": spec.target, + "problem_type": spec.task_type, + "num_classes": spec.num_classes, + "fold_count": spec.fold_count, + "records_rows": int(len(records.df)), + "records_columns": int(len(records.df.columns)), + "entity_table": "records", + "entity_pkey": records.pkey_col, + "split_timestamp_columns": bool(records.time_col is not None), + } + return row + + +def _summarize_fold_example(dataset_name: str) -> list[dict]: + from relbench.tasks import get_task + + rows: list[dict] = [] + for fold in [0]: + task = get_task(dataset_name, f"fold-{fold}") + train = task.get_table("train", mask_input_cols=False) + val = task.get_table("val", mask_input_cols=False) + test = task.get_table("test", mask_input_cols=False) + rows.append( + { + "dataset_name": dataset_name, + "fold": int(fold), + "task_type": str(task.task_type.value), + "fold_train_rows": int(len(train)), + "fold_val_rows": int(len(val)), + "fold_test_rows": int(len(test)), + "train_columns": int(len(train.df.columns)), + "time_col": str(train.time_col), + } + ) + return rows + + +def main() -> None: + args = _parse_args() + dataset_slugs = _parse_dataset_slugs(args.dataset_slugs) + output_path = Path(args.output_csv) + output_path.parent.mkdir(parents=True, exist_ok=True) + + dataset_rows: list[dict] = [] + split_rows: list[dict] = [] + for slug in dataset_slugs: + dataset_name = f"tabarena-{slug}" + print(f"[Inspect] {dataset_name}") + dataset_rows.append(_summarize_dataset(dataset_name)) + if args.include_fold_examples: + split_rows.extend(_summarize_fold_example(dataset_name)) + + df = pd.DataFrame(dataset_rows) + df.to_csv(output_path, index=False) + + if split_rows: + split_path = output_path.with_name(f"{output_path.stem}_fold_samples.csv") + pd.DataFrame(split_rows).to_csv(split_path, index=False) + print(f"[Done] dataset summary: {output_path}") + print(f"[Done] fold sample summary: {split_path}") + else: + print(f"[Done] dataset summary: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 5d908c21..ff6b8b83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ example=[ "torch_geometric", "tqdm", ] +tabarena=[ + "openml", +] test=[ "pytest", ] diff --git a/relbench/datasets/__init__.py b/relbench/datasets/__init__.py index 91dcf8f5..b2e91cda 100644 --- a/relbench/datasets/__init__.py +++ b/relbench/datasets/__init__.py @@ -18,6 +18,7 @@ ratebeer, salt, stack, + tabarena, tgb, trial, ) @@ -72,6 +73,13 @@ def download_dataset(name: str) -> None: `dataset.get_db()` is called. """ + if name.startswith("tabarena-"): + print( + f"Dataset '{name}' is derived from TabArena OpenML tasks and must be " + "generated locally; skipping download." + ) + return + if name == "rel-mimic": from relbench.datasets.mimic import verify_mimic_access @@ -172,3 +180,10 @@ def get_dataset(name: str, download=True) -> Dataset: register_dataset("tgbn-genre", tgb.TGBDataset, tgb_name="tgbn-genre") register_dataset("tgbn-reddit", tgb.TGBDataset, tgb_name="tgbn-reddit") register_dataset("tgbn-token", tgb.TGBDataset, tgb_name="tgbn-token") + +for dataset_slug in tabarena.get_tabarena_dataset_slugs(): + register_dataset( + f"tabarena-{dataset_slug}", + tabarena.TabArenaDataset, + dataset_slug=dataset_slug, + ) diff --git a/relbench/datasets/tabarena.py b/relbench/datasets/tabarena.py new file mode 100644 index 00000000..514881a6 --- /dev/null +++ b/relbench/datasets/tabarena.py @@ -0,0 +1,710 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import pandas as pd + +from relbench.base import Database, Dataset, Table + + +@dataclass(frozen=True) +class TabArenaDatasetSpec: + slug: str + name: str + task_id: int + dataset_id: int + target: str + task_type: str + num_classes: int + fold_count: int + + +TABARENA_DATASETS: dict[str, TabArenaDatasetSpec] = { + "airfoil-self-noise": TabArenaDatasetSpec( + slug="airfoil-self-noise", + name="airfoil_self_noise", + task_id=363612, + dataset_id=46904, + target="scaled-sound-pressure", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "amazon-employee-access": TabArenaDatasetSpec( + slug="amazon-employee-access", + name="Amazon_employee_access", + task_id=363613, + dataset_id=46905, + target="ResourceApproved", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "anneal": TabArenaDatasetSpec( + slug="anneal", + name="anneal", + task_id=363614, + dataset_id=46906, + target="classes", + task_type="Supervised Classification", + num_classes=5, + fold_count=30, + ), + "another-dataset-on-used-fiat-500": TabArenaDatasetSpec( + slug="another-dataset-on-used-fiat-500", + name="Another-Dataset-on-used-Fiat-500", + task_id=363615, + dataset_id=46907, + target="price", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "apsfailure": TabArenaDatasetSpec( + slug="apsfailure", + name="APSFailure", + task_id=363616, + dataset_id=46908, + target="AirPressureSystemFailure", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "bank-customer-churn": TabArenaDatasetSpec( + slug="bank-customer-churn", + name="Bank_Customer_Churn", + task_id=363619, + dataset_id=46911, + target="churn", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "bank-marketing": TabArenaDatasetSpec( + slug="bank-marketing", + name="bank-marketing", + task_id=363618, + dataset_id=46910, + target="SubscribeTermDeposit", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "bioresponse": TabArenaDatasetSpec( + slug="bioresponse", + name="Bioresponse", + task_id=363620, + dataset_id=46912, + target="MoleculeElicitsResponse", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "blood-transfusion-service-center": TabArenaDatasetSpec( + slug="blood-transfusion-service-center", + name="blood-transfusion-service-center", + task_id=363621, + dataset_id=46913, + target="DonatedBloodInMarch2007", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "churn": TabArenaDatasetSpec( + slug="churn", + name="churn", + task_id=363623, + dataset_id=46915, + target="CustomerChurned", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "coil2000-insurance-policies": TabArenaDatasetSpec( + slug="coil2000-insurance-policies", + name="coil2000_insurance_policies", + task_id=363624, + dataset_id=46916, + target="MobileHomePolicy", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "concrete-compressive-strength": TabArenaDatasetSpec( + slug="concrete-compressive-strength", + name="concrete_compressive_strength", + task_id=363625, + dataset_id=46917, + target="ConcreteCompressiveStrength", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "credit-card-clients-default": TabArenaDatasetSpec( + slug="credit-card-clients-default", + name="credit_card_clients_default", + task_id=363627, + dataset_id=46919, + target="DefaultOnPaymentNextMonth", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "credit-g": TabArenaDatasetSpec( + slug="credit-g", + name="credit-g", + task_id=363626, + dataset_id=46918, + target="good_or_bad_customer", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "customer-satisfaction-in-airline": TabArenaDatasetSpec( + slug="customer-satisfaction-in-airline", + name="customer_satisfaction_in_airline", + task_id=363628, + dataset_id=46920, + target="satisfaction", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "diabetes": TabArenaDatasetSpec( + slug="diabetes", + name="diabetes", + task_id=363629, + dataset_id=46921, + target="TestedPositiveForDiabetes", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "diabetes130us": TabArenaDatasetSpec( + slug="diabetes130us", + name="Diabetes130US", + task_id=363630, + dataset_id=46922, + target="EarlyReadmission", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "diamonds": TabArenaDatasetSpec( + slug="diamonds", + name="diamonds", + task_id=363631, + dataset_id=46923, + target="price", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "e-commereshippingdata": TabArenaDatasetSpec( + slug="e-commereshippingdata", + name="E-CommereShippingData", + task_id=363632, + dataset_id=46924, + target="ArrivedLate", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "fitness-club": TabArenaDatasetSpec( + slug="fitness-club", + name="Fitness_Club", + task_id=363671, + dataset_id=46927, + target="attended", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "food-delivery-time": TabArenaDatasetSpec( + slug="food-delivery-time", + name="Food_Delivery_Time", + task_id=363672, + dataset_id=46928, + target="Time_taken(min)", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "givemesomecredit": TabArenaDatasetSpec( + slug="givemesomecredit", + name="GiveMeSomeCredit", + task_id=363673, + dataset_id=46929, + target="FinancialDistressNextTwoYears", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "hazelnut-spread-contaminant-detection": TabArenaDatasetSpec( + slug="hazelnut-spread-contaminant-detection", + name="hazelnut-spread-contaminant-detection", + task_id=363674, + dataset_id=46930, + target="Contaminated", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "healthcare-insurance-expenses": TabArenaDatasetSpec( + slug="healthcare-insurance-expenses", + name="healthcare_insurance_expenses", + task_id=363675, + dataset_id=46931, + target="charges", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "heloc": TabArenaDatasetSpec( + slug="heloc", + name="heloc", + task_id=363676, + dataset_id=46932, + target="RiskPerformance", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "hiva-agnostic": TabArenaDatasetSpec( + slug="hiva-agnostic", + name="hiva_agnostic", + task_id=363677, + dataset_id=46933, + target="CompoundActivity", + task_type="Supervised Classification", + num_classes=3, + fold_count=9, + ), + "houses": TabArenaDatasetSpec( + slug="houses", + name="houses", + task_id=363678, + dataset_id=46934, + target="LnMedianHouseValue", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "hr-analytics-job-change-of-data-scientists": TabArenaDatasetSpec( + slug="hr-analytics-job-change-of-data-scientists", + name="HR_Analytics_Job_Change_of_Data_Scientists", + task_id=363679, + dataset_id=46935, + target="LookingForJobChange", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "in-vehicle-coupon-recommendation": TabArenaDatasetSpec( + slug="in-vehicle-coupon-recommendation", + name="in_vehicle_coupon_recommendation", + task_id=363681, + dataset_id=46937, + target="AcceptCoupon", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "is-this-a-good-customer": TabArenaDatasetSpec( + slug="is-this-a-good-customer", + name="Is-this-a-good-customer", + task_id=363682, + dataset_id=46938, + target="bad_client_target", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "jm1": TabArenaDatasetSpec( + slug="jm1", + name="jm1", + task_id=363712, + dataset_id=46979, + target="defects", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "kddcup09-appetency": TabArenaDatasetSpec( + slug="kddcup09-appetency", + name="kddcup09_appetency", + task_id=363683, + dataset_id=46939, + target="appetency", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "marketing-campaign": TabArenaDatasetSpec( + slug="marketing-campaign", + name="Marketing_Campaign", + task_id=363684, + dataset_id=46940, + target="Response", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "maternal-health-risk": TabArenaDatasetSpec( + slug="maternal-health-risk", + name="maternal_health_risk", + task_id=363685, + dataset_id=46941, + target="RiskLevel", + task_type="Supervised Classification", + num_classes=3, + fold_count=30, + ), + "miami-housing": TabArenaDatasetSpec( + slug="miami-housing", + name="miami_housing", + task_id=363686, + dataset_id=46942, + target="SALE_PRC", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "mic": TabArenaDatasetSpec( + slug="mic", + name="MIC", + task_id=363711, + dataset_id=46980, + target="LET_IS", + task_type="Supervised Classification", + num_classes=8, + fold_count=30, + ), + "naticusdroid": TabArenaDatasetSpec( + slug="naticusdroid", + name="NATICUSdroid", + task_id=363689, + dataset_id=46969, + target="Malware", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "online-shoppers-intention": TabArenaDatasetSpec( + slug="online-shoppers-intention", + name="online_shoppers_intention", + task_id=363691, + dataset_id=46947, + target="Revenue", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "physiochemical-protein": TabArenaDatasetSpec( + slug="physiochemical-protein", + name="physiochemical_protein", + task_id=363693, + dataset_id=46949, + target="ResidualSize", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "polish-companies-bankruptcy": TabArenaDatasetSpec( + slug="polish-companies-bankruptcy", + name="polish_companies_bankruptcy", + task_id=363694, + dataset_id=46950, + target="company_bankrupt", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "qsar-biodeg": TabArenaDatasetSpec( + slug="qsar-biodeg", + name="qsar-biodeg", + task_id=363696, + dataset_id=46952, + target="Biodegradable", + task_type="Supervised Classification", + num_classes=2, + fold_count=30, + ), + "qsar-fish-toxicity": TabArenaDatasetSpec( + slug="qsar-fish-toxicity", + name="QSAR_fish_toxicity", + task_id=363698, + dataset_id=46954, + target="LC50", + task_type="Supervised Regression", + num_classes=0, + fold_count=30, + ), + "qsar-tid-11": TabArenaDatasetSpec( + slug="qsar-tid-11", + name="QSAR-TID-11", + task_id=363697, + dataset_id=46953, + target="MEDIAN_PXC50", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "sdss17": TabArenaDatasetSpec( + slug="sdss17", + name="SDSS17", + task_id=363699, + dataset_id=46955, + target="ObjectType", + task_type="Supervised Classification", + num_classes=3, + fold_count=9, + ), + "seismic-bumps": TabArenaDatasetSpec( + slug="seismic-bumps", + name="seismic-bumps", + task_id=363700, + dataset_id=46956, + target="HighEnergySeismicBump", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "splice": TabArenaDatasetSpec( + slug="splice", + name="splice", + task_id=363702, + dataset_id=46958, + target="SiteType", + task_type="Supervised Classification", + num_classes=3, + fold_count=9, + ), + "students-dropout-and-academic-success": TabArenaDatasetSpec( + slug="students-dropout-and-academic-success", + name="students_dropout_and_academic_success", + task_id=363704, + dataset_id=46960, + target="AcademicOutcome", + task_type="Supervised Classification", + num_classes=3, + fold_count=9, + ), + "superconductivity": TabArenaDatasetSpec( + slug="superconductivity", + name="superconductivity", + task_id=363705, + dataset_id=46961, + target="critical_temp", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), + "taiwanese-bankruptcy-prediction": TabArenaDatasetSpec( + slug="taiwanese-bankruptcy-prediction", + name="taiwanese_bankruptcy_prediction", + task_id=363706, + dataset_id=46962, + target="Bankrupt", + task_type="Supervised Classification", + num_classes=2, + fold_count=9, + ), + "website-phishing": TabArenaDatasetSpec( + slug="website-phishing", + name="website_phishing", + task_id=363707, + dataset_id=46963, + target="WebsiteType", + task_type="Supervised Classification", + num_classes=3, + fold_count=30, + ), + "wine-quality": TabArenaDatasetSpec( + slug="wine-quality", + name="wine_quality", + task_id=363708, + dataset_id=46964, + target="median_wine_quality", + task_type="Supervised Regression", + num_classes=0, + fold_count=9, + ), +} + + +def get_tabarena_dataset_slugs() -> list[str]: + return sorted(TABARENA_DATASETS.keys()) + + +def _import_openml(): + try: + import openml + except ImportError as exc: + raise ImportError( + "TabArena datasets require the `openml` package. Install it with " + "`pip install relbench[tabarena]` (or `pip install openml`)." + ) from exc + return openml + + +def _problem_type_from_spec(task_type: str, num_classes: int) -> str: + if task_type == "Supervised Regression": + return "regression" + if int(num_classes) <= 2: + return "binary" + return "multiclass" + + +class TabArenaDataset(Dataset): + r"""Single-table RelBench dataset wrapper over TabArena OpenML tasks.""" + + url = "https://huggingface.co/datasets/TabArena/benchmark_results" + val_timestamp = pd.Timestamp("2000-01-02") + test_timestamp = pd.Timestamp("2000-01-03") + + def __init__(self, *, dataset_slug: str, cache_dir: Optional[str] = None): + slug = str(dataset_slug) + if slug not in TABARENA_DATASETS: + raise ValueError( + f"Unknown TabArena dataset slug={slug!r}. Known values: {sorted(TABARENA_DATASETS.keys())}" + ) + + self.spec = TABARENA_DATASETS[slug] + self.name = f"tabarena-{slug}" + self.tabarena_name = self.spec.name + self.task_id = int(self.spec.task_id) + self.openml_dataset_id = int(self.spec.dataset_id) + self.target_name = self.spec.target + self.problem_type = _problem_type_from_spec( + self.spec.task_type, self.spec.num_classes + ) + self.num_classes = ( + int(self.spec.num_classes) if self.problem_type != "regression" else 0 + ) + + self._openml_task = None + self._X_df: Optional[pd.DataFrame] = None + self._y_encoded: Optional[np.ndarray] = None + + super().__init__(cache_dir=cache_dir) + + @property + def available_folds(self) -> list[int]: + return list(range(int(self.spec.fold_count))) + + def _load_task_with_retry(self, task_id: int, retries: int = 4): + openml = _import_openml() + delay = 1.0 + for attempt in range(retries + 1): + try: + return openml.tasks.get_task( + task_id, + download_splits=True, + download_data=True, + download_qualities=False, + download_features_meta_data=True, + ) + except Exception: + if attempt == retries: + raise + time.sleep(delay) + delay *= 2.0 + + def _ensure_openml_loaded(self) -> None: + if ( + self._openml_task is not None + and self._X_df is not None + and self._y_encoded is not None + ): + return + + task = self._load_task_with_retry(self.task_id) + X_df, y_ser, _cat, _names = task.get_dataset().get_data( + target=task.target_name, + dataset_format="dataframe", + ) + + X_df = pd.DataFrame(X_df).reset_index(drop=True) + y_ser = pd.Series(y_ser, name=task.target_name).reset_index(drop=True) + + if self.problem_type == "regression": + y_encoded = y_ser.astype(float).to_numpy(copy=True) + else: + cat = pd.Categorical(y_ser) + if cat.codes.min() < 0: + raise RuntimeError( + f"Encountered missing labels in OpenML task_id={self.task_id} ({self.tabarena_name})." + ) + y_encoded = cat.codes.astype(np.int64, copy=False) + detected_num_classes = int(len(cat.categories)) + if self.num_classes and detected_num_classes != self.num_classes: + raise RuntimeError( + f"Label cardinality mismatch for {self.tabarena_name}: expected {self.num_classes}, got {detected_num_classes}." + ) + self.num_classes = detected_num_classes + + if len(X_df) != len(y_encoded): + raise RuntimeError( + f"Feature/label row mismatch for {self.tabarena_name}: {len(X_df)} vs {len(y_encoded)}." + ) + + self._openml_task = task + self._X_df = X_df + self._y_encoded = y_encoded + + def get_openml_task(self): + self._ensure_openml_loaded() + return self._openml_task + + def get_target_array(self) -> np.ndarray: + self._ensure_openml_loaded() + assert self._y_encoded is not None + return self._y_encoded + + def get_openml_fold_indices(self, fold: int) -> tuple[np.ndarray, np.ndarray]: + fold = int(fold) + if fold < 0 or fold >= int(self.spec.fold_count): + raise ValueError( + f"Invalid fold={fold} for {self.name}. Valid folds are 0..{int(self.spec.fold_count) - 1}." + ) + + task = self.get_openml_task() + n_repeats, n_folds, _n_samples = task.get_split_dimensions() + repeat = fold // int(n_folds) + fold_in_repeat = fold % int(n_folds) + if repeat >= int(n_repeats): + raise ValueError( + f"Fold index {fold} exceeds OpenML split dimensions for {self.name}: repeats={n_repeats}, folds={n_folds}." + ) + + train_idx, test_idx = task.get_train_test_split_indices( + repeat=repeat, + fold=fold_in_repeat, + sample=0, + ) + return ( + np.asarray(train_idx, dtype=np.int64), + np.asarray(test_idx, dtype=np.int64), + ) + + def make_db(self) -> Database: + self._ensure_openml_loaded() + assert self._X_df is not None + + records = self._X_df.copy(deep=True) + records.insert(0, "record_id", np.arange(len(records), dtype=np.int64)) + + return Database( + { + "records": Table( + df=records, + fkey_col_to_pkey_table={}, + pkey_col="record_id", + time_col=None, + ) + } + ) diff --git a/relbench/tasks/__init__.py b/relbench/tasks/__init__.py index cb643d95..b654fcf2 100644 --- a/relbench/tasks/__init__.py +++ b/relbench/tasks/__init__.py @@ -9,6 +9,7 @@ from relbench.base import AutoCompleteTask, BaseTask, TaskType from relbench.datasets import get_dataset +from relbench.datasets.tabarena import TABARENA_DATASETS from relbench.tasks import ( amazon, arxiv, @@ -20,6 +21,7 @@ mimic, ratebeer, stack, + tabarena, tgb, trial, ) @@ -76,6 +78,13 @@ def download_task(dataset_name: str, task_name: str) -> None: `task.get_table(split)` is called. """ + if dataset_name.startswith("tabarena-"): + print( + f"Task '{dataset_name}/{task_name}' is derived from TabArena OpenML tasks " + "and must be generated locally; skipping download." + ) + return + DOWNLOAD_REGISTRY.fetch( f"{dataset_name}/tasks/{task_name}.zip", processor=pooch.Unzip(extract_dir="."), @@ -594,3 +603,13 @@ def _register_thgl_edge_type_tasks(dataset_name: str, edge_types: list[int]) -> spec=tgb.TGBNodePropSpec(), k=10, ) + +for dataset_slug, spec in TABARENA_DATASETS.items(): + dataset_name = f"tabarena-{dataset_slug}" + for fold in range(spec.fold_count): + register_task( + dataset_name, + f"fold-{fold}", + tabarena.TabArenaFoldEntityTask, + fold=fold, + ) diff --git a/relbench/tasks/tabarena.py b/relbench/tasks/tabarena.py new file mode 100644 index 00000000..c092a10f --- /dev/null +++ b/relbench/tasks/tabarena.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import Optional + +import numpy as np +import pandas as pd +from sklearn.metrics import log_loss as sklearn_log_loss +from sklearn.metrics import roc_auc_score +from sklearn.model_selection import train_test_split + +from relbench.base import EntityTask, Table, TaskType +from relbench.datasets.tabarena import TabArenaDataset + +_SPLIT_TIMESTAMPS = { + "train": pd.Timestamp("2000-01-01"), + "val": pd.Timestamp("2000-01-02"), + "test": pd.Timestamp("2000-01-03"), +} + + +def _binary_metric_error(true: np.ndarray, pred: np.ndarray) -> float: + pred = np.asarray(pred, dtype=np.float64) + if pred.ndim > 1: + pred = pred.reshape(pred.shape[0], -1) + if pred.shape[1] == 1: + pred = pred[:, 0] + else: + pred = pred[:, 1] + if pred.min() < 0.0 or pred.max() > 1.0: + pred = np.clip(pred, -40.0, 40.0) + pred = 1.0 / (1.0 + np.exp(-pred)) + score = roc_auc_score(np.asarray(true, dtype=np.int64), pred) + return float(1.0 - score) + + +def _softmax(x: np.ndarray) -> np.ndarray: + x = x - np.max(x, axis=1, keepdims=True) + exp_x = np.exp(x) + return exp_x / np.sum(exp_x, axis=1, keepdims=True) + + +def _multiclass_metric_error(true: np.ndarray, pred: np.ndarray) -> float: + return _multiclass_metric_error_with_num_classes(true, pred, num_classes=None) + + +def _multiclass_metric_error_with_num_classes( + true: np.ndarray, + pred: np.ndarray, + num_classes: Optional[int], +) -> float: + pred = np.asarray(pred, dtype=np.float64) + true_arr = np.asarray(true, dtype=np.int64) + + if num_classes is not None: + inferred_num_classes = int(num_classes) + elif pred.ndim == 2: + inferred_num_classes = int(pred.shape[1]) + else: + inferred_num_classes = int( + max(true_arr.max(initial=0), pred.max(initial=0)) + 1 + ) + if inferred_num_classes <= 1: + inferred_num_classes = 2 + + if pred.ndim == 1: + pred_labels = pred.astype(np.int64, copy=False) + pred_labels = np.clip(pred_labels, 0, inferred_num_classes - 1) + eps = 1e-7 + probs = np.full( + (len(pred_labels), inferred_num_classes), + fill_value=eps / max(inferred_num_classes - 1, 1), + dtype=np.float64, + ) + probs[np.arange(len(pred_labels), dtype=np.int64), pred_labels] = 1.0 - eps + elif pred.ndim == 2: + probs = pred + if probs.shape[1] < inferred_num_classes: + padding = np.full( + (probs.shape[0], inferred_num_classes - probs.shape[1]), + fill_value=0.0, + dtype=np.float64, + ) + probs = np.hstack([probs, padding]) + elif probs.shape[1] > inferred_num_classes: + probs = probs[:, :inferred_num_classes] + + row_sums = probs.sum(axis=1) + if np.all(probs >= 0.0) and np.allclose(row_sums, 1.0, atol=1e-4): + pass + else: + probs = _softmax(probs) + else: + raise ValueError( + "Expected multiclass predictions with shape (N,) or (N, num_classes). " + f"Got shape {pred.shape}." + ) + + labels = np.arange(inferred_num_classes, dtype=np.int64) + return float( + sklearn_log_loss( + true_arr, + probs, + labels=labels, + ) + ) + + +def _regression_metric_error(true: np.ndarray, pred: np.ndarray) -> float: + true = np.asarray(true, dtype=np.float64) + pred = np.asarray(pred, dtype=np.float64).reshape(-1) + return float(np.sqrt(np.mean((true - pred) ** 2))) + + +_binary_metric_error.__name__ = "metric_error" +_multiclass_metric_error.__name__ = "metric_error" +_regression_metric_error.__name__ = "metric_error" + + +class TabArenaFoldEntityTask(EntityTask): + r"""Single-table TabArena task for a specific OpenML fold index.""" + + entity_col = "record_id" + entity_table = "records" + time_col = "timestamp" + target_col = "target" + timedelta = pd.Timedelta(days=1) + num_eval_timestamps = 1 + + def __init__( + self, + dataset, + *, + fold: int, + val_frac: float = 0.2, + random_state: Optional[int] = None, + cache_dir: Optional[str] = None, + ): + if not isinstance(dataset, TabArenaDataset): + raise TypeError( + "TabArenaFoldEntityTask expects a TabArenaDataset instance. " + f"Got {type(dataset)}" + ) + + self.fold = int(fold) + self.val_frac = float(val_frac) + if not (0.0 < self.val_frac < 1.0): + raise ValueError(f"val_frac must be in (0, 1), got {self.val_frac}") + self.random_state = self.fold if random_state is None else int(random_state) + + if self.fold not in dataset.available_folds: + raise ValueError( + f"Fold={self.fold} is unavailable for {dataset.name}. " + f"Available folds: {dataset.available_folds}" + ) + + self.problem_type = dataset.problem_type + if self.problem_type == "regression": + self.task_type = TaskType.REGRESSION + self.metrics = [_regression_metric_error] + elif self.problem_type == "binary": + self.task_type = TaskType.BINARY_CLASSIFICATION + self.metrics = [_binary_metric_error] + elif self.problem_type == "multiclass": + self.task_type = TaskType.MULTICLASS_CLASSIFICATION + self.num_classes = int(dataset.num_classes) + self.metrics = [ + _make_multiclass_metric_error_with_num_classes(self.num_classes) + ] + else: + raise ValueError(f"Unsupported problem_type={self.problem_type}") + + super().__init__(dataset, cache_dir=cache_dir) + + def make_table(self, db, timestamps): # pragma: no cover + raise RuntimeError( + "TabArenaFoldEntityTask uses precomputed OpenML fold indices and overrides _get_table()." + ) + + @lru_cache(maxsize=None) + def _split_indices(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + train_idx, test_idx = self.dataset.get_openml_fold_indices(self.fold) + y = self.dataset.get_target_array() + + stratify = ( + y[train_idx] if self.problem_type in {"binary", "multiclass"} else None + ) + try: + train_idx, val_idx = train_test_split( + train_idx, + test_size=self.val_frac, + random_state=self.random_state, + shuffle=True, + stratify=stratify, + ) + except ValueError: + train_idx, val_idx = train_test_split( + train_idx, + test_size=self.val_frac, + random_state=self.random_state, + shuffle=True, + stratify=None, + ) + + return ( + np.asarray(train_idx, dtype=np.int64), + np.asarray(val_idx, dtype=np.int64), + np.asarray(test_idx, dtype=np.int64), + ) + + def _get_table(self, split: str) -> Table: + if split not in _SPLIT_TIMESTAMPS: + raise ValueError( + f"Unknown split={split!r}. Expected one of {sorted(_SPLIT_TIMESTAMPS.keys())}." + ) + + train_idx, val_idx, test_idx = self._split_indices() + if split == "train": + idx = train_idx + elif split == "val": + idx = val_idx + else: + idx = test_idx + + y = self.dataset.get_target_array() + target = y[idx] + + df = pd.DataFrame( + { + self.time_col: _SPLIT_TIMESTAMPS[split], + self.entity_col: idx.astype(np.int64, copy=False), + self.target_col: target, + } + ) + + if self.task_type != TaskType.REGRESSION: + df[self.target_col] = df[self.target_col].astype(np.int64, copy=False) + + return Table( + df=df, + fkey_col_to_pkey_table={self.entity_col: self.entity_table}, + pkey_col=None, + time_col=self.time_col, + ) + + +def _make_multiclass_metric_error_with_num_classes( + num_classes: int, +): + def _metric(true: np.ndarray, pred: np.ndarray) -> float: + return _multiclass_metric_error_with_num_classes( + true, + pred, + num_classes=num_classes, + ) + + _metric.__name__ = "metric_error" + return _metric diff --git a/test/datasets/test_tabarena.py b/test/datasets/test_tabarena.py new file mode 100644 index 00000000..55c8f8c3 --- /dev/null +++ b/test/datasets/test_tabarena.py @@ -0,0 +1,158 @@ +import numpy as np +import pandas as pd + +from relbench.datasets.tabarena import TabArenaDataset +from relbench.tasks.tabarena import TabArenaFoldEntityTask + + +class _FakeOpenMLDataset: + def __init__(self, X_df: pd.DataFrame, y_ser: pd.Series, target_name: str): + self._X_df = X_df + self._y_ser = y_ser + self._target_name = str(target_name) + + def get_data(self, *, target: str, dataset_format: str): + assert target == self._target_name + assert dataset_format == "dataframe" + categorical_indicator = [False] * len(self._X_df.columns) + attribute_names = list(self._X_df.columns) + return self._X_df, self._y_ser, categorical_indicator, attribute_names + + +class _FakeOpenMLTask: + def __init__( + self, + *, + target_name: str, + X_df: pd.DataFrame, + y_ser: pd.Series, + n_repeats: int, + n_folds: int, + ): + self.target_name = str(target_name) + self._dataset = _FakeOpenMLDataset( + X_df=X_df, y_ser=y_ser, target_name=target_name + ) + self._n_repeats = int(n_repeats) + self._n_folds = int(n_folds) + self._n_samples = int(len(X_df)) + + def get_dataset(self): + return self._dataset + + def get_split_dimensions(self): + return self._n_repeats, self._n_folds, self._n_samples + + def get_train_test_split_indices(self, *, repeat: int, fold: int, sample: int): + assert int(sample) == 0 + repeat = int(repeat) + fold = int(fold) + if repeat < 0 or repeat >= self._n_repeats: + raise ValueError( + f"repeat={repeat} out of range for n_repeats={self._n_repeats}" + ) + if fold < 0 or fold >= self._n_folds: + raise ValueError(f"fold={fold} out of range for n_folds={self._n_folds}") + + # Simple deterministic CV split: each fold takes indices i where i % n_folds == fold. + idx = np.arange(self._n_samples, dtype=np.int64) + test_idx = idx[idx % self._n_folds == fold] + train_idx = idx[idx % self._n_folds != fold] + return train_idx, test_idx + + +def _install_fake_openml(monkeypatch): + def _fake_load_task_with_retry( + self: TabArenaDataset, task_id: int, retries: int = 4 + ): + _ = task_id + _ = retries + n_samples = 90 + n_folds = int(self.spec.fold_count) + target_name = self.spec.target + + X_df = pd.DataFrame( + { + "feat_num": np.arange(n_samples, dtype=np.int64), + "feat_mod3": np.arange(n_samples, dtype=np.int64) % 3, + } + ) + if self.problem_type == "regression": + y_ser = pd.Series(np.linspace(0.0, 1.0, n_samples), name=target_name) + elif self.problem_type == "binary": + y_ser = pd.Series(["no", "yes"] * (n_samples // 2), name=target_name) + else: + classes = [f"class_{i}" for i in range(int(self.spec.num_classes))] + y_ser = pd.Series( + [classes[i % len(classes)] for i in range(n_samples)], name=target_name + ) + + return _FakeOpenMLTask( + target_name=target_name, + X_df=X_df, + y_ser=y_ser, + n_repeats=1, + n_folds=n_folds, + ) + + monkeypatch.setattr( + TabArenaDataset, "_load_task_with_retry", _fake_load_task_with_retry + ) + + +def test_tabarena_dataset_and_task_binary(monkeypatch): + _install_fake_openml(monkeypatch) + + dataset = TabArenaDataset(dataset_slug="apsfailure", cache_dir=None) + db = dataset.get_db() + records = db.table_dict["records"] + assert records.pkey_col == "record_id" + assert records.time_col is None + assert len(records) == 90 + + train_idx, test_idx = dataset.get_openml_fold_indices(0) + assert train_idx.dtype == np.int64 + assert test_idx.dtype == np.int64 + assert set(train_idx).isdisjoint(set(test_idx)) + + task = TabArenaFoldEntityTask(dataset, fold=0, cache_dir=None) + train_table = task.get_table("train") + assert set(train_table.df.columns) == {"timestamp", "record_id", "target"} + test_table = task.get_table("test") + assert set(test_table.df.columns) == {"timestamp", "record_id"} + + # Perfect predictions yield AUC=1.0 => metric_error=0.0. + full_test = task.get_table("test", mask_input_cols=False) + y_true = full_test.df["target"].to_numpy() + metrics = task.evaluate(y_true, target_table=full_test) + assert metrics["metric_error"] == 0.0 + + +def test_tabarena_dataset_and_task_regression(monkeypatch): + _install_fake_openml(monkeypatch) + + dataset = TabArenaDataset(dataset_slug="diamonds", cache_dir=None) + task = TabArenaFoldEntityTask(dataset, fold=0, cache_dir=None) + + full_test = task.get_table("test", mask_input_cols=False) + y_true = full_test.df["target"].to_numpy() + metrics = task.evaluate(y_true, target_table=full_test) + assert metrics["metric_error"] == 0.0 + + +def test_tabarena_task_multiclass(monkeypatch): + _install_fake_openml(monkeypatch) + + dataset = TabArenaDataset(dataset_slug="splice", cache_dir=None) + task = TabArenaFoldEntityTask(dataset, fold=0, cache_dir=None) + + full_test = task.get_table("test", mask_input_cols=False) + y_true = full_test.df["target"].to_numpy() + num_classes = int(dataset.num_classes) + + # Use a uniform distribution: log loss is finite and should be > 0. + probs = np.full( + (len(y_true), num_classes), fill_value=1.0 / num_classes, dtype=np.float64 + ) + metrics = task.evaluate(probs, target_table=full_test) + assert metrics["metric_error"] > 0.0 From f8557b67ce41f14706961c8052e9380a64369c95 Mon Sep 17 00:00:00 2001 From: pc0618 Date: Wed, 4 Mar 2026 04:25:00 +0000 Subject: [PATCH 2/8] Rename TabArena folds to splits and remove synthetic task timestamps --- README.md | 2 +- examples/translate_tabarena_to_relbench.py | 28 ++++---- relbench/datasets/tabarena.py | 25 ++++--- relbench/tasks/__init__.py | 8 +-- relbench/tasks/tabarena.py | 83 ++++++++++++++++------ test/datasets/test_tabarena.py | 14 ++-- 6 files changed, 103 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index a3130746..f3957199 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ If you use the 4DBInfer datasets in your work, please cite [4DBInfer](https://gi **Using TabArena datasets** -RelBench includes an optional integration for TabArena: a collection of single-table OpenML tasks. TabArena datasets are exposed under names like `tabarena-credit-g`, with per-fold tasks named `fold-0`, `fold-1`, etc. +RelBench includes an optional integration for TabArena: a collection of single-table OpenML tasks. TabArena datasets are exposed under names like `tabarena-credit-g`, with per-split tasks named `split-0`, `split-1`, etc. To use TabArena datasets, install the optional dependency: ```bash diff --git a/examples/translate_tabarena_to_relbench.py b/examples/translate_tabarena_to_relbench.py index 66ddab44..a8d69c6c 100644 --- a/examples/translate_tabarena_to_relbench.py +++ b/examples/translate_tabarena_to_relbench.py @@ -26,10 +26,10 @@ def _parse_args() -> argparse.Namespace: default="results/tabarena_relbench_translation.csv", ) parser.add_argument( - "--include_fold_examples", + "--include_split_examples", action="store_true", default=False, - help="If set, writes one sample record for fold-0 per dataset.", + help="If set, writes one sample record for split-0 per dataset.", ) return parser.parse_args() @@ -59,7 +59,7 @@ def _summarize_dataset(dataset_name: str) -> dict: "target_col": spec.target, "problem_type": spec.task_type, "num_classes": spec.num_classes, - "fold_count": spec.fold_count, + "split_count": spec.fold_count, "records_rows": int(len(records.df)), "records_columns": int(len(records.df.columns)), "entity_table": "records", @@ -69,23 +69,23 @@ def _summarize_dataset(dataset_name: str) -> dict: return row -def _summarize_fold_example(dataset_name: str) -> list[dict]: +def _summarize_split_example(dataset_name: str) -> list[dict]: from relbench.tasks import get_task rows: list[dict] = [] - for fold in [0]: - task = get_task(dataset_name, f"fold-{fold}") + for split in [0]: + task = get_task(dataset_name, f"split-{split}") train = task.get_table("train", mask_input_cols=False) val = task.get_table("val", mask_input_cols=False) test = task.get_table("test", mask_input_cols=False) rows.append( { "dataset_name": dataset_name, - "fold": int(fold), + "split": int(split), "task_type": str(task.task_type.value), - "fold_train_rows": int(len(train)), - "fold_val_rows": int(len(val)), - "fold_test_rows": int(len(test)), + "split_train_rows": int(len(train)), + "split_val_rows": int(len(val)), + "split_test_rows": int(len(test)), "train_columns": int(len(train.df.columns)), "time_col": str(train.time_col), } @@ -105,17 +105,17 @@ def main() -> None: dataset_name = f"tabarena-{slug}" print(f"[Inspect] {dataset_name}") dataset_rows.append(_summarize_dataset(dataset_name)) - if args.include_fold_examples: - split_rows.extend(_summarize_fold_example(dataset_name)) + if args.include_split_examples: + split_rows.extend(_summarize_split_example(dataset_name)) df = pd.DataFrame(dataset_rows) df.to_csv(output_path, index=False) if split_rows: - split_path = output_path.with_name(f"{output_path.stem}_fold_samples.csv") + split_path = output_path.with_name(f"{output_path.stem}_split_samples.csv") pd.DataFrame(split_rows).to_csv(split_path, index=False) print(f"[Done] dataset summary: {output_path}") - print(f"[Done] fold sample summary: {split_path}") + print(f"[Done] split sample summary: {split_path}") else: print(f"[Done] dataset summary: {output_path}") diff --git a/relbench/datasets/tabarena.py b/relbench/datasets/tabarena.py index 514881a6..0f4157f8 100644 --- a/relbench/datasets/tabarena.py +++ b/relbench/datasets/tabarena.py @@ -593,9 +593,14 @@ def __init__(self, *, dataset_slug: str, cache_dir: Optional[str] = None): super().__init__(cache_dir=cache_dir) @property - def available_folds(self) -> list[int]: + def available_splits(self) -> list[int]: return list(range(int(self.spec.fold_count))) + @property + def available_folds(self) -> list[int]: + # Backward-compatible alias. + return self.available_splits + def _load_task_with_retry(self, task_id: int, retries: int = 4): openml = _import_openml() delay = 1.0 @@ -665,20 +670,20 @@ def get_target_array(self) -> np.ndarray: assert self._y_encoded is not None return self._y_encoded - def get_openml_fold_indices(self, fold: int) -> tuple[np.ndarray, np.ndarray]: - fold = int(fold) - if fold < 0 or fold >= int(self.spec.fold_count): + def get_openml_split_indices(self, split: int) -> tuple[np.ndarray, np.ndarray]: + split = int(split) + if split < 0 or split >= int(self.spec.fold_count): raise ValueError( - f"Invalid fold={fold} for {self.name}. Valid folds are 0..{int(self.spec.fold_count) - 1}." + f"Invalid split={split} for {self.name}. Valid splits are 0..{int(self.spec.fold_count) - 1}." ) task = self.get_openml_task() n_repeats, n_folds, _n_samples = task.get_split_dimensions() - repeat = fold // int(n_folds) - fold_in_repeat = fold % int(n_folds) + repeat = split // int(n_folds) + fold_in_repeat = split % int(n_folds) if repeat >= int(n_repeats): raise ValueError( - f"Fold index {fold} exceeds OpenML split dimensions for {self.name}: repeats={n_repeats}, folds={n_folds}." + f"Split index {split} exceeds OpenML split dimensions for {self.name}: repeats={n_repeats}, folds={n_folds}." ) train_idx, test_idx = task.get_train_test_split_indices( @@ -691,6 +696,10 @@ def get_openml_fold_indices(self, fold: int) -> tuple[np.ndarray, np.ndarray]: np.asarray(test_idx, dtype=np.int64), ) + def get_openml_fold_indices(self, fold: int) -> tuple[np.ndarray, np.ndarray]: + # Backward-compatible alias. + return self.get_openml_split_indices(split=fold) + def make_db(self) -> Database: self._ensure_openml_loaded() assert self._X_df is not None diff --git a/relbench/tasks/__init__.py b/relbench/tasks/__init__.py index b654fcf2..bf6a5f4d 100644 --- a/relbench/tasks/__init__.py +++ b/relbench/tasks/__init__.py @@ -606,10 +606,10 @@ def _register_thgl_edge_type_tasks(dataset_name: str, edge_types: list[int]) -> for dataset_slug, spec in TABARENA_DATASETS.items(): dataset_name = f"tabarena-{dataset_slug}" - for fold in range(spec.fold_count): + for split in range(spec.fold_count): register_task( dataset_name, - f"fold-{fold}", - tabarena.TabArenaFoldEntityTask, - fold=fold, + f"split-{split}", + tabarena.TabArenaSplitEntityTask, + split=split, ) diff --git a/relbench/tasks/tabarena.py b/relbench/tasks/tabarena.py index c092a10f..2ba9a077 100644 --- a/relbench/tasks/tabarena.py +++ b/relbench/tasks/tabarena.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import lru_cache -from typing import Optional +from typing import Any, Dict, Optional import numpy as np import pandas as pd @@ -12,12 +12,6 @@ from relbench.base import EntityTask, Table, TaskType from relbench.datasets.tabarena import TabArenaDataset -_SPLIT_TIMESTAMPS = { - "train": pd.Timestamp("2000-01-01"), - "val": pd.Timestamp("2000-01-02"), - "test": pd.Timestamp("2000-01-03"), -} - def _binary_metric_error(true: np.ndarray, pred: np.ndarray) -> float: pred = np.asarray(pred, dtype=np.float64) @@ -117,12 +111,12 @@ def _regression_metric_error(true: np.ndarray, pred: np.ndarray) -> float: _regression_metric_error.__name__ = "metric_error" -class TabArenaFoldEntityTask(EntityTask): - r"""Single-table TabArena task for a specific OpenML fold index.""" +class TabArenaSplitEntityTask(EntityTask): + r"""Single-table TabArena task for a specific OpenML split index.""" entity_col = "record_id" entity_table = "records" - time_col = "timestamp" + time_col = None target_col = "target" timedelta = pd.Timedelta(days=1) num_eval_timestamps = 1 @@ -131,27 +125,36 @@ def __init__( self, dataset, *, - fold: int, + split: Optional[int] = None, + fold: Optional[int] = None, val_frac: float = 0.2, random_state: Optional[int] = None, cache_dir: Optional[str] = None, ): if not isinstance(dataset, TabArenaDataset): raise TypeError( - "TabArenaFoldEntityTask expects a TabArenaDataset instance. " + "TabArenaSplitEntityTask expects a TabArenaDataset instance. " f"Got {type(dataset)}" ) - self.fold = int(fold) + if split is None and fold is None: + raise ValueError("Exactly one of `split` or `fold` must be provided.") + if split is not None and fold is not None and int(split) != int(fold): + raise ValueError( + f"Received conflicting split={split} and fold={fold}; please provide one index." + ) + + self.split = int(split if split is not None else fold) + self.fold = self.split # Backward-compatible alias. self.val_frac = float(val_frac) if not (0.0 < self.val_frac < 1.0): raise ValueError(f"val_frac must be in (0, 1), got {self.val_frac}") - self.random_state = self.fold if random_state is None else int(random_state) + self.random_state = self.split if random_state is None else int(random_state) - if self.fold not in dataset.available_folds: + if self.split not in dataset.available_splits: raise ValueError( - f"Fold={self.fold} is unavailable for {dataset.name}. " - f"Available folds: {dataset.available_folds}" + f"Split={self.split} is unavailable for {dataset.name}. " + f"Available splits: {dataset.available_splits}" ) self.problem_type = dataset.problem_type @@ -174,12 +177,12 @@ def __init__( def make_table(self, db, timestamps): # pragma: no cover raise RuntimeError( - "TabArenaFoldEntityTask uses precomputed OpenML fold indices and overrides _get_table()." + "TabArenaSplitEntityTask uses precomputed OpenML split indices and overrides _get_table()." ) @lru_cache(maxsize=None) def _split_indices(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - train_idx, test_idx = self.dataset.get_openml_fold_indices(self.fold) + train_idx, test_idx = self.dataset.get_openml_split_indices(self.split) y = self.dataset.get_target_array() stratify = ( @@ -209,9 +212,10 @@ def _split_indices(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) def _get_table(self, split: str) -> Table: - if split not in _SPLIT_TIMESTAMPS: + if split not in {"train", "val", "test"}: raise ValueError( - f"Unknown split={split!r}. Expected one of {sorted(_SPLIT_TIMESTAMPS.keys())}." + "Unknown split=" + f"{split!r}. Expected one of ['test', 'train', 'val']." ) train_idx, val_idx, test_idx = self._split_indices() @@ -227,7 +231,6 @@ def _get_table(self, split: str) -> Table: df = pd.DataFrame( { - self.time_col: _SPLIT_TIMESTAMPS[split], self.entity_col: idx.astype(np.int64, copy=False), self.target_col: target, } @@ -240,9 +243,43 @@ def _get_table(self, split: str) -> Table: df=df, fkey_col_to_pkey_table={self.entity_col: self.entity_table}, pkey_col=None, - time_col=self.time_col, + time_col=None, ) + def stats(self) -> Dict[str, Dict[str, Any]]: + r"""Get split-level statistics for tasks without a time column.""" + res: Dict[str, Dict[str, Any]] = {} + for split in ["train", "val", "test"]: + table = self.get_table(split, mask_input_cols=False) + stats: Dict[str, Any] = { + "num_rows": len(table.df), + "num_unique_entities": table.df[self.entity_col].nunique(), + } + self._set_stats(table.df, stats) + res[split] = {"total": stats} + + total_df = pd.concat( + [ + self.get_table(split, mask_input_cols=False).df + for split in ["train", "val", "test"] + ] + ) + res["total"] = {} + self._set_stats(total_df, res["total"]) + + train_uniques = set(self.get_table("train").df[self.entity_col].unique()) + test_uniques = set( + self.get_table("test", mask_input_cols=False).df[self.entity_col].unique() + ) + res["total"]["ratio_train_test_entity_overlap"] = len( + train_uniques.intersection(test_uniques) + ) / len(test_uniques) + return res + + +# Backward-compatible alias. +TabArenaFoldEntityTask = TabArenaSplitEntityTask + def _make_multiclass_metric_error_with_num_classes( num_classes: int, diff --git a/test/datasets/test_tabarena.py b/test/datasets/test_tabarena.py index 55c8f8c3..6bf7c2e4 100644 --- a/test/datasets/test_tabarena.py +++ b/test/datasets/test_tabarena.py @@ -2,7 +2,7 @@ import pandas as pd from relbench.datasets.tabarena import TabArenaDataset -from relbench.tasks.tabarena import TabArenaFoldEntityTask +from relbench.tasks.tabarena import TabArenaSplitEntityTask class _FakeOpenMLDataset: @@ -110,16 +110,16 @@ def test_tabarena_dataset_and_task_binary(monkeypatch): assert records.time_col is None assert len(records) == 90 - train_idx, test_idx = dataset.get_openml_fold_indices(0) + train_idx, test_idx = dataset.get_openml_split_indices(0) assert train_idx.dtype == np.int64 assert test_idx.dtype == np.int64 assert set(train_idx).isdisjoint(set(test_idx)) - task = TabArenaFoldEntityTask(dataset, fold=0, cache_dir=None) + task = TabArenaSplitEntityTask(dataset, split=0, cache_dir=None) train_table = task.get_table("train") - assert set(train_table.df.columns) == {"timestamp", "record_id", "target"} + assert set(train_table.df.columns) == {"record_id", "target"} test_table = task.get_table("test") - assert set(test_table.df.columns) == {"timestamp", "record_id"} + assert set(test_table.df.columns) == {"record_id"} # Perfect predictions yield AUC=1.0 => metric_error=0.0. full_test = task.get_table("test", mask_input_cols=False) @@ -132,7 +132,7 @@ def test_tabarena_dataset_and_task_regression(monkeypatch): _install_fake_openml(monkeypatch) dataset = TabArenaDataset(dataset_slug="diamonds", cache_dir=None) - task = TabArenaFoldEntityTask(dataset, fold=0, cache_dir=None) + task = TabArenaSplitEntityTask(dataset, split=0, cache_dir=None) full_test = task.get_table("test", mask_input_cols=False) y_true = full_test.df["target"].to_numpy() @@ -144,7 +144,7 @@ def test_tabarena_task_multiclass(monkeypatch): _install_fake_openml(monkeypatch) dataset = TabArenaDataset(dataset_slug="splice", cache_dir=None) - task = TabArenaFoldEntityTask(dataset, fold=0, cache_dir=None) + task = TabArenaSplitEntityTask(dataset, split=0, cache_dir=None) full_test = task.get_table("test", mask_input_cols=False) y_true = full_test.df["target"].to_numpy() From c7d62a90c6cea28928fc311e3fff03cd8ab5c55b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 04:25:26 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- relbench/tasks/tabarena.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/relbench/tasks/tabarena.py b/relbench/tasks/tabarena.py index 2ba9a077..5c08d711 100644 --- a/relbench/tasks/tabarena.py +++ b/relbench/tasks/tabarena.py @@ -214,8 +214,7 @@ def _split_indices(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: def _get_table(self, split: str) -> Table: if split not in {"train", "val", "test"}: raise ValueError( - "Unknown split=" - f"{split!r}. Expected one of ['test', 'train', 'val']." + "Unknown split=" f"{split!r}. Expected one of ['test', 'train', 'val']." ) train_idx, val_idx, test_idx = self._split_indices() From f796cba2deb0bcce2334bacf4e4bbcdfa547794b Mon Sep 17 00:00:00 2001 From: pc0618 Date: Wed, 4 Mar 2026 04:25:42 +0000 Subject: [PATCH 4/8] Guard TabArena stats overlap when test split is empty --- relbench/tasks/tabarena.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/relbench/tasks/tabarena.py b/relbench/tasks/tabarena.py index 5c08d711..3960546e 100644 --- a/relbench/tasks/tabarena.py +++ b/relbench/tasks/tabarena.py @@ -270,9 +270,10 @@ def stats(self) -> Dict[str, Dict[str, Any]]: test_uniques = set( self.get_table("test", mask_input_cols=False).df[self.entity_col].unique() ) - res["total"]["ratio_train_test_entity_overlap"] = len( - train_uniques.intersection(test_uniques) - ) / len(test_uniques) + overlap = len(train_uniques.intersection(test_uniques)) + res["total"]["ratio_train_test_entity_overlap"] = ( + overlap / len(test_uniques) if test_uniques else float("nan") + ) return res From 790b5dbc9ef99267408c40bc45269c4dfd36bd8c Mon Sep 17 00:00:00 2001 From: pc0618 Date: Wed, 4 Mar 2026 05:20:48 +0000 Subject: [PATCH 5/8] Make TabArena task tables edge-free and add PluRel-16B runbook --- README.md | 3 + examples/tabarena_plurel16b_inference.md | 150 +++++++++++++++++++++++ relbench/tasks/tabarena.py | 13 +- test/datasets/test_tabarena.py | 2 + 4 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 examples/tabarena_plurel16b_inference.md diff --git a/README.md b/README.md index f3957199..10b936f7 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,9 @@ pip install relbench[tabarena] TabArena datasets are generated locally (from OpenML) and cached under `~/.cache/relbench/tabarena-*/`. Passing `download=True` will skip the RelBench server download step for these datasets/tasks. +For an end-to-end PluRel-16B TabArena inference runbook (including `split-*` task naming, random sampling behavior, and `seq_len=2048/4096` commands), see: +[`examples/tabarena_plurel16b_inference.md`](examples/tabarena_plurel16b_inference.md) + # Package Usage diff --git a/examples/tabarena_plurel16b_inference.md b/examples/tabarena_plurel16b_inference.md new file mode 100644 index 00000000..178c77b7 --- /dev/null +++ b/examples/tabarena_plurel16b_inference.md @@ -0,0 +1,150 @@ +# TabArena PluRel-16B Inference (Updated Single-Table Context) + +This runbook is for running TabArena inference with the PluRel 16B checkpoint using +the latest RelBench TabArena task semantics: + +- task names are `split-N` (not `fold-N`) +- task tables are edge-free (`fkey_col_to_pkey_table = {}`) +- no synthetic task timestamps (`time_col=None`) + +This is intended to keep TabArena task context single-node for RT/rustler-style samplers. + +## 1) Pin Revisions + +Use these exact code revisions for reproducibility: + +- `relbench`: branch `tabarena-single-table`, commit `f796cba` +- `relational-transformer`: commit `f1243c35b8410610102cfc478cdb93c1e0ab3d50` + +## 2) Verify TabArena Is Edge-Free (Required) + +Run in the `relbench` environment: + +```bash +cd /home/pc0618/relbench +python - <<'PY' +from relbench.tasks import get_task +t = get_task("tabarena-apsfailure", "split-0", download=False) +train = t.get_table("train", mask_input_cols=False) +test = t.get_table("test") +print("train_fkeys:", train.fkey_col_to_pkey_table) +print("test_fkeys:", test.fkey_col_to_pkey_table) +print("time_col:", train.time_col) +assert train.fkey_col_to_pkey_table == {} +assert test.fkey_col_to_pkey_table == {} +assert train.time_col is None +print("OK") +PY +``` + +## 3) Patch RT Sweep Script For `split-*` Tasks + +`scripts/rt_tabarena_sweep.py` in `relational-transformer` currently references `fold-*`. +Patch it to use `split-*` for task-table access: + +```bash +cd /home/pc0618/relational-transformer +rg -n "fold-" scripts/rt_tabarena_sweep.py +``` + +Update these patterns: + +- `f"fold-{fold}"` -> `f"split-{fold}"` +- `/tasks/f"fold-{fold}"` -> `/tasks/f"split-{fold}"` + +Note: keeping the CSV column name as `fold` is fine; it can still hold split index values. + +## 4) Random Sampling = True + +Rustler sampling is stochastic by default (no extra flag required): + +- random neighbor subsampling is used when fanout exceeds `max_bfs_width` +- random traversal order is used in sampling +- sampling is controlled by `seed` and `shuffle_py(...)` + +So "use random sampling = true" corresponds to running the default sampler behavior +with a fixed `--seed` for reproducibility. + +## 5) Run PluRel-16B Inference (No Training) + +Recommended environment and CPU-safe settings: + +```bash +cd /home/pc0618/relational-transformer +source "$HOME/.cargo/env" +source .venv-rt-cpu-smoke/bin/activate + +export OMP_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 +export TOKENIZERS_PARALLELISM=false +export HF_HUB_DISABLE_IMPLICIT_TOKEN=1 +``` + +### Context length 2048 + +```bash +PYTHONPATH=. python scripts/rt_tabarena_sweep.py \ + --manifest_csv /home/pc0618/relbench/results/tabarena_all51/tabarena_manifest_no_multiclass.csv \ + --output_csv /home/pc0618/relbench/results/tabarena_all51/tabarena_rt_plurel_16b_split_ctx2048_eval200.csv \ + --resume --preprocess --embed \ + --load_ckpt_path /home/pc0618/scratch/rt_ckpts_plurel/synthetic-pretrain_rdb_512_size_16b.pt \ + --train_steps 0 --skip_train_eval --skip_val_eval \ + --eval_batches 200 \ + --torch_num_threads 1 \ + --seq_len 2048 --max_bfs_width 16 --batch_size 8 \ + --num_blocks 12 --d_model 256 --num_heads 8 --d_ff 1024 \ + --disable_full_attention --use_qk_norm \ + --seed 42 +``` + +### Context length 4096 + +```bash +PYTHONPATH=. python scripts/rt_tabarena_sweep.py \ + --manifest_csv /home/pc0618/relbench/results/tabarena_all51/tabarena_manifest_no_multiclass.csv \ + --output_csv /home/pc0618/relbench/results/tabarena_all51/tabarena_rt_plurel_16b_split_ctx4096_eval200.csv \ + --resume --preprocess --embed \ + --load_ckpt_path /home/pc0618/scratch/rt_ckpts_plurel/synthetic-pretrain_rdb_512_size_16b.pt \ + --train_steps 0 --skip_train_eval --skip_val_eval \ + --eval_batches 200 \ + --torch_num_threads 1 \ + --seq_len 4096 --max_bfs_width 16 --batch_size 4 \ + --num_blocks 12 --d_model 256 --num_heads 8 --d_ff 1024 \ + --disable_full_attention --use_qk_norm \ + --seed 42 +``` + +If memory is tight at `4096`, reduce `--batch_size` to `2`. + +## 6) Confirm Single-Node Context + +After patching the RT script to `split-*`, a quick check for one sample: + +```bash +cd /home/pc0618/relational-transformer +PYTHONPATH=. python - <<'PY' +import numpy as np +from rt.data import RelationalDataset + +ds = RelationalDataset( + tasks=[("tabarena-apsfailure", "split-0", "target", "test", [])], + batch_size=1, + seq_len=4096, + rank=0, + world_size=1, + max_bfs_width=16, + embedding_model="all-MiniLM-L12-v2", + d_text=384, + seed=42, +) +item = ds[0] +mask = (~item["is_padding"][0]).cpu().numpy() +uniq = np.unique(item["node_idxs"][0].cpu().numpy()[mask]) +print("unique_nodes_in_context =", len(uniq)) +assert len(uniq) == 1 +print("OK") +PY +``` + diff --git a/relbench/tasks/tabarena.py b/relbench/tasks/tabarena.py index 3960546e..12d6b0c5 100644 --- a/relbench/tasks/tabarena.py +++ b/relbench/tasks/tabarena.py @@ -240,11 +240,22 @@ def _get_table(self, split: str) -> Table: return Table( df=df, - fkey_col_to_pkey_table={self.entity_col: self.entity_table}, + # Keep task tables edge-free so single-table TabArena contexts stay single-node + # in external context samplers (e.g., RT rustler). + fkey_col_to_pkey_table={}, pkey_col=None, time_col=None, ) + def _mask_input_cols(self, table: Table) -> Table: + # Keep entity ids visible for inference while preserving edge-free task tables. + return Table( + df=table.df[[self.entity_col]], + fkey_col_to_pkey_table={}, + pkey_col=table.pkey_col, + time_col=table.time_col, + ) + def stats(self) -> Dict[str, Dict[str, Any]]: r"""Get split-level statistics for tasks without a time column.""" res: Dict[str, Dict[str, Any]] = {} diff --git a/test/datasets/test_tabarena.py b/test/datasets/test_tabarena.py index 6bf7c2e4..e759838c 100644 --- a/test/datasets/test_tabarena.py +++ b/test/datasets/test_tabarena.py @@ -118,8 +118,10 @@ def test_tabarena_dataset_and_task_binary(monkeypatch): task = TabArenaSplitEntityTask(dataset, split=0, cache_dir=None) train_table = task.get_table("train") assert set(train_table.df.columns) == {"record_id", "target"} + assert train_table.fkey_col_to_pkey_table == {} test_table = task.get_table("test") assert set(test_table.df.columns) == {"record_id"} + assert test_table.fkey_col_to_pkey_table == {} # Perfect predictions yield AUC=1.0 => metric_error=0.0. full_test = task.get_table("test", mask_input_cols=False) From 7aa14255be44c6605db53758a2bb9d3b9b8a9556 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 05:21:07 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/tabarena_plurel16b_inference.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/tabarena_plurel16b_inference.md b/examples/tabarena_plurel16b_inference.md index 178c77b7..e49d86b5 100644 --- a/examples/tabarena_plurel16b_inference.md +++ b/examples/tabarena_plurel16b_inference.md @@ -147,4 +147,3 @@ assert len(uniq) == 1 print("OK") PY ``` - From 4e607bdd83c70cadd04d526af299250409e6e2ca Mon Sep 17 00:00:00 2001 From: Pranshu Chaturvedi Date: Mon, 30 Mar 2026 10:48:15 -0700 Subject: [PATCH 7/8] TabArena: replace internal example with public validation scripts --- README.md | 13 +- examples/tabarena_plurel16b_inference.md | 149 ------------ examples/translate_tabarena_to_relbench.py | 219 +++++++++-------- examples/validate_tabarena_baseline.py | 265 +++++++++++++++++++++ relbench/tasks/tabarena.py | 2 +- 5 files changed, 397 insertions(+), 251 deletions(-) delete mode 100644 examples/tabarena_plurel16b_inference.md create mode 100644 examples/validate_tabarena_baseline.py diff --git a/README.md b/README.md index 10b936f7..de59be7a 100644 --- a/README.md +++ b/README.md @@ -161,10 +161,17 @@ To use TabArena datasets, install the optional dependency: pip install relbench[tabarena] ``` -TabArena datasets are generated locally (from OpenML) and cached under `~/.cache/relbench/tabarena-*/`. Passing `download=True` will skip the RelBench server download step for these datasets/tasks. +See [`examples/translate_tabarena_to_relbench.py`](examples/translate_tabarena_to_relbench.py) for a comparison between the original OpenML task and the relbenchified `records` and `split-*` tables, and [`examples/validate_tabarena_baseline.py`](examples/validate_tabarena_baseline.py) for a public-baseline validation script that checks inference consistency between the original and relbenchified views. -For an end-to-end PluRel-16B TabArena inference runbook (including `split-*` task naming, random sampling behavior, and `seq_len=2048/4096` commands), see: -[`examples/tabarena_plurel16b_inference.md`](examples/tabarena_plurel16b_inference.md) +If you use the TabArena datasets in your work, please cite TabArena as below: +``` +@inproceedings{erickson2025tabarena, + title={TabArena: A Living Benchmark for Machine Learning on Tabular Data}, + author={Erickson, Nick and Purucker, Lennart and Tschalzev, Andrej and Holzm{\"u}ller, David and Mutalik Desai, Prateek and Salinas, David and Hutter, Frank}, + booktitle={Advances in Neural Information Processing Systems}, + year={2025} +} +``` # Package Usage diff --git a/examples/tabarena_plurel16b_inference.md b/examples/tabarena_plurel16b_inference.md deleted file mode 100644 index e49d86b5..00000000 --- a/examples/tabarena_plurel16b_inference.md +++ /dev/null @@ -1,149 +0,0 @@ -# TabArena PluRel-16B Inference (Updated Single-Table Context) - -This runbook is for running TabArena inference with the PluRel 16B checkpoint using -the latest RelBench TabArena task semantics: - -- task names are `split-N` (not `fold-N`) -- task tables are edge-free (`fkey_col_to_pkey_table = {}`) -- no synthetic task timestamps (`time_col=None`) - -This is intended to keep TabArena task context single-node for RT/rustler-style samplers. - -## 1) Pin Revisions - -Use these exact code revisions for reproducibility: - -- `relbench`: branch `tabarena-single-table`, commit `f796cba` -- `relational-transformer`: commit `f1243c35b8410610102cfc478cdb93c1e0ab3d50` - -## 2) Verify TabArena Is Edge-Free (Required) - -Run in the `relbench` environment: - -```bash -cd /home/pc0618/relbench -python - <<'PY' -from relbench.tasks import get_task -t = get_task("tabarena-apsfailure", "split-0", download=False) -train = t.get_table("train", mask_input_cols=False) -test = t.get_table("test") -print("train_fkeys:", train.fkey_col_to_pkey_table) -print("test_fkeys:", test.fkey_col_to_pkey_table) -print("time_col:", train.time_col) -assert train.fkey_col_to_pkey_table == {} -assert test.fkey_col_to_pkey_table == {} -assert train.time_col is None -print("OK") -PY -``` - -## 3) Patch RT Sweep Script For `split-*` Tasks - -`scripts/rt_tabarena_sweep.py` in `relational-transformer` currently references `fold-*`. -Patch it to use `split-*` for task-table access: - -```bash -cd /home/pc0618/relational-transformer -rg -n "fold-" scripts/rt_tabarena_sweep.py -``` - -Update these patterns: - -- `f"fold-{fold}"` -> `f"split-{fold}"` -- `/tasks/f"fold-{fold}"` -> `/tasks/f"split-{fold}"` - -Note: keeping the CSV column name as `fold` is fine; it can still hold split index values. - -## 4) Random Sampling = True - -Rustler sampling is stochastic by default (no extra flag required): - -- random neighbor subsampling is used when fanout exceeds `max_bfs_width` -- random traversal order is used in sampling -- sampling is controlled by `seed` and `shuffle_py(...)` - -So "use random sampling = true" corresponds to running the default sampler behavior -with a fixed `--seed` for reproducibility. - -## 5) Run PluRel-16B Inference (No Training) - -Recommended environment and CPU-safe settings: - -```bash -cd /home/pc0618/relational-transformer -source "$HOME/.cargo/env" -source .venv-rt-cpu-smoke/bin/activate - -export OMP_NUM_THREADS=1 -export OPENBLAS_NUM_THREADS=1 -export MKL_NUM_THREADS=1 -export NUMEXPR_NUM_THREADS=1 -export TOKENIZERS_PARALLELISM=false -export HF_HUB_DISABLE_IMPLICIT_TOKEN=1 -``` - -### Context length 2048 - -```bash -PYTHONPATH=. python scripts/rt_tabarena_sweep.py \ - --manifest_csv /home/pc0618/relbench/results/tabarena_all51/tabarena_manifest_no_multiclass.csv \ - --output_csv /home/pc0618/relbench/results/tabarena_all51/tabarena_rt_plurel_16b_split_ctx2048_eval200.csv \ - --resume --preprocess --embed \ - --load_ckpt_path /home/pc0618/scratch/rt_ckpts_plurel/synthetic-pretrain_rdb_512_size_16b.pt \ - --train_steps 0 --skip_train_eval --skip_val_eval \ - --eval_batches 200 \ - --torch_num_threads 1 \ - --seq_len 2048 --max_bfs_width 16 --batch_size 8 \ - --num_blocks 12 --d_model 256 --num_heads 8 --d_ff 1024 \ - --disable_full_attention --use_qk_norm \ - --seed 42 -``` - -### Context length 4096 - -```bash -PYTHONPATH=. python scripts/rt_tabarena_sweep.py \ - --manifest_csv /home/pc0618/relbench/results/tabarena_all51/tabarena_manifest_no_multiclass.csv \ - --output_csv /home/pc0618/relbench/results/tabarena_all51/tabarena_rt_plurel_16b_split_ctx4096_eval200.csv \ - --resume --preprocess --embed \ - --load_ckpt_path /home/pc0618/scratch/rt_ckpts_plurel/synthetic-pretrain_rdb_512_size_16b.pt \ - --train_steps 0 --skip_train_eval --skip_val_eval \ - --eval_batches 200 \ - --torch_num_threads 1 \ - --seq_len 4096 --max_bfs_width 16 --batch_size 4 \ - --num_blocks 12 --d_model 256 --num_heads 8 --d_ff 1024 \ - --disable_full_attention --use_qk_norm \ - --seed 42 -``` - -If memory is tight at `4096`, reduce `--batch_size` to `2`. - -## 6) Confirm Single-Node Context - -After patching the RT script to `split-*`, a quick check for one sample: - -```bash -cd /home/pc0618/relational-transformer -PYTHONPATH=. python - <<'PY' -import numpy as np -from rt.data import RelationalDataset - -ds = RelationalDataset( - tasks=[("tabarena-apsfailure", "split-0", "target", "test", [])], - batch_size=1, - seq_len=4096, - rank=0, - world_size=1, - max_bfs_width=16, - embedding_model="all-MiniLM-L12-v2", - d_text=384, - seed=42, -) -item = ds[0] -mask = (~item["is_padding"][0]).cpu().numpy() -uniq = np.unique(item["node_idxs"][0].cpu().numpy()[mask]) -print("unique_nodes_in_context =", len(uniq)) -assert len(uniq) == 1 -print("OK") -PY -``` diff --git a/examples/translate_tabarena_to_relbench.py b/examples/translate_tabarena_to_relbench.py index a8d69c6c..78495f69 100644 --- a/examples/translate_tabarena_to_relbench.py +++ b/examples/translate_tabarena_to_relbench.py @@ -1,123 +1,146 @@ -"""Utilities to inspect how TabArena datasets are represented in RelBench.""" +"""Inspect how a TabArena OpenML task is represented in RelBench. + +This script compares the original OpenML dataset/task with the RelBench wrapper: +the source rows become a single ``records`` table, while each ``split-*`` task +materializes a thin task table keyed by ``record_id``. +""" + +from __future__ import annotations import argparse -from pathlib import Path +from dataclasses import asdict +import numpy as np import pandas as pd -from relbench.datasets import get_dataset -from relbench.datasets.tabarena import TABARENA_DATASETS, get_tabarena_dataset_slugs +from relbench.datasets.tabarena import TabArenaDataset, get_tabarena_dataset_slugs +from relbench.tasks.tabarena import TabArenaSplitEntityTask -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--dataset_slugs", + "--dataset", type=str, - default="all", - help=( - "Comma-separated TabArena slugs or 'all'. " - "Example: credit-g,airfoil-self-noise" - ), + default="credit-g", + choices=get_tabarena_dataset_slugs(), + help="TabArena dataset slug, for example `credit-g` or `airfoil-self-noise`.", ) parser.add_argument( - "--output_csv", - type=str, - default="results/tabarena_relbench_translation.csv", + "--split", + type=int, + default=0, + help="OpenML split index exposed in RelBench as `split-`.", ) parser.add_argument( - "--include_split_examples", - action="store_true", - default=False, - help="If set, writes one sample record for split-0 per dataset.", + "--show_rows", + type=int, + default=3, + help="Number of joined examples to print from the RelBench train split.", ) return parser.parse_args() -def _parse_dataset_slugs(arg: str) -> list[str]: - if arg.strip().lower() == "all": - return get_tabarena_dataset_slugs() - slugs = [slug.strip() for slug in arg.split(",") if slug.strip()] - valid = set(get_tabarena_dataset_slugs()) - invalid = [slug for slug in slugs if slug not in valid] - if invalid: - raise ValueError(f"Unknown dataset slugs: {invalid}") - return slugs - - -def _summarize_dataset(dataset_name: str) -> dict: - dataset = get_dataset(dataset_name, download=False) - spec = TABARENA_DATASETS[dataset.name.replace("tabarena-", "")] - db = dataset.make_db() - records = db.table_dict["records"] - row = { - "dataset_slug": spec.slug, +def _load_openml_frame(dataset: TabArenaDataset) -> tuple[pd.DataFrame, pd.Series]: + task = dataset.get_openml_task() + X_df, y_ser, _cat, _names = task.get_dataset().get_data( + target=task.target_name, + dataset_format="dataframe", + ) + X_df = pd.DataFrame(X_df).reset_index(drop=True) + y_ser = pd.Series(y_ser, name=task.target_name).reset_index(drop=True) + return X_df, y_ser + + +def _join_records(records_df: pd.DataFrame, task_df: pd.DataFrame) -> pd.DataFrame: + joined = task_df.merge(records_df, on="record_id", how="left", validate="1:1") + feature_cols = [col for col in joined.columns if col not in {"record_id", "target"}] + return joined[["record_id", *feature_cols, "target"]] + + +def _check_translation( + dataset: TabArenaDataset, + task: TabArenaSplitEntityTask, +) -> dict[str, object]: + X_df, _y_ser = _load_openml_frame(dataset) + y_encoded = pd.Series(dataset.get_target_array(), name="target").reset_index(drop=True) + records_df = dataset.get_db().table_dict["records"].df.reset_index(drop=True) + + openml_train_idx, openml_test_idx = dataset.get_openml_split_indices(task.split) + train_table = task.get_table("train", mask_input_cols=False) + val_table = task.get_table("val", mask_input_cols=False) + test_table = task.get_table("test", mask_input_cols=False) + + relbench_train_ids = train_table.df["record_id"].to_numpy() + relbench_val_ids = val_table.df["record_id"].to_numpy() + relbench_test_ids = test_table.df["record_id"].to_numpy() + relbench_trainval_ids = set(relbench_train_ids).union(relbench_val_ids) + + target_matches = True + for split_name, split_df in { + "train": train_table.df, + "val": val_table.df, + "test": test_table.df, + }.items(): + relbench_target = split_df["target"].reset_index(drop=True) + source_target = y_encoded.iloc[split_df["record_id"].to_numpy()].reset_index(drop=True) + if not relbench_target.equals(source_target): + target_matches = False + print(f"[check] target mismatch on {split_name}") + + return { + "dataset_slug": dataset.spec.slug, "dataset_name": dataset.name, - "tabarena_benchmark_name": spec.name, - "openml_task_id": spec.task_id, - "openml_dataset_id": spec.dataset_id, - "target_col": spec.target, - "problem_type": spec.task_type, - "num_classes": spec.num_classes, - "split_count": spec.fold_count, - "records_rows": int(len(records.df)), - "records_columns": int(len(records.df.columns)), - "entity_table": "records", - "entity_pkey": records.pkey_col, - "split_timestamp_columns": bool(records.time_col is not None), + "tabarena_name": dataset.tabarena_name, + "problem_type": dataset.problem_type, + "openml_task_id": dataset.task_id, + "openml_dataset_id": dataset.openml_dataset_id, + "target_name": dataset.target_name, + "records_rows": len(records_df), + "records_feature_columns": len(records_df.columns) - 1, + "openml_rows": len(X_df), + "openml_train_rows": len(openml_train_idx), + "openml_test_rows": len(openml_test_idx), + "relbench_train_rows": len(train_table.df), + "relbench_val_rows": len(val_table.df), + "relbench_test_rows": len(test_table.df), + "records_match_openml_rows": len(records_df) == len(X_df), + "record_ids_are_row_indices": np.array_equal( + records_df["record_id"].to_numpy(), + np.arange(len(records_df), dtype=np.int64), + ), + "relbench_test_matches_openml_test": set(relbench_test_ids) == set(openml_test_idx), + "relbench_train_val_partition_openml_train": relbench_trainval_ids == set( + openml_train_idx + ), + "relbench_train_val_are_disjoint": set(relbench_train_ids).isdisjoint( + relbench_val_ids + ), + "targets_match_source_rows": target_matches, } - return row - - -def _summarize_split_example(dataset_name: str) -> list[dict]: - from relbench.tasks import get_task - - rows: list[dict] = [] - for split in [0]: - task = get_task(dataset_name, f"split-{split}") - train = task.get_table("train", mask_input_cols=False) - val = task.get_table("val", mask_input_cols=False) - test = task.get_table("test", mask_input_cols=False) - rows.append( - { - "dataset_name": dataset_name, - "split": int(split), - "task_type": str(task.task_type.value), - "split_train_rows": int(len(train)), - "split_val_rows": int(len(val)), - "split_test_rows": int(len(test)), - "train_columns": int(len(train.df.columns)), - "time_col": str(train.time_col), - } - ) - return rows def main() -> None: - args = _parse_args() - dataset_slugs = _parse_dataset_slugs(args.dataset_slugs) - output_path = Path(args.output_csv) - output_path.parent.mkdir(parents=True, exist_ok=True) - - dataset_rows: list[dict] = [] - split_rows: list[dict] = [] - for slug in dataset_slugs: - dataset_name = f"tabarena-{slug}" - print(f"[Inspect] {dataset_name}") - dataset_rows.append(_summarize_dataset(dataset_name)) - if args.include_split_examples: - split_rows.extend(_summarize_split_example(dataset_name)) - - df = pd.DataFrame(dataset_rows) - df.to_csv(output_path, index=False) - - if split_rows: - split_path = output_path.with_name(f"{output_path.stem}_split_samples.csv") - pd.DataFrame(split_rows).to_csv(split_path, index=False) - print(f"[Done] dataset summary: {output_path}") - print(f"[Done] split sample summary: {split_path}") - else: - print(f"[Done] dataset summary: {output_path}") + args = parse_args() + + dataset = TabArenaDataset(dataset_slug=args.dataset) + task = TabArenaSplitEntityTask(dataset, split=args.split) + records_df = dataset.get_db().table_dict["records"].df.reset_index(drop=True) + + print("[dataset spec]") + print(asdict(dataset.spec)) + print() + + summary = _check_translation(dataset, task) + print("[translation summary]") + for key, value in summary.items(): + print(f"{key}: {value}") + print() + + train_table = task.get_table("train", mask_input_cols=False) + joined_train = _join_records(records_df, train_table.df).head(args.show_rows) + print("[joined relbench train rows]") + print(joined_train.to_string(index=False)) if __name__ == "__main__": diff --git a/examples/validate_tabarena_baseline.py b/examples/validate_tabarena_baseline.py new file mode 100644 index 00000000..ca8891bc --- /dev/null +++ b/examples/validate_tabarena_baseline.py @@ -0,0 +1,265 @@ +"""Validate the TabArena RelBench wrapper with a public baseline. + +The script trains a baseline on the original OpenML rows selected by a RelBench +task split, then runs inference on both the original and relbenchified views of +the same validation and test rows. Matching predictions and metrics confirm that +the RelBench wrapper preserves the original single-table task semantics. +""" + +from __future__ import annotations + +import argparse + +import numpy as np +import pandas as pd +from pandas import CategoricalDtype +from sklearn.compose import ColumnTransformer +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.impute import SimpleImputer +from sklearn.metrics import mean_squared_error, roc_auc_score +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import OneHotEncoder + +from relbench.datasets.tabarena import TabArenaDataset, get_tabarena_dataset_slugs +from relbench.tasks.tabarena import TabArenaSplitEntityTask + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--dataset", + type=str, + default="credit-g", + choices=get_tabarena_dataset_slugs(), + help="TabArena dataset slug, for example `credit-g` or `airfoil-self-noise`.", + ) + parser.add_argument( + "--split", + type=int, + default=0, + help="OpenML split index exposed in RelBench as `split-`.", + ) + parser.add_argument( + "--model", + type=str, + default="random-forest", + choices=["random-forest", "xgboost"], + help="Public baseline model used for validation.", + ) + parser.add_argument( + "--random_state", + type=int, + default=0, + ) + return parser.parse_args() + + +def _load_openml_frame(dataset: TabArenaDataset) -> pd.DataFrame: + task = dataset.get_openml_task() + X_df, y_ser, _cat, _names = task.get_dataset().get_data( + target=task.target_name, + dataset_format="dataframe", + ) + X_df = pd.DataFrame(X_df).reset_index(drop=True) + y_ser = pd.Series(y_ser, name=task.target_name).reset_index(drop=True) + _ = y_ser + return X_df + + +def _join_records(records_df: pd.DataFrame, task_df: pd.DataFrame) -> tuple[pd.DataFrame, np.ndarray]: + joined = task_df.merge(records_df, on="record_id", how="left", validate="1:1") + X_df = joined.drop(columns=["record_id", "target"]) + y = joined["target"].to_numpy() + return X_df, y + + +def _build_preprocessor(X_df: pd.DataFrame) -> ColumnTransformer: + categorical_cols = [ + col + for col in X_df.columns + if pd.api.types.is_object_dtype(X_df[col]) + or isinstance(X_df[col].dtype, CategoricalDtype) + or pd.api.types.is_bool_dtype(X_df[col]) + ] + numeric_cols = [col for col in X_df.columns if col not in categorical_cols] + + return ColumnTransformer( + transformers=[ + ( + "numeric", + Pipeline( + [("imputer", SimpleImputer(strategy="median"))] + ), + numeric_cols, + ), + ( + "categorical", + Pipeline( + [ + ("imputer", SimpleImputer(strategy="most_frequent")), + ( + "encoder", + OneHotEncoder(handle_unknown="ignore", sparse_output=True), + ), + ] + ), + categorical_cols, + ), + ] + ) + + +def _build_model( + *, + task: TabArenaSplitEntityTask, + X_df: pd.DataFrame, + model_name: str, + random_state: int, +) -> Pipeline: + preprocessor = _build_preprocessor(X_df) + + if model_name == "random-forest": + if task.task_type.value == "regression": + estimator = RandomForestRegressor( + n_estimators=200, + random_state=random_state, + n_jobs=-1, + ) + else: + estimator = RandomForestClassifier( + n_estimators=200, + random_state=random_state, + n_jobs=-1, + ) + elif model_name == "xgboost": + try: + from xgboost import XGBClassifier, XGBRegressor + except ImportError as exc: # pragma: no cover - dependency is optional + raise ImportError( + "The `xgboost` example requires `pip install xgboost`." + ) from exc + + if task.task_type.value == "regression": + estimator = XGBRegressor( + n_estimators=300, + max_depth=6, + learning_rate=0.05, + subsample=0.8, + colsample_bytree=0.8, + tree_method="hist", + random_state=random_state, + ) + else: + estimator = XGBClassifier( + n_estimators=300, + max_depth=6, + learning_rate=0.05, + subsample=0.8, + colsample_bytree=0.8, + tree_method="hist", + eval_metric="logloss", + random_state=random_state, + ) + else: # pragma: no cover - guarded by argparse + raise ValueError(f"Unsupported model {model_name!r}") + + return Pipeline( + [ + ("preprocess", preprocessor), + ("model", estimator), + ] + ) + + +def _predict( + model: Pipeline, + task: TabArenaSplitEntityTask, + X_df: pd.DataFrame, +) -> np.ndarray: + if task.task_type.value == "regression": + return np.asarray(model.predict(X_df), dtype=np.float64) + proba = model.predict_proba(X_df) + if proba.ndim != 2 or proba.shape[1] != 2: + raise RuntimeError(f"Expected binary predict_proba output, got shape={proba.shape}") + return np.asarray(proba[:, 1], dtype=np.float64) + + +def _metric_name_and_value( + task: TabArenaSplitEntityTask, y_true: np.ndarray, pred: np.ndarray +) -> tuple[str, float]: + if task.task_type.value == "regression": + return "rmse", float(np.sqrt(mean_squared_error(y_true, pred))) + return "auroc", float(roc_auc_score(y_true, pred)) + + +def _max_prediction_delta(a: np.ndarray, b: np.ndarray) -> float: + a = np.asarray(a, dtype=np.float64).reshape(-1) + b = np.asarray(b, dtype=np.float64).reshape(-1) + return float(np.max(np.abs(a - b))) if len(a) else 0.0 + + +def main() -> None: + args = parse_args() + + dataset = TabArenaDataset(dataset_slug=args.dataset) + task = TabArenaSplitEntityTask(dataset, split=args.split) + records_df = dataset.get_db().table_dict["records"].df.reset_index(drop=True) + openml_X = _load_openml_frame(dataset) + openml_y = pd.Series(dataset.get_target_array(), name="target").reset_index(drop=True) + + relbench_train = task.get_table("train", mask_input_cols=False).df + relbench_val = task.get_table("val", mask_input_cols=False).df + relbench_test = task.get_table("test", mask_input_cols=False).df + + X_train_rb, y_train_rb = _join_records(records_df, relbench_train) + X_val_rb, y_val_rb = _join_records(records_df, relbench_val) + X_test_rb, y_test_rb = _join_records(records_df, relbench_test) + + X_train_orig = openml_X.iloc[relbench_train["record_id"].to_numpy()].reset_index(drop=True) + y_train_orig = openml_y.iloc[relbench_train["record_id"].to_numpy()].to_numpy() + X_val_orig = openml_X.iloc[relbench_val["record_id"].to_numpy()].reset_index(drop=True) + y_val_orig = openml_y.iloc[relbench_val["record_id"].to_numpy()].to_numpy() + X_test_orig = openml_X.iloc[relbench_test["record_id"].to_numpy()].reset_index(drop=True) + y_test_orig = openml_y.iloc[relbench_test["record_id"].to_numpy()].to_numpy() + + print("[data equality checks]") + print(f"train features identical: {X_train_orig.equals(X_train_rb)}") + print(f"val features identical: {X_val_orig.equals(X_val_rb)}") + print(f"test features identical: {X_test_orig.equals(X_test_rb)}") + print(f"train labels identical: {np.array_equal(y_train_orig, y_train_rb)}") + print(f"val labels identical: {np.array_equal(y_val_orig, y_val_rb)}") + print(f"test labels identical: {np.array_equal(y_test_orig, y_test_rb)}") + print() + + model = _build_model( + task=task, + X_df=X_train_orig, + model_name=args.model, + random_state=args.random_state, + ) + model.fit(X_train_orig, y_train_orig) + + val_pred_orig = _predict(model, task, X_val_orig) + val_pred_rb = _predict(model, task, X_val_rb) + test_pred_orig = _predict(model, task, X_test_orig) + test_pred_rb = _predict(model, task, X_test_rb) + + metric_name, val_metric_orig = _metric_name_and_value(task, y_val_orig, val_pred_orig) + _, val_metric_rb = _metric_name_and_value(task, y_val_rb, val_pred_rb) + _, test_metric_orig = _metric_name_and_value(task, y_test_orig, test_pred_orig) + _, test_metric_rb = _metric_name_and_value(task, y_test_rb, test_pred_rb) + + print("[prediction consistency]") + print(f"val max |orig - relbench|: {_max_prediction_delta(val_pred_orig, val_pred_rb):.6g}") + print(f"test max |orig - relbench|: {_max_prediction_delta(test_pred_orig, test_pred_rb):.6g}") + print() + + print(f"[{metric_name}]") + print(f"validation original-openml: {val_metric_orig:.6f}") + print(f"validation relbenchified: {val_metric_rb:.6f}") + print(f"test original-openml: {test_metric_orig:.6f}") + print(f"test relbenchified: {test_metric_rb:.6f}") + + +if __name__ == "__main__": + main() diff --git a/relbench/tasks/tabarena.py b/relbench/tasks/tabarena.py index 12d6b0c5..3b6cbec9 100644 --- a/relbench/tasks/tabarena.py +++ b/relbench/tasks/tabarena.py @@ -236,7 +236,7 @@ def _get_table(self, split: str) -> Table: ) if self.task_type != TaskType.REGRESSION: - df[self.target_col] = df[self.target_col].astype(np.int64, copy=False) + df[self.target_col] = df[self.target_col].astype(np.int64) return Table( df=df, From 4558a9d5d3d0193fbb6c9221afea0c9ca63f8064 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 05:19:21 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/translate_tabarena_to_relbench.py | 16 +++++--- examples/validate_tabarena_baseline.py | 48 ++++++++++++++-------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/examples/translate_tabarena_to_relbench.py b/examples/translate_tabarena_to_relbench.py index 78495f69..420447d5 100644 --- a/examples/translate_tabarena_to_relbench.py +++ b/examples/translate_tabarena_to_relbench.py @@ -63,7 +63,9 @@ def _check_translation( task: TabArenaSplitEntityTask, ) -> dict[str, object]: X_df, _y_ser = _load_openml_frame(dataset) - y_encoded = pd.Series(dataset.get_target_array(), name="target").reset_index(drop=True) + y_encoded = pd.Series(dataset.get_target_array(), name="target").reset_index( + drop=True + ) records_df = dataset.get_db().table_dict["records"].df.reset_index(drop=True) openml_train_idx, openml_test_idx = dataset.get_openml_split_indices(task.split) @@ -83,7 +85,9 @@ def _check_translation( "test": test_table.df, }.items(): relbench_target = split_df["target"].reset_index(drop=True) - source_target = y_encoded.iloc[split_df["record_id"].to_numpy()].reset_index(drop=True) + source_target = y_encoded.iloc[split_df["record_id"].to_numpy()].reset_index( + drop=True + ) if not relbench_target.equals(source_target): target_matches = False print(f"[check] target mismatch on {split_name}") @@ -109,10 +113,10 @@ def _check_translation( records_df["record_id"].to_numpy(), np.arange(len(records_df), dtype=np.int64), ), - "relbench_test_matches_openml_test": set(relbench_test_ids) == set(openml_test_idx), - "relbench_train_val_partition_openml_train": relbench_trainval_ids == set( - openml_train_idx - ), + "relbench_test_matches_openml_test": set(relbench_test_ids) + == set(openml_test_idx), + "relbench_train_val_partition_openml_train": relbench_trainval_ids + == set(openml_train_idx), "relbench_train_val_are_disjoint": set(relbench_train_ids).isdisjoint( relbench_val_ids ), diff --git a/examples/validate_tabarena_baseline.py b/examples/validate_tabarena_baseline.py index ca8891bc..6a1828f3 100644 --- a/examples/validate_tabarena_baseline.py +++ b/examples/validate_tabarena_baseline.py @@ -1,9 +1,9 @@ """Validate the TabArena RelBench wrapper with a public baseline. -The script trains a baseline on the original OpenML rows selected by a RelBench -task split, then runs inference on both the original and relbenchified views of -the same validation and test rows. Matching predictions and metrics confirm that -the RelBench wrapper preserves the original single-table task semantics. +The script trains a baseline on the original OpenML rows selected by a RelBench task +split, then runs inference on both the original and relbenchified views of the same +validation and test rows. Matching predictions and metrics confirm that the RelBench +wrapper preserves the original single-table task semantics. """ from __future__ import annotations @@ -66,7 +66,9 @@ def _load_openml_frame(dataset: TabArenaDataset) -> pd.DataFrame: return X_df -def _join_records(records_df: pd.DataFrame, task_df: pd.DataFrame) -> tuple[pd.DataFrame, np.ndarray]: +def _join_records( + records_df: pd.DataFrame, task_df: pd.DataFrame +) -> tuple[pd.DataFrame, np.ndarray]: joined = task_df.merge(records_df, on="record_id", how="left", validate="1:1") X_df = joined.drop(columns=["record_id", "target"]) y = joined["target"].to_numpy() @@ -87,9 +89,7 @@ def _build_preprocessor(X_df: pd.DataFrame) -> ColumnTransformer: transformers=[ ( "numeric", - Pipeline( - [("imputer", SimpleImputer(strategy="median"))] - ), + Pipeline([("imputer", SimpleImputer(strategy="median"))]), numeric_cols, ), ( @@ -180,7 +180,9 @@ def _predict( return np.asarray(model.predict(X_df), dtype=np.float64) proba = model.predict_proba(X_df) if proba.ndim != 2 or proba.shape[1] != 2: - raise RuntimeError(f"Expected binary predict_proba output, got shape={proba.shape}") + raise RuntimeError( + f"Expected binary predict_proba output, got shape={proba.shape}" + ) return np.asarray(proba[:, 1], dtype=np.float64) @@ -205,7 +207,9 @@ def main() -> None: task = TabArenaSplitEntityTask(dataset, split=args.split) records_df = dataset.get_db().table_dict["records"].df.reset_index(drop=True) openml_X = _load_openml_frame(dataset) - openml_y = pd.Series(dataset.get_target_array(), name="target").reset_index(drop=True) + openml_y = pd.Series(dataset.get_target_array(), name="target").reset_index( + drop=True + ) relbench_train = task.get_table("train", mask_input_cols=False).df relbench_val = task.get_table("val", mask_input_cols=False).df @@ -215,11 +219,17 @@ def main() -> None: X_val_rb, y_val_rb = _join_records(records_df, relbench_val) X_test_rb, y_test_rb = _join_records(records_df, relbench_test) - X_train_orig = openml_X.iloc[relbench_train["record_id"].to_numpy()].reset_index(drop=True) + X_train_orig = openml_X.iloc[relbench_train["record_id"].to_numpy()].reset_index( + drop=True + ) y_train_orig = openml_y.iloc[relbench_train["record_id"].to_numpy()].to_numpy() - X_val_orig = openml_X.iloc[relbench_val["record_id"].to_numpy()].reset_index(drop=True) + X_val_orig = openml_X.iloc[relbench_val["record_id"].to_numpy()].reset_index( + drop=True + ) y_val_orig = openml_y.iloc[relbench_val["record_id"].to_numpy()].to_numpy() - X_test_orig = openml_X.iloc[relbench_test["record_id"].to_numpy()].reset_index(drop=True) + X_test_orig = openml_X.iloc[relbench_test["record_id"].to_numpy()].reset_index( + drop=True + ) y_test_orig = openml_y.iloc[relbench_test["record_id"].to_numpy()].to_numpy() print("[data equality checks]") @@ -244,14 +254,20 @@ def main() -> None: test_pred_orig = _predict(model, task, X_test_orig) test_pred_rb = _predict(model, task, X_test_rb) - metric_name, val_metric_orig = _metric_name_and_value(task, y_val_orig, val_pred_orig) + metric_name, val_metric_orig = _metric_name_and_value( + task, y_val_orig, val_pred_orig + ) _, val_metric_rb = _metric_name_and_value(task, y_val_rb, val_pred_rb) _, test_metric_orig = _metric_name_and_value(task, y_test_orig, test_pred_orig) _, test_metric_rb = _metric_name_and_value(task, y_test_rb, test_pred_rb) print("[prediction consistency]") - print(f"val max |orig - relbench|: {_max_prediction_delta(val_pred_orig, val_pred_rb):.6g}") - print(f"test max |orig - relbench|: {_max_prediction_delta(test_pred_orig, test_pred_rb):.6g}") + print( + f"val max |orig - relbench|: {_max_prediction_delta(val_pred_orig, val_pred_rb):.6g}" + ) + print( + f"test max |orig - relbench|: {_max_prediction_delta(test_pred_orig, test_pred_rb):.6g}" + ) print() print(f"[{metric_name}]")