A PyTorch-based feedforward neural network for predicting cancer severity levels from medical tabular data, with automatic patient risk stratification using K-means clustering.
This project uses deep learning to predict cancer severity scores (0-1 scale) from patient medical records. The model is trained on NIH Glioblastoma dataset containing 21,634 patient records with 164 clinical features.
- Deep Neural Network: 3 hidden layers with batch normalization and dropout
- Automatic Risk Stratification: K-means clustering into Safe 🟢, Moderate 🟡, and Severe 🔴 groups
- Continuous Training: Save/load checkpoints and track training history across sessions
- Comprehensive Visualizations: Training curves, prediction analysis, and patient clustering plots
- High Accuracy: Achieves ~99.4% accuracy (0.006 MAE on 0-1 scale)
Input Layer: 22 features
Hidden Layer 1: 128 neurons (ReLU + BatchNorm + Dropout 0.3)
Hidden Layer 2: 64 neurons (ReLU + BatchNorm + Dropout 0.3)
Hidden Layer 3: 32 neurons (ReLU + BatchNorm + Dropout 0.2)
Output Layer: 1 neuron (Sigmoid activation)
Total Parameters: 13,761
Scaling Rule: Pyramidal (÷2 per layer)
- Gender, Race, Ethnicity, Age at diagnosis
- Vital status, Tumor grade, Morphology, Site of biopsy
- Laterality, Prior malignancy, Prior treatment, Another malignancy
- Metastasis, Disease status, Progression, WHO grade
- Alcohol history, Alcohol intensity, Tobacco frequency
- Tobacco onset, Days to death, Karnofsky performance score
The severity level is computed from 5 weighted clinical factors:
| Factor | Weight | Values |
|---|---|---|
| Tumor Grade | 30% | G1: 0.05, G2: 0.15, G3: 0.25, G4: 0.30 |
| Vital Status | 25% | Alive: 0.0, Dead: 0.25 |
| Metastasis | 20% | Yes: 0.20, No: 0.0 |
| Prior Malignancy | 15% | Yes: 0.15, No: 0.0 |
| Disease Status | 10% | With tumor: 0.10, Tumor free: 0.0 |
Total Range: 0.0 (safest) to 1.0 (most severe)
Python 3.7+pip install torch torchvision pandas numpy scikit-learn matplotlibOr for CPU-only PyTorch (smaller, faster):
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install pandas numpy scikit-learn matplotlib.
├── cancer_severity_prediction.ipynb # Main Jupyter notebook
├── Downloads/
│ └── NIH Glioblastoma data.csv # Input dataset (21,634 patients)
├── medical_cancer_severity_model.pth # Latest trained model
├── best_model.pth # Best performing model
├── training_history.json # Training session logs
└── patient_clusters.csv # K-means cluster assignments
Place your CSV file with the required columns (see Input Features section).
Update the file path in Cell 8:
csv_file_path = 'Downloads/NIH Glioblastoma data.csv'Open in Jupyter Notebook and run all cells:
jupyter notebook cancer_severity_prediction.ipynbOr click: Kernel → Restart & Run All
Epoch [10/100], Loss: 0.006933
Epoch [20/100], Loss: 0.007218
...
Epoch [100/100], Loss: 0.006926
Evaluation Results (0-1 scale):
Mean Absolute Error (MAE): 0.0058
Root Mean Squared Error (RMSE): 0.0117
NEW BEST MODEL! MAE improved to 0.0058
Run Cell 17b to train for additional epochs:
continue_training(model, train_loader, test_loader, criterion,
optimizer, device, checkpoint, additional_epochs=50)- Type: L1Loss (Mean Absolute Error)
- Range: 0.0 to 1.0
- Interpretation: Average prediction error
- Loss < 0.01 = Excellent (< 1% error)
- Loss 0.01-0.05 = Good
- Loss > 0.05 = Needs improvement
- Definition: One complete pass through all training data
- Example: 100 epochs = model reviews all 17,307 patients 100 times
- Purpose: Each pass refines the model's understanding
- MAE (Mean Absolute Error): Average prediction error
- RMSE (Root Mean Squared Error): Penalizes large errors more
- Current Performance: MAE ~0.006 (99.4% accuracy)
Automatically groups patients into risk categories:
| Category | Severity Range | Description |
|---|---|---|
| 🟢 Safe | 0.0 - 0.35 | Low risk, favorable prognosis |
| 🟡 Moderate | 0.35 - 0.55 | Medium risk, requires monitoring |
| 🔴 Severe | 0.55 - 1.0 | High risk, aggressive treatment needed |
CLUSTER ANALYSIS
Cluster 0 - SAFE 🟢
Number of patients: 8,542
Average severity: 0.298
Cluster 1 - MODERATE 🟡
Number of patients: 7,123
Average severity: 0.502
Cluster 2 - SEVERE 🔴
Number of patients: 5,969
Average severity: 0.687
Results saved to patient_clusters.csv
The notebook generates 6 plots:
- Training Loss Curve - Loss vs. epoch
- Predictions vs. Actual - Scatter plot with diagonal line
- Error Distribution - Histogram of prediction errors
- Patient Clusters (2D PCA) - Left: by cluster, Right: by severity
- Cluster Severity Distributions - 3 histograms (one per cluster)
Every training session is logged in training_history.json:
{
"sessions": [
{
"date": "2026-02-09 14:32:15",
"mae": 0.0064,
"rmse": 0.0129,
"epochs": 100,
"improved": true
}
],
"best_mae": 0.0064,
"best_model_date": "2026-02-09 14:32:15"
}You can adjust these in the notebook:
| Parameter | Location | Default | Description |
|---|---|---|---|
| Learning Rate | Cell 13 | 0.001 | How fast the model learns |
| Batch Size | Cell 11 | 32 | Samples per training step |
| Epochs | Cell 14 | 100 | Training iterations |
| Hidden Layers | Cell 4 | 128, 64, 32 | Network architecture |
| Dropout | Cell 4 | 0.3, 0.3, 0.2 | Regularization strength |
| Loss Function | Cell 13 | L1Loss | MAE vs MSE |
- Ensure
%matplotlib inlineis in Cell 1 - Restart kernel and run all cells
- Delete
training_history.jsonandtraining_history.json.backup - The notebook will create fresh files
- Reduce
batch_sizein Cell 11 - Model automatically uses CPU if GPU unavailable
- Check for missing data in CSV
- The preprocessing handles this automatically, but verify data quality
# Make predictions on new patient data
patient_data = X_test[0:1] # Single patient
prediction = predict_severity(model, patient_data, device)
print(f"Predicted severity: {prediction[0][0]:.3f}")
# Output: Predicted severity: 0.487 (Moderate risk)Tested on NIH Glioblastoma dataset:
- Training samples: 17,307 patients
- Test samples: 4,327 patients
- Features: 22 clinical variables
- MAE: 0.0064 (0.64% average error)
- RMSE: 0.0129
- Prediction range: 0.280 - 0.771
To improve the model:
- Add more features to
feature_mappingin Cell 3 - Adjust severity weighting in
preprocess_medical_data() - Experiment with different architectures in Cell 4
- Try different optimizers (SGD, RMSprop) in Cell 13
This project is for educational and research purposes.
- Dataset: NIH Glioblastoma Cancer Genome Atlas (TCGA-GBM)
- Framework: PyTorch 2.10.0
- Preprocessing: scikit-learn
- Visualization: matplotlib
Last Updated: February 2026
Version: 1.0
Python: 3.7+
PyTorch: 2.10.0+