From 7664236b075bc9006d62e089042b672b1961b458 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 May 2025 21:17:51 -0400 Subject: [PATCH 01/24] refactor: rework the package to be Reactant first --- .JuliaFormatter.toml | 2 +- .buildkite/testing.yml | 46 +++++++-------- Project.toml | 20 ++----- src/NeuralOperators.jl | 24 +++----- src/layers.jl | 103 ++++++++++++++-------------------- src/models/deeponet.jl | 84 ++++++++++++++++++---------- src/models/fno.jl | 27 ++++++--- src/models/nomad.jl | 31 ++++++---- src/transform.jl | 14 ++++- src/utils.jl | 33 ++--------- test/Project.toml | 14 ++--- test/deeponet_tests.jl | 118 +++++++++++++++++++-------------------- test/fno_tests.jl | 64 ++++++++++----------- test/layers_tests.jl | 87 ++++++++++++++--------------- test/nomad_tests.jl | 46 +++++++-------- test/qa_tests.jl | 16 ++++-- test/runtests.jl | 35 ++++-------- test/shared_testsetup.jl | 46 +++++---------- test/utils_tests.jl | 58 ------------------- 19 files changed, 386 insertions(+), 482 deletions(-) delete mode 100644 test/utils_tests.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 37dae14..f444ca1 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,4 +1,4 @@ -style = "sciml" +style = "blue" format_markdown = true whitespace_in_kwargs = false margin = 92 diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 7979b22..3722d64 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -19,29 +19,29 @@ steps: julia: - "1" - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - BACKEND_GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" + # - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + # plugins: + # - JuliaCI/julia#v1: + # version: "{{matrix.julia}}" + # - JuliaCI/julia-test#v1: + # test_args: "--quickfail" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # env: + # JULIA_AMDGPU_CORE_MUST_LOAD: "1" + # JULIA_AMDGPU_HIP_MUST_LOAD: "1" + # JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + # BACKEND_GROUP: "AMDGPU" + # agents: + # queue: "juliagpu" + # rocm: "*" + # rocmgpu: "*" + # if: build.message !~ /\[skip tests\]/ + # timeout_in_minutes: 60 + # matrix: + # setup: + # julia: + # - "1" env: RETESTITEMS_NWORKERS: 4 diff --git a/Project.toml b/Project.toml index e1406d5..425e257 100644 --- a/Project.toml +++ b/Project.toml @@ -1,33 +1,23 @@ name = "NeuralOperators" uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" authors = ["Avik Pal "] -version = "0.5.3" +version = "0.6.0" [deps] -ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] -ArgCheck = "2.3" -ChainRulesCore = "1.24" +AbstractFFTs = "1.5.0" ConcreteStructs = "0.2.3" -FFTW = "1.8" Lux = "1" LuxCore = "1" -LuxLib = "1.2" -MLDataDevices = "1.2.0" -NNlib = "0.9.21" Random = "1.10" -Static = "1.1.1" +Reactant = "0.2.122" WeightInitializers = "1" julia = "1.10" diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 78a1552..315aac6 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -1,34 +1,26 @@ module NeuralOperators -using ArgCheck: @argcheck -using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable +using AbstractFFTs: rfft, irfft using ConcreteStructs: @concrete -using FFTW: FFTW, irfft, rfft using Random: Random, AbstractRNG -using Static: StaticBool, False, True, known, static using Lux using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer -using LuxLib: batched_matmul -using MLDataDevices: AbstractDevice, AbstractGPUDevice -using NNlib: NNlib - -const BoolLike = Union{Bool, StaticBool, Val{true}, Val{false}} -const CRC = ChainRulesCore +using WeightInitializers: glorot_uniform include("utils.jl") include("transform.jl") include("layers.jl") -include("models/fno.jl") -include("models/deeponet.jl") -include("models/nomad.jl") +# include("models/fno.jl") +# include("models/deeponet.jl") +# include("models/nomad.jl") export FourierTransform export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel -export FourierNeuralOperator -export DeepONet -export NOMAD +# export FourierNeuralOperator +# export DeepONet +# export NOMAD end diff --git a/src/layers.jl b/src/layers.jl index 218cdaa..0488f4e 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -1,33 +1,28 @@ """ - OperatorConv(ch::Pair{<:Integer, <:Integer}, modes::Dims, - ::Type{<:AbstractTransform}; init_weight=glorot_uniform, - permuted=Val(false)) + OperatorConv( + ch::Pair{<:Integer, <:Integer}, modes::Dims, tr::AbstractTransform; + init_weight=glorot_uniform + ) ## Arguments - `ch`: A `Pair` of input and output channel size `ch_in => ch_out`, e.g. `64 => 64`. - `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension of data. - - `::Type{TR}`: The transform to operate the transformation. + - `tr`: The transform to operate the transformation. ## Keyword Arguments - `init_weight`: Initial function to initialize parameters. - - `permuted`: Whether the dim is permuted. If `permuted = Val(false)`, the layer accepts - data in the order of `(ch, x_1, ... , x_d, batch)`. Otherwise the order is - `(x_1, ... , x_d, ch, batch)`. ## Example ```jldoctest -julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}); - -julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}; permuted=Val(true)); +julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}((16,))); ``` """ @concrete struct OperatorConv <: AbstractLuxLayer - perm <: StaticBool in_chs::Int out_chs::Int prod_modes::Int @@ -35,17 +30,14 @@ julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}; permuted=Val(tr init_weight end -function Base.show(io::IO, layer::OperatorConv) - print(io, "OperatorConv($(layer.in_chs) => $(layer.out_chs), $(layer.tform.modes), \ - $(printable_type(layer.tform)); permuted = $(layer.perm))") -end - function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv) in_chs, out_chs = layer.in_chs, layer.out_chs scale = real(one(eltype(layer.tform))) / (in_chs * out_chs) return (; weight=scale * layer.init_weight( - rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes)) + rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes + ) + ) end function LuxCore.parameterlength(layer::OperatorConv) @@ -53,31 +45,27 @@ function LuxCore.parameterlength(layer::OperatorConv) end function OperatorConv( - ch::Pair{<:Integer, <:Integer}, modes::Dims, ::Type{TR}; init_weight=glorot_uniform, - permuted::BoolLike=False()) where {TR <: AbstractTransform{<:Number}} - return OperatorConv(static(permuted), ch..., prod(modes), TR(modes), init_weight) + ch::Pair{<:Integer,<:Integer}, + modes::Dims, + tform::AbstractTransform; + init_weight=glorot_uniform, +) + return OperatorConv(ch..., prod(modes), tform, init_weight) end -function (conv::OperatorConv{True})(x::AbstractArray, ps, st) +function (conv::OperatorConv)(x::AbstractArray{T,N}, ps, st) where {T,N} return operator_conv(x, conv.tform, ps.weight), st end -function (conv::OperatorConv{False})(x::AbstractArray, ps, st) - N = ndims(conv.tform) - xᵀ = permutedims(x, (ntuple(i -> i + 1, N)..., 1, N + 2)) - yᵀ = operator_conv(xᵀ, conv.tform, ps.weight) - y = permutedims(yᵀ, (N + 1, 1:N..., N + 2)) - return y, st -end - function operator_conv(x, tform::AbstractTransform, weights) x_t = transform(tform, x) x_tr = truncate_modes(tform, x_t) x_p = apply_pattern(x_tr, weights) pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)] - x_padded = NNlib.pad_constant(x_p, expand_pad_dims(pad_dims), false; - dims=ntuple(identity, ndims(x_p) - 2))::typeof(x_p) + x_padded = NNlib.pad_constant( + x_p, expand_pad_dims(pad_dims), false; dims=ntuple(identity, ndims(x_p) - 2) + ) return inverse(tform, x_padded, size(x)) end @@ -93,40 +81,32 @@ Construct a `OperatorConv` with `FourierTransform{ComplexF32}` as the transform. ```jldoctest julia> SpectralConv(2 => 5, (16,)); -julia> SpectralConv(2 => 5, (16,); permuted=Val(true)); - ``` """ -function SpectralConv(args...; kwargs...) - return OperatorConv(args..., FourierTransform{ComplexF32}; kwargs...) +function SpectralConv(ch::Pair{<:Integer,<:Integer}, modes::Dims; kwargs...) + return OperatorConv(ch, modes, FourierTransform{ComplexF32}(modes); kwargs...) end """ - OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims, transform::Type{TR}, - act::A=identity; permuted=Val(false), kwargs...) where {TR <: AbstractTransform, A} + OperatorKernel( + ch::Pair{<:Integer, <:Integer}, modes::Dims, transform::AbstractTransform, + act=identity; kwargs... + ) ## Arguments - `ch`: A `Pair` of input and output channel size `ch_in => ch_out`, e.g. `64 => 64`. - `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension of data. - - `::Type{TR}`: The transform to operate the transformation. - -## Keyword Arguments - - - `σ`: Activation function. - - `permuted`: Whether the dim is permuted. If `permuted = Val(true)`, the layer accepts - data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is - `(x_1, ... , x_d, ch, batch)`. + - `transform`: The transform to operate the transformation. + - `act`: Activation function. All the keyword arguments are passed to the [`OperatorConv`](@ref) constructor. ## Example ```jldoctest -julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}); - -julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val(true)); +julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}((16,))); ``` """ @@ -134,14 +114,20 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val( layer end -OperatorKernel(lin, conv) = OperatorKernel(lin, conv, identity) - function OperatorKernel( - ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR}, act=identity; - permuted::BoolLike=False(), kwargs...) where {N, TR <: AbstractTransform{<:Number}} - lin = known(static(permuted)) ? Conv(ntuple(one, N), ch) : Dense(ch) - conv = OperatorConv(ch, modes, transform; permuted, kwargs...) - return OperatorKernel(Parallel(Fix1(add_act, act), lin, conv)) + ch::Pair{<:Integer,<:Integer}, + modes::Dims{N}, + transform::AbstractTransform, + act=identity; + kwargs..., +) where {N} + return OperatorKernel( + Parallel( + Fix1(add_act, act), + Conv(ntuple(one, N), ch), + OperatorConv(ch, modes, transform; kwargs...), + ), + ) end """ @@ -155,11 +141,8 @@ Construct a `OperatorKernel` with `FourierTransform{ComplexF32}` as the transfor ```jldoctest julia> SpectralKernel(2 => 5, (16,)); -julia> SpectralKernel(2 => 5, (16,); permuted=Val(true)); - ``` """ -function SpectralKernel( - ch::Pair{<:Integer, <:Integer}, modes::Dims, act=identity; kwargs...) - return OperatorKernel(ch, modes, FourierTransform{ComplexF32}, act; kwargs...) +function SpectralKernel(ch::Pair{<:Integer,<:Integer}, modes::Dims, act=identity; kwargs...) + return OperatorKernel(ch, modes, FourierTransform{ComplexF32}(modes), act; kwargs...) end diff --git a/src/models/deeponet.jl b/src/models/deeponet.jl index 76e50ff..84f7551 100644 --- a/src/models/deeponet.jl +++ b/src/models/deeponet.jl @@ -96,21 +96,35 @@ julia> size(first(deeponet((u, y), ps, st))) ``` """ function DeepONet(; - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity, - trunk_activation=identity, additional=NoOpLayer()) + branch=(64, 32, 32, 16), + trunk=(1, 8, 8, 16), + branch_activation=identity, + trunk_activation=identity, + additional=NoOpLayer(), +) # checks for last dimension size - @argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \ - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ - work." - - branch_net = Chain([Dense(branch[i] => branch[i + 1], - ifelse(i == length(branch) - 1, identity, branch_activation)) - for i in 1:(length(branch) - 1)]...) - - trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], - ifelse(i == length(trunk) - 1, identity, trunk_activation)) - for i in 1:(length(trunk) - 1)]...) + @argcheck branch[end] == trunk[end] "Branch and Trunk net must share the same amount of \ + nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ + work." + + branch_net = Chain( + [ + Dense( + branch[i] => branch[i + 1], + ifelse(i == length(branch) - 1, identity, branch_activation), + ) for i in 1:(length(branch) - 1) + ]..., + ) + + trunk_net = Chain( + [ + Dense( + trunk[i] => trunk[i + 1], + ifelse(i == length(trunk) - 1, identity, trunk_activation), + ) for i in 1:(length(trunk) - 1) + ]..., + ) return DeepONet(branch_net, trunk_net, additional) end @@ -119,49 +133,62 @@ function (deeponet::DeepONet)((x1, x2), ps, st::NamedTuple) b, st_b = deeponet.branch(x1, ps.branch, st.branch) t, st_t = deeponet.trunk(x2, ps.trunk, st.trunk) - @argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \ - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ - work." + @argcheck size(b, 1) == size(t, 1) "Branch and Trunk net must share the same amount of \ + nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ + work." - additional = deeponet.additional isa NoOpLayer ? nothing : - StatefulLuxLayer{true}(deeponet.additional, ps.additional, st.additional) + additional = if deeponet.additional isa NoOpLayer + nothing + else + StatefulLuxLayer{true}(deeponet.additional, ps.additional, st.additional) + end out = deeponet_project(b, t, additional) - stₙ = merge((; branch=st_b, trunk=st_t), - deeponet.additional isa NoOpLayer ? (;) : additional.st) + stₙ = merge( + (; branch=st_b, trunk=st_t), deeponet.additional isa NoOpLayer ? (;) : additional.st + ) return out, stₙ end function deeponet_project( - b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, ::Nothing) where {T1, T2} + b::AbstractArray{T1,2}, t::AbstractArray{T2,3}, ::Nothing +) where {T1,T2} # b [p, nb], t [p, N, nb] bᵣ = reshape(b, size(b, 1), 1, size(b, 2)) return dropdims(sum(bᵣ .* t; dims=1); dims=1) # [N, nb] end function deeponet_project( - b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, ::Nothing) where {T1, T2} + b::AbstractArray{T1,3}, t::AbstractArray{T2,3}, ::Nothing +) where {T1,T2} # b [p, u, nb], t [p, N, nb] return batched_matmul(safe_batched_adjoint(b), t) # [u, N, b] end function deeponet_project( - b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, ::Nothing) where {T1, T2, N} + b::AbstractArray{T1,N}, t::AbstractArray{T2,3}, ::Nothing +) where {T1,T2,N} # b [p, u_size..., nb], t [p, N, nb] bᵣ = reshape(b, size(b, 1), :, size(b, N)) - return reshape(batched_matmul(safe_batched_adjoint(bᵣ), t), - size(b)[2:(N - 1)]..., size(t, 2), size(b, N)) + return reshape( + batched_matmul(safe_batched_adjoint(bᵣ), t), + size(b)[2:(N - 1)]..., + size(t, 2), + size(b, N), + ) end function deeponet_project( - b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional) where {T1, T2} + b::AbstractArray{T1,2}, t::AbstractArray{T2,3}, additional +) where {T1,T2} # b [p, nb], t [p, N, nb] bᵣ = reshape(b, size(b, 1), 1, size(b, 2)) return additional(bᵣ .* t) # [p, N, nb] => [out_dims, N, nb] end function deeponet_project( - b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional) where {T1, T2} + b::AbstractArray{T1,3}, t::AbstractArray{T2,3}, additional +) where {T1,T2} # b [p, u, nb], t [p, N, nb] bᵣ = reshape(b, size(b, 1), size(b, 2), 1, size(b, 3)) # [p, u, 1, nb] tᵣ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # [p, 1, N, nb] @@ -169,7 +196,8 @@ function deeponet_project( end function deeponet_project( - b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, additional) where {T1, T2, N} + b::AbstractArray{T1,N}, t::AbstractArray{T2,3}, additional +) where {T1,T2,N} # b [p, u_size..., nb], t [p, N, nb] bᵣ = reshape(b, size(b, 1), :, 1, size(b, N)) # [p, (u_size...), 1, nb] tᵣ = reshape(t, size(t, 1), 1, size(t, 2), size(t, 3)) # [p, 1, N, nb] diff --git a/src/models/fno.jl b/src/models/fno.jl index 0a7304c..d66abd3 100644 --- a/src/models/fno.jl +++ b/src/models/fno.jl @@ -42,8 +42,13 @@ julia> size(first(fno(u, ps, st))) model <: Chain end -function FourierNeuralOperator(σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), - modes::Dims{M}=(16,), permuted::BoolLike=False(), kwargs...) where {C, M} +function FourierNeuralOperator( + σ=gelu; + chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), + modes::Dims{M}=(16,), + permuted::BoolLike=False(), + kwargs..., +) where {C,M} @argcheck length(chs) ≥ 5 map₁ = chs[1] => chs[2] @@ -53,12 +58,18 @@ function FourierNeuralOperator(σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128 kernel_size = map(Returns(1), modes) lifting = known(static(permuted)) ? Conv(kernel_size, map₁) : Dense(map₁) - project = known(static(permuted)) ? - Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) : - Chain(Dense(map₂, σ), Dense(map₃)) - - mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...) - for i in 2:(C - 3)]...) + project = if known(static(permuted)) + Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) + else + Chain(Dense(map₂, σ), Dense(map₃)) + end + + mapping = Chain( + [ + SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...) for + i in 2:(C - 3) + ]..., + ) return FourierNeuralOperator(Chain(lifting, mapping, project)) end diff --git a/src/models/nomad.jl b/src/models/nomad.jl index 0a50d2c..fc2de42 100644 --- a/src/models/nomad.jl +++ b/src/models/nomad.jl @@ -85,15 +85,26 @@ julia> size(first(nomad((u, y), ps, st))) (8, 5) ``` """ -function NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), - approximator_activation=identity, - decoder_activation=identity, concatenate=nomad_concatenate) - approximator_net = Chain([Dense(approximator[i] => approximator[i + 1], - approximator_activation) - for i in 1:(length(approximator) - 1)]...) - - decoder_net = Chain([Dense(decoder[i] => decoder[i + 1], decoder_activation) - for i in 1:(length(decoder) - 1)]...) +function NOMAD(; + approximator=(8, 32, 32, 16), + decoder=(18, 16, 8, 8), + approximator_activation=identity, + decoder_activation=identity, + concatenate=nomad_concatenate, +) + approximator_net = Chain( + [ + Dense(approximator[i] => approximator[i + 1], approximator_activation) for + i in 1:(length(approximator) - 1) + ]..., + ) + + decoder_net = Chain( + [ + Dense(decoder[i] => decoder[i + 1], decoder_activation) for + i in 1:(length(decoder) - 1) + ]..., + ) return NOMAD(approximator_net, decoder_net, concatenate) end @@ -105,7 +116,7 @@ function (nomad::NOMAD)(x, ps, st::NamedTuple) end function NOMAD(approximator_net, decoder_net; concatenate=nomad_concatenate) - NOMAD(approximator_net, decoder_net, concatenate) + return NOMAD(approximator_net, decoder_net, concatenate) end batch_vectorize(x::AbstractArray) = reshape(x, :, size(x, ndims(x))) diff --git a/src/transform.jl b/src/transform.jl index e683936..9f729e8 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -14,8 +14,15 @@ abstract type AbstractTransform{T} end Base.eltype(::Type{<:AbstractTransform{T}}) where {T} = T -printable_type(T::AbstractTransform) = "$(nameof(typeof(T))){$(eltype(T))}" +function transform end +function truncate_modes end +function inverse end +""" + FourierTransform{T}(modes) + +A concrete implementation of `AbstractTransform` for Fourier transforms. +""" @concrete struct FourierTransform{T} <: AbstractTransform{T} modes end @@ -25,12 +32,13 @@ Base.ndims(T::FourierTransform) = length(T.modes) transform(ft::FourierTransform, x::AbstractArray) = rfft(x, 1:ndims(ft)) function low_pass(ft::FourierTransform, x_fft::AbstractArray) - return view(x_fft,(map(d -> 1:d, ft.modes)...),:,:) + return view(x_fft, map(d -> 1:d, ft.modes)..., :, :) end truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft) function inverse( - ft::FourierTransform, x_fft::AbstractArray{T, N}, M::NTuple{N, Int64}) where {T, N} + ft::FourierTransform, x_fft::AbstractArray{T,N}, M::NTuple{N,Int64} +) where {T,N} return real(irfft(x_fft, first(M), 1:ndims(ft))) end diff --git a/src/utils.jl b/src/utils.jl index 459a1b4..f42dfe8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,18 +1,17 @@ function apply_pattern( - x_tr::AbstractArray{T1, N}, weights::AbstractArray{T2, 3}) where {T1, T2, N} + x_tr::AbstractArray{T1,N}, weights::AbstractArray{T2,3} +) where {T1,T2,N} x_size = size(x_tr) x_flat = reshape(x_tr, :, x_size[N - 1], x_size[N]) x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m - x_weighted = permutedims(batched_matmul(weights, x_flat_t), (3, 1, 2)) # m x o x b + x_weighted = permutedims(batched_mul(weights, x_flat_t), (3, 1, 2)) # m x o x b return reshape(x_weighted, x_size[1:(N - 2)]..., size(x_weighted)[2:3]...) end function add_act(act::F, x1, x2) where {F} - y = x1 .+ x2 - act = NNlib.fast_act(act, y) - return fast_activation!!(act, y) + return fast_activation!!(NNlib.fast_act(act, y), x1 .+ x2) end @concrete struct Fix1 <: Function @@ -27,27 +26,3 @@ Base.show(io::IO, f::Fix1) = print(io, "Fix1($(f.f), $(f.x))") function expand_pad_dims(pad_dims::Dims{N}) where {N} return ntuple(i -> isodd(i) ? 0 : pad_dims[i ÷ 2], 2N) end - -@non_differentiable expand_pad_dims(::Any) - -# Handling Wrapper Types are hard. Make sure to not construct a ReshapedArray of -# BatchedAdjoint -safe_batched_adjoint(x::AbstractArray) = NNlib.batched_adjoint(x) - -function CRC.rrule(::typeof(safe_batched_adjoint), x::AbstractArray) - return safe_batched_adjoint(x), ∇safe_batched_adjoint -end - -∇safe_batched_adjoint(Δ) = NoTangent(), safe_batched_adjoint(Δ) -function ∇safe_batched_adjoint(Δ::AbstractArray{T, 3}) where {T} - return ∇safe_batched_adjoint(get_device_type(Δ), Δ) -end - -function ∇safe_batched_adjoint(::Type{<:AbstractDevice}, Δ::AbstractArray{T, 3}) where {T} - return NoTangent(), safe_batched_adjoint(Δ) -end - -function ∇safe_batched_adjoint( - ::Type{<:AbstractGPUDevice}, Δ::AbstractArray{T, 3}) where {T} - return NoTangent(), stack(adjoint, eachslice(Δ; dims=3)) -end diff --git a/test/Project.toml b/test/Project.toml index 8890976..64ddc63 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,19 +1,17 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -24,18 +22,14 @@ Aqua = "0.8.7" Documenter = "1.5.0" ExplicitImports = "1.9.0" Hwloc = "3.2.0" -InteractiveUtils = "<0.0.1, 1" Lux = "1" LuxCore = "1" LuxLib = "1.2" LuxTestUtils = "1.1.2" -MLDataDevices = "1" -Optimisers = "0.3.3" -Pkg = "1.10" -Preferences = "1" +Optimisers = "0.4" Random = "1.10" ReTestItems = "1.24.0" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" -Zygote = "0.6.70" +Zygote = "0.7" diff --git a/test/deeponet_tests.jl b/test/deeponet_tests.jl index c67d31d..418b679 100644 --- a/test/deeponet_tests.jl +++ b/test/deeponet_tests.jl @@ -1,71 +1,71 @@ -@testitem "DeepONet" setup=[SharedTestSetup] begin - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) +# @testitem "DeepONet" setup=[SharedTestSetup] begin +# @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES +# rng = StableRNG(12345) - setups = [ - (u_size=(64, 5), y_size=(1, 10, 5), out_size=(10, 5), - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar"), - (u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(1, 10, 5), - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar II"), - (u_size=(64, 3, 5), y_size=(4, 10, 5), out_size=(3, 10, 5), - branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Vector"), - (u_size=(64, 4, 3, 3, 5), y_size=(4, 10, 5), out_size=(4, 3, 3, 10, 5), - branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor") - ] +# setups = [ +# (u_size=(64, 5), y_size=(1, 10, 5), out_size=(10, 5), +# branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar"), +# (u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(1, 10, 5), +# branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar II"), +# (u_size=(64, 3, 5), y_size=(4, 10, 5), out_size=(3, 10, 5), +# branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Vector"), +# (u_size=(64, 4, 3, 3, 5), y_size=(4, 10, 5), out_size=(4, 3, 3, 10, 5), +# branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor") +# ] - @testset "$(setup.name)" for setup in setups - u = rand(Float32, setup.u_size...) |> aType - y = rand(Float32, setup.y_size...) |> aType - deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) +# @testset "$(setup.name)" for setup in setups +# u = rand(Float32, setup.u_size...) |> aType +# y = rand(Float32, setup.y_size...) |> aType +# deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) - ps, st = Lux.setup(rng, deeponet) |> dev - @inferred first(deeponet((u, y), ps, st)) - @jet first(deeponet((u, y), ps, st)) +# ps, st = Lux.setup(rng, deeponet) |> dev +# @inferred first(deeponet((u, y), ps, st)) +# @jet first(deeponet((u, y), ps, st)) - pred = first(deeponet((u, y), ps, st)) - @test setup.out_size == size(pred) - end +# pred = first(deeponet((u, y), ps, st)) +# @test setup.out_size == size(pred) +# end - setups = [ - (u_size=(64, 5), y_size=(1, 10, 5), out_size=(4, 10, 5), - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), - additional=Dense(16 => 4), name="Scalar"), - (u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(4, 1, 10, 5), - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), - additional=Dense(16 => 4), name="Scalar II"), - (u_size=(64, 3, 5), y_size=(8, 10, 5), out_size=(4, 3, 10, 5), - branch=(64, 32, 32, 16), trunk=(8, 8, 8, 16), - additional=Dense(16 => 4), name="Vector") - ] +# setups = [ +# (u_size=(64, 5), y_size=(1, 10, 5), out_size=(4, 10, 5), +# branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), +# additional=Dense(16 => 4), name="Scalar"), +# (u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(4, 1, 10, 5), +# branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), +# additional=Dense(16 => 4), name="Scalar II"), +# (u_size=(64, 3, 5), y_size=(8, 10, 5), out_size=(4, 3, 10, 5), +# branch=(64, 32, 32, 16), trunk=(8, 8, 8, 16), +# additional=Dense(16 => 4), name="Vector") +# ] - @testset "Additional layer: $(setup.name)" for setup in setups - u = rand(Float32, setup.u_size...) |> aType - y = rand(Float32, setup.y_size...) |> aType - deeponet = DeepONet(; - branch=setup.branch, trunk=setup.trunk, additional=setup.additional) +# @testset "Additional layer: $(setup.name)" for setup in setups +# u = rand(Float32, setup.u_size...) |> aType +# y = rand(Float32, setup.y_size...) |> aType +# deeponet = DeepONet(; +# branch=setup.branch, trunk=setup.trunk, additional=setup.additional) - ps, st = Lux.setup(rng, deeponet) |> dev - @inferred first(deeponet((u, y), ps, st)) - @jet first(deeponet((u, y), ps, st)) +# ps, st = Lux.setup(rng, deeponet) |> dev +# @inferred first(deeponet((u, y), ps, st)) +# @jet first(deeponet((u, y), ps, st)) - pred = first(deeponet((u, y), ps, st)) - @test setup.out_size == size(pred) +# pred = first(deeponet((u, y), ps, st)) +# @test setup.out_size == size(pred) - __f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st))) - @test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3) - end +# __f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st))) +# @test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3) +# end - @testset "Embedding layer mismatch" begin - u = rand(Float32, 64, 5) |> aType - y = rand(Float32, 1, 10, 5) |> aType +# @testset "Embedding layer mismatch" begin +# u = rand(Float32, 64, 5) |> aType +# y = rand(Float32, 1, 10, 5) |> aType - deeponet = DeepONet( - Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)), - Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)) - ) +# deeponet = DeepONet( +# Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)), +# Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)) +# ) - ps, st = Lux.setup(rng, deeponet) |> dev - @test_throws ArgumentError deeponet((u, y), ps, st) - end - end -end +# ps, st = Lux.setup(rng, deeponet) |> dev +# @test_throws ArgumentError deeponet((u, y), ps, st) +# end +# end +# end diff --git a/test/fno_tests.jl b/test/fno_tests.jl index cf586ca..351b2df 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -1,39 +1,39 @@ -@testitem "Fourier Neural Operator" setup=[SharedTestSetup] begin - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) +# @testitem "Fourier Neural Operator" setup=[SharedTestSetup] begin +# @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES +# rng = StableRNG(12345) - setups = [ - (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), - x_size=(2, 1024, 5), y_size=(1, 1024, 5), permuted=Val(false)), - (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), - x_size=(1024, 2, 5), y_size=(1024, 1, 5), permuted=Val(true)) - ] +# setups = [ +# (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), +# x_size=(2, 1024, 5), y_size=(1, 1024, 5), permuted=Val(false)), +# (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), +# x_size=(1024, 2, 5), y_size=(1024, 1, 5), permuted=Val(true)) +# ] - @testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups - fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.permuted) - display(fno) - ps, st = Lux.setup(rng, fno) |> dev +# @testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups +# fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.permuted) +# display(fno) +# ps, st = Lux.setup(rng, fno) |> dev - x = rand(rng, Float32, setup.x_size...) |> aType - y = rand(rng, Float32, setup.y_size...) |> aType +# x = rand(rng, Float32, setup.x_size...) |> aType +# y = rand(rng, Float32, setup.y_size...) |> aType - @inferred fno(x, ps, st) - @jet fno(x, ps, st) +# @inferred fno(x, ps, st) +# @jet fno(x, ps, st) - @test size(first(fno(x, ps, st))) == setup.y_size +# @test size(first(fno(x, ps, st))) == setup.y_size - data = [(x, y)] - @test begin - l2, l1 = train!(fno, ps, st, data; epochs=10) - l2 < l1 - end +# data = [(x, y)] +# @test begin +# l2, l1 = train!(fno, ps, st, data; epochs=10) +# l2 < l1 +# end - __f = (x, ps) -> sum(abs2, first(fno(x, ps, st))) - @test_gradients(__f, x, - ps; - atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoTracker(), AutoEnzyme(), AutoReverseDiff()]) - end - end -end +# __f = (x, ps) -> sum(abs2, first(fno(x, ps, st))) +# @test_gradients(__f, x, +# ps; +# atol=1.0f-3, +# rtol=1.0f-3, +# skip_backends=[AutoTracker(), AutoEnzyme(), AutoReverseDiff()]) +# end +# end +# end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 0be6931..8a20ca3 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -1,49 +1,48 @@ -@testitem "SpectralConv & SpectralKernel" setup=[SharedTestSetup] begin - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) - - opconv = [SpectralConv, SpectralKernel] - setups = [ - (; m=(16,), permuted=Val(false), x_size=(2, 1024, 5), y_size=(128, 1024, 5)), - (; m=(16,), permuted=Val(true), x_size=(1024, 2, 5), y_size=(1024, 128, 5)), - (; m=(10, 10), permuted=Val(false), - x_size=(1, 22, 22, 5), y_size=(64, 22, 22, 5)), - (; m=(10, 10), permuted=Val(true), - x_size=(22, 22, 1, 5), y_size=(22, 22, 64, 5)) - ] - - @testset "$(op) $(length(setup.m))D: permuted = $(setup.permuted)" for setup in - setups, - op in opconv - - p = Lux.Utils.unwrap_val(setup.permuted) - in_chs = ifelse(p, setup.x_size[end - 1], first(setup.x_size)) - out_chs = ifelse(p, setup.y_size[end - 1], first(setup.y_size)) - ch = 64 => out_chs - - l1 = p ? Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch)) : - Dense(in_chs => first(ch)) - m = Chain(l1, op(ch, setup.m; setup.permuted)) - display(m) - ps, st = Lux.setup(rng, m) |> dev - - x = rand(rng, Float32, setup.x_size...) |> aType - @test size(first(m(x, ps, st))) == setup.y_size - @inferred m(x, ps, st) - @jet m(x, ps, st) - - data = [(x, aType(rand(rng, Float32, setup.y_size...)))] - @test begin - l2, l1 = train!(m, ps, st, data; epochs=10) - l2 < l1 +@testitem "SpectralConv & SpectralKernel" setup = [SharedTestSetup] begin + rng = StableRNG(12345) + + opconv = [SpectralConv, SpectralKernel] + setups = [ + (; m=(16,), x_size=(1024, 2, 5), y_size=(1024, 128, 5)), + (; m=(10, 10), x_size=(22, 22, 1, 5), y_size=(22, 22, 64, 5)), + ] + + rdev = reactant_device() + + @testset "$(op) $(length(setup.m))D" for setup in setups, op in opconv + in_chs = setup.x_size[end - 1] + out_chs = setup.y_size[end - 1] + ch = 64 => out_chs + + l1 = Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch)) + m = Chain(l1, op(ch, setup.m; setup.permuted)) + display(m) + ps, st = Lux.setup(rng, m) + + x = rand(rng, Float32, setup.x_size...) + @test size(first(m(x, ps, st))) == setup.y_size + + ps_ra, st_ra = rdev((ps, st)) + x_ra = rdev(x) + y_ra = rdev(rand(rng, Float32, setup.y_size...)) + + @test begin + l2, l1 = train!(MSELoss(), AutoEnzyme(), m, ps, st, [(x, y)]; epochs=10) + l2 < l1 + end + + @testset "check gradients" begin + ∂x_zyg, ∂ps_zyg = zygote_gradient(m, x, ps, st) + + ∂x_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + enzyme_gradient(m, x_ra, ps_ra, st_ra) end - __f = (x, ps) -> sum(abs2, first(m(x, ps, st))) - @test_gradients(__f, x, - ps; - atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoTracker(), AutoEnzyme(), AutoReverseDiff()]) + @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) end end end diff --git a/test/nomad_tests.jl b/test/nomad_tests.jl index c371fa4..8077f3f 100644 --- a/test/nomad_tests.jl +++ b/test/nomad_tests.jl @@ -1,28 +1,28 @@ -@testitem "NOMAD" setup=[SharedTestSetup] begin - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) +# @testitem "NOMAD" setup=[SharedTestSetup] begin +# @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES +# rng = StableRNG(12345) - setups = [ - (u_size=(1, 5), y_size=(1, 5), out_size=(1, 5), - approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"), - (u_size=(8, 5), y_size=(2, 5), out_size=(8, 5), - approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector") - ] +# setups = [ +# (u_size=(1, 5), y_size=(1, 5), out_size=(1, 5), +# approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"), +# (u_size=(8, 5), y_size=(2, 5), out_size=(8, 5), +# approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector") +# ] - @testset "$(setup.name)" for setup in setups - u = rand(Float32, setup.u_size...) |> aType - y = rand(Float32, setup.y_size...) |> aType - nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) +# @testset "$(setup.name)" for setup in setups +# u = rand(Float32, setup.u_size...) |> aType +# y = rand(Float32, setup.y_size...) |> aType +# nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) - ps, st = Lux.setup(rng, nomad) |> dev - @inferred first(nomad((u, y), ps, st)) - @jet first(nomad((u, y), ps, st)) +# ps, st = Lux.setup(rng, nomad) |> dev +# @inferred first(nomad((u, y), ps, st)) +# @jet first(nomad((u, y), ps, st)) - pred = first(nomad((u, y), ps, st)) - @test setup.out_size == size(pred) +# pred = first(nomad((u, y), ps, st)) +# @test setup.out_size == size(pred) - __f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st))) - @test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3) - end - end -end +# __f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st))) +# @test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3) +# end +# end +# end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 7305720..52dd0fc 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,19 +1,23 @@ -@testitem "doctests: Quality Assurance" tags=[:qa] begin +@testitem "doctests: Quality Assurance" tags = [:qa] begin using Documenter, NeuralOperators - DocMeta.setdocmeta!(NeuralOperators, :DocTestSetup, - :(using Lux, NeuralOperators, Random); recursive=true) + DocMeta.setdocmeta!( + NeuralOperators, + :DocTestSetup, + :(using Lux, NeuralOperators, Random); + recursive=true, + ) doctest(NeuralOperators; manual=false) end -@testitem "Aqua: Quality Assurance" tags=[:qa] begin +@testitem "Aqua: Quality Assurance" tags = [:qa] begin using Aqua Aqua.test_all(NeuralOperators; ambiguities=false) Aqua.test_ambiguities(NeuralOperators; recursive=false) end -@testitem "Explicit Imports: Quality Assurance" tags=[:qa] begin +@testitem "Explicit Imports: Quality Assurance" tags = [:qa] begin using ExplicitImports, Lux # Skip our own packages @@ -22,7 +26,7 @@ end @test check_no_self_qualified_accesses(NeuralOperators) === nothing @test check_all_explicit_imports_via_owners(NeuralOperators) === nothing @test check_all_qualified_accesses_via_owners(NeuralOperators) === nothing - if VERSION >= v"1.11-" + if VERSION ≥ v"1.11-" @test_broken check_all_explicit_imports_are_public(NeuralOperators) === nothing # mostly upstream problems @test_broken check_all_qualified_accesses_are_public(NeuralOperators) === nothing # mostly upstream problems end diff --git a/test/runtests.jl b/test/runtests.jl index 1987473..aaec123 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,31 +1,16 @@ -using Preferences - -Preferences.set_preferences!("LuxLib", "instability_check" => "error") -Preferences.set_preferences!("LuxCore", "instability_check" => "error") - -using ReTestItems, Pkg, Test, InteractiveUtils, Hwloc, NeuralOperators +using ReTestItems, Test, Hwloc, NeuralOperators, Reactant const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) -const EXTRA_PKGS = String[] -(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") - -if !isempty(EXTRA_PKGS) - @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.add(EXTRA_PKGS) - Pkg.update() - Base.retry_load_extensions() - Pkg.instantiate() -end - -const RETESTITEMS_NWORKERS = parse( - Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) -const RETESTITEMS_NWORKER_THREADS = parse(Int, - get(ENV, "RETESTITEMS_NWORKER_THREADS", - string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) +const RETESTITEMS_NWORKER_THREADS = parse( + Int, get(ENV, "RETESTITEMS_NWORKER_THREADS", string(Hwloc.num_virtual_cores())) +) @testset "NeuralOperators.jl Tests" begin - ReTestItems.runtests(NeuralOperators; nworkers=RETESTITEMS_NWORKERS, - nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) + ReTestItems.runtests( + NeuralOperators; + nworkers=0, + nworker_threads=RETESTITEMS_NWORKER_THREADS, + testitem_timeout=3600, + ) end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 6dcb2bf..29a123e 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,39 +1,11 @@ @testsetup module SharedTestSetup import Reexport: @reexport -@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, LuxTestUtils -using MLDataDevices - -LuxTestUtils.jet_target_modules!(["NeuralOperators", "Lux", "LuxLib"]) +@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, Reactant +using LuxTestUtils: check_approx const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) -if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" - using LuxCUDA -end - -if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" - using AMDGPU -end - -cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" -function cuda_testing() - return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && - MLDataDevices.functional(CUDADevice) -end -function amdgpu_testing() - return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && - MLDataDevices.functional(AMDGPUDevice) -end - -const MODES = begin - modes = [] - cpu_testing() && push!(modes, ("CPU", Array, CPUDevice(), false)) - cuda_testing() && push!(modes, ("CUDA", CuArray, CUDADevice(), true)) - amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, AMDGPUDevice(), true)) - modes -end - train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...) function train!(loss, backend, model, ps, st, data; epochs=10) @@ -41,7 +13,6 @@ function train!(loss, backend, model, ps, st, data; epochs=10) tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) for _ in 1:epochs, (x, y) in data - _, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate) end @@ -50,7 +21,18 @@ function train!(loss, backend, model, ps, st, data; epochs=10) return l2, l1 end +sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) + +function zygote_gradient(model, x, ps, st) + return Zygote.gradient(sumabs2first, model, x, ps, st)[2:3] +end + +function enzyme_gradient(model, x, ps, st) + return Enzyme.gradient(Reverse, sumabs2first, Const(model), x, ps, Const(st))[2:3] +end + export check_approx -export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, train! +export BACKEND_GROUP, train! +export zygote_gradient, enzyme_gradient end diff --git a/test/utils_tests.jl b/test/utils_tests.jl deleted file mode 100644 index 801e4a9..0000000 --- a/test/utils_tests.jl +++ /dev/null @@ -1,58 +0,0 @@ -@testitem "utils" setup=[SharedTestSetup] begin - import NeuralOperators: deeponet_project, nomad_concatenate, batch_vectorize - - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) - - setups = [ - (b_size=(16, 5), t_size=(16, 10, 5), out_size=(10, 5), - additional=NoOpLayer(), name="Scalar"), - (b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(1, 10, 5), - additional=NoOpLayer(), name="Scalar II"), - (b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(3, 10, 5), - additional=NoOpLayer(), name="Vector"), - (b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), - out_size=(4, 3, 3, 10, 5), additional=NoOpLayer(), name="Tensor"), - (b_size=(16, 5), t_size=(16, 10, 5), out_size=(4, 10, 5), - additional=Dense(16 => 4), name="additional : Scalar"), - (b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(4, 1, 10, 5), - additional=Dense(16 => 4), name="additional : Scalar II"), - (b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(4, 3, 10, 5), - additional=Dense(16 => 4), name="additional : Vector"), - (b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), out_size=(3, 4, 3, 4, 10, 5), - additional=Chain(Dense(16 => 4), ReshapeLayer((3, 4, 3, 4, 10))), - name="additional : Tensor") - ] - - @testset "project : $(setup.name)" for setup in setups - b = rand(Float32, setup.b_size...) |> aType - t = rand(Float32, setup.t_size...) |> aType - - ps, st = Lux.setup(rng, setup.additional) |> dev - additional = setup.additional isa NoOpLayer ? nothing : - StatefulLuxLayer{true}(setup.additional, ps, st) - - @inferred deeponet_project(b, t, additional) - @jet deeponet_project(b, t, additional) - @test setup.out_size == size(deeponet_project(b, t, additional)) - end - - setups = [(x_size=(6, 5), y_size=(4, 5), out_size=(10, 5), name="Scalar"), - (x_size=(12, 5), y_size=(8, 5), out_size=(20, 5), name="Vector I"), - (x_size=(4, 6, 5), y_size=(6, 5), out_size=(30, 5), name="Vector II"), - (x_size=(4, 2, 3, 5), y_size=(2, 2, 3, 5), out_size=(36, 5), name="Tensor")] - - @testset "nomad_concatenate $(setup.name)" for setup in setups - x_size = rand(Float32, setup.x_size...) |> aType - y_size = rand(Float32, setup.y_size...) |> aType - - @test setup.out_size == size(nomad_concatenate(x_size, y_size)) - end - - @testset "batch vectorize" begin - x_size = (4, 2, 3) - x = rand(Float32, x_size..., 5) |> aType - @test size(batch_vectorize(x)) == (prod(x_size), 5) - end - end -end From 8e33dd24ea35d6b4c180e3822eb7415af6da4d14 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 May 2025 21:44:07 -0400 Subject: [PATCH 02/24] fix: more test fixes --- .github/workflows/CI.yml | 2 +- Project.toml | 6 ++++-- src/NeuralOperators.jl | 6 ++++-- src/layers.jl | 2 +- test/Project.toml | 1 + test/layers_tests.jl | 6 ++++-- test/qa_tests.jl | 6 ++---- test/shared_testsetup.jl | 5 +++-- 8 files changed, 20 insertions(+), 14 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1b306d2..f440168 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -29,10 +29,10 @@ jobs: matrix: version: - "1.10" + - "1" os: - ubuntu-latest - macos-latest - - windows-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/Project.toml b/Project.toml index 425e257..15a935d 100644 --- a/Project.toml +++ b/Project.toml @@ -8,8 +8,9 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] @@ -17,7 +18,8 @@ AbstractFFTs = "1.5.0" ConcreteStructs = "0.2.3" Lux = "1" LuxCore = "1" +LuxLib = "1.8.0" +NNlib = "0.9.30" Random = "1.10" -Reactant = "0.2.122" WeightInitializers = "1" julia = "1.10" diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 315aac6..e847fd6 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -4,8 +4,10 @@ using AbstractFFTs: rfft, irfft using ConcreteStructs: @concrete using Random: Random, AbstractRNG -using Lux -using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer +using Lux: Lux, Conv, Parallel +using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer +using LuxLib: fast_activation!! +using NNlib: batched_mul, pad_constant using WeightInitializers: glorot_uniform include("utils.jl") diff --git a/src/layers.jl b/src/layers.jl index 0488f4e..dfb63df 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -63,7 +63,7 @@ function operator_conv(x, tform::AbstractTransform, weights) x_p = apply_pattern(x_tr, weights) pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)] - x_padded = NNlib.pad_constant( + x_padded = pad_constant( x_p, expand_pad_dims(pad_dims), false; dims=ntuple(identity, ndims(x_p) - 2) ) diff --git a/test/Project.toml b/test/Project.toml index 64ddc63..51b3fef 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 8a20ca3..e72497e 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -15,7 +15,7 @@ ch = 64 => out_chs l1 = Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch)) - m = Chain(l1, op(ch, setup.m; setup.permuted)) + m = Chain(l1, op(ch, setup.m)) display(m) ps, st = Lux.setup(rng, m) @@ -27,7 +27,9 @@ y_ra = rdev(rand(rng, Float32, setup.y_size...)) @test begin - l2, l1 = train!(MSELoss(), AutoEnzyme(), m, ps, st, [(x, y)]; epochs=10) + l2, l1 = train!( + MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + ) l2 < l1 end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 52dd0fc..ed099ec 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -26,8 +26,6 @@ end @test check_no_self_qualified_accesses(NeuralOperators) === nothing @test check_all_explicit_imports_via_owners(NeuralOperators) === nothing @test check_all_qualified_accesses_via_owners(NeuralOperators) === nothing - if VERSION ≥ v"1.11-" - @test_broken check_all_explicit_imports_are_public(NeuralOperators) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(NeuralOperators) === nothing # mostly upstream problems - end + @test check_all_explicit_imports_are_public(NeuralOperators) === nothing + @test check_all_qualified_accesses_are_public(NeuralOperators) === nothing end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 29a123e..b82505a 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -3,20 +3,21 @@ import Reexport: @reexport @reexport using Lux, Zygote, Optimisers, Random, StableRNGs, Reactant using LuxTestUtils: check_approx +using FFTW const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...) function train!(loss, backend, model, ps, st, data; epochs=10) - l1 = loss(model, ps, st, first(data)) + l1 = @jit loss(model, ps, st, first(data)) tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) for _ in 1:epochs, (x, y) in data _, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate) end - l2 = loss(model, ps, st, first(data)) + l2 = @jit loss(model, ps, st, first(data)) return l2, l1 end From b80b2b23de372563dfb7621871c55c946bdaf6a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 May 2025 22:23:10 -0400 Subject: [PATCH 03/24] feat: simplify nomad --- src/NeuralOperators.jl | 7 ++-- src/models/nomad.jl | 47 +++++++----------------- test/nomad_tests.jl | 81 +++++++++++++++++++++++++++--------------- test/runtests.jl | 2 +- 4 files changed, 70 insertions(+), 67 deletions(-) diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index e847fd6..c994837 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -4,7 +4,7 @@ using AbstractFFTs: rfft, irfft using ConcreteStructs: @concrete using Random: Random, AbstractRNG -using Lux: Lux, Conv, Parallel +using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer using LuxLib: fast_activation!! using NNlib: batched_mul, pad_constant @@ -17,12 +17,13 @@ include("layers.jl") # include("models/fno.jl") # include("models/deeponet.jl") -# include("models/nomad.jl") +include("models/nomad.jl") export FourierTransform export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel + # export FourierNeuralOperator # export DeepONet -# export NOMAD +export NOMAD end diff --git a/src/models/nomad.jl b/src/models/nomad.jl index fc2de42..5b671ab 100644 --- a/src/models/nomad.jl +++ b/src/models/nomad.jl @@ -1,5 +1,5 @@ """ - NOMAD(approximator, decoder, concatenate) + NOMAD(approximator, decoder) Constructs a NOMAD from `approximator` and `decoder` architectures. Make sure the output from `approximator` combined with the coordinate dimension has compatible size for input to @@ -10,12 +10,6 @@ from `approximator` combined with the coordinate dimension has compatible size f - `approximator`: `Lux` network to be used as approximator net. - `decoder`: `Lux` network to be used as decoder net. -## Keyword Arguments - - - `concatenate`: function that defines the concatenation of output from `approximator` and - the coordinate dimension, defaults to concatenation along first dimension after - vectorizing the tensors - ## References [1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD: @@ -40,15 +34,19 @@ julia> size(first(nomad((u, y), ps, st))) (8, 5) ``` """ -@concrete struct NOMAD <: AbstractLuxContainerLayer{(:approximator, :decoder)} - approximator - decoder - concatenate <: Function +@concrete struct NOMAD <: AbstractLuxWrapperLayer{:model} + model +end + +function NOMAD(approximator, decoder) + return NOMAD(Chain(; approximator=Parallel(vcat, approximator, NoOpLayer()), decoder)) end """ - NOMAD(; approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8), - approximator_activation = identity, decoder_activation = identity) + NOMAD(; + approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8), + approximator_activation = identity, decoder_activation = identity + ) Constructs a NOMAD composed of Dense layers. Make sure that last node of `approximator` + coordinate length = first node of `decoder`. @@ -61,9 +59,6 @@ coordinate length = first node of `decoder`. net - `approximator_activation`: activation function for approximator net - `decoder_activation`: activation function for decoder net - - `concatenate`: function that defines the concatenation of output from `approximator` and - the coordinate dimension, defaults to concatenation along first dimension after - vectorizing the tensors ## References @@ -90,7 +85,6 @@ function NOMAD(; decoder=(18, 16, 8, 8), approximator_activation=identity, decoder_activation=identity, - concatenate=nomad_concatenate, ) approximator_net = Chain( [ @@ -106,22 +100,5 @@ function NOMAD(; ]..., ) - return NOMAD(approximator_net, decoder_net, concatenate) -end - -function (nomad::NOMAD)(x, ps, st::NamedTuple) - a, st_a = nomad.approximator(x[1], ps.approximator, st.approximator) - out, st_d = nomad.decoder(nomad.concatenate(a, x[2]), ps.decoder, st.decoder) - return out, (approximator=st_a, decoder=st_d) -end - -function NOMAD(approximator_net, decoder_net; concatenate=nomad_concatenate) - return NOMAD(approximator_net, decoder_net, concatenate) -end - -batch_vectorize(x::AbstractArray) = reshape(x, :, size(x, ndims(x))) - -nomad_concatenate(x::AbstractMatrix, y::AbstractMatrix) = cat(x, y; dims=1) -function nomad_concatenate(x::AbstractArray, y::AbstractArray) - return nomad_concatenate(batch_vectorize(x), batch_vectorize(y)) + return NOMAD(approximator_net, decoder_net) end diff --git a/test/nomad_tests.jl b/test/nomad_tests.jl index 8077f3f..c97f1cb 100644 --- a/test/nomad_tests.jl +++ b/test/nomad_tests.jl @@ -1,28 +1,53 @@ -# @testitem "NOMAD" setup=[SharedTestSetup] begin -# @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES -# rng = StableRNG(12345) - -# setups = [ -# (u_size=(1, 5), y_size=(1, 5), out_size=(1, 5), -# approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"), -# (u_size=(8, 5), y_size=(2, 5), out_size=(8, 5), -# approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector") -# ] - -# @testset "$(setup.name)" for setup in setups -# u = rand(Float32, setup.u_size...) |> aType -# y = rand(Float32, setup.y_size...) |> aType -# nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) - -# ps, st = Lux.setup(rng, nomad) |> dev -# @inferred first(nomad((u, y), ps, st)) -# @jet first(nomad((u, y), ps, st)) - -# pred = first(nomad((u, y), ps, st)) -# @test setup.out_size == size(pred) - -# __f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st))) -# @test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3) -# end -# end -# end +@testitem "NOMAD" setup = [SharedTestSetup] begin + rng = StableRNG(12345) + + setups = [ + ( + u_size=(1, 5), + y_size=(1, 5), + out_size=(1, 5), + approximator=(1, 16, 16, 15), + decoder=(16, 8, 4, 1), + name="Scalar", + ), + ( + u_size=(8, 5), + y_size=(2, 5), + out_size=(8, 5), + approximator=(8, 32, 32, 16), + decoder=(18, 16, 8, 8), + name="Vector", + ), + ] + + xdev = reactant_device() + + @testset "$(setup.name)" for setup in setups + u = rand(Float32, setup.u_size...) + y = rand(Float32, setup.y_size...) + nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) + + ps, st = Lux.setup(rng, nomad) + + pred = first(nomad((u, y), ps, st)) + @test setup.out_size == size(pred) + + ps_ra, st_ra = xdev((ps, st)) + u_ra, y_ra = xdev(u), xdev(y) + + @testset "check gradients" begin + ∂u_zyg, ∂ps_zyg = zygote_gradient(nomad, (u, y), ps, st) + + ∂u_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + enzyme_gradient(nomad, (u_ra, y_ra), ps_ra, st_ra) + end + + @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-3 rtol = 1.0f-3 + @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-3 rtol = 1.0f-3 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index aaec123..c777622 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,7 @@ const RETESTITEMS_NWORKER_THREADS = parse( @testset "NeuralOperators.jl Tests" begin ReTestItems.runtests( NeuralOperators; - nworkers=0, + nworkers=1, nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600, ) From 0ea99b254d12e76e2415ff354c2445500a5222d9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 May 2025 22:36:06 -0400 Subject: [PATCH 04/24] test: temporarily mark the test as broken --- test/layers_tests.jl | 48 +++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/test/layers_tests.jl b/test/layers_tests.jl index e72497e..a358b69 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -3,8 +3,8 @@ opconv = [SpectralConv, SpectralKernel] setups = [ - (; m=(16,), x_size=(1024, 2, 5), y_size=(1024, 128, 5)), - (; m=(10, 10), x_size=(22, 22, 1, 5), y_size=(22, 22, 64, 5)), + (; m=(16,), x_size=(1024, 2, 5), y_size=(1024, 16, 5)), + (; m=(10, 10), x_size=(22, 22, 1, 5), y_size=(22, 22, 16, 5)), ] rdev = reactant_device() @@ -12,7 +12,7 @@ @testset "$(op) $(length(setup.m))D" for setup in setups, op in opconv in_chs = setup.x_size[end - 1] out_chs = setup.y_size[end - 1] - ch = 64 => out_chs + ch = 4 => out_chs l1 = Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch)) m = Chain(l1, op(ch, setup.m)) @@ -26,25 +26,27 @@ x_ra = rdev(x) y_ra = rdev(rand(rng, Float32, setup.y_size...)) - @test begin - l2, l1 = train!( - MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 - ) - l2 < l1 - end - - @testset "check gradients" begin - ∂x_zyg, ∂ps_zyg = zygote_gradient(m, x, ps, st) - - ∂x_ra, ∂ps_ra = Reactant.with_config(; - dot_general_precision=PrecisionConfig.HIGH, - convolution_precision=PrecisionConfig.HIGH, - ) do - enzyme_gradient(m, x_ra, ps_ra, st_ra) - end - - @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) - end + # XXX: upstream fix is needed for the FFT adjoint to work correctly + # https://github.com/EnzymeAD/Reactant.jl/issues/246 + # @test begin + # l2, l1 = train!( + # MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + # ) + # l2 < l1 + # end + + # @testset "check gradients" begin + # ∂x_zyg, ∂ps_zyg = zygote_gradient(m, x, ps, st) + + # ∂x_ra, ∂ps_ra = Reactant.with_config(; + # dot_general_precision=PrecisionConfig.HIGH, + # convolution_precision=PrecisionConfig.HIGH, + # ) do + # @jit enzyme_gradient(m, x_ra, ps_ra, st_ra) + # end + + # @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 + # @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + # end end end From 303e69f3a8a5c5c8a7ff8bfc58418c001411173e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 May 2025 23:03:45 -0400 Subject: [PATCH 05/24] test: more test fixes --- src/NeuralOperators.jl | 2 +- src/models/deeponet.jl | 118 ++++++--------------------------------- src/utils.jl | 3 +- test/Project.toml | 2 +- test/deeponet_tests.jl | 101 +++++++++++++++++++++++---------- test/layers_tests.jl | 1 + test/nomad_tests.jl | 3 +- test/qa_tests.jl | 2 - test/shared_testsetup.jl | 2 +- 9 files changed, 98 insertions(+), 136 deletions(-) diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index c994837..6afcc4b 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -7,7 +7,7 @@ using Random: Random, AbstractRNG using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer using LuxLib: fast_activation!! -using NNlib: batched_mul, pad_constant +using NNlib: NNlib, batched_mul, pad_constant using WeightInitializers: glorot_uniform include("utils.jl") diff --git a/src/models/deeponet.jl b/src/models/deeponet.jl index 84f7551..9806ccd 100644 --- a/src/models/deeponet.jl +++ b/src/models/deeponet.jl @@ -9,11 +9,6 @@ nets output should have the same first dimension. - `branch`: `Lux` network to be used as branch net. - `trunk`: `Lux` network to be used as trunk net. -## Keyword Arguments - - - `additional`: `Lux` network to pass the output of DeepONet, to include additional - operations for embeddings, defaults to `nothing` - ## References [1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for @@ -27,9 +22,7 @@ We are given several (b = 200) instances of the IC, discretized at 50 points eac to query the solution for 100 different locations and times [0;1]. That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100]. So, -the input for the branch net is 50 and 100 for the trunk net. Note that the inputs must be -batched so the branch input is of shape [50 x 200 x 1] and the trunk input is of shape -[2 x 100 x 1]. +the input for the branch net is 50 and 100 for the trunk net. ## Example @@ -44,23 +37,27 @@ julia> ps, st = Lux.setup(Xoshiro(), deeponet); julia> u = rand(Float32, 64, 5); -julia> y = rand(Float32, 1, 10, 5); +julia> y = rand(Float32, 1, 10); julia> size(first(deeponet((u, y), ps, st))) (10, 5) ``` """ -@concrete struct DeepONet <: AbstractLuxContainerLayer{(:branch, :trunk, :additional)} - branch - trunk - additional +@concrete struct DeepONet <: AbstractLuxWrapperLayer{:model} + model end -DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer()) +function DeepONet(branch, trunk) + return DeepONet( + Parallel(*; branch=Chain(branch, WrappedFunction(adjoint)), trunk=trunk) + ) +end """ - DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), - branch_activation = identity, trunk_activation = identity) + DeepONet(; + branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), + branch_activation = identity, trunk_activation = identity + ) Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and `trunk` are same. @@ -71,8 +68,6 @@ Constructs a DeepONet composed of Dense layers. Make sure the last node of `bran - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net - `branch_activation`: activation function for branch net - `trunk_activation`: activation function for trunk net - - `additional`: `Lux` network to pass the output of DeepONet, to include additional - operations for embeddings, defaults to `nothing` ## References @@ -89,7 +84,7 @@ julia> ps, st = Lux.setup(Xoshiro(), deeponet); julia> u = rand(Float32, 64, 5); -julia> y = rand(Float32, 1, 10, 5); +julia> y = rand(Float32, 1, 10); julia> size(first(deeponet((u, y), ps, st))) (10, 5) @@ -100,13 +95,12 @@ function DeepONet(; trunk=(1, 8, 8, 16), branch_activation=identity, trunk_activation=identity, - additional=NoOpLayer(), ) # checks for last dimension size - @argcheck branch[end] == trunk[end] "Branch and Trunk net must share the same amount of \ - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ - work." + @argcheck branch[end] == trunk[end] "Branch and Trunk net must share the same amount \ + of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \ + won't work." branch_net = Chain( [ @@ -126,81 +120,5 @@ function DeepONet(; ]..., ) - return DeepONet(branch_net, trunk_net, additional) -end - -function (deeponet::DeepONet)((x1, x2), ps, st::NamedTuple) - b, st_b = deeponet.branch(x1, ps.branch, st.branch) - t, st_t = deeponet.trunk(x2, ps.trunk, st.trunk) - - @argcheck size(b, 1) == size(t, 1) "Branch and Trunk net must share the same amount of \ - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ - work." - - additional = if deeponet.additional isa NoOpLayer - nothing - else - StatefulLuxLayer{true}(deeponet.additional, ps.additional, st.additional) - end - out = deeponet_project(b, t, additional) - - stₙ = merge( - (; branch=st_b, trunk=st_t), deeponet.additional isa NoOpLayer ? (;) : additional.st - ) - return out, stₙ -end - -function deeponet_project( - b::AbstractArray{T1,2}, t::AbstractArray{T2,3}, ::Nothing -) where {T1,T2} - # b [p, nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), 1, size(b, 2)) - return dropdims(sum(bᵣ .* t; dims=1); dims=1) # [N, nb] -end - -function deeponet_project( - b::AbstractArray{T1,3}, t::AbstractArray{T2,3}, ::Nothing -) where {T1,T2} - # b [p, u, nb], t [p, N, nb] - return batched_matmul(safe_batched_adjoint(b), t) # [u, N, b] -end - -function deeponet_project( - b::AbstractArray{T1,N}, t::AbstractArray{T2,3}, ::Nothing -) where {T1,T2,N} - # b [p, u_size..., nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), :, size(b, N)) - return reshape( - batched_matmul(safe_batched_adjoint(bᵣ), t), - size(b)[2:(N - 1)]..., - size(t, 2), - size(b, N), - ) -end - -function deeponet_project( - b::AbstractArray{T1,2}, t::AbstractArray{T2,3}, additional -) where {T1,T2} - # b [p, nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), 1, size(b, 2)) - return additional(bᵣ .* t) # [p, N, nb] => [out_dims, N, nb] -end - -function deeponet_project( - b::AbstractArray{T1,3}, t::AbstractArray{T2,3}, additional -) where {T1,T2} - # b [p, u, nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), size(b, 2), 1, size(b, 3)) # [p, u, 1, nb] - tᵣ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # [p, 1, N, nb] - return additional(bᵣ .* tᵣ) # [p, u, N, nb] => [out_size, u, N, nb] -end - -function deeponet_project( - b::AbstractArray{T1,N}, t::AbstractArray{T2,3}, additional -) where {T1,T2,N} - # b [p, u_size..., nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), :, 1, size(b, N)) # [p, (u_size...), 1, nb] - tᵣ = reshape(t, size(t, 1), 1, size(t, 2), size(t, 3)) # [p, 1, N, nb] - bᵣtᵣ = reshape(bᵣ .* tᵣ, size(b, 1), size(b)[2:(N - 1)]..., size(t, 2), size(b, N)) - return additional(bᵣtᵣ) # [p, u_size..., N, nb] => [out_size, u_size..., N, nb] + return DeepONet(branch_net, trunk_net) end diff --git a/src/utils.jl b/src/utils.jl index f42dfe8..7b4bb3b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,7 +11,8 @@ function apply_pattern( end function add_act(act::F, x1, x2) where {F} - return fast_activation!!(NNlib.fast_act(act, y), x1 .+ x2) + y = x1 .+ x2 + return fast_activation!!(NNlib.fast_act(act, y), y) end @concrete struct Fix1 <: Function diff --git a/test/Project.toml b/test/Project.toml index 51b3fef..4b9ba91 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Aqua = "0.8.7" Documenter = "1.5.0" ExplicitImports = "1.9.0" -Hwloc = "3.2.0" +Hwloc = "3.2" Lux = "1" LuxCore = "1" LuxLib = "1.2" diff --git a/test/deeponet_tests.jl b/test/deeponet_tests.jl index 418b679..048ef10 100644 --- a/test/deeponet_tests.jl +++ b/test/deeponet_tests.jl @@ -1,24 +1,48 @@ -# @testitem "DeepONet" setup=[SharedTestSetup] begin +# @testitem "DeepONet" setup = [SharedTestSetup] begin # @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES # rng = StableRNG(12345) # setups = [ -# (u_size=(64, 5), y_size=(1, 10, 5), out_size=(10, 5), -# branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar"), -# (u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(1, 10, 5), -# branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar II"), -# (u_size=(64, 3, 5), y_size=(4, 10, 5), out_size=(3, 10, 5), -# branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Vector"), -# (u_size=(64, 4, 3, 3, 5), y_size=(4, 10, 5), out_size=(4, 3, 3, 10, 5), -# branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor") +# ( +# u_size=(64, 5), +# y_size=(1, 10, 5), +# out_size=(10, 5), +# branch=(64, 32, 32, 16), +# trunk=(1, 8, 8, 16), +# name="Scalar", +# ), +# ( +# u_size=(64, 1, 5), +# y_size=(1, 10, 5), +# out_size=(1, 10, 5), +# branch=(64, 32, 32, 16), +# trunk=(1, 8, 8, 16), +# name="Scalar II", +# ), +# ( +# u_size=(64, 3, 5), +# y_size=(4, 10, 5), +# out_size=(3, 10, 5), +# branch=(64, 32, 32, 16), +# trunk=(4, 8, 8, 16), +# name="Vector", +# ), +# ( +# u_size=(64, 4, 3, 3, 5), +# y_size=(4, 10, 5), +# out_size=(4, 3, 3, 10, 5), +# branch=(64, 32, 32, 16), +# trunk=(4, 8, 8, 16), +# name="Tensor", +# ), # ] # @testset "$(setup.name)" for setup in setups -# u = rand(Float32, setup.u_size...) |> aType -# y = rand(Float32, setup.y_size...) |> aType +# u = aType(rand(Float32, setup.u_size...)) +# y = aType(rand(Float32, setup.y_size...)) # deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) -# ps, st = Lux.setup(rng, deeponet) |> dev +# ps, st = dev(Lux.setup(rng, deeponet)) # @inferred first(deeponet((u, y), ps, st)) # @jet first(deeponet((u, y), ps, st)) @@ -27,24 +51,43 @@ # end # setups = [ -# (u_size=(64, 5), y_size=(1, 10, 5), out_size=(4, 10, 5), -# branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), -# additional=Dense(16 => 4), name="Scalar"), -# (u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(4, 1, 10, 5), -# branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), -# additional=Dense(16 => 4), name="Scalar II"), -# (u_size=(64, 3, 5), y_size=(8, 10, 5), out_size=(4, 3, 10, 5), -# branch=(64, 32, 32, 16), trunk=(8, 8, 8, 16), -# additional=Dense(16 => 4), name="Vector") +# ( +# u_size=(64, 5), +# y_size=(1, 10, 5), +# out_size=(4, 10, 5), +# branch=(64, 32, 32, 16), +# trunk=(1, 8, 8, 16), +# additional=Dense(16 => 4), +# name="Scalar", +# ), +# ( +# u_size=(64, 1, 5), +# y_size=(1, 10, 5), +# out_size=(4, 1, 10, 5), +# branch=(64, 32, 32, 16), +# trunk=(1, 8, 8, 16), +# additional=Dense(16 => 4), +# name="Scalar II", +# ), +# ( +# u_size=(64, 3, 5), +# y_size=(8, 10, 5), +# out_size=(4, 3, 10, 5), +# branch=(64, 32, 32, 16), +# trunk=(8, 8, 8, 16), +# additional=Dense(16 => 4), +# name="Vector", +# ), # ] # @testset "Additional layer: $(setup.name)" for setup in setups -# u = rand(Float32, setup.u_size...) |> aType -# y = rand(Float32, setup.y_size...) |> aType +# u = aType(rand(Float32, setup.u_size...)) +# y = aType(rand(Float32, setup.y_size...)) # deeponet = DeepONet(; -# branch=setup.branch, trunk=setup.trunk, additional=setup.additional) +# branch=setup.branch, trunk=setup.trunk, additional=setup.additional +# ) -# ps, st = Lux.setup(rng, deeponet) |> dev +# ps, st = dev(Lux.setup(rng, deeponet)) # @inferred first(deeponet((u, y), ps, st)) # @jet first(deeponet((u, y), ps, st)) @@ -56,15 +99,15 @@ # end # @testset "Embedding layer mismatch" begin -# u = rand(Float32, 64, 5) |> aType -# y = rand(Float32, 1, 10, 5) |> aType +# u = aType(rand(Float32, 64, 5)) +# y = aType(rand(Float32, 1, 10, 5)) # deeponet = DeepONet( # Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)), -# Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)) +# Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)), # ) -# ps, st = Lux.setup(rng, deeponet) |> dev +# ps, st = dev(Lux.setup(rng, deeponet)) # @test_throws ArgumentError deeponet((u, y), ps, st) # end # end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index a358b69..34e73e5 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -44,6 +44,7 @@ # ) do # @jit enzyme_gradient(m, x_ra, ps_ra, st_ra) # end + # ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() # @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 # @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) diff --git a/test/nomad_tests.jl b/test/nomad_tests.jl index c97f1cb..8567967 100644 --- a/test/nomad_tests.jl +++ b/test/nomad_tests.jl @@ -42,8 +42,9 @@ dot_general_precision=PrecisionConfig.HIGH, convolution_precision=PrecisionConfig.HIGH, ) do - enzyme_gradient(nomad, (u_ra, y_ra), ps_ra, st_ra) + @jit enzyme_gradient(nomad, (u_ra, y_ra), ps_ra, st_ra) end + ∂u_ra, ∂ps_ra = (∂u_ra, ∂ps_ra) |> cpu_device() @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-3 rtol = 1.0f-3 @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-3 rtol = 1.0f-3 diff --git a/test/qa_tests.jl b/test/qa_tests.jl index ed099ec..57c8042 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -26,6 +26,4 @@ end @test check_no_self_qualified_accesses(NeuralOperators) === nothing @test check_all_explicit_imports_via_owners(NeuralOperators) === nothing @test check_all_qualified_accesses_via_owners(NeuralOperators) === nothing - @test check_all_explicit_imports_are_public(NeuralOperators) === nothing - @test check_all_qualified_accesses_are_public(NeuralOperators) === nothing end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index b82505a..e453fa8 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,7 +1,7 @@ @testsetup module SharedTestSetup import Reexport: @reexport -@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, Reactant +@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, Reactant, Enzyme using LuxTestUtils: check_approx using FFTW From 2e5054a7dcf58bc8b9b30e8d677b44494c850e64 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 May 2025 23:12:49 -0400 Subject: [PATCH 06/24] fix: deeponets --- .JuliaFormatter.toml | 10 +-- src/NeuralOperators.jl | 6 +- src/models/deeponet.jl | 11 ++- test/deeponet_tests.jl | 166 +++++++++++++---------------------------- 4 files changed, 64 insertions(+), 129 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index f444ca1..28be623 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,9 +1,3 @@ style = "blue" -format_markdown = true -whitespace_in_kwargs = false -margin = 92 -indent = 4 -format_docstrings = true -separate_kwargs_with_semicolon = true -always_for_in = true -annotate_untyped_fields_with_any = false +pipe_to_function_call = false +always_use_return = true diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 6afcc4b..641a050 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -4,7 +4,7 @@ using AbstractFFTs: rfft, irfft using ConcreteStructs: @concrete using Random: Random, AbstractRNG -using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer +using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer, WrappedFunction using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer using LuxLib: fast_activation!! using NNlib: NNlib, batched_mul, pad_constant @@ -16,14 +16,14 @@ include("transform.jl") include("layers.jl") # include("models/fno.jl") -# include("models/deeponet.jl") +include("models/deeponet.jl") include("models/nomad.jl") export FourierTransform export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel # export FourierNeuralOperator -# export DeepONet +export DeepONet export NOMAD end diff --git a/src/models/deeponet.jl b/src/models/deeponet.jl index 9806ccd..e87a5db 100644 --- a/src/models/deeponet.jl +++ b/src/models/deeponet.jl @@ -49,7 +49,10 @@ end function DeepONet(branch, trunk) return DeepONet( - Parallel(*; branch=Chain(branch, WrappedFunction(adjoint)), trunk=trunk) + Chain( + Parallel(*; branch=Chain(branch, WrappedFunction(adjoint)), trunk=trunk), + WrappedFunction(adjoint), + ) ) end @@ -98,9 +101,9 @@ function DeepONet(; ) # checks for last dimension size - @argcheck branch[end] == trunk[end] "Branch and Trunk net must share the same amount \ - of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \ - won't work." + @assert branch[end] == trunk[end] "Branch and Trunk net must share the same amount \ + of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \ + won't work." branch_net = Chain( [ diff --git a/test/deeponet_tests.jl b/test/deeponet_tests.jl index 048ef10..f1d4a42 100644 --- a/test/deeponet_tests.jl +++ b/test/deeponet_tests.jl @@ -1,114 +1,52 @@ -# @testitem "DeepONet" setup = [SharedTestSetup] begin -# @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES -# rng = StableRNG(12345) - -# setups = [ -# ( -# u_size=(64, 5), -# y_size=(1, 10, 5), -# out_size=(10, 5), -# branch=(64, 32, 32, 16), -# trunk=(1, 8, 8, 16), -# name="Scalar", -# ), -# ( -# u_size=(64, 1, 5), -# y_size=(1, 10, 5), -# out_size=(1, 10, 5), -# branch=(64, 32, 32, 16), -# trunk=(1, 8, 8, 16), -# name="Scalar II", -# ), -# ( -# u_size=(64, 3, 5), -# y_size=(4, 10, 5), -# out_size=(3, 10, 5), -# branch=(64, 32, 32, 16), -# trunk=(4, 8, 8, 16), -# name="Vector", -# ), -# ( -# u_size=(64, 4, 3, 3, 5), -# y_size=(4, 10, 5), -# out_size=(4, 3, 3, 10, 5), -# branch=(64, 32, 32, 16), -# trunk=(4, 8, 8, 16), -# name="Tensor", -# ), -# ] - -# @testset "$(setup.name)" for setup in setups -# u = aType(rand(Float32, setup.u_size...)) -# y = aType(rand(Float32, setup.y_size...)) -# deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) - -# ps, st = dev(Lux.setup(rng, deeponet)) -# @inferred first(deeponet((u, y), ps, st)) -# @jet first(deeponet((u, y), ps, st)) - -# pred = first(deeponet((u, y), ps, st)) -# @test setup.out_size == size(pred) -# end - -# setups = [ -# ( -# u_size=(64, 5), -# y_size=(1, 10, 5), -# out_size=(4, 10, 5), -# branch=(64, 32, 32, 16), -# trunk=(1, 8, 8, 16), -# additional=Dense(16 => 4), -# name="Scalar", -# ), -# ( -# u_size=(64, 1, 5), -# y_size=(1, 10, 5), -# out_size=(4, 1, 10, 5), -# branch=(64, 32, 32, 16), -# trunk=(1, 8, 8, 16), -# additional=Dense(16 => 4), -# name="Scalar II", -# ), -# ( -# u_size=(64, 3, 5), -# y_size=(8, 10, 5), -# out_size=(4, 3, 10, 5), -# branch=(64, 32, 32, 16), -# trunk=(8, 8, 8, 16), -# additional=Dense(16 => 4), -# name="Vector", -# ), -# ] - -# @testset "Additional layer: $(setup.name)" for setup in setups -# u = aType(rand(Float32, setup.u_size...)) -# y = aType(rand(Float32, setup.y_size...)) -# deeponet = DeepONet(; -# branch=setup.branch, trunk=setup.trunk, additional=setup.additional -# ) - -# ps, st = dev(Lux.setup(rng, deeponet)) -# @inferred first(deeponet((u, y), ps, st)) -# @jet first(deeponet((u, y), ps, st)) - -# pred = first(deeponet((u, y), ps, st)) -# @test setup.out_size == size(pred) - -# __f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st))) -# @test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3) -# end - -# @testset "Embedding layer mismatch" begin -# u = aType(rand(Float32, 64, 5)) -# y = aType(rand(Float32, 1, 10, 5)) - -# deeponet = DeepONet( -# Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)), -# Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)), -# ) - -# ps, st = dev(Lux.setup(rng, deeponet)) -# @test_throws ArgumentError deeponet((u, y), ps, st) -# end -# end -# end +@testitem "DeepONet" setup = [SharedTestSetup] begin + rng = StableRNG(12345) + + setups = [ + ( + u_size=(64, 5), + y_size=(1, 10), + out_size=(10, 5), + branch=(64, 32, 32, 16), + trunk=(1, 8, 8, 16), + name="Scalar", + ), + ( + u_size=(64, 5), + y_size=(4, 10), + out_size=(10, 5), + branch=(64, 32, 32, 16), + trunk=(4, 8, 8, 16), + name="Vector", + ), + ] + + @testset "$(setup.name)" for setup in setups + u = rand(Float32, setup.u_size...) + y = rand(Float32, setup.y_size...) + deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) + + ps, st = Lux.setup(rng, deeponet) + + pred = first(deeponet((u, y), ps, st)) + @test setup.out_size == size(pred) + + ps_ra, st_ra = reactant_device()(ps, st) + u_ra, y_ra = reactant_device()(u, y) + + @testset "check gradients" begin + ∂u_zyg, ∂ps_zyg = zygote_gradient(deeponet, (u, y), ps, st) + + ∂u_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + @jit enzyme_gradient(deeponet, (u_ra, y_ra), ps_ra, st_ra) + end + ∂u_ra, ∂ps_ra = (∂u_ra, ∂ps_ra) |> cpu_device() + + @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-3 rtol = 1.0f-3 + @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-3 rtol = 1.0f-3 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + end + end +end From b5e943103e1c14fd8f28b426d69d4798ee7f6413 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 May 2025 23:54:43 -0400 Subject: [PATCH 07/24] fix: more test fixes --- src/NeuralOperators.jl | 6 +-- src/models/fno.jl | 69 ++++++++++++++------------------ test/deeponet_tests.jl | 6 ++- test/fno_tests.jl | 89 ++++++++++++++++++++++++------------------ 4 files changed, 87 insertions(+), 83 deletions(-) diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 641a050..ab60f2b 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -7,7 +7,7 @@ using Random: Random, AbstractRNG using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer, WrappedFunction using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer using LuxLib: fast_activation!! -using NNlib: NNlib, batched_mul, pad_constant +using NNlib: NNlib, batched_mul, pad_constant, gelu using WeightInitializers: glorot_uniform include("utils.jl") @@ -15,14 +15,14 @@ include("utils.jl") include("transform.jl") include("layers.jl") -# include("models/fno.jl") +include("models/fno.jl") include("models/deeponet.jl") include("models/nomad.jl") export FourierTransform export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel -# export FourierNeuralOperator +export FourierNeuralOperator export DeepONet export NOMAD diff --git a/src/models/fno.jl b/src/models/fno.jl index d66abd3..a9ea053 100644 --- a/src/models/fno.jl +++ b/src/models/fno.jl @@ -1,15 +1,19 @@ """ FourierNeuralOperator( - σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), - permuted::Val{perm}=False, kwargs...) where {C, M, perm} + σ=gelu; + chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), + modes::Dims{M}=(16,), + kwargs... + ) where {C, M} -The Fourier neural operator is a operator learning model that uses a Fourier kernel to perform -spectral convolutions. It is a promising operator for surrogate methods, and can be regarded as -a physics operator. +The Fourier neural operator is a operator learning model that uses a Fourier kernel to +perform spectral convolutions. It is a promising operator for surrogate methods, and can be +regarded as a physics operator. The model is composed of a `Dense` layer to lift a `(d + 1)`-dimensional vector field to an `n`-dimensional vector field, an integral kernel operator which consists of four Fourier -kernels, and two `Dense` layers to project data back to the scalar field of the space of interest. +kernels, and two `Dense` layers to project data back to the scalar field of the space of +interest. ## Arguments @@ -19,11 +23,8 @@ kernels, and two `Dense` layers to project data back to the scalar field of the - `chs`: A `Tuple` or `Vector` of the size of each of the 8 channels. - `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension - of data. For example, one-dimensional data would have a 1-element tuple, and two-dimensional data - would have a 2-element tuple. - - `permuted`: Whether the dim is permuted. If `permuted = Val(false)`, the layer accepts - data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is - `(x_1, ... , x_d, ch, batch)`. + of data. For example, one-dimensional data would have a 1-element tuple, and + two-dimensional data would have a 2-element tuple. ## Example @@ -39,37 +40,27 @@ julia> size(first(fno(u, ps, st))) ``` """ @concrete struct FourierNeuralOperator <: AbstractLuxWrapperLayer{:model} - model <: Chain + model <: AbstractLuxLayer end function FourierNeuralOperator( - σ=gelu; - chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), - modes::Dims{M}=(16,), - permuted::BoolLike=False(), - kwargs..., + σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), kwargs... ) where {C,M} - @argcheck length(chs) ≥ 5 - - map₁ = chs[1] => chs[2] - map₂ = chs[C - 2] => chs[C - 1] - map₃ = chs[C - 1] => chs[C] - - kernel_size = map(Returns(1), modes) - - lifting = known(static(permuted)) ? Conv(kernel_size, map₁) : Dense(map₁) - project = if known(static(permuted)) - Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) - else - Chain(Dense(map₂, σ), Dense(map₃)) - end - - mapping = Chain( - [ - SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...) for - i in 2:(C - 3) - ]..., + @assert length(chs) ≥ 5 + + return FourierNeuralOperator( + Chain( + Conv(map(Returns(1), modes), chs[1] => chs[2]), + Chain( + [ + SpectralKernel(chs[i] => chs[i + 1], modes, σ; kwargs...) for + i in 2:(C - 3) + ]..., + ), + Chain( + Conv(map(Returns(1), modes), chs[C - 2] => chs[C - 1], σ), + Conv(map(Returns(1), modes), chs[C - 1] => chs[C]), + ), + ), ) - - return FourierNeuralOperator(Chain(lifting, mapping, project)) end diff --git a/test/deeponet_tests.jl b/test/deeponet_tests.jl index f1d4a42..7ef1a30 100644 --- a/test/deeponet_tests.jl +++ b/test/deeponet_tests.jl @@ -20,6 +20,8 @@ ), ] + xdev = reactant_device() + @testset "$(setup.name)" for setup in setups u = rand(Float32, setup.u_size...) y = rand(Float32, setup.y_size...) @@ -30,8 +32,8 @@ pred = first(deeponet((u, y), ps, st)) @test setup.out_size == size(pred) - ps_ra, st_ra = reactant_device()(ps, st) - u_ra, y_ra = reactant_device()(u, y) + ps_ra, st_ra = (ps, st) |> xdev + u_ra, y_ra = (u, y) |> xdev @testset "check gradients" begin ∂u_zyg, ∂ps_zyg = zygote_gradient(deeponet, (u, y), ps, st) diff --git a/test/fno_tests.jl b/test/fno_tests.jl index 351b2df..a80b3ee 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -1,39 +1,50 @@ -# @testitem "Fourier Neural Operator" setup=[SharedTestSetup] begin -# @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES -# rng = StableRNG(12345) - -# setups = [ -# (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), -# x_size=(2, 1024, 5), y_size=(1, 1024, 5), permuted=Val(false)), -# (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), -# x_size=(1024, 2, 5), y_size=(1024, 1, 5), permuted=Val(true)) -# ] - -# @testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups -# fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.permuted) -# display(fno) -# ps, st = Lux.setup(rng, fno) |> dev - -# x = rand(rng, Float32, setup.x_size...) |> aType -# y = rand(rng, Float32, setup.y_size...) |> aType - -# @inferred fno(x, ps, st) -# @jet fno(x, ps, st) - -# @test size(first(fno(x, ps, st))) == setup.y_size - -# data = [(x, y)] -# @test begin -# l2, l1 = train!(fno, ps, st, data; epochs=10) -# l2 < l1 -# end - -# __f = (x, ps) -> sum(abs2, first(fno(x, ps, st))) -# @test_gradients(__f, x, -# ps; -# atol=1.0f-3, -# rtol=1.0f-3, -# skip_backends=[AutoTracker(), AutoEnzyme(), AutoReverseDiff()]) -# end -# end -# end +@testitem "Fourier Neural Operator" setup = [SharedTestSetup] begin + rng = StableRNG(12345) + + setups = [ + ( + modes=(16,), + chs=(2, 64, 64, 64, 64, 64, 128, 1), + x_size=(1024, 2, 5), + y_size=(1024, 1, 5), + ), + ] + + @testset "$(length(setup.modes))D" for setup in setups + fno = FourierNeuralOperator(; setup.chs, setup.modes) + display(fno) + ps, st = Lux.setup(rng, fno) + + x = rand(rng, Float32, setup.x_size...) + y = rand(rng, Float32, setup.y_size...) + + @test size(first(fno(x, ps, st))) == setup.y_size + + ps_ra, st_ra = reactant_device(ps, st) + x_ra, y_ra = reactant_device(x, y) + + # XXX: upstream fix is needed for the FFT adjoint to work correctly + # https://github.com/EnzymeAD/Reactant.jl/issues/246 + # @test begin + # l2, l1 = train!( + # MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + # ) + # l2 < l1 + # end + + # @testset "check gradients" begin + # ∂x_zyg, ∂ps_zyg = zygote_gradient(fno, x, ps, st) + + # ∂x_ra, ∂ps_ra = Reactant.with_config(; + # dot_general_precision=PrecisionConfig.HIGH, + # convolution_precision=PrecisionConfig.HIGH, + # ) do + # enzyme_gradient(fno, x_ra, ps_ra, st_ra) + # end + # ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() + + # @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 + # @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + # end + end +end From be7031b2192160c4fa0a88786893b37cb3cd6db9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 May 2025 23:55:12 -0400 Subject: [PATCH 08/24] chore: fmt --- src/models/deeponet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/deeponet.jl b/src/models/deeponet.jl index e87a5db..3a855f3 100644 --- a/src/models/deeponet.jl +++ b/src/models/deeponet.jl @@ -52,7 +52,7 @@ function DeepONet(branch, trunk) Chain( Parallel(*; branch=Chain(branch, WrappedFunction(adjoint)), trunk=trunk), WrappedFunction(adjoint), - ) + ), ) end From 907eb819859aff2d87c83f2dbacc46c816cf12fc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 31 May 2025 16:17:52 -0400 Subject: [PATCH 09/24] fix: calling --- test/fno_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/fno_tests.jl b/test/fno_tests.jl index a80b3ee..6e3982e 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -20,8 +20,8 @@ @test size(first(fno(x, ps, st))) == setup.y_size - ps_ra, st_ra = reactant_device(ps, st) - x_ra, y_ra = reactant_device(x, y) + # ps_ra, st_ra = (ps, st) |> reactant_device() + # x_ra, y_ra = (x, y) |> reactant_device() # XXX: upstream fix is needed for the FFT adjoint to work correctly # https://github.com/EnzymeAD/Reactant.jl/issues/246 From f32dfe422e6c29e64340df51aa0d8665dfcdc702 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 2 Jun 2025 11:19:58 -0400 Subject: [PATCH 10/24] fix: try activating fft tests --- test/Project.toml | 1 + test/fno_tests.jl | 50 +++++++++++++++++++++----------------------- test/layers_tests.jl | 44 +++++++++++++++++++------------------- 3 files changed, 46 insertions(+), 49 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 4b9ba91..1f3a3bf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,6 +30,7 @@ LuxTestUtils = "1.1.2" Optimisers = "0.4" Random = "1.10" ReTestItems = "1.24.0" +Reactant = "0.2.123" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" diff --git a/test/fno_tests.jl b/test/fno_tests.jl index 6e3982e..2bfe063 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -20,31 +20,29 @@ @test size(first(fno(x, ps, st))) == setup.y_size - # ps_ra, st_ra = (ps, st) |> reactant_device() - # x_ra, y_ra = (x, y) |> reactant_device() - - # XXX: upstream fix is needed for the FFT adjoint to work correctly - # https://github.com/EnzymeAD/Reactant.jl/issues/246 - # @test begin - # l2, l1 = train!( - # MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 - # ) - # l2 < l1 - # end - - # @testset "check gradients" begin - # ∂x_zyg, ∂ps_zyg = zygote_gradient(fno, x, ps, st) - - # ∂x_ra, ∂ps_ra = Reactant.with_config(; - # dot_general_precision=PrecisionConfig.HIGH, - # convolution_precision=PrecisionConfig.HIGH, - # ) do - # enzyme_gradient(fno, x_ra, ps_ra, st_ra) - # end - # ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() - - # @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 - # @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) - # end + ps_ra, st_ra = (ps, st) |> reactant_device() + x_ra, y_ra = (x, y) |> reactant_device() + + @test begin + l2, l1 = train!( + MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + ) + l2 < l1 + end + + @testset "check gradients" begin + ∂x_zyg, ∂ps_zyg = zygote_gradient(fno, x, ps, st) + + ∂x_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + @jit enzyme_gradient(fno, x_ra, ps_ra, st_ra) + end + ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() + + @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + end end end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 34e73e5..725cf74 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -26,28 +26,26 @@ x_ra = rdev(x) y_ra = rdev(rand(rng, Float32, setup.y_size...)) - # XXX: upstream fix is needed for the FFT adjoint to work correctly - # https://github.com/EnzymeAD/Reactant.jl/issues/246 - # @test begin - # l2, l1 = train!( - # MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 - # ) - # l2 < l1 - # end - - # @testset "check gradients" begin - # ∂x_zyg, ∂ps_zyg = zygote_gradient(m, x, ps, st) - - # ∂x_ra, ∂ps_ra = Reactant.with_config(; - # dot_general_precision=PrecisionConfig.HIGH, - # convolution_precision=PrecisionConfig.HIGH, - # ) do - # @jit enzyme_gradient(m, x_ra, ps_ra, st_ra) - # end - # ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() - - # @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 - # @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) - # end + @test begin + l2, l1 = train!( + MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + ) + l2 < l1 + end + + @testset "check gradients" begin + ∂x_zyg, ∂ps_zyg = zygote_gradient(m, x, ps, st) + + ∂x_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + @jit enzyme_gradient(m, x_ra, ps_ra, st_ra) + end + ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() + + @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + end end end From 1d1316f920de1e3471005d8e1c874731cc81fdaa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 2 Jun 2025 16:00:45 -0400 Subject: [PATCH 11/24] docs: update deeponet docs --- docs/make.jl | 5 +++-- docs/pages.jl | 8 +++----- docs/src/models/deeponet.md | 35 ++++++++++++++++++++--------------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 8693552..e9048a7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -17,8 +17,9 @@ makedocs(; format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", canonical="https://docs.sciml.ai/NeuralOperators/stable/", - assets=["assets/favicon.ico"]), - pages + assets=["assets/favicon.ico"], + ), + pages, ) deploydocs(; repo="github.com/SciML/NeuralOperators.jl.git", push_preview=true) diff --git a/docs/pages.jl b/docs/pages.jl index 2c9c8a4..e50138c 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -3,10 +3,8 @@ pages = [ "Pre-built Models" => [ "FNO" => "models/fno.md", "DeepONet" => "models/deeponet.md", - "NOMAD" => "models/nomad.md" + "NOMAD" => "models/nomad.md", ], - "Tutorials" => [ - "Burgers Equation" => "tutorials/burgers.md" - ], - "API Reference" => "api.md" + "Tutorials" => ["Burgers Equation" => "tutorials/burgers.md"], + "API Reference" => "api.md", ] diff --git a/docs/src/models/deeponet.md b/docs/src/models/deeponet.md index 0765306..1a9c086 100644 --- a/docs/src/models/deeponet.md +++ b/docs/src/models/deeponet.md @@ -11,7 +11,7 @@ u(y) \xrightarrow{\text{branch}} & \; b \\ & \quad \searrow\\ &\quad \quad \mathcal{G}_{\theta} u(y) = \sum_k b_k t_k \\ & \quad \nearrow \\ -y \; \; \xrightarrow{\text{trunk}} \; \; & t +y \; \; \xrightarrow{\text{trunk}} \; \; & t \end{align*} ``` @@ -38,24 +38,26 @@ v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1] ### Copy-pastable code ```@example deeponet_tutorial -using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie +using NeuralOperators, Lux, Random, Optimisers, Reactant, CairoMakie rng = Random.default_rng() +xdev = reactant_device() eval_points = 1 -data_size = 64 +batch_size = 64 dim_y = 1 m = 32 xrange = range(0, 2π; length=m) .|> Float32 -u_data = zeros(Float32, m, data_size) -α = 0.5f0 .+ 0.5f0 .* rand(Float32, data_size) +α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size) -y_data = rand(Float32, 1, eval_points, data_size) .* 2π -v_data = zeros(Float32, eval_points, data_size) -for i in 1:data_size +u_data = zeros(Float32, m, batch_size) +y_data = rand(Float32, 1, eval_points) .* 2π +v_data = zeros(Float32, eval_points, batch_size) + +for i in 1:batch_size u_data[:, i] .= sin.(α[i] .* xrange) - v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[1, :, i]) + v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[1, :]) end deeponet = DeepONet( @@ -63,18 +65,21 @@ deeponet = DeepONet( Chain(Dense(1 => 4, σ), Dense(4 => 8, σ)) ) -ps, st = Lux.setup(rng, deeponet) +ps, st = Lux.setup(rng, deeponet) |> xdev + +u_data = u_data |> xdev +y_data = y_data |> xdev +v_data = v_data |> xdev data = [((u_data, y_data), v_data)] function train!(model, ps, st, data; epochs=10) losses = [] tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) for _ in 1:epochs, (x, y) in data - - _, loss, - _, tstate = Training.single_train_step!(AutoZygote(), MSELoss(), (x, y), - tstate) - push!(losses, loss) + (_, loss, _, tstate) = Training.single_train_step!( + AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) + ) + push!(losses, Float32(loss)) end return losses end From 7268d7b7573e7b73acda76a147c58ee6fb07396e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Jun 2025 12:07:03 -0400 Subject: [PATCH 12/24] docs: more updates --- .github/workflows/CompatHelper.yml | 2 +- docs/Project.toml | 15 ++++---- docs/src/models/deeponet.md | 28 ++++++++++---- docs/src/models/fno.md | 60 ++++++++++++++++++------------ docs/src/models/nomad.md | 56 ++++++++++++++++++---------- test/Project.toml | 2 +- 6 files changed, 104 insertions(+), 59 deletions(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 0603391..aa70e3f 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -37,7 +37,7 @@ jobs: - name: "Run CompatHelper" run: | import CompatHelper - CompatHelper.main() + CompatHelper.main(; subdirs=["", "docs", "test"]) shell: julia --color=yes {0} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/docs/Project.toml b/docs/Project.toml index 29b4d3c..a71dd4b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,29 +1,30 @@ [deps] +AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + +[sources] +NeuralOperators = {path = ".."} [compat] -CairoMakie = "0.12.11" +CairoMakie = "0.13" CondaPkg = "0.2.23" DataDeps = "0.7.13" Documenter = "1.7.0" Lux = "1" -LuxCUDA = "0.3.3" MAT = "0.10.7" MLUtils = "0.4.4" -NeuralOperators = "0.5" -Optimisers = "0.3.3" +NeuralOperators = "0.6" +Optimisers = "0.4" Printf = "1.10" PythonCall = "0.9.23" -Zygote = "0.6.71" diff --git a/docs/src/models/deeponet.md b/docs/src/models/deeponet.md index 1a9c086..3f5a186 100644 --- a/docs/src/models/deeponet.md +++ b/docs/src/models/deeponet.md @@ -38,9 +38,15 @@ v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1] ### Copy-pastable code ```@example deeponet_tutorial -using NeuralOperators, Lux, Random, Optimisers, Reactant, CairoMakie +using NeuralOperators, Lux, Random, Optimisers, Reactant + +using CairoMakie, AlgebraOfGraphics +set_aog_theme!() +const AoG = AlgebraOfGraphics rng = Random.default_rng() +Random.seed!(rng, 1234) + xdev = reactant_device() eval_points = 1 @@ -52,7 +58,7 @@ xrange = range(0, 2π; length=m) .|> Float32 α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size) u_data = zeros(Float32, m, batch_size) -y_data = rand(Float32, 1, eval_points) .* 2π +y_data = rand(rng, Float32, 1, eval_points) .* Float32(2π) v_data = zeros(Float32, eval_points, batch_size) for i in 1:batch_size @@ -65,12 +71,12 @@ deeponet = DeepONet( Chain(Dense(1 => 4, σ), Dense(4 => 8, σ)) ) -ps, st = Lux.setup(rng, deeponet) |> xdev +ps, st = Lux.setup(rng, deeponet) |> xdev; -u_data = u_data |> xdev -y_data = y_data |> xdev -v_data = v_data |> xdev -data = [((u_data, y_data), v_data)] +u_data = u_data |> xdev; +y_data = y_data |> xdev; +v_data = v_data |> xdev; +data = [((u_data, y_data), v_data)]; function train!(model, ps, st, data; epochs=10) losses = [] @@ -86,5 +92,11 @@ end losses = train!(deeponet, ps, st, data; epochs=1000) -lines(losses) +draw( + AoG.data((; losses, iteration=1:length(losses))) * + mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") * + visual(Lines); + axis=(; yscale=log10), + figure=(; title="Using DeepONet to learn the anti-derivative operator") +) ``` diff --git a/docs/src/models/fno.md b/docs/src/models/fno.md index 0de5552..a628905 100644 --- a/docs/src/models/fno.md +++ b/docs/src/models/fno.md @@ -18,7 +18,7 @@ convolution operation, which can be efficiently computed in the fourier domain. ```math \begin{align*} -(\Kappa_{\theta}u)(x) +(\Kappa_{\theta}u)(x) &= \int_D \kappa_{\theta}(x - y) dy \quad \forall x \in D\\ &= \mathcal{F}^{-1}(\mathcal{F}(\kappa_{\theta}) \mathcal{F}(u))(x) \quad \forall x \in D \end{align*} @@ -57,44 +57,58 @@ v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1] ``` ```@example fno_tutorial -using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie +using NeuralOperators, Lux, Random, Optimisers, Reactant + +using CairoMakie, AlgebraOfGraphics +set_aog_theme!() +const AoG = AlgebraOfGraphics rng = Random.default_rng() +Random.seed!(rng, 1234) -data_size = 128 +xdev = reactant_device() + +batch_size = 128 m = 32 xrange = range(0, 2π; length=m) .|> Float32; -u_data = zeros(Float32, m, 1, data_size); -α = 0.5f0 .+ 0.5f0 .* rand(Float32, data_size); -v_data = zeros(Float32, m, 1, data_size); +u_data = zeros(Float32, m, 1, batch_size); +α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size); +v_data = zeros(Float32, m, 1, batch_size); -for i in 1:data_size +for i in 1:batch_size u_data[:, 1, i] .= sin.(α[i] .* xrange) v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange) end -fno = FourierNeuralOperator(gelu; chs=(1, 64, 64, 128, 1), modes=(16,), permuted=Val(true)) +fno = FourierNeuralOperator(gelu; chs=(1, 64, 64, 128, 1), modes=(16,)) -ps, st = Lux.setup(rng, fno); +ps, st = Lux.setup(rng, fno) |> xdev; +u_data = u_data |> xdev; +v_data = v_data |> xdev; data = [(u_data, v_data)]; function train!(model, ps, st, data; epochs=10) losses = [] - tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) for _ in 1:epochs, (x, y) in data - - _, loss, - _, tstate = Training.single_train_step!(AutoZygote(), MSELoss(), (x, y), - tstate) - push!(losses, loss) + (_, loss, _, tstate) = Training.single_train_step!( + AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) + ) + push!(losses, Float32(loss)) end return losses end -losses = train!(fno, ps, st, data; epochs=100) +losses = train!(fno, ps, st, data; epochs=1000) -lines(losses) +draw( + AoG.data((; losses, iteration=1:length(losses))) * + mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") * + visual(Lines); + axis=(; yscale=log10), + figure=(; title="Using Fourier Neural Operator to learn the anti-derivative operator") +) ``` ```@raw html @@ -113,10 +127,10 @@ First, we construct our training data. rng = Random.default_rng() ```` -`data_size` is the number of observations. +`batch_size` is the number of observations. ````@example minimal_lux -data_size = 128 +batch_size = 128 ```` `m` is the length of a single observation, you can also interpret this as the size of the grid we're evaluating our function on. @@ -137,7 +151,7 @@ Each value in the array here, `α`, will be the multiplicative factor on the input to the sine function. ````@example minimal_lux -α = 0.5f0 .+ 0.5f0 .* rand(Float32, data_size); +α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size); nothing #hide ```` @@ -146,8 +160,8 @@ of the training data in a single array, in order to batch process them more efficiently. ````@example minimal_lux -u_data = zeros(Float32, m, 1, data_size); -v_data = zeros(Float32, m, 1, data_size); +u_data = zeros(Float32, m, 1, batch_size); +v_data = zeros(Float32, m, 1, batch_size); nothing #hide ```` @@ -155,7 +169,7 @@ and fill the data arrays with values. Here, `u_data` is ````@example minimal_lux -for i in 1:data_size +for i in 1:batch_size u_data[:, 1, i] .= sin.(α[i] .* xrange) v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange) end diff --git a/docs/src/models/nomad.md b/docs/src/models/nomad.md index 7ecf801..0a944fc 100644 --- a/docs/src/models/nomad.md +++ b/docs/src/models/nomad.md @@ -40,46 +40,64 @@ v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1] ### Copy-pastable code ```@example nomad_tutorial -using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie +using NeuralOperators, Lux, Random, Optimisers, Reactant + +using CairoMakie, AlgebraOfGraphics +set_aog_theme!() +const AoG = AlgebraOfGraphics rng = Random.default_rng() +Random.seed!(rng, 1234) + +xdev = reactant_device() eval_points = 1 -data_size = 128 +batch_size = 64 dim_y = 1 m = 32 xrange = range(0, 2π; length=m) .|> Float32 -u_data = zeros(Float32, m, data_size) -α = 0.5f0 .+ 0.5f0 .* rand(Float32, data_size) +α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size) -y_data = rand(Float32, 1, eval_points, data_size) .* 2π -v_data = zeros(Float32, eval_points, data_size) -for i in 1:data_size +u_data = zeros(Float32, m, batch_size) +y_data = rand(rng, Float32, eval_points, batch_size) .* Float32(2π) +v_data = zeros(Float32, eval_points, batch_size) + +for i in 1:batch_size u_data[:, i] .= sin.(α[i] .* xrange) - v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[1, :, i]) + v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[:, i]) end -nomad = NOMAD(Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 7)), - Chain(Dense(8 => 4, σ), Dense(4 => 1))) +nomad = NOMAD( + Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 8 - eval_points)), + Chain(Dense(8 => 4, σ), Dense(4 => eval_points)) +) -ps, st = Lux.setup(rng, nomad) -data = [((u_data, y_data), v_data)] +ps, st = Lux.setup(rng, nomad) |> xdev; +u_data = u_data |> xdev; +y_data = y_data |> xdev; +v_data = v_data |> xdev; +data = [((u_data, y_data), v_data)]; function train!(model, ps, st, data; epochs=10) losses = [] - tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) for _ in 1:epochs, (x, y) in data - - _, loss, - _, tstate = Training.single_train_step!(AutoZygote(), MSELoss(), (x, y), - tstate) - push!(losses, loss) + (_, loss, _, tstate) = Training.single_train_step!( + AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) + ) + push!(losses, Float32(loss)) end return losses end losses = train!(nomad, ps, st, data; epochs=1000) -lines(losses) +draw( + AoG.data((; losses, iteration=1:length(losses))) * + mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") * + visual(Lines); + axis=(; yscale=log10), + figure=(; title="Using NOMAD to learn the anti-derivative operator") +) ``` diff --git a/test/Project.toml b/test/Project.toml index 1f3a3bf..0626fb6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,7 +30,7 @@ LuxTestUtils = "1.1.2" Optimisers = "0.4" Random = "1.10" ReTestItems = "1.24.0" -Reactant = "0.2.123" +Reactant = "0.2.124" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" From 4901fe15b7c871fd4347cb9311943f010683b145 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Jun 2025 12:16:52 -0400 Subject: [PATCH 13/24] test: more test fixes --- docs/Project.toml | 2 ++ src/models/fno.jl | 4 ++-- test/Project.toml | 2 ++ test/fno_tests.jl | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index a71dd4b..0c267d3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -17,6 +17,7 @@ Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" NeuralOperators = {path = ".."} [compat] +AlgebraOfGraphics = "0.10.7" CairoMakie = "0.13" CondaPkg = "0.2.23" DataDeps = "0.7.13" @@ -28,3 +29,4 @@ NeuralOperators = "0.6" Optimisers = "0.4" Printf = "1.10" PythonCall = "0.9.23" +Reactant = "0.2.125" diff --git a/src/models/fno.jl b/src/models/fno.jl index a9ea053..5c7f99f 100644 --- a/src/models/fno.jl +++ b/src/models/fno.jl @@ -33,10 +33,10 @@ julia> fno = FourierNeuralOperator(gelu; chs=(2, 64, 64, 128, 1), modes=(16,)); julia> ps, st = Lux.setup(Xoshiro(), fno); -julia> u = rand(Float32, 2, 1024, 5); +julia> u = rand(Float32, 1024, 2, 5); julia> size(first(fno(u, ps, st))) -(1, 1024, 5) +(1024, 1, 5) ``` """ @concrete struct FourierNeuralOperator <: AbstractLuxWrapperLayer{:model} diff --git a/test/Project.toml b/test/Project.toml index 0626fb6..c0f8668 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,7 +21,9 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8.7" Documenter = "1.5.0" +Enzyme = "0.13.48" ExplicitImports = "1.9.0" +FFTW = "1.9.0" Hwloc = "3.2" Lux = "1" LuxCore = "1" diff --git a/test/fno_tests.jl b/test/fno_tests.jl index 2bfe063..9e59d8e 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -25,7 +25,7 @@ @test begin l2, l1 = train!( - MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + MSELoss(), AutoEnzyme(), fno, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 ) l2 < l1 end From 6012a1e845d0255203b673d56e647fdd1cc9c42e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Jun 2025 12:34:58 -0400 Subject: [PATCH 14/24] fix: downgrade CI compats --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 15a935d..aa608b6 100644 --- a/Project.toml +++ b/Project.toml @@ -16,9 +16,9 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] AbstractFFTs = "1.5.0" ConcreteStructs = "0.2.3" -Lux = "1" -LuxCore = "1" -LuxLib = "1.8.0" +Lux = "1.13" +LuxCore = "1.2" +LuxLib = "1.8" NNlib = "0.9.30" Random = "1.10" WeightInitializers = "1" From e80de31010b1d57604d307e8e756e24db77d6218 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Jun 2025 14:19:55 -0400 Subject: [PATCH 15/24] docs: use 1.11 --- .buildkite/documentation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index f20d9c9..c04c184 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -2,7 +2,7 @@ steps: - label: ":julia: Documentation" plugins: - JuliaCI/julia#v1: - version: "1.10" + version: "1" - JuliaCI/julia-coverage#v1: codecov: true command: | From 874567bc7dcd54df86f6811ed932d448c07c3d49 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Jun 2025 17:51:29 -0400 Subject: [PATCH 16/24] docs: update burgers deeponet --- docs/pages.jl | 2 +- docs/src/models/fno.md | 102 ++++++++--------- docs/src/tutorials/burgers.md | 136 ----------------------- docs/src/tutorials/burgers_deeponet.md | 145 +++++++++++++++++++++++++ 4 files changed, 197 insertions(+), 188 deletions(-) delete mode 100644 docs/src/tutorials/burgers.md create mode 100644 docs/src/tutorials/burgers_deeponet.md diff --git a/docs/pages.jl b/docs/pages.jl index e50138c..c182f5f 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -5,6 +5,6 @@ pages = [ "DeepONet" => "models/deeponet.md", "NOMAD" => "models/nomad.md", ], - "Tutorials" => ["Burgers Equation" => "tutorials/burgers.md"], + "Tutorials" => ["Burgers Equation" => "tutorials/burgers_deeponet.md"], "API Reference" => "api.md", ] diff --git a/docs/src/models/fno.md b/docs/src/models/fno.md index a628905..8ef1dad 100644 --- a/docs/src/models/fno.md +++ b/docs/src/models/fno.md @@ -115,110 +115,110 @@ draw( ``` -````@example minimal_lux +```@example minimal_lux using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie -```` +``` ### Constructing training data First, we construct our training data. -````@example minimal_lux +```@example minimal_lux rng = Random.default_rng() -```` +``` `batch_size` is the number of observations. -````@example minimal_lux +```@example minimal_lux batch_size = 128 -```` +``` `m` is the length of a single observation, you can also interpret this as the size of the grid we're evaluating our function on. -````@example minimal_lux +```@example minimal_lux m = 32 -```` +``` We instantiate the domain that the function operates on as a range from `0` to `2π`, whose length is the grid size. -````@example minimal_lux +```@example minimal_lux xrange = range(0, 2π; length=m) .|> Float32; nothing #hide -```` +``` Each value in the array here, `α`, will be the multiplicative factor on the input to the sine function. -````@example minimal_lux +```@example minimal_lux α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size); nothing #hide -```` +``` -Now, we create our data arrays. We are storing all +Now, we create our data arrays. We are storing all of the training data in a single array, in order to batch process them more efficiently. -````@example minimal_lux +```@example minimal_lux u_data = zeros(Float32, m, 1, batch_size); v_data = zeros(Float32, m, 1, batch_size); nothing #hide -```` +``` and fill the data arrays with values. Here, `u_data` is -````@example minimal_lux +```@example minimal_lux for i in 1:batch_size u_data[:, 1, i] .= sin.(α[i] .* xrange) v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange) end -```` +``` ### Creating the model -Finally, we get to the model itself. We instantiate a `FourierNeuralOperator` and provide it several parameters. +Finally, we get to the model itself. We instantiate a `FourierNeuralOperator` and provide it several parameters. The first argument is the "activation function" for each neuron. The keyword arguments are: - - `chs` is a tuple, representing the layer sizes for each layer. - - `modes` is a 1-tuple, where the number represents the number of Fourier modes that - are preserved, and the size of the tuple represents the number of dimensions. - - `permuted` indicates that the order of the arguments is permuted such that each column - of the array represents a single observation. This is substantially faster than the usual - row access pattern, since Julia stores arrays by concatenating columns. - `Val(true)` is another way of expressing `true`, but in the type domain, so that - the compiler can see the value and use the appropriate optimizations. +- `chs` is a tuple, representing the layer sizes for each layer. +- `modes` is a 1-tuple, where the number represents the number of Fourier modes that + are preserved, and the size of the tuple represents the number of dimensions. +- `permuted` indicates that the order of the arguments is permuted such that each column + of the array represents a single observation. This is substantially faster than the usual + row access pattern, since Julia stores arrays by concatenating columns. + `Val(true)` is another way of expressing `true`, but in the type domain, so that + the compiler can see the value and use the appropriate optimizations. -````@example minimal_lux +```@example minimal_lux fno = FourierNeuralOperator( gelu; # activation function chs=(1, 64, 64, 128, 1), # channel weights modes=(16,), # number of Fourier modes to retain permuted=Val(true) # structure of the data means that columns are observations ) -```` +``` -Now, we set up the model. This function returns two things, -a set of parameters and a set of states. Since the operator is -"stateless", the states are empty and will remain so. The parameters +Now, we set up the model. This function returns two things, +a set of parameters and a set of states. Since the operator is +"stateless", the states are empty and will remain so. The parameters are the weights of the neural network, and we will be modifying them in the training loop. -````@example minimal_lux +```@example minimal_lux ps, st = Lux.setup(rng, fno); nothing #hide -```` +``` -We construct data as a vector of tuples (input, output). These are pre-batched, +We construct data as a vector of tuples (input, output). These are pre-batched, but for example if we had a lot of training data, we could dynamically load it, or create multiple batches. -````@example minimal_lux +```@example minimal_lux data = [(u_data, v_data)]; nothing #hide -```` +``` ### Training the model @@ -226,7 +226,7 @@ Now, we create a function to train the model. An "epoch" is basically a run over all input data, and the more epochs we have, the better the neural network gets! -````@example minimal_lux +```@example minimal_lux function train!(model, ps, st, data; epochs=10) # The `losses` array is used only for visualization, # you don't actually need it to train. @@ -245,45 +245,45 @@ function train!(model, ps, st, data; epochs=10) end return losses end -```` +``` Now we train our model! -````@example minimal_lux +```@example minimal_lux losses = @time train!(fno, ps, st, data; epochs=500) -```` +``` We can plot the losses - you can see that at some point, we hit diminishing returns. -````@example minimal_lux +```@example minimal_lux lines(losses; axis=(; yscale=log10, ylabel="Loss", xlabel="Epoch")) -```` +``` ### Applying the model Let's try to actually apply this model using some input data. -````@example minimal_lux +```@example minimal_lux input_data = u_data[:, 1, 1] -```` +``` -This is our input data. It's currently one-dimensional, +This is our input data. It's currently one-dimensional, but our neural network expects input in batched form, so we simply `reshape` it (a no-cost operation) to a 3d array with singleton dimensions. -````@example minimal_lux +```@example minimal_lux reshaped_input = reshape(input_data, length(input_data), 1, 1) -```` +``` Now we can pass this to `Lux.apply`: -````@example minimal_lux +```@example minimal_lux output_data, st = Lux.apply(fno, reshaped_input, ps, st) -```` +``` and plot it: -````@example minimal_lux +```@example minimal_lux f, a, p = lines(dropdims(reshaped_input; dims=(2, 3)); label="u") lines!(a, dropdims(output_data; dims=(2, 3)); label="Predicted") lines!(a, v_data[:, 1, 1]; label="Expected") @@ -295,4 +295,4 @@ a2, p2 = lines(f[2, 1], absolute_error; axis=(; ylabel="Error")) rowsize!(f.layout, 2, Aspect(1, 1 / 8)) linkxaxes!(a, a2) f -```` +``` diff --git a/docs/src/tutorials/burgers.md b/docs/src/tutorials/burgers.md deleted file mode 100644 index 76f8ece..0000000 --- a/docs/src/tutorials/burgers.md +++ /dev/null @@ -1,136 +0,0 @@ -# Burgers Equation using DeepONet - -## Data Loading - -```@example burgers -using DataDeps, MAT, MLUtils -using PythonCall, CondaPkg # For `gdown` -using Printf - -const gdown = pyimport("gdown") - -register( - DataDep( - "Burgers", - """ - Burgers' equation dataset from - [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) - - mapping between initial conditions to the solutions at the last point of time \ - evolution in some function space. - - u(x,0) -> u(x, time_end): - - * `a`: initial conditions u(x,0) - * `u`: solutions u(x,t_end) - """, - "https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe", - "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd"; - fetch_method=(url, - local_dir) -> begin - pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip"))) - end, - post_fetch_method=unpack -) -) - -filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat") - -const N = 2048 -const Δsamples = 2^3 -const grid_size = div(2^13, Δsamples) -const T = Float32 - -file = matopen(filepath) -x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :, 1) -y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :, 1) -close(file) - -x_data = permutedims(x_data, (2, 1, 3)) -grid = reshape(T.(collect(range(0, 1; length=grid_size)')), :, grid_size, 1) -``` - -## Model - -```@example burgers -using Lux, NeuralOperators, Optimisers, Zygote, Random -using LuxCUDA - -const cdev = cpu_device() -const gdev = gpu_device() - -deeponet = DeepONet(; - branch=(size(x_data, 1), ntuple(Returns(32), 5)...), - trunk=(size(grid, 1), ntuple(Returns(32), 5)...), - branch_activation=tanh, - trunk_activation=tanh -) -ps, st = Lux.setup(Random.default_rng(), deeponet) |> gdev; -``` - -## Training - -```@example burgers -x_data_dev = x_data |> gdev -y_data_dev = y_data |> gdev -grid_dev = grid |> gdev - -function loss_function(model, ps, st, ((v, y), u)) - û, stₙ = model((v, y), ps, st) - return MAELoss()(û, u), stₙ, (;) -end - -function train_model!(model, ps, st, data; epochs=5000) - train_state = Training.TrainState(model, ps, st, Adam(0.0001f0)) - - for epoch in 1:epochs - _, loss, - _, - train_state = Training.single_train_step!( - AutoZygote(), loss_function, data, train_state) - - if epoch % 25 == 1 || epoch == epochs - @printf("Epoch %d: loss = %.6e\n", epoch, loss) - end - end - - return train_state.parameters, train_state.states -end - -ps_trained, -st_trained = train_model!( - deeponet, ps, st, ((x_data_dev, grid_dev), y_data_dev)) -``` - -## Plotting - -```@example burgers -using CairoMakie - -pred = first(deeponet((x_data_dev, grid_dev), ps_trained, st_trained)) |> cdev - -begin - fig = Figure(; size=(1024, 1024)) - - axs = [Axis(fig[i, j]) for i in 1:4, j in 1:4] - for i in 1:4, j in 1:4 - - idx = i + (j - 1) * 4 - ax = axs[i, j] - l1 = lines!(ax, vec(grid), pred[idx, :, 1]) - l2 = lines!(ax, vec(grid), y_data[idx, :, 1]) - - i == 4 && (ax.xlabel = "x") - j == 1 && (ax.ylabel = "u(x)") - - if i == 1 && j == 1 - axislegend(ax, [l1, l2], ["Predictions", "Ground Truth"]) - end - end - linkaxes!(axs...) - - fig[0, :] = Label(fig, "Burgers Equation using DeepONet"; tellwidth=false, font=:bold) - - fig -end -``` diff --git a/docs/src/tutorials/burgers_deeponet.md b/docs/src/tutorials/burgers_deeponet.md new file mode 100644 index 0000000..5e0be51 --- /dev/null +++ b/docs/src/tutorials/burgers_deeponet.md @@ -0,0 +1,145 @@ +# Burgers Equation using DeepONet + +## Data Loading + +```@example burgers +using DataDeps, MAT, MLUtils +using PythonCall, CondaPkg # For `gdown` +using Printf + +const gdown = pyimport("gdown") + +register( + DataDep( + "Burgers", + """ + Burgers' equation dataset from + [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) + + mapping between initial conditions to the solutions at the last point of time \ + evolution in some function space. + + u(x,0) -> u(x, time_end): + + * `a`: initial conditions u(x,0) + * `u`: solutions u(x,t_end) + """, + "https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe", + "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd"; + fetch_method=(url, + local_dir) -> begin + pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip"))) + end, + post_fetch_method=unpack +) +) + +filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat") + +const N = 2048 +const Δsamples = 2^3 +const grid_size = div(2^13, Δsamples) +const T = Float32 + +file = matopen(filepath) +x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :) +y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :) +close(file) + +x_data = permutedims(x_data, (2, 1)) +y_data = permutedims(y_data, (2, 1)) +grid = reshape(collect(T, range(0, 1; length=grid_size)), 1, :) +``` + +## Model + +```@example burgers +using Lux, NeuralOperators, Optimisers, Random, Reactant + +const cdev = cpu_device() +const xdev = reactant_device(; force=true) + +deeponet = DeepONet(; + branch=(size(x_data, 1), ntuple(Returns(32), 5)...), + trunk=(size(grid, 1), ntuple(Returns(32), 5)...), + branch_activation=gelu, + trunk_activation=gelu +) +ps, st = Lux.setup(Random.default_rng(), deeponet) |> xdev; +``` + +## Training + +```@example burgers +x_data_dev = x_data |> xdev; +y_data_dev = y_data |> xdev; +grid_dev = grid |> xdev; + +function train_model!(model, ps, st, data; epochs=5000) + train_state = Training.TrainState(model, ps, st, Adam(0.0001f0)) + + for epoch in 1:epochs + (_, loss, _, train_state) = Training.single_train_step!( + AutoEnzyme(), MAELoss(), data, train_state + ) + + if epoch % 100 == 1 || epoch == epochs + @printf("Epoch %d: loss = %.6e\n", epoch, loss) + end + end + + return train_state.parameters, train_state.states +end + +(ps_trained, st_trained) = train_model!( + deeponet, ps, st, ((x_data_dev, grid_dev), y_data_dev) +) +nothing #hide +``` + +## Plotting + +```@example burgers +using CairoMakie, AlgebraOfGraphics +const AoG = AlgebraOfGraphics +AoG.set_aog_theme!() + +pred = first( + Reactant.with_config(; + convolution_precision=PrecisionConfig.HIGH, + dot_general_precision=PrecisionConfig.HIGH, + ) do + @jit(deeponet((x_data_dev, grid_dev), ps_trained, st_trained)) + end +) |> cdev + +data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[] +for i in 1:16 + append!(repeated_grid, vcat(vec(grid), vec(grid))) + append!(sequence, repeat([i], grid_size * 2)) + append!(label, repeat(["Ground Truth"], grid_size)) + append!(label, repeat(["Predictions"], grid_size)) + append!(data_sequence, vec(y_data[:, i])) + append!(data_sequence, vec(pred[:, i])) +end +plot_data = (; data_sequence, sequence, repeated_grid, label) + +draw( + AoG.data(plot_data) * + mapping( + :repeated_grid => L"x", + :data_sequence => L"u(x)"; + color=:label => "", + layout=:sequence => nonnumeric, + ) * + visual(Lines), + scales(; Color=(; palette=:tab10)); + figure=(; + size=(1024, 1024), + title="Using DeepONet to solve the Burgers equation", + titlesize=25, + ), + axis=(; xlabelsize=25, ylabelsize=25), + legend=(; label=L"u(x)", position=:bottom, labelsize=20), +) +``` From ac919b5b94f951128e53fc5f0130c3570699b4d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Jun 2025 18:36:30 -0400 Subject: [PATCH 17/24] chore: bump reactant version --- docs/Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 0c267d3..4df738c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -29,4 +29,4 @@ NeuralOperators = "0.6" Optimisers = "0.4" Printf = "1.10" PythonCall = "0.9.23" -Reactant = "0.2.125" +Reactant = "0.2.126" diff --git a/test/Project.toml b/test/Project.toml index c0f8668..1ea6dbe 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -32,7 +32,7 @@ LuxTestUtils = "1.1.2" Optimisers = "0.4" Random = "1.10" ReTestItems = "1.24.0" -Reactant = "0.2.124" +Reactant = "0.2.126" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" From 9fb6f139f708576b93e2713b0ae84c808535c6c7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Jun 2025 10:36:06 -0500 Subject: [PATCH 18/24] fix: correct loss term --- test/shared_testsetup.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index e453fa8..fd16b02 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -17,7 +17,7 @@ function train!(loss, backend, model, ps, st, data; epochs=10) _, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate) end - l2 = @jit loss(model, ps, st, first(data)) + l2 = @jit loss(model, tstate.parameters, tstate.states, first(data)) return l2, l1 end From 14e323464faf0649fd52f9063941863fae1cdd46 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Jun 2025 10:05:58 -0400 Subject: [PATCH 19/24] test: use specific reactant branch --- docs/Project.toml | 1 + test/Project.toml | 3 +++ 2 files changed, 4 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 4df738c..b9c185b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,6 +15,7 @@ Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [sources] NeuralOperators = {path = ".."} +Reactant = { url = "https://github.com/EnzymeAD/Reactant.jl", rev = "ap/ignore_derivatives_julia" } [compat] AlgebraOfGraphics = "0.10.7" diff --git a/test/Project.toml b/test/Project.toml index 1ea6dbe..e0f62cb 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -37,3 +37,6 @@ Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" Zygote = "0.7" + +[sources] +Reactant = { url = "https://github.com/EnzymeAD/Reactant.jl", rev = "ap/ignore_derivatives_julia" } From d642aa91604f376c8edac51eb8d9942ce365a31c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Jun 2025 12:49:30 -0400 Subject: [PATCH 20/24] docs: fno tutorial is working :tada: --- docs/src/models/fno.md | 121 ++++++++++++++++++++--------------------- 1 file changed, 60 insertions(+), 61 deletions(-) diff --git a/docs/src/models/fno.md b/docs/src/models/fno.md index 8ef1dad..1050ea0 100644 --- a/docs/src/models/fno.md +++ b/docs/src/models/fno.md @@ -90,7 +90,7 @@ data = [(u_data, v_data)]; function train!(model, ps, st, data; epochs=10) losses = [] - tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.003f0)) for _ in 1:epochs, (x, y) in data (_, loss, _, tstate) = Training.single_train_step!( AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) @@ -115,60 +115,65 @@ draw( ``` -```@example minimal_lux -using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie +```@example fno_tutorial_details +using NeuralOperators, Lux, Random, Optimisers, Reactant +``` + +We will use Reactant.jl to accelerate the training process. + +```@example fno_tutorial_details +xdev = reactant_device() ``` ### Constructing training data First, we construct our training data. -```@example minimal_lux +```@example fno_tutorial_details rng = Random.default_rng() ``` `batch_size` is the number of observations. -```@example minimal_lux +```@example fno_tutorial_details batch_size = 128 ``` -`m` is the length of a single observation, you can also interpret this as the size of the grid we're evaluating our function on. +`m` is the length of a single observation, you can also interpret this as the size of the +grid we're evaluating our function on. -```@example minimal_lux +```@example fno_tutorial_details m = 32 ``` -We instantiate the domain that the function operates on -as a range from `0` to `2π`, whose length is the grid size. +We instantiate the domain that the function operates on as a range from `0` to `2π`, whose +length is the grid size. -```@example minimal_lux +```@example fno_tutorial_details xrange = range(0, 2π; length=m) .|> Float32; nothing #hide ``` -Each value in the array here, `α`, will be the multiplicative -factor on the input to the sine function. +Each value in the array here, `α`, will be the multiplicative factor on the input to the +sine function. -```@example minimal_lux +```@example fno_tutorial_details α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size); nothing #hide ``` -Now, we create our data arrays. We are storing all -of the training data in a single array, in order to -batch process them more efficiently. +Now, we create our data arrays. We are storing all of the training data in a single array, +in order to batch process them more efficiently. -```@example minimal_lux +```@example fno_tutorial_details u_data = zeros(Float32, m, 1, batch_size); v_data = zeros(Float32, m, 1, batch_size); nothing #hide ``` -and fill the data arrays with values. -Here, `u_data` is +and fill the data arrays with values. Here, `u_data` is -```@example minimal_lux +```@example fno_tutorial_details for i in 1:batch_size u_data[:, 1, i] .= sin.(α[i] .* xrange) v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange) @@ -177,7 +182,8 @@ end ### Creating the model -Finally, we get to the model itself. We instantiate a `FourierNeuralOperator` and provide it several parameters. +Finally, we get to the model itself. We instantiate a `FourierNeuralOperator` and provide +it several parameters. The first argument is the "activation function" for each neuron. @@ -186,18 +192,12 @@ The keyword arguments are: - `chs` is a tuple, representing the layer sizes for each layer. - `modes` is a 1-tuple, where the number represents the number of Fourier modes that are preserved, and the size of the tuple represents the number of dimensions. -- `permuted` indicates that the order of the arguments is permuted such that each column - of the array represents a single observation. This is substantially faster than the usual - row access pattern, since Julia stores arrays by concatenating columns. - `Val(true)` is another way of expressing `true`, but in the type domain, so that - the compiler can see the value and use the appropriate optimizations. -```@example minimal_lux +```@example fno_tutorial_details fno = FourierNeuralOperator( gelu; # activation function chs=(1, 64, 64, 128, 1), # channel weights modes=(16,), # number of Fourier modes to retain - permuted=Val(true) # structure of the data means that columns are observations ) ``` @@ -206,8 +206,8 @@ a set of parameters and a set of states. Since the operator is "stateless", the states are empty and will remain so. The parameters are the weights of the neural network, and we will be modifying them in the training loop. -```@example minimal_lux -ps, st = Lux.setup(rng, fno); +```@example fno_tutorial_details +ps, st = Lux.setup(rng, fno) |> xdev; nothing #hide ``` @@ -215,55 +215,50 @@ We construct data as a vector of tuples (input, output). These are pre-batched, but for example if we had a lot of training data, we could dynamically load it, or create multiple batches. -```@example minimal_lux +```@example fno_tutorial_details +u_data = u_data |> xdev; +v_data = v_data |> xdev; data = [(u_data, v_data)]; nothing #hide ``` ### Training the model -Now, we create a function to train the model. -An "epoch" is basically a run over all input data, -and the more epochs we have, the better the neural network gets! +Now, we create a function to train the model. An "epoch" is basically a run over all +input data, and the more epochs we have, the better the neural network gets! -```@example minimal_lux +```@example fno_tutorial_details function train!(model, ps, st, data; epochs=10) # The `losses` array is used only for visualization, # you don't actually need it to train. losses = [] # Initialize a training state and an optimizer (Adam, in this case). - tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) - # Loop over epochs, then loop over each batch of training data, and step into the training: + tstate = Training.TrainState(model, ps, st, Adam(0.003f0)) + # Loop over epochs, then loop over each batch of training data, and step into the + # training: for _ in 1:epochs for (x, y) in data - _, loss, - _, tstate = Training.single_train_step!( - AutoZygote(), MSELoss(), (x, y), - tstate) - push!(losses, loss) + (_, loss, _, tstate) = Training.single_train_step!( + AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) + ) + push!(losses, Float32(loss)) end end - return losses + return losses, tstate.parameters, tstate.states end ``` Now we train our model! -```@example minimal_lux -losses = @time train!(fno, ps, st, data; epochs=500) -``` - -We can plot the losses - you can see that at some point, we hit diminishing returns. - -```@example minimal_lux -lines(losses; axis=(; yscale=log10, ylabel="Loss", xlabel="Epoch")) +```@example fno_tutorial_details +losses, ps, st = @time train!(fno, ps, st, data; epochs=500) ``` ### Applying the model Let's try to actually apply this model using some input data. -```@example minimal_lux +```@example fno_tutorial_details input_data = u_data[:, 1, 1] ``` @@ -271,26 +266,30 @@ This is our input data. It's currently one-dimensional, but our neural network expects input in batched form, so we simply `reshape` it (a no-cost operation) to a 3d array with singleton dimensions. -```@example minimal_lux +```@example fno_tutorial_details reshaped_input = reshape(input_data, length(input_data), 1, 1) ``` -Now we can pass this to `Lux.apply`: +Now we can pass this to `Lux.apply` (`@jit` is used to run the function with Reactant.jl): -```@example minimal_lux -output_data, st = Lux.apply(fno, reshaped_input, ps, st) +```@example fno_tutorial_details +output_data, st = @jit Lux.apply(fno, reshaped_input, ps, st) ``` and plot it: -```@example minimal_lux -f, a, p = lines(dropdims(reshaped_input; dims=(2, 3)); label="u") -lines!(a, dropdims(output_data; dims=(2, 3)); label="Predicted") -lines!(a, v_data[:, 1, 1]; label="Expected") +```@example fno_tutorial_details +using CairoMakie, AlgebraOfGraphics +const AoG = AlgebraOfGraphics +AoG.set_aog_theme!() + +f, a, p = lines(dropdims(Array(reshaped_input); dims=(2, 3)); label="u") +lines!(a, dropdims(Array(output_data); dims=(2, 3)); label="Predicted") +lines!(a, Array(v_data)[:, 1, 1]; label="Expected") axislegend(a) # Compute the absolute error and plot that too, # on a separate axis. -absolute_error = v_data[:, 1, 1] .- dropdims(output_data; dims=(2, 3)) +absolute_error = Array(v_data)[:, 1, 1] .- dropdims(Array(output_data); dims=(2, 3)) a2, p2 = lines(f[2, 1], absolute_error; axis=(; ylabel="Error")) rowsize!(f.layout, 2, Aspect(1, 1 / 8)) linkxaxes!(a, a2) From 70697fb9a21d1afbffeb8e9f9650166447e765f3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Jun 2025 12:50:42 -0400 Subject: [PATCH 21/24] chore: remove sources --- docs/Project.toml | 3 +-- test/Project.toml | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index b9c185b..54a83c4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,7 +15,6 @@ Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [sources] NeuralOperators = {path = ".."} -Reactant = { url = "https://github.com/EnzymeAD/Reactant.jl", rev = "ap/ignore_derivatives_julia" } [compat] AlgebraOfGraphics = "0.10.7" @@ -30,4 +29,4 @@ NeuralOperators = "0.6" Optimisers = "0.4" Printf = "1.10" PythonCall = "0.9.23" -Reactant = "0.2.126" +Reactant = "0.2.127" diff --git a/test/Project.toml b/test/Project.toml index e0f62cb..02ee99b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -32,11 +32,8 @@ LuxTestUtils = "1.1.2" Optimisers = "0.4" Random = "1.10" ReTestItems = "1.24.0" -Reactant = "0.2.126" +Reactant = "0.2.127" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" Zygote = "0.7" - -[sources] -Reactant = { url = "https://github.com/EnzymeAD/Reactant.jl", rev = "ap/ignore_derivatives_julia" } From 195941287c8ec7b6be02d5e4e3dd4a618a6ad260 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Jun 2025 12:55:06 -0400 Subject: [PATCH 22/24] test: relax tolerance --- test/deeponet_tests.jl | 6 +++--- test/fno_tests.jl | 4 ++-- test/layers_tests.jl | 4 ++-- test/nomad_tests.jl | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/deeponet_tests.jl b/test/deeponet_tests.jl index 7ef1a30..ae9608d 100644 --- a/test/deeponet_tests.jl +++ b/test/deeponet_tests.jl @@ -46,9 +46,9 @@ end ∂u_ra, ∂ps_ra = (∂u_ra, ∂ps_ra) |> cpu_device() - @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-3 rtol = 1.0f-3 - @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-3 rtol = 1.0f-3 - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-2 rtol = 1.0f-2 + @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-2 rtol = 1.0f-2 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) end end end diff --git a/test/fno_tests.jl b/test/fno_tests.jl index 9e59d8e..0b89130 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -41,8 +41,8 @@ end ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() - @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-2 rtol = 1.0f-2 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) end end end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 725cf74..56abd57 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -44,8 +44,8 @@ end ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() - @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-3 rtol = 1.0f-3 - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-2 rtol = 1.0f-2 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) end end end diff --git a/test/nomad_tests.jl b/test/nomad_tests.jl index 8567967..d55f155 100644 --- a/test/nomad_tests.jl +++ b/test/nomad_tests.jl @@ -46,9 +46,9 @@ end ∂u_ra, ∂ps_ra = (∂u_ra, ∂ps_ra) |> cpu_device() - @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-3 rtol = 1.0f-3 - @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-3 rtol = 1.0f-3 - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-3, rtol=1.0f-3) + @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-2 rtol = 1.0f-2 + @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-2 rtol = 1.0f-2 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) end end end From 860bcedcb117f804260d8673a8793dcc9fe497c3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Jun 2025 07:59:23 -0400 Subject: [PATCH 23/24] docs: fno tutorial --- docs/pages.jl | 7 +- docs/src/tutorials/burgers_fno.md | 147 ++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 docs/src/tutorials/burgers_fno.md diff --git a/docs/pages.jl b/docs/pages.jl index c182f5f..a1dfeb4 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -5,6 +5,11 @@ pages = [ "DeepONet" => "models/deeponet.md", "NOMAD" => "models/nomad.md", ], - "Tutorials" => ["Burgers Equation" => "tutorials/burgers_deeponet.md"], + "Tutorials" => [ + "Solving Burgers Equation" => [ + "DeepONet" => "tutorials/burgers_deeponet.md", + "FNO" => "tutorials/burgers_fno.md", + ], + ], "API Reference" => "api.md", ] diff --git a/docs/src/tutorials/burgers_fno.md b/docs/src/tutorials/burgers_fno.md new file mode 100644 index 0000000..a80ec02 --- /dev/null +++ b/docs/src/tutorials/burgers_fno.md @@ -0,0 +1,147 @@ +# Burgers Equation using Fourier Neural Operator + +## Data Loading + +```@example burgers_fno +using DataDeps, MAT, MLUtils +using PythonCall, CondaPkg # For `gdown` +using Printf + +const gdown = pyimport("gdown") + +register( + DataDep( + "Burgers", + """ + Burgers' equation dataset from + [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) + + mapping between initial conditions to the solutions at the last point of time \ + evolution in some function space. + + u(x,0) -> u(x, time_end): + + * `a`: initial conditions u(x,0) + * `u`: solutions u(x,t_end) + """, + "https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe", + "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd"; + fetch_method=(url, + local_dir) -> begin + pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip"))) + end, + post_fetch_method=unpack +) +) + +filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat") + +const N = 2048 +const Δsamples = 2^3 +const grid_size = div(2^13, Δsamples) +const T = Float32 + +file = matopen(filepath) +x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :) +y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :) +close(file) + +x_data = hcat( + repeat(reshape(collect(T, range(0, 1; length=grid_size)), :, 1, 1), 1, 1, N), + reshape(permutedims(x_data, (2, 1)), grid_size, 1, N) +); +y_data = reshape(permutedims(y_data, (2, 1)), grid_size, 1, N); +``` + +## Model + +```@example burgers_fno +using Lux, NeuralOperators, Optimisers, Random, Reactant + +const cdev = cpu_device() +const xdev = reactant_device(; force=true) + +fno = FourierNeuralOperator( + gelu; + chs = (2, 32, 32, 32, 1), + modes = (16,) +) +ps, st = Lux.setup(Random.default_rng(), fno) |> xdev; +``` + +## Training + +```@example burgers_fno +dataloader = DataLoader((x_data, y_data); batchsize=128, shuffle=true) |> xdev; + +function train_model!(model, ps, st, dataloader; epochs=5000) + train_state = Training.TrainState(model, ps, st, Adam(0.0001f0)) + + for epoch in 1:epochs, data in dataloader + (_, loss, _, train_state) = Training.single_train_step!( + AutoEnzyme(), MAELoss(), data, train_state + ) + + if epoch % 100 == 1 || epoch == epochs + @printf("Epoch %d: loss = %.6e\n", epoch, loss) + end + end + + return train_state.parameters, train_state.states +end + +(ps_trained, st_trained) = train_model!(fno, ps, st, dataloader) +nothing #hide +``` + +## Plotting + +```@example burgers_fno +using CairoMakie, AlgebraOfGraphics +const AoG = AlgebraOfGraphics +AoG.set_aog_theme!() + +x_data_dev = x_data |> xdev; +y_data_dev = y_data |> xdev; + +grid = x_data[:, 1, :] +pred = first( + Reactant.with_config(; + convolution_precision=PrecisionConfig.HIGH, + dot_general_precision=PrecisionConfig.HIGH, + ) do + @jit(fno(x_data_dev, ps_trained, st_trained)) + end +) |> cdev + +data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[] +for i in 1:16 + append!(repeated_grid, vcat(grid[:, i], grid[:, i])) + append!(sequence, repeat([i], grid_size * 2)) + append!(label, repeat(["Ground Truth"], grid_size)) + append!(label, repeat(["Predictions"], grid_size)) + append!(data_sequence, vec(y_data[:, 1, i])) + append!(data_sequence, vec(pred[:, 1, i])) +end +plot_data = (; data_sequence, sequence, repeated_grid, label) + +draw( + AoG.data(plot_data) * + mapping( + :repeated_grid => L"x", + :data_sequence => L"u(x)"; + color=:label => "", + layout=:sequence => nonnumeric, + linestyle=:label => "", + ) * + visual(Lines; linewidth=4), + scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash, :dot])); + figure=(; + size=(1024, 1024), + title="Using FNO to solve the Burgers equation", + titlesize=25, + ), + axis=(; xlabelsize=25, ylabelsize=25), + legend=(; label=L"u(x)", position=:bottom, labelsize=20), +) +``` From 665b7169d7d61aba4d1e354645ecd1f2a85634a1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Jun 2025 08:33:01 -0400 Subject: [PATCH 24/24] docs: missing docs --- docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/api.md b/docs/src/api.md index 0f65e31..a643193 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -21,4 +21,5 @@ SpectralKernel ```@docs NeuralOperators.AbstractTransform +NeuralOperators.FourierTransform ```