Skip to content

Commit 19f44b6

Browse files
author
Ian Atol
authored
Allow inlining methods with unmatched type parameters (#45062)
1 parent 36aab14 commit 19f44b6

File tree

17 files changed

+355
-82
lines changed

17 files changed

+355
-82
lines changed

base/compiler/ssair/inlining.jl

Lines changed: 117 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,28 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
376376
boundscheck = :off
377377
end
378378
end
379+
if !validate_sparams(sparam_vals)
380+
if def.isva
381+
nonva_args = argexprs[1:end-1]
382+
va_arg = argexprs[end]
383+
tuple_call = Expr(:call, TOP_TUPLE, def, nonva_args...)
384+
tuple_type = tuple_tfunc(Any[argextype(arg, compact) for arg in nonva_args])
385+
tupl = insert_node_here!(compact, NewInstruction(tuple_call, tuple_type, topline))
386+
apply_iter_expr = Expr(:call, Core._apply_iterate, iterate, Core._compute_sparams, tupl, va_arg)
387+
sparam_vals = insert_node_here!(compact,
388+
effect_free(NewInstruction(apply_iter_expr, SimpleVector, topline)))
389+
else
390+
sparam_vals = insert_node_here!(compact,
391+
effect_free(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
392+
end
393+
end
379394
# If the iterator already moved on to the next basic block,
380395
# temporarily re-open in again.
381396
local return_value
382397
sig = def.sig
383398
# Special case inlining that maintains the current basic block if there's only one BB in the target
399+
new_new_offset = length(compact.new_new_nodes)
400+
late_fixup_offset = length(compact.late_fixup)
384401
if spec.linear_inline_eligible
385402
#compact[idx] = nothing
386403
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
@@ -389,7 +406,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
389406
# face of rename_arguments! mutating in place - should figure out
390407
# something better eventually.
391408
inline_compact[idx′] = nothing
392-
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
409+
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
393410
if isa(stmt′, ReturnNode)
394411
val = stmt′.val
395412
return_value = SSAValue(idx′)
@@ -402,7 +419,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
402419
end
403420
inline_compact[idx′] = stmt′
404421
end
405-
just_fixup!(inline_compact)
422+
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
406423
compact.result_idx = inline_compact.result_idx
407424
else
408425
bb_offset, post_bb_id = popfirst!(todo_bbs)
@@ -416,7 +433,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
416433
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
417434
for ((_, idx′), stmt′) in inline_compact
418435
inline_compact[idx′] = nothing
419-
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
436+
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
420437
if isa(stmt′, ReturnNode)
421438
if isdefined(stmt′, :val)
422439
val = stmt′.val
@@ -436,7 +453,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
436453
end
437454
inline_compact[idx′] = stmt′
438455
end
439-
just_fixup!(inline_compact)
456+
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
440457
compact.result_idx = inline_compact.result_idx
441458
compact.active_result_bb = inline_compact.active_result_bb
442459
if length(pn.edges) == 1
@@ -460,7 +477,8 @@ function fix_va_argexprs!(compact::IncrementalCompact,
460477
push!(tuple_typs, argextype(arg, compact))
461478
end
462479
tuple_typ = tuple_tfunc(tuple_typs)
463-
push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx)))
480+
tuple_inst = NewInstruction(tuple_call, tuple_typ, line_idx)
481+
push!(newargexprs, insert_node_here!(compact, tuple_inst))
464482
return newargexprs
465483
end
466484

@@ -875,8 +893,26 @@ function validate_sparams(sparams::SimpleVector)
875893
return true
876894
end
877895

896+
function may_have_fcalls(m::Method)
897+
may_have_fcall = true
898+
if isdefined(m, :source)
899+
src = m.source
900+
isa(src, Vector{UInt8}) && (src = uncompressed_ir(m))
901+
if isa(src, CodeInfo)
902+
may_have_fcall = src.has_fcall
903+
end
904+
end
905+
return may_have_fcall
906+
end
907+
908+
function can_inline_typevars(m::MethodMatch, argtypes::Vector{Any})
909+
may_have_fcalls(m.method) && return false
910+
any(@nospecialize(x) -> x isa UnionAll, argtypes[2:end]) && return false
911+
return true
912+
end
913+
878914
function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig,
879-
flag::UInt8, state::InliningState)
915+
flag::UInt8, state::InliningState, allow_typevars::Bool = false)
880916
method = match.method
881917
spec_types = match.spec_types
882918

@@ -898,8 +934,9 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig,
898934
end
899935
end
900936

901-
# Bail out if any static parameters are left as TypeVar
902-
validate_sparams(match.sparams) || return nothing
937+
if !validate_sparams(match.sparams)
938+
(allow_typevars && can_inline_typevars(match, argtypes)) || return nothing
939+
end
903940

904941
et = state.et
905942

@@ -1231,6 +1268,9 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
12311268
flag::UInt8, sig::Signature, state::InliningState)
12321269
argtypes = sig.argtypes
12331270
cases = InliningCase[]
1271+
local only_method = nothing
1272+
local meth::MethodLookupResult
1273+
local revisit_idx = nothing
12341274
local any_fully_covered = false
12351275
local handled_all_cases = true
12361276
for i in 1:length(infos)
@@ -1243,14 +1283,58 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
12431283
# No applicable methods; try next union split
12441284
handled_all_cases = false
12451285
continue
1286+
else
1287+
if length(meth) == 1 && only_method !== false
1288+
if only_method === nothing
1289+
only_method = meth[1].method
1290+
elseif only_method !== meth[1].method
1291+
only_method = false
1292+
end
1293+
else
1294+
only_method = false
1295+
end
12461296
end
1247-
for match in meth
1248-
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true)
1297+
for (j, match) in enumerate(meth)
12491298
any_fully_covered |= match.fully_covers
1299+
if !validate_sparams(match.sparams)
1300+
if !match.fully_covers
1301+
handled_all_cases = false
1302+
continue
1303+
end
1304+
if revisit_idx === nothing
1305+
revisit_idx = (i, j)
1306+
else
1307+
handled_all_cases = false
1308+
revisit_idx = nothing
1309+
end
1310+
else
1311+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
1312+
end
12501313
end
12511314
end
12521315

1253-
if !handled_all_cases
1316+
if handled_all_cases && revisit_idx !== nothing
1317+
# we handled everything except one match with unmatched sparams,
1318+
# so try to handle it by bypassing validate_sparams
1319+
(i, j) = revisit_idx
1320+
match = infos[i].results[j]
1321+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true)
1322+
elseif length(cases) == 0 && only_method isa Method
1323+
# if the signature is fully covered and there is only one applicable method,
1324+
# we can try to inline it even in the prescence of unmatched sparams
1325+
# -- But don't try it if we already tried to handle the match in the revisit_idx
1326+
# case, because that'll (necessarily) be the same method.
1327+
if length(infos) > 1
1328+
atype = argtypes_to_type(argtypes)
1329+
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), atype, only_method.sig)::SimpleVector
1330+
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
1331+
else
1332+
@assert length(meth) == 1
1333+
match = meth[1]
1334+
end
1335+
handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true) || return nothing
1336+
any_fully_covered = handled_all_cases = match.fully_covers
1337+
elseif !handled_all_cases
12541338
# if we've not seen all candidates, union split is valid only for dispatch tuples
12551339
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
12561340
end
@@ -1286,10 +1370,10 @@ function compute_inlining_cases(info::ConstCallInfo,
12861370
case = concrete_result_item(result, state)
12871371
push!(cases, InliningCase(result.mi.specTypes, case))
12881372
elseif isa(result, ConstPropResult)
1289-
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, true)
1373+
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true)
12901374
else
12911375
@assert result === nothing
1292-
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true)
1376+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
12931377
end
12941378
end
12951379
end
@@ -1324,22 +1408,22 @@ end
13241408

13251409
function handle_match!(
13261410
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1327-
cases::Vector{InliningCase}, allow_abstract::Bool = false)
1411+
cases::Vector{InliningCase}, allow_abstract::Bool, allow_typevars::Bool)
13281412
spec_types = match.spec_types
13291413
allow_abstract || isdispatchtuple(spec_types) || return false
1330-
# we may see duplicated dispatch signatures here when a signature gets widened
1414+
# We may see duplicated dispatch signatures here when a signature gets widened
13311415
# during abstract interpretation: for the purpose of inlining, we can just skip
1332-
# processing this dispatch candidate
1333-
_any(case->case.sig === spec_types, cases) && return true
1334-
item = analyze_method!(match, argtypes, nothing, flag, state)
1416+
# processing this dispatch candidate (unless unmatched type parameters are present)
1417+
!allow_typevars && _any(case->case.sig === spec_types, cases) && return true
1418+
item = analyze_method!(match, argtypes, nothing, flag, state, allow_typevars)
13351419
item === nothing && return false
13361420
push!(cases, InliningCase(spec_types, item))
13371421
return true
13381422
end
13391423

13401424
function handle_const_prop_result!(
13411425
result::ConstPropResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1342-
cases::Vector{InliningCase}, allow_abstract::Bool = false)
1426+
cases::Vector{InliningCase}, allow_abstract::Bool)
13431427
(; mi) = item = InliningTodo(result.result, argtypes)
13441428
spec_types = mi.specTypes
13451429
allow_abstract || isdispatchtuple(spec_types) || return false
@@ -1624,30 +1708,37 @@ function late_inline_special_case!(
16241708
end
16251709

16261710
function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any},
1627-
@nospecialize(spsig), spvals::SimpleVector,
1711+
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
16281712
linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact)
16291713
compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS
16301714
compact.result[idx][:line] += linetable_offset
1631-
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck)
1715+
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx)
16321716
end
16331717

16341718
function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
1635-
@nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol)
1719+
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
1720+
boundscheck::Symbol, compact::IncrementalCompact, idx::Int)
16361721
if isa(val, Argument)
16371722
return arg_replacements[val.n]
16381723
end
16391724
if isa(val, Expr)
16401725
e = val::Expr
16411726
head = e.head
16421727
if head === :static_parameter
1643-
return quoted(spvals[e.args[1]::Int])
1644-
elseif head === :cfunction
1728+
if isa(spvals, SimpleVector)
1729+
return quoted(spvals[e.args[1]::Int])
1730+
else
1731+
ret = insert_node!(compact, SSAValue(idx),
1732+
effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any)))
1733+
return ret
1734+
end
1735+
elseif head === :cfunction && isa(spvals, SimpleVector)
16451736
@assert !isa(spsig, UnionAll) || !isempty(spvals)
16461737
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
16471738
e.args[4] = svec(Any[
16481739
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
16491740
for argt in e.args[4]::SimpleVector ]...)
1650-
elseif head === :foreigncall
1741+
elseif head === :foreigncall && isa(spvals, SimpleVector)
16511742
@assert !isa(spsig, UnionAll) || !isempty(spvals)
16521743
for i = 1:length(e.args)
16531744
if i == 2
@@ -1671,7 +1762,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
16711762
isa(val, Union{SSAValue, NewSSAValue}) && return val # avoid infinite loop
16721763
urs = userefs(val)
16731764
for op in urs
1674-
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)
1765+
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx)
16751766
end
16761767
return urs[]
16771768
end

0 commit comments

Comments
 (0)