Skip to content

Commit 5d03fcd

Browse files
aviateskKristofferC
authored andcommitted
inference: revive CachedMethodTable mechanism
`CachedMethodTable` was removed within #44240 as we couldn't confirm any performance improvement then. However it turns out the optimization was critical in some real world cases (e.g. #46492), so this commit revives the mechanism with the following tweaks that should make it more effective: - create method table cache per inference (rather than per local inference on a function call as on the previous implementation) - only use cache mechanism for abstract types (since we already cache lookup result at the next level as for concrete types) As a result, the following snippet reported at #46492 recovers the compilation performance: ```julia using ControlSystems a_2 = [-5 -3; 2 -9] C_212 = ss(a_2, [1; 2], [1 0; 0 1], [0; 0]) @time norm(C_212) ``` > on master ``` julia> @time norm(C_212) 364.489044 seconds (724.44 M allocations: 92.524 GiB, 6.01% gc time, 100.00% compilation time) 0.5345224838248489 ``` > on this commit ``` julia> @time norm(C_212) 26.539016 seconds (62.09 M allocations: 5.537 GiB, 5.55% gc time, 100.00% compilation time) 0.5345224838248489 ``` (cherry picked from commit 8445744)
1 parent c65611a commit 5d03fcd

File tree

4 files changed

+68
-37
lines changed

4 files changed

+68
-37
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
282282
if result === missing
283283
return FailedMethodMatch("For one of the union split cases, too many methods matched")
284284
end
285-
matches, overlayed = result
285+
(; matches, overlayed) = result
286286
nonoverlayed &= !overlayed
287287
push!(infos, MethodMatchInfo(matches))
288288
for m in matches
@@ -323,7 +323,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
323323
# (assume this will always be true, so we don't compute / update valid age in this case)
324324
return FailedMethodMatch("Too many methods matched")
325325
end
326-
matches, overlayed = result
326+
(; matches, overlayed) = result
327327
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
328328
return MethodMatches(matches.matches,
329329
MethodMatchInfo(matches),

base/compiler/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ something(x::Any, y...) = x
123123
############
124124

125125
include("compiler/cicache.jl")
126+
include("compiler/methodtable.jl")
126127
include("compiler/types.jl")
127128
include("compiler/utilities.jl")
128129
include("compiler/validation.jl")
129-
include("compiler/methodtable.jl")
130130

131131
include("compiler/inferenceresult.jl")
132132
include("compiler/inferencestate.jl")

base/compiler/methodtable.jl

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,27 @@
22

33
abstract type MethodTableView; end
44

5+
struct MethodLookupResult
6+
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
7+
# and work with Vector{Any} on the C side.
8+
matches::Vector{Any}
9+
valid_worlds::WorldRange
10+
ambig::Bool
11+
end
12+
length(result::MethodLookupResult) = length(result.matches)
13+
function iterate(result::MethodLookupResult, args...)
14+
r = iterate(result.matches, args...)
15+
r === nothing && return nothing
16+
match, state = r
17+
return (match::MethodMatch, state)
18+
end
19+
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
20+
21+
struct MethodMatchResult
22+
matches::MethodLookupResult
23+
overlayed::Bool
24+
end
25+
526
"""
627
struct InternalMethodTable <: MethodTableView
728
@@ -23,25 +44,21 @@ struct OverlayMethodTable <: MethodTableView
2344
mt::Core.MethodTable
2445
end
2546

26-
struct MethodLookupResult
27-
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
28-
# and work with Vector{Any} on the C side.
29-
matches::Vector{Any}
30-
valid_worlds::WorldRange
31-
ambig::Bool
32-
end
33-
length(result::MethodLookupResult) = length(result.matches)
34-
function iterate(result::MethodLookupResult, args...)
35-
r = iterate(result.matches, args...)
36-
r === nothing && return nothing
37-
match, state = r
38-
return (match::MethodMatch, state)
47+
"""
48+
struct CachedMethodTable <: MethodTableView
49+
50+
Overlays another method table view with an additional local fast path cache that
51+
can respond to repeated, identical queries faster than the original method table.
52+
"""
53+
struct CachedMethodTable{T} <: MethodTableView
54+
cache::IdDict{Any, Union{Missing, MethodMatchResult}}
55+
table::T
3956
end
40-
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
57+
CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodMatchResult}}(), table)
4158

4259
"""
4360
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
44-
(matches::MethodLookupResult, overlayed::Bool) or missing
61+
MethodMatchResult(matches::MethodLookupResult, overlayed::Bool) or missing
4562
4663
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
4764
If no applicable methods are found, an empty result is returned.
@@ -51,7 +68,7 @@ If the number of applicable methods exceeded the specified limit, `missing` is r
5168
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
5269
result = _findall(sig, nothing, table.world, limit)
5370
result === missing && return missing
54-
return result, false
71+
return MethodMatchResult(result, false)
5572
end
5673

5774
function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
@@ -60,18 +77,20 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
6077
nr = length(result)
6178
if nr 1 && result[nr].fully_covers
6279
# no need to fall back to the internal method table
63-
return result, true
80+
return MethodMatchResult(result, true)
6481
end
6582
# fall back to the internal method table
6683
fallback_result = _findall(sig, nothing, table.world, limit)
6784
fallback_result === missing && return missing
6885
# merge the fallback match results with the internal method table
69-
return MethodLookupResult(
70-
vcat(result.matches, fallback_result.matches),
71-
WorldRange(
72-
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
73-
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
74-
result.ambig | fallback_result.ambig), !isempty(result)
86+
return MethodMatchResult(
87+
MethodLookupResult(
88+
vcat(result.matches, fallback_result.matches),
89+
WorldRange(
90+
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
91+
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
92+
result.ambig | fallback_result.ambig),
93+
!isempty(result))
7594
end
7695

7796
function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
@@ -85,6 +104,17 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
85104
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
86105
end
87106

107+
function findall(@nospecialize(sig::Type), table::CachedMethodTable; limit::Int=typemax(Int))
108+
if isconcretetype(sig)
109+
# as for concrete types, we cache result at on the next level
110+
return findall(sig, table.table; limit)
111+
end
112+
box = Core.Box(sig)
113+
return get!(table.cache, sig) do
114+
findall(box.contents, table.table; limit)
115+
end
116+
end
117+
88118
"""
89119
findsup(sig::Type, view::MethodTableView) ->
90120
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing
@@ -129,6 +159,10 @@ function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
129159
return match, valid_worlds
130160
end
131161

162+
# This query is not cached
163+
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)
164+
132165
isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
133166
isoverlayed(::InternalMethodTable) = false
134167
isoverlayed(::OverlayMethodTable) = true
168+
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)

base/compiler/types.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ struct NativeInterpreter <: AbstractInterpreter
318318
cache::Vector{InferenceResult}
319319
# The world age we're working inside of
320320
world::UInt
321+
# method table to lookup for during inference on this world age
322+
method_table::CachedMethodTable{InternalMethodTable}
321323

322324
# Parameters for inference and optimization
323325
inf_params::InferenceParams
@@ -327,27 +329,21 @@ struct NativeInterpreter <: AbstractInterpreter
327329
inf_params = InferenceParams(),
328330
opt_params = OptimizationParams(),
329331
)
332+
cache = Vector{InferenceResult}() # Initially empty cache
333+
330334
# Sometimes the caller is lazy and passes typemax(UInt).
331335
# we cap it to the current world age
332336
if world == typemax(UInt)
333337
world = get_world_counter()
334338
end
335339

340+
method_table = CachedMethodTable(InternalMethodTable(world))
341+
336342
# If they didn't pass typemax(UInt) but passed something more subtly
337343
# incorrect, fail out loudly.
338344
@assert world <= get_world_counter()
339345

340-
return new(
341-
# Initially empty cache
342-
Vector{InferenceResult}(),
343-
344-
# world age counter
345-
world,
346-
347-
# parameters for inference and optimization
348-
inf_params,
349-
opt_params,
350-
)
346+
return new(cache, world, method_table, inf_params, opt_params)
351347
end
352348
end
353349

@@ -396,6 +392,7 @@ External `AbstractInterpreter` can optionally return `OverlayMethodTable` here
396392
to incorporate customized dispatches for the overridden methods.
397393
"""
398394
method_table(interp::AbstractInterpreter) = InternalMethodTable(get_world_counter(interp))
395+
method_table(interp::NativeInterpreter) = interp.method_table
399396

400397
"""
401398
By default `AbstractInterpreter` implements the following inference bail out logic:

0 commit comments

Comments
 (0)