Explainable deep learning for knee MRI classification on the MRNet dataset (CAM + SHAP) with an interactive viewer.
Train → generate explanations → inspect axial/coronal/sagittal slices side-by-side.
MRNet Dataset »
·
Report Bug
·
Request Feature
This repository contains scripts to train MRNet-style models and generate explainability artifacts (CAM/SHAP) for knee MRI exams. MRNet is a knee MRI dataset/competition with 1,370 exams from Stanford University Medical Center.
- Training scripts for MRNet-style classification workflows (
train_mrnet.py,train_classifier.py). - Generate CAM visualizations (
create_cams.py). - Generate SHAP visualizations (
create_shaps.py). - Interactive OpenCV-based viewer that shows CAM / raw slice / SHAP for axial, coronal, sagittal planes (
vis_predictions.py). - Notebook included for experimentation (
MRNet_v1_0.ipynb).
- PyTorch + TorchVision.
- OpenCV, NumPy, Pandas, Matplotlib.
- SHAP.
- Python environment (venv/conda recommended).
- MRNet dataset access (see Dataset section).
-
Clone:
git clone https://github.com/KaMeLoTmArMoT/mrnet_xai.git cd mrnet_xai -
Install dependencies:
pip install -r requirements.txt
Note: requirements include a Git dependency (
torchsample).
- MRNet is distributed under a research-use agreement; the download link and dataset must not be redistributed.
- Official info + registration: https://stanfordmlgroup.github.io/competitions/mrnet/
Start with the training scripts:
train_mrnet.pytrain_classifier.py
Use:
create_cams.pyfor CAM outputs.create_shaps.pyfor SHAP outputs.
Run the viewer:
vis_predictions.py
- Quit:
qorEsc. - Change slice index (per plane):
- Decrease:
1(axial),2(coronal),3(sagittal). - Increase:
7(axial),8(coronal),9(sagittal).
- Decrease:
- Change case:
- Prev/next:
4/6. - Jump to case:
Enterthen type case id (in console).
- Prev/next:
- Switch SHAP source:
v(if multiple sources exist). - Zoom:
+/-, reset with0.
Contributions are welcome:
- Fork the repo
- Create a feature branch
- Open a PR



