-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Labels
compilerCompiler-related itemCompiler-related itemenhancementNew feature or requestNew feature or requestresearchResearch-related itemResearch-related item
Description
It would be nice if the compiler could figure out that only one truncation is necessary for F instead of for each intermediate multiplication separately: F = A * B + C * D
Example from DTI code:
vW_prev = vW[l].copy()
>>>
# vW[l] = vW[l] * MOMENTUM - dW[l] * LEARN_RATE
# vW[l] = vW[l].trunc(mpc.fp)
# temp = vW[l] * (MOMENTUM + 1) - vW_prev * MOMENTUM
# temp = temp.trunc(mpc.fp)
# W[l] = W[l] + temp
>>> should be
vW[l] = vW[l] * MOMENTUM - dW[l] * LEARN_RATE
W[l] = W[l] + vW[l] * (MOMENTUM + 1) - vW_prev * MOMENTUM
vb_prev = vb[l].copy()
>>>
# vb[l] = vb[l] * MOMENTUM - db[l] * LEARN_RATE
# vb[l] = vb[l].trunc(mpc.fp)
# temp_v = vb[l] * (MOMENTUM + 1) - vb_prev * MOMENTUM
# temp_v = temp_v.trunc(mpc.fp)
# b[l] = b[l] + temp_v
>>> should be
vb[l] = vb[l] * MOMENTUM - db[l] * LEARN_RATE
b[l] = b[l] + vb[l] * (MOMENTUM + 1) - vb_prev * MOMENTUM
Metadata
Metadata
Assignees
Labels
compilerCompiler-related itemCompiler-related itemenhancementNew feature or requestNew feature or requestresearchResearch-related itemResearch-related item