Skip to content

Azjob21/masked-xray-challenge

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Masked X-Ray Challenge

Binary X-ray classification using two independent deep learning strategies, both trained on the same dataset and submitted to the Kaggle leaderboard.


Results Summary

Strategy Method Best Val AUC
Strategy 1 EfficientNet-B0 SSL Pipeline 0.9860
Strategy 2 EfficientNetV2-S (standalone) 0.9991
Strategy 2 ConvNeXt-Small (standalone) 0.9999
Strategy 2 Ensemble (EfficientNetV2 + ConvNeXt) ~0.9995

Repository Structure

masked-xray-challenge/
├── notebooks/
│   ├── strategy1_efficientnet_ssl.ipynb
│   └── strategy2_ensemble.ipynb
├── checkpoints/
│   ├── strategy1/
│   │   ├── tf_efficientnet_b0_ssl_backbone.pth
│   │   ├── tf_efficientnet_b0_simsiam_full.pth
│   │   ├── tf_efficientnet_b0_best_classifier.pth
│   │   └── tf_efficientnet_b0_training_history.csv
│   └── strategy2/
│       ├── tf_efficientnetv2_s_in21k_ft_in1k_best.pth
│       ├── tf_efficientnetv2_s_in21k_ft_in1k_history.csv
│       ├── convnext_small_fb_in22k_ft_in1k_best.pth
│       └── convnext_small_fb_in22k_ft_in1k_history.csv
├── submissions/
│   ├── strategy1_auc0986.csv
│   └── strategy2_ensemble.csv
├── .gitignore
└── README.md

Dataset

Split Images Labels
Train (unlabeled) 4,099
Val (labeled) 878 641 positive / 237 negative
Test 879


Strategy 1 — EfficientNet SSL Pipeline

Overview

Self-supervised learning pipeline using SimSiam contrastive learning to pretrain on unlabeled train images, followed by supervised fine-tuning on labeled validation images.

Pipeline Architecture

Unlabeled Train Images (4,099)
         ↓
Hand-Crafted Feature Extraction
(Intensity + GLCM Texture + Edge Features)
         ↓
Nearest Neighbor Pair Finding (k=1)
         ↓
SimSiam Self-Supervised Pretraining
(EfficientNet-B0 backbone, 20 epochs)
         ↓
SSL Backbone Weights
         ↓
Supervised Fine-Tuning
(878 labeled images, frozen backbone)
         ↓
Test-Time Augmentation (5 iterations)
         ↓
submission.csv

Results

Metric Value
Best Validation AUC 0.9860
Best Epoch 27
Early Stopping Epoch 42
SSL Pretraining Epochs 20
Train Images (SSL) 4,099
Labeled Images (Supervised) 878

Method Details

Step 1 — Hand-Crafted Feature Extraction

Features extracted per image for similarity-based pair finding:

  • Intensity — mean, std, median, min, max, percentiles (25/75/90), entropy, variance (10 features)
  • GLCM Texture — contrast, homogeneity, energy at distances [1, 3, 5] (9 features)
  • Edge — Sobel gradient mean, std, max (3 features)

Total: 22 features per image, normalized with StandardScaler.

Step 2 — Nearest Neighbor Pair Finding

Each image is paired with its single closest neighbor in feature space using Euclidean distance, generating 4,099 positive pairs for SSL training.

Step 3 — SimSiam Pretraining

Setting Value
Backbone tf_efficientnet_b0 (ImageNet pretrained)
Projection head 3-layer MLP, 2048-dim output
Predictor head 2-layer MLP, 512-dim bottleneck
Loss Negative cosine similarity
Optimizer SGD lr=0.03, momentum=0.9, weight_decay=1e-4
Scheduler CosineAnnealingLR
Epochs 20
Augmentations RandomCrop, HorizontalFlip, Rotation(40°), ColorJitter, RandomPerspective

SSL loss converged from -0.23 → -0.89 over 20 epochs.

Step 4 — Supervised Fine-Tuning

Setting Value
Backbone Initialized from SSL weights, frozen
Classifier head Linear(feat_dim→1024) → ReLU → Dropout(0.5) → Linear(1024→512) → ReLU → Dropout(0.5) → Linear(512→1)
Loss BCEWithLogitsLoss
Optimizer RMSprop lr=1e-4, weight_decay=1e-6
Scheduler ReduceLROnPlateau (factor=0.5, patience=5)
Early stopping patience=15

Step 5 — Test-Time Augmentation

5 inference passes averaged:

  • Pass 1: Standard resize 224×224 (original)
  • Passes 2–5: Training augmentations (random crops, flips, rotations)

Loading Strategy 1 Model

checkpoint = torch.load('checkpoints/strategy1/tf_efficientnet_b0_best_classifier.pth',
                        map_location='cpu')

model = EfficientNetClassifier(
    model_name=checkpoint['model_name'],
    pretrained=False,
    freeze_backbone=checkpoint['freeze_backbone'],
    dropout=0.5
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded from epoch {checkpoint['epoch']}")
print(f"Val AUC at save: {checkpoint['val_auc']:.4f}")

Key Design Decisions

Why SimSiam for SSL? SimSiam does not require negative pairs or a momentum encoder, making it simpler and faster to train on a small unlabeled dataset.

Why nearest neighbor pairs? Hand-crafted features (texture, intensity, edges) capture radiological similarity without requiring labels, giving the SSL model meaningful positive pairs to learn from rather than random augmentation views.

Why freeze the backbone during fine-tuning? With only 878 labeled samples, freezing prevents catastrophic forgetting of SSL representations while allowing the classifier head to specialize quickly.



Strategy 2 — Advanced Multi-Model Ensemble

Overview

Four diverse architectures trained independently with advanced augmentation techniques (MixUp, CutMix), Focal Loss, spatial attention heads, and CLAHE X-ray preprocessing. Final predictions are combined via AUC-weighted ensemble.

Pipeline Architecture

Labeled Val Images (878)
         ↓
CLAHE Preprocessing
(Contrast Limited Adaptive Histogram Equalization)
         ↓
┌──────────────────────────────────────┐
│  Train 4 Models Independently        │
│                                      │
│  EfficientNetV2-S  @ 384px           │
│  ConvNeXt-Small    @ 384px           │
│  SwinV2-Small      @ 256px (dropped) │
│  EVA-02-Small      @ 336px (dropped) │
└──────────────────────────────────────┘
         ↓
Test-Time Augmentation (5 passes each)
         ↓
AUC-Weighted Ensemble
(EfficientNetV2 + ConvNeXt only)
         ↓
submission.csv

Results

Model Image Size Best Val AUC Epochs Status
EfficientNetV2-S 384px 0.9991 60 ✓ Used in ensemble
ConvNeXt-Small 384px 0.9999 57 ✓ Used in ensemble
SwinV2-Small 256px 0.8360 31 ✗ Dropped
EVA-02-Small 336px 0.7040 29 ✗ Dropped
Ensemble ~0.9995 ✓ Final submission

Method Details

Preprocessing — CLAHE

Each image is enhanced with Contrast Limited Adaptive Histogram Equalization before training and inference, improving visibility of subtle X-ray features.

Model Architecture — AttentionClassifier

Each backbone is extended with a spatial attention module and enhanced classifier head:

Backbone (pretrained)
     ↓
Spatial Attention (Conv1×1 → ReLU → Conv1×1 → Sigmoid)
     ↓
AdaptiveAvgPool2d
     ↓
Linear(feat_dim→1024) → ReLU → BN → Dropout(0.5)
     ↓
Linear(1024→512) → ReLU → BN → Dropout(0.25)
     ↓
Linear(512→1)

Training Settings

Setting Value
Loss Focal Loss (α=0.25, γ=2.0, label smoothing=0.1)
Optimizer AdamW lr=1e-4, weight_decay=1e-5
Scheduler CosineAnnealingWarmRestarts (T₀=10, T_mult=2)
Batch size 16
Max epochs 60
Early stopping patience=12
Class balancing WeightedRandomSampler
Augmentations MixUp (α=0.4) + CutMix (α=1.0), randomly applied 50/50

Why Transformers Underperformed

SwinV2 (0.836) and EVA-02 (0.704) both struggled on this dataset. With only 878 labeled samples, vision transformers lack sufficient data to learn effective attention patterns, while CNN-based models (EfficientNetV2, ConvNeXt) benefit from stronger inductive biases that suit small medical imaging datasets.

Ensemble Weighting

Final predictions are a weighted average of the two strong models, with weights proportional to their validation AUC scores:

weight_efficientnetv2 = 0.9991 / (0.9991 + 0.9999) ≈ 0.4998
weight_convnext       = 0.9999 / (0.9991 + 0.9999) ≈ 0.5002

Test-Time Augmentation

5 inference passes per model:

  • Pass 1: Standard resize (original)
  • Pass 2: Horizontal flip
  • Pass 3: Rotation +5°
  • Pass 4: Rotation -5°
  • Pass 5: Center crop (slight zoom)

Loading Strategy 2 Models

# Load EfficientNetV2
ckpt = torch.load('checkpoints/strategy2/tf_efficientnetv2_s_in21k_ft_in1k_best.pth',
                  map_location='cpu')
model = AttentionClassifier(model_name=ckpt['model_name'], img_size=ckpt['img_size'])
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f"Loaded epoch {ckpt['epoch']} | Val AUC: {ckpt['val_auc']:.4f}")

# Load ConvNeXt
ckpt = torch.load('checkpoints/strategy2/convnext_small_fb_in22k_ft_in1k_best.pth',
                  map_location='cpu')
model = AttentionClassifier(model_name=ckpt['model_name'], img_size=ckpt['img_size'])
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f"Loaded epoch {ckpt['epoch']} | Val AUC: {ckpt['val_auc']:.4f}")


Requirements

pip install torch torchvision timm scikit-learn pandas numpy opencv-python scikit-image tqdm

Notes

  • Model .pth files are tracked with Git LFS due to size (86MB–201MB)
  • Both strategies train and validate on the same 878 labeled val images
  • The unlabeled train set (4,099 images) is only used in Strategy 1 for SSL pretraining
  • Transformer models (SwinV2, EVA-02) require significantly more labeled data to be competitive with CNN models on this task

Releases

No releases published

Packages

 
 
 

Contributors