@@ -9,7 +9,7 @@ function test_allocate()
99 ImmutableArray (a)
1010end
1111let
12- @allocated (test_allocate ())
12+ # @allocated(test_allocate())
1313 # @test @allocated(test_allocate()) < 100
1414end
1515
@@ -18,32 +18,214 @@ function test_broadcast1()
1818 @test typeof (a .+ a) <: Core.ImmutableArray
1919end
2020
21- function test_broadcast2 ()
22- a = Core. ImmutableArray ([1 ,2 ,3 ])
23- @test typeof (a .+ 1 ) <: Core.ImmutableArray
21+ # DiffEq / Performance Tests
22+
23+ using DifferentialEquations
24+ using StaticArrays
25+
26+ function _build_atsit5_caches (:: Type{T} ) where {T}
27+
28+ cs = SVector {6, T} (0.161 , 0.327 , 0.9 , 0.9800255409045097 , 1.0 , 1.0 )
29+
30+ as = SVector {21, T} (
31+ #= a21=# convert (T,0.161 ),
32+ #= a31=# convert (T,- 0.008480655492356989 ),
33+ #= a32=# convert (T,0.335480655492357 ),
34+ #= a41=# convert (T,2.8971530571054935 ),
35+ #= a42=# convert (T,- 6.359448489975075 ),
36+ #= a43=# convert (T,4.3622954328695815 ),
37+ #= a51=# convert (T,5.325864828439257 ),
38+ #= a52=# convert (T,- 11.748883564062828 ),
39+ #= a53=# convert (T,7.4955393428898365 ),
40+ #= a54=# convert (T,- 0.09249506636175525 ),
41+ #= a61=# convert (T,5.86145544294642 ),
42+ #= a62=# convert (T,- 12.92096931784711 ),
43+ #= a63=# convert (T,8.159367898576159 ),
44+ #= a64=# convert (T,- 0.071584973281401 ),
45+ #= a65=# convert (T,- 0.028269050394068383 ),
46+ #= a71=# convert (T,0.09646076681806523 ),
47+ #= a72=# convert (T,0.01 ),
48+ #= a73=# convert (T,0.4798896504144996 ),
49+ #= a74=# convert (T,1.379008574103742 ),
50+ #= a75=# convert (T,- 3.290069515436081 ),
51+ #= a76=# convert (T,2.324710524099774 )
52+ )
53+
54+ btildes = SVector {7,T} (
55+ convert (T,- 0.00178001105222577714 ),
56+ convert (T,- 0.0008164344596567469 ),
57+ convert (T,0.007880878010261995 ),
58+ convert (T,- 0.1447110071732629 ),
59+ convert (T,0.5823571654525552 ),
60+ convert (T,- 0.45808210592918697 ),
61+ convert (T,0.015151515151515152 )
62+ )
63+
64+ rs = SVector {22, T} (
65+ #= r11=# convert (T,1.0 ),
66+ #= r12=# convert (T,- 2.763706197274826 ),
67+ #= r13=# convert (T,2.9132554618219126 ),
68+ #= r14=# convert (T,- 1.0530884977290216 ),
69+ #= r22=# convert (T,0.13169999999999998 ),
70+ #= r23=# convert (T,- 0.2234 ),
71+ #= r24=# convert (T,0.1017 ),
72+ #= r32=# convert (T,3.9302962368947516 ),
73+ #= r33=# convert (T,- 5.941033872131505 ),
74+ #= r34=# convert (T,2.490627285651253 ),
75+ #= r42=# convert (T,- 12.411077166933676 ),
76+ #= r43=# convert (T,30.33818863028232 ),
77+ #= r44=# convert (T,- 16.548102889244902 ),
78+ #= r52=# convert (T,37.50931341651104 ),
79+ #= r53=# convert (T,- 88.1789048947664 ),
80+ #= r54=# convert (T,47.37952196281928 ),
81+ #= r62=# convert (T,- 27.896526289197286 ),
82+ #= r63=# convert (T,65.09189467479366 ),
83+ #= r64=# convert (T,- 34.87065786149661 ),
84+ #= r72=# convert (T,1.5 ),
85+ #= r73=# convert (T,- 4 ),
86+ #= r74=# convert (T,2.5 ),
87+ )
88+ return cs, as, btildes, rs
2489end
2590
26- function test_diffeq ()
91+ function test_imarrays ()
2792 function lorenz (u, p, t)
2893 a,b,c = u
2994 x,y,z = p
3095 dx_dt = x * (b - a)
3196 dy_dt = a* (y - c) - b
3297 dz_dt = a* b - z * c
33- Core. ImmutableArray ([dx_dt, dy_dt, dz_dt])
98+ res = Vector {Float64} (undef, 3 )
99+ res[1 ], res[2 ], res[3 ] = dx_dt, dy_dt, dz_dt
100+ Core. ImmutableArray (res)
34101 end
35- u0 = Core. ImmutableArray ([1.0 , 1.0 , 1.0 ])
36- tspan = (0.0 , 100.0 )
37- p = (10.0 , 28.0 , 8.0 / 3.0 )
38- prob = ODEProblem (lorenz, u0, tspan, p)
39- sol = solve (prob)
40- @test typeof (sol[1 ]) <: Core.ImmutableArray
41- @test typeof (sol[1 ]) == typeof (sol[423 ])
42- end
43102
44- let
45- test_broadcast1 ()
46- test_broadcast2 ()
47- # test_diffeq() disabled bc big dependency
48- end
103+ _u0 = Core. ImmutableArray ([1.0 , 1.0 , 1.0 ])
104+ _tspan = (0.0 , 100.0 )
105+ _p = (10.0 , 28.0 , 8.0 / 3.0 )
106+ prob = ODEProblem (lorenz, _u0, _tspan, _p)
107+
108+ u0 = prob. u0
109+ tspan = prob. tspan
110+ f = prob. f
111+ p = prob. p
112+
113+ dt = 0.1f0
114+ saveat = nothing
115+ save_everystep = true
116+ abstol = 1f-6
117+ reltol = 1f-3
118+
119+ t = tspan[1 ]
120+ tf = prob. tspan[2 ]
121+
122+ beta1 = 7 / 50
123+ beta2 = 2 / 25
124+ qmax = 10.0
125+ qmin = 1 / 5
126+ gamma = 9 / 10
127+ qoldinit = 1e-4
128+
129+ if saveat === nothing
130+ ts = Vector {eltype(dt)} (undef,1 )
131+ ts[1 ] = prob. tspan[1 ]
132+ us = Vector {typeof(u0)} (undef,0 )
133+ push! (us,recursivecopy (u0))
134+ else
135+ ts = saveat
136+ cur_t = 1
137+ us = MVector {length(ts),typeof(u0)} (undef)
138+ if prob. tspan[1 ] == ts[1 ]
139+ cur_t += 1
140+ us[1 ] = u0
141+ end
142+ end
143+
144+ u = u0
145+ qold = 1e-4
146+ k7 = f (u, p, t)
49147
148+ cs, as, btildes, rs = _build_atsit5_caches (eltype (u0))
149+ c1, c2, c3, c4, c5, c6 = cs
150+ a21, a31, a32, a41, a42, a43, a51, a52, a53, a54,
151+ a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76 = as
152+ btilde1, btilde2, btilde3, btilde4, btilde5, btilde6, btilde7 = btildes
153+
154+ # FSAL
155+ while t < tspan[2 ]
156+ uprev = u
157+ k1 = k7
158+ EEst = Inf
159+
160+ while EEst > 1
161+ dt < 1e-14 && error (" dt<dtmin" )
162+
163+ tmp = uprev+ dt* a21* k1
164+ k2 = f (tmp, p, t+ c1* dt)
165+ tmp = uprev+ dt* (a31* k1+ a32* k2)
166+ k3 = f (tmp, p, t+ c2* dt)
167+ tmp = uprev+ dt* (a41* k1+ a42* k2+ a43* k3)
168+ k4 = f (tmp, p, t+ c3* dt)
169+ tmp = uprev+ dt* (a51* k1+ a52* k2+ a53* k3+ a54* k4)
170+ k5 = f (tmp, p, t+ c4* dt)
171+ tmp = uprev+ dt* (a61* k1+ a62* k2+ a63* k3+ a64* k4+ a65* k5)
172+ k6 = f (tmp, p, t+ dt)
173+ u = uprev+ dt* (a71* k1+ a72* k2+ a73* k3+ a74* k4+ a75* k5+ a76* k6)
174+ k7 = f (u, p, t+ dt)
175+
176+ tmp = dt* (btilde1* k1+ btilde2* k2+ btilde3* k3+ btilde4* k4+
177+ btilde5* k5+ btilde6* k6+ btilde7* k7)
178+ tmp = tmp./ (abstol.+ max .(abs .(uprev),abs .(u))* reltol)
179+ EEst = DiffEqBase. ODE_DEFAULT_NORM (tmp, t)
180+
181+ if iszero (EEst)
182+ q = inv (qmax)
183+ else
184+ @fastmath q11 = EEst^ beta1
185+ @fastmath q = q11/ (qold^ beta2)
186+ end
187+
188+ if EEst > 1
189+ dt = dt/ min (inv (qmin),q11/ gamma)
190+ else # EEst <= 1
191+ @fastmath q = max (inv (qmax),min (inv (qmin),q/ gamma))
192+ qold = max (EEst,qoldinit)
193+ dtold = dt
194+ dt = dt/ q # dtnew
195+ dt = min (abs (dt),abs (tf- t- dtold))
196+ told = t
197+
198+ if (tf - t - dtold) < 1e-14
199+ t = tf
200+ else
201+ t += dtold
202+ end
203+
204+ if saveat === nothing && save_everystep
205+ push! (us,recursivecopy (u))
206+ push! (ts,t)
207+ else saveat != = nothing
208+ while cur_t <= length (ts) && ts[cur_t] <= t
209+ savet = ts[cur_t]
210+ θ = (savet - told)/ dtold
211+ b1θ, b2θ, b3θ, b4θ, b5θ, b6θ, b7θ = bθs (rs, θ)
212+ us[cur_t] = uprev + dtold* (
213+ b1θ* k1 + b2θ* k2 + b3θ* k3 + b4θ* k4 + b5θ* k5 + b6θ* k6 + b7θ* k7)
214+ cur_t += 1
215+ end
216+ end
217+ end
218+ end
219+ end
220+
221+ if saveat === nothing && ! save_everystep
222+ push! (us,u)
223+ push! (ts,t)
224+ end
225+
226+ sol = DiffEqBase. build_solution (prob,Tsit5 (),ts,us,calculate_error = false )
227+
228+ DiffEqBase. has_analytic (prob. f) && DiffEqBase. calculate_solution_errors! (sol;timeseries_errors= true ,dense_errors= false )
229+
230+ sol
231+ end
0 commit comments