Skip to content

Commit 52b1d5e

Browse files
committed
cleared up the fact that the loss should be computed as MSE
1 parent 2b66fac commit 52b1d5e

File tree

2 files changed

+9
-20
lines changed

2 files changed

+9
-20
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Write a Python class to implement Mixed Precision Training that uses both float32 and float16 data types to optimize memory usage and speed. Your class should have an `__init__(self, loss_scale=1024.0)` method to initialize with loss scaling factor. Implement `forward(self, weights, inputs, targets)` to perform forward pass with float16 computation and return loss (scaled) in float32, and `backward(self, gradients)` to unscale gradients and check for overflow. Use float16 for computations but float32 for gradient accumulation. Return gradients as float32 and set them to zero if overflow is detected. The forward method must return loss as float32 dtype. Only use NumPy.
1+
Write a Python class to implement Mixed Precision Training that uses both float32 and float16 data types to optimize memory usage and speed. Your class should have an `__init__(self, loss_scale=1024.0)` method to initialize with loss scaling factor. Implement `forward(self, weights, inputs, targets)` to perform forward pass with float16 computation and return Mean Squared Error (MSE) loss (scaled) in float32, and `backward(self, gradients)` to unscale gradients and check for overflow. Use float16 for computations but float32 for gradient accumulation. Return gradients as float32 and set them to zero if overflow is detected. Only use NumPy.
Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,44 @@
11
# **Mixed Precision Training**
2-
32
## **1. Definition**
4-
53
Mixed Precision Training is a **deep learning optimization technique** that uses both **float16** (half precision) and **float32** (single precision) data types during training to reduce memory usage and increase training speed while maintaining model accuracy.
6-
74
The technique works by:
85
- **Using float16 for forward pass computations** to save memory and increase speed
96
- **Using float32 for gradient accumulation** to maintain numerical precision
107
- **Applying loss scaling** to prevent gradient underflow in float16
11-
128
---
13-
149
## **2. Key Components**
10+
### **Mean Squared Error (MSE) Loss**
11+
The loss function must be computed as Mean Squared Error:
12+
$$
13+
\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2
14+
$$
15+
where $y_i$ is the target and $\hat{y}_i$ is the prediction for sample $i$.
1516

1617
### **Loss Scaling**
1718
To prevent gradient underflow in float16, gradients are scaled up during the forward pass:
18-
1919
$$
20-
\text{scaled\_loss} = \text{loss} \times \text{scale\_factor}
20+
\text{scaled\_loss} = \text{MSE} \times \text{scale\_factor}
2121
$$
22-
2322
Then unscaled during backward pass:
24-
2523
$$
2624
\text{gradient} = \frac{\text{scaled\_gradient}}{\text{scale\_factor}}
2725
$$
28-
2926
### **Overflow Detection**
3027
Check for invalid gradients (NaN or Inf) that indicate numerical overflow:
31-
3228
$$
3329
\text{overflow} = \text{any}(\text{isnan}(\text{gradients}) \text{ or } \text{isinf}(\text{gradients}))
3430
$$
35-
3631
---
37-
3832
## **3. Precision Usage**
39-
4033
- **float16**: Forward pass computations, activations, temporary calculations
4134
- **float32**: Gradient accumulation, parameter updates, loss scaling
4235
- **Automatic casting**: Convert between precisions as needed
43-
36+
- **Loss computation**: Use MSE as the loss function before scaling
4437
---
45-
4638
## **4. Benefits and Applications**
47-
4839
- **Memory Efficiency**: Reduces memory usage by ~50% for activations
4940
- **Speed Improvement**: Faster computation on modern GPUs with Tensor Cores
5041
- **Training Stability**: Loss scaling prevents gradient underflow
5142
- **Model Accuracy**: Maintains comparable accuracy to full precision training
52-
5343
Common in training large neural networks where memory is a constraint and speed is critical.
54-
5544
---

0 commit comments

Comments
 (0)