Skip to content

Commit a233410

Browse files
author
Ian Atol
committed
memory_opt! performance and safety improvements
1 parent 2cd2b1c commit a233410

File tree

3 files changed

+249
-33
lines changed

3 files changed

+249
-33
lines changed

base/array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}}
136136
"""
137137
IMVector{T}
138138
139-
One-dimensional [`ImmutableArray`](@ref) with elements of type `T`. Alias for `IMArray{T, 1}`.
139+
One-dimensional [`ImmutableArray`](@ref) or [`Array`](@ref) with elements of type `T`. Alias for `IMArray{T, 1}`.
140140
"""
141141
const IMVector{T} = IMArray{T, 1}
142142

143143
"""
144144
IMMatrix{T}
145145
146-
Two-dimensional [`ImmutableArray`](@ref) with elements of type `T`. Alias for `IMArray{T,2}`.
146+
Two-dimensional [`ImmutableArray`](@ref) or [`Array`](@ref) with elements of type `T`. Alias for `IMArray{T,2}`.
147147
"""
148148
const IMMatrix{T} = IMArray{T, 2}
149149

base/compiler/ssair/passes.jl

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,19 +1260,22 @@ function is_allocation(stmt::Expr)
12601260
isexpr(stmt, :foreigncall) || return false
12611261
s = stmt.args[1]
12621262
isa(s, QuoteNode) && (s = s.value)
1263-
return s === :jl_alloc_array_1d
1263+
return (s === :jl_alloc_array_1d || s === :jl_alloc_array_2d || s === :jl_alloc_array_3d || s === :jl_new_array)
12641264
end
12651265

12661266
function memory_opt!(ir::IRCode)
12671267
compact = IncrementalCompact(ir, false)
12681268
uses = IdDict{Int, Vector{Int}}()
1269-
relevant = IdSet{Int}()
1270-
revisit = Int[]
1271-
function mark_val(@nospecialize val)
1269+
relevant = IdSet{Int}() # allocations
1270+
revisit = Int[] # potential targets for mutating_arrayfreeze
1271+
1272+
function mark_escape(@nospecialize val)
12721273
isa(val, SSAValue) || return
12731274
val.id in relevant && pop!(relevant, val.id)
12741275
end
1276+
12751277
for ((_, idx), stmt) in compact
1278+
12761279
if isa(stmt, ReturnNode)
12771280
isdefined(stmt, :val) || continue
12781281
val = stmt.val
@@ -1282,51 +1285,82 @@ function memory_opt!(ir::IRCode)
12821285
end
12831286
continue
12841287
end
1288+
12851289
(isexpr(stmt, :call) || isexpr(stmt, :foreigncall)) || continue
1290+
12861291
if is_allocation(stmt)
12871292
push!(relevant, idx)
12881293
# TODO: Mark everything else here
12891294
continue
12901295
end
1296+
12911297
# TODO: Replace this by interprocedural escape analysis
12921298
if is_known_call(stmt, arrayset, compact)
1299+
# arrayset expr.args:
1300+
# :(Base.arrayset)
1301+
# false
1302+
# :(%2) array
1303+
# :(%8) value
1304+
# :(%7) index
12931305
# The value being set escapes, everything else doesn't
1294-
mark_val(stmt.args[4])
1306+
(length(stmt.args) == 5) || continue # fix boundserror during precompile --- but how do we have arrayset with < 5 args?
1307+
mark_escape(stmt.args[4])
12951308
arr = stmt.args[3]
12961309
if isa(arr, SSAValue) && arr.id in relevant
12971310
(haskey(uses, arr.id)) || (uses[arr.id] = Int[])
12981311
push!(uses[arr.id], idx)
12991312
end
1313+
13001314
elseif is_known_call(stmt, Core.arrayfreeze, compact) && isa(stmt.args[2], SSAValue)
13011315
push!(revisit, idx)
1316+
1317+
elseif is_known_call(stmt, arraysize, compact) && isa(stmt.args[2], SSAValue)
1318+
arr = stmt.args[2]
1319+
typ = abstract_eval_ssavalue(arr, compact)
1320+
1321+
# make sure this call isn't going to throw
1322+
if typ != Union{} && typ <: AbstractArray && ndims(typ) == stmt.args[3]
1323+
# don't escape the array, but mark usage for dom analysis
1324+
if arr.id in relevant
1325+
(haskey(uses, arr.id)) || (uses[arr.id] = Int[])
1326+
push!(uses[arr.id], idx)
1327+
end
1328+
else # if this call throws, the array definitely escapes
1329+
for ur in userefs(stmt)
1330+
mark_escape(ur[])
1331+
end
1332+
end
13021333
else
13031334
# For now we assume everything escapes
13041335
# TODO: We could handle PhiNodes specially and improve this
13051336
for ur in userefs(stmt)
1306-
mark_val(ur[])
1337+
mark_escape(ur[])
13071338
end
13081339
end
13091340
end
1341+
13101342
ir = finish(compact)
13111343
isempty(revisit) && return ir
1344+
13121345
domtree = construct_domtree(ir.cfg.blocks)
1346+
13131347
for idx in revisit
13141348
# Make sure that the value we reference didn't escape
13151349
stmt = ir.stmts[idx][:inst]::Expr
13161350
id = (stmt.args[2]::SSAValue).id
1317-
13181351
(id in relevant) || continue
13191352

13201353
# We're ok to steal the memory if we don't dominate any uses
13211354
ok = true
1322-
for use in uses[id]
1323-
if ssadominates(ir, domtree, idx, use)
1324-
ok = false
1325-
break
1355+
if haskey(uses, id)
1356+
for use in uses[id]
1357+
if ssadominates(ir, domtree, idx, use)
1358+
ok = false
1359+
break
1360+
end
13261361
end
13271362
end
13281363
ok || continue
1329-
13301364
stmt.args[1] = Core.mutating_arrayfreeze
13311365
end
13321366
return ir

test/compiler/immutablearray.jl

Lines changed: 201 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function test_allocate()
99
ImmutableArray(a)
1010
end
1111
let
12-
@allocated(test_allocate())
12+
#@allocated(test_allocate())
1313
#@test @allocated(test_allocate()) < 100
1414
end
1515

@@ -18,32 +18,214 @@ function test_broadcast1()
1818
@test typeof(a .+ a) <: Core.ImmutableArray
1919
end
2020

21-
function test_broadcast2()
22-
a = Core.ImmutableArray([1,2,3])
23-
@test typeof(a .+ 1) <: Core.ImmutableArray
21+
#DiffEq / Performance Tests
22+
23+
using DifferentialEquations
24+
using StaticArrays
25+
26+
function _build_atsit5_caches(::Type{T}) where {T}
27+
28+
cs = SVector{6, T}(0.161, 0.327, 0.9, 0.9800255409045097, 1.0, 1.0)
29+
30+
as = SVector{21, T}(
31+
#=a21=# convert(T,0.161),
32+
#=a31=# convert(T,-0.008480655492356989),
33+
#=a32=# convert(T,0.335480655492357),
34+
#=a41=# convert(T,2.8971530571054935),
35+
#=a42=# convert(T,-6.359448489975075),
36+
#=a43=# convert(T,4.3622954328695815),
37+
#=a51=# convert(T,5.325864828439257),
38+
#=a52=# convert(T,-11.748883564062828),
39+
#=a53=# convert(T,7.4955393428898365),
40+
#=a54=# convert(T,-0.09249506636175525),
41+
#=a61=# convert(T,5.86145544294642),
42+
#=a62=# convert(T,-12.92096931784711),
43+
#=a63=# convert(T,8.159367898576159),
44+
#=a64=# convert(T,-0.071584973281401),
45+
#=a65=# convert(T,-0.028269050394068383),
46+
#=a71=# convert(T,0.09646076681806523),
47+
#=a72=# convert(T,0.01),
48+
#=a73=# convert(T,0.4798896504144996),
49+
#=a74=# convert(T,1.379008574103742),
50+
#=a75=# convert(T,-3.290069515436081),
51+
#=a76=# convert(T,2.324710524099774)
52+
)
53+
54+
btildes = SVector{7,T}(
55+
convert(T,-0.00178001105222577714),
56+
convert(T,-0.0008164344596567469),
57+
convert(T,0.007880878010261995),
58+
convert(T,-0.1447110071732629),
59+
convert(T,0.5823571654525552),
60+
convert(T,-0.45808210592918697),
61+
convert(T,0.015151515151515152)
62+
)
63+
64+
rs = SVector{22, T}(
65+
#=r11=# convert(T,1.0),
66+
#=r12=# convert(T,-2.763706197274826),
67+
#=r13=# convert(T,2.9132554618219126),
68+
#=r14=# convert(T,-1.0530884977290216),
69+
#=r22=# convert(T,0.13169999999999998),
70+
#=r23=# convert(T,-0.2234),
71+
#=r24=# convert(T,0.1017),
72+
#=r32=# convert(T,3.9302962368947516),
73+
#=r33=# convert(T,-5.941033872131505),
74+
#=r34=# convert(T,2.490627285651253),
75+
#=r42=# convert(T,-12.411077166933676),
76+
#=r43=# convert(T,30.33818863028232),
77+
#=r44=# convert(T,-16.548102889244902),
78+
#=r52=# convert(T,37.50931341651104),
79+
#=r53=# convert(T,-88.1789048947664),
80+
#=r54=# convert(T,47.37952196281928),
81+
#=r62=# convert(T,-27.896526289197286),
82+
#=r63=# convert(T,65.09189467479366),
83+
#=r64=# convert(T,-34.87065786149661),
84+
#=r72=# convert(T,1.5),
85+
#=r73=# convert(T,-4),
86+
#=r74=# convert(T,2.5),
87+
)
88+
return cs, as, btildes, rs
2489
end
2590

26-
function test_diffeq()
91+
function test_imarrays()
2792
function lorenz(u, p, t)
2893
a,b,c = u
2994
x,y,z = p
3095
dx_dt = x * (b - a)
3196
dy_dt = a*(y - c) - b
3297
dz_dt = a*b - z * c
33-
Core.ImmutableArray([dx_dt, dy_dt, dz_dt])
98+
res = Vector{Float64}(undef, 3)
99+
res[1], res[2], res[3] = dx_dt, dy_dt, dz_dt
100+
Core.ImmutableArray(res)
34101
end
35-
u0 = Core.ImmutableArray([1.0, 1.0, 1.0])
36-
tspan = (0.0, 100.0)
37-
p = (10.0, 28.0, 8.0/3.0)
38-
prob = ODEProblem(lorenz, u0, tspan, p)
39-
sol = solve(prob)
40-
@test typeof(sol[1]) <: Core.ImmutableArray
41-
@test typeof(sol[1]) == typeof(sol[423])
42-
end
43102

44-
let
45-
test_broadcast1()
46-
test_broadcast2()
47-
#test_diffeq() disabled bc big dependency
48-
end
103+
_u0 = Core.ImmutableArray([1.0, 1.0, 1.0])
104+
_tspan = (0.0, 100.0)
105+
_p = (10.0, 28.0, 8.0/3.0)
106+
prob = ODEProblem(lorenz, _u0, _tspan, _p)
107+
108+
u0 = prob.u0
109+
tspan = prob.tspan
110+
f = prob.f
111+
p = prob.p
112+
113+
dt = 0.1f0
114+
saveat = nothing
115+
save_everystep = true
116+
abstol = 1f-6
117+
reltol = 1f-3
118+
119+
t = tspan[1]
120+
tf = prob.tspan[2]
121+
122+
beta1 = 7/50
123+
beta2 = 2/25
124+
qmax = 10.0
125+
qmin = 1/5
126+
gamma = 9/10
127+
qoldinit = 1e-4
128+
129+
if saveat === nothing
130+
ts = Vector{eltype(dt)}(undef,1)
131+
ts[1] = prob.tspan[1]
132+
us = Vector{typeof(u0)}(undef,0)
133+
push!(us,recursivecopy(u0))
134+
else
135+
ts = saveat
136+
cur_t = 1
137+
us = MVector{length(ts),typeof(u0)}(undef)
138+
if prob.tspan[1] == ts[1]
139+
cur_t += 1
140+
us[1] = u0
141+
end
142+
end
143+
144+
u = u0
145+
qold = 1e-4
146+
k7 = f(u, p, t)
49147

148+
cs, as, btildes, rs = _build_atsit5_caches(eltype(u0))
149+
c1, c2, c3, c4, c5, c6 = cs
150+
a21, a31, a32, a41, a42, a43, a51, a52, a53, a54,
151+
a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76 = as
152+
btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7 = btildes
153+
154+
# FSAL
155+
while t < tspan[2]
156+
uprev = u
157+
k1 = k7
158+
EEst = Inf
159+
160+
while EEst > 1
161+
dt < 1e-14 && error("dt<dtmin")
162+
163+
tmp = uprev+dt*a21*k1
164+
k2 = f(tmp, p, t+c1*dt)
165+
tmp = uprev+dt*(a31*k1+a32*k2)
166+
k3 = f(tmp, p, t+c2*dt)
167+
tmp = uprev+dt*(a41*k1+a42*k2+a43*k3)
168+
k4 = f(tmp, p, t+c3*dt)
169+
tmp = uprev+dt*(a51*k1+a52*k2+a53*k3+a54*k4)
170+
k5 = f(tmp, p, t+c4*dt)
171+
tmp = uprev+dt*(a61*k1+a62*k2+a63*k3+a64*k4+a65*k5)
172+
k6 = f(tmp, p, t+dt)
173+
u = uprev+dt*(a71*k1+a72*k2+a73*k3+a74*k4+a75*k5+a76*k6)
174+
k7 = f(u, p, t+dt)
175+
176+
tmp = dt*(btilde1*k1+btilde2*k2+btilde3*k3+btilde4*k4+
177+
btilde5*k5+btilde6*k6+btilde7*k7)
178+
tmp = tmp./(abstol.+max.(abs.(uprev),abs.(u))*reltol)
179+
EEst = DiffEqBase.ODE_DEFAULT_NORM(tmp, t)
180+
181+
if iszero(EEst)
182+
q = inv(qmax)
183+
else
184+
@fastmath q11 = EEst^beta1
185+
@fastmath q = q11/(qold^beta2)
186+
end
187+
188+
if EEst > 1
189+
dt = dt/min(inv(qmin),q11/gamma)
190+
else # EEst <= 1
191+
@fastmath q = max(inv(qmax),min(inv(qmin),q/gamma))
192+
qold = max(EEst,qoldinit)
193+
dtold = dt
194+
dt = dt/q #dtnew
195+
dt = min(abs(dt),abs(tf-t-dtold))
196+
told = t
197+
198+
if (tf - t - dtold) < 1e-14
199+
t = tf
200+
else
201+
t += dtold
202+
end
203+
204+
if saveat === nothing && save_everystep
205+
push!(us,recursivecopy(u))
206+
push!(ts,t)
207+
else saveat !== nothing
208+
while cur_t <= length(ts) && ts[cur_t] <= t
209+
savet = ts[cur_t]
210+
θ = (savet - told)/dtold
211+
b1θ, b2θ, b3θ, b4θ, b5θ, b6θ, b7θ = bθs(rs, θ)
212+
us[cur_t] = uprev + dtold*(
213+
b1θ*k1 + b2θ*k2 + b3θ*k3 + b4θ*k4 + b5θ*k5 + b6θ*k6 + b7θ*k7)
214+
cur_t += 1
215+
end
216+
end
217+
end
218+
end
219+
end
220+
221+
if saveat === nothing && !save_everystep
222+
push!(us,u)
223+
push!(ts,t)
224+
end
225+
226+
sol = DiffEqBase.build_solution(prob,Tsit5(),ts,us,calculate_error = false)
227+
228+
DiffEqBase.has_analytic(prob.f) && DiffEqBase.calculate_solution_errors!(sol;timeseries_errors=true,dense_errors=false)
229+
230+
sol
231+
end

0 commit comments

Comments
 (0)