diff --git a/Project.toml b/Project.toml index c2b6a3e..8595dff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,30 +1,35 @@ name = "StructUtils" uuid = "ec057cc2-7a8d-4b58-b3b3-92acb9f63b42" -version = "2.6.2" +version = "2.7.0" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] -Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [extensions] -StructUtilsTablesExt = ["Tables"] StructUtilsMeasurementsExt = ["Measurements"] +StructUtilsStaticArraysCoreExt = ["StaticArraysCore"] +StructUtilsTablesExt = ["Tables"] [compat] -julia = "1.9" -Tables = "1" Measurements = "2" +StaticArraysCore = "1" +Tables = "1" +julia = "1.9" [extras] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" -Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [targets] -test = ["Dates", "Measurements", "Tables", "Test", "UUIDs"] +test = ["Dates", "Measurements", "StaticArrays", "StaticArraysCore", "Tables", "Test", "UUIDs"] diff --git a/ext/StructUtilsStaticArraysCoreExt.jl b/ext/StructUtilsStaticArraysCoreExt.jl new file mode 100644 index 0000000..e211bab --- /dev/null +++ b/ext/StructUtilsStaticArraysCoreExt.jl @@ -0,0 +1,19 @@ +module StructUtilsStaticArraysCoreExt + +using StructUtils +using StaticArraysCore: StaticArray, size_to_tuple + +StructUtils.fixedsizearray(::Type{<:StaticArray}) = true + +StructUtils.discover_dims(style, ::Type{<:StaticArray{S}}, source) where {S<:Tuple} = + size_to_tuple(S) + +StructUtils.arrayfromdata(::Type{T}, buf::Vector, dims::Tuple) where {T<:StaticArray} = + T(Tuple(buf)) + +if VERSION >= v"1.11" + StructUtils.arrayfromdata(::Type{T}, mem::Memory, dims::Tuple) where {T<:StaticArray} = + T(Tuple(mem)) +end + +end # module diff --git a/src/StructUtils.jl b/src/StructUtils.jl index 6ba1da8..63bfaa5 100644 --- a/src/StructUtils.jl +++ b/src/StructUtils.jl @@ -219,6 +219,25 @@ arraylike(::Type{<:Base.Generator}) = true arraylike(::Type{<:Core.SimpleVector}) = true arraylike(@nospecialize(T)) = false +""" + StructUtils.fixedsizearray(::Type{T}) -> Bool + StructUtils.fixedsizearray(::StructStyle, ::Type{T}) -> Bool + +Returns `true` if `T` is a fixed-size array type that should be pre-allocated +and filled via `setindex!` rather than grown via `push!`. The default +implementation returns `true` for multidimensional `<:AbstractArray` types +(ndims > 1) and `false` for everything else. + +Override this for custom array types that have a fixed, known size but +are not growable (e.g. `StaticArrays.StaticArray`). +""" +function fixedsizearray end + +fixedsizearray(::Type) = false +fixedsizearray(::Type{<:AbstractArray{T,N}}) where {T,N} = N > 1 +fixedsizearray(::Type{<:AbstractSet}) = false +fixedsizearray(st::StructStyle, ::Type{T}) where {T} = fixedsizearray(T) + """ StructUtils.structlike(x) -> Bool StructUtils.structlike(::StructStyle, x) -> Bool @@ -671,6 +690,32 @@ function discover_dims(style, x) return (ret.value..., len) end +""" + StructUtils.discover_dims(style, ::Type{T}, source) -> Tuple + +Discover the dimensions for a fixed-size array type `T`. By default, +delegates to `discover_dims(style, source)` to scan the source object. +Override for types where dimensions are encoded in the type itself +(e.g. `StaticArrays.StaticArray`), avoiding the need to scan the source. +""" +discover_dims(style, ::Type{T}, source) where {T} = discover_dims(style, source) + +""" + StructUtils.arrayfromdata(::Type{T}, mem, dims::Tuple) -> T + +Convert a filled data buffer `mem` with shape `dims` into the target array +type `T`. Called at the end of `makearray` for `fixedsizearray` types. +""" +function arrayfromdata end + +arrayfromdata(::Type{T}, buf::Vector, dims::Tuple) where {T<:AbstractArray} = + reshape(buf, dims) + +if VERSION >= v"1.11" + arrayfromdata(::Type{T}, mem::Memory, dims::Tuple) where {T<:AbstractArray} = + Base.wrap(Array, Base.memoryref(mem), dims) +end + struct MultiDimClosure{S,A} style::S arr::A @@ -886,7 +931,42 @@ function (f::ArrayClosure{T,S})(_, v) where {T,S} return st end -makearray(style, ::Type{T}, source) where {T} = @inline makearray(style, initialize(style, T, source), source) +struct FixedArrayClosure{A,S} + arr::A + style::S + idx::Base.RefValue{Int} +end + +function (f::FixedArrayClosure{A,S})(_, v) where {A,S} + val, st = make(f.style, eltype(f.arr), v) + i = f.idx[] + @inbounds f.arr[i] = val + f.idx[] = i + 1 + return st +end + +function makearray(style, ::Type{T}, source) where {T} + if fixedsizearray(style, T) + ET = eltype(T) + dims = discover_dims(style, T, source) + L = prod(dims) + if VERSION >= v"1.11" + data = Memory{ET}(undef, L) + else + data = Vector{ET}(undef, L) + end + N = length(dims) + if N > 1 + buf = reshape(data, dims) + st = applyeach(style, MultiDimClosure(style, buf, ones(Int, N), Ref(N)), source) + else + st = applyeach(style, FixedArrayClosure(data, style, Ref(1)), source) + end + return arrayfromdata(T, data, dims), st + else + return @inline makearray(style, initialize(style, T, source), source) + end +end function makearray(style, x::T, source) where {T} if !(T <: AbstractSet) && ndims(T) > 1 diff --git a/test/runtests.jl b/test/runtests.jl index 38357df..1e3ad0a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -379,4 +379,16 @@ end @test StructUtils.make(SomeStruct, Dict(:my_field => 45)).my_field == 45 end +if VERSION >= v"1.11" + using StaticArrays + + @testset "StaticArrays" begin + @test StructUtils.make(SVector{3,Int}, [1, 2, 3]) == SVector{3,Int}((1, 2, 3)) + @test StructUtils.make(SVector{2,Float64}, [1, 2]) == SVector{2,Float64}((1.0, 2.0)) + @test StructUtils.make(SMatrix{2,2,Int}, [[1, 3], [2, 4]]) == SMatrix{2,2,Int}((1, 3, 2, 4)) + @test StructUtils.make(MVector{3,Int}, [1, 2, 3]) == MVector{3,Int}((1, 2, 3)) + @test StructUtils.make(Vector{SVector{2,Int}}, [[1, 2], [3, 4]]) == [SVector{2,Int}((1, 2)), SVector{2,Int}((3, 4))] + end +end + end