Skip to content

Implement clever joint truncations pattern matcher #13

@hsmajlovic

Description

@hsmajlovic

It would be nice if the compiler could figure out that only one truncation is necessary for F instead of for each intermediate multiplication separately: F = A * B + C * D

Example from DTI code:

vW_prev = vW[l].copy()
>>>
# vW[l] = vW[l] * MOMENTUM - dW[l] * LEARN_RATE
# vW[l] = vW[l].trunc(mpc.fp)
# temp = vW[l] * (MOMENTUM + 1) - vW_prev * MOMENTUM
# temp = temp.trunc(mpc.fp)
# W[l] = W[l] + temp
>>> should be
vW[l] = vW[l] * MOMENTUM - dW[l] * LEARN_RATE
W[l] = W[l] + vW[l] * (MOMENTUM + 1) - vW_prev * MOMENTUM

vb_prev = vb[l].copy()
>>>
# vb[l] = vb[l] * MOMENTUM - db[l] * LEARN_RATE
# vb[l] = vb[l].trunc(mpc.fp)
# temp_v = vb[l] * (MOMENTUM + 1) - vb_prev * MOMENTUM
# temp_v = temp_v.trunc(mpc.fp)
# b[l] = b[l] + temp_v
>>> should be
vb[l] = vb[l] * MOMENTUM - db[l] * LEARN_RATE
b[l] = b[l] + vb[l] * (MOMENTUM + 1) - vb_prev * MOMENTUM

Metadata

Metadata

Assignees

No one assigned

    Labels

    compilerCompiler-related itemenhancementNew feature or requestresearchResearch-related item

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions