diff --git a/src/dispatch.jl b/src/dispatch.jl index f8b70c4..d74aa55 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -13,6 +13,16 @@ abstract type AbstractMutable end +# Zero arithmetic methods for AbstractMutable types. +# The main Zero arithmetic is defined in rewrite.jl with Number/AbstractArray; +# these methods extend it to AbstractMutable. +Base.:*(z::Zero, ::AbstractMutable) = z +Base.:*(::AbstractMutable, z::Zero) = z +Base.:+(::Zero, x::AbstractMutable) = copy_if_mutable(x) +Base.:+(x::AbstractMutable, ::Zero) = copy_if_mutable(x) +Base.:-(::Zero, x::AbstractMutable) = operate(-, x) +Base.:-(x::AbstractMutable, ::Zero) = copy_if_mutable(x) + function Base.sum( a::AbstractArray{T}; dims = :, diff --git a/src/rewrite.jl b/src/rewrite.jl index 485cd89..f37ca58 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -58,20 +58,31 @@ broadcast!!(::Union{typeof(add_mul),typeof(+)}, ::Zero, x) = copy_if_mutable(x) broadcast!!(::typeof(add_mul), ::Zero, x, y) = x * y # Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)` -Base.:*(z::Zero, ::Any) = z -Base.:*(::Any, z::Zero) = z +# These methods are narrowed to `Number` and `AbstractArray` to avoid invalidating +# the very broad `+(x, y)`, `*(x, y)` fallbacks in Base, which causes thousands of +# method invalidations across the ecosystem. Downstream packages that define custom +# types participating in MutableArithmetics rewrites should define their own +# `+(::MyType, ::Zero)` etc. methods. +Base.:*(z::Zero, ::Number) = z +Base.:*(::Number, z::Zero) = z +Base.:*(z::Zero, ::AbstractArray) = z +Base.:*(::AbstractArray, z::Zero) = z Base.:*(z::Zero, ::Zero) = z -Base.:+(::Zero, x::Any) = x -Base.:+(x::Any, ::Zero) = x +Base.:+(::Zero, x::Number) = x +Base.:+(x::Number, ::Zero) = x +Base.:+(::Zero, x::AbstractArray) = x +Base.:+(x::AbstractArray, ::Zero) = x Base.:+(z::Zero, ::Zero) = z -Base.:-(::Zero, x::Any) = -x -Base.:-(x::Any, ::Zero) = x +Base.:-(::Zero, x::Number) = -x +Base.:-(x::Number, ::Zero) = x +Base.:-(::Zero, x::AbstractArray) = -x +Base.:-(x::AbstractArray, ::Zero) = x Base.:-(z::Zero, ::Zero) = z Base.:-(z::Zero) = z Base.:+(z::Zero) = z Base.:*(z::Zero) = z -function Base.:/(z::Zero, x::Any) +function Base.:/(z::Zero, x::Number) if iszero(x) throw(DivideError()) else