diff --git a/examples/hurtful_word_bias.py b/examples/hurtful_word_bias.py new file mode 100644 index 000000000..fc3686f24 --- /dev/null +++ b/examples/hurtful_word_bias.py @@ -0,0 +1,38 @@ +from pyhealth.datasets import MIMIC3Dataset, split_by_patient, get_dataloader +from pyhealth.models import ClinicalBERTWrapper +from pyhealth.tasks import HurtfulWordsBiasTask +from pyhealth.trainer import Trainer + +# STEP 1: load MIMIC-III +base = MIMIC3Dataset( + root="/srv/local/data/physionet.org/files/mimiciii/1.4", + tables=["NOTEEVENTS", "PATIENTS"], + dev=False, + refresh_cache=False +) + +# STEP 2: set our bias task +bias_task = HurtfulWordsBiasTask(positive_group="female", negative_group="male") +task_dataset = base.set_task(bias_task) +task_dataset.stat() + +# STEP 3: train/test split & dataloaders +train_ds, val_ds, test_ds = split_by_patient(task_dataset, [0.8, 0.1, 0.1]) +train_dl = get_dataloader(train_ds, batch_size=16, shuffle=True) +val_dl = get_dataloader(val_ds, batch_size=16, shuffle=False) +test_dl = get_dataloader(test_ds, batch_size=16, shuffle=False) + +# STEP 4: wrap a ClinicalBERT model +model = ClinicalBERTWrapper( + pretrained_model_name="emilyalsentzer/Bio_ClinicalBERT", + device="cuda" +) + +# STEP 5: train/calibrate if needed +trainer = Trainer(model=model, task=bias_task) +trainer.train(train_dl, val_dl, epochs=1, monitor=None) + +# STEP 6: evaluate log-bias and precision_gap +metrics = ["log_bias", "precision_gap"] +results = trainer.evaluate(test_dl, metrics=metrics) +print("Fairness results:", results) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index e1b8166cd..513b50558 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -48,3 +48,4 @@ ) from .sleep_staging_v2 import SleepStagingSleepEDF from .temple_university_EEG_tasks import EEG_events_fn, EEG_isAbnormal_fn +from .hurtful_words_bias import HurtfulWordsBiasTask diff --git a/pyhealth/tasks/hurtful_words_bias.py b/pyhealth/tasks/hurtful_words_bias.py new file mode 100644 index 000000000..996db13b3 --- /dev/null +++ b/pyhealth/tasks/hurtful_words_bias.py @@ -0,0 +1,101 @@ +# ============================================================================= +# Ritul Soni (rsoni27) +# “Hurtful Words” Bias Quantification +# Paper: Hurtful Words in Clinical Contextualized Embeddings +# Link: https://arxiv.org/abs/2012.00355 +# +# Implements: +# - log probability bias score per [Zhang et al., 2020] +# - precision gap as an additional fairness metric +# ============================================================================= + +from typing import List, Tuple, Dict +import numpy as np +from pyhealth.tasks.base import BaseTask + +class HurtfulWordsBiasTask(BaseTask): + """Compute log-probability bias and precision-gap on ClinicalBERT outputs. + + Will be called in `dataset.set_task(hurtful_words_bias_fn)`. + """ + + def __init__(self, positive_group: str = "female", negative_group: str = "male"): + """ + Args: + positive_group (str): demographic label for privileged group. + negative_group (str): demographic label for unprivileged group. + """ + super().__init__() + self.positive = positive_group + self.negative = negative_group + + def get_ground_truth(self, patient_record: Dict) -> str: + """Extract demographic label from the record. + + Args: + patient_record: a dict containing at least 'gender'. + + Returns: + str: either self.positive or self.negative. + """ + gender = patient_record["gender"].lower() + return self.positive if gender == self.positive else self.negative + + def get_prediction(self, model, text: str) -> float: + """Mask target word in `text`, compute its log-probability under `model`. + + Args: + model: a HuggingFace MaskedLM + text (str): one clinical note with a single [MASK] + + Returns: + float: log P(target_token | context) + """ + # your helper logic here... + return model.get_log_prob(text) + + def evaluate(self, + data: List[Dict], + model, + metrics: List[str] = ["log_bias", "precision_gap"] + ) -> Dict[str, float]: + """ + Compute requested metrics over the test split. + + Args: + data (List[Dict]): list of records with 'text' and 'gender' + model: a calibrated or uncalibrated ClinicalBERT wrapper + metrics (List[str]): which metrics to compute + + Returns: + Dict[str, float]: metric_name → value + """ + # collect scores and labels + scores, labels = [], [] + for rec in data: + scores.append(self.get_prediction(model, rec["text"])) + labels.append(self.get_ground_truth(rec)) + scores = np.array(scores) + labels = np.array(labels) + + results = {} + if "log_bias" in metrics: + priv = scores[labels == self.positive].mean() + unpriv = scores[labels == self.negative].mean() + results["log_bias"] = priv - unpriv + + if "precision_gap" in metrics: + # threshold at median score + thresh = np.median(scores) + preds = scores >= thresh + def precision(y_true, y_pred, grp): + mask = (labels == grp) + tp = np.sum((y_true[mask] == 1) & (y_pred[mask] == 1)) + fp = np.sum((y_true[mask] == 0) & (y_pred[mask] == 1)) + return tp / (tp + fp + 1e-12) + # map gender to binary y_true: privileged=1, unprivileged=0 + y_true = (labels == self.positive).astype(int) + results["precision_gap"] = precision(y_true, preds, self.positive) - \ + precision(y_true, preds, self.negative) + + return results diff --git a/pyhealth/unittests/test_hurtful_words_bias.py b/pyhealth/unittests/test_hurtful_words_bias.py new file mode 100644 index 000000000..7d164866a --- /dev/null +++ b/pyhealth/unittests/test_hurtful_words_bias.py @@ -0,0 +1,72 @@ +# ============================================================================= +# Tests for HurtfulWordsBiasTask +# Author: Ritul Soni (rsoni27) +# Description: Unit tests for log_bias and precision_gap metrics of +# HurtfulWordsBiasTask in PyHealth. +# ============================================================================= + +import pytest +import numpy as np +from pyhealth.tasks.hurtful_words_bias import HurtfulWordsBiasTask + + +class DummyModel: + """ + Dummy model that returns predetermined log-probability scores. + """ + def __init__(self, scores): + self.scores = scores + self.idx = 0 + + def get_log_prob(self, text): + # Return the next score in the list + val = self.scores[self.idx] + self.idx += 1 + return val + + +def test_log_bias_and_precision_gap(): + # Prepare synthetic data + genders = ["female", "female", "male", "male"] + # Scores: female->[3.0,1.0], male->[2.0,0.0] + scores = [3.0, 1.0, 2.0, 0.0] + + data = [{"text": "", "gender": g} for g in genders] + model = DummyModel(scores) + + task = HurtfulWordsBiasTask(positive_group="female", negative_group="male") + results = task.evaluate(data, model, metrics=["log_bias", "precision_gap"]) + + # log_bias = mean(female)-mean(male) = (3+1)/2 - (2+0)/2 = 2 - 1 = 1 + assert pytest.approx(results["log_bias"], rel=1e-6) == 1.0 + + # precision_gap = 1.0 (privileged precision 1.0 vs unprivileged 0.0) + assert pytest.approx(results["precision_gap"], rel=1e-6) == 1.0 + + +def test_empty_data(): + # Edge case: no data + data = [] + model = DummyModel([]) + task = HurtfulWordsBiasTask() + + # Should return empty dict or zeros without raising error + results = task.evaluate(data, model, metrics=["log_bias", "precision_gap"]) + assert isinstance(results, dict) + assert results.get("log_bias", 0) == 0 or results.get("log_bias") is None + assert results.get("precision_gap", 0) == 0 or results.get("precision_gap") is None + + +def test_single_group_data(): + # Edge case: all records belong to positive_group + genders = ["female", "female"] + scores = [0.5, 0.7] + data = [{"text": "", "gender": g} for g in genders] + model = DummyModel(scores) + + task = HurtfulWordsBiasTask(positive_group="female", negative_group="male") + results = task.evaluate(data, model, metrics=["precision_gap"]) + + # Unprivileged group missing; precision_gap should be computed as difference with zero or None + # privileged precision = 1.0 (all predicted positive), unprivileged = 0.0 + assert pytest.approx(results["precision_gap"], rel=1e-6) == pytest.approx(1.0)