Skip to content

Commit ce813c4

Browse files
authored
Merge pull request #541 from Open-Deep-ML/moe18-patch-7
Update solution.py
2 parents 45a9988 + ea5db51 commit ce813c4

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

questions/160_mixed_precision_training/solution.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22

3-
43
class MixedPrecision:
54
def __init__(self, loss_scale=1024.0):
65
self.loss_scale = loss_scale
@@ -15,9 +14,8 @@ def forward(self, weights, inputs, targets):
1514
predictions = np.dot(inputs_fp16, weights_fp16)
1615
loss = np.mean((targets_fp16 - predictions) ** 2)
1716

18-
# Scale loss and convert back to float32
19-
scaled_loss = loss.astype(np.float32) * self.loss_scale
20-
17+
# Scale loss and convert back to float32 (Python float)
18+
scaled_loss = float(loss) * self.loss_scale
2119
return scaled_loss
2220

2321
def backward(self, gradients):

0 commit comments

Comments
 (0)