This repository implements a Multi-Task Learning (MTL) architecture that performs multiple computer vision tasks in a single pass to enable cost-effective autonomous driving perception. Our system achieves real-time performance with a single forward pass taking only 0.007944s.
🚀 Featured Insight: Discover how our model handles real-world traffic scenarios in India! Check out this insightful LinkedIn post.
- Multi-Task Learning (MTL): Combines 2D object detection, semantic segmentation, and monocular depth estimation in a single neural network
- Efficient Architecture: Shared encoder with task-specific decoders for real-time processing
- Cost-Effective: Eliminates the need for expensive hardware by performing multiple tasks with a single model
- Real-time Performance: Achieves >60 FPS on V100 GPU
- Indian Street Adaptation: Trained and tested on the India Driving Dataset from IIIT Hyderabad
The MTL architecture uses a shared encoder-decoder design:
┌───────────┐ ┌───────────────┐
│ │ │ Segmentation │
│ │───▶│ Decoder │───▶ Pixel-level Scene Understanding
│ │ └───────────────┘
│ Shared │ ┌───────────────┐
│ Encoder │───▶│ Depth │───▶ 3D Structure Estimation
│ (ResNet) │ │ Decoder │
│ │ └───────────────┘
│ │ ┌───────────────┐
│ │───▶│ Detection │───▶ Object Localization
└───────────┘ │ Decoder │
└───────────────┘
- Encoder: A ResNet-based backbone that processes raw camera input
- Segmentation Decoder: U-Net style architecture for pixel-wise classification (35 classes)
- Depth Decoder: Predicts monocular depth using a specialized loss function
- Detection Decoder: Anchor-free object detection for 15 object classes
| Task | Metric | Performance |
|---|---|---|
| Segmentation | IOU | 0.979 |
| Segmentation | Pixel Accuracy | 0.943 |
| Depth Estimation | A1 | 0.852 |
| Depth Estimation | RMSE | 0.031 |
| Object Detection | mAP @ 0.5 IOU | 0.256 |
| Object Detection | Average IOU | 0.726 |
Our training approach includes several specialized techniques:
- Knowledge Distillation for Depth: We used distillation to train our depth decoder, leveraging results from a state-of-the-art depth estimation model
- PCGrad Optimizer: Used for managing task interference between the different learning objectives
- Custom Loss Functions:
- Dice-Focal loss for segmentation
- Specialized depth loss including gradient terms
- Detection loss combining heatmap, regression, and dimension losses
The model is trained on the India Driving Dataset from IIIT Hyderabad, which includes:
- Images of Indian street scenes with diverse traffic patterns
- Annotations for object detection (15 classes)
- Pixel-level semantic segmentation (35 classes)
- Depth ground truth
The latest version of our MTL architecture (MTL 2.0) includes several improvements:
- Significantly improved depth estimation performance
- Temporally consistent results without explicit temporal monitoring
- Reduced computational cost while maintaining accuracy
- Tested on V100 GPU with 32 GB VRAM
from encoder import Encoder
from decoders import seg_decoder, objdet_Decoder
class MTL_Model(nn.Module):
def __init__(self, n_classes=35, device='cuda'):
super(MTL_Model, self).__init__()
self.encoder = Encoder(device=device)
self.seg_decoder = seg_decoder(n_classes, device=device)
self.dep_decoder = seg_decoder(n_classes=1, device=device)
self.obj_decoder = objdet_Decoder(n_classes=15, device=device)
self.to(device)
def forward(self, X):
outputs = self.encoder(X)
seg_maps = self.seg_decoder(outputs)
depth_maps = self.dep_decoder(outputs)
detection_maps = self.obj_decoder(outputs)
return (seg_maps, torch.sigmoid(depth_maps), detection_maps)model = MTL_Model(device=device)
model.load_state_dict(torch.load("model_path.pth", map_location=device))
model.eval()
# Process a single image
with torch.no_grad():
rgb = preprocess(image).unsqueeze(0).to(device)
seg_maps, depth_maps, detection_maps = model(rgb)
# Extract results
seg_pred = torch.argmax(torch.softmax(seg_maps, dim=1).squeeze(0), dim=0)
depth_pred = depth_maps.squeeze().cpu().numpy()
# For detection, decode the heatmap
hmap, regs, w_h_ = zip(*detection_maps)
detections = ctdet_decode(hmap[0], regs[0], w_h_[0])- This work was built using PyTorch and related computer vision libraries
- We acknowledge the India Driving Dataset from IIIT Hyderabad