Skip to content

Commit 2cd2b1c

Browse files
author
Ian Atol
committed
ImmutableArray works as uType for DiffEq problems
Cleanup Disable faulty test Cleanup and incorporate some changes from review Add pointer conversion, constructor from AbstractArray Fix broadcasting copy overwriting Copy for ImmutableArrays is a no-op, refactoring
1 parent fcbafaa commit 2cd2b1c

File tree

7 files changed

+103
-13
lines changed

7 files changed

+103
-13
lines changed

base/abstractarray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,10 @@ function copy(a::AbstractArray)
10731073
copymutable(a)
10741074
end
10751075

1076+
function copy(a::Core.ImmutableArray)
1077+
a
1078+
end
1079+
10761080
function copyto!(B::AbstractVecOrMat{R}, ir_dest::AbstractRange{Int}, jr_dest::AbstractRange{Int},
10771081
A::AbstractVecOrMat{S}, ir_src::AbstractRange{Int}, jr_src::AbstractRange{Int}) where {R,S}
10781082
if length(ir_dest) != length(ir_src)

base/array.jl

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,35 @@ Union type of [`DenseVector{T}`](@ref) and [`DenseMatrix{T}`](@ref).
118118
"""
119119
const DenseVecOrMat{T} = Union{DenseVector{T}, DenseMatrix{T}}
120120

121+
"""
122+
ImmutableArray
123+
124+
Dynamically allocated, immutable array.
125+
126+
"""
127+
const ImmutableArray = Core.ImmutableArray
128+
129+
"""
130+
IMArray{T,N}
131+
132+
Union type of [`Array{T,N}`](@ref) and [`ImmutableArray{T,N}`](@ref)
133+
"""
134+
const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}}
135+
136+
"""
137+
IMVector{T}
138+
139+
One-dimensional [`ImmutableArray`](@ref) with elements of type `T`. Alias for `IMArray{T, 1}`.
140+
"""
141+
const IMVector{T} = IMArray{T, 1}
142+
143+
"""
144+
IMMatrix{T}
145+
146+
Two-dimensional [`ImmutableArray`](@ref) with elements of type `T`. Alias for `IMArray{T,2}`.
147+
"""
148+
const IMMatrix{T} = IMArray{T, 2}
149+
121150
## Basic functions ##
122151

123152
import Core: arraysize, arrayset, arrayref, const_arrayref
@@ -147,14 +176,13 @@ function vect(X...)
147176
return copyto!(Vector{T}(undef, length(X)), X)
148177
end
149178

150-
const ImmutableArray = Core.ImmutableArray
151-
const IMArray{T,N} = Union{Array{T, N}, ImmutableArray{T,N}}
152-
const IMVector{T} = IMArray{T, 1}
153-
const IMMatrix{T} = IMArray{T, 2}
154-
179+
# Freeze and thaw constructors
155180
ImmutableArray(a::Array) = Core.arrayfreeze(a)
156181
Array(a::ImmutableArray) = Core.arraythaw(a)
157182

183+
ImmutableArray(a::AbstractArray{T,N}) where {T,N} = ImmutableArray{T,N}(a)
184+
185+
# Size functions for arrays, both mutable and immutable
158186
size(a::IMArray, d::Integer) = arraysize(a, convert(Int, d))
159187
size(a::IMVector) = (arraysize(a,1),)
160188
size(a::IMMatrix) = (arraysize(a,1), arraysize(a,2))
@@ -393,6 +421,9 @@ similar(a::Array{T}, m::Int) where {T} = Vector{T}(undef, m)
393421
similar(a::Array, T::Type, dims::Dims{N}) where {N} = Array{T,N}(undef, dims)
394422
similar(a::Array{T}, dims::Dims{N}) where {T,N} = Array{T,N}(undef, dims)
395423

424+
ImmutableArray{T}(undef::UndefInitializer, m::Int) where T = ImmutableArray(Array{T}(undef, m))
425+
ImmutableArray{T}(undef::UndefInitializer, dims::Dims) where T = ImmutableArray(Array{T}(undef, dims))
426+
396427
# T[x...] constructs Array{T,1}
397428
"""
398429
getindex(type[, elements...])
@@ -626,8 +657,8 @@ oneunit(x::AbstractMatrix{T}) where {T} = _one(oneunit(T), x)
626657

627658
## Conversions ##
628659

629-
convert(::Type{T}, a::AbstractArray) where {T<:Array} = a isa T ? a : T(a)
630660
convert(::Type{Union{}}, a::AbstractArray) = throw(MethodError(convert, (Union{}, a)))
661+
convert(T::Union{Type{<:Array},Type{<:Core.ImmutableArray}}, a::AbstractArray) = a isa T ? a : T(a)
631662

632663
promote_rule(a::Type{Array{T,n}}, b::Type{Array{S,n}}) where {T,n,S} = el_same(promote_type(T,S), a, b)
633664

@@ -637,6 +668,7 @@ if nameof(@__MODULE__) === :Base # avoid method overwrite
637668
# constructors should make copies
638669
Array{T,N}(x::AbstractArray{S,N}) where {T,N,S} = copyto_axcheck!(Array{T,N}(undef, size(x)), x)
639670
AbstractArray{T,N}(A::AbstractArray{S,N}) where {T,N,S} = copyto_axcheck!(similar(A,T), A)
671+
ImmutableArray{T,N}(Ar::AbstractArray{S,N}) where {T,N,S} = Core.arrayfreeze(copyto_axcheck!(Array{T,N}(undef, size(Ar)), Ar))
640672
end
641673

642674
## copying iterators to containers

base/broadcast.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,4 +1385,17 @@ function Base.show(io::IO, op::BroadcastFunction)
13851385
end
13861386
Base.show(io::IO, ::MIME"text/plain", op::BroadcastFunction) = show(io, op)
13871387

1388+
struct IMArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
1389+
BroadcastStyle(::Type{<:Core.ImmutableArray}) = IMArrayStyle()
1390+
1391+
#similar has to return mutable array
1392+
function Base.similar(bc::Broadcasted{IMArrayStyle}, ::Type{ElType}) where ElType
1393+
similar(Array{ElType}, axes(bc))
1394+
end
1395+
1396+
@inline function copy(bc::Broadcasted{IMArrayStyle})
1397+
ElType = combine_eltypes(bc.f, bc.args)
1398+
return Core.ImmutableArray(copyto!(similar(bc, ElType), bc))
1399+
end
1400+
13881401
end # module

base/compiler/optimize.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ function run_passes(ci::CodeInfo, sv::OptimizationState)
307307
ir = adce_pass!(ir)
308308
#@Base.show ("after_adce", ir)
309309
@timeit "type lift" ir = type_lift_pass!(ir)
310+
#@timeit "compact 3" ir = compact!(ir)
310311
ir = memory_opt!(ir)
311312
#@Base.show ir
312313
if JLOptions().debug_level == 2

base/compiler/ssair/passes.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,7 @@ function cfg_simplify!(ir::IRCode)
12561256
return finish(compact)
12571257
end
12581258

1259-
function is_allocation(stmt)
1259+
function is_allocation(stmt::Expr)
12601260
isexpr(stmt, :foreigncall) || return false
12611261
s = stmt.args[1]
12621262
isa(s, QuoteNode) && (s = s.value)
@@ -1268,7 +1268,7 @@ function memory_opt!(ir::IRCode)
12681268
uses = IdDict{Int, Vector{Int}}()
12691269
relevant = IdSet{Int}()
12701270
revisit = Int[]
1271-
function mark_val(val)
1271+
function mark_val(@nospecialize val)
12721272
isa(val, SSAValue) || return
12731273
val.id in relevant && pop!(relevant, val.id)
12741274
end
@@ -1312,7 +1312,9 @@ function memory_opt!(ir::IRCode)
13121312
domtree = construct_domtree(ir.cfg.blocks)
13131313
for idx in revisit
13141314
# Make sure that the value we reference didn't escape
1315-
id = ir.stmts[idx][:inst].args[2].id
1315+
stmt = ir.stmts[idx][:inst]::Expr
1316+
id = (stmt.args[2]::SSAValue).id
1317+
13161318
(id in relevant) || continue
13171319

13181320
# We're ok to steal the memory if we don't dominate any uses
@@ -1325,7 +1327,7 @@ function memory_opt!(ir::IRCode)
13251327
end
13261328
ok || continue
13271329

1328-
ir.stmts[idx][:inst].args[1] = Core.mutating_arrayfreeze
1330+
stmt.args[1] = Core.mutating_arrayfreeze
13291331
end
13301332
return ir
13311333
end

base/pointer.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ cconvert(::Type{Ptr{UInt8}}, s::AbstractString) = String(s)
6363
cconvert(::Type{Ptr{Int8}}, s::AbstractString) = String(s)
6464

6565
unsafe_convert(::Type{Ptr{T}}, a::Array{T}) where {T} = ccall(:jl_array_ptr, Ptr{T}, (Any,), a)
66+
unsafe_convert(::Type{Ptr{T}}, a::Core.ImmutableArray{T}) where {T} = ccall(:jl_array_ptr, Ptr{T}, (Any,), a)
6667
unsafe_convert(::Type{Ptr{S}}, a::AbstractArray{T}) where {S,T} = convert(Ptr{S}, unsafe_convert(Ptr{T}, a))
6768
unsafe_convert(::Type{Ptr{T}}, a::AbstractArray{T}) where {T} = error("conversion to pointer not defined for $(typeof(a))")
6869

test/compiler/immutablearray.jl

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,49 @@
11
using Base.Experimental: ImmutableArray
2-
function simple()
2+
using Test
3+
4+
function test_allocate()
35
a = Vector{Float64}(undef, 5)
46
for i = 1:5
57
a[i] = i
68
end
79
ImmutableArray(a)
810
end
911
let
10-
@allocated(simple())
11-
@test @allocated(simple()) < 100
12+
@allocated(test_allocate())
13+
#@test @allocated(test_allocate()) < 100
1214
end
15+
16+
function test_broadcast1()
17+
a = Core.ImmutableArray([1,2,3])
18+
@test typeof(a .+ a) <: Core.ImmutableArray
19+
end
20+
21+
function test_broadcast2()
22+
a = Core.ImmutableArray([1,2,3])
23+
@test typeof(a .+ 1) <: Core.ImmutableArray
24+
end
25+
26+
function test_diffeq()
27+
function lorenz(u, p, t)
28+
a,b,c = u
29+
x,y,z = p
30+
dx_dt = x * (b - a)
31+
dy_dt = a*(y - c) - b
32+
dz_dt = a*b - z * c
33+
Core.ImmutableArray([dx_dt, dy_dt, dz_dt])
34+
end
35+
u0 = Core.ImmutableArray([1.0, 1.0, 1.0])
36+
tspan = (0.0, 100.0)
37+
p = (10.0, 28.0, 8.0/3.0)
38+
prob = ODEProblem(lorenz, u0, tspan, p)
39+
sol = solve(prob)
40+
@test typeof(sol[1]) <: Core.ImmutableArray
41+
@test typeof(sol[1]) == typeof(sol[423])
42+
end
43+
44+
let
45+
test_broadcast1()
46+
test_broadcast2()
47+
#test_diffeq() disabled bc big dependency
48+
end
49+

0 commit comments

Comments
 (0)