|
1 | 1 | # **Mixed Precision Training** |
2 | | - |
3 | 2 | ## **1. Definition** |
4 | | - |
5 | 3 | Mixed Precision Training is a **deep learning optimization technique** that uses both **float16** (half precision) and **float32** (single precision) data types during training to reduce memory usage and increase training speed while maintaining model accuracy. |
6 | | - |
7 | 4 | The technique works by: |
8 | 5 | - **Using float16 for forward pass computations** to save memory and increase speed |
9 | 6 | - **Using float32 for gradient accumulation** to maintain numerical precision |
10 | 7 | - **Applying loss scaling** to prevent gradient underflow in float16 |
11 | | - |
12 | 8 | --- |
13 | | - |
14 | 9 | ## **2. Key Components** |
| 10 | +### **Mean Squared Error (MSE) Loss** |
| 11 | +The loss function must be computed as Mean Squared Error: |
| 12 | +$$ |
| 13 | +\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 |
| 14 | +$$ |
| 15 | +where $y_i$ is the target and $\hat{y}_i$ is the prediction for sample $i$. |
15 | 16 |
|
16 | 17 | ### **Loss Scaling** |
17 | 18 | To prevent gradient underflow in float16, gradients are scaled up during the forward pass: |
18 | | - |
19 | 19 | $$ |
20 | | -\text{scaled\_loss} = \text{loss} \times \text{scale\_factor} |
| 20 | +\text{scaled\_loss} = \text{MSE} \times \text{scale\_factor} |
21 | 21 | $$ |
22 | | - |
23 | 22 | Then unscaled during backward pass: |
24 | | - |
25 | 23 | $$ |
26 | 24 | \text{gradient} = \frac{\text{scaled\_gradient}}{\text{scale\_factor}} |
27 | 25 | $$ |
28 | | - |
29 | 26 | ### **Overflow Detection** |
30 | 27 | Check for invalid gradients (NaN or Inf) that indicate numerical overflow: |
31 | | - |
32 | 28 | $$ |
33 | 29 | \text{overflow} = \text{any}(\text{isnan}(\text{gradients}) \text{ or } \text{isinf}(\text{gradients})) |
34 | 30 | $$ |
35 | | - |
36 | 31 | --- |
37 | | - |
38 | 32 | ## **3. Precision Usage** |
39 | | - |
40 | 33 | - **float16**: Forward pass computations, activations, temporary calculations |
41 | 34 | - **float32**: Gradient accumulation, parameter updates, loss scaling |
42 | 35 | - **Automatic casting**: Convert between precisions as needed |
43 | | - |
| 36 | +- **Loss computation**: Use MSE as the loss function before scaling |
44 | 37 | --- |
45 | | - |
46 | 38 | ## **4. Benefits and Applications** |
47 | | - |
48 | 39 | - **Memory Efficiency**: Reduces memory usage by ~50% for activations |
49 | 40 | - **Speed Improvement**: Faster computation on modern GPUs with Tensor Cores |
50 | 41 | - **Training Stability**: Loss scaling prevents gradient underflow |
51 | 42 | - **Model Accuracy**: Maintains comparable accuracy to full precision training |
52 | | - |
53 | 43 | Common in training large neural networks where memory is a constraint and speed is critical. |
54 | | - |
55 | 44 | --- |
0 commit comments