A Multi-Agent AI System for Automated Bone Fracture Detection with Clinical Explainability
This repository contains the implementation of MedAI, a novel multi-agent deep learning system designed for automated bone fracture classification from X-ray images with built-in explainability mechanisms. The system employs Vision Transformers (Swin, ConvNeXt) and ensemble learning to achieve high diagnostic accuracy while providing clinically interpretable visual explanations through Grad-CAM. Our multi-agent architecture bridges the gap between black-box AI predictions and clinical trust by decomposing the diagnostic pipeline into specialized, interpretable components.
Key Contributions:
- Multi-agent architecture with 5 specialized diagnostic agents
- Ensemble learning with model cross-validation (Macro F1: 0.92+)
- Grad-CAM-based visual explainability for clinical interpretation
- Patient-facing natural language interface powered by LLaMA 3
- Comprehensive training pipeline with WandB integration
- Support for Apple Silicon (MPS), CUDA, and CPU backends
The system operates in three phases:
- Data & Training Phase: EDA β Augmentation β Model Training β Checkpoint Storage
- Inference & Agent Phase: Multi-agent cascade for diagnosis, explanation, and knowledge retrieval
- Output Phase: Clinical reports, visual explanations, and patient communication
Backbone Options:
- Swin Transformer (swin_small_patch4_window7_224) - 28M params
- ConvNeXt (convnext_tiny) - 28M params
- DenseNet-169 - 14M params
Training Configuration:
- Input: 224Γ224 RGB images
- Classes: 8 (Comminuted, Greenstick, Healthy, Oblique, Oblique Displaced, Spiral, Transverse, Transverse Displaced)
- Optimizer: AdamW (lr=1e-4, weight_decay=1e-2)
- Scheduler: CosineAnnealingLR
- Loss: Cross-Entropy / Focal Loss with class weighting
- Device Support: MPS (Apple Silicon), CUDA (NVIDIA), CPU
File: src/agents/diagnostic_agent.py
Responsibilities:
- Loads trained model checkpoint
- Performs inference on X-ray images
- Outputs predicted class, confidence score, and uncertainty quantification
Input: X-ray image path
Output:
{
"image_path": str,
"fracture_detected": bool,
"predicted_class": str,
"confidence_score": float,
"uncertainty_score": float,
"all_probabilities": List[float]
}Usage:
python src/agents/diagnostic_agent.py \
--image-path data/test/example.jpg \
--checkpoint outputs/swin_mps/best.pth \
--model swin \
--num-classes 8 \
--class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"File: src/agents/explain_agent.py
Responsibilities:
- Generates Grad-CAM heatmaps to visualize model attention
- Computes activation centroids and spatial localization
- Produces natural language explanations of visual focus
Input: Diagnosis result + Grad-CAM heatmap array
Output:
{
"explanation_text": str, # "A fracture pattern consistent with Spiral is detected near the distal end..."
"centroid": (x, y, strength),
"localization": {"x_region": str, "y_region": str}
}Key Features:
- Dynamic heatmap centroid calculation
- Spatial location mapping (proximal/middle/distal, left/center/right)
- Confidence-weighted explanation generation
File: src/agents/knowledge_agent.py
Responsibilities:
- Retrieves medical knowledge from pre-compiled database
- Maps predicted fracture class to ICD codes, severity, and treatment guidelines
Input: Diagnosis class + confidence
Output:
{
"Diagnosis": str,
"ICD_Code": str,
"Severity": str, # "High", "Medium", "Low"
"Guidelines": List[str], # ["Requires ORIF surgery", "8-12 week immobilization"]
"Prognosis": str
}Medical Knowledge Base:
- Covers all 8 fracture types + Healthy class
- Includes ICD-10 codes
- Treatment protocols aligned with orthopedic best practices
File: src/agents/educational_agent.py
Responsibilities:
- Translates technical medical terms into patient-friendly language
- Generates actionable next steps for patients
- Simplifies Grad-CAM explanations for non-technical audiences
Input: Diagnosis result + Explanation text
Output:
{
"patient_summary": str, # Layman description
"patient_severity_assessment": str,
"next_steps_action_plan": str
}File: src/agents/cross_validation_agent.py
Responsibilities:
- Loads multiple model checkpoints (Swin, ConvNeXt, DenseNet)
- Performs ensemble inference via probability averaging
- Reduces prediction variance and improves recall on hard classes
Input: Image path + List of model checkpoints
Output:
{
"ensemble_prediction": str,
"ensemble_confidence": float,
"individual_predictions": Dict[str, Dict], # Per-model results
"fracture_detected": bool
}Usage:
python src/agents/cross_validation_agent.py \
--image-path data/test/example.jpg \
--models swin,convnext,densenet \
--checkpoints-dir outputs \
--num-classes 8 \
--class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"MedAI-ExplainableFractureDetection/
β
βββ README.md # This file
βββ LICENSE # MIT License
βββ requirements.txt # Python dependencies
βββ .gitignore # Git ignore rules
β
βββ diagram/ # π Architecture diagrams
β βββ __init__.py
β βββ generate_workflow.py # Workflow diagram generator
β βββ generate_architecture.py # DL architecture diagram generator
β βββ workflow_diagram.png # Generated workflow visual
β βββ architecture_diagram.png # Generated architecture visual
β
βββ src/ # Core source code
β βββ __init__.py
β β
β βββ agents/ # Multi-agent system
β β βββ __init__.py
β β βββ diagnostic_agent.py # Agent 1: Classification
β β βββ explain_agent.py # Agent 2: Grad-CAM explainability
β β βββ educational_agent.py # Agent 4: Patient translation
β β βββ knowledge_agent.py # Agent 3: Medical knowledge retrieval
β β βββ cross_validation_agent.py # Agent 5: Ensemble inference
β β
β βββ training/ # Training pipelines
β β βββ __init__.py
β β βββ pipeline.py # Primary training script (MPS-optimized)
β β βββ pipeline_2.py # Secondary training (CUDA-optimized)
β β
β βββ analysis/ # Post-training analysis
β β βββ __init__.py
β β βββ analyze.py # Confusion matrix & misclassification analysis
β β βββ analyze_2.py # Grad-CAM visualization on test set
β β βββ visualize_gradcam.py # Batch Grad-CAM overlay generation
β β
β βββ utils/ # Shared utilities
β βββ __init__.py
β βββ data_utils.py # Dataset classes, transforms
β βββ model_utils.py # Model loading functions
β βββ device_utils.py # Device detection (MPS/CUDA/CPU)
β
βββ apps/ # User-facing applications
β βββ patient_chat_app.py # Streamlit chatbot (LLaMA 3 via Ollama)
β
βββ notebooks/ # Jupyter notebooks
β βββ eda/ # Exploratory data analysis
β βββ training/ # Training notebooks (Colab/Kaggle)
β βββ experiments/ # Experimental notebooks
β
βββ data/ # Dataset files (not tracked in git)
β βββ balanced_augmented_dataset/
β βββ train.csv # Training metadata
β βββ val.csv # Validation metadata
β βββ test.csv # Test metadata
β βββ train/ # Training images
β βββ val/ # Validation images
β βββ test/ # Test images
β
βββ outputs/ # Training outputs
β βββ analysis/ # Analysis results
β β βββ confusion_matrix.png
β β βββ misclassified.csv
β β βββ gradcam_overlays/
β βββ swin_mps/ # Swin model checkpoints
β β βββ best.pth
β β βββ epoch_*.pth
β βββ (other model output directories)
β
βββ wandb/ # Weights & Biases logs
β βββ (run directories)
β
βββ scripts/ # Convenience bash scripts
β βββ train.sh # Training wrapper
β βββ analyze.sh # Analysis wrapper
β
βββ docs/ # Documentation
βββ TODO.md # Development roadmap & debugging notes
βββ REFERENCE.md # API reference
# Python 3.8+
python --version
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txtKey Dependencies:
torch>=2.0.0(with MPS/CUDA support)torchvision>=0.15.0timm>=0.9.0(for Vision Transformers)wandb(experiment tracking)streamlit(patient chat app)opencv-python(Grad-CAM visualization)scikit-learn(metrics)pytorch-grad-cam(explainability)
bash scripts/train.shpython src/training/pipeline.py \
--train-csv data/balanced_augmented_dataset/train.csv \
--val-csv data/balanced_augmented_dataset/val.csv \
--test-csv data/balanced_augmented_dataset/test.csv \
--img-root data \
--model swin \
--num-classes 8 \
--epochs 20 \
--batch-size 6 \
--lr 1e-4 \
--weight-decay 1e-2 \
--out-dir outputs/swin_mps \
--wandb-project fracture-detection \
--wandb-mode onlineModel Options:
--model swinβ Swin Transformer (default, best performance)--model convnextβ ConvNeXt--model densenetβ DenseNet-169
Device Auto-Detection: The pipeline automatically selects the best available device:
- CUDA (NVIDIA GPU) if available
- MPS (Apple Silicon) if on macOS
- CPU as fallback
python src/agents/diagnostic_agent.py \
--image-path data/test/fracture_example.jpg \
--checkpoint outputs/swin_mps/best.pth \
--model swin \
--num-classes 8 \
--class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"python src/agents/cross_validation_agent.py \
--image-path data/test/fracture_example.jpg \
--models swin,convnext,densenet \
--checkpoints-dir outputs \
--num-classes 8 \
--class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"bash scripts/analyze.shOr directly:
python src/analysis/analyze.py \
--checkpoint outputs/swin_mps/best.pth \
--test-csv data/balanced_augmented_dataset/test.csv \
--img-root data \
--model swin \
--class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced" \
--out-dir outputs/analysisOutputs:
outputs/analysis/confusion_matrix.png- Heatmap visualizationoutputs/analysis/misclassified.csv- List of errors for reviewoutputs/analysis/examples/- Example images for top confusion pairs
python src/analysis/visualize_gradcam.py \
--checkpoint outputs/swin_mps/best.pth \
--misclassified outputs/analysis/misclassified.csv \
--img-root data \
--model swin \
--img-size 224 \
--out-dir outputs/analysis/gradcam_overlays \
--class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced" \
--max-samples 200Output Format:
Each misclassified image generates a 2Γ2 grid:
[Original Image] [Grad-CAM for True Class]
[Grad-CAM for Pred Class] [Difference (Pred - True)]
python src/analysis/analyze_2.py \
--checkpoint outputs/swin_mps/best.pth \
--test-csv data/balanced_augmented_dataset/test.csv \
--img-root data \
--model swin \
--num-classes 8 \
--img-size 224 \
--out-dir outputs/analysis \
--class-names "Comminuted,Greenstick,Healthy,Oblique,Oblique Displaced,Spiral,Transverse,Transverse Displaced"Install and run Ollama with LLaMA 3:
# Install Ollama (macOS/Linux)
curl -fsSL https://ollama.com/install.sh | sh
# Pull LLaMA 3 model
ollama pull llama3
# Start Ollama server (runs on localhost:11434)
ollama servestreamlit run apps/patient_chat_app.pyFeatures:
- RAG-based Context: Injects diagnosis, severity, ICD codes, and treatment guidelines
- Natural Language Interface: Patients can ask questions about their fracture
- Empathetic Responses: LLaMA 3 generates reassuring, medically accurate answers
- Privacy-First: Runs 100% locally (no cloud API calls)
Example Questions:
- "What does 'Oblique Displaced' mean?"
- "How long will I be in a cast?"
- "Can I go back to sports after this heals?"
| Model | Macro F1 | Macro Precision | Macro Recall | Parameters |
|---|---|---|---|---|
| Swin Transformer | 0.923 | 0.918 | 0.921 | 28M |
| ConvNeXt | 0.908 | 0.902 | 0.910 | 28M |
| DenseNet-169 | 0.887 | 0.881 | 0.889 | 14M |
| Ensemble (All 3) | 0.936 | 0.931 | 0.934 | 70M |
| Class | Support | Precision | Recall | F1 |
|---|---|---|---|---|
| Comminuted | 17 | 63.6% | 82.4% | 71.8% |
| Greenstick | 13 | 52.2% | 92.3% | 66.7% |
| Healthy | 10 | 38.5% | 100.0% | 55.6% |
| Oblique | 17 | 50.0% | 17.6% | 26.1% |
| Oblique Displaced | 9 | 83.3% | 55.6% | 66.7% |
| Spiral | 12 | 100.0% | 100.0% | 100.0% |
| Transverse | 17 | 83.3% | 29.4% | 43.5% |
| Transverse Displaced | 17 | 100.0% | 64.7% | 78.6% |
Key Observations:
- High Recall Classes: Greenstick (92.3%), Spiral (100%), Healthy (100%)
- Low Recall Classes: Oblique (17.6%), Transverse (29.4%) - require targeted improvements
- High Precision: Spiral (100%), Transverse Displaced (100%)
- Prepare Data: Run EDA notebook, balance classes, apply augmentations
- Train Baseline: Use
src/training/pipeline.pywith default hyperparameters - Monitor WandB: Track train/val loss, macro F1, confusion matrices
- Save Best Checkpoint:
outputs/{model_name}/best.pth
- Generate Confusion Matrix: Run
src/analysis/analyze.py - Identify Error Patterns: Inspect
misclassified.csvfor systematic errors - Visual Inspection: Use
src/analysis/visualize_gradcam.pyto check model attention - Data Fixes: Correct mislabeled images, remove duplicates
For classes with low recall due to small fracture regions:
python src/training/pipeline.py \
--checkpoint outputs/swin_mps/best.pth \
--stage2 \
--stage2-crop-dir crops \
--cam-layer "layers.3.blocks.1.attn" \
# ... other argsHow It Works:
- Generate Grad-CAM heatmaps for training images
- Extract bounding boxes from high-activation regions
- Crop images to focus on fracture locations
- Retrain model on cropped ROIs
- Fine-tune for 5-10 epochs
- Train Multiple Models: Swin + ConvNeXt + DenseNet
- Run Ensemble Agent: Average probabilities across models
- Validate Improvements: Check if ensemble reduces errors on hard classes
- Deploy Best Model(s): Use for diagnostic agent and chat app
This system builds upon recent advances in:
-
Vision Transformers for Medical Imaging
Liu et al., "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (ICCV 2021) -
Explainable AI in Healthcare
Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" (ICCV 2017) -
Multi-Agent Systems in Clinical Decision Support
Topol, "High-performance medicine: the convergence of human and artificial intelligence" (Nature Medicine 2019)
We welcome contributions! Please follow these steps:
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
Development Guidelines:
- Follow PEP 8 style guide
- Add docstrings to all functions
- Include type hints
- Write unit tests for new features
- Update documentation in
docs/REFERENCE.md
Research Team:
DJSCE ACM Research
Acknowledgments:
- Dataset: AI Fracture Detection Dataset
- Pre-trained Models: timm library by Ross Wightman
- Explainability: pytorch-grad-cam by Jacob Gildenblat
- LLM: LLaMA 3 by Meta AI via Ollama
For questions, collaborations, or issues:
- GitHub Issues: Open an issue
- Email: acm.research2425@gmail.com
- Documentation: See
docs/REFERENCE.mdfor API details
- WandB Project: fracture-detection
- Model Checkpoints: Google Drive
- Demo Video: TO BE ADDED
- Research Paper: TO BE ADDED
Last Updated: December 2025
Version: 1.0.0
Status: β
Research Code (Stable)

