@@ -1256,76 +1256,184 @@ function cfg_simplify!(ir::IRCode)
12561256 return finish (compact)
12571257end
12581258
1259- function is_allocation (stmt)
1259+ # function is_known_fcall(stmt::Expr, @nospecialize(func))
1260+ # isexpr(stmt, :foreigncall) || return false
1261+ # s = stmt.args[1]
1262+ # isa(s, QuoteNode) && (s = s.value)
1263+ # return s === func
1264+ # end
1265+
1266+ function is_known_fcall (stmt:: Expr , funcs:: Vector{Symbol} )
12601267 isexpr (stmt, :foreigncall ) || return false
12611268 s = stmt. args[1 ]
12621269 isa (s, QuoteNode) && (s = s. value)
1263- return s === :jl_alloc_array_1d
1270+ # return any(e -> s === e, funcs)
1271+ return true in map (e -> s === e, funcs)
1272+ end
1273+
1274+ function is_allocation (stmt:: Expr )
1275+ isexpr (stmt, :foreigncall ) || return false
1276+ s = stmt. args[1 ]
1277+ isa (s, QuoteNode) && (s = s. value)
1278+ return (s === :jl_alloc_array_1d
1279+ || s === :jl_alloc_array_2d
1280+ || s === :jl_alloc_array_3d
1281+ || s === :jl_new_array )
12641282end
12651283
12661284function memory_opt! (ir:: IRCode )
12671285 compact = IncrementalCompact (ir, false )
12681286 uses = IdDict {Int, Vector{Int}} ()
1269- relevant = IdSet {Int} ()
1270- revisit = Int[]
1271- function mark_val (val)
1287+ relevant = IdSet {Int} () # allocations
1288+ revisit = Int[] # potential targets for a mutating_arrayfreeze drop-in
1289+ maybecopies = Int[] # calls to maybecopy
1290+
1291+ function mark_escape (@nospecialize val)
12721292 isa (val, SSAValue) || return
1293+ # println(val.id, " escaped.")
12731294 val. id in relevant && pop! (relevant, val. id)
12741295 end
1296+
1297+ function mark_use (val, idx)
1298+ isa (val, SSAValue) || return
1299+ id = val. id
1300+ id in relevant || return
1301+ (haskey (uses, id)) || (uses[id] = Int[])
1302+ push! (uses[id], idx)
1303+ end
1304+
12751305 for ((_, idx), stmt) in compact
1306+
1307+ # println("idx: ", idx, " = ", stmt)
1308+
12761309 if isa (stmt, ReturnNode)
12771310 isdefined (stmt, :val ) || continue
12781311 val = stmt. val
1279- if isa (val, SSAValue) && val. id in relevant
1280- (haskey (uses, val. id)) || (uses[val. id] = Int[])
1281- push! (uses[val. id], idx)
1282- end
1312+ mark_use (val, idx)
12831313 continue
1314+
1315+ # check for phinodes that are possibly allocations
1316+ elseif isa (stmt, PhiNode)
1317+
1318+ # ensure all of the phinode values are defined
1319+ defined = true
1320+ for i = 1 : length (stmt. values)
1321+ if ! isassigned (stmt. values, i)
1322+ defined = false
1323+ end
1324+ end
1325+
1326+ defined || continue
1327+
1328+ for val in stmt. values
1329+ if isa (val, SSAValue) && val. id in relevant
1330+ push! (relevant, idx)
1331+ end
1332+ end
12841333 end
1334+
12851335 (isexpr (stmt, :call ) || isexpr (stmt, :foreigncall )) || continue
1336+
1337+ if is_known_call (stmt, Core. maybecopy, compact)
1338+ push! (maybecopies, idx)
1339+ continue
1340+ end
1341+
12861342 if is_allocation (stmt)
12871343 push! (relevant, idx)
12881344 # TODO : Mark everything else here
12891345 continue
12901346 end
1291- # TODO : Replace this by interprocedural escape analysis
1292- if is_known_call (stmt, arrayset, compact)
1347+
1348+ if is_known_call (stmt, arrayset, compact) && length (stmt . args) >= 5
12931349 # The value being set escapes, everything else doesn't
1294- mark_val (stmt. args[4 ])
1350+ mark_escape (stmt. args[4 ])
12951351 arr = stmt. args[3 ]
1296- if isa (arr, SSAValue) && arr. id in relevant
1297- (haskey (uses, arr. id)) || (uses[arr. id] = Int[])
1298- push! (uses[arr. id], idx)
1299- end
1352+ mark_use (arr, idx)
1353+
1354+ elseif is_known_call (stmt, arrayref, compact) && length (stmt. args) == 4
1355+ arr = stmt. args[3 ]
1356+ mark_use (arr, idx)
1357+
1358+ elseif is_known_call (stmt, setindex!, compact) && length (stmt. args) == 4
1359+ # handle similarly to arrayset
1360+ val = stmt. args[3 ]
1361+ mark_escape (val)
1362+
1363+ arr = stmt. args[2 ]
1364+ mark_use (arr, idx)
1365+
1366+ elseif is_known_call (stmt, (=== ), compact) && length (stmt. args) == 3
1367+ arr1 = stmt. args[2 ]
1368+ arr2 = stmt. args[3 ]
1369+
1370+ mark_use (arr1, idx)
1371+ mark_use (arr2, idx)
1372+
1373+ # these foreigncalls have similar structure and don't escape our array, so handle them all at once
1374+ elseif is_known_fcall (stmt, [:jl_array_ptr , :jl_array_copy ]) && length (stmt. args) == 6
1375+ arr = stmt. args[6 ]
1376+ mark_use (arr, idx)
1377+
1378+ elseif is_known_call (stmt, arraysize, compact) && isa (stmt. args[2 ], SSAValue)
1379+ arr = stmt. args[2 ]
1380+ mark_use (arr, idx)
1381+
13001382 elseif is_known_call (stmt, Core. arrayfreeze, compact) && isa (stmt. args[2 ], SSAValue)
1383+ # mark these for potential replacement with mutating_arrayfreeze
13011384 push! (revisit, idx)
1385+
13021386 else
1303- # For now we assume everything escapes
1304- # TODO : We could handle PhiNodes specially and improve this
1387+ # Assume everything else escapes
13051388 for ur in userefs (stmt)
1306- mark_val (ur[])
1389+ mark_escape (ur[])
13071390 end
13081391 end
13091392 end
1393+
13101394 ir = finish (compact)
1311- isempty (revisit) && return ir
1395+ isempty (revisit) && isempty (maybecopies) && return ir
1396+
13121397 domtree = construct_domtree (ir. cfg. blocks)
1398+
13131399 for idx in revisit
13141400 # Make sure that the value we reference didn't escape
1315- id = ir. stmts[idx][:inst ]. args[2 ]. id
1401+ stmt = ir. stmts[idx][:inst ]:: Expr
1402+ id = (stmt. args[2 ]:: SSAValue ). id
13161403 (id in relevant) || continue
13171404
1405+ # println("Revisiting ", stmt)
1406+
13181407 # We're ok to steal the memory if we don't dominate any uses
13191408 ok = true
1320- for use in uses[id]
1321- if ssadominates (ir, domtree, idx, use)
1322- ok = false
1323- break
1409+ if haskey (uses, id)
1410+ for use in uses[id]
1411+ if ssadominates (ir, domtree, idx, use)
1412+ ok = false
1413+ break
1414+ end
13241415 end
13251416 end
13261417 ok || continue
1327-
1328- ir. stmts[idx][:inst ]. args[1 ] = Core. mutating_arrayfreeze
1418+ stmt. args[1 ] = Core. mutating_arrayfreeze
13291419 end
1420+
1421+ # TODO : Use escape analysis info to determine if maybecopy should copy
1422+
1423+ # for idx in maybecopies
1424+ # stmt = ir.stmts[idx][:inst]::Expr
1425+ # #println(stmt.args)
1426+ # arr = stmt.args[2]
1427+ # id = isa(arr, SSAValue) ? arr.id : arr.n # SSAValue or Core.Argument
1428+
1429+ # if (id in relevant) # didn't escape elsewhere, so make a copy to keep it un-escaped
1430+ # #println("didn't escape maybecopy")
1431+ # stmt.args[1] = Main.Base.copy
1432+ # else # already escaped, so save the cost of copying and just pass the actual object
1433+ # #println("escaped maybecopy")
1434+ # ir.stmts[idx][:inst] = arr
1435+ # end
1436+ # end
1437+
13301438 return ir
13311439end
0 commit comments