diff --git a/.gitignore b/.gitignore index 7e413875..8709beb5 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,7 @@ cython_debug/ #.idea/ output/ +models/ scrap/ .DS_Store .vscode/ @@ -168,3 +169,6 @@ scrap/ .cursor/ .private/ .idea/ + +# Personal AI context (keep local) +CLAUDE.LOCAL.MD diff --git a/CLAUDE.MD b/CLAUDE.MD new file mode 100644 index 00000000..1f199cfb --- /dev/null +++ b/CLAUDE.MD @@ -0,0 +1,193 @@ +# HealthChain - Claude Code Context + +## Project Overview + +HealthChain is an open-source Python framework for productionizing healthcare AI applications with native protocol understanding. It provides built-in FHIR support, real-time EHR connectivity, and production-ready deployment capabilities for AI/ML engineers working with healthcare systems. + +**Key Problem Solved**: EHR data is specific, complex, and fragmented. HealthChain eliminates months of custom integration work by providing native understanding of healthcare protocols and data formats. + +**Target Users**: +- HealthTech engineers building clinical workflow integrations +- LLM/GenAI developers aggregating multi-EHR data +- ML researchers deploying models as healthcare APIs + +## Architecture & Structure + +``` +healthchain/ +├── cli.py # Command-line interface +├── config/ # Configuration management +├── configs/ # YAML and Liquid templates +├── fhir/ # FHIR resource utilities and helpers +├── gateway/ # API gateways (FHIR, CDS Hooks) +├── interop/ # Format conversion (FHIR ↔ CDA) +├── io/ # Document and data I/O +├── models/ # Pydantic data models +├── pipeline/ # Pipeline components and NLP integrations +├── sandbox/ # Testing utilities with synthetic data +├── templates/ # Code generation templates +└── utils/ # Shared utilities + +tests/ # Test suite +cookbook/ # Usage examples and tutorials +docs/ # MkDocs documentation +``` + +## Core Modules + +### 1. Pipeline (`healthchain/pipeline/`) +- Build medical NLP pipelines with components like SpacyNLP +- Process clinical documents with automatic FHIR conversion +- Type-safe pipeline composition using generics + +### 2. Gateway (`healthchain/gateway/`) +- **FHIRGateway**: Connect to multiple FHIR sources, aggregate patient data +- **CDSHooksGateway**: Real-time clinical decision support integration with Epic/Cerner +- **HealthChainAPI**: FastAPI-based application framework + +### 3. FHIR Utilities (`healthchain/fhir/`) +- Type-safe FHIR resource creation and validation +- Bundle manipulation and resource extraction +- Recently refactored for clearer separation of concerns + +### 4. Interop (`healthchain/interop/`) +- Convert between FHIR and CDA formats +- Configuration-driven templates using Liquid +- Support for various healthcare data standards + +### 5. Sandbox (`healthchain/sandbox/`) +- Test CDS Hooks services with synthetic data +- Load from test datasets (Synthea, MIMIC) +- Request/response validation and debugging + +### 6. I/O (`healthchain/io/`) +- Document processing and management +- Data loading for ML workflows +- Recently refactored for better organization + +## Development Guidelines + +### Code Style +- **Linter**: Ruff for code formatting and linting +- **Type Hints**: Use Pydantic models and type annotations throughout +- **Python Version**: Support 3.9-3.11 (not 3.12+) +- **Testing**: pytest with async support (`pytest-asyncio`) + +### Key Dependencies +- **fhir.resources**: FHIR resource models (v8.0.0+) +- **FastAPI/Starlette**: API framework +- **Pydantic**: Data validation (v2.x, <2.11.0) +- **spaCy**: NLP processing (v3.x) +- **python-liquid**: Template engine for data conversion + +### Patterns & Conventions + +1. **Type Safety**: Leverage Pydantic models for all data structures +2. **Pipeline Pattern**: Use composable components with `Pipeline[T]` generic type +3. **Gateway Pattern**: Extend base gateway classes for new integrations +4. **Configuration**: Use YAML configs in `configs/` directory +5. **Templates**: Liquid templates for FHIR/CDA conversion + +### Testing +- Tests organized in `tests/` mirroring source structure +- Use pytest fixtures for common test data +- Async tests for gateway/API functionality +- Recently consolidated test structure + +### Documentation + +**Style Guide:** +- **Concise**: Get to the point quickly - developers want answers, not essays +- **Friendly**: Conversational but professional tone; use emojis sparingly in headers +- **Developer-Friendly**: Code examples first, explanations second; show don't tell +- **Scannable**: Use bullets, tables, clear sections; respect developer's time +- **Practical**: Focus on "how" over "why"; include working code examples + +**Good Documentation Examples:** +- `docs/index.md`: Clean feature overview, clear use case table, minimal prose +- `docs/quickstart.md`: Code-first approach, progressive complexity, practical examples +- `docs/cookbook/index.md`: Brief descriptions, clear outcomes, call-to-action + +**Anti-Patterns (avoid):** +- Long paragraphs explaining concepts before showing code +- Over-explaining obvious functionality +- Academic or overly formal tone +- Excessive background before getting to the practical content + +**Structure:** +- Lead with executable code examples +- Add brief context only where needed +- Use tables for feature comparisons +- Include links to full docs for deep dives +- Keep cookbook examples focused on one task + +**Technical Details:** +- MkDocs with Material theme +- API reference auto-generated from docstrings using mkdocstrings +- Cookbook examples for common use cases +- Follow existing docs/ structure for consistency + +## Recent Changes & Context + +Based on recent commits: +- **FHIR Helper Module**: Refactored for clearer separation of utilities +- **I/O Module**: Refactored for better organization +- **Test Consolidation**: Tests reorganized for clarity +- **MIMIC Loader**: Added support for loading as dict for ML workflows +- **Bundle Conversion**: Config-based conversion instead of params + +## Important Workflows + +### Adding a New Gateway +1. Create class in `healthchain/gateway/` extending base gateway +2. Implement required protocol methods +3. Add configuration in `configs/` +4. Create sandbox test in `healthchain/sandbox/` +5. Add cookbook example in `cookbook/` + +### Adding FHIR Resource Support +1. Use `fhir.resources` models +2. Add helper methods in `healthchain/fhir/` if needed +3. Update type hints and validation +4. Add tests with synthetic FHIR data + +### Adding Data Conversion Templates +1. Create Liquid template in `configs/` +2. Add configuration YAML +3. Implement in `healthchain/interop/` +4. Test with real healthcare data examples + +## Common Gotchas + +1. **Pydantic v2**: Use v2 patterns, but stay <2.11.0 for compatibility +2. **NumPy**: Locked to <2.0.0 for spaCy compatibility +3. **FHIR Validation**: Always validate resources before serialization +4. **Async/Sync**: Gateway operations are async, pipeline operations are sync +5. **Healthcare Standards**: Follow HL7 FHIR R4 and CDS Hooks specifications + +## Testing with Real Data + +- **Synthea**: Synthetic patient generator for realistic test data +- **MIMIC**: Medical Information Mart for Intensive Care dataset support +- **Sandbox**: Use `SandboxClient` for end-to-end testing without real EHR + +## Security & Compliance + +- OAuth2 authentication support for FHIR endpoints +- Audit trails and data provenance (roadmap item) +- HIPAA compliance features (roadmap item) +- No PHI in tests - use synthetic data only + +## Deployment + +- Docker/Kubernetes support (enhanced support on roadmap) +- FastAPI apps with Uvicorn +- OpenAPI/Swagger documentation auto-generated +- Environment-based configuration + +## Resources + +- Documentation: https://dotimplement.github.io/HealthChain/ +- Repository: https://github.com/dotimplement/HealthChain +- Discord: https://discord.gg/UQC6uAepUz +- Standards: HL7 FHIR R4, CDS Hooks diff --git a/README.md b/README.md index 9ef165b1..7b5238b0 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ [![Substack][substack-badge]][substack] [![Discord][discord-badge]][discord] +![AI-Assisted Development][ai-badge] @@ -242,6 +243,7 @@ This project builds on [fhir.resources](https://github.com/nazrulworld/fhir.reso [build-badge]: https://img.shields.io/github/actions/workflow/status/dotimplement/healthchain/ci.yml?branch=main&style=flat-square&color=%2379a8a9 [discord-badge]: https://img.shields.io/badge/chat-%235965f2?style=flat-square&logo=discord&logoColor=white [substack-badge]: https://img.shields.io/badge/Cool_Things_In_HealthTech-%23c094ff?style=flat-square&logo=substack&logoColor=white +[ai-badge]: https://img.shields.io/badge/AI--Assisted_Development_Friendly-CLAUDE.MD-%23FF6B6B?style=flat-square&logo=anthropic&logoColor=white [pypi]: https://pypi.org/project/healthchain/ [pypistats]: https://pepy.tech/project/healthchain diff --git a/cookbook/sepsis_prediction_inference.py b/cookbook/sepsis_prediction_inference.py new file mode 100644 index 00000000..33edb858 --- /dev/null +++ b/cookbook/sepsis_prediction_inference.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +""" +Sepsis Prediction Inference Script + +Demonstrates how to load and use the trained sepsis prediction model. + +Requirements: +- pip install scikit-learn xgboost joblib pandas numpy + +Usage: +- python sepsis_prediction_inference.py +""" + +import pandas as pd +import numpy as np +from pathlib import Path +from typing import Dict, Union, Tuple +import joblib + + +def load_model(model_path: Union[str, Path]) -> Dict: + """ + Load trained sepsis prediction model. + + Args: + model_path: Path to saved model file + + Returns: + Dictionary containing model, scaler, and metadata + """ + print(f"Loading model from {model_path}...") + model_data = joblib.load(model_path) + + metadata = model_data["metadata"] + print(f" Model: {metadata['model_name']}") + print(f" Training date: {metadata['training_date']}") + print(f" Features: {', '.join(metadata['feature_names'])}") + print(f" Test F1-score: {metadata['metrics']['f1']:.4f}") + print(f" Test AUC-ROC: {metadata['metrics']['auc']:.4f}") + + if "optimal_threshold" in metadata["metrics"]: + print(f" Optimal threshold: {metadata['metrics']['optimal_threshold']:.4f}") + print(f" Optimal F1-score: {metadata['metrics']['optimal_f1']:.4f}") + + return model_data + + +def predict_sepsis( + model_data: Dict, patient_features: pd.DataFrame, use_optimal_threshold: bool = True +) -> Tuple[np.ndarray, np.ndarray]: + """ + Predict sepsis risk for patient(s). + + Args: + model_data: Dictionary containing model, scaler, and metadata + patient_features: DataFrame with patient features + use_optimal_threshold: Whether to use optimal threshold (default: True) + + Returns: + Tuple of (predictions, probabilities) + """ + model = model_data["model"] + scaler = model_data["scaler"] + metadata = model_data["metadata"] + feature_names = metadata["feature_names"] + + # Ensure features are in correct order + patient_features = patient_features[feature_names] + + # Apply scaling if Logistic Regression + if scaler is not None: + patient_features_scaled = scaler.transform(patient_features) + probabilities = model.predict_proba(patient_features_scaled)[:, 1] + else: + probabilities = model.predict_proba(patient_features)[:, 1] + + # Use optimal threshold if available and requested + if use_optimal_threshold and "optimal_threshold" in metadata["metrics"]: + threshold = metadata["metrics"]["optimal_threshold"] + else: + threshold = 0.5 + + predictions = (probabilities >= threshold).astype(int) + + return predictions, probabilities + + +def create_example_patients() -> pd.DataFrame: + """ + Create example patient data for demonstration. + + Returns: + DataFrame with example patient features + """ + # Example patient data + # Patient 1: Healthy patient (low risk) + # Patient 2: Moderate risk (some abnormal values) + # Patient 3: Low risk (normal values) + # Patient 4: High risk for sepsis (multiple severe abnormalities) + # Patient 5: Critical sepsis risk (severe multi-organ dysfunction) + patients = pd.DataFrame( + { + "heart_rate": [85, 110, 75, 130, 145], # beats/min (normal: 60-100) + "temperature": [ + 37.2, + 38.5, + 36.8, + 39.2, + 35.5, + ], # Celsius (normal: 36.5-37.5, hypothermia <36) + "respiratory_rate": [16, 24, 14, 30, 35], # breaths/min (normal: 12-20) + "wbc": [8.5, 15.2, 7.0, 18.5, 22.0], # x10^9/L (normal: 4-11) + "lactate": [ + 1.2, + 3.5, + 0.9, + 4.8, + 6.5, + ], # mmol/L (normal: <2, severe sepsis: >4) + "creatinine": [0.9, 1.8, 0.8, 2.5, 3.2], # mg/dL (normal: 0.6-1.2) + "age": [45, 68, 35, 72, 78], # years + "gender_encoded": [1, 0, 1, 1, 0], # 1=Male, 0=Female + } + ) + + return patients + + +def interpret_results( + predictions: np.ndarray, probabilities: np.ndarray, patient_features: pd.DataFrame +) -> None: + """ + Interpret and display prediction results. + + Args: + predictions: Binary predictions (0=no sepsis, 1=sepsis) + probabilities: Probability scores + patient_features: Original patient features + """ + print("\n" + "=" * 80) + print("SEPSIS PREDICTION RESULTS") + print("=" * 80) + + for i in range(len(predictions)): + print(f"\nPatient {i+1}:") + print(f" Risk Score: {probabilities[i]:.2%}") + print(f" Prediction: {'SEPSIS RISK' if predictions[i] == 1 else 'Low Risk'}") + + # Show key vital signs + print(" Key Features:") + print(f" Heart Rate: {patient_features.iloc[i]['heart_rate']:.1f} bpm") + print(f" Temperature: {patient_features.iloc[i]['temperature']:.1f}°C") + print( + f" Respiratory Rate: {patient_features.iloc[i]['respiratory_rate']:.1f} /min" + ) + print(f" WBC: {patient_features.iloc[i]['wbc']:.1f} x10^9/L") + print(f" Lactate: {patient_features.iloc[i]['lactate']:.1f} mmol/L") + print(f" Creatinine: {patient_features.iloc[i]['creatinine']:.2f} mg/dL") + + # Risk interpretation + if probabilities[i] >= 0.7: + risk_level = "HIGH" + elif probabilities[i] >= 0.4: + risk_level = "MODERATE" + else: + risk_level = "LOW" + + print(f" Clinical Interpretation: {risk_level} RISK") + + print("\n" + "=" * 80) + + +def main(): + """Main inference pipeline.""" + # Model path (relative to script location) + script_dir = Path(__file__).parent + model_path = script_dir / "models" / "sepsis_model.pkl" + + print("=" * 80) + print("Sepsis Prediction Inference") + print("=" * 80 + "\n") + + # Load model + model_data = load_model(model_path) + + # Create example patients + print("\nCreating example patient data...") + patient_features = create_example_patients() + print(f"Number of patients: {len(patient_features)}") + + # Make predictions + print("\nMaking predictions...") + predictions, probabilities = predict_sepsis( + model_data, patient_features, use_optimal_threshold=True + ) + + # Interpret results + interpret_results(predictions, probabilities, patient_features) + + print("\n" + "=" * 80) + print("Inference complete!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/cookbook/sepsis_prediction_training.py b/cookbook/sepsis_prediction_training.py new file mode 100644 index 00000000..a0ea85ce --- /dev/null +++ b/cookbook/sepsis_prediction_training.py @@ -0,0 +1,1039 @@ +#!/usr/bin/env python3 +""" +Sepsis Prediction Training Script + +Trains Random Forest, XGBoost, and Logistic Regression models for sepsis prediction +using MIMIC-IV clinical database data. + +Requirements: +- pip install scikit-learn xgboost joblib pandas numpy + +Run: +- python sepsis_prediction_training.py +""" + +import pandas as pd +import numpy as np +from pathlib import Path +from datetime import datetime +from typing import Dict, Tuple, List, Any, Union + +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import ( + accuracy_score, + precision_score, + recall_score, + f1_score, + roc_auc_score, + precision_recall_curve, +) +import xgboost as xgb +import joblib + + +# MIMIC-IV ItemID mappings for features +CHARTEVENTS_ITEMIDS = { + "heart_rate": 220050, + "temperature_f": 223761, + "temperature_c": 223762, + "respiratory_rate": 220210, +} + +LABEVENTS_ITEMIDS = { + "wbc": [51300, 51301], # White Blood Cell Count + "lactate": 50813, + "creatinine": 50912, +} + +# Sepsis ICD-10 codes +SEPSIS_ICD10_CODES = [ + "A41.9", # Sepsis, unspecified organism + "A40", # Streptococcal sepsis (starts with) + "A41", # Other sepsis (starts with) + "R65.20", # Severe sepsis without shock + "R65.21", # Severe sepsis with shock + "R65.1", # SIRS (Systemic Inflammatory Response Syndrome) + "A41.0", # Sepsis due to Streptococcus, group A + "A41.1", # Sepsis due to Streptococcus, group B + "A41.2", # Sepsis due to other specified streptococci + "A41.3", # Sepsis due to Haemophilus influenzae + "A41.4", # Sepsis due to anaerobes + "A41.5", # Sepsis due to other Gram-negative organisms + "A41.50", # Sepsis due to unspecified Gram-negative organism + "A41.51", # Sepsis due to Escherichia coli + "A41.52", # Sepsis due to Pseudomonas + "A41.53", # Sepsis due to Serratia + "A41.59", # Sepsis due to other Gram-negative organisms + "A41.8", # Other specified sepsis + "A41.81", # Sepsis due to Enterococcus + "A41.89", # Other specified sepsis +] + +# Sepsis ICD-9 codes (for older data) +SEPSIS_ICD9_CODES = [ + "038", # Septicemia (starts with) + "99591", # Sepsis + "99592", # Severe sepsis + "78552", # Septic shock +] + + +def load_mimic_data(data_dir: str) -> Dict[str, pd.DataFrame]: + """ + Load all required MIMIC-IV CSV tables. + + Args: + data_dir: Path to MIMIC-IV dataset directory + + Returns: + Dictionary mapping table names to DataFrames + """ + data_dir = Path(data_dir) + + print("Loading MIMIC-IV data...") + + tables = { + "patients": pd.read_csv( + data_dir / "hosp" / "patients.csv.gz", compression="gzip", low_memory=False + ), + "admissions": pd.read_csv( + data_dir / "hosp" / "admissions.csv.gz", + compression="gzip", + low_memory=False, + ), + "icustays": pd.read_csv( + data_dir / "icu" / "icustays.csv.gz", compression="gzip", low_memory=False + ), + "chartevents": pd.read_csv( + data_dir / "icu" / "chartevents.csv.gz", + compression="gzip", + low_memory=False, + ), + "labevents": pd.read_csv( + data_dir / "hosp" / "labevents.csv.gz", compression="gzip", low_memory=False + ), + "diagnoses_icd": pd.read_csv( + data_dir / "hosp" / "diagnoses_icd.csv.gz", + compression="gzip", + low_memory=False, + ), + } + + print(f"Loaded {len(tables)} tables") + for name, df in tables.items(): + print(f" {name}: {len(df)} rows") + + return tables + + +def extract_chartevents_features( + chartevents: pd.DataFrame, icustays: pd.DataFrame +) -> pd.DataFrame: + """ + Extract 2-3 vital signs from chartevents table. + + Args: + chartevents: Chart events DataFrame + icustays: ICU stays DataFrame + + Returns: + DataFrame with features per stay_id + """ + print("Extracting chartevents features...") + + # Filter to relevant itemids + relevant_itemids = list(CHARTEVENTS_ITEMIDS.values()) + chartevents_filtered = chartevents[ + chartevents["itemid"].isin(relevant_itemids) + ].copy() + + # Merge with icustays to get stay times + chartevents_merged = chartevents_filtered.merge( + icustays[["stay_id", "intime", "outtime"]], on="stay_id", how="inner" + ) + + # Convert charttime to datetime + chartevents_merged["charttime"] = pd.to_datetime(chartevents_merged["charttime"]) + chartevents_merged["intime"] = pd.to_datetime(chartevents_merged["intime"]) + + # Filter to first 24 hours of ICU stay + chartevents_merged = chartevents_merged[ + (chartevents_merged["charttime"] >= chartevents_merged["intime"]) + & ( + chartevents_merged["charttime"] + <= chartevents_merged["intime"] + pd.Timedelta(hours=24) + ) + ] + + # Extract numeric values + chartevents_merged["valuenum"] = pd.to_numeric( + chartevents_merged["valuenum"], errors="coerce" + ) + + # Aggregate by stay_id and itemid (take mean) + features = [] + + for stay_id in icustays["stay_id"].unique(): + stay_data = chartevents_merged[chartevents_merged["stay_id"] == stay_id] + + feature_row = {"stay_id": stay_id} + + # Heart Rate + hr_data = stay_data[stay_data["itemid"] == CHARTEVENTS_ITEMIDS["heart_rate"]][ + "valuenum" + ] + feature_row["heart_rate"] = hr_data.mean() if not hr_data.empty else np.nan + + # Temperature (prefer Celsius, convert Fahrenheit if needed) + temp_c = stay_data[stay_data["itemid"] == CHARTEVENTS_ITEMIDS["temperature_c"]][ + "valuenum" + ] + temp_f = stay_data[stay_data["itemid"] == CHARTEVENTS_ITEMIDS["temperature_f"]][ + "valuenum" + ] + + if not temp_c.empty: + feature_row["temperature"] = temp_c.mean() + elif not temp_f.empty: + # Convert Fahrenheit to Celsius + feature_row["temperature"] = (temp_f.mean() - 32) * 5 / 9 + else: + feature_row["temperature"] = np.nan + + # Respiratory Rate + rr_data = stay_data[ + stay_data["itemid"] == CHARTEVENTS_ITEMIDS["respiratory_rate"] + ]["valuenum"] + feature_row["respiratory_rate"] = ( + rr_data.mean() if not rr_data.empty else np.nan + ) + + features.append(feature_row) + + return pd.DataFrame(features) + + +def extract_labevents_features( + labevents: pd.DataFrame, icustays: pd.DataFrame +) -> pd.DataFrame: + """ + Extract 2-3 lab values from labevents table. + + Args: + labevents: Lab events DataFrame + icustays: ICU stays DataFrame + + Returns: + DataFrame with features per stay_id + """ + print("Extracting labevents features...") + + # Get relevant itemids + relevant_itemids = [ + LABEVENTS_ITEMIDS["lactate"], + LABEVENTS_ITEMIDS["creatinine"], + ] + LABEVENTS_ITEMIDS["wbc"] + + labevents_filtered = labevents[labevents["itemid"].isin(relevant_itemids)].copy() + + # Merge with icustays via admissions + # First need to get hadm_id from icustays + icustays_with_hadm = icustays[["stay_id", "hadm_id", "intime"]].copy() + + # Labevents links via hadm_id, then we need to link to stay_id + labevents_merged = labevents_filtered.merge( + icustays_with_hadm, on="hadm_id", how="inner" + ) + + # Convert charttime to datetime + labevents_merged["charttime"] = pd.to_datetime(labevents_merged["charttime"]) + labevents_merged["intime"] = pd.to_datetime(labevents_merged["intime"]) + + # Filter to first 24 hours of ICU stay + labevents_merged = labevents_merged[ + (labevents_merged["charttime"] >= labevents_merged["intime"]) + & ( + labevents_merged["charttime"] + <= labevents_merged["intime"] + pd.Timedelta(hours=24) + ) + ] + + # Extract numeric values + labevents_merged["valuenum"] = pd.to_numeric( + labevents_merged["valuenum"], errors="coerce" + ) + + # Aggregate by stay_id and itemid + features = [] + + for stay_id in icustays["stay_id"].unique(): + stay_data = labevents_merged[labevents_merged["stay_id"] == stay_id] + + feature_row = {"stay_id": stay_id} + + # WBC (check both itemids) + wbc_data = stay_data[stay_data["itemid"].isin(LABEVENTS_ITEMIDS["wbc"])][ + "valuenum" + ] + feature_row["wbc"] = wbc_data.mean() if not wbc_data.empty else np.nan + + # Lactate + lactate_data = stay_data[stay_data["itemid"] == LABEVENTS_ITEMIDS["lactate"]][ + "valuenum" + ] + feature_row["lactate"] = ( + lactate_data.mean() if not lactate_data.empty else np.nan + ) + + # Creatinine + creatinine_data = stay_data[ + stay_data["itemid"] == LABEVENTS_ITEMIDS["creatinine"] + ]["valuenum"] + feature_row["creatinine"] = ( + creatinine_data.mean() if not creatinine_data.empty else np.nan + ) + + features.append(feature_row) + + return pd.DataFrame(features) + + +def extract_demographics( + patients: pd.DataFrame, admissions: pd.DataFrame, icustays: pd.DataFrame +) -> pd.DataFrame: + """ + Extract age and gender from patients table. + + Args: + patients: Patients DataFrame + admissions: Admissions DataFrame (not used, kept for compatibility) + icustays: ICU stays DataFrame + + Returns: + DataFrame with demographics per stay_id + """ + print("Extracting demographics...") + + # icustays already has subject_id, so merge directly with patients + icustays_with_patient = icustays[["stay_id", "subject_id"]].merge( + patients[["subject_id", "gender", "anchor_age"]], on="subject_id", how="left" + ) + + # Use anchor_age if available, otherwise calculate from anchor_year and anchor_age + # For demo data, anchor_age should be available + demographics = icustays_with_patient[["stay_id", "anchor_age", "gender"]].copy() + demographics.rename(columns={"anchor_age": "age"}, inplace=True) + + # Encode gender (M=1, F=0) + demographics["gender_encoded"] = (demographics["gender"] == "M").astype(int) + + return demographics[["stay_id", "age", "gender_encoded"]] + + +def extract_sepsis_labels( + diagnoses_icd: pd.DataFrame, icustays: pd.DataFrame +) -> pd.DataFrame: + """ + Extract sepsis labels from diagnoses_icd table. + Checks both ICD-9 and ICD-10 codes to maximize positive samples. + + Args: + diagnoses_icd: Diagnoses ICD DataFrame + icustays: ICU stays DataFrame + + Returns: + DataFrame with sepsis labels per stay_id + """ + print("Extracting sepsis labels...") + + # Check what ICD versions are available + icd_versions = diagnoses_icd["icd_version"].unique() + print(f" Available ICD versions: {sorted(icd_versions)}") + + all_sepsis_diagnoses = [] + + # Check ICD-10 codes + if 10 in icd_versions: + diagnoses_icd10 = diagnoses_icd[diagnoses_icd["icd_version"] == 10].copy() + print(f" ICD-10 diagnoses: {len(diagnoses_icd10)} rows") + + sepsis_mask = pd.Series( + [False] * len(diagnoses_icd10), index=diagnoses_icd10.index + ) + + for code in SEPSIS_ICD10_CODES: + if "." not in code or code.endswith("."): + # Pattern match (e.g., "A40" matches "A40.x") + code_prefix = code.rstrip(".") + mask = diagnoses_icd10["icd_code"].str.startswith(code_prefix, na=False) + sepsis_mask |= mask + if mask.sum() > 0: + print( + f" Found {mask.sum()} ICD-10 diagnoses matching pattern '{code}'" + ) + else: + # Exact match + mask = diagnoses_icd10["icd_code"] == code + sepsis_mask |= mask + if mask.sum() > 0: + print( + f" Found {mask.sum()} ICD-10 diagnoses with exact code '{code}'" + ) + + sepsis_icd10 = diagnoses_icd10[sepsis_mask].copy() + if len(sepsis_icd10) > 0: + all_sepsis_diagnoses.append(sepsis_icd10) + print(f" Total ICD-10 sepsis diagnoses: {len(sepsis_icd10)}") + + # Check ICD-9 codes + if 9 in icd_versions: + diagnoses_icd9 = diagnoses_icd[diagnoses_icd["icd_version"] == 9].copy() + print(f" ICD-9 diagnoses: {len(diagnoses_icd9)} rows") + + sepsis_mask = pd.Series( + [False] * len(diagnoses_icd9), index=diagnoses_icd9.index + ) + + for code in SEPSIS_ICD9_CODES: + if len(code) <= 3 or code.endswith("."): + # Pattern match (e.g., "038" matches "038.x") + code_prefix = code.rstrip(".") + mask = diagnoses_icd9["icd_code"].str.startswith(code_prefix, na=False) + sepsis_mask |= mask + if mask.sum() > 0: + print( + f" Found {mask.sum()} ICD-9 diagnoses matching pattern '{code}'" + ) + else: + # Exact match + mask = diagnoses_icd9["icd_code"] == code + sepsis_mask |= mask + if mask.sum() > 0: + print( + f" Found {mask.sum()} ICD-9 diagnoses with exact code '{code}'" + ) + + sepsis_icd9 = diagnoses_icd9[sepsis_mask].copy() + if len(sepsis_icd9) > 0: + all_sepsis_diagnoses.append(sepsis_icd9) + print(f" Total ICD-9 sepsis diagnoses: {len(sepsis_icd9)}") + + # Combine all sepsis diagnoses + if all_sepsis_diagnoses: + sepsis_diagnoses = pd.concat(all_sepsis_diagnoses, ignore_index=True) + print(f" Total sepsis diagnoses (ICD-9 + ICD-10): {len(sepsis_diagnoses)}") + + if len(sepsis_diagnoses) > 0: + print( + f" Sample sepsis ICD codes: {sepsis_diagnoses['icd_code'].unique()[:15].tolist()}" + ) + print( + f" Unique hadm_id with sepsis: {sepsis_diagnoses['hadm_id'].nunique()}" + ) + else: + sepsis_diagnoses = pd.DataFrame(columns=diagnoses_icd.columns) + print(" No sepsis diagnoses found") + + # Merge with icustays to get stay_id + icustays_with_hadm = icustays[["stay_id", "hadm_id"]].copy() + + if len(sepsis_diagnoses) > 0: + sepsis_labels = icustays_with_hadm.merge( + sepsis_diagnoses[["hadm_id"]].drop_duplicates(), + on="hadm_id", + how="left", + indicator=True, + ) + else: + sepsis_labels = icustays_with_hadm.copy() + sepsis_labels["_merge"] = "left_only" + + # Create binary label (1 if sepsis, 0 otherwise) + sepsis_labels["sepsis"] = (sepsis_labels["_merge"] == "both").astype(int) + + sepsis_count = sepsis_labels["sepsis"].sum() + print( + f" ICU stays with sepsis: {sepsis_count}/{len(sepsis_labels)} ({sepsis_count/len(sepsis_labels)*100:.2f}%)" + ) + + return sepsis_labels[["stay_id", "sepsis"]] + + +def print_feature_summary(X: pd.DataFrame): + """Print feature statistics with FHIR mapping information. + + Args: + X: Feature matrix with actual data + """ + print("\n" + "=" * 120) + print("FEATURE SUMMARY: MIMIC-IV → Model → FHIR Mapping") + print("=" * 120) + + # Define FHIR mappings for each feature + fhir_mappings = { + "heart_rate": { + "mimic_table": "chartevents", + "mimic_itemid": "220050", + "fhir_resource": "Observation", + "fhir_code": "8867-4", + "fhir_system": "LOINC", + "fhir_display": "Heart rate", + }, + "temperature": { + "mimic_table": "chartevents", + "mimic_itemid": "223762/223761", + "fhir_resource": "Observation", + "fhir_code": "8310-5", + "fhir_system": "LOINC", + "fhir_display": "Body temperature", + }, + "respiratory_rate": { + "mimic_table": "chartevents", + "mimic_itemid": "220210", + "fhir_resource": "Observation", + "fhir_code": "9279-1", + "fhir_system": "LOINC", + "fhir_display": "Respiratory rate", + }, + "wbc": { + "mimic_table": "labevents", + "mimic_itemid": "51300/51301", + "fhir_resource": "Observation", + "fhir_code": "6690-2", + "fhir_system": "LOINC", + "fhir_display": "Leukocytes [#/volume] in Blood", + }, + "lactate": { + "mimic_table": "labevents", + "mimic_itemid": "50813", + "fhir_resource": "Observation", + "fhir_code": "2524-7", + "fhir_system": "LOINC", + "fhir_display": "Lactate [Moles/volume] in Blood", + }, + "creatinine": { + "mimic_table": "labevents", + "mimic_itemid": "50912", + "fhir_resource": "Observation", + "fhir_code": "2160-0", + "fhir_system": "LOINC", + "fhir_display": "Creatinine [Mass/volume] in Serum or Plasma", + }, + "age": { + "mimic_table": "patients", + "mimic_itemid": "anchor_age", + "fhir_resource": "Patient", + "fhir_code": "birthDate", + "fhir_system": "FHIR Core", + "fhir_display": "Patient birth date (calculate age)", + }, + "gender_encoded": { + "mimic_table": "patients", + "mimic_itemid": "gender", + "fhir_resource": "Patient", + "fhir_code": "gender", + "fhir_system": "FHIR Core", + "fhir_display": "Administrative Gender (M/F)", + }, + } + + print( + f"\n{'Feature':<20} {'Mean±SD':<20} {'MIMIC Source':<20} {'FHIR Resource':<20} {'FHIR Code (System)':<30}" + ) + print("-" * 120) + + for feature in X.columns: + mapping = fhir_mappings.get(feature, {}) + + # Calculate statistics + mean_val = X[feature].mean() + std_val = X[feature].std() + + # Format based on feature type + if feature == "gender_encoded": + stats = f"{mean_val:.2f} (M={X[feature].sum():.0f})" + else: + stats = f"{mean_val:.2f}±{std_val:.2f}" + + mimic_source = f"{mapping.get('mimic_table', 'N/A')} ({mapping.get('mimic_itemid', 'N/A')})" + fhir_resource = mapping.get("fhir_resource", "N/A") + fhir_code = ( + f"{mapping.get('fhir_code', 'N/A')} ({mapping.get('fhir_system', 'N/A')})" + ) + + print( + f"{feature:<20} {stats:<20} {mimic_source:<20} {fhir_resource:<20} {fhir_code:<30}" + ) + + print("\n" + "=" * 120) + print( + "Note: Statistics calculated from first 24 hours of ICU stay. Missing values imputed with median." + ) + print("=" * 120 + "\n") + + +def create_feature_matrix( + chartevents_features: pd.DataFrame, + labevents_features: pd.DataFrame, + demographics: pd.DataFrame, + sepsis_labels: pd.DataFrame, +) -> Tuple[pd.DataFrame, pd.Series]: + """ + Create feature matrix and labels from extracted features. + + Args: + chartevents_features: Chart events features + labevents_features: Lab events features + demographics: Demographics features + sepsis_labels: Sepsis labels + + Returns: + Tuple of (feature matrix, labels) + """ + print("Creating feature matrix...") + + # Merge all features on stay_id + features = ( + chartevents_features.merge(labevents_features, on="stay_id", how="outer") + .merge(demographics, on="stay_id", how="outer") + .merge(sepsis_labels, on="stay_id", how="inner") + ) + + # Select feature columns (exclude stay_id and sepsis) + feature_cols = [ + "heart_rate", + "temperature", + "respiratory_rate", + "wbc", + "lactate", + "creatinine", + "age", + "gender_encoded", + ] + + X = features[feature_cols].copy() + y = features["sepsis"].copy() + + print(f"Feature matrix shape: {X.shape}") + print(f"Sepsis cases: {y.sum()} ({y.sum() / len(y) * 100:.2f}%)") + + return X, y + + +def train_models(X_train: pd.DataFrame, y_train: pd.Series) -> Dict[str, Any]: + """ + Train all three models (Random Forest, XGBoost, Logistic Regression). + + Args: + X_train: Training features + y_train: Training labels + + Returns: + Dictionary of trained models + """ + print("\nTraining models...") + + models = {} + + # Check if we have any positive samples + positive_samples = y_train.sum() + total_samples = len(y_train) + positive_rate = positive_samples / total_samples if total_samples > 0 else 0.0 + + print( + f" Positive samples: {positive_samples}/{total_samples} ({positive_rate*100:.2f}%)" + ) + + # Random Forest - use class_weight to handle imbalance + print(" Training Random Forest...") + rf = RandomForestClassifier( + n_estimators=100, + random_state=42, + n_jobs=-1, + class_weight="balanced", # Automatically adjust for class imbalance + ) + rf.fit(X_train, y_train) + models["RandomForest"] = rf + + # XGBoost - handle case with no positive samples + print(" Training XGBoost...") + if positive_samples == 0: + # When there are no positive samples, set base_score to a small value + # and use scale_pos_weight to avoid errors + xgb_model = xgb.XGBClassifier( + random_state=42, + n_jobs=-1, + eval_metric="logloss", + base_score=0.01, # Small positive value instead of 0 + scale_pos_weight=1.0, + ) + else: + # Calculate scale_pos_weight for imbalanced data + scale_pos_weight = (total_samples - positive_samples) / positive_samples + xgb_model = xgb.XGBClassifier( + random_state=42, + n_jobs=-1, + eval_metric="logloss", + scale_pos_weight=scale_pos_weight, + ) + xgb_model.fit(X_train, y_train) + models["XGBoost"] = xgb_model + + # Logistic Regression (with scaling) - use class_weight to handle imbalance + print(" Training Logistic Regression...") + scaler = StandardScaler() + X_train_scaled = scaler.fit_transform(X_train) + lr = LogisticRegression( + random_state=42, + max_iter=1000, + class_weight="balanced", # Automatically adjust for class imbalance + ) + lr.fit(X_train_scaled, y_train) + models["LogisticRegression"] = lr + models["scaler"] = scaler # Store scaler for later use + + return models + + +def evaluate_models( + models: Dict[str, Any], + X_test: pd.DataFrame, + y_test: pd.Series, + feature_names: List[str], +) -> Dict[str, Dict[str, float]]: + """ + Evaluate and compare all models. + + Args: + models: Dictionary of trained models + X_test: Test features + y_test: Test labels + feature_names: List of feature names + + Returns: + Dictionary of evaluation metrics for each model + """ + print("\nEvaluating models...") + print( + f"Test set: {len(y_test)} samples, {y_test.sum()} positive ({y_test.sum()/len(y_test)*100:.2f}%)" + ) + + results = {} + + for name, model in models.items(): + if name == "scaler": + continue + + # Get probability predictions + if name == "LogisticRegression": + X_test_scaled = models["scaler"].transform(X_test) + y_pred_proba = model.predict_proba(X_test_scaled)[:, 1] + else: + y_pred_proba = model.predict_proba(X_test)[:, 1] + + # Use default threshold (0.5) for predictions + y_pred = (y_pred_proba >= 0.5).astype(int) + + # Calculate metrics with default threshold + metrics = { + "accuracy": accuracy_score(y_test, y_pred), + "precision": precision_score(y_test, y_pred, zero_division=0), + "recall": recall_score(y_test, y_pred, zero_division=0), + "f1": f1_score(y_test, y_pred, zero_division=0), + "auc": roc_auc_score(y_test, y_pred_proba) + if len(np.unique(y_test)) > 1 + else 0.0, + } + + # Try to find optimal threshold for F1 score + if len(np.unique(y_test)) > 1 and y_test.sum() > 0: + precision, recall, thresholds = precision_recall_curve(y_test, y_pred_proba) + f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10) + optimal_idx = np.argmax(f1_scores) + optimal_threshold = ( + thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5 + ) + optimal_f1 = f1_scores[optimal_idx] + + # Predictions with optimal threshold + y_pred_optimal = (y_pred_proba >= optimal_threshold).astype(int) + metrics["optimal_threshold"] = optimal_threshold + metrics["optimal_f1"] = optimal_f1 + metrics["optimal_precision"] = precision_score( + y_test, y_pred_optimal, zero_division=0 + ) + metrics["optimal_recall"] = recall_score( + y_test, y_pred_optimal, zero_division=0 + ) + else: + metrics["optimal_threshold"] = 0.5 + metrics["optimal_f1"] = 0.0 + metrics["optimal_precision"] = 0.0 + metrics["optimal_recall"] = 0.0 + + results[name] = metrics + + print(f"\n{name}:") + print( + f" Predictions: {y_pred.sum()} positive predicted (actual: {y_test.sum()})" + ) + print(f" Accuracy: {metrics['accuracy']:.4f}") + print(f" Precision: {metrics['precision']:.4f}") + print(f" Recall: {metrics['recall']:.4f}") + print(f" F1-score: {metrics['f1']:.4f}") + print(f" AUC-ROC: {metrics['auc']:.4f}") + if metrics["optimal_f1"] > 0: + print(f" Optimal threshold: {metrics['optimal_threshold']:.4f}") + print(f" Optimal F1-score: {metrics['optimal_f1']:.4f}") + print(f" Optimal Precision: {metrics['optimal_precision']:.4f}") + print(f" Optimal Recall: {metrics['optimal_recall']:.4f}") + + # Show feature importance for tree-based models + if hasattr(model, "feature_importances_"): + print("\n Top 5 Feature Importances:") + importances = model.feature_importances_ + indices = np.argsort(importances)[::-1][:5] + for idx in indices: + print(f" {feature_names[idx]}: {importances[idx]:.4f}") + + return results + + +def select_best_model( + models: Dict[str, Any], + results: Dict[str, Dict[str, float]], + metric: str = "f1", +) -> Tuple[str, Any, Dict[str, float]]: + """ + Select best model based on specified metric. + + Args: + models: Dictionary of trained models + results: Evaluation results + metric: Metric to optimize ("f1", "recall", "precision", "auc") + + Returns: + Tuple of (best model name, best model, best metrics) + """ + print(f"\nSelecting best model based on {metric}...") + + # Get the appropriate metric value (prefer optimal if available) + def get_metric_value(metrics, metric_name): + if metric_name == "f1": + return metrics.get("optimal_f1", metrics["f1"]) + elif metric_name == "recall": + return metrics.get("optimal_recall", metrics["recall"]) + elif metric_name == "precision": + return metrics.get("optimal_precision", metrics["precision"]) + elif metric_name == "auc": + return metrics.get("auc", 0.0) + else: + return metrics.get("optimal_f1", metrics["f1"]) + + best_name = max(results.keys(), key=lambda k: get_metric_value(results[k], metric)) + best_model = models[best_name] + best_metrics = results[best_name] + + best_value = get_metric_value(best_metrics, metric) + print(f"Best model: {best_name} ({metric}: {best_value:.4f})") + + return best_name, best_model, best_metrics + + +def save_model( + model: Any, + model_name: str, + feature_names: List[str], + metrics: Dict[str, float], + scaler: Any, + output_path: Union[str, Path], +) -> None: + """ + Save the best model with metadata. + + Args: + model: Trained model + model_name: Name of the model + feature_names: List of feature names + metrics: Evaluation metrics + scaler: StandardScaler (if Logistic Regression, None otherwise) + output_path: Path to save model + """ + print(f"\nSaving model to {output_path}...") + + # Create output directory if it doesn't exist + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Prepare metadata + metadata = { + "model_name": model_name, + "training_date": datetime.now().isoformat(), + "feature_names": feature_names, + "metrics": metrics, + "itemid_mappings": { + "chartevents": CHARTEVENTS_ITEMIDS, + "labevents": LABEVENTS_ITEMIDS, + }, + "sepsis_icd_codes": { + "icd10": SEPSIS_ICD10_CODES, + "icd9": SEPSIS_ICD9_CODES, + }, + } + + # Save model and metadata + model_data = { + "model": model, + "scaler": scaler, + "metadata": metadata, + } + + joblib.dump(model_data, output_path) + + print("Model saved successfully!") + + +def main(): + """Main training pipeline.""" + # Data directory + data_dir = "../datasets/mimic-iv-clinical-database-demo-2.2" + + # Output path (relative to script location) + script_dir = Path(__file__).parent + output_path = script_dir / "models" / "sepsis_model.pkl" + + print("=" * 60) + print("Sepsis Prediction Model Training") + print("=" * 60) + + # Load data + tables = load_mimic_data(data_dir) + + # Extract features + chartevents_features = extract_chartevents_features( + tables["chartevents"], tables["icustays"] + ) + labevents_features = extract_labevents_features( + tables["labevents"], tables["icustays"] + ) + demographics = extract_demographics( + tables["patients"], tables["admissions"], tables["icustays"] + ) + + # Extract labels + sepsis_labels = extract_sepsis_labels(tables["diagnoses_icd"], tables["icustays"]) + + # Create feature matrix + X, y = create_feature_matrix( + chartevents_features, + labevents_features, + demographics, + sepsis_labels, + ) + + # Handle missing values (impute with median) + print("\nHandling missing values...") + missing_before = X.isnull().sum().sum() + print(f" Missing values before imputation: {missing_before}") + X = X.fillna(X.median()) + + # Print feature summary with actual data statistics + print_feature_summary(X) + + # Split data with careful stratification to ensure positive samples in both sets + print("\nSplitting data...") + if len(np.unique(y)) > 1 and y.sum() > 0: + # Use stratification to ensure positive samples in both train and test + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y + ) + print( + f" Training set: {len(X_train)} samples ({y_train.sum()} positive, {y_train.sum()/len(y_train)*100:.2f}%)" + ) + print( + f" Test set: {len(X_test)} samples ({y_test.sum()} positive, {y_test.sum()/len(y_test)*100:.2f}%)" + ) + + # Warn if test set has no positive samples (shouldn't happen with stratify, but check anyway) + if y_test.sum() == 0: + print( + " WARNING: Test set has no positive samples! Consider using a different random seed." + ) + else: + print( + " Warning: No positive samples or only one class. Skipping stratification." + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + print(f" Training set: {len(X_train)} samples") + print(f" Test set: {len(X_test)} samples") + + # Apply oversampling to training data to balance classes + print("\nApplying oversampling to training data...") + try: + from imblearn.over_sampling import SMOTE + + # Only apply SMOTE if we have positive samples + if y_train.sum() > 0 and len(np.unique(y_train)) > 1: + print( + f" Before oversampling: {len(X_train)} samples ({y_train.sum()} positive, {y_train.sum()/len(y_train)*100:.2f}%)" + ) + # Ensure k_neighbors doesn't exceed available positive samples + k_neighbors = min(5, max(1, y_train.sum() - 1)) + smote = SMOTE(random_state=42, k_neighbors=k_neighbors) + X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train) + print( + f" After oversampling: {len(X_train_resampled)} samples ({y_train_resampled.sum()} positive, {y_train_resampled.sum()/len(X_train_resampled)*100:.2f}%)" + ) + X_train = pd.DataFrame( + X_train_resampled, + columns=X_train.columns, + index=X_train.index[: len(X_train_resampled)], + ) + y_train = pd.Series( + y_train_resampled, index=y_train.index[: len(y_train_resampled)] + ) + else: + print(" Skipping oversampling: insufficient positive samples") + except (ImportError, ModuleNotFoundError) as e: + print( + " imbalanced-learn not installed. Install with: pip install imbalanced-learn" + ) + print(f" Error: {e}") + print(" Proceeding without oversampling...") + + # Train models + models = train_models(X_train, y_train) + + # Evaluate models + feature_names = X.columns.tolist() + results = evaluate_models(models, X_test, y_test, feature_names) + + # Select best model (can change metric: "f1", "recall", "precision", "auc") + # For sepsis prediction, recall (sensitivity) is often most important + best_name, best_model, best_metrics = select_best_model( + models, results, metric="f1" + ) + + # Save best model + scaler = models.get("scaler") + save_model( + best_model, + best_name, + feature_names, + best_metrics, + scaler, + output_path, + ) + + print("\n" + "=" * 60) + print("Training complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/docs/reference/utilities/sandbox.md b/docs/reference/utilities/sandbox.md index 3c666bb8..5ab4d952 100644 --- a/docs/reference/utilities/sandbox.md +++ b/docs/reference/utilities/sandbox.md @@ -140,6 +140,38 @@ data_dir/ ) ``` +=== "Direct Loader for ML Workflows" + ```python + # Use loader directly for ML pipelines (faster, no validation) + from healthchain.sandbox.loaders import MimicOnFHIRLoader + from healthchain.io import Dataset + + loader = MimicOnFHIRLoader() + + # as_dict=True: Returns single bundle dict (fast, no FHIR validation) + # Suitable for ML feature extraction workflows + bundle = loader.load( + data_dir="./data/mimic-iv-fhir", + resource_types=["MimicObservationChartevents", "MimicPatient"], + as_dict=True + ) + + # Convert to DataFrame for ML + dataset = Dataset.from_fhir_bundle( + bundle, + schema="healthchain/configs/features/sepsis_vitals.yaml" + ) + df = dataset.data + + # as_dict=False (default): Returns Dict[str, Bundle] + # Validated Bundle objects grouped by resource type (for CDS Hooks) + bundles = loader.load( + data_dir="./data/mimic-iv-fhir", + resource_types=["MimicMedication", "MimicCondition"] + ) + # Use bundles["medicationstatement"] and bundles["condition"] + ``` + ### Synthea Loader Synthetic patient data generated by [Synthea](https://synthea.mitre.org), containing realistic FHIR Bundles (typically 100-500 resources per patient). Ideal for single-patient workflows that require diverse data scenarios. diff --git a/healthchain/configs/features/sepsis_vitals.yaml b/healthchain/configs/features/sepsis_vitals.yaml new file mode 100644 index 00000000..52133afc --- /dev/null +++ b/healthchain/configs/features/sepsis_vitals.yaml @@ -0,0 +1,85 @@ +name: sepsis_prediction_features +version: "1.0" +description: Feature schema for sepsis prediction model trained on MIMIC-IV data + +model_info: + model_type: Random Forest / XGBoost / Logistic Regression + training_data: MIMIC-IV Clinical Database Demo + target: Sepsis (ICD-9/ICD-10 codes) + prediction_window: First 24 hours of ICU stay + +metadata: + age_calculation: event_date + event_date_source: Observation + event_date_strategy: earliest + +features: + heart_rate: + fhir_resource: Observation + code: "220045" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items + display: Heart Rate + unit: bpm + dtype: float64 + required: true + + temperature: + fhir_resource: Observation + code: "223761" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items + display: Temperature Fahrenheit + unit: °F + dtype: float64 + required: true + + respiratory_rate: + fhir_resource: Observation + code: "220210" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items + display: Respiratory Rate + unit: insp/min + dtype: float64 + required: true + + wbc: + fhir_resource: Observation + code: "51301" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems + display: White Blood Cells + unit: K/uL + dtype: float64 + required: true + + lactate: + fhir_resource: Observation + code: "50843" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems + display: Lactate Dehydrogenase, Ascites + unit: IU/L + dtype: float64 + required: true + + creatinine: + fhir_resource: Observation + code: "50912" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems + display: Creatinine + unit: mg/dL + dtype: float64 + required: true + + age: + fhir_resource: Patient + field: birthDate + transform: calculate_age + dtype: int64 + required: true + display: Patient age calculated from birth date + + gender_encoded: + fhir_resource: Patient + field: gender + transform: encode_gender + dtype: int64 + required: true + display: Administrative gender (M=1, F=0) diff --git a/healthchain/fhir/__init__.py b/healthchain/fhir/__init__.py index 9da081d3..ebd3c848 100644 --- a/healthchain/fhir/__init__.py +++ b/healthchain/fhir/__init__.py @@ -1,23 +1,32 @@ """FHIR utilities for HealthChain.""" -from healthchain.fhir.helpers import ( +from healthchain.fhir.resourcehelpers import ( create_condition, create_medication_statement, create_allergy_intolerance, + create_value_quantity_observation, + create_patient, + create_risk_assessment_from_prediction, + create_document_reference, + set_condition_category, + add_provenance_metadata, + add_coding_to_codeable_concept, +) + +from healthchain.fhir.elementhelpers import ( create_single_codeable_concept, create_single_reaction, - set_condition_category, - read_content_attachment, - create_document_reference, create_document_reference_content, create_single_attachment, +) + +from healthchain.fhir.readers import ( create_resource_from_dict, convert_prefetch_to_fhir_objects, - add_provenance_metadata, - add_coding_to_codeable_concept, + read_content_attachment, ) -from healthchain.fhir.bundle_helpers import ( +from healthchain.fhir.bundlehelpers import ( create_bundle, add_resource, get_resources, @@ -27,23 +36,42 @@ count_resources, ) +from healthchain.fhir.dataframe import ( + BundleConverterConfig, + bundle_to_dataframe, + get_supported_resources, + get_resource_info, + print_supported_resources, +) + +from healthchain.fhir.utilities import ( + calculate_age_from_birthdate, + calculate_age_from_event_date, + encode_gender, +) + __all__ = [ # Resource creation "create_condition", "create_medication_statement", "create_allergy_intolerance", + "create_value_quantity_observation", + "create_patient", + "create_risk_assessment_from_prediction", + "create_document_reference", + # Element creation "create_single_codeable_concept", "create_single_reaction", - "set_condition_category", - "read_content_attachment", - "create_document_reference", "create_document_reference_content", "create_single_attachment", - "create_resource_from_dict", - "convert_prefetch_to_fhir_objects", # Resource modification + "set_condition_category", "add_provenance_metadata", "add_coding_to_codeable_concept", + # Conversions and readers + "create_resource_from_dict", + "convert_prefetch_to_fhir_objects", + "read_content_attachment", # Bundle operations "create_bundle", "add_resource", @@ -52,4 +80,14 @@ "merge_bundles", "extract_resources", "count_resources", + # Bundle to DataFrame conversion + "BundleConverterConfig", + "bundle_to_dataframe", + "get_supported_resources", + "get_resource_info", + "print_supported_resources", + # Utility functions + "calculate_age_from_birthdate", + "calculate_age_from_event_date", + "encode_gender", ] diff --git a/healthchain/fhir/bundle_helpers.py b/healthchain/fhir/bundlehelpers.py similarity index 100% rename from healthchain/fhir/bundle_helpers.py rename to healthchain/fhir/bundlehelpers.py diff --git a/healthchain/fhir/dataframe.py b/healthchain/fhir/dataframe.py new file mode 100644 index 00000000..0c9bfdb2 --- /dev/null +++ b/healthchain/fhir/dataframe.py @@ -0,0 +1,600 @@ +"""FHIR to DataFrame converters. + +This module provides generic functions to convert FHIR Bundles to pandas DataFrames +for analysis and ML model deployment. + +In instances where there are multiple codes present for a single resource, the first code is used as the primary code. +""" + +import pandas as pd +import logging + +from typing import Any, Dict, List, Union, Optional, Literal +from collections import defaultdict +from fhir.resources.bundle import Bundle +from pydantic import BaseModel, field_validator, ConfigDict + +from healthchain.fhir.utilities import ( + calculate_age_from_birthdate, + calculate_age_from_event_date, + encode_gender, +) + +logger = logging.getLogger(__name__) + + +# Resource handler registry +SUPPORTED_RESOURCES = { + "Patient": { + "handler": "_flatten_patient", + "description": "Patient demographics (age, gender)", + "output_columns": ["age", "gender"], + }, + "Observation": { + "handler": "_flatten_observations", + "description": "Clinical observations (vitals, labs)", + "output_columns": "Dynamic based on observation codes", + "options": ["aggregation"], + }, + "Condition": { + "handler": "_flatten_conditions", + "description": "Conditions/diagnoses as binary indicators", + "output_columns": "Dynamic: condition_{code}_{display}", + }, + "MedicationStatement": { + "handler": "_flatten_medications", + "description": "Medications as binary indicators", + "output_columns": "Dynamic: medication_{code}_{display}", + }, +} + + +class BundleConverterConfig(BaseModel): + """Configuration for FHIR Bundle to DataFrame conversion. + + This configuration object controls which FHIR resources are processed and how + they are converted to DataFrame columns for ML model deployment. + + Attributes: + resources: List of FHIR resource types to include in the conversion + observation_aggregation: How to aggregate multiple observation values + age_calculation: Method for calculating patient age + event_date_source: Which resource to extract event date from + event_date_strategy: Which date to use when multiple dates exist + resource_options: Resource-specific configuration options (extensible) + + Example: + >>> config = BundleConverterConfig( + ... resources=["Patient", "Observation", "Condition"], + ... observation_aggregation="median" + ... ) + >>> df = bundle_to_dataframe(bundle, config=config) + """ + + # Core resources to include + resources: List[str] = ["Patient", "Observation"] + + # Observation-specific options + observation_aggregation: Literal["mean", "median", "max", "min", "last"] = "mean" + + # Patient age calculation + age_calculation: Literal["current_date", "event_date"] = "current_date" + event_date_source: Literal["Observation", "Encounter"] = "Observation" + event_date_strategy: Literal["earliest", "latest", "first"] = "earliest" + + # Resource-specific options (extensible for future use) + resource_options: Dict[str, Dict[str, Any]] = {} + + model_config = ConfigDict(extra="allow") + + @field_validator("resources") + @classmethod + def validate_resources(cls, v): + """Validate that requested resources are supported and warn about unsupported ones.""" + supported = get_supported_resources() + unsupported = [r for r in v if r not in supported] + if unsupported: + logger.warning( + f"Unsupported resources will be skipped: {unsupported}. " + f"Supported resources: {supported}" + ) + return v + + +def get_supported_resources() -> List[str]: + """Get list of supported FHIR resource types. + + Returns: + List of resource type names that can be converted to DataFrame columns + + Example: + >>> resources = get_supported_resources() + >>> print(resources) + ['Patient', 'Observation', 'Condition', 'MedicationStatement'] + """ + return list(SUPPORTED_RESOURCES.keys()) + + +def get_resource_info(resource_type: str) -> Dict[str, Any]: + """Get detailed information about a supported resource type. + + Args: + resource_type: FHIR resource type name + + Returns: + Dictionary with resource handler information, or empty dict if unsupported + + Example: + >>> info = get_resource_info("Observation") + >>> print(info["description"]) + 'Clinical observations (vitals, labs)' + """ + return SUPPORTED_RESOURCES.get(resource_type, {}) + + +def print_supported_resources() -> None: + """Print user-friendly list of supported FHIR resources for conversion. + + Example: + >>> from healthchain.fhir.converters import print_supported_resources + >>> print_supported_resources() + Supported FHIR Resources for ML Dataset Conversion: + + ✓ Patient + Patient demographics (age, gender) + Columns: age, gender + ... + """ + print("Supported FHIR Resources for ML Dataset Conversion:\n") + for resource, info in SUPPORTED_RESOURCES.items(): + print(f" ✓ {resource}") + print(f" {info['description']}") + if isinstance(info["output_columns"], list): + print(f" Columns: {', '.join(info['output_columns'])}") + else: + print(f" Columns: {info['output_columns']}") + if info.get("options"): + print(f" Options: {', '.join(info['options'])}") + print() + + +def _get_field(resource: Dict, field_name: str, default=None): + """Get field value from a dictionary.""" + return resource.get(field_name, default) + + +def _get_reference(field: Union[str, Dict[str, Any]]) -> Optional[str]: + """Extract reference string from a FHIR Reference field.""" + + if not field: + return None + + # Case 1: Already a string + if isinstance(field, str): + return field + + # Case 2: Dict with 'reference' field + return _get_field(field, "reference") + + +def extract_observation_value(observation: Dict) -> Optional[float]: + """Extract numeric value from an Observation dict. + + Handles different value types (valueQuantity, valueInteger, valueString) and + attempts to convert to float. + """ + + try: + value_quantity = _get_field(observation, "valueQuantity") + if value_quantity: + value = _get_field(value_quantity, "value") + if value is not None: + return float(value) + + value_int = _get_field(observation, "valueInteger") + if value_int is not None: + return float(value_int) + + value_str = _get_field(observation, "valueString") + if value_str: + return float(value_str) + + except (ValueError, TypeError): + pass + + return None + + +def extract_event_date( + resources: Dict[str, List[Any]], + source: str = "Observation", + strategy: str = "earliest", +) -> Optional[str]: + """Extract event date from patient resources for age calculation. + + Used primarily for MIMIC-IV on FHIR datasets where age is calculated + based on event dates rather than current date. + + Args: + resources: Dictionary of patient resources (from group_bundle_by_patient) + source: Which resource type to extract date from ("Observation" or "Encounter") + strategy: Which date to use ("earliest", "latest", "first") + + Returns: + Event date in ISO format, or None if no suitable date found + + Example: + >>> resources = {"observations": [obs1, obs2], "encounters": [enc1]} + >>> event_date = extract_event_date(resources, source="Observation", strategy="earliest") + """ + if source == "Observation": + items = resources.get("observations", []) + date_field = "effectiveDateTime" + elif source == "Encounter": + items = resources.get("encounters", []) + date_field = "period" + else: + return None + + if not items: + return None + + dates = [] + for item in items: + if source == "Encounter": + # Extract start date from period + period = _get_field(item, date_field) + if period: + start = _get_field(period, "start") + if start: + dates.append(start) + else: + # Direct date field + date_value = _get_field(item, date_field) + if date_value: + dates.append(date_value) + + if not dates: + return None + + # Apply strategy + if strategy == "earliest": + return min(dates) + elif strategy == "latest": + return max(dates) + elif strategy == "first": + return dates[0] + else: + return min(dates) # Default to earliest + + +def group_bundle_by_patient( + bundle: Union[Bundle, Dict[str, Any]], +) -> Dict[str, Dict[str, List[Any]]]: + """Group Bundle resources by patient reference. + + Organizes FHIR resources in a Bundle by their associated patient, making it easier + to process patient-centric data. Accepts both Pydantic Bundle objects and dicts, + converts to dict internally for performance. + + Args: + bundle: FHIR Bundle resource (Pydantic object or dict) + + Returns: + Dictionary mapping patient references to their resources: + { + "Patient/123": { + "patient": Patient resource dict, + "observations": [Observation dict, ...], + "conditions": [Condition dict, ...], + ... + } + } + """ + if not isinstance(bundle, dict): + bundle = bundle.model_dump() + + patient_data = defaultdict( + lambda: { + "patient": None, + "observations": [], + "conditions": [], + "medications": [], + "allergies": [], + "procedures": [], + "encounters": [], + "other": [], + } + ) + + # Get bundle entries + entries = _get_field(bundle, "entry") + if not entries: + return dict(patient_data) + + for entry in entries: + # Get resource from entry + resource = _get_field(entry, "resource") + if not resource: + continue + + resource_type = _get_field(resource, "resourceType") + resource_id = _get_field(resource, "id") + + if resource_type == "Patient": + patient_ref = f"Patient/{resource_id}" + patient_data[patient_ref]["patient"] = resource + + else: + # Get patient reference from resource + subject = _get_field(resource, "subject") + patient_field = _get_field(resource, "patient") + + patient_ref = _get_reference(subject) or _get_reference(patient_field) + + if patient_ref: + # Add to appropriate list based on resource type + if resource_type == "Observation": + patient_data[patient_ref]["observations"].append(resource) + elif resource_type == "Condition": + patient_data[patient_ref]["conditions"].append(resource) + elif resource_type == "MedicationStatement": + patient_data[patient_ref]["medications"].append(resource) + elif resource_type == "AllergyIntolerance": + patient_data[patient_ref]["allergies"].append(resource) + elif resource_type == "Procedure": + patient_data[patient_ref]["procedures"].append(resource) + elif resource_type == "Encounter": + patient_data[patient_ref]["encounters"].append(resource) + else: + patient_data[patient_ref]["other"].append(resource) + + return dict(patient_data) + + +def bundle_to_dataframe( + bundle: Union[Bundle, Dict[str, Any]], + config: Optional[BundleConverterConfig] = None, +) -> pd.DataFrame: + """Convert a FHIR Bundle to a pandas DataFrame. + + Converts FHIR resources to a tabular format with one row per patient. + Uses a configuration object to control which resources are processed and how. + + Args: + bundle: FHIR Bundle resource (object or dict) + config: BundleConverterConfig object specifying conversion behavior. + If None, uses default config (Patient + Observation with mean aggregation) + + Returns: + DataFrame with one row per patient and columns for each feature + + Example: + >>> from healthchain.fhir.converters import BundleConverterConfig + >>> + >>> # Default behavior + >>> df = bundle_to_dataframe(bundle) + >>> + >>> # Custom config + >>> config = BundleConverterConfig( + ... resources=["Patient", "Observation", "Condition"], + ... observation_aggregation="median", + ... age_calculation="event_date" + ... ) + >>> df = bundle_to_dataframe(bundle, config=config) + """ + # Use default config if not provided + if config is None: + config = BundleConverterConfig() + + # Group resources by patient + patient_data = group_bundle_by_patient(bundle) + + if not patient_data: + return pd.DataFrame() + + # Build rows for each patient + rows = [] + for patient_ref, resources in patient_data.items(): + row = {"patient_ref": patient_ref} + + # Process each requested resource type using registry + for resource_type in config.resources: + handler_info = SUPPORTED_RESOURCES.get(resource_type) + + if not handler_info: + # Skip unsupported resources gracefully (already warned by validator) + continue + + # Get handler function by name + handler_name = handler_info["handler"] + handler = globals()[handler_name] + + # Call handler with standardized signature + features = handler(resources, config) + if features: + row.update(features) + + rows.append(row) + + return pd.DataFrame(rows) + + +def _flatten_patient( + resources: Dict[str, Any], config: BundleConverterConfig +) -> Dict[str, Any]: + """Flatten patient demographics into feature columns. + + Args: + resources: Dictionary of patient resources + config: Converter configuration + + Returns: + Dictionary with age and gender features + """ + if not resources["patient"]: + return {} + + features = {} + patient = resources["patient"] + + birth_date = _get_field(patient, "birthDate") + gender = _get_field(patient, "gender") + + # Calculate age based on configuration + if config.age_calculation == "event_date": + event_date = extract_event_date( + resources, config.event_date_source, config.event_date_strategy + ) + features["age"] = calculate_age_from_event_date(birth_date, event_date) + else: + features["age"] = calculate_age_from_birthdate(birth_date) + + features["gender"] = encode_gender(gender) + + return features + + +def _flatten_observations( + resources: Dict[str, Any], config: BundleConverterConfig +) -> Dict[str, float]: + """Flatten observations into feature columns. + + Args: + resources: Dictionary of patient resources + config: Converter configuration + + Returns: + Dictionary with observation features + """ + observations = resources.get("observations", []) + aggregation = config.observation_aggregation + import numpy as np + + # Group observations by code + obs_by_code = defaultdict(list) + + for obs in observations: + code_field = _get_field(obs, "code") + if not code_field: + continue + + coding_array = _get_field(code_field, "coding") + if not coding_array or len(coding_array) == 0: + continue + + coding = coding_array[0] + code = _get_field(coding, "code") + display = _get_field(coding, "display") or code + system = _get_field(coding, "system") + + value = extract_observation_value(obs) + if value is not None: + obs_by_code[code].append( + { + "value": value, + "display": display, + "system": system, + } + ) + + # Aggregate and create feature columns + features = {} + for code, obs_list in obs_by_code.items(): + values = [item["value"] for item in obs_list] + display = obs_list[0]["display"] + + # Create column name: code_display + col_name = f"{code}_{display.replace(' ', '_')}" + + # Aggregate values + if aggregation == "mean": + features[col_name] = np.mean(values) + elif aggregation == "median": + features[col_name] = np.median(values) + elif aggregation == "max": + features[col_name] = np.max(values) + elif aggregation == "min": + features[col_name] = np.min(values) + elif aggregation == "last": + features[col_name] = values[-1] + else: + features[col_name] = np.mean(values) + + return features + + +def _flatten_conditions( + resources: Dict[str, Any], config: BundleConverterConfig +) -> Dict[str, int]: + """Flatten conditions into binary indicator columns. + + Args: + resources: Dictionary of patient resources + config: Converter configuration + + Returns: + Dictionary with condition indicator features + """ + conditions = resources.get("conditions", []) + features = {} + + for condition in conditions: + code_field = _get_field(condition, "code") + if not code_field: + continue + + coding_array = _get_field(code_field, "coding") + if not coding_array or len(coding_array) == 0: + continue + + # Get primary coding + coding = coding_array[0] + code = _get_field(coding, "code") + display = _get_field(coding, "display") or code + + # Create column name: condition_code_display + col_name = f"condition_{code}_{display.replace(' ', '_')}" + features[col_name] = 1 + + return features + + +def _flatten_medications( + resources: Dict[str, Any], config: BundleConverterConfig +) -> Dict[str, int]: + """Flatten medications into binary indicator columns. + + Args: + resources: Dictionary of patient resources + config: Converter configuration + + Returns: + Dictionary with medication indicator features + """ + medications = resources.get("medications", []) + features = {} + + for med in medications: + medication = _get_field(med, "medication") + if not medication: + continue + + med_concept = _get_field(medication, "concept") + if not med_concept: + continue + + coding_array = _get_field(med_concept, "coding") + if not coding_array or len(coding_array) == 0: + continue + + # Get primary coding + coding = coding_array[0] + code = _get_field(coding, "code") + display = _get_field(coding, "display") or code + + # Create column name: medication_code_display + col_name = f"medication_{code}_{display.replace(' ', '_')}" + features[col_name] = 1 + + return features diff --git a/healthchain/fhir/elementhelpers.py b/healthchain/fhir/elementhelpers.py new file mode 100644 index 00000000..c4b4532f --- /dev/null +++ b/healthchain/fhir/elementhelpers.py @@ -0,0 +1,109 @@ +"""FHIR element creation functions. + +This module provides convenience functions for creating FHIR elements that are used +as building blocks within FHIR resources (e.g., CodeableConcept, Attachment, Coding). +""" + +import logging +import base64 +import datetime + +from typing import Optional, List, Dict, Any +from fhir.resources.codeableconcept import CodeableConcept +from fhir.resources.codeablereference import CodeableReference +from fhir.resources.coding import Coding +from fhir.resources.attachment import Attachment + +logger = logging.getLogger(__name__) + + +def create_single_codeable_concept( + code: str, + display: Optional[str] = None, + system: Optional[str] = "http://snomed.info/sct", +) -> CodeableConcept: + """ + Create a minimal FHIR CodeableConcept with a single coding. + + Args: + code: REQUIRED. The code value from the code system + display: The display name for the code + system: The code system (default: SNOMED CT) + + Returns: + CodeableConcept: A FHIR CodeableConcept resource with a single coding + """ + return CodeableConcept(coding=[Coding(system=system, code=code, display=display)]) + + +def create_single_reaction( + code: str, + display: Optional[str] = None, + system: Optional[str] = "http://snomed.info/sct", + severity: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Create a minimal FHIR Reaction with a single coding. + + Creates a FHIR Reaction object with a single manifestation coding. The manifestation + describes the clinical reaction that was observed. The severity indicates how severe + the reaction was. + + Args: + code: REQUIRED. The code value from the code system representing the reaction manifestation + display: The display name for the manifestation code + system: The code system for the manifestation code (default: SNOMED CT) + severity: The severity of the reaction (mild, moderate, severe) + + Returns: + A list containing a single FHIR Reaction dictionary with manifestation and severity fields + """ + return [ + { + "manifestation": [ + CodeableReference( + concept=CodeableConcept( + coding=[Coding(system=system, code=code, display=display)] + ) + ) + ], + "severity": severity, + } + ] + + +def create_single_attachment( + content_type: Optional[str] = None, + data: Optional[str] = None, + url: Optional[str] = None, + title: Optional[str] = "Attachment created by HealthChain", +) -> Attachment: + """Create a minimal FHIR Attachment. + + Creates a FHIR Attachment resource with basic fields. Either data or url should be provided. + If data is provided, it will be base64 encoded. + + Args: + content_type: The MIME type of the content + data: The actual data content to be base64 encoded + url: The URL where the data can be found + title: A title for the attachment (default: "Attachment created by HealthChain") + + Returns: + Attachment: A FHIR Attachment resource with basic metadata and content + """ + + if not data and not url: + logger.warning("No data or url provided for attachment") + + if data: + data = base64.b64encode(data.encode("utf-8")).decode("utf-8") + + return Attachment( + contentType=content_type, + data=data, + url=url, + title=title, + creation=datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%S%z" + ), + ) diff --git a/healthchain/fhir/readers.py b/healthchain/fhir/readers.py new file mode 100644 index 00000000..7d7bbd06 --- /dev/null +++ b/healthchain/fhir/readers.py @@ -0,0 +1,137 @@ +"""FHIR conversion and reading functions. + +This module provides functions for converting between different FHIR representations +and reading data from FHIR resources. +""" + +import logging +import importlib + +from typing import Optional, Dict, Any, List +from fhir.resources.resource import Resource +from fhir.resources.documentreference import DocumentReference + +logger = logging.getLogger(__name__) + + +def create_resource_from_dict( + resource_dict: Dict, resource_type: str +) -> Optional[Resource]: + """Create a FHIR resource instance from a dictionary + + Args: + resource_dict: Dictionary representation of the resource + resource_type: Type of FHIR resource to create + + Returns: + Optional[Resource]: FHIR resource instance or None if creation failed + """ + try: + resource_module = importlib.import_module( + f"fhir.resources.{resource_type.lower()}" + ) + resource_class = getattr(resource_module, resource_type) + return resource_class(**resource_dict) + except Exception as e: + logger.error(f"Failed to create FHIR resource: {str(e)}") + return None + + +def convert_prefetch_to_fhir_objects( + prefetch_dict: Dict[str, Any], +) -> Dict[str, Resource]: + """Convert a dictionary of FHIR resource dicts to FHIR Resource objects. + + Takes a prefetch dictionary where values may be either dict representations of FHIR + resources or already instantiated FHIR Resource objects, and ensures all values are + FHIR Resource objects. + + Args: + prefetch_dict: Dictionary mapping keys to FHIR resource dicts or objects + + Returns: + Dict[str, Resource]: Dictionary with same keys but all values as FHIR Resource objects + + Example: + >>> prefetch = { + ... "patient": {"resourceType": "Patient", "id": "123"}, + ... "condition": Condition(id="456", ...) + ... } + >>> fhir_objects = convert_prefetch_to_fhir_objects(prefetch) + >>> isinstance(fhir_objects["patient"], Patient) # True + >>> isinstance(fhir_objects["condition"], Condition) # True + """ + from fhir.resources import get_fhir_model_class + + result: Dict[str, Resource] = {} + + for key, resource_data in prefetch_dict.items(): + if isinstance(resource_data, dict): + # Convert dict to FHIR Resource object + resource_type = resource_data.get("resourceType") + if resource_type: + try: + resource_class = get_fhir_model_class(resource_type) + result[key] = resource_class(**resource_data) + except Exception as e: + logger.warning( + f"Failed to convert {resource_type} to FHIR object: {e}" + ) + result[key] = resource_data + else: + logger.warning( + f"No resourceType found for key '{key}', keeping as dict" + ) + result[key] = resource_data + elif isinstance(resource_data, Resource): + # Already a FHIR object + result[key] = resource_data + else: + logger.warning(f"Unexpected type for key '{key}': {type(resource_data)}") + result[key] = resource_data + + return result + + +def read_content_attachment( + document_reference: DocumentReference, + include_data: bool = True, +) -> Optional[List[Dict[str, Any]]]: + """Read the attachments in a human readable format from a FHIR DocumentReference content field. + + Args: + document_reference: The FHIR DocumentReference resource + include_data: Whether to include the data of the attachments. If true, the data will be also be decoded (default: True) + + Returns: + Optional[List[Dict[str, Any]]]: List of dictionaries containing attachment data and metadata, + or None if no attachments are found: + [ + { + "data": str, + "metadata": Dict[str, Any] + } + ] + """ + if not document_reference.content: + return None + + attachments = [] + for content in document_reference.content: + attachment = content.attachment + result = {} + + if include_data: + result["data"] = ( + attachment.url if attachment.url else attachment.data.decode("utf-8") + ) + + result["metadata"] = { + "content_type": attachment.contentType, + "title": attachment.title, + "creation": attachment.creation, + } + + attachments.append(result) + + return attachments diff --git a/healthchain/fhir/helpers.py b/healthchain/fhir/resourcehelpers.py similarity index 62% rename from healthchain/fhir/helpers.py rename to healthchain/fhir/resourcehelpers.py index 95114272..a1cacf3c 100644 --- a/healthchain/fhir/helpers.py +++ b/healthchain/fhir/resourcehelpers.py @@ -1,214 +1,43 @@ -"""Convenience functions for creating minimal FHIR resources. +"""FHIR resource creation and modification functions. + +This module provides convenience functions for creating and modifying FHIR resources. + Patterns: -- create_*(): create a new FHIR resource with sensible defaults - useful for dev, use with caution -- add_*(): add data to resources with list fields safely (e.g. coding) -- set_*(): set the field of specific resources with soft validation (e.g. category) -- read_*(): return a human readable format of the data in a resource (e.g. attachments) +- create_*(): create a new FHIR resource with sensible defaults +- set_*(): set specific fields of resources with soft validation +- add_*(): add data to resources safely + +Parameters marked REQUIRED are required by FHIR specification. """ import logging -import base64 import datetime -import uuid -import importlib from typing import Optional, List, Dict, Any from fhir.resources.condition import Condition from fhir.resources.medicationstatement import MedicationStatement from fhir.resources.allergyintolerance import AllergyIntolerance from fhir.resources.documentreference import DocumentReference +from fhir.resources.observation import Observation +from fhir.resources.riskassessment import RiskAssessment +from fhir.resources.patient import Patient +from fhir.resources.quantity import Quantity from fhir.resources.codeableconcept import CodeableConcept -from fhir.resources.codeablereference import CodeableReference -from fhir.resources.coding import Coding -from fhir.resources.attachment import Attachment -from fhir.resources.resource import Resource from fhir.resources.reference import Reference from fhir.resources.meta import Meta +from fhir.resources.coding import Coding +from fhir.resources.identifier import Identifier +from fhir.resources.resource import Resource +from healthchain.fhir.elementhelpers import ( + create_single_codeable_concept, + create_single_attachment, +) +from healthchain.fhir.utilities import _generate_id logger = logging.getLogger(__name__) -def _generate_id() -> str: - """Generate a unique ID prefixed with 'hc-'. - - Returns: - str: A unique ID string prefixed with 'hc-' - """ - return f"hc-{str(uuid.uuid4())}" - - -def create_resource_from_dict( - resource_dict: Dict, resource_type: str -) -> Optional[Resource]: - """Create a FHIR resource instance from a dictionary - - Args: - resource_dict: Dictionary representation of the resource - resource_type: Type of FHIR resource to create - - Returns: - Optional[Resource]: FHIR resource instance or None if creation failed - """ - try: - resource_module = importlib.import_module( - f"fhir.resources.{resource_type.lower()}" - ) - resource_class = getattr(resource_module, resource_type) - return resource_class(**resource_dict) - except Exception as e: - logger.error(f"Failed to create FHIR resource: {str(e)}") - return None - - -def convert_prefetch_to_fhir_objects( - prefetch_dict: Dict[str, Any], -) -> Dict[str, Resource]: - """Convert a dictionary of FHIR resource dicts to FHIR Resource objects. - - Takes a prefetch dictionary where values may be either dict representations of FHIR - resources or already instantiated FHIR Resource objects, and ensures all values are - FHIR Resource objects. - - Args: - prefetch_dict: Dictionary mapping keys to FHIR resource dicts or objects - - Returns: - Dict[str, Resource]: Dictionary with same keys but all values as FHIR Resource objects - - Example: - >>> prefetch = { - ... "patient": {"resourceType": "Patient", "id": "123"}, - ... "condition": Condition(id="456", ...) - ... } - >>> fhir_objects = convert_prefetch_to_fhir_objects(prefetch) - >>> isinstance(fhir_objects["patient"], Patient) # True - >>> isinstance(fhir_objects["condition"], Condition) # True - """ - from fhir.resources import get_fhir_model_class - - result: Dict[str, Resource] = {} - - for key, resource_data in prefetch_dict.items(): - if isinstance(resource_data, dict): - # Convert dict to FHIR Resource object - resource_type = resource_data.get("resourceType") - if resource_type: - try: - resource_class = get_fhir_model_class(resource_type) - result[key] = resource_class(**resource_data) - except Exception as e: - logger.warning( - f"Failed to convert {resource_type} to FHIR object: {e}" - ) - result[key] = resource_data - else: - logger.warning( - f"No resourceType found for key '{key}', keeping as dict" - ) - result[key] = resource_data - elif isinstance(resource_data, Resource): - # Already a FHIR object - result[key] = resource_data - else: - logger.warning(f"Unexpected type for key '{key}': {type(resource_data)}") - result[key] = resource_data - - return result - - -def create_single_codeable_concept( - code: str, - display: Optional[str] = None, - system: Optional[str] = "http://snomed.info/sct", -) -> CodeableConcept: - """ - Create a minimal FHIR CodeableConcept with a single coding. - - Args: - code: REQUIRED. The code value from the code system - display: The display name for the code - system: The code system (default: SNOMED CT) - - Returns: - CodeableConcept: A FHIR CodeableConcept resource with a single coding - """ - return CodeableConcept(coding=[Coding(system=system, code=code, display=display)]) - - -def create_single_reaction( - code: str, - display: Optional[str] = None, - system: Optional[str] = "http://snomed.info/sct", - severity: Optional[str] = None, -) -> List[Dict[str, Any]]: - """Create a minimal FHIR Reaction with a single coding. - - Creates a FHIR Reaction object with a single manifestation coding. The manifestation - describes the clinical reaction that was observed. The severity indicates how severe - the reaction was. - - Args: - code: REQUIRED. The code value from the code system representing the reaction manifestation - display: The display name for the manifestation code - system: The code system for the manifestation code (default: SNOMED CT) - severity: The severity of the reaction (mild, moderate, severe) - - Returns: - A list containing a single FHIR Reaction dictionary with manifestation and severity fields - """ - return [ - { - "manifestation": [ - CodeableReference( - concept=CodeableConcept( - coding=[Coding(system=system, code=code, display=display)] - ) - ) - ], - "severity": severity, - } - ] - - -def create_single_attachment( - content_type: Optional[str] = None, - data: Optional[str] = None, - url: Optional[str] = None, - title: Optional[str] = "Attachment created by HealthChain", -) -> Attachment: - """Create a minimal FHIR Attachment. - - Creates a FHIR Attachment resource with basic fields. Either data or url should be provided. - If data is provided, it will be base64 encoded. - - Args: - content_type: The MIME type of the content - data: The actual data content to be base64 encoded - url: The URL where the data can be found - title: A title for the attachment (default: "Attachment created by HealthChain") - - Returns: - Attachment: A FHIR Attachment resource with basic metadata and content - """ - - if not data and not url: - logger.warning("No data or url provided for attachment") - - if data: - data = base64.b64encode(data.encode("utf-8")).decode("utf-8") - - return Attachment( - contentType=content_type, - data=data, - url=url, - title=title, - creation=datetime.datetime.now(datetime.timezone.utc).strftime( - "%Y-%m-%dT%H:%M:%S%z" - ), - ) - - def create_condition( subject: str, clinical_status: str = "active", @@ -321,6 +150,187 @@ def create_allergy_intolerance( return allergy +def create_value_quantity_observation( + code: str, + value: float, + unit: str, + status: str = "final", + subject: Optional[str] = None, + system: str = "http://loinc.org", + display: Optional[str] = None, + effective_datetime: Optional[str] = None, +) -> Observation: + """ + Create a minimal FHIR Observation for vital signs or laboratory values. + If you need to create a more complex observation, use the FHIR Observation resource directly. + https://hl7.org/fhir/observation.html + + Args: + status: REQUIRED. The status of the observation (default: "final") + code: REQUIRED. The observation code (e.g., LOINC code for the measurement) + value: The numeric value of the observation + unit: The unit of measure (e.g., "beats/min", "mg/dL") + system: The code system for the observation code (default: LOINC) + display: The display name for the observation code + effective_datetime: When the observation was made (ISO format). Uses current time if not provided. + subject: Reference to the patient (e.g. "Patient/123") + + Returns: + Observation: A FHIR Observation resource with an auto-generated ID prefixed with 'hc-' + """ + if not effective_datetime: + effective_datetime = datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%S%z" + ) + subject_ref = None + if subject is not None: + subject_ref = Reference(reference=subject) + + observation = Observation( + id=_generate_id(), + status=status, + code=create_single_codeable_concept(code, display, system), + subject=subject_ref, + effectiveDateTime=effective_datetime, + valueQuantity=Quantity( + value=value, unit=unit, system="http://unitsofmeasure.org", code=unit + ), + ) + + return observation + + +def create_patient( + gender: Optional[str] = None, + birth_date: Optional[str] = None, + identifier: Optional[str] = None, + identifier_system: Optional[str] = "http://hospital.example.org", +) -> Patient: + """ + Create a minimal FHIR Patient resource with basic gender and birthdate + If you need to create a more complex patient, use the FHIR Patient resource directly + https://hl7.org/fhir/patient.html (No required fields). + + Args: + gender: Administrative gender (male, female, other, unknown) + birth_date: Birth date in YYYY-MM-DD format + identifier: Optional identifier value for the patient (e.g., MRN) + identifier_system: The system for the identifier (default: "http://hospital.example.org") + + Returns: + Patient: A FHIR Patient resource with an auto-generated ID prefixed with 'hc-' + """ + patient_id = _generate_id() + + patient_data = {"id": patient_id} + + if birth_date: + patient_data["birthDate"] = birth_date + + if gender: + patient_data["gender"] = gender.lower() + + if identifier: + patient_data["identifier"] = [ + Identifier( + system=identifier_system, + value=identifier, + ) + ] + + patient = Patient(**patient_data) + return patient + + +def create_risk_assessment_from_prediction( + subject: str, + prediction: Dict[str, Any], + status: str = "final", + method: Optional[CodeableConcept] = None, + basis: Optional[List[Reference]] = None, + comment: Optional[str] = None, + occurrence_datetime: Optional[str] = None, +) -> RiskAssessment: + """ + Create a FHIR RiskAssessment from ML model prediction output. + If you need to create a more complex risk assessment, use the FHIR RiskAssessment resource directly. + https://hl7.org/fhir/riskassessment.html + + Args: + subject: REQUIRED. Reference to the patient (e.g. "Patient/123") + prediction: Dictionary containing prediction details with keys: + - outcome: CodeableConcept or dict with code, display, system for the predicted outcome + - probability: float between 0 and 1 representing the risk probability + - qualitative_risk: Optional str indicating risk level (e.g., "high", "moderate", "low") + status: REQUIRED. The status of the assessment (default: "final") + method: Optional CodeableConcept describing the assessment method/model used + basis: Optional list of References to observations or other resources used as input + comment: Optional text comment about the assessment + + occurrence_datetime: When the assessment was made (ISO format). Uses current time if not provided. + + Returns: + RiskAssessment: A FHIR RiskAssessment resource with an auto-generated ID prefixed with 'hc-' + + Example: + >>> prediction = { + ... "outcome": {"code": "A41.9", "display": "Sepsis", "system": "http://hl7.org/fhir/sid/icd-10"}, + ... "probability": 0.85, + ... "qualitative_risk": "high" + ... } + >>> risk = create_risk_assessment("Patient/123", prediction) + """ + if not occurrence_datetime: + occurrence_datetime = datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%S%z" + ) + + outcome = prediction.get("outcome") + if isinstance(outcome, dict): + outcome_concept = create_single_codeable_concept( + code=outcome["code"], + display=outcome.get("display"), + system=outcome.get("system", "http://snomed.info/sct"), + ) + else: + outcome_concept = outcome + + prediction_data = { + "outcome": outcome_concept, + } + + if "probability" in prediction: + prediction_data["probabilityDecimal"] = prediction["probability"] + + if "qualitative_risk" in prediction: + prediction_data["qualitativeRisk"] = create_single_codeable_concept( + code=prediction["qualitative_risk"], + display=prediction["qualitative_risk"].capitalize(), + system="http://terminology.hl7.org/CodeSystem/risk-probability", + ) + + risk_assessment_data = { + "id": _generate_id(), + "status": status, + "subject": Reference(reference=subject), + "occurrenceDateTime": occurrence_datetime, + "prediction": [prediction_data], + } + + if method: + risk_assessment_data["method"] = method + + if basis: + risk_assessment_data["basis"] = basis + + if comment: + risk_assessment_data["note"] = [{"text": comment}] + + risk_assessment = RiskAssessment(**risk_assessment_data) + + return risk_assessment + + def create_document_reference( data: Optional[Any] = None, url: Optional[str] = None, @@ -566,47 +576,3 @@ def add_coding_to_codeable_concept( codeable_concept.coding.append(Coding(system=system, code=code, display=display)) return codeable_concept - - -def read_content_attachment( - document_reference: DocumentReference, - include_data: bool = True, -) -> Optional[List[Dict[str, Any]]]: - """Read the attachments in a human readable format from a FHIR DocumentReference content field. - - Args: - document_reference: The FHIR DocumentReference resource - include_data: Whether to include the data of the attachments. If true, the data will be also be decoded (default: True) - - Returns: - Optional[List[Dict[str, Any]]]: List of dictionaries containing attachment data and metadata, - or None if no attachments are found: - [ - { - "data": str, - "metadata": Dict[str, Any] - } - ] - """ - if not document_reference.content: - return None - - attachments = [] - for content in document_reference.content: - attachment = content.attachment - result = {} - - if include_data: - result["data"] = ( - attachment.url if attachment.url else attachment.data.decode("utf-8") - ) - - result["metadata"] = { - "content_type": attachment.contentType, - "title": attachment.title, - "creation": attachment.creation, - } - - attachments.append(result) - - return attachments diff --git a/healthchain/fhir/utilities.py b/healthchain/fhir/utilities.py new file mode 100644 index 00000000..31788d42 --- /dev/null +++ b/healthchain/fhir/utilities.py @@ -0,0 +1,117 @@ +"""FHIR utility functions. + +This module provides utility functions for common operations like ID generation, +age calculation, and gender encoding. +""" + +import datetime +import uuid +from typing import Optional + + +def _generate_id() -> str: + """Generate a unique ID prefixed with 'hc-'. + + Returns: + str: A unique ID string prefixed with 'hc-' + """ + return f"hc-{str(uuid.uuid4())}" + + +def calculate_age_from_birthdate(birth_date: str) -> Optional[int]: + """Calculate age in years from a birth date string. + + Args: + birth_date: Birth date in ISO format (YYYY-MM-DD or full ISO datetime) + + Returns: + Age in years, or None if birth date is invalid + """ + if not birth_date: + return None + + try: + if isinstance(birth_date, str): + # Remove timezone info for simpler parsing + birth_date_clean = birth_date.replace("Z", "").split("T")[0] + birth_dt = datetime.datetime.strptime(birth_date_clean, "%Y-%m-%d") + else: + birth_dt = birth_date + + # Calculate age + today = datetime.datetime.now() + age = today.year - birth_dt.year + + # Adjust if birthday hasn't occurred this year + if (today.month, today.day) < (birth_dt.month, birth_dt.day): + age -= 1 + + return age + except (ValueError, AttributeError, TypeError): + return None + + +def calculate_age_from_event_date(birth_date: str, event_date: str) -> Optional[int]: + """Calculate age in years from birth date and event date (MIMIC-IV style). + + Uses the formula: age = year(eventDate) - year(birthDate) + This matches MIMIC-IV on FHIR de-identified age calculation. + + Args: + birth_date: Birth date in ISO format (YYYY-MM-DD or full ISO datetime) + event_date: Event date in ISO format (YYYY-MM-DD or full ISO datetime) + + Returns: + Age in years based on year difference, or None if dates are invalid + + Example: + >>> calculate_age_from_event_date("1990-06-15", "2020-03-10") + 30 + """ + if not birth_date or not event_date: + return None + + try: + # Parse birth date + if isinstance(birth_date, str): + birth_date_clean = birth_date.replace("Z", "").split("T")[0] + birth_year = int(birth_date_clean.split("-")[0]) + else: + birth_year = birth_date.year + + # Parse event date + if isinstance(event_date, str): + event_date_clean = event_date.replace("Z", "").split("T")[0] + event_year = int(event_date_clean.split("-")[0]) + else: + event_year = event_date.year + + # MIMIC-IV style: simple year difference + age = event_year - birth_year + + return age + except (ValueError, AttributeError, TypeError, IndexError): + return None + + +def encode_gender(gender: str) -> Optional[int]: + """Encode gender as integer for ML models. + + Standard encoding: Male=1, Female=0, Other/Unknown=None + + Args: + gender: Gender string (case-insensitive) + + Returns: + Encoded gender (1 for male, 0 for female, None for other/unknown) + """ + if not gender: + return None + + gender_lower = gender.lower() + if gender_lower in ["male", "m"]: + return 1 + elif gender_lower in ["female", "f"]: + return 0 + else: + return None diff --git a/healthchain/io/__init__.py b/healthchain/io/__init__.py index 52bc38ef..2e33328c 100644 --- a/healthchain/io/__init__.py +++ b/healthchain/io/__init__.py @@ -1,15 +1,32 @@ -from .containers import DataContainer, Document, Tabular -from .base import BaseAdapter +"""IO module for data containers, adapters, and mappers. + +This module provides: +- Containers: Data structures for documents and datasets +- Adapters: Convert external formats (CDA, CDS Hooks) to/from HealthChain +- Mappers: Transform clinical data between formats (FHIR to pandas, FHIR versions) +""" + +from .containers import DataContainer, Document, Dataset, FeatureSchema +from .adapters.base import BaseAdapter from .adapters.cdaadapter import CdaAdapter from .adapters.cdsfhiradapter import CdsFhirAdapter +from .mappers import BaseMapper, FHIRFeatureMapper +from .types import TimeWindow, ValidationResult __all__ = [ # Containers "DataContainer", "Document", - "Tabular", + "Dataset", + "FeatureSchema", # Adapters "BaseAdapter", "CdaAdapter", "CdsFhirAdapter", + # Mappers + "BaseMapper", + "FHIRFeatureMapper", + # Types + "TimeWindow", + "ValidationResult", ] diff --git a/healthchain/io/adapters/__init__.py b/healthchain/io/adapters/__init__.py index 6fb1012a..cc698d79 100644 --- a/healthchain/io/adapters/__init__.py +++ b/healthchain/io/adapters/__init__.py @@ -7,5 +7,6 @@ from .cdaadapter import CdaAdapter from .cdsfhiradapter import CdsFhirAdapter +from .base import BaseAdapter -__all__ = ["CdaAdapter", "CdsFhirAdapter"] +__all__ = ["CdaAdapter", "CdsFhirAdapter", "BaseAdapter"] diff --git a/healthchain/io/base.py b/healthchain/io/adapters/base.py similarity index 100% rename from healthchain/io/base.py rename to healthchain/io/adapters/base.py diff --git a/healthchain/io/adapters/cdaadapter.py b/healthchain/io/adapters/cdaadapter.py index 8271f52e..91d5d456 100644 --- a/healthchain/io/adapters/cdaadapter.py +++ b/healthchain/io/adapters/cdaadapter.py @@ -2,7 +2,7 @@ from typing import Optional from healthchain.io.containers import Document -from healthchain.io.base import BaseAdapter +from healthchain.io.adapters.base import BaseAdapter from healthchain.interop import create_interop, FormatType, InteropEngine from healthchain.models.requests.cdarequest import CdaRequest from healthchain.models.responses.cdaresponse import CdaResponse diff --git a/healthchain/io/adapters/cdsfhiradapter.py b/healthchain/io/adapters/cdsfhiradapter.py index 7d3be0e7..42d36572 100644 --- a/healthchain/io/adapters/cdsfhiradapter.py +++ b/healthchain/io/adapters/cdsfhiradapter.py @@ -4,7 +4,7 @@ from fhir.resources.documentreference import DocumentReference from healthchain.io.containers import Document -from healthchain.io.base import BaseAdapter +from healthchain.io.adapters.base import BaseAdapter from healthchain.models.requests.cdsrequest import CDSRequest from healthchain.models.responses.cdsresponse import CDSResponse from healthchain.fhir import read_content_attachment, convert_prefetch_to_fhir_objects diff --git a/healthchain/io/containers/__init__.py b/healthchain/io/containers/__init__.py index 05db1a95..d11372a8 100644 --- a/healthchain/io/containers/__init__.py +++ b/healthchain/io/containers/__init__.py @@ -1,5 +1,6 @@ from .base import DataContainer, BaseDocument from .document import Document -from .tabular import Tabular +from .dataset import Dataset +from .featureschema import FeatureSchema -__all__ = ["DataContainer", "BaseDocument", "Document", "Tabular"] +__all__ = ["DataContainer", "BaseDocument", "Document", "Dataset", "FeatureSchema"] diff --git a/healthchain/io/containers/dataset.py b/healthchain/io/containers/dataset.py new file mode 100644 index 00000000..39740be5 --- /dev/null +++ b/healthchain/io/containers/dataset.py @@ -0,0 +1,307 @@ +import pandas as pd +import numpy as np + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterator, List, Union, Optional + +from fhir.resources.bundle import Bundle +from fhir.resources.riskassessment import RiskAssessment + +from healthchain.io.containers.base import DataContainer +from healthchain.io.containers.featureschema import FeatureSchema +from healthchain.io.mappers.fhirfeaturemapper import FHIRFeatureMapper +from healthchain.io.types import ValidationResult +from healthchain.fhir.resourcehelpers import ( + create_risk_assessment_from_prediction, + create_single_codeable_concept, +) + + +@dataclass +class Dataset(DataContainer[pd.DataFrame]): + """ + A container for tabular data optimized for ML inference, lightweight wrapper around a pandas DataFrame. + + Methods: + from_csv: Load Dataset from CSV. + from_dict: Load Dataset from dict. + from_fhir_bundle: Create Dataset from FHIR Bundle and schema. + to_csv: Save Dataset to CSV. + to_risk_assessment: Convert predictions to FHIR RiskAssessment. + """ + + def __post_init__(self): + if not isinstance(self.data, pd.DataFrame): + raise TypeError("data must be a pandas DataFrame") + + @property + def columns(self) -> List[str]: + return list(self.data.columns) + + @property + def index(self) -> pd.Index: + return self.data.index + + @property + def dtypes(self) -> Dict[str, str]: + return {col: str(dtype) for col, dtype in self.data.dtypes.items()} + + def column_count(self) -> int: + return len(self.columns) + + def row_count(self) -> int: + return len(self.data) + + def get_dtype(self, column: str) -> str: + return str(self.data[column].dtype) + + def __iter__(self) -> Iterator[str]: + return iter(self.columns) + + def __len__(self) -> int: + return self.row_count() + + def describe(self) -> str: + return f"Dataset with {self.column_count()} columns and {self.row_count()} rows" + + def remove_column(self, name: str) -> None: + self.data.drop(columns=[name], inplace=True) + + @classmethod + def from_csv(cls, path: str, **kwargs) -> "Dataset": + return cls(pd.read_csv(path, **kwargs)) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Dataset": + df = pd.DataFrame(data["data"]) + return cls(df) + + def to_csv(self, path: str, **kwargs) -> None: + self.data.to_csv(path, **kwargs) + + @classmethod + def from_fhir_bundle( + cls, + bundle: Union[Bundle, Dict[str, Any]], + schema: Union[str, Path, FeatureSchema], + aggregation: str = "mean", + ) -> "Dataset": + """Create Dataset from a FHIR Bundle using a feature schema. + + Extracts features from FHIR resources according to the schema specification, + converting FHIR data to a pandas DataFrame suitable for ML inference. + + Args: + bundle: FHIR Bundle resource (object or dict) + schema: FeatureSchema object, or path to YAML schema file + aggregation: How to aggregate multiple observation values (default: "mean") + Options: "mean", "median", "max", "min", "last" (default: "mean") + + Returns: + Dataset container with extracted features + + Example: + >>> from fhir.resources.bundle import Bundle + >>> bundle = Bundle(**patient_data) + >>> dataset = Dataset.from_fhir_bundle( + ... bundle, + ... schema="healthchain/configs/features/sepsis_vitals.yaml" + ... ) + >>> df = dataset.data + """ + # Load schema if path provided + if isinstance(schema, (str, Path)): + schema = FeatureSchema.from_yaml(schema) + + # Extract features using mapper + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(bundle, aggregation=aggregation) + + return cls(df) + + def validate( + self, schema: FeatureSchema, raise_on_error: bool = False + ) -> ValidationResult: + """Validate DataFrame against a feature schema. + + Checks that required features are present and have correct data types. + + Args: + schema: FeatureSchema to validate against + raise_on_error: Whether to raise exception on validation failure + + Returns: + ValidationResult with validation status and details + + Raises: + ValueError: If raise_on_error is True and validation fails + + Example: + >>> schema = FeatureSchema.from_yaml("configs/features/sepsis_vitals.yaml") + >>> result = dataset.validate(schema) + >>> if not result.valid: + ... print(result.errors) + """ + result = ValidationResult(valid=True) + + # Check for missing required features + required = schema.get_required_features() + missing = [f for f in required if f not in self.data.columns] + + for feature in missing: + result.add_missing_feature(feature) + + # Check data types for present features + for feature_name, mapping in schema.features.items(): + if feature_name in self.data.columns: + actual_dtype = str(self.data[feature_name].dtype) + expected_dtype = mapping.dtype + + # Check for type mismatches (allow some flexibility) + if not self._dtypes_compatible(actual_dtype, expected_dtype): + result.add_type_mismatch(feature_name, expected_dtype, actual_dtype) + + # Warn about optional missing features + optional = set(schema.get_feature_names()) - set(required) + missing_optional = [f for f in optional if f not in self.data.columns] + + for feature in missing_optional: + result.add_warning(f"Optional feature '{feature}' is missing") + + if raise_on_error and not result.valid: + raise ValueError(str(result)) + + return result + + def _dtypes_compatible(self, actual: str, expected: str) -> bool: + """Check if actual dtype is compatible with expected dtype. + + Args: + actual: Actual dtype string + expected: Expected dtype string + + Returns: + True if dtypes are compatible + """ + # Handle numeric types flexibly + numeric_types = {"int64", "int32", "float64", "float32"} + if expected in numeric_types and actual in numeric_types: + return True + + # Exact match for non-numeric types + return actual == expected + + def to_risk_assessment( + self, + predictions: np.ndarray, + probabilities: np.ndarray, + outcome_code: str, + outcome_display: str, + outcome_system: str = "http://hl7.org/fhir/sid/icd-10", + model_name: Optional[str] = None, + model_version: Optional[str] = None, + high_threshold: float = 0.7, + moderate_threshold: float = 0.4, + ) -> List[RiskAssessment]: + """Convert model predictions to FHIR RiskAssessment resources. + + Creates RiskAssessment resources from ML model output, suitable for + including in FHIR Bundles or sending to FHIR servers. + + Args: + predictions: Binary predictions array (0/1) + probabilities: Probability scores array (0-1) + outcome_code: Code for the predicted outcome (e.g., "A41.9" for sepsis) + outcome_display: Display text for the outcome (e.g., "Sepsis") + outcome_system: Code system for the outcome (default: ICD-10) + model_name: Name of the ML model (optional) + model_version: Version of the ML model (optional) + high_threshold: Threshold for high risk (default: 0.7) + moderate_threshold: Threshold for moderate risk (default: 0.4) + + Returns: + List of RiskAssessment resources, one per patient + + Example: + >>> predictions = np.array([0, 1, 0]) + >>> probabilities = np.array([0.15, 0.85, 0.32]) + >>> risk_assessments = dataset.to_risk_assessment( + ... predictions, + ... probabilities, + ... outcome_code="A41.9", + ... outcome_display="Sepsis, unspecified", + ... model_name="RandomForest", + ... model_version="1.0" + ... ) + """ + if len(predictions) != len(self.data): + raise ValueError( + f"Predictions length ({len(predictions)}) must match " + f"DataFrame length ({len(self.data)})" + ) + + if len(probabilities) != len(self.data): + raise ValueError( + f"Probabilities length ({len(probabilities)}) must match " + f"DataFrame length ({len(self.data)})" + ) + + risk_assessments = [] + + # Get patient references + if "patient_ref" not in self.data.columns: + raise ValueError("DataFrame must have 'patient_ref' column") + + for idx, row in self.data.iterrows(): + patient_ref = row["patient_ref"] + prediction = int(predictions[idx]) + probability = float(probabilities[idx]) + + # Determine qualitative risk + if probability >= high_threshold: + qualitative_risk = "high" + elif probability >= moderate_threshold: + qualitative_risk = "moderate" + else: + qualitative_risk = "low" + + # Build prediction dict + prediction_dict = { + "outcome": { + "code": outcome_code, + "display": outcome_display, + "system": outcome_system, + }, + "probability": probability, + "qualitative_risk": qualitative_risk, + } + + # Create method CodeableConcept if model info provided + method = None + if model_name: + method = create_single_codeable_concept( + code=model_name, + display=f"{model_name} v{model_version}" + if model_version + else model_name, + system="https://healthchain.github.io/ml-models", + ) + + # Create comment with prediction details + comment = ( + f"ML prediction: {'Positive' if prediction == 1 else 'Negative'} " + f"(probability: {probability:.2%}, risk: {qualitative_risk})" + ) + + # Create RiskAssessment + risk_assessment = create_risk_assessment_from_prediction( + subject=patient_ref, + prediction=prediction_dict, + method=method, + comment=comment, + ) + + risk_assessments.append(risk_assessment) + + return risk_assessments diff --git a/healthchain/io/containers/featureschema.py b/healthchain/io/containers/featureschema.py new file mode 100644 index 00000000..6504d4ef --- /dev/null +++ b/healthchain/io/containers/featureschema.py @@ -0,0 +1,220 @@ +"""Feature schema definitions for FHIR to Dataset data conversion. + +This module provides classes to define and manage feature schemas that map +FHIR resources to pandas DataFrame columns for ML model deployment. +""" + +import yaml +from pathlib import Path +from typing import Dict, List, Optional, Union, Any +from pydantic import BaseModel, field_validator, ConfigDict, model_validator + + +class FeatureMapping(BaseModel): + """Maps a single feature to its FHIR source.""" + + name: str + fhir_resource: str + code: Optional[str] = None + code_system: Optional[str] = None + field: Optional[str] = None + transform: Optional[str] = None + dtype: str = "float64" + required: bool = True + unit: Optional[str] = None + display: Optional[str] = None + + model_config = ConfigDict(extra="allow") + + @model_validator(mode="after") + def validate_resource_requirements(self) -> "FeatureMapping": + """Validate the feature mapping configuration based on resource type.""" + if self.fhir_resource == "Observation": + if not self.code: + raise ValueError( + f"Feature '{self.name}': Observation resources require a 'code'" + ) + if not self.code_system: + raise ValueError( + f"Feature '{self.name}': Observation resources require a 'code_system'" + ) + elif self.fhir_resource == "Patient": + if not self.field: + raise ValueError( + f"Feature '{self.name}': Patient resources require a 'field'" + ) + return self + + @classmethod + def from_dict(cls, name: str, data: Dict[str, Any]) -> "FeatureMapping": + """Create a FeatureMapping from a dictionary. + + Args: + name: The feature name + data: Dictionary containing feature configuration + + Returns: + FeatureMapping instance + """ + return cls(name=name, **data) + + +class FeatureSchema(BaseModel): + """Schema defining how to extract features from FHIR resources.""" + + name: str + version: str + features: Dict[str, FeatureMapping] = {} + description: Optional[str] = None + model_info: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None + + model_config = ConfigDict(extra="allow") + + @field_validator("features", mode="before") + @classmethod + def convert_feature_dicts(cls, v): + """Convert feature dicts to FeatureMapping objects if needed.""" + if v and isinstance(v, dict): + # Check if values are dicts (need conversion) or already FeatureMapping + if v and isinstance(list(v.values())[0], dict): + return { + name: FeatureMapping.from_dict(name, mapping) + for name, mapping in v.items() + } + return v + + @classmethod + def from_yaml(cls, path: Union[str, Path]) -> "FeatureSchema": + """Load schema from a YAML file. + + Args: + path: Path to the YAML file + + Returns: + FeatureSchema instance + + Example: + >>> schema = FeatureSchema.from_yaml("configs/features/sepsis_vitals.yaml") + """ + path = Path(path) + with open(path, "r") as f: + data = yaml.safe_load(f) + + return cls.model_validate(data) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FeatureSchema": + """Create a FeatureSchema from a dictionary. + + Args: + data: Dictionary containing schema configuration + + Returns: + FeatureSchema instance + """ + return cls.model_validate(data) + + def to_dict(self) -> Dict[str, Any]: + """Convert schema to dictionary format. + + Returns: + Dictionary representation of the schema + """ + result = { + "name": self.name, + "version": self.version, + "description": self.description, + "model_info": self.model_info, + "features": { + name: { + k: v + for k, v in mapping.model_dump().items() + if k != "name" and v is not None + } + for name, mapping in self.features.items() + }, + } + if self.metadata: + result["metadata"] = self.metadata + return result + + def to_yaml(self, path: Union[str, Path]) -> None: + """Save schema to a YAML file. + + Args: + path: Path where the YAML file will be saved + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + with open(path, "w") as f: + yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False) + + def get_feature_names(self) -> List[str]: + """Get list of feature names in order. + + Returns: + List of feature names + """ + return list(self.features.keys()) + + def get_required_features(self) -> List[str]: + """Get list of required feature names. + + Returns: + List of required feature names + """ + return [name for name, mapping in self.features.items() if mapping.required] + + def get_features_by_resource(self, resource_type: str) -> Dict[str, FeatureMapping]: + """Get all features mapped to a specific FHIR resource type. + + Args: + resource_type: FHIR resource type (e.g., "Observation", "Patient") + + Returns: + Dictionary of features for the specified resource type + """ + return { + name: mapping + for name, mapping in self.features.items() + if mapping.fhir_resource == resource_type + } + + def get_observation_codes(self) -> Dict[str, FeatureMapping]: + """Get all Observation features with their codes. + + Returns: + Dictionary mapping codes to feature mappings + """ + observations = self.get_features_by_resource("Observation") + return { + mapping.code: mapping for mapping in observations.values() if mapping.code + } + + def validate_dataframe_columns(self, columns: List[str]) -> Dict[str, Any]: + """Validate that a DataFrame has the expected columns. + + Args: + columns: List of column names from a DataFrame + + Returns: + Dictionary with validation results: + - valid: bool + - missing_required: List of missing required features + - unexpected: List of unexpected columns + """ + expected = set(self.get_feature_names()) + actual = set(columns) + required = set(self.get_required_features()) + + missing_required = list(required - actual) + unexpected = list(actual - expected) + + return { + "valid": len(missing_required) == 0, + "missing_required": missing_required, + "unexpected": unexpected, + "missing_optional": list((expected - required) - actual), + } diff --git a/healthchain/io/containers/tabular.py b/healthchain/io/containers/tabular.py deleted file mode 100644 index 809bb16e..00000000 --- a/healthchain/io/containers/tabular.py +++ /dev/null @@ -1,81 +0,0 @@ -import pandas as pd - -from dataclasses import dataclass -from typing import Any, Dict, Iterator, List - -from healthchain.io.containers.base import DataContainer - - -@dataclass -class Tabular(DataContainer[pd.DataFrame]): - """ - A container for tabular data, wrapping a pandas DataFrame. - - Attributes: - data (pd.DataFrame): The pandas DataFrame containing the tabular data. - - Methods: - __post_init__(): Validates that the data is a pandas DataFrame. - columns: Property that returns a list of column names. - index: Property that returns the DataFrame's index. - dtypes: Property that returns a dictionary of column names and their data types. - column_count(): Returns the number of columns in the DataFrame. - row_count(): Returns the number of rows in the DataFrame. - get_dtype(column: str): Returns the data type of a specific column. - __iter__(): Returns an iterator over the column names. - __len__(): Returns the number of rows in the DataFrame. - describe(): Returns a string description of the tabular data. - remove_column(name: str): Removes a column from the DataFrame. - from_csv(path: str, **kwargs): Class method to create a Tabular object from a CSV file. - from_dict(data: Dict[str, Any]): Class method to create a Tabular object from a dictionary. - to_csv(path: str, **kwargs): Saves the DataFrame to a CSV file. - """ - - def __post_init__(self): - if not isinstance(self.data, pd.DataFrame): - raise TypeError("data must be a pandas DataFrame") - - @property - def columns(self) -> List[str]: - return list(self.data.columns) - - @property - def index(self) -> pd.Index: - return self.data.index - - @property - def dtypes(self) -> Dict[str, str]: - return {col: str(dtype) for col, dtype in self.data.dtypes.items()} - - def column_count(self) -> int: - return len(self.columns) - - def row_count(self) -> int: - return len(self.data) - - def get_dtype(self, column: str) -> str: - return str(self.data[column].dtype) - - def __iter__(self) -> Iterator[str]: - return iter(self.columns) - - def __len__(self) -> int: - return self.row_count() - - def describe(self) -> str: - return f"Tabular data with {self.column_count()} columns and {self.row_count()} rows" - - def remove_column(self, name: str) -> None: - self.data.drop(columns=[name], inplace=True) - - @classmethod - def from_csv(cls, path: str, **kwargs) -> "Tabular": - return cls(pd.read_csv(path, **kwargs)) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Tabular": - df = pd.DataFrame(**data["data"]) - return cls(df) - - def to_csv(self, path: str, **kwargs) -> None: - self.data.to_csv(path, **kwargs) diff --git a/healthchain/io/mappers/__init__.py b/healthchain/io/mappers/__init__.py new file mode 100644 index 00000000..2bee9cff --- /dev/null +++ b/healthchain/io/mappers/__init__.py @@ -0,0 +1,12 @@ +"""Clinical data mappers for transformations between formats. + +Mappers handle transformations between different clinical data formats: +- FHIR to pandas (ML feature extraction) +- FHIR version migrations +- Clinical standard conversions (FHIR to OMOP) +""" + +from .base import BaseMapper +from .fhirfeaturemapper import FHIRFeatureMapper + +__all__ = ["BaseMapper", "FHIRFeatureMapper"] diff --git a/healthchain/io/mappers/base.py b/healthchain/io/mappers/base.py new file mode 100644 index 00000000..ddc4031c --- /dev/null +++ b/healthchain/io/mappers/base.py @@ -0,0 +1,48 @@ +"""Base mapper for clinical data transformations. + +Mappers handle transformations between different clinical data formats and +representations, including: +- Clinical standard conversions (FHIR versions, FHIR to OMOP) +- Feature extraction for ML (FHIR to pandas) +- Data model transformations +""" + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +SourceType = TypeVar("SourceType") +TargetType = TypeVar("TargetType") + + +class BaseMapper(Generic[SourceType, TargetType], ABC): + """ + Abstract base class for clinical data mappers. + + Mappers transform clinical data between different formats and representations, + distinct from Adapters which handle external message format conversion. + + Use mappers for: + - FHIR to pandas feature extraction (ML workflows) + - FHIR version migrations (R4 to R5) + - Clinical standard conversions (FHIR to OMOP) + - Semantic and structural data transformations + + Example: + >>> class FHIRFeatureMapper(BaseMapper[Bundle, pd.DataFrame]): + ... def map(self, source: Bundle) -> pd.DataFrame: + ... # Extract features from FHIR Bundle + ... return dataframe + """ + + @abstractmethod + def map(self, source: SourceType) -> TargetType: + """ + Transform source data to target format. + + Args: + source: Source data in input format + + Returns: + Transformed data in target format + """ + pass diff --git a/healthchain/io/mappers/fhirfeaturemapper.py b/healthchain/io/mappers/fhirfeaturemapper.py new file mode 100644 index 00000000..24eba60f --- /dev/null +++ b/healthchain/io/mappers/fhirfeaturemapper.py @@ -0,0 +1,149 @@ +"""Schema-driven FHIR to feature mapper for ML model deployment. + +This module provides schema-driven feature extraction from FHIR Bundles, +using FeatureSchema to specify which features to extract and how to transform them. +""" + +from typing import Any, Dict, Union +import pandas as pd +import numpy as np + +from fhir.resources.bundle import Bundle + +from healthchain.io.containers.featureschema import FeatureSchema +from healthchain.io.mappers.base import BaseMapper +from healthchain.fhir.dataframe import bundle_to_dataframe, BundleConverterConfig + + +class FHIRFeatureMapper(BaseMapper[Bundle, pd.DataFrame]): + """Schema-driven mapper from FHIR resources to DataFrame features. + + Uses a FeatureSchema to extract and transform specific features from FHIR Bundles. + Leverages the generic bundle_to_dataframe converter and filters/renames columns + based on the schema. + """ + + def __init__(self, schema: FeatureSchema): + self.schema = schema + + def map(self, source: Bundle) -> pd.DataFrame: + """Transform FHIR Bundle to DataFrame using default aggregation. + Args: + source: FHIR Bundle resource + + Returns: + DataFrame with extracted features + """ + return self.extract_features(source) + + def extract_features( + self, + bundle: Union[Bundle, Dict[str, Any]], + aggregation: str = "mean", + ) -> pd.DataFrame: + """Extract features from a FHIR Bundle according to the schema. + + Args: + bundle: FHIR Bundle resource (object or dict) + aggregation: How to aggregate multiple observation values (default: "mean") + Options: "mean", "median", "max", "min", "last" (default: "mean") + + Returns: + DataFrame with one row per patient and columns matching schema features + + Example: + >>> from healthchain.io.containers.featureschema import FeatureSchema + >>> schema = FeatureSchema.from_yaml("configs/features/sepsis_vitals.yaml") + >>> mapper = FHIRFeatureMapper(schema) + >>> df = mapper.extract_features(bundle) + """ + # Build config from schema + config = self._build_config_from_schema(aggregation) + + # Extract features using config + df = bundle_to_dataframe(bundle, config=config) + + if df.empty: + return pd.DataFrame( + columns=["patient_ref"] + self.schema.get_feature_names() + ) + + # Map generic column names to schema feature names + df_mapped = self._map_columns_to_schema(df) + + # Ensure all schema features are present (fill missing with NaN) + feature_names = self.schema.get_feature_names() + for feature in feature_names: + if feature not in df_mapped.columns: + df_mapped[feature] = np.nan + + # Reorder columns to match schema + df_mapped = df_mapped[["patient_ref"] + feature_names] + + return df_mapped + + def _build_config_from_schema(self, aggregation: str) -> BundleConverterConfig: + """Build converter config from feature schema. + + Args: + aggregation: Aggregation method for observations + + Returns: + BundleConverterConfig configured based on schema + """ + # Determine which resources are needed from schema + resources = set() + for feature in self.schema.features.values(): + resources.add(feature.fhir_resource) + + # Extract age calculation metadata if present + metadata = self.schema.metadata or {} + age_calculation = metadata.get("age_calculation", "current_date") + event_date_source = metadata.get("event_date_source", "Observation") + event_date_strategy = metadata.get("event_date_strategy", "earliest") + + return BundleConverterConfig( + resources=list(resources), + observation_aggregation=aggregation, + age_calculation=age_calculation, + event_date_source=event_date_source, + event_date_strategy=event_date_strategy, + ) + + def _map_columns_to_schema(self, df: pd.DataFrame) -> pd.DataFrame: + """Map generic DataFrame columns to schema feature names. + + Args: + df: DataFrame from bundle_to_dataframe + + Returns: + DataFrame with columns renamed according to schema + """ + rename_map = {} + + # Map observation columns + obs_features = self.schema.get_features_by_resource("Observation") + for feature_name, mapping in obs_features.items(): + # Generic converter creates columns like: "8867-4_Heart_rate" + # Find matching column in df + for col in df.columns: + if col.startswith(mapping.code): + rename_map[col] = feature_name + break + + # Map patient columns (already have correct names from helpers) + patient_features = self.schema.get_features_by_resource("Patient") + for feature_name, mapping in patient_features.items(): + if mapping.field == "birthDate": + # Generic converter uses "age" + if "age" in df.columns: + rename_map["age"] = feature_name + elif mapping.field == "gender": + # Generic converter uses "gender" + if "gender" in df.columns: + rename_map["gender"] = feature_name + + # Rename columns + df_renamed = df.rename(columns=rename_map) + + return df_renamed diff --git a/healthchain/io/types.py b/healthchain/io/types.py new file mode 100644 index 00000000..260cab61 --- /dev/null +++ b/healthchain/io/types.py @@ -0,0 +1,139 @@ +"""Type definitions for IO operations. + +This module provides common types used across IO operations, particularly +for FHIR to Dataset data conversion. +""" + +from dataclasses import dataclass +from typing import List, Dict, Tuple +from pydantic import BaseModel, Field, field_validator + + +class TimeWindow(BaseModel): + """Defines a time window for filtering temporal data. + + Used to extract data from a specific time period relative to a reference point, + such as the first 24 hours after ICU admission. + + Attributes: + reference_field: Field name in the FHIR resource marking the reference time + (e.g., "intime" for ICU admission, "admittime" for hospital admission) + hours: Duration of the time window in hours from the reference point + offset_hours: Number of hours to offset from the reference point (default: 0) + For example, offset_hours=6 and hours=24 would capture hours 6-30 + + Example: + >>> # Capture first 24 hours after ICU admission + >>> window = TimeWindow(reference_field="intime", hours=24) + >>> + >>> # Capture hours 6-30 after admission + >>> window = TimeWindow(reference_field="admittime", hours=24, offset_hours=6) + """ + + reference_field: str + hours: int + offset_hours: int = Field(default=0) + + @field_validator("hours") + @classmethod + def hours_must_be_positive(cls, v): + if v <= 0: + raise ValueError("hours must be positive") + return v + + @field_validator("offset_hours") + @classmethod + def offset_hours_non_negative(cls, v): + if v < 0: + raise ValueError("offset_hours must be non-negative") + return v + + +@dataclass +class ValidationResult: + """Result of data validation operations. + + Attributes: + valid: Overall validation status + missing_features: List of required features that are missing + type_mismatches: Dictionary mapping feature names to (expected, actual) type tuples + warnings: List of non-critical validation warnings + errors: List of validation errors + + Example: + >>> result = ValidationResult( + ... valid=False, + ... missing_features=["heart_rate"], + ... type_mismatches={"age": ("int64", "object")}, + ... warnings=["Optional feature 'temperature' is missing"], + ... errors=["Required feature 'heart_rate' is missing"] + ... ) + """ + + valid: bool + missing_features: List[str] = None + type_mismatches: Dict[str, Tuple[str, str]] = None + warnings: List[str] = None + errors: List[str] = None + + def __post_init__(self): + """Initialize empty lists and dicts for None values.""" + if self.missing_features is None: + self.missing_features = [] + if self.type_mismatches is None: + self.type_mismatches = {} + if self.warnings is None: + self.warnings = [] + if self.errors is None: + self.errors = [] + + def __str__(self) -> str: + """Human-readable validation result.""" + if self.valid: + return "Validation passed" + + lines = ["Validation failed:"] + + if self.errors: + lines.append("\nErrors:") + for error in self.errors: + lines.append(f" - {error}") + + if self.missing_features: + lines.append("\nMissing features:") + for feature in self.missing_features: + lines.append(f" - {feature}") + + if self.type_mismatches: + lines.append("\nType mismatches:") + for feature, (expected, actual) in self.type_mismatches.items(): + lines.append(f" - {feature}: expected {expected}, got {actual}") + + if self.warnings: + lines.append("\nWarnings:") + for warning in self.warnings: + lines.append(f" - {warning}") + + return "\n".join(lines) + + def add_error(self, error: str) -> None: + """Add an error to the validation result.""" + self.errors.append(error) + self.valid = False + + def add_warning(self, warning: str) -> None: + """Add a warning to the validation result.""" + self.warnings.append(warning) + + def add_missing_feature(self, feature: str) -> None: + """Add a missing feature.""" + self.missing_features.append(feature) + self.errors.append(f"Required feature '{feature}' is missing") + self.valid = False + + def add_type_mismatch(self, feature: str, expected: str, actual: str) -> None: + """Add a type mismatch.""" + self.type_mismatches[feature] = (expected, actual) + self.errors.append( + f"Type mismatch for '{feature}': expected {expected}, got {actual}" + ) diff --git a/healthchain/sandbox/generators/conditiongenerators.py b/healthchain/sandbox/generators/conditiongenerators.py index 366b9984..09dd6354 100644 --- a/healthchain/sandbox/generators/conditiongenerators.py +++ b/healthchain/sandbox/generators/conditiongenerators.py @@ -4,7 +4,7 @@ from fhir.resources.reference import Reference from fhir.resources.condition import ConditionStage, ConditionParticipant -from healthchain.fhir.helpers import create_single_codeable_concept, create_condition +from healthchain.fhir import create_single_codeable_concept, create_condition from healthchain.sandbox.generators.basegenerators import ( BaseGenerator, generator_registry, diff --git a/healthchain/sandbox/loaders/mimic.py b/healthchain/sandbox/loaders/mimic.py index be79adc3..57e03219 100644 --- a/healthchain/sandbox/loaders/mimic.py +++ b/healthchain/sandbox/loaders/mimic.py @@ -7,7 +7,7 @@ import logging import random from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from fhir.resources.R4B.bundle import Bundle @@ -54,28 +54,49 @@ def load( resource_types: Optional[List[str]] = None, sample_size: Optional[int] = None, random_seed: Optional[int] = None, + as_dict: bool = False, **kwargs, - ) -> Dict: + ) -> Union[Dict[str, Bundle], Dict[str, Any]]: """ - Load MIMIC-on-FHIR data as a dict of FHIR Bundles. + Load MIMIC-on-FHIR data as FHIR Bundle(s). Args: data_dir: Path to root MIMIC-on-FHIR directory (expects a /fhir subdir with .ndjson.gz files) resource_types: Resource type names to load (e.g., ["MimicMedication"]). Required. sample_size: Number of resources to randomly sample per type (loads all if None) random_seed: Seed for sampling + as_dict: If True, return single bundle dict (fast, no validation - for ML workflows). + If False, return dict of validated Bundle objects grouped by resource type (for CDS Hooks). + Default: False **kwargs: Reserved for future use Returns: - Dict mapping resource type (e.g., "MedicationStatement") to FHIR R4B Bundle + If as_dict=False: Dict[str, Bundle] - validated Pydantic Bundle objects grouped by resource type + Example: {"observation": Bundle(...), "patient": Bundle(...)} + If as_dict=True: Dict[str, Any] - single combined bundle dict (no validation) + Example: {"type": "collection", "entry": [...]} Raises: FileNotFoundError: If directory or resource files not found ValueError: If resource_types is None/empty or resources fail validation - Example: + Examples: + CDS Hooks prefetch format (validated, grouped by resource type): >>> loader = MimicOnFHIRLoader() - >>> loader.load(data_dir="./data/mimic-iv-fhir", resource_types=["MimicMedication"], sample_size=100) + >>> prefetch = loader.load( + ... data_dir="./data/mimic-iv-fhir", + ... resource_types=["MimicMedication", "MimicCondition"] + ... ) + >>> prefetch["medicationstatement"] # Pydantic Bundle object + + ML workflow (single bundle dict, fast, no validation): + >>> bundle = loader.load( + ... data_dir="./data/mimic-iv-fhir", + ... resource_types=["MimicObservationChartevents", "MimicPatient"], + ... as_dict=True + ... ) + >>> from healthchain.io import Dataset + >>> dataset = Dataset.from_fhir_bundle(bundle, schema="sepsis_vitals.yaml") """ data_dir = Path(data_dir) @@ -141,6 +162,15 @@ def load( f"No valid resources loaded from specified resource types: {resource_types}" ) + # ML workflow + if as_dict: + all_entries = [] + for resources in resources_by_type.values(): + all_entries.extend([{"resource": r} for r in resources]) + + return {"type": "collection", "entry": all_entries} + + # CDS Hooks prefetch bundles = {} for fhir_type, resources in resources_by_type.items(): bundles[fhir_type.lower()] = Bundle( diff --git a/tests/containers/conftest.py b/tests/containers/conftest.py deleted file mode 100644 index bc77ee38..00000000 --- a/tests/containers/conftest.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -from healthchain.io.containers.document import FhirData, Document -from healthchain.fhir import create_bundle, create_document_reference - - -@pytest.fixture -def fhir_data(): - return FhirData() - - -@pytest.fixture -def sample_bundle(): - return create_bundle() - - -@pytest.fixture -def sample_document(): - return Document("This is a sample text for testing.") - - -@pytest.fixture -def sample_document_reference(): - return create_document_reference( - data="test content", - content_type="text/plain", - description="Test Document", - ) - - -@pytest.fixture -def document_family(): - """Create a family of related documents.""" - original = create_document_reference( - data="original content", - content_type="text/plain", - description="Original Report", - ) - - summary = create_document_reference( - data="summary content", content_type="text/plain", description="Summary" - ) - - translation = create_document_reference( - data="translated content", content_type="text/plain", description="Translation" - ) - - return original, summary, translation diff --git a/tests/fhir/test_bundle_helpers.py b/tests/fhir/test_bundle_helpers.py index d82ba692..cfc89c4e 100644 --- a/tests/fhir/test_bundle_helpers.py +++ b/tests/fhir/test_bundle_helpers.py @@ -7,7 +7,7 @@ from fhir.resources.allergyintolerance import AllergyIntolerance from fhir.resources.documentreference import DocumentReference -from healthchain.fhir.bundle_helpers import ( +from healthchain.fhir.bundlehelpers import ( create_bundle, add_resource, get_resources, diff --git a/tests/fhir/test_converters.py b/tests/fhir/test_converters.py new file mode 100644 index 00000000..40e06725 --- /dev/null +++ b/tests/fhir/test_converters.py @@ -0,0 +1,525 @@ +"""Tests for FHIR converters module. + +Tests the converter functions that transform FHIR Bundles to DataFrames, +with focus on the dict-based conversion architecture. +""" + +import pytest +import pandas as pd + +from healthchain.fhir.dataframe import ( + extract_observation_value, + group_bundle_by_patient, + bundle_to_dataframe, + extract_event_date, + get_supported_resources, + get_resource_info, + BundleConverterConfig, +) +from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + create_condition, + create_medication_statement, +) + + +@pytest.mark.parametrize( + "obs_dict,expected", + [ + ({"valueQuantity": {"value": 85.0}}, 85.0), + ({"valueInteger": 100}, 100.0), + ({"valueString": "98.6"}, 98.6), + ({}, None), + ({"valueString": "not a number"}, None), + ({"valueBoolean": True}, None), + ], +) +def test_extract_observation_value_handles_value_types(obs_dict, expected): + """extract_observation_value handles different value types and invalid values.""" + assert extract_observation_value(obs_dict) == expected + + +def test_group_bundle_by_patient_handles_both_input_types(): + """group_bundle_by_patient handles Pydantic Bundle and dict input.""" + # Test Pydantic input + pydantic_bundle = create_bundle() + patient1 = create_patient("male", "1980-01-01") + patient1.id = "123" + add_resource(pydantic_bundle, patient1) + add_resource( + pydantic_bundle, + create_value_quantity_observation( + subject="Patient/123", code="8867-4", value=85.0, unit="bpm" + ), + ) + + result = group_bundle_by_patient(pydantic_bundle) + assert "Patient/123" in result + assert isinstance(result["Patient/123"]["patient"], dict) + assert result["Patient/123"]["patient"]["resourceType"] == "Patient" + + # Test dict input + dict_bundle = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + {"resource": {"resourceType": "Patient", "id": "456", "gender": "female"}}, + { + "resource": { + "resourceType": "Observation", + "subject": {"reference": "Patient/456"}, + "code": {"coding": [{"code": "8310-5"}]}, + "valueQuantity": {"value": 37.0}, + } + }, + ], + } + + result = group_bundle_by_patient(dict_bundle) + assert "Patient/456" in result + assert len(result["Patient/456"]["observations"]) == 1 + + +def test_group_bundle_by_patient_handles_reference_formats(): + """group_bundle_by_patient handles string and dict references, plus patient field.""" + bundle_dict = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + {"resource": {"resourceType": "Patient", "id": "789"}}, + { + "resource": { + "resourceType": "Observation", + "subject": "Patient/789", # String reference + "code": {"coding": [{"code": "8867-4"}]}, + "valueQuantity": {"value": 90.0}, + } + }, + { + "resource": { + "resourceType": "AllergyIntolerance", + "patient": {"reference": "Patient/789"}, # Uses patient field + "code": {"coding": [{"code": "123"}]}, + } + }, + ], + } + + result = group_bundle_by_patient(bundle_dict) + assert len(result["Patient/789"]["observations"]) == 1 + assert len(result["Patient/789"]["allergies"]) == 1 + + +def test_group_bundle_by_patient_groups_multiple_resource_types(): + """group_bundle_by_patient correctly categorizes different resource types.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "999" + + add_resource(bundle, patient) + + # Add one of each type + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/999", code="8867-4", value=85.0, unit="bpm" + ), + ) + + from healthchain.fhir import ( + create_condition, + create_medication_statement, + create_allergy_intolerance, + ) + + add_resource(bundle, create_condition("Patient/999", code="E11.9")) + add_resource(bundle, create_medication_statement("Patient/999", code="123")) + add_resource(bundle, create_allergy_intolerance("Patient/999", code="456")) + + result = group_bundle_by_patient(bundle) + + assert len(result["Patient/999"]["observations"]) == 1 + assert len(result["Patient/999"]["conditions"]) == 1 + assert len(result["Patient/999"]["medications"]) == 1 + assert len(result["Patient/999"]["allergies"]) == 1 + + +def test_bundle_to_dataframe_basic_conversion(): + """bundle_to_dataframe converts both Pydantic and dict Bundles to DataFrames.""" + # Test with Pydantic Bundle + pydantic_bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(pydantic_bundle, patient) + add_resource( + pydantic_bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + display="Heart rate", + ), + ) + + df = bundle_to_dataframe(pydantic_bundle) + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + assert "age" in df.columns and "gender" in df.columns + assert "8867-4_Heart_rate" in df.columns + assert df["8867-4_Heart_rate"].iloc[0] == 85.0 + + # Test with dict Bundle + dict_bundle = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + { + "resource": { + "resourceType": "Patient", + "id": "456", + "gender": "female", + "birthDate": "1990-05-15", + } + }, + { + "resource": { + "resourceType": "Observation", + "subject": {"reference": "Patient/456"}, + "code": { + "coding": [{"code": "8310-5", "display": "Body temperature"}] + }, + "valueQuantity": {"value": 37.0}, + } + }, + ], + } + + df = bundle_to_dataframe(dict_bundle) + assert len(df) == 1 + assert "8310-5_Body_temperature" in df.columns + + +@pytest.mark.parametrize( + "resources,source,strategy,expected", + [ + ( + { + "observations": [ + {"effectiveDateTime": "2024-01-15"}, + {"effectiveDateTime": "2024-01-10"}, + {"effectiveDateTime": "2024-01-20"}, + ] + }, + "Observation", + "earliest", + "2024-01-10", + ), + ( + { + "observations": [ + {"effectiveDateTime": "2024-01-15"}, + {"effectiveDateTime": "2024-01-10"}, + {"effectiveDateTime": "2024-01-20"}, + ] + }, + "Observation", + "latest", + "2024-01-20", + ), + ( + { + "observations": [ + {"effectiveDateTime": "2024-01-15"}, + {"effectiveDateTime": "2024-01-10"}, + {"effectiveDateTime": "2024-01-20"}, + ] + }, + "Observation", + "first", + "2024-01-15", + ), + ( + { + "encounters": [ + {"period": {"start": "2024-01-15T10:00:00Z"}}, + {"period": {"start": "2024-01-10T08:00:00Z"}}, + ] + }, + "Encounter", + "earliest", + "2024-01-10T08:00:00Z", + ), + ({}, "Observation", "earliest", None), + ({"observations": []}, "Observation", "earliest", None), + ], +) +def test_extract_event_date_strategies_and_sources( + resources, source, strategy, expected +): + """extract_event_date handles different strategies and resource sources.""" + assert extract_event_date(resources, source=source, strategy=strategy) == expected + + +@pytest.mark.parametrize( + "aggregation,values,expected", + [ + ("mean", [85.0, 92.0], 88.5), + ("median", [85.0, 92.0, 100.0], 92.0), + ("max", [85.0, 92.0, 100.0], 100.0), + ("min", [85.0, 92.0, 100.0], 85.0), + ("last", [85.0, 92.0, 100.0], 100.0), + ], +) +def test_bundle_to_dataframe_observation_aggregation_strategies( + aggregation, values, expected +): + """bundle_to_dataframe applies different aggregation strategies correctly.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + + # Add multiple observations with same code + for value in values: + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=value, + unit="bpm", + display="Heart rate", + ), + ) + + config = BundleConverterConfig( + resources=["Patient", "Observation"], observation_aggregation=aggregation + ) + df = bundle_to_dataframe(bundle, config=config) + + assert df["8867-4_Heart_rate"].iloc[0] == expected + + +def test_bundle_to_dataframe_age_calculation_modes(): + """bundle_to_dataframe calculates age from current date or event date.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + effective_datetime="2020-01-01T00:00:00Z", + ), + ) + + # Test event_date calculation + config = BundleConverterConfig( + age_calculation="event_date", + event_date_source="Observation", + event_date_strategy="earliest", + ) + df = bundle_to_dataframe(bundle, config=config) + assert df["age"].iloc[0] == 40 # 2020 - 1980 + + # Test current_date calculation (default) + config_default = BundleConverterConfig() + df_default = bundle_to_dataframe(bundle, config=config_default) + assert df_default["age"].iloc[0] is not None + + +def test_bundle_to_dataframe_creates_binary_indicators_for_conditions_and_medications(): + """bundle_to_dataframe creates binary indicator columns for conditions and medications.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + add_resource( + bundle, create_condition("Patient/123", code="E11.9", display="Type_2_diabetes") + ) + add_resource( + bundle, + create_medication_statement("Patient/123", code="1049221", display="Insulin"), + ) + + config = BundleConverterConfig( + resources=["Patient", "Condition", "MedicationStatement"] + ) + df = bundle_to_dataframe(bundle, config=config) + + assert "condition_E11.9_Type_2_diabetes" in df.columns + assert df["condition_E11.9_Type_2_diabetes"].iloc[0] == 1 + assert "medication_1049221_Insulin" in df.columns + assert df["medication_1049221_Insulin"].iloc[0] == 1 + + +def test_bundle_to_dataframe_handles_edge_cases(): + """bundle_to_dataframe handles empty bundles and malformed data gracefully.""" + # Empty bundle + empty_bundle = create_bundle() + df = bundle_to_dataframe(empty_bundle) + assert isinstance(df, pd.DataFrame) and len(df) == 0 + + # Missing coding arrays - should skip bad observation + bundle_dict = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + { + "resource": { + "resourceType": "Patient", + "id": "123", + "gender": "male", + "birthDate": "1980-01-01", + } + }, + { + "resource": { + "resourceType": "Observation", + "subject": {"reference": "Patient/123"}, + "code": {}, # Missing coding array + "valueQuantity": {"value": 85.0}, + } + }, + ], + } + + df = bundle_to_dataframe(bundle_dict) + assert len(df) == 1 + assert df["patient_ref"].iloc[0] == "Patient/123" + + # Missing display - should use code as fallback + bundle_with_condition = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "456" + add_resource(bundle_with_condition, patient) + add_resource(bundle_with_condition, create_condition("Patient/456", code="E11.9")) + + config = BundleConverterConfig(resources=["Patient", "Condition"]) + df = bundle_to_dataframe(bundle_with_condition, config=config) + assert "condition_E11.9_E11.9" in df.columns # Code used as display + + +def test_bundle_to_dataframe_handles_multiple_patients(): + """bundle_to_dataframe creates one row per patient in multi-patient bundles.""" + bundle = create_bundle() + + # Add first patient with observations + patient1 = create_patient("male", "1980-01-01") + patient1.id = "123" + add_resource(bundle, patient1) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + display="Heart rate", + ), + ) + + # Add second patient with observations + patient2 = create_patient("female", "1990-05-15") + patient2.id = "456" + add_resource(bundle, patient2) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/456", + code="8867-4", + value=72.0, + unit="bpm", + display="Heart rate", + ), + ) + + df = bundle_to_dataframe(bundle) + + assert len(df) == 2 + assert set(df["patient_ref"]) == {"Patient/123", "Patient/456"} + assert df[df["patient_ref"] == "Patient/123"]["8867-4_Heart_rate"].iloc[0] == 85.0 + assert df[df["patient_ref"] == "Patient/456"]["8867-4_Heart_rate"].iloc[0] == 72.0 + + +def test_bundle_converter_config_defaults(): + """BundleConverterConfig uses sensible defaults.""" + config = BundleConverterConfig() + + assert config.resources == ["Patient", "Observation"] + assert config.observation_aggregation == "mean" + assert config.age_calculation == "current_date" + assert config.event_date_source == "Observation" + assert config.event_date_strategy == "earliest" + + +def test_bundle_converter_config_validates_unsupported_resources(caplog): + """BundleConverterConfig warns about unsupported resources but doesn't fail.""" + import logging + + caplog.set_level(logging.WARNING) + + config = BundleConverterConfig( + resources=["Patient", "Observation", "UnsupportedResource", "AnotherFakeOne"] + ) + + # Should still create config successfully + assert "Patient" in config.resources + assert "Observation" in config.resources + + # Should have logged warnings + assert any("UnsupportedResource" in record.message for record in caplog.records) + + +def test_get_supported_resources_returns_expected_types(): + """get_supported_resources returns list of supported resource types.""" + resources = get_supported_resources() + + assert isinstance(resources, list) + assert "Patient" in resources + assert "Observation" in resources + assert "Condition" in resources + assert "MedicationStatement" in resources + + +def test_get_resource_info_returns_handler_details(): + """get_resource_info returns metadata for supported resources.""" + obs_info = get_resource_info("Observation") + + assert obs_info["handler"] == "_flatten_observations" + assert "description" in obs_info + assert "observation" in obs_info["description"].lower() + + # Unsupported resource returns empty dict + assert get_resource_info("UnsupportedResource") == {} + + +def test_bundle_to_dataframe_skips_unsupported_resources_gracefully(): + """bundle_to_dataframe skips unsupported resources without error.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", code="8867-4", value=85.0, unit="bpm" + ), + ) + + # Include unsupported resource types in config + config = BundleConverterConfig( + resources=["Patient", "Observation", "UnsupportedType"] + ) + + # Should not raise error, just skip unsupported types + df = bundle_to_dataframe(bundle, config=config) + assert len(df) == 1 diff --git a/tests/fhir/test_helpers.py b/tests/fhir/test_helpers.py index 75763b16..35f7ce24 100644 --- a/tests/fhir/test_helpers.py +++ b/tests/fhir/test_helpers.py @@ -9,7 +9,7 @@ from datetime import datetime -from healthchain.fhir.helpers import ( +from healthchain.fhir import ( create_resource_from_dict, create_single_codeable_concept, create_single_reaction, @@ -23,6 +23,8 @@ read_content_attachment, add_provenance_metadata, add_coding_to_codeable_concept, + calculate_age_from_birthdate, + calculate_age_from_event_date, ) import pytest @@ -336,3 +338,72 @@ def test_set_condition_category_invalid_raises(): def test_create_condition_without_code_is_none(): cond = create_condition(subject="Patient/1") assert cond.code is None + + +def test_calculate_age_from_birthdate(): + """Test standard age calculation from birth date.""" + # Test with date 30 years ago + from datetime import datetime + + birth_year = datetime.now().year - 30 + birth_date = f"{birth_year}-06-15" + + age = calculate_age_from_birthdate(birth_date) + assert age is not None + # Age should be 29 or 30 depending on current date + assert age in [29, 30] + + +def test_calculate_age_from_birthdate_with_full_datetime(): + """Test age calculation with full ISO datetime.""" + from datetime import datetime + + birth_year = datetime.now().year - 25 + birth_date = f"{birth_year}-03-10T10:30:00Z" + + age = calculate_age_from_birthdate(birth_date) + assert age is not None + assert age in [24, 25] + + +def test_calculate_age_from_birthdate_invalid(): + """Test age calculation with invalid date.""" + assert calculate_age_from_birthdate(None) is None + assert calculate_age_from_birthdate("") is None + assert calculate_age_from_birthdate("invalid") is None + + +def test_calculate_age_from_event_date(): + """Test MIMIC-IV style age calculation using event date.""" + birth_date = "1990-06-15" + event_date = "2020-03-10" + + age = calculate_age_from_event_date(birth_date, event_date) + assert age == 30 # 2020 - 1990 = 30 + + +def test_calculate_age_from_event_date_with_full_datetime(): + """Test MIMIC-IV style calculation with full ISO datetime.""" + birth_date = "1985-12-25T08:00:00Z" + event_date = "2023-01-15T14:30:00Z" + + age = calculate_age_from_event_date(birth_date, event_date) + assert age == 38 # 2023 - 1985 = 38 + + +def test_calculate_age_from_event_date_same_year(): + """Test MIMIC-IV style calculation when birth and event in same year.""" + birth_date = "2020-01-01" + event_date = "2020-12-31" + + age = calculate_age_from_event_date(birth_date, event_date) + assert age == 0 # Same year = 0 + + +def test_calculate_age_from_event_date_invalid(): + """Test MIMIC-IV style calculation with invalid dates.""" + assert calculate_age_from_event_date(None, "2020-01-01") is None + assert calculate_age_from_event_date("1990-01-01", None) is None + assert calculate_age_from_event_date("", "2020-01-01") is None + assert calculate_age_from_event_date("invalid", "2020-01-01") is None + assert calculate_age_from_event_date("1990-01-01", "invalid") is None diff --git a/tests/io/conftest.py b/tests/io/conftest.py new file mode 100644 index 00000000..efbe95dd --- /dev/null +++ b/tests/io/conftest.py @@ -0,0 +1,212 @@ +import pytest +import pandas as pd +from pathlib import Path + +from healthchain.io.containers.featureschema import FeatureSchema +from healthchain.io.containers.dataset import Dataset +from healthchain.fhir import create_bundle +from healthchain.fhir import create_patient, create_value_quantity_observation + + +@pytest.fixture +def sepsis_schema(): + """Load the actual sepsis_vitals.yaml schema. + + Uses the real schema file for integration-style testing. + """ + schema_path = Path("healthchain/configs/features/sepsis_vitals.yaml") + return FeatureSchema.from_yaml(schema_path) + + +@pytest.fixture +def minimal_schema(): + """Minimal schema with required and optional features. + + Useful for testing basic functionality without all the complexity + of the full sepsis schema. + """ + return FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "heart_rate": { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "display": "Heart rate", + "dtype": "float64", + "required": True, + }, + "temperature": { + "fhir_resource": "Observation", + "code": "8310-5", + "code_system": "http://loinc.org", + "display": "Body temperature", + "dtype": "float64", + "required": False, + }, + "age": { + "fhir_resource": "Patient", + "field": "birthDate", + "transform": "calculate_age", + "dtype": "int64", + "required": True, + }, + "gender_encoded": { + "fhir_resource": "Patient", + "field": "gender", + "transform": "encode_gender", + "dtype": "int64", + "required": True, + }, + }, + } + ) + + +@pytest.fixture +def observation_bundle(): + """Bundle with patient and observations matching minimal schema. + + Contains a single patient with heart rate and temperature observations. + """ + from healthchain.fhir import add_resource + + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + + # Use add_resource to properly add to bundle + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8310-5", + value=37.0, + unit="F", + system="http://loinc.org", + display="Body temperature", + ), + ) + + return bundle + + +@pytest.fixture +def observation_bundle_with_duplicates(): + """Bundle with multiple observations of the same type for testing aggregation.""" + from healthchain.fhir import add_resource + + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + + # Use add_resource consistently like observation_bundle + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=90.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=88.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + + return bundle + + +@pytest.fixture +def empty_observation_bundle(): + """Bundle with patient but no observations.""" + from healthchain.fhir import add_resource + + bundle = create_bundle() + patient = create_patient("female", "1990-05-15") + patient.id = "456" + + add_resource(bundle, patient) + return bundle + + +@pytest.fixture +def sample_dataset(): + """Sample dataset with minimal schema features. + + Contains two patients with complete feature data. + """ + data = { + "patient_ref": ["Patient/1", "Patient/2"], + "heart_rate": [85.0, 92.0], + "temperature": [37.0, 37.5], + "age": [45, 62], + "gender_encoded": [1, 0], + } + return Dataset(pd.DataFrame(data)) + + +@pytest.fixture +def sample_dataset_incomplete(): + """Sample dataset missing required features. + + Useful for testing validation logic. + """ + data = { + "patient_ref": ["Patient/1", "Patient/2"], + "heart_rate": [85.0, 92.0], + # Missing temperature (optional), age, and gender_encoded (required) + } + return Dataset(pd.DataFrame(data)) + + +@pytest.fixture +def sample_dataset_wrong_types(): + """Sample dataset with incorrect data types. + + Useful for testing type validation logic. + """ + data = { + "patient_ref": ["Patient/1", "Patient/2"], + "heart_rate": ["85.0", "92.0"], # String instead of float + "temperature": [37.0, 37.5], + "age": [45.5, 62.5], # Float instead of int + "gender_encoded": [1, 0], + } + return Dataset(pd.DataFrame(data)) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py new file mode 100644 index 00000000..be2e25f1 --- /dev/null +++ b/tests/io/test_dataset.py @@ -0,0 +1,294 @@ +import pytest +import pandas as pd +import numpy as np + + +from healthchain.io.containers.dataset import Dataset + + +def test_dataset_from_fhir_bundle(observation_bundle, minimal_schema): + """Dataset.from_fhir_bundle extracts features using schema.""" + dataset = Dataset.from_fhir_bundle(observation_bundle, minimal_schema) + + assert len(dataset.data) == 1 + assert "patient_ref" in dataset.columns + assert "heart_rate" in dataset.columns + assert "temperature" in dataset.columns + assert "age" in dataset.columns + + +def test_dataset_from_fhir_bundle_with_yaml_path(observation_bundle): + """Dataset.from_fhir_bundle accepts YAML schema path.""" + schema_path = "healthchain/configs/features/sepsis_vitals.yaml" + dataset = Dataset.from_fhir_bundle(observation_bundle, schema_path) + + assert len(dataset.data) == 1 + assert "patient_ref" in dataset.columns + + +def test_dataset_from_fhir_bundle_with_aggregation( + observation_bundle_with_duplicates, minimal_schema +): + """Dataset.from_fhir_bundle respects aggregation parameter.""" + dataset_mean = Dataset.from_fhir_bundle( + observation_bundle_with_duplicates, minimal_schema, aggregation="mean" + ) + dataset_max = Dataset.from_fhir_bundle( + observation_bundle_with_duplicates, minimal_schema, aggregation="max" + ) + + assert dataset_mean.data["heart_rate"].iloc[0] == pytest.approx(87.666667, rel=1e-5) + assert dataset_max.data["heart_rate"].iloc[0] == 90.0 + + +def test_dataset_validate_with_complete_data(sample_dataset, minimal_schema): + """Dataset.validate passes with complete valid data.""" + result = sample_dataset.validate(minimal_schema) + + assert result.valid is True + assert len(result.missing_features) == 0 + assert len(result.errors) == 0 + + +def test_dataset_validate_detects_missing_required_features( + sample_dataset_incomplete, minimal_schema +): + """Dataset.validate detects missing required features.""" + result = sample_dataset_incomplete.validate(minimal_schema) + + assert result.valid is False + assert len(result.missing_features) > 0 + assert "age" in result.missing_features + assert "gender_encoded" in result.missing_features + + +def test_dataset_validate_raises_on_error_when_requested( + sample_dataset_incomplete, minimal_schema +): + """Dataset.validate raises exception when raise_on_error is True.""" + with pytest.raises(ValueError, match="Validation failed"): + sample_dataset_incomplete.validate(minimal_schema, raise_on_error=True) + + +def test_dataset_validate_detects_type_mismatches( + sample_dataset_wrong_types, minimal_schema +): + """Dataset.validate detects incorrect data types.""" + result = sample_dataset_wrong_types.validate(minimal_schema) + + # Type mismatches are recorded even if they don't fail validation due to dtype_compatible + assert len(result.type_mismatches) > 0 + # heart_rate should be object (string) instead of float64 + assert "heart_rate" in result.type_mismatches + # Check that errors were added for the type mismatches + assert len(result.errors) > 0 + assert any("heart_rate" in error for error in result.errors) + + +def test_dataset_validate_warns_about_missing_optional(minimal_schema): + """Dataset.validate generates warnings for missing optional features.""" + data = pd.DataFrame( + { + "patient_ref": ["Patient/1"], + "heart_rate": [85.0], + "age": [45], + "gender_encoded": [1], + # Missing optional "temperature" + } + ) + dataset = Dataset(data) + + result = dataset.validate(minimal_schema) + + assert result.valid is True + assert len(result.warnings) > 0 + assert any("temperature" in w for w in result.warnings) + + +def test_dataset_dtype_compatibility_allows_numeric_flexibility(): + """Dataset._dtypes_compatible allows flexibility between numeric types.""" + data = pd.DataFrame( + { + "patient_ref": ["Patient/1"], + "value_int": [45], # int64 + "value_float": [45.0], # float64 + } + ) + dataset = Dataset(data) + + # int64 and float64 should be compatible + assert dataset._dtypes_compatible("int64", "float64") + assert dataset._dtypes_compatible("float64", "int64") + assert dataset._dtypes_compatible("int32", "float64") + + +def test_dataset_to_risk_assessment_creates_resources_with_metadata(sample_dataset): + """Dataset.to_risk_assessment creates RiskAssessment resources with probabilities, model metadata, and comments.""" + predictions = np.array([0, 1]) + probabilities = np.array([0.15, 0.85]) + + # Test with model metadata + risks = sample_dataset.to_risk_assessment( + predictions, + probabilities, + outcome_code="A41.9", + outcome_display="Sepsis", + model_name="RandomForest", + model_version="1.0", + ) + + # Basic structure + assert len(risks) == 2 + assert risks[0].subject.reference == "Patient/1" + assert risks[1].subject.reference == "Patient/2" + assert risks[0].status == "final" + + # Probabilities + assert risks[0].prediction[0].probabilityDecimal == 0.15 + assert risks[1].prediction[0].probabilityDecimal == 0.85 + + # Model metadata + assert risks[0].method is not None + assert risks[0].method.coding[0].code == "RandomForest" + assert "v1.0" in risks[0].method.coding[0].display + + # Comments + assert risks[0].note is not None + assert "Negative" in risks[0].note[0].text + assert "15.00%" in risks[0].note[0].text + assert "low" in risks[0].note[0].text + assert "Positive" in risks[1].note[0].text + assert "85.00%" in risks[1].note[0].text + assert "high" in risks[1].note[0].text + + +@pytest.mark.parametrize( + "predictions,probabilities,expected_risks", + [ + ([0, 1, 0], [0.15, 0.85, 0.55], ["low", "high", "moderate"]), + ([0, 1, 0], [0.0, 1.0, 0.5], ["low", "high", "moderate"]), # Edge cases + ], +) +def test_dataset_to_risk_assessment_categorizes_risk_levels( + predictions, probabilities, expected_risks +): + """Dataset.to_risk_assessment correctly categorizes risk levels including edge probabilities.""" + data = pd.DataFrame( + { + "patient_ref": ["Patient/1", "Patient/2", "Patient/3"], + "heart_rate": [85.0, 92.0, 88.0], + "temperature": [37.0, 37.5, 37.2], + "age": [45, 62, 50], + "gender_encoded": [1, 0, 1], + } + ) + dataset = Dataset(data) + + risks = dataset.to_risk_assessment( + np.array(predictions), + np.array(probabilities), + outcome_code="A41.9", + outcome_display="Sepsis", + ) + + for i, expected_risk in enumerate(expected_risks): + assert risks[i].prediction[0].qualitativeRisk.coding[0].code == expected_risk + + +@pytest.mark.parametrize( + "data_dict,predictions,probabilities,expected_error", + [ + ( + {"heart_rate": [85.0, 92.0], "age": [45, 62]}, # Missing patient_ref + [0, 1], + [0.15, 0.85], + "DataFrame must have 'patient_ref' column", + ), + ( + {"patient_ref": ["Patient/1", "Patient/2"], "value": [1, 2]}, + [0], # Wrong prediction length + [0.15, 0.85], + "Predictions length .* must match", + ), + ( + {"patient_ref": ["Patient/1", "Patient/2"], "value": [1, 2]}, + [0, 1], + [0.15], # Wrong probability length + "Probabilities length .* must match", + ), + ], +) +def test_dataset_to_risk_assessment_validation_errors( + data_dict, predictions, probabilities, expected_error +): + """Dataset.to_risk_assessment validates required columns and array lengths.""" + data = pd.DataFrame(data_dict) + dataset = Dataset(data) + + with pytest.raises(ValueError, match=expected_error): + dataset.to_risk_assessment( + np.array(predictions), + np.array(probabilities), + outcome_code="A41.9", + outcome_display="Sepsis", + ) + + +def test_dataset_from_csv_loads_correctly(tmp_path): + """Dataset.from_csv loads CSV files into DataFrame.""" + csv_file = tmp_path / "test.csv" + csv_file.write_text( + "patient_ref,heart_rate,age\nPatient/1,85.0,45\nPatient/2,92.0,62" + ) + + dataset = Dataset.from_csv(str(csv_file)) + + assert len(dataset.data) == 2 + assert "patient_ref" in dataset.columns + assert dataset.data["heart_rate"].iloc[0] == 85.0 + + +def test_dataset_from_dict_creates_dataframe(): + """Dataset.from_dict creates DataFrame from dict.""" + data_dict = { + "data": {"patient_ref": ["Patient/1", "Patient/2"], "heart_rate": [85.0, 92.0]} + } + + dataset = Dataset.from_dict(data_dict) + + assert len(dataset.data) == 2 + assert "patient_ref" in dataset.columns + assert "heart_rate" in dataset.columns + assert dataset.data["heart_rate"].iloc[0] == 85.0 + + +def test_dataset_to_csv_saves_correctly(tmp_path, sample_dataset): + """Dataset.to_csv exports DataFrame to CSV.""" + csv_file = tmp_path / "output.csv" + + sample_dataset.to_csv(str(csv_file), index=False) + + assert csv_file.exists() + df = pd.read_csv(csv_file) + assert len(df) == 2 + assert "patient_ref" in df.columns + + +def test_dataset_rejects_non_dataframe_input(): + """Dataset validates input is a DataFrame in __post_init__.""" + with pytest.raises(TypeError, match="data must be a pandas DataFrame"): + Dataset([{"patient_ref": "Patient/1"}]) + + +def test_dataset_to_risk_assessment_validates_probability_length(): + """Dataset.to_risk_assessment validates probabilities array length.""" + data = pd.DataFrame({"patient_ref": ["Patient/1", "Patient/2"], "value": [1, 2]}) + dataset = Dataset(data) + + predictions = np.array([0, 1]) + probabilities = np.array([0.15]) # Wrong length + + with pytest.raises(ValueError, match="Probabilities length .* must match"): + dataset.to_risk_assessment( + predictions, probabilities, outcome_code="A41.9", outcome_display="Sepsis" + ) diff --git a/tests/containers/test_document.py b/tests/io/test_document.py similarity index 86% rename from tests/containers/test_document.py rename to tests/io/test_document.py index 9faa038f..ab86a839 100644 --- a/tests/containers/test_document.py +++ b/tests/io/test_document.py @@ -6,6 +6,11 @@ from healthchain.fhir import create_bundle, add_resource, create_condition +@pytest.fixture +def sample_document(): + return Document("This is a sample text for testing.") + + def test_document_initialization(sample_document): """Test basic Document initialization and properties.""" assert sample_document.data == "This is a sample text for testing." @@ -23,55 +28,6 @@ def test_document_initialization(sample_document): assert sample_document.nlp.get_embeddings() is None -def test_document_properties(sample_document): - """Test Document property access.""" - # Test property access - assert hasattr(sample_document, "nlp") - assert hasattr(sample_document, "fhir") - assert hasattr(sample_document, "cds") - assert hasattr(sample_document, "models") - - -def test_document_word_count(sample_document): - """Test word count functionality.""" - assert sample_document.word_count() == 7 - - -def test_document_iteration(sample_document): - """Test document iteration over tokens.""" - tokens = list(sample_document) - assert tokens == [ - "This", - "is", - "a", - "sample", - "text", - "for", - "testing.", - ] - - -def test_document_length(sample_document): - """Test document length.""" - assert len(sample_document) == 34 # Length of the text string - - -def test_document_post_init(sample_document): - """Test post-initialization behavior.""" - # Test that text is set from data - assert sample_document.text == sample_document.data - # Test that basic tokenization is performed - assert len(sample_document.nlp._tokens) > 0 - - -def test_empty_document(): - """Test Document initialization with empty text.""" - doc = Document("") - assert doc.text == "" - assert doc.nlp._tokens == [] - assert doc.word_count() == 0 - - @pytest.mark.parametrize( "data_builder, expect_bundle, expected_entries, expected_text", [ diff --git a/tests/io/test_feature_schema.py b/tests/io/test_feature_schema.py new file mode 100644 index 00000000..4569434f --- /dev/null +++ b/tests/io/test_feature_schema.py @@ -0,0 +1,244 @@ +import pytest +import tempfile +from pathlib import Path + +from healthchain.io.containers.featureschema import FeatureSchema, FeatureMapping + + +@pytest.mark.parametrize( + "mapping_data,expected_error", + [ + ( + {"fhir_resource": "Observation"}, + "Observation resources require a 'code'", + ), + ( + {"fhir_resource": "Observation", "code": "123"}, + "Observation resources require a 'code_system'", + ), + ( + {"fhir_resource": "Observation", "code_system": "http://loinc.org"}, + "Observation resources require a 'code'", + ), + ( + {"fhir_resource": "Patient"}, + "Patient resources require a 'field'", + ), + ], +) +def test_feature_mapping_required_fields_and_validations(mapping_data, expected_error): + """FeatureMapping enforces required fields and validates resource-specific requirements.""" + with pytest.raises(ValueError, match=expected_error): + FeatureMapping(name="test_feature", dtype="float64", **mapping_data) + + +def test_feature_schema_loads_from_yaml(sepsis_schema): + """FeatureSchema.from_yaml loads the sepsis_vitals schema correctly.""" + assert sepsis_schema.name == "sepsis_prediction_features" + assert sepsis_schema.version == "1.0" + assert len(sepsis_schema.features) == 8 + assert "heart_rate" in sepsis_schema.features + assert "age" in sepsis_schema.features + + +def test_feature_schema_from_dict(minimal_schema): + """FeatureSchema.from_dict creates schema with proper FeatureMapping objects.""" + assert minimal_schema.name == "test_schema" + assert isinstance(minimal_schema.features["heart_rate"], FeatureMapping) + assert minimal_schema.features["heart_rate"].required is True + assert minimal_schema.features["temperature"].required is False + + +def test_feature_schema_to_dict_and_back_handles_unknown_and_nested_fields( + minimal_schema, +): + """FeatureSchema.to_dict/from_dict: unknown fields are allowed (Pydantic extra='allow').""" + # Add an unknown field at the top-level + schema_dict = minimal_schema.to_dict() + schema_dict["extra_top_level"] = "foo" + # Add extra/unknown fields at the feature level + schema_dict["features"]["heart_rate"]["unknown_field"] = 12345 + schema_dict["features"]["temperature"]["nested_field"] = {"inner": ["a", {"b": 7}]} + + # With Pydantic extra='allow', unknown fields are accepted and preserved + loaded = FeatureSchema.from_dict(schema_dict) + + # Core fields should still be correct + assert loaded.name == minimal_schema.name + assert loaded.version == minimal_schema.version + assert len(loaded.features) == len(minimal_schema.features) + + # Unknown fields are preserved in the model + assert "heart_rate" in loaded.features + assert loaded.features["heart_rate"].code == "8867-4" + + +def test_feature_schema_to_yaml_and_back(minimal_schema): + """FeatureSchema can be saved to YAML and reloaded.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + temp_path = f.name + + try: + minimal_schema.to_yaml(temp_path) + loaded = FeatureSchema.from_yaml(temp_path) + + assert loaded.name == minimal_schema.name + assert len(loaded.features) == len(minimal_schema.features) + assert loaded.features["heart_rate"].code == "8867-4" + finally: + Path(temp_path).unlink() + + +def test_feature_schema_required_vs_optional_distinction(minimal_schema): + """FeatureSchema correctly distinguishes required from optional features.""" + required = minimal_schema.get_required_features() + all_features = minimal_schema.get_feature_names() + + # Required features should be a subset of all features + assert set(required).issubset(set(all_features)) + + # Temperature is optional, others are required + assert "temperature" not in required + assert len(required) == 3 + assert all(f in required for f in ["heart_rate", "age", "gender_encoded"]) + + +@pytest.mark.parametrize( + "columns, expected_valid, missing_required, missing_optional, unexpected", + [ + ( + ["heart_rate", "temperature"], # missing required + False, + {"age", "gender_encoded"}, + set(), + set(), + ), + ( + ["heart_rate", "age", "gender_encoded"], # missing optional + True, + set(), + {"temperature"}, + set(), + ), + ( + ["heart_rate", "age", "gender_encoded", "unexpected_col"], # unexpected col + True, + set(), + set(), + {"unexpected_col"}, + ), + ], +) +def test_feature_schema_validate_dataframe_columns_various_cases( + minimal_schema, + columns, + expected_valid, + missing_required, + missing_optional, + unexpected, +): + """FeatureSchema.validate_dataframe_columns: missing required, optional, and unexpected columns.""" + result = minimal_schema.validate_dataframe_columns(columns) + assert result["valid"] is expected_valid + assert set(result["missing_required"]) == missing_required + if missing_optional: + assert set(result["missing_optional"]) == missing_optional + if unexpected: + assert set(result["unexpected"]) == unexpected + + +def test_feature_schema_get_features_by_resource(minimal_schema): + """FeatureSchema.get_features_by_resource filters features by FHIR resource type.""" + observations = minimal_schema.get_features_by_resource("Observation") + patients = minimal_schema.get_features_by_resource("Patient") + + assert len(observations) == 2 # heart_rate, temperature + assert "heart_rate" in observations + assert "temperature" in observations + + assert len(patients) == 2 # age, gender_encoded + assert "age" in patients + assert "gender_encoded" in patients + + # Non-existent resource type returns empty dict + assert minimal_schema.get_features_by_resource("Condition") == {} + + +def test_feature_schema_get_observation_codes(minimal_schema): + """FeatureSchema.get_observation_codes returns mapping of codes to features.""" + obs_codes = minimal_schema.get_observation_codes() + + assert "8867-4" in obs_codes # heart_rate code + assert "8310-5" in obs_codes # temperature code + assert obs_codes["8867-4"].name == "heart_rate" + assert obs_codes["8310-5"].name == "temperature" + + +def test_feature_schema_get_feature_names_preserves_order(minimal_schema): + """FeatureSchema.get_feature_names returns features in definition order.""" + names = minimal_schema.get_feature_names() + + assert isinstance(names, list) + assert len(names) == 4 + # Order should match the features dict order + assert names == ["heart_rate", "temperature", "age", "gender_encoded"] + + +def test_feature_schema_from_yaml_handles_malformed_file(): + """FeatureSchema.from_yaml raises error for malformed YAML.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("invalid: yaml: content: [\n") # Malformed YAML + temp_path = f.name + + try: + with pytest.raises( + Exception + ): # Could be yaml.YAMLError or other parsing errors + FeatureSchema.from_yaml(temp_path) + finally: + Path(temp_path).unlink() + + +def test_feature_mapping_from_dict_creates_instance(): + """FeatureMapping.from_dict creates instance with name parameter.""" + mapping_data = { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + } + + mapping = FeatureMapping.from_dict("test_feature", mapping_data) + + assert mapping.name == "test_feature" + assert mapping.code == "8867-4" + assert mapping.fhir_resource == "Observation" + assert mapping.required is True + + +def test_feature_schema_handles_optional_fields(minimal_schema): + """FeatureSchema preserves optional metadata fields.""" + # Check that optional fields can be None + assert minimal_schema.description is None or isinstance( + minimal_schema.description, str + ) + assert minimal_schema.model_info is None or isinstance( + minimal_schema.model_info, dict + ) + + # Create schema with metadata + schema_with_metadata = FeatureSchema.from_dict( + { + "name": "test", + "version": "1.0", + "description": "Test description", + "model_info": {"type": "RandomForest"}, + "metadata": {"custom_field": "value"}, + "features": {}, + } + ) + + assert schema_with_metadata.description == "Test description" + assert schema_with_metadata.model_info["type"] == "RandomForest" + assert schema_with_metadata.metadata["custom_field"] == "value" diff --git a/tests/containers/test_fhir_data.py b/tests/io/test_fhir_data.py similarity index 86% rename from tests/containers/test_fhir_data.py rename to tests/io/test_fhir_data.py index 90830e3e..0c28dc13 100644 --- a/tests/containers/test_fhir_data.py +++ b/tests/io/test_fhir_data.py @@ -1,12 +1,41 @@ +import pytest +from healthchain.io.containers.document import FhirData + from healthchain.fhir import create_condition, create_document_reference -def test_bundle_operations(fhir_data, sample_bundle): - """Test basic bundle operations.""" - assert fhir_data.bundle is None +@pytest.fixture +def fhir_data(): + return FhirData() + + +@pytest.fixture +def sample_document_reference(): + return create_document_reference( + data="test content", + content_type="text/plain", + description="Test Document", + ) + + +@pytest.fixture +def document_family(): + """Create a family of related documents.""" + original = create_document_reference( + data="original content", + content_type="text/plain", + description="Original Report", + ) + + summary = create_document_reference( + data="summary content", content_type="text/plain", description="Summary" + ) + + translation = create_document_reference( + data="translated content", content_type="text/plain", description="Translation" + ) - fhir_data.bundle = sample_bundle - assert fhir_data.bundle == sample_bundle + return original, summary, translation def test_resource_operations(fhir_data): diff --git a/tests/io/test_fhir_feature_mapper.py b/tests/io/test_fhir_feature_mapper.py new file mode 100644 index 00000000..8c6bb8eb --- /dev/null +++ b/tests/io/test_fhir_feature_mapper.py @@ -0,0 +1,363 @@ +import pytest +import numpy as np + +from healthchain.io.mappers.fhirfeaturemapper import FHIRFeatureMapper + + +def test_mapper_extracts_features_from_bundle(observation_bundle, minimal_schema): + """FHIRFeatureMapper extracts features matching schema from FHIR Bundle.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features(observation_bundle) + + assert len(df) == 1 + assert "patient_ref" in df.columns + assert df["patient_ref"].iloc[0] == "Patient/123" + assert "heart_rate" in df.columns + assert "temperature" in df.columns + assert "age" in df.columns + assert "gender_encoded" in df.columns + + +@pytest.mark.parametrize( + "aggregation,expected_value", + [ + ("mean", 87.666667), + ("median", 88.0), + ("max", 90.0), + ("min", 85.0), + ("last", 88.0), + ], +) +def test_mapper_aggregation_methods( + observation_bundle_with_duplicates, minimal_schema, aggregation, expected_value +): + """FHIRFeatureMapper correctly aggregates multiple observation values.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features( + observation_bundle_with_duplicates, aggregation=aggregation + ) + + assert len(df) == 1 + assert df["heart_rate"].iloc[0] == pytest.approx(expected_value, rel=1e-5) + + +def test_mapper_fills_missing_observations_with_nan( + empty_observation_bundle, minimal_schema +): + """FHIRFeatureMapper fills missing observations with NaN.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features(empty_observation_bundle) + + assert len(df) == 1 + assert df["patient_ref"].iloc[0] == "Patient/456" + # Patient features should be present + assert df["age"].notna().iloc[0] + assert df["gender_encoded"].notna().iloc[0] + # Observation features should be NaN + assert np.isnan(df["heart_rate"].iloc[0]) + assert np.isnan(df["temperature"].iloc[0]) + + +def test_mapper_column_mapping_from_generic_to_schema(): + """FHIRFeatureMapper correctly maps generic column names to schema feature names.""" + from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + ) + from healthchain.io.containers.featureschema import FeatureSchema + + # Create schema with specific LOINC codes + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "hr": { # Schema uses "hr" as feature name + "fhir_resource": "Observation", + "code": "8867-4", # LOINC for heart rate + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + } + }, + } + ) + + # Create bundle with observation that has code 8867-4 + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(bundle) + + # Should be renamed to "hr" not "8867-4_Heart_rate" + assert "hr" in df.columns + assert df["hr"].iloc[0] == 85.0 + + +def test_mapper_handles_bundle_with_no_matching_observations(observation_bundle): + """FHIRFeatureMapper handles bundle with observations that don't match schema.""" + from healthchain.io.containers.featureschema import FeatureSchema + + # Schema with different codes than what's in the bundle + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "blood_pressure": { + "fhir_resource": "Observation", + "code": "85354-9", # Different code + "code_system": "http://loinc.org", + "dtype": "float64", + "required": False, + } + }, + } + ) + + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(observation_bundle) + + assert len(df) == 1 + assert "blood_pressure" in df.columns + assert np.isnan(df["blood_pressure"].iloc[0]) + + +def test_mapper_extracts_patient_demographics(observation_bundle, minimal_schema): + """FHIRFeatureMapper correctly extracts and transforms patient demographics.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features(observation_bundle) + + # Age should be calculated from birthDate (1980-01-01) + assert df["age"].iloc[0] > 40 # Age should be around 44-45 + assert df["age"].dtype == np.int64 + + # Gender should be encoded (male = 1) + assert df["gender_encoded"].iloc[0] == 1 + assert df["gender_encoded"].dtype == np.int64 + + +def test_mapper_preserves_column_order_from_schema(observation_bundle, minimal_schema): + """FHIRFeatureMapper returns DataFrame with columns ordered as in schema.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features(observation_bundle) + + expected_order = ["patient_ref"] + minimal_schema.get_feature_names() + assert list(df.columns) == expected_order + + +def test_mapper_handles_multiple_patients(): + """FHIRFeatureMapper processes multiple patients in a bundle.""" + from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + ) + from healthchain.io.containers.featureschema import FeatureSchema + + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "heart_rate": { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + } + }, + } + ) + + bundle = create_bundle() + patient1 = create_patient("male", "1980-01-01") + patient1.id = "123" + patient2 = create_patient("female", "1990-05-15") + patient2.id = "456" + + add_resource(bundle, patient1) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + system="http://loinc.org", + ), + ) + add_resource(bundle, patient2) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/456", + code="8867-4", + value=92.0, + unit="bpm", + system="http://loinc.org", + ), + ) + + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(bundle) + + assert len(df) == 2 + assert set(df["patient_ref"]) == {"Patient/123", "Patient/456"} + assert 85.0 in df["heart_rate"].values + assert 92.0 in df["heart_rate"].values + + +def test_mapper_aggregation_with_mixed_values(): + """FHIRFeatureMapper aggregates correctly with extreme value differences.""" + from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + ) + from healthchain.io.containers.featureschema import FeatureSchema + + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "heart_rate": { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + } + }, + } + ) + + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + + # Extreme values + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=50.0, + unit="bpm", + system="http://loinc.org", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=100.0, + unit="bpm", + system="http://loinc.org", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=75.0, + unit="bpm", + system="http://loinc.org", + ), + ) + + mapper = FHIRFeatureMapper(schema) + + # Test different aggregation methods + df_mean = mapper.extract_features(bundle, aggregation="mean") + assert df_mean["heart_rate"].iloc[0] == 75.0 + + df_max = mapper.extract_features(bundle, aggregation="max") + assert df_max["heart_rate"].iloc[0] == 100.0 + + df_min = mapper.extract_features(bundle, aggregation="min") + assert df_min["heart_rate"].iloc[0] == 50.0 + + +def test_mapper_with_schema_metadata_configuration(): + """FHIRFeatureMapper uses schema metadata for age calculation.""" + from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + ) + from healthchain.io.containers.featureschema import FeatureSchema + + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "metadata": { + "age_calculation": "event_date", + "event_date_source": "Observation", + "event_date_strategy": "earliest", + }, + "features": { + "heart_rate": { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + }, + "age": { + "fhir_resource": "Patient", + "field": "birthDate", + "transform": "calculate_age", + "dtype": "int64", + "required": True, + }, + }, + } + ) + + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + effective_datetime="2020-01-01T00:00:00Z", + ), + ) + + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(bundle) + + # Age should be calculated from birthdate to event date (40 years) + assert df["age"].iloc[0] == 40 diff --git a/tests/sandbox/test_mimic_loader.py b/tests/sandbox/test_mimic_loader.py index 0c2614e2..bf4e10dd 100644 --- a/tests/sandbox/test_mimic_loader.py +++ b/tests/sandbox/test_mimic_loader.py @@ -316,3 +316,115 @@ def test_mimic_loader_skips_resources_without_resource_type(temp_mimic_data_dir) # Should only load the valid resource bundle = result["medicationstatement"] assert len(bundle.entry) == 1 + + +def test_mimic_loader_as_dict_returns_plain_dict( + temp_mimic_data_dir, mock_medication_resources +): + """MimicOnFHIRLoader with as_dict=True returns plain dict (not Pydantic Bundle).""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication"], + as_dict=True, + ) + + # Should return a plain dict, not Dict[str, Bundle] + assert isinstance(result, dict) + assert "type" in result + assert result["type"] == "collection" + assert "entry" in result + assert isinstance(result["entry"], list) + assert len(result["entry"]) == 2 + + +def test_mimic_loader_as_dict_combines_multiple_resource_types( + temp_mimic_data_dir, mock_medication_resources, mock_condition_resources +): + """MimicOnFHIRLoader with as_dict=True combines all resources into single bundle.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + create_ndjson_gz_file( + fhir_dir / "MimicCondition.ndjson.gz", mock_condition_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication", "MimicCondition"], + as_dict=True, + ) + + # Should be a single bundle dict with all resources combined + assert isinstance(result, dict) + assert result["type"] == "collection" + assert len(result["entry"]) == 3 # 2 medications + 1 condition + + # Verify resource types are mixed + resource_types = {entry["resource"]["resourceType"] for entry in result["entry"]} + assert resource_types == {"MedicationStatement", "Condition"} + + +def test_mimic_loader_default_returns_validated_bundles( + temp_mimic_data_dir, mock_medication_resources, mock_condition_resources +): + """MimicOnFHIRLoader with as_dict=False (default) returns validated Bundle objects.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + create_ndjson_gz_file( + fhir_dir / "MimicCondition.ndjson.gz", mock_condition_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication", "MimicCondition"], + as_dict=False, # Explicit default + ) + + # Should return Dict[str, Bundle] with validated Pydantic objects + assert isinstance(result, dict) + assert "medicationstatement" in result + assert "condition" in result + + # Each value should be a Pydantic Bundle + assert type(result["medicationstatement"]).__name__ == "Bundle" + assert type(result["condition"]).__name__ == "Bundle" + assert len(result["medicationstatement"].entry) == 2 + assert len(result["condition"].entry) == 1 + + +def test_mimic_loader_as_dict_structure_matches_fhir_bundle( + temp_mimic_data_dir, mock_medication_resources +): + """MimicOnFHIRLoader with as_dict=True produces valid FHIR Bundle structure.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication"], + as_dict=True, + ) + + # Verify FHIR Bundle structure + assert result["type"] == "collection" + assert isinstance(result["entry"], list) + + # Each entry should have resource field + for entry in result["entry"]: + assert "resource" in entry + assert "resourceType" in entry["resource"] + assert entry["resource"]["resourceType"] == "MedicationStatement"