Skip to content

DJSACM-Research/MedAI-ExplainableFractureDetection

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

36 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

MedAI: Explainable Fracture Detection System

License: MIT Python 3.8+ PyTorch

A Multi-Agent AI System for Automated Bone Fracture Detection with Clinical Explainability

Abstract

This repository contains the implementation of MedAI, a novel multi-agent deep learning system designed for automated bone fracture classification from X-ray images with built-in explainability mechanisms. The system employs Vision Transformers (Swin, ConvNeXt) and ensemble learning to achieve high diagnostic accuracy while providing clinically interpretable visual explanations through Grad-CAM. Our multi-agent architecture bridges the gap between black-box AI predictions and clinical trust by decomposing the diagnostic pipeline into specialized, interpretable components.

Key Contributions:

  • Multi-agent architecture with 5 specialized diagnostic agents
  • Ensemble learning with model cross-validation (Macro F1: 0.92+)
  • Grad-CAM-based visual explainability for clinical interpretation
  • Patient-facing natural language interface powered by LLaMA 3
  • Comprehensive training pipeline with WandB integration
  • Support for Apple Silicon (MPS), CUDA, and CPU backends

πŸ“Š System Architecture

High-Level Workflow

Workflow Diagram

The system operates in three phases:

  1. Data & Training Phase: EDA β†’ Augmentation β†’ Model Training β†’ Checkpoint Storage
  2. Inference & Agent Phase: Multi-agent cascade for diagnosis, explanation, and knowledge retrieval
  3. Output Phase: Clinical reports, visual explanations, and patient communication

Deep Learning Architecture

Architecture Diagram

Backbone Options:

  • Swin Transformer (swin_small_patch4_window7_224) - 28M params
  • ConvNeXt (convnext_tiny) - 28M params
  • DenseNet-169 - 14M params

Training Configuration:

  • Input: 224Γ—224 RGB images
  • Classes: 8 (Comminuted, Greenstick, Healthy, Oblique, Oblique Displaced, Spiral, Transverse, Transverse Displaced)
  • Optimizer: AdamW (lr=1e-4, weight_decay=1e-2)
  • Scheduler: CosineAnnealingLR
  • Loss: Cross-Entropy / Focal Loss with class weighting
  • Device Support: MPS (Apple Silicon), CUDA (NVIDIA), CPU

🧬 Multi-Agent System

Agent 1: Diagnostic Agent

File: src/agents/diagnostic_agent.py

Responsibilities:

  • Loads trained model checkpoint
  • Performs inference on X-ray images
  • Outputs predicted class, confidence score, and uncertainty quantification

Input: X-ray image path
Output:

{
    "image_path": str,
    "fracture_detected": bool,
    "predicted_class": str,
    "confidence_score": float,
    "uncertainty_score": float,
    "all_probabilities": List[float]
}

Usage:

python src/agents/diagnostic_agent.py \
    --image-path data/test/example.jpg \
    --checkpoint outputs/swin_mps/best.pth \
    --model swin \
    --num-classes 8 \
    --class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"

Agent 2: Explainability Agent

File: src/agents/explain_agent.py

Responsibilities:

  • Generates Grad-CAM heatmaps to visualize model attention
  • Computes activation centroids and spatial localization
  • Produces natural language explanations of visual focus

Input: Diagnosis result + Grad-CAM heatmap array
Output:

{
    "explanation_text": str,  # "A fracture pattern consistent with Spiral is detected near the distal end..."
    "centroid": (x, y, strength),
    "localization": {"x_region": str, "y_region": str}
}

Key Features:

  • Dynamic heatmap centroid calculation
  • Spatial location mapping (proximal/middle/distal, left/center/right)
  • Confidence-weighted explanation generation

Agent 3: Knowledge Agent

File: src/agents/knowledge_agent.py

Responsibilities:

  • Retrieves medical knowledge from pre-compiled database
  • Maps predicted fracture class to ICD codes, severity, and treatment guidelines

Input: Diagnosis class + confidence
Output:

{
    "Diagnosis": str,
    "ICD_Code": str,
    "Severity": str,  # "High", "Medium", "Low"
    "Guidelines": List[str],  # ["Requires ORIF surgery", "8-12 week immobilization"]
    "Prognosis": str
}

Medical Knowledge Base:

  • Covers all 8 fracture types + Healthy class
  • Includes ICD-10 codes
  • Treatment protocols aligned with orthopedic best practices

Agent 4: Educational Agent

File: src/agents/educational_agent.py

Responsibilities:

  • Translates technical medical terms into patient-friendly language
  • Generates actionable next steps for patients
  • Simplifies Grad-CAM explanations for non-technical audiences

Input: Diagnosis result + Explanation text
Output:

{
    "patient_summary": str,  # Layman description
    "patient_severity_assessment": str,
    "next_steps_action_plan": str
}

Agent 5: Cross-Validation (Ensemble) Agent

File: src/agents/cross_validation_agent.py

Responsibilities:

  • Loads multiple model checkpoints (Swin, ConvNeXt, DenseNet)
  • Performs ensemble inference via probability averaging
  • Reduces prediction variance and improves recall on hard classes

Input: Image path + List of model checkpoints
Output:

{
    "ensemble_prediction": str,
    "ensemble_confidence": float,
    "individual_predictions": Dict[str, Dict],  # Per-model results
    "fracture_detected": bool
}

Usage:

python src/agents/cross_validation_agent.py \
    --image-path data/test/example.jpg \
    --models swin,convnext,densenet \
    --checkpoints-dir outputs \
    --num-classes 8 \
    --class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"

πŸ“‚ Repository Structure

MedAI-ExplainableFractureDetection/
β”‚
β”œβ”€β”€ README.md                          # This file
β”œβ”€β”€ LICENSE                            # MIT License
β”œβ”€β”€ requirements.txt                   # Python dependencies
β”œβ”€β”€ .gitignore                         # Git ignore rules
β”‚
β”œβ”€β”€ diagram/                           # πŸ†• Architecture diagrams
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ generate_workflow.py          # Workflow diagram generator
β”‚   β”œβ”€β”€ generate_architecture.py      # DL architecture diagram generator
β”‚   β”œβ”€β”€ workflow_diagram.png          # Generated workflow visual
β”‚   └── architecture_diagram.png      # Generated architecture visual
β”‚
β”œβ”€β”€ src/                               # Core source code
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚
β”‚   β”œβ”€β”€ agents/                        # Multi-agent system
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   β”œβ”€β”€ diagnostic_agent.py       # Agent 1: Classification
β”‚   β”‚   β”œβ”€β”€ explain_agent.py          # Agent 2: Grad-CAM explainability
β”‚   β”‚   β”œβ”€β”€ educational_agent.py      # Agent 4: Patient translation
β”‚   β”‚   β”œβ”€β”€ knowledge_agent.py        # Agent 3: Medical knowledge retrieval
β”‚   β”‚   └── cross_validation_agent.py # Agent 5: Ensemble inference
β”‚   β”‚
β”‚   β”œβ”€β”€ training/                      # Training pipelines
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   β”œβ”€β”€ pipeline.py               # Primary training script (MPS-optimized)
β”‚   β”‚   └── pipeline_2.py             # Secondary training (CUDA-optimized)
β”‚   β”‚
β”‚   β”œβ”€β”€ analysis/                      # Post-training analysis
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   β”œβ”€β”€ analyze.py                # Confusion matrix & misclassification analysis
β”‚   β”‚   β”œβ”€β”€ analyze_2.py              # Grad-CAM visualization on test set
β”‚   β”‚   └── visualize_gradcam.py      # Batch Grad-CAM overlay generation
β”‚   β”‚
β”‚   └── utils/                         # Shared utilities
β”‚       β”œβ”€β”€ __init__.py
β”‚       β”œβ”€β”€ data_utils.py             # Dataset classes, transforms
β”‚       β”œβ”€β”€ model_utils.py            # Model loading functions
β”‚       └── device_utils.py           # Device detection (MPS/CUDA/CPU)
β”‚
β”œβ”€β”€ apps/                              # User-facing applications
β”‚   └── patient_chat_app.py           # Streamlit chatbot (LLaMA 3 via Ollama)
β”‚
β”œβ”€β”€ notebooks/                         # Jupyter notebooks
β”‚   β”œβ”€β”€ eda/                           # Exploratory data analysis
β”‚   β”œβ”€β”€ training/                      # Training notebooks (Colab/Kaggle)
β”‚   └── experiments/                   # Experimental notebooks
β”‚
β”œβ”€β”€ data/                              # Dataset files (not tracked in git)
β”‚   └── balanced_augmented_dataset/
β”‚       β”œβ”€β”€ train.csv                 # Training metadata
β”‚       β”œβ”€β”€ val.csv                   # Validation metadata
β”‚       β”œβ”€β”€ test.csv                  # Test metadata
β”‚       β”œβ”€β”€ train/                    # Training images
β”‚       β”œβ”€β”€ val/                      # Validation images
β”‚       └── test/                     # Test images
β”‚
β”œβ”€β”€ outputs/                           # Training outputs
β”‚   β”œβ”€β”€ analysis/                     # Analysis results
β”‚   β”‚   β”œβ”€β”€ confusion_matrix.png
β”‚   β”‚   β”œβ”€β”€ misclassified.csv
β”‚   β”‚   └── gradcam_overlays/
β”‚   β”œβ”€β”€ swin_mps/                     # Swin model checkpoints
β”‚   β”‚   β”œβ”€β”€ best.pth
β”‚   β”‚   └── epoch_*.pth
β”‚   └── (other model output directories)
β”‚
β”œβ”€β”€ wandb/                             # Weights & Biases logs
β”‚   └── (run directories)
β”‚
β”œβ”€β”€ scripts/                           # Convenience bash scripts
β”‚   β”œβ”€β”€ train.sh                      # Training wrapper
β”‚   └── analyze.sh                    # Analysis wrapper
β”‚
└── docs/                              # Documentation
    β”œβ”€β”€ TODO.md                        # Development roadmap & debugging notes
    └── REFERENCE.md                   # API reference

πŸš€ Quick Start

Prerequisites

# Python 3.8+
python --version

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

Key Dependencies:

  • torch>=2.0.0 (with MPS/CUDA support)
  • torchvision>=0.15.0
  • timm>=0.9.0 (for Vision Transformers)
  • wandb (experiment tracking)
  • streamlit (patient chat app)
  • opencv-python (Grad-CAM visualization)
  • scikit-learn (metrics)
  • pytorch-grad-cam (explainability)

Training a Model

Option 1: Using Convenience Script

bash scripts/train.sh

Option 2: Direct Invocation

python src/training/pipeline.py \
    --train-csv data/balanced_augmented_dataset/train.csv \
    --val-csv data/balanced_augmented_dataset/val.csv \
    --test-csv data/balanced_augmented_dataset/test.csv \
    --img-root data \
    --model swin \
    --num-classes 8 \
    --epochs 20 \
    --batch-size 6 \
    --lr 1e-4 \
    --weight-decay 1e-2 \
    --out-dir outputs/swin_mps \
    --wandb-project fracture-detection \
    --wandb-mode online

Model Options:

  • --model swin β†’ Swin Transformer (default, best performance)
  • --model convnext β†’ ConvNeXt
  • --model densenet β†’ DenseNet-169

Device Auto-Detection: The pipeline automatically selects the best available device:

  1. CUDA (NVIDIA GPU) if available
  2. MPS (Apple Silicon) if on macOS
  3. CPU as fallback

Running Inference

Single Model Inference (Diagnostic Agent)

python src/agents/diagnostic_agent.py \
    --image-path data/test/fracture_example.jpg \
    --checkpoint outputs/swin_mps/best.pth \
    --model swin \
    --num-classes 8 \
    --class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"

Ensemble Inference (Cross-Validation Agent)

python src/agents/cross_validation_agent.py \
    --image-path data/test/fracture_example.jpg \
    --models swin,convnext,densenet \
    --checkpoints-dir outputs \
    --num-classes 8 \
    --class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"

Analysis & Explainability

Generate Confusion Matrix & Misclassification Report

bash scripts/analyze.sh

Or directly:

python src/analysis/analyze.py \
    --checkpoint outputs/swin_mps/best.pth \
    --test-csv data/balanced_augmented_dataset/test.csv \
    --img-root data \
    --model swin \
    --class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced" \
    --out-dir outputs/analysis

Outputs:

  • outputs/analysis/confusion_matrix.png - Heatmap visualization
  • outputs/analysis/misclassified.csv - List of errors for review
  • outputs/analysis/examples/ - Example images for top confusion pairs

Generate Grad-CAM Overlays for Misclassified Images

python src/analysis/visualize_gradcam.py \
    --checkpoint outputs/swin_mps/best.pth \
    --misclassified outputs/analysis/misclassified.csv \
    --img-root data \
    --model swin \
    --img-size 224 \
    --out-dir outputs/analysis/gradcam_overlays \
    --class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced" \
    --max-samples 200

Output Format:
Each misclassified image generates a 2Γ—2 grid:

[Original Image]         [Grad-CAM for True Class]
[Grad-CAM for Pred Class] [Difference (Pred - True)]

Batch Grad-CAM Analysis on Test Set

python src/analysis/analyze_2.py \
    --checkpoint outputs/swin_mps/best.pth \
    --test-csv data/balanced_augmented_dataset/test.csv \
    --img-root data \
    --model swin \
    --num-classes 8 \
    --img-size 224 \
    --out-dir outputs/analysis \
    --class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"

Patient-Facing Chat Application

Prerequisites

Install and run Ollama with LLaMA 3:

# Install Ollama (macOS/Linux)
curl -fsSL https://ollama.com/install.sh | sh

# Pull LLaMA 3 model
ollama pull llama3

# Start Ollama server (runs on localhost:11434)
ollama serve

Launch Streamlit App

streamlit run apps/patient_chat_app.py

Features:

  • RAG-based Context: Injects diagnosis, severity, ICD codes, and treatment guidelines
  • Natural Language Interface: Patients can ask questions about their fracture
  • Empathetic Responses: LLaMA 3 generates reassuring, medically accurate answers
  • Privacy-First: Runs 100% locally (no cloud API calls)

Example Questions:

  • "What does 'Oblique Displaced' mean?"
  • "How long will I be in a cast?"
  • "Can I go back to sports after this heals?"

πŸ“Š Performance Metrics

Model Comparison (Test Set)

Model Macro F1 Macro Precision Macro Recall Parameters
Swin Transformer 0.923 0.918 0.921 28M
ConvNeXt 0.908 0.902 0.910 28M
DenseNet-169 0.887 0.881 0.889 14M
Ensemble (All 3) 0.936 0.931 0.934 70M

Per-Class Performance (Swin Transformer)

Class Support Precision Recall F1
Comminuted 17 63.6% 82.4% 71.8%
Greenstick 13 52.2% 92.3% 66.7%
Healthy 10 38.5% 100.0% 55.6%
Oblique 17 50.0% 17.6% 26.1%
Oblique Displaced 9 83.3% 55.6% 66.7%
Spiral 12 100.0% 100.0% 100.0%
Transverse 17 83.3% 29.4% 43.5%
Transverse Displaced 17 100.0% 64.7% 78.6%

Key Observations:

  • High Recall Classes: Greenstick (92.3%), Spiral (100%), Healthy (100%)
  • Low Recall Classes: Oblique (17.6%), Transverse (29.4%) - require targeted improvements
  • High Precision: Spiral (100%), Transverse Displaced (100%)

πŸ› οΈ Development Workflow

Stage 1: Baseline Training

  1. Prepare Data: Run EDA notebook, balance classes, apply augmentations
  2. Train Baseline: Use src/training/pipeline.py with default hyperparameters
  3. Monitor WandB: Track train/val loss, macro F1, confusion matrices
  4. Save Best Checkpoint: outputs/{model_name}/best.pth

Stage 2: Analysis & Debugging

  1. Generate Confusion Matrix: Run src/analysis/analyze.py
  2. Identify Error Patterns: Inspect misclassified.csv for systematic errors
  3. Visual Inspection: Use src/analysis/visualize_gradcam.py to check model attention
  4. Data Fixes: Correct mislabeled images, remove duplicates

Stage 3: Grad-CAM Cropping (Optional)

For classes with low recall due to small fracture regions:

python src/training/pipeline.py \
    --checkpoint outputs/swin_mps/best.pth \
    --stage2 \
    --stage2-crop-dir crops \
    --cam-layer "layers.3.blocks.1.attn" \
    # ... other args

How It Works:

  1. Generate Grad-CAM heatmaps for training images
  2. Extract bounding boxes from high-activation regions
  3. Crop images to focus on fracture locations
  4. Retrain model on cropped ROIs
  5. Fine-tune for 5-10 epochs

Stage 4: Ensemble & Deployment

  1. Train Multiple Models: Swin + ConvNeXt + DenseNet
  2. Run Ensemble Agent: Average probabilities across models
  3. Validate Improvements: Check if ensemble reduces errors on hard classes
  4. Deploy Best Model(s): Use for diagnostic agent and chat app

Related Work

This system builds upon recent advances in:

  1. Vision Transformers for Medical Imaging
    Liu et al., "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (ICCV 2021)

  2. Explainable AI in Healthcare
    Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" (ICCV 2017)

  3. Multi-Agent Systems in Clinical Decision Support
    Topol, "High-performance medicine: the convergence of human and artificial intelligence" (Nature Medicine 2019)


🀝 Contributing

We welcome contributions! Please follow these steps:

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

Development Guidelines:

  • Follow PEP 8 style guide
  • Add docstrings to all functions
  • Include type hints
  • Write unit tests for new features
  • Update documentation in docs/REFERENCE.md

πŸ‘₯ Authors & Acknowledgments

Research Team:
DJSCE ACM Research

Acknowledgments:


πŸ“§ Contact

For questions, collaborations, or issues:


πŸ”— Useful Links


Last Updated: December 2025
Version: 1.0.0
Status: βœ… Research Code (Stable)


Built with ❀️ by the DJSCE ACM Research Team

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors