diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 37dae14..28be623 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,9 +1,3 @@ -style = "sciml" -format_markdown = true -whitespace_in_kwargs = false -margin = 92 -indent = 4 -format_docstrings = true -separate_kwargs_with_semicolon = true -always_for_in = true -annotate_untyped_fields_with_any = false +style = "blue" +pipe_to_function_call = false +always_use_return = true diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index f20d9c9..c04c184 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -2,7 +2,7 @@ steps: - label: ":julia: Documentation" plugins: - JuliaCI/julia#v1: - version: "1.10" + version: "1" - JuliaCI/julia-coverage#v1: codecov: true command: | diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 7979b22..3722d64 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -19,29 +19,29 @@ steps: julia: - "1" - - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - - JuliaCI/julia-coverage#v1: - codecov: true - env: - JULIA_AMDGPU_CORE_MUST_LOAD: "1" - JULIA_AMDGPU_HIP_MUST_LOAD: "1" - JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" - BACKEND_GROUP: "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 - matrix: - setup: - julia: - - "1" + # - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + # plugins: + # - JuliaCI/julia#v1: + # version: "{{matrix.julia}}" + # - JuliaCI/julia-test#v1: + # test_args: "--quickfail" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # env: + # JULIA_AMDGPU_CORE_MUST_LOAD: "1" + # JULIA_AMDGPU_HIP_MUST_LOAD: "1" + # JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + # BACKEND_GROUP: "AMDGPU" + # agents: + # queue: "juliagpu" + # rocm: "*" + # rocmgpu: "*" + # if: build.message !~ /\[skip tests\]/ + # timeout_in_minutes: 60 + # matrix: + # setup: + # julia: + # - "1" env: RETESTITEMS_NWORKERS: 4 diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1b306d2..f440168 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -29,10 +29,10 @@ jobs: matrix: version: - "1.10" + - "1" os: - ubuntu-latest - macos-latest - - windows-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 0603391..aa70e3f 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -37,7 +37,7 @@ jobs: - name: "Run CompatHelper" run: | import CompatHelper - CompatHelper.main() + CompatHelper.main(; subdirs=["", "docs", "test"]) shell: julia --color=yes {0} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Project.toml b/Project.toml index e1406d5..aa608b6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,33 +1,25 @@ name = "NeuralOperators" uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" authors = ["Avik Pal "] -version = "0.5.3" +version = "0.6.0" [deps] -ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] -ArgCheck = "2.3" -ChainRulesCore = "1.24" +AbstractFFTs = "1.5.0" ConcreteStructs = "0.2.3" -FFTW = "1.8" -Lux = "1" -LuxCore = "1" -LuxLib = "1.2" -MLDataDevices = "1.2.0" -NNlib = "0.9.21" +Lux = "1.13" +LuxCore = "1.2" +LuxLib = "1.8" +NNlib = "0.9.30" Random = "1.10" -Static = "1.1.1" WeightInitializers = "1" julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 29b4d3c..54a83c4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,29 +1,32 @@ [deps] +AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + +[sources] +NeuralOperators = {path = ".."} [compat] -CairoMakie = "0.12.11" +AlgebraOfGraphics = "0.10.7" +CairoMakie = "0.13" CondaPkg = "0.2.23" DataDeps = "0.7.13" Documenter = "1.7.0" Lux = "1" -LuxCUDA = "0.3.3" MAT = "0.10.7" MLUtils = "0.4.4" -NeuralOperators = "0.5" -Optimisers = "0.3.3" +NeuralOperators = "0.6" +Optimisers = "0.4" Printf = "1.10" PythonCall = "0.9.23" -Zygote = "0.6.71" +Reactant = "0.2.127" diff --git a/docs/make.jl b/docs/make.jl index 8693552..e9048a7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -17,8 +17,9 @@ makedocs(; format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", canonical="https://docs.sciml.ai/NeuralOperators/stable/", - assets=["assets/favicon.ico"]), - pages + assets=["assets/favicon.ico"], + ), + pages, ) deploydocs(; repo="github.com/SciML/NeuralOperators.jl.git", push_preview=true) diff --git a/docs/pages.jl b/docs/pages.jl index 2c9c8a4..a1dfeb4 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -3,10 +3,13 @@ pages = [ "Pre-built Models" => [ "FNO" => "models/fno.md", "DeepONet" => "models/deeponet.md", - "NOMAD" => "models/nomad.md" + "NOMAD" => "models/nomad.md", ], "Tutorials" => [ - "Burgers Equation" => "tutorials/burgers.md" + "Solving Burgers Equation" => [ + "DeepONet" => "tutorials/burgers_deeponet.md", + "FNO" => "tutorials/burgers_fno.md", + ], ], - "API Reference" => "api.md" + "API Reference" => "api.md", ] diff --git a/docs/src/api.md b/docs/src/api.md index 0f65e31..a643193 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -21,4 +21,5 @@ SpectralKernel ```@docs NeuralOperators.AbstractTransform +NeuralOperators.FourierTransform ``` diff --git a/docs/src/models/deeponet.md b/docs/src/models/deeponet.md index 0765306..3f5a186 100644 --- a/docs/src/models/deeponet.md +++ b/docs/src/models/deeponet.md @@ -11,7 +11,7 @@ u(y) \xrightarrow{\text{branch}} & \; b \\ & \quad \searrow\\ &\quad \quad \mathcal{G}_{\theta} u(y) = \sum_k b_k t_k \\ & \quad \nearrow \\ -y \; \; \xrightarrow{\text{trunk}} \; \; & t +y \; \; \xrightarrow{\text{trunk}} \; \; & t \end{align*} ``` @@ -38,24 +38,32 @@ v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1] ### Copy-pastable code ```@example deeponet_tutorial -using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie +using NeuralOperators, Lux, Random, Optimisers, Reactant + +using CairoMakie, AlgebraOfGraphics +set_aog_theme!() +const AoG = AlgebraOfGraphics rng = Random.default_rng() +Random.seed!(rng, 1234) + +xdev = reactant_device() eval_points = 1 -data_size = 64 +batch_size = 64 dim_y = 1 m = 32 xrange = range(0, 2π; length=m) .|> Float32 -u_data = zeros(Float32, m, data_size) -α = 0.5f0 .+ 0.5f0 .* rand(Float32, data_size) +α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size) + +u_data = zeros(Float32, m, batch_size) +y_data = rand(rng, Float32, 1, eval_points) .* Float32(2π) +v_data = zeros(Float32, eval_points, batch_size) -y_data = rand(Float32, 1, eval_points, data_size) .* 2π -v_data = zeros(Float32, eval_points, data_size) -for i in 1:data_size +for i in 1:batch_size u_data[:, i] .= sin.(α[i] .* xrange) - v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[1, :, i]) + v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[1, :]) end deeponet = DeepONet( @@ -63,23 +71,32 @@ deeponet = DeepONet( Chain(Dense(1 => 4, σ), Dense(4 => 8, σ)) ) -ps, st = Lux.setup(rng, deeponet) -data = [((u_data, y_data), v_data)] +ps, st = Lux.setup(rng, deeponet) |> xdev; + +u_data = u_data |> xdev; +y_data = y_data |> xdev; +v_data = v_data |> xdev; +data = [((u_data, y_data), v_data)]; function train!(model, ps, st, data; epochs=10) losses = [] tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) for _ in 1:epochs, (x, y) in data - - _, loss, - _, tstate = Training.single_train_step!(AutoZygote(), MSELoss(), (x, y), - tstate) - push!(losses, loss) + (_, loss, _, tstate) = Training.single_train_step!( + AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) + ) + push!(losses, Float32(loss)) end return losses end losses = train!(deeponet, ps, st, data; epochs=1000) -lines(losses) +draw( + AoG.data((; losses, iteration=1:length(losses))) * + mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") * + visual(Lines); + axis=(; yscale=log10), + figure=(; title="Using DeepONet to learn the anti-derivative operator") +) ``` diff --git a/docs/src/models/fno.md b/docs/src/models/fno.md index 0de5552..1050ea0 100644 --- a/docs/src/models/fno.md +++ b/docs/src/models/fno.md @@ -18,7 +18,7 @@ convolution operation, which can be efficiently computed in the fourier domain. ```math \begin{align*} -(\Kappa_{\theta}u)(x) +(\Kappa_{\theta}u)(x) &= \int_D \kappa_{\theta}(x - y) dy \quad \forall x \in D\\ &= \mathcal{F}^{-1}(\mathcal{F}(\kappa_{\theta}) \mathcal{F}(u))(x) \quad \forall x \in D \end{align*} @@ -57,228 +57,241 @@ v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1] ``` ```@example fno_tutorial -using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie +using NeuralOperators, Lux, Random, Optimisers, Reactant + +using CairoMakie, AlgebraOfGraphics +set_aog_theme!() +const AoG = AlgebraOfGraphics rng = Random.default_rng() +Random.seed!(rng, 1234) + +xdev = reactant_device() -data_size = 128 +batch_size = 128 m = 32 xrange = range(0, 2π; length=m) .|> Float32; -u_data = zeros(Float32, m, 1, data_size); -α = 0.5f0 .+ 0.5f0 .* rand(Float32, data_size); -v_data = zeros(Float32, m, 1, data_size); +u_data = zeros(Float32, m, 1, batch_size); +α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size); +v_data = zeros(Float32, m, 1, batch_size); -for i in 1:data_size +for i in 1:batch_size u_data[:, 1, i] .= sin.(α[i] .* xrange) v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange) end -fno = FourierNeuralOperator(gelu; chs=(1, 64, 64, 128, 1), modes=(16,), permuted=Val(true)) +fno = FourierNeuralOperator(gelu; chs=(1, 64, 64, 128, 1), modes=(16,)) -ps, st = Lux.setup(rng, fno); +ps, st = Lux.setup(rng, fno) |> xdev; +u_data = u_data |> xdev; +v_data = v_data |> xdev; data = [(u_data, v_data)]; function train!(model, ps, st, data; epochs=10) losses = [] - tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.003f0)) for _ in 1:epochs, (x, y) in data - - _, loss, - _, tstate = Training.single_train_step!(AutoZygote(), MSELoss(), (x, y), - tstate) - push!(losses, loss) + (_, loss, _, tstate) = Training.single_train_step!( + AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) + ) + push!(losses, Float32(loss)) end return losses end -losses = train!(fno, ps, st, data; epochs=100) +losses = train!(fno, ps, st, data; epochs=1000) -lines(losses) +draw( + AoG.data((; losses, iteration=1:length(losses))) * + mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") * + visual(Lines); + axis=(; yscale=log10), + figure=(; title="Using Fourier Neural Operator to learn the anti-derivative operator") +) ``` ```@raw html ``` -````@example minimal_lux -using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie -```` +```@example fno_tutorial_details +using NeuralOperators, Lux, Random, Optimisers, Reactant +``` + +We will use Reactant.jl to accelerate the training process. + +```@example fno_tutorial_details +xdev = reactant_device() +``` ### Constructing training data First, we construct our training data. -````@example minimal_lux +```@example fno_tutorial_details rng = Random.default_rng() -```` +``` -`data_size` is the number of observations. +`batch_size` is the number of observations. -````@example minimal_lux -data_size = 128 -```` +```@example fno_tutorial_details +batch_size = 128 +``` -`m` is the length of a single observation, you can also interpret this as the size of the grid we're evaluating our function on. +`m` is the length of a single observation, you can also interpret this as the size of the +grid we're evaluating our function on. -````@example minimal_lux +```@example fno_tutorial_details m = 32 -```` +``` -We instantiate the domain that the function operates on -as a range from `0` to `2π`, whose length is the grid size. +We instantiate the domain that the function operates on as a range from `0` to `2π`, whose +length is the grid size. -````@example minimal_lux +```@example fno_tutorial_details xrange = range(0, 2π; length=m) .|> Float32; nothing #hide -```` +``` -Each value in the array here, `α`, will be the multiplicative -factor on the input to the sine function. +Each value in the array here, `α`, will be the multiplicative factor on the input to the +sine function. -````@example minimal_lux -α = 0.5f0 .+ 0.5f0 .* rand(Float32, data_size); +```@example fno_tutorial_details +α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size); nothing #hide -```` +``` -Now, we create our data arrays. We are storing all -of the training data in a single array, in order to -batch process them more efficiently. +Now, we create our data arrays. We are storing all of the training data in a single array, +in order to batch process them more efficiently. -````@example minimal_lux -u_data = zeros(Float32, m, 1, data_size); -v_data = zeros(Float32, m, 1, data_size); +```@example fno_tutorial_details +u_data = zeros(Float32, m, 1, batch_size); +v_data = zeros(Float32, m, 1, batch_size); nothing #hide -```` +``` -and fill the data arrays with values. -Here, `u_data` is +and fill the data arrays with values. Here, `u_data` is -````@example minimal_lux -for i in 1:data_size +```@example fno_tutorial_details +for i in 1:batch_size u_data[:, 1, i] .= sin.(α[i] .* xrange) v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange) end -```` +``` ### Creating the model -Finally, we get to the model itself. We instantiate a `FourierNeuralOperator` and provide it several parameters. +Finally, we get to the model itself. We instantiate a `FourierNeuralOperator` and provide +it several parameters. The first argument is the "activation function" for each neuron. The keyword arguments are: - - `chs` is a tuple, representing the layer sizes for each layer. - - `modes` is a 1-tuple, where the number represents the number of Fourier modes that - are preserved, and the size of the tuple represents the number of dimensions. - - `permuted` indicates that the order of the arguments is permuted such that each column - of the array represents a single observation. This is substantially faster than the usual - row access pattern, since Julia stores arrays by concatenating columns. - `Val(true)` is another way of expressing `true`, but in the type domain, so that - the compiler can see the value and use the appropriate optimizations. +- `chs` is a tuple, representing the layer sizes for each layer. +- `modes` is a 1-tuple, where the number represents the number of Fourier modes that + are preserved, and the size of the tuple represents the number of dimensions. -````@example minimal_lux +```@example fno_tutorial_details fno = FourierNeuralOperator( gelu; # activation function chs=(1, 64, 64, 128, 1), # channel weights modes=(16,), # number of Fourier modes to retain - permuted=Val(true) # structure of the data means that columns are observations ) -```` +``` -Now, we set up the model. This function returns two things, -a set of parameters and a set of states. Since the operator is -"stateless", the states are empty and will remain so. The parameters +Now, we set up the model. This function returns two things, +a set of parameters and a set of states. Since the operator is +"stateless", the states are empty and will remain so. The parameters are the weights of the neural network, and we will be modifying them in the training loop. -````@example minimal_lux -ps, st = Lux.setup(rng, fno); +```@example fno_tutorial_details +ps, st = Lux.setup(rng, fno) |> xdev; nothing #hide -```` +``` -We construct data as a vector of tuples (input, output). These are pre-batched, +We construct data as a vector of tuples (input, output). These are pre-batched, but for example if we had a lot of training data, we could dynamically load it, or create multiple batches. -````@example minimal_lux +```@example fno_tutorial_details +u_data = u_data |> xdev; +v_data = v_data |> xdev; data = [(u_data, v_data)]; nothing #hide -```` +``` ### Training the model -Now, we create a function to train the model. -An "epoch" is basically a run over all input data, -and the more epochs we have, the better the neural network gets! +Now, we create a function to train the model. An "epoch" is basically a run over all +input data, and the more epochs we have, the better the neural network gets! -````@example minimal_lux +```@example fno_tutorial_details function train!(model, ps, st, data; epochs=10) # The `losses` array is used only for visualization, # you don't actually need it to train. losses = [] # Initialize a training state and an optimizer (Adam, in this case). - tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) - # Loop over epochs, then loop over each batch of training data, and step into the training: + tstate = Training.TrainState(model, ps, st, Adam(0.003f0)) + # Loop over epochs, then loop over each batch of training data, and step into the + # training: for _ in 1:epochs for (x, y) in data - _, loss, - _, tstate = Training.single_train_step!( - AutoZygote(), MSELoss(), (x, y), - tstate) - push!(losses, loss) + (_, loss, _, tstate) = Training.single_train_step!( + AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) + ) + push!(losses, Float32(loss)) end end - return losses + return losses, tstate.parameters, tstate.states end -```` +``` Now we train our model! -````@example minimal_lux -losses = @time train!(fno, ps, st, data; epochs=500) -```` - -We can plot the losses - you can see that at some point, we hit diminishing returns. - -````@example minimal_lux -lines(losses; axis=(; yscale=log10, ylabel="Loss", xlabel="Epoch")) -```` +```@example fno_tutorial_details +losses, ps, st = @time train!(fno, ps, st, data; epochs=500) +``` ### Applying the model Let's try to actually apply this model using some input data. -````@example minimal_lux +```@example fno_tutorial_details input_data = u_data[:, 1, 1] -```` +``` -This is our input data. It's currently one-dimensional, +This is our input data. It's currently one-dimensional, but our neural network expects input in batched form, so we simply `reshape` it (a no-cost operation) to a 3d array with singleton dimensions. -````@example minimal_lux +```@example fno_tutorial_details reshaped_input = reshape(input_data, length(input_data), 1, 1) -```` +``` -Now we can pass this to `Lux.apply`: +Now we can pass this to `Lux.apply` (`@jit` is used to run the function with Reactant.jl): -````@example minimal_lux -output_data, st = Lux.apply(fno, reshaped_input, ps, st) -```` +```@example fno_tutorial_details +output_data, st = @jit Lux.apply(fno, reshaped_input, ps, st) +``` and plot it: -````@example minimal_lux -f, a, p = lines(dropdims(reshaped_input; dims=(2, 3)); label="u") -lines!(a, dropdims(output_data; dims=(2, 3)); label="Predicted") -lines!(a, v_data[:, 1, 1]; label="Expected") +```@example fno_tutorial_details +using CairoMakie, AlgebraOfGraphics +const AoG = AlgebraOfGraphics +AoG.set_aog_theme!() + +f, a, p = lines(dropdims(Array(reshaped_input); dims=(2, 3)); label="u") +lines!(a, dropdims(Array(output_data); dims=(2, 3)); label="Predicted") +lines!(a, Array(v_data)[:, 1, 1]; label="Expected") axislegend(a) # Compute the absolute error and plot that too, # on a separate axis. -absolute_error = v_data[:, 1, 1] .- dropdims(output_data; dims=(2, 3)) +absolute_error = Array(v_data)[:, 1, 1] .- dropdims(Array(output_data); dims=(2, 3)) a2, p2 = lines(f[2, 1], absolute_error; axis=(; ylabel="Error")) rowsize!(f.layout, 2, Aspect(1, 1 / 8)) linkxaxes!(a, a2) f -```` +``` diff --git a/docs/src/models/nomad.md b/docs/src/models/nomad.md index 7ecf801..0a944fc 100644 --- a/docs/src/models/nomad.md +++ b/docs/src/models/nomad.md @@ -40,46 +40,64 @@ v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1] ### Copy-pastable code ```@example nomad_tutorial -using NeuralOperators, Lux, Random, Optimisers, Zygote, CairoMakie +using NeuralOperators, Lux, Random, Optimisers, Reactant + +using CairoMakie, AlgebraOfGraphics +set_aog_theme!() +const AoG = AlgebraOfGraphics rng = Random.default_rng() +Random.seed!(rng, 1234) + +xdev = reactant_device() eval_points = 1 -data_size = 128 +batch_size = 64 dim_y = 1 m = 32 xrange = range(0, 2π; length=m) .|> Float32 -u_data = zeros(Float32, m, data_size) -α = 0.5f0 .+ 0.5f0 .* rand(Float32, data_size) +α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size) -y_data = rand(Float32, 1, eval_points, data_size) .* 2π -v_data = zeros(Float32, eval_points, data_size) -for i in 1:data_size +u_data = zeros(Float32, m, batch_size) +y_data = rand(rng, Float32, eval_points, batch_size) .* Float32(2π) +v_data = zeros(Float32, eval_points, batch_size) + +for i in 1:batch_size u_data[:, i] .= sin.(α[i] .* xrange) - v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[1, :, i]) + v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[:, i]) end -nomad = NOMAD(Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 7)), - Chain(Dense(8 => 4, σ), Dense(4 => 1))) +nomad = NOMAD( + Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 8 - eval_points)), + Chain(Dense(8 => 4, σ), Dense(4 => eval_points)) +) -ps, st = Lux.setup(rng, nomad) -data = [((u_data, y_data), v_data)] +ps, st = Lux.setup(rng, nomad) |> xdev; +u_data = u_data |> xdev; +y_data = y_data |> xdev; +v_data = v_data |> xdev; +data = [((u_data, y_data), v_data)]; function train!(model, ps, st, data; epochs=10) losses = [] - tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) for _ in 1:epochs, (x, y) in data - - _, loss, - _, tstate = Training.single_train_step!(AutoZygote(), MSELoss(), (x, y), - tstate) - push!(losses, loss) + (_, loss, _, tstate) = Training.single_train_step!( + AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false) + ) + push!(losses, Float32(loss)) end return losses end losses = train!(nomad, ps, st, data; epochs=1000) -lines(losses) +draw( + AoG.data((; losses, iteration=1:length(losses))) * + mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") * + visual(Lines); + axis=(; yscale=log10), + figure=(; title="Using NOMAD to learn the anti-derivative operator") +) ``` diff --git a/docs/src/tutorials/burgers.md b/docs/src/tutorials/burgers.md deleted file mode 100644 index 76f8ece..0000000 --- a/docs/src/tutorials/burgers.md +++ /dev/null @@ -1,136 +0,0 @@ -# Burgers Equation using DeepONet - -## Data Loading - -```@example burgers -using DataDeps, MAT, MLUtils -using PythonCall, CondaPkg # For `gdown` -using Printf - -const gdown = pyimport("gdown") - -register( - DataDep( - "Burgers", - """ - Burgers' equation dataset from - [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) - - mapping between initial conditions to the solutions at the last point of time \ - evolution in some function space. - - u(x,0) -> u(x, time_end): - - * `a`: initial conditions u(x,0) - * `u`: solutions u(x,t_end) - """, - "https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe", - "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd"; - fetch_method=(url, - local_dir) -> begin - pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip"))) - end, - post_fetch_method=unpack -) -) - -filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat") - -const N = 2048 -const Δsamples = 2^3 -const grid_size = div(2^13, Δsamples) -const T = Float32 - -file = matopen(filepath) -x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :, 1) -y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :, 1) -close(file) - -x_data = permutedims(x_data, (2, 1, 3)) -grid = reshape(T.(collect(range(0, 1; length=grid_size)')), :, grid_size, 1) -``` - -## Model - -```@example burgers -using Lux, NeuralOperators, Optimisers, Zygote, Random -using LuxCUDA - -const cdev = cpu_device() -const gdev = gpu_device() - -deeponet = DeepONet(; - branch=(size(x_data, 1), ntuple(Returns(32), 5)...), - trunk=(size(grid, 1), ntuple(Returns(32), 5)...), - branch_activation=tanh, - trunk_activation=tanh -) -ps, st = Lux.setup(Random.default_rng(), deeponet) |> gdev; -``` - -## Training - -```@example burgers -x_data_dev = x_data |> gdev -y_data_dev = y_data |> gdev -grid_dev = grid |> gdev - -function loss_function(model, ps, st, ((v, y), u)) - û, stₙ = model((v, y), ps, st) - return MAELoss()(û, u), stₙ, (;) -end - -function train_model!(model, ps, st, data; epochs=5000) - train_state = Training.TrainState(model, ps, st, Adam(0.0001f0)) - - for epoch in 1:epochs - _, loss, - _, - train_state = Training.single_train_step!( - AutoZygote(), loss_function, data, train_state) - - if epoch % 25 == 1 || epoch == epochs - @printf("Epoch %d: loss = %.6e\n", epoch, loss) - end - end - - return train_state.parameters, train_state.states -end - -ps_trained, -st_trained = train_model!( - deeponet, ps, st, ((x_data_dev, grid_dev), y_data_dev)) -``` - -## Plotting - -```@example burgers -using CairoMakie - -pred = first(deeponet((x_data_dev, grid_dev), ps_trained, st_trained)) |> cdev - -begin - fig = Figure(; size=(1024, 1024)) - - axs = [Axis(fig[i, j]) for i in 1:4, j in 1:4] - for i in 1:4, j in 1:4 - - idx = i + (j - 1) * 4 - ax = axs[i, j] - l1 = lines!(ax, vec(grid), pred[idx, :, 1]) - l2 = lines!(ax, vec(grid), y_data[idx, :, 1]) - - i == 4 && (ax.xlabel = "x") - j == 1 && (ax.ylabel = "u(x)") - - if i == 1 && j == 1 - axislegend(ax, [l1, l2], ["Predictions", "Ground Truth"]) - end - end - linkaxes!(axs...) - - fig[0, :] = Label(fig, "Burgers Equation using DeepONet"; tellwidth=false, font=:bold) - - fig -end -``` diff --git a/docs/src/tutorials/burgers_deeponet.md b/docs/src/tutorials/burgers_deeponet.md new file mode 100644 index 0000000..5e0be51 --- /dev/null +++ b/docs/src/tutorials/burgers_deeponet.md @@ -0,0 +1,145 @@ +# Burgers Equation using DeepONet + +## Data Loading + +```@example burgers +using DataDeps, MAT, MLUtils +using PythonCall, CondaPkg # For `gdown` +using Printf + +const gdown = pyimport("gdown") + +register( + DataDep( + "Burgers", + """ + Burgers' equation dataset from + [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) + + mapping between initial conditions to the solutions at the last point of time \ + evolution in some function space. + + u(x,0) -> u(x, time_end): + + * `a`: initial conditions u(x,0) + * `u`: solutions u(x,t_end) + """, + "https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe", + "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd"; + fetch_method=(url, + local_dir) -> begin + pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip"))) + end, + post_fetch_method=unpack +) +) + +filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat") + +const N = 2048 +const Δsamples = 2^3 +const grid_size = div(2^13, Δsamples) +const T = Float32 + +file = matopen(filepath) +x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :) +y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :) +close(file) + +x_data = permutedims(x_data, (2, 1)) +y_data = permutedims(y_data, (2, 1)) +grid = reshape(collect(T, range(0, 1; length=grid_size)), 1, :) +``` + +## Model + +```@example burgers +using Lux, NeuralOperators, Optimisers, Random, Reactant + +const cdev = cpu_device() +const xdev = reactant_device(; force=true) + +deeponet = DeepONet(; + branch=(size(x_data, 1), ntuple(Returns(32), 5)...), + trunk=(size(grid, 1), ntuple(Returns(32), 5)...), + branch_activation=gelu, + trunk_activation=gelu +) +ps, st = Lux.setup(Random.default_rng(), deeponet) |> xdev; +``` + +## Training + +```@example burgers +x_data_dev = x_data |> xdev; +y_data_dev = y_data |> xdev; +grid_dev = grid |> xdev; + +function train_model!(model, ps, st, data; epochs=5000) + train_state = Training.TrainState(model, ps, st, Adam(0.0001f0)) + + for epoch in 1:epochs + (_, loss, _, train_state) = Training.single_train_step!( + AutoEnzyme(), MAELoss(), data, train_state + ) + + if epoch % 100 == 1 || epoch == epochs + @printf("Epoch %d: loss = %.6e\n", epoch, loss) + end + end + + return train_state.parameters, train_state.states +end + +(ps_trained, st_trained) = train_model!( + deeponet, ps, st, ((x_data_dev, grid_dev), y_data_dev) +) +nothing #hide +``` + +## Plotting + +```@example burgers +using CairoMakie, AlgebraOfGraphics +const AoG = AlgebraOfGraphics +AoG.set_aog_theme!() + +pred = first( + Reactant.with_config(; + convolution_precision=PrecisionConfig.HIGH, + dot_general_precision=PrecisionConfig.HIGH, + ) do + @jit(deeponet((x_data_dev, grid_dev), ps_trained, st_trained)) + end +) |> cdev + +data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[] +for i in 1:16 + append!(repeated_grid, vcat(vec(grid), vec(grid))) + append!(sequence, repeat([i], grid_size * 2)) + append!(label, repeat(["Ground Truth"], grid_size)) + append!(label, repeat(["Predictions"], grid_size)) + append!(data_sequence, vec(y_data[:, i])) + append!(data_sequence, vec(pred[:, i])) +end +plot_data = (; data_sequence, sequence, repeated_grid, label) + +draw( + AoG.data(plot_data) * + mapping( + :repeated_grid => L"x", + :data_sequence => L"u(x)"; + color=:label => "", + layout=:sequence => nonnumeric, + ) * + visual(Lines), + scales(; Color=(; palette=:tab10)); + figure=(; + size=(1024, 1024), + title="Using DeepONet to solve the Burgers equation", + titlesize=25, + ), + axis=(; xlabelsize=25, ylabelsize=25), + legend=(; label=L"u(x)", position=:bottom, labelsize=20), +) +``` diff --git a/docs/src/tutorials/burgers_fno.md b/docs/src/tutorials/burgers_fno.md new file mode 100644 index 0000000..a80ec02 --- /dev/null +++ b/docs/src/tutorials/burgers_fno.md @@ -0,0 +1,147 @@ +# Burgers Equation using Fourier Neural Operator + +## Data Loading + +```@example burgers_fno +using DataDeps, MAT, MLUtils +using PythonCall, CondaPkg # For `gdown` +using Printf + +const gdown = pyimport("gdown") + +register( + DataDep( + "Burgers", + """ + Burgers' equation dataset from + [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) + + mapping between initial conditions to the solutions at the last point of time \ + evolution in some function space. + + u(x,0) -> u(x, time_end): + + * `a`: initial conditions u(x,0) + * `u`: solutions u(x,t_end) + """, + "https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe", + "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd"; + fetch_method=(url, + local_dir) -> begin + pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip"))) + end, + post_fetch_method=unpack +) +) + +filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat") + +const N = 2048 +const Δsamples = 2^3 +const grid_size = div(2^13, Δsamples) +const T = Float32 + +file = matopen(filepath) +x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :) +y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :) +close(file) + +x_data = hcat( + repeat(reshape(collect(T, range(0, 1; length=grid_size)), :, 1, 1), 1, 1, N), + reshape(permutedims(x_data, (2, 1)), grid_size, 1, N) +); +y_data = reshape(permutedims(y_data, (2, 1)), grid_size, 1, N); +``` + +## Model + +```@example burgers_fno +using Lux, NeuralOperators, Optimisers, Random, Reactant + +const cdev = cpu_device() +const xdev = reactant_device(; force=true) + +fno = FourierNeuralOperator( + gelu; + chs = (2, 32, 32, 32, 1), + modes = (16,) +) +ps, st = Lux.setup(Random.default_rng(), fno) |> xdev; +``` + +## Training + +```@example burgers_fno +dataloader = DataLoader((x_data, y_data); batchsize=128, shuffle=true) |> xdev; + +function train_model!(model, ps, st, dataloader; epochs=5000) + train_state = Training.TrainState(model, ps, st, Adam(0.0001f0)) + + for epoch in 1:epochs, data in dataloader + (_, loss, _, train_state) = Training.single_train_step!( + AutoEnzyme(), MAELoss(), data, train_state + ) + + if epoch % 100 == 1 || epoch == epochs + @printf("Epoch %d: loss = %.6e\n", epoch, loss) + end + end + + return train_state.parameters, train_state.states +end + +(ps_trained, st_trained) = train_model!(fno, ps, st, dataloader) +nothing #hide +``` + +## Plotting + +```@example burgers_fno +using CairoMakie, AlgebraOfGraphics +const AoG = AlgebraOfGraphics +AoG.set_aog_theme!() + +x_data_dev = x_data |> xdev; +y_data_dev = y_data |> xdev; + +grid = x_data[:, 1, :] +pred = first( + Reactant.with_config(; + convolution_precision=PrecisionConfig.HIGH, + dot_general_precision=PrecisionConfig.HIGH, + ) do + @jit(fno(x_data_dev, ps_trained, st_trained)) + end +) |> cdev + +data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[] +for i in 1:16 + append!(repeated_grid, vcat(grid[:, i], grid[:, i])) + append!(sequence, repeat([i], grid_size * 2)) + append!(label, repeat(["Ground Truth"], grid_size)) + append!(label, repeat(["Predictions"], grid_size)) + append!(data_sequence, vec(y_data[:, 1, i])) + append!(data_sequence, vec(pred[:, 1, i])) +end +plot_data = (; data_sequence, sequence, repeated_grid, label) + +draw( + AoG.data(plot_data) * + mapping( + :repeated_grid => L"x", + :data_sequence => L"u(x)"; + color=:label => "", + layout=:sequence => nonnumeric, + linestyle=:label => "", + ) * + visual(Lines; linewidth=4), + scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash, :dot])); + figure=(; + size=(1024, 1024), + title="Using FNO to solve the Burgers equation", + titlesize=25, + ), + axis=(; xlabelsize=25, ylabelsize=25), + legend=(; label=L"u(x)", position=:bottom, labelsize=20), +) +``` diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 78a1552..ab60f2b 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -1,20 +1,14 @@ module NeuralOperators -using ArgCheck: @argcheck -using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable +using AbstractFFTs: rfft, irfft using ConcreteStructs: @concrete -using FFTW: FFTW, irfft, rfft using Random: Random, AbstractRNG -using Static: StaticBool, False, True, known, static -using Lux -using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer -using LuxLib: batched_matmul -using MLDataDevices: AbstractDevice, AbstractGPUDevice -using NNlib: NNlib - -const BoolLike = Union{Bool, StaticBool, Val{true}, Val{false}} -const CRC = ChainRulesCore +using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer, WrappedFunction +using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer +using LuxLib: fast_activation!! +using NNlib: NNlib, batched_mul, pad_constant, gelu +using WeightInitializers: glorot_uniform include("utils.jl") @@ -27,6 +21,7 @@ include("models/nomad.jl") export FourierTransform export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel + export FourierNeuralOperator export DeepONet export NOMAD diff --git a/src/layers.jl b/src/layers.jl index 218cdaa..dfb63df 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -1,33 +1,28 @@ """ - OperatorConv(ch::Pair{<:Integer, <:Integer}, modes::Dims, - ::Type{<:AbstractTransform}; init_weight=glorot_uniform, - permuted=Val(false)) + OperatorConv( + ch::Pair{<:Integer, <:Integer}, modes::Dims, tr::AbstractTransform; + init_weight=glorot_uniform + ) ## Arguments - `ch`: A `Pair` of input and output channel size `ch_in => ch_out`, e.g. `64 => 64`. - `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension of data. - - `::Type{TR}`: The transform to operate the transformation. + - `tr`: The transform to operate the transformation. ## Keyword Arguments - `init_weight`: Initial function to initialize parameters. - - `permuted`: Whether the dim is permuted. If `permuted = Val(false)`, the layer accepts - data in the order of `(ch, x_1, ... , x_d, batch)`. Otherwise the order is - `(x_1, ... , x_d, ch, batch)`. ## Example ```jldoctest -julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}); - -julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}; permuted=Val(true)); +julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}((16,))); ``` """ @concrete struct OperatorConv <: AbstractLuxLayer - perm <: StaticBool in_chs::Int out_chs::Int prod_modes::Int @@ -35,17 +30,14 @@ julia> OperatorConv(2 => 5, (16,), FourierTransform{ComplexF32}; permuted=Val(tr init_weight end -function Base.show(io::IO, layer::OperatorConv) - print(io, "OperatorConv($(layer.in_chs) => $(layer.out_chs), $(layer.tform.modes), \ - $(printable_type(layer.tform)); permuted = $(layer.perm))") -end - function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv) in_chs, out_chs = layer.in_chs, layer.out_chs scale = real(one(eltype(layer.tform))) / (in_chs * out_chs) return (; weight=scale * layer.init_weight( - rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes)) + rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes + ) + ) end function LuxCore.parameterlength(layer::OperatorConv) @@ -53,31 +45,27 @@ function LuxCore.parameterlength(layer::OperatorConv) end function OperatorConv( - ch::Pair{<:Integer, <:Integer}, modes::Dims, ::Type{TR}; init_weight=glorot_uniform, - permuted::BoolLike=False()) where {TR <: AbstractTransform{<:Number}} - return OperatorConv(static(permuted), ch..., prod(modes), TR(modes), init_weight) + ch::Pair{<:Integer,<:Integer}, + modes::Dims, + tform::AbstractTransform; + init_weight=glorot_uniform, +) + return OperatorConv(ch..., prod(modes), tform, init_weight) end -function (conv::OperatorConv{True})(x::AbstractArray, ps, st) +function (conv::OperatorConv)(x::AbstractArray{T,N}, ps, st) where {T,N} return operator_conv(x, conv.tform, ps.weight), st end -function (conv::OperatorConv{False})(x::AbstractArray, ps, st) - N = ndims(conv.tform) - xᵀ = permutedims(x, (ntuple(i -> i + 1, N)..., 1, N + 2)) - yᵀ = operator_conv(xᵀ, conv.tform, ps.weight) - y = permutedims(yᵀ, (N + 1, 1:N..., N + 2)) - return y, st -end - function operator_conv(x, tform::AbstractTransform, weights) x_t = transform(tform, x) x_tr = truncate_modes(tform, x_t) x_p = apply_pattern(x_tr, weights) pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)] - x_padded = NNlib.pad_constant(x_p, expand_pad_dims(pad_dims), false; - dims=ntuple(identity, ndims(x_p) - 2))::typeof(x_p) + x_padded = pad_constant( + x_p, expand_pad_dims(pad_dims), false; dims=ntuple(identity, ndims(x_p) - 2) + ) return inverse(tform, x_padded, size(x)) end @@ -93,40 +81,32 @@ Construct a `OperatorConv` with `FourierTransform{ComplexF32}` as the transform. ```jldoctest julia> SpectralConv(2 => 5, (16,)); -julia> SpectralConv(2 => 5, (16,); permuted=Val(true)); - ``` """ -function SpectralConv(args...; kwargs...) - return OperatorConv(args..., FourierTransform{ComplexF32}; kwargs...) +function SpectralConv(ch::Pair{<:Integer,<:Integer}, modes::Dims; kwargs...) + return OperatorConv(ch, modes, FourierTransform{ComplexF32}(modes); kwargs...) end """ - OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims, transform::Type{TR}, - act::A=identity; permuted=Val(false), kwargs...) where {TR <: AbstractTransform, A} + OperatorKernel( + ch::Pair{<:Integer, <:Integer}, modes::Dims, transform::AbstractTransform, + act=identity; kwargs... + ) ## Arguments - `ch`: A `Pair` of input and output channel size `ch_in => ch_out`, e.g. `64 => 64`. - `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension of data. - - `::Type{TR}`: The transform to operate the transformation. - -## Keyword Arguments - - - `σ`: Activation function. - - `permuted`: Whether the dim is permuted. If `permuted = Val(true)`, the layer accepts - data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is - `(x_1, ... , x_d, ch, batch)`. + - `transform`: The transform to operate the transformation. + - `act`: Activation function. All the keyword arguments are passed to the [`OperatorConv`](@ref) constructor. ## Example ```jldoctest -julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}); - -julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val(true)); +julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}((16,))); ``` """ @@ -134,14 +114,20 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val( layer end -OperatorKernel(lin, conv) = OperatorKernel(lin, conv, identity) - function OperatorKernel( - ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR}, act=identity; - permuted::BoolLike=False(), kwargs...) where {N, TR <: AbstractTransform{<:Number}} - lin = known(static(permuted)) ? Conv(ntuple(one, N), ch) : Dense(ch) - conv = OperatorConv(ch, modes, transform; permuted, kwargs...) - return OperatorKernel(Parallel(Fix1(add_act, act), lin, conv)) + ch::Pair{<:Integer,<:Integer}, + modes::Dims{N}, + transform::AbstractTransform, + act=identity; + kwargs..., +) where {N} + return OperatorKernel( + Parallel( + Fix1(add_act, act), + Conv(ntuple(one, N), ch), + OperatorConv(ch, modes, transform; kwargs...), + ), + ) end """ @@ -155,11 +141,8 @@ Construct a `OperatorKernel` with `FourierTransform{ComplexF32}` as the transfor ```jldoctest julia> SpectralKernel(2 => 5, (16,)); -julia> SpectralKernel(2 => 5, (16,); permuted=Val(true)); - ``` """ -function SpectralKernel( - ch::Pair{<:Integer, <:Integer}, modes::Dims, act=identity; kwargs...) - return OperatorKernel(ch, modes, FourierTransform{ComplexF32}, act; kwargs...) +function SpectralKernel(ch::Pair{<:Integer,<:Integer}, modes::Dims, act=identity; kwargs...) + return OperatorKernel(ch, modes, FourierTransform{ComplexF32}(modes), act; kwargs...) end diff --git a/src/models/deeponet.jl b/src/models/deeponet.jl index 76e50ff..3a855f3 100644 --- a/src/models/deeponet.jl +++ b/src/models/deeponet.jl @@ -9,11 +9,6 @@ nets output should have the same first dimension. - `branch`: `Lux` network to be used as branch net. - `trunk`: `Lux` network to be used as trunk net. -## Keyword Arguments - - - `additional`: `Lux` network to pass the output of DeepONet, to include additional - operations for embeddings, defaults to `nothing` - ## References [1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for @@ -27,9 +22,7 @@ We are given several (b = 200) instances of the IC, discretized at 50 points eac to query the solution for 100 different locations and times [0;1]. That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100]. So, -the input for the branch net is 50 and 100 for the trunk net. Note that the inputs must be -batched so the branch input is of shape [50 x 200 x 1] and the trunk input is of shape -[2 x 100 x 1]. +the input for the branch net is 50 and 100 for the trunk net. ## Example @@ -44,23 +37,30 @@ julia> ps, st = Lux.setup(Xoshiro(), deeponet); julia> u = rand(Float32, 64, 5); -julia> y = rand(Float32, 1, 10, 5); +julia> y = rand(Float32, 1, 10); julia> size(first(deeponet((u, y), ps, st))) (10, 5) ``` """ -@concrete struct DeepONet <: AbstractLuxContainerLayer{(:branch, :trunk, :additional)} - branch - trunk - additional +@concrete struct DeepONet <: AbstractLuxWrapperLayer{:model} + model end -DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer()) +function DeepONet(branch, trunk) + return DeepONet( + Chain( + Parallel(*; branch=Chain(branch, WrappedFunction(adjoint)), trunk=trunk), + WrappedFunction(adjoint), + ), + ) +end """ - DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), - branch_activation = identity, trunk_activation = identity) + DeepONet(; + branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), + branch_activation = identity, trunk_activation = identity + ) Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and `trunk` are same. @@ -71,8 +71,6 @@ Constructs a DeepONet composed of Dense layers. Make sure the last node of `bran - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net - `branch_activation`: activation function for branch net - `trunk_activation`: activation function for trunk net - - `additional`: `Lux` network to pass the output of DeepONet, to include additional - operations for embeddings, defaults to `nothing` ## References @@ -89,90 +87,41 @@ julia> ps, st = Lux.setup(Xoshiro(), deeponet); julia> u = rand(Float32, 64, 5); -julia> y = rand(Float32, 1, 10, 5); +julia> y = rand(Float32, 1, 10); julia> size(first(deeponet((u, y), ps, st))) (10, 5) ``` """ function DeepONet(; - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity, - trunk_activation=identity, additional=NoOpLayer()) + branch=(64, 32, 32, 16), + trunk=(1, 8, 8, 16), + branch_activation=identity, + trunk_activation=identity, +) # checks for last dimension size - @argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \ - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ - work." - - branch_net = Chain([Dense(branch[i] => branch[i + 1], - ifelse(i == length(branch) - 1, identity, branch_activation)) - for i in 1:(length(branch) - 1)]...) - - trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], - ifelse(i == length(trunk) - 1, identity, trunk_activation)) - for i in 1:(length(trunk) - 1)]...) - - return DeepONet(branch_net, trunk_net, additional) -end - -function (deeponet::DeepONet)((x1, x2), ps, st::NamedTuple) - b, st_b = deeponet.branch(x1, ps.branch, st.branch) - t, st_t = deeponet.trunk(x2, ps.trunk, st.trunk) - - @argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \ - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ - work." - - additional = deeponet.additional isa NoOpLayer ? nothing : - StatefulLuxLayer{true}(deeponet.additional, ps.additional, st.additional) - out = deeponet_project(b, t, additional) - - stₙ = merge((; branch=st_b, trunk=st_t), - deeponet.additional isa NoOpLayer ? (;) : additional.st) - return out, stₙ -end - -function deeponet_project( - b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, ::Nothing) where {T1, T2} - # b [p, nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), 1, size(b, 2)) - return dropdims(sum(bᵣ .* t; dims=1); dims=1) # [N, nb] -end - -function deeponet_project( - b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, ::Nothing) where {T1, T2} - # b [p, u, nb], t [p, N, nb] - return batched_matmul(safe_batched_adjoint(b), t) # [u, N, b] -end - -function deeponet_project( - b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, ::Nothing) where {T1, T2, N} - # b [p, u_size..., nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), :, size(b, N)) - return reshape(batched_matmul(safe_batched_adjoint(bᵣ), t), - size(b)[2:(N - 1)]..., size(t, 2), size(b, N)) -end - -function deeponet_project( - b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional) where {T1, T2} - # b [p, nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), 1, size(b, 2)) - return additional(bᵣ .* t) # [p, N, nb] => [out_dims, N, nb] -end - -function deeponet_project( - b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional) where {T1, T2} - # b [p, u, nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), size(b, 2), 1, size(b, 3)) # [p, u, 1, nb] - tᵣ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # [p, 1, N, nb] - return additional(bᵣ .* tᵣ) # [p, u, N, nb] => [out_size, u, N, nb] -end - -function deeponet_project( - b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, additional) where {T1, T2, N} - # b [p, u_size..., nb], t [p, N, nb] - bᵣ = reshape(b, size(b, 1), :, 1, size(b, N)) # [p, (u_size...), 1, nb] - tᵣ = reshape(t, size(t, 1), 1, size(t, 2), size(t, 3)) # [p, 1, N, nb] - bᵣtᵣ = reshape(bᵣ .* tᵣ, size(b, 1), size(b)[2:(N - 1)]..., size(t, 2), size(b, N)) - return additional(bᵣtᵣ) # [p, u_size..., N, nb] => [out_size, u_size..., N, nb] + @assert branch[end] == trunk[end] "Branch and Trunk net must share the same amount \ + of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \ + won't work." + + branch_net = Chain( + [ + Dense( + branch[i] => branch[i + 1], + ifelse(i == length(branch) - 1, identity, branch_activation), + ) for i in 1:(length(branch) - 1) + ]..., + ) + + trunk_net = Chain( + [ + Dense( + trunk[i] => trunk[i + 1], + ifelse(i == length(trunk) - 1, identity, trunk_activation), + ) for i in 1:(length(trunk) - 1) + ]..., + ) + + return DeepONet(branch_net, trunk_net) end diff --git a/src/models/fno.jl b/src/models/fno.jl index 0a7304c..5c7f99f 100644 --- a/src/models/fno.jl +++ b/src/models/fno.jl @@ -1,15 +1,19 @@ """ FourierNeuralOperator( - σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), - permuted::Val{perm}=False, kwargs...) where {C, M, perm} + σ=gelu; + chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), + modes::Dims{M}=(16,), + kwargs... + ) where {C, M} -The Fourier neural operator is a operator learning model that uses a Fourier kernel to perform -spectral convolutions. It is a promising operator for surrogate methods, and can be regarded as -a physics operator. +The Fourier neural operator is a operator learning model that uses a Fourier kernel to +perform spectral convolutions. It is a promising operator for surrogate methods, and can be +regarded as a physics operator. The model is composed of a `Dense` layer to lift a `(d + 1)`-dimensional vector field to an `n`-dimensional vector field, an integral kernel operator which consists of four Fourier -kernels, and two `Dense` layers to project data back to the scalar field of the space of interest. +kernels, and two `Dense` layers to project data back to the scalar field of the space of +interest. ## Arguments @@ -19,11 +23,8 @@ kernels, and two `Dense` layers to project data back to the scalar field of the - `chs`: A `Tuple` or `Vector` of the size of each of the 8 channels. - `modes`: The modes to be preserved. A tuple of length `d`, where `d` is the dimension - of data. For example, one-dimensional data would have a 1-element tuple, and two-dimensional data - would have a 2-element tuple. - - `permuted`: Whether the dim is permuted. If `permuted = Val(false)`, the layer accepts - data in the order of `(ch, x_1, ... , x_d , batch)`. Otherwise the order is - `(x_1, ... , x_d, ch, batch)`. + of data. For example, one-dimensional data would have a 1-element tuple, and + two-dimensional data would have a 2-element tuple. ## Example @@ -32,33 +33,34 @@ julia> fno = FourierNeuralOperator(gelu; chs=(2, 64, 64, 128, 1), modes=(16,)); julia> ps, st = Lux.setup(Xoshiro(), fno); -julia> u = rand(Float32, 2, 1024, 5); +julia> u = rand(Float32, 1024, 2, 5); julia> size(first(fno(u, ps, st))) -(1, 1024, 5) +(1024, 1, 5) ``` """ @concrete struct FourierNeuralOperator <: AbstractLuxWrapperLayer{:model} - model <: Chain + model <: AbstractLuxLayer end -function FourierNeuralOperator(σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), - modes::Dims{M}=(16,), permuted::BoolLike=False(), kwargs...) where {C, M} - @argcheck length(chs) ≥ 5 - - map₁ = chs[1] => chs[2] - map₂ = chs[C - 2] => chs[C - 1] - map₃ = chs[C - 1] => chs[C] - - kernel_size = map(Returns(1), modes) - - lifting = known(static(permuted)) ? Conv(kernel_size, map₁) : Dense(map₁) - project = known(static(permuted)) ? - Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) : - Chain(Dense(map₂, σ), Dense(map₃)) - - mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...) - for i in 2:(C - 3)]...) - - return FourierNeuralOperator(Chain(lifting, mapping, project)) +function FourierNeuralOperator( + σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), kwargs... +) where {C,M} + @assert length(chs) ≥ 5 + + return FourierNeuralOperator( + Chain( + Conv(map(Returns(1), modes), chs[1] => chs[2]), + Chain( + [ + SpectralKernel(chs[i] => chs[i + 1], modes, σ; kwargs...) for + i in 2:(C - 3) + ]..., + ), + Chain( + Conv(map(Returns(1), modes), chs[C - 2] => chs[C - 1], σ), + Conv(map(Returns(1), modes), chs[C - 1] => chs[C]), + ), + ), + ) end diff --git a/src/models/nomad.jl b/src/models/nomad.jl index 0a50d2c..5b671ab 100644 --- a/src/models/nomad.jl +++ b/src/models/nomad.jl @@ -1,5 +1,5 @@ """ - NOMAD(approximator, decoder, concatenate) + NOMAD(approximator, decoder) Constructs a NOMAD from `approximator` and `decoder` architectures. Make sure the output from `approximator` combined with the coordinate dimension has compatible size for input to @@ -10,12 +10,6 @@ from `approximator` combined with the coordinate dimension has compatible size f - `approximator`: `Lux` network to be used as approximator net. - `decoder`: `Lux` network to be used as decoder net. -## Keyword Arguments - - - `concatenate`: function that defines the concatenation of output from `approximator` and - the coordinate dimension, defaults to concatenation along first dimension after - vectorizing the tensors - ## References [1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD: @@ -40,15 +34,19 @@ julia> size(first(nomad((u, y), ps, st))) (8, 5) ``` """ -@concrete struct NOMAD <: AbstractLuxContainerLayer{(:approximator, :decoder)} - approximator - decoder - concatenate <: Function +@concrete struct NOMAD <: AbstractLuxWrapperLayer{:model} + model +end + +function NOMAD(approximator, decoder) + return NOMAD(Chain(; approximator=Parallel(vcat, approximator, NoOpLayer()), decoder)) end """ - NOMAD(; approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8), - approximator_activation = identity, decoder_activation = identity) + NOMAD(; + approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8), + approximator_activation = identity, decoder_activation = identity + ) Constructs a NOMAD composed of Dense layers. Make sure that last node of `approximator` + coordinate length = first node of `decoder`. @@ -61,9 +59,6 @@ coordinate length = first node of `decoder`. net - `approximator_activation`: activation function for approximator net - `decoder_activation`: activation function for decoder net - - `concatenate`: function that defines the concatenation of output from `approximator` and - the coordinate dimension, defaults to concatenation along first dimension after - vectorizing the tensors ## References @@ -85,32 +80,25 @@ julia> size(first(nomad((u, y), ps, st))) (8, 5) ``` """ -function NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), - approximator_activation=identity, - decoder_activation=identity, concatenate=nomad_concatenate) - approximator_net = Chain([Dense(approximator[i] => approximator[i + 1], - approximator_activation) - for i in 1:(length(approximator) - 1)]...) - - decoder_net = Chain([Dense(decoder[i] => decoder[i + 1], decoder_activation) - for i in 1:(length(decoder) - 1)]...) - - return NOMAD(approximator_net, decoder_net, concatenate) -end - -function (nomad::NOMAD)(x, ps, st::NamedTuple) - a, st_a = nomad.approximator(x[1], ps.approximator, st.approximator) - out, st_d = nomad.decoder(nomad.concatenate(a, x[2]), ps.decoder, st.decoder) - return out, (approximator=st_a, decoder=st_d) -end - -function NOMAD(approximator_net, decoder_net; concatenate=nomad_concatenate) - NOMAD(approximator_net, decoder_net, concatenate) -end - -batch_vectorize(x::AbstractArray) = reshape(x, :, size(x, ndims(x))) - -nomad_concatenate(x::AbstractMatrix, y::AbstractMatrix) = cat(x, y; dims=1) -function nomad_concatenate(x::AbstractArray, y::AbstractArray) - return nomad_concatenate(batch_vectorize(x), batch_vectorize(y)) +function NOMAD(; + approximator=(8, 32, 32, 16), + decoder=(18, 16, 8, 8), + approximator_activation=identity, + decoder_activation=identity, +) + approximator_net = Chain( + [ + Dense(approximator[i] => approximator[i + 1], approximator_activation) for + i in 1:(length(approximator) - 1) + ]..., + ) + + decoder_net = Chain( + [ + Dense(decoder[i] => decoder[i + 1], decoder_activation) for + i in 1:(length(decoder) - 1) + ]..., + ) + + return NOMAD(approximator_net, decoder_net) end diff --git a/src/transform.jl b/src/transform.jl index e683936..9f729e8 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -14,8 +14,15 @@ abstract type AbstractTransform{T} end Base.eltype(::Type{<:AbstractTransform{T}}) where {T} = T -printable_type(T::AbstractTransform) = "$(nameof(typeof(T))){$(eltype(T))}" +function transform end +function truncate_modes end +function inverse end +""" + FourierTransform{T}(modes) + +A concrete implementation of `AbstractTransform` for Fourier transforms. +""" @concrete struct FourierTransform{T} <: AbstractTransform{T} modes end @@ -25,12 +32,13 @@ Base.ndims(T::FourierTransform) = length(T.modes) transform(ft::FourierTransform, x::AbstractArray) = rfft(x, 1:ndims(ft)) function low_pass(ft::FourierTransform, x_fft::AbstractArray) - return view(x_fft,(map(d -> 1:d, ft.modes)...),:,:) + return view(x_fft, map(d -> 1:d, ft.modes)..., :, :) end truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft) function inverse( - ft::FourierTransform, x_fft::AbstractArray{T, N}, M::NTuple{N, Int64}) where {T, N} + ft::FourierTransform, x_fft::AbstractArray{T,N}, M::NTuple{N,Int64} +) where {T,N} return real(irfft(x_fft, first(M), 1:ndims(ft))) end diff --git a/src/utils.jl b/src/utils.jl index 459a1b4..7b4bb3b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,18 +1,18 @@ function apply_pattern( - x_tr::AbstractArray{T1, N}, weights::AbstractArray{T2, 3}) where {T1, T2, N} + x_tr::AbstractArray{T1,N}, weights::AbstractArray{T2,3} +) where {T1,T2,N} x_size = size(x_tr) x_flat = reshape(x_tr, :, x_size[N - 1], x_size[N]) x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m - x_weighted = permutedims(batched_matmul(weights, x_flat_t), (3, 1, 2)) # m x o x b + x_weighted = permutedims(batched_mul(weights, x_flat_t), (3, 1, 2)) # m x o x b return reshape(x_weighted, x_size[1:(N - 2)]..., size(x_weighted)[2:3]...) end function add_act(act::F, x1, x2) where {F} y = x1 .+ x2 - act = NNlib.fast_act(act, y) - return fast_activation!!(act, y) + return fast_activation!!(NNlib.fast_act(act, y), y) end @concrete struct Fix1 <: Function @@ -27,27 +27,3 @@ Base.show(io::IO, f::Fix1) = print(io, "Fix1($(f.f), $(f.x))") function expand_pad_dims(pad_dims::Dims{N}) where {N} return ntuple(i -> isodd(i) ? 0 : pad_dims[i ÷ 2], 2N) end - -@non_differentiable expand_pad_dims(::Any) - -# Handling Wrapper Types are hard. Make sure to not construct a ReshapedArray of -# BatchedAdjoint -safe_batched_adjoint(x::AbstractArray) = NNlib.batched_adjoint(x) - -function CRC.rrule(::typeof(safe_batched_adjoint), x::AbstractArray) - return safe_batched_adjoint(x), ∇safe_batched_adjoint -end - -∇safe_batched_adjoint(Δ) = NoTangent(), safe_batched_adjoint(Δ) -function ∇safe_batched_adjoint(Δ::AbstractArray{T, 3}) where {T} - return ∇safe_batched_adjoint(get_device_type(Δ), Δ) -end - -function ∇safe_batched_adjoint(::Type{<:AbstractDevice}, Δ::AbstractArray{T, 3}) where {T} - return NoTangent(), safe_batched_adjoint(Δ) -end - -function ∇safe_batched_adjoint( - ::Type{<:AbstractGPUDevice}, Δ::AbstractArray{T, 3}) where {T} - return NoTangent(), stack(adjoint, eachslice(Δ; dims=3)) -end diff --git a/test/Project.toml b/test/Project.toml index 8890976..02ee99b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,19 +1,18 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -22,20 +21,19 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8.7" Documenter = "1.5.0" +Enzyme = "0.13.48" ExplicitImports = "1.9.0" -Hwloc = "3.2.0" -InteractiveUtils = "<0.0.1, 1" +FFTW = "1.9.0" +Hwloc = "3.2" Lux = "1" LuxCore = "1" LuxLib = "1.2" LuxTestUtils = "1.1.2" -MLDataDevices = "1" -Optimisers = "0.3.3" -Pkg = "1.10" -Preferences = "1" +Optimisers = "0.4" Random = "1.10" ReTestItems = "1.24.0" +Reactant = "0.2.127" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" -Zygote = "0.6.70" +Zygote = "0.7" diff --git a/test/deeponet_tests.jl b/test/deeponet_tests.jl index c67d31d..ae9608d 100644 --- a/test/deeponet_tests.jl +++ b/test/deeponet_tests.jl @@ -1,71 +1,54 @@ -@testitem "DeepONet" setup=[SharedTestSetup] begin - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) - - setups = [ - (u_size=(64, 5), y_size=(1, 10, 5), out_size=(10, 5), - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar"), - (u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(1, 10, 5), - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar II"), - (u_size=(64, 3, 5), y_size=(4, 10, 5), out_size=(3, 10, 5), - branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Vector"), - (u_size=(64, 4, 3, 3, 5), y_size=(4, 10, 5), out_size=(4, 3, 3, 10, 5), - branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor") - ] - - @testset "$(setup.name)" for setup in setups - u = rand(Float32, setup.u_size...) |> aType - y = rand(Float32, setup.y_size...) |> aType - deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) - - ps, st = Lux.setup(rng, deeponet) |> dev - @inferred first(deeponet((u, y), ps, st)) - @jet first(deeponet((u, y), ps, st)) - - pred = first(deeponet((u, y), ps, st)) - @test setup.out_size == size(pred) - end - - setups = [ - (u_size=(64, 5), y_size=(1, 10, 5), out_size=(4, 10, 5), - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), - additional=Dense(16 => 4), name="Scalar"), - (u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(4, 1, 10, 5), - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), - additional=Dense(16 => 4), name="Scalar II"), - (u_size=(64, 3, 5), y_size=(8, 10, 5), out_size=(4, 3, 10, 5), - branch=(64, 32, 32, 16), trunk=(8, 8, 8, 16), - additional=Dense(16 => 4), name="Vector") - ] - - @testset "Additional layer: $(setup.name)" for setup in setups - u = rand(Float32, setup.u_size...) |> aType - y = rand(Float32, setup.y_size...) |> aType - deeponet = DeepONet(; - branch=setup.branch, trunk=setup.trunk, additional=setup.additional) - - ps, st = Lux.setup(rng, deeponet) |> dev - @inferred first(deeponet((u, y), ps, st)) - @jet first(deeponet((u, y), ps, st)) - - pred = first(deeponet((u, y), ps, st)) - @test setup.out_size == size(pred) - - __f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st))) - @test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3) - end - - @testset "Embedding layer mismatch" begin - u = rand(Float32, 64, 5) |> aType - y = rand(Float32, 1, 10, 5) |> aType - - deeponet = DeepONet( - Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)), - Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)) - ) - - ps, st = Lux.setup(rng, deeponet) |> dev - @test_throws ArgumentError deeponet((u, y), ps, st) +@testitem "DeepONet" setup = [SharedTestSetup] begin + rng = StableRNG(12345) + + setups = [ + ( + u_size=(64, 5), + y_size=(1, 10), + out_size=(10, 5), + branch=(64, 32, 32, 16), + trunk=(1, 8, 8, 16), + name="Scalar", + ), + ( + u_size=(64, 5), + y_size=(4, 10), + out_size=(10, 5), + branch=(64, 32, 32, 16), + trunk=(4, 8, 8, 16), + name="Vector", + ), + ] + + xdev = reactant_device() + + @testset "$(setup.name)" for setup in setups + u = rand(Float32, setup.u_size...) + y = rand(Float32, setup.y_size...) + deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) + + ps, st = Lux.setup(rng, deeponet) + + pred = first(deeponet((u, y), ps, st)) + @test setup.out_size == size(pred) + + ps_ra, st_ra = (ps, st) |> xdev + u_ra, y_ra = (u, y) |> xdev + + @testset "check gradients" begin + ∂u_zyg, ∂ps_zyg = zygote_gradient(deeponet, (u, y), ps, st) + + ∂u_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + @jit enzyme_gradient(deeponet, (u_ra, y_ra), ps_ra, st_ra) + end + ∂u_ra, ∂ps_ra = (∂u_ra, ∂ps_ra) |> cpu_device() + + @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-2 rtol = 1.0f-2 + @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-2 rtol = 1.0f-2 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) end end end diff --git a/test/fno_tests.jl b/test/fno_tests.jl index cf586ca..0b89130 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -1,39 +1,48 @@ -@testitem "Fourier Neural Operator" setup=[SharedTestSetup] begin - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) - - setups = [ - (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), - x_size=(2, 1024, 5), y_size=(1, 1024, 5), permuted=Val(false)), - (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), - x_size=(1024, 2, 5), y_size=(1024, 1, 5), permuted=Val(true)) - ] - - @testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups - fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.permuted) - display(fno) - ps, st = Lux.setup(rng, fno) |> dev - - x = rand(rng, Float32, setup.x_size...) |> aType - y = rand(rng, Float32, setup.y_size...) |> aType - - @inferred fno(x, ps, st) - @jet fno(x, ps, st) - - @test size(first(fno(x, ps, st))) == setup.y_size - - data = [(x, y)] - @test begin - l2, l1 = train!(fno, ps, st, data; epochs=10) - l2 < l1 +@testitem "Fourier Neural Operator" setup = [SharedTestSetup] begin + rng = StableRNG(12345) + + setups = [ + ( + modes=(16,), + chs=(2, 64, 64, 64, 64, 64, 128, 1), + x_size=(1024, 2, 5), + y_size=(1024, 1, 5), + ), + ] + + @testset "$(length(setup.modes))D" for setup in setups + fno = FourierNeuralOperator(; setup.chs, setup.modes) + display(fno) + ps, st = Lux.setup(rng, fno) + + x = rand(rng, Float32, setup.x_size...) + y = rand(rng, Float32, setup.y_size...) + + @test size(first(fno(x, ps, st))) == setup.y_size + + ps_ra, st_ra = (ps, st) |> reactant_device() + x_ra, y_ra = (x, y) |> reactant_device() + + @test begin + l2, l1 = train!( + MSELoss(), AutoEnzyme(), fno, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + ) + l2 < l1 + end + + @testset "check gradients" begin + ∂x_zyg, ∂ps_zyg = zygote_gradient(fno, x, ps, st) + + ∂x_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + @jit enzyme_gradient(fno, x_ra, ps_ra, st_ra) end + ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() - __f = (x, ps) -> sum(abs2, first(fno(x, ps, st))) - @test_gradients(__f, x, - ps; - atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoTracker(), AutoEnzyme(), AutoReverseDiff()]) + @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-2 rtol = 1.0f-2 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) end end end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 0be6931..56abd57 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -1,49 +1,51 @@ -@testitem "SpectralConv & SpectralKernel" setup=[SharedTestSetup] begin - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) - - opconv = [SpectralConv, SpectralKernel] - setups = [ - (; m=(16,), permuted=Val(false), x_size=(2, 1024, 5), y_size=(128, 1024, 5)), - (; m=(16,), permuted=Val(true), x_size=(1024, 2, 5), y_size=(1024, 128, 5)), - (; m=(10, 10), permuted=Val(false), - x_size=(1, 22, 22, 5), y_size=(64, 22, 22, 5)), - (; m=(10, 10), permuted=Val(true), - x_size=(22, 22, 1, 5), y_size=(22, 22, 64, 5)) - ] - - @testset "$(op) $(length(setup.m))D: permuted = $(setup.permuted)" for setup in - setups, - op in opconv - - p = Lux.Utils.unwrap_val(setup.permuted) - in_chs = ifelse(p, setup.x_size[end - 1], first(setup.x_size)) - out_chs = ifelse(p, setup.y_size[end - 1], first(setup.y_size)) - ch = 64 => out_chs - - l1 = p ? Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch)) : - Dense(in_chs => first(ch)) - m = Chain(l1, op(ch, setup.m; setup.permuted)) - display(m) - ps, st = Lux.setup(rng, m) |> dev - - x = rand(rng, Float32, setup.x_size...) |> aType - @test size(first(m(x, ps, st))) == setup.y_size - @inferred m(x, ps, st) - @jet m(x, ps, st) - - data = [(x, aType(rand(rng, Float32, setup.y_size...)))] - @test begin - l2, l1 = train!(m, ps, st, data; epochs=10) - l2 < l1 +@testitem "SpectralConv & SpectralKernel" setup = [SharedTestSetup] begin + rng = StableRNG(12345) + + opconv = [SpectralConv, SpectralKernel] + setups = [ + (; m=(16,), x_size=(1024, 2, 5), y_size=(1024, 16, 5)), + (; m=(10, 10), x_size=(22, 22, 1, 5), y_size=(22, 22, 16, 5)), + ] + + rdev = reactant_device() + + @testset "$(op) $(length(setup.m))D" for setup in setups, op in opconv + in_chs = setup.x_size[end - 1] + out_chs = setup.y_size[end - 1] + ch = 4 => out_chs + + l1 = Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch)) + m = Chain(l1, op(ch, setup.m)) + display(m) + ps, st = Lux.setup(rng, m) + + x = rand(rng, Float32, setup.x_size...) + @test size(first(m(x, ps, st))) == setup.y_size + + ps_ra, st_ra = rdev((ps, st)) + x_ra = rdev(x) + y_ra = rdev(rand(rng, Float32, setup.y_size...)) + + @test begin + l2, l1 = train!( + MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + ) + l2 < l1 + end + + @testset "check gradients" begin + ∂x_zyg, ∂ps_zyg = zygote_gradient(m, x, ps, st) + + ∂x_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + @jit enzyme_gradient(m, x_ra, ps_ra, st_ra) end + ∂x_ra, ∂ps_ra = (∂x_ra, ∂ps_ra) |> cpu_device() - __f = (x, ps) -> sum(abs2, first(m(x, ps, st))) - @test_gradients(__f, x, - ps; - atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoTracker(), AutoEnzyme(), AutoReverseDiff()]) + @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-2 rtol = 1.0f-2 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) end end end diff --git a/test/nomad_tests.jl b/test/nomad_tests.jl index c371fa4..d55f155 100644 --- a/test/nomad_tests.jl +++ b/test/nomad_tests.jl @@ -1,28 +1,54 @@ -@testitem "NOMAD" setup=[SharedTestSetup] begin - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) - - setups = [ - (u_size=(1, 5), y_size=(1, 5), out_size=(1, 5), - approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"), - (u_size=(8, 5), y_size=(2, 5), out_size=(8, 5), - approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector") - ] - - @testset "$(setup.name)" for setup in setups - u = rand(Float32, setup.u_size...) |> aType - y = rand(Float32, setup.y_size...) |> aType - nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) - - ps, st = Lux.setup(rng, nomad) |> dev - @inferred first(nomad((u, y), ps, st)) - @jet first(nomad((u, y), ps, st)) - - pred = first(nomad((u, y), ps, st)) - @test setup.out_size == size(pred) - - __f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st))) - @test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3) +@testitem "NOMAD" setup = [SharedTestSetup] begin + rng = StableRNG(12345) + + setups = [ + ( + u_size=(1, 5), + y_size=(1, 5), + out_size=(1, 5), + approximator=(1, 16, 16, 15), + decoder=(16, 8, 4, 1), + name="Scalar", + ), + ( + u_size=(8, 5), + y_size=(2, 5), + out_size=(8, 5), + approximator=(8, 32, 32, 16), + decoder=(18, 16, 8, 8), + name="Vector", + ), + ] + + xdev = reactant_device() + + @testset "$(setup.name)" for setup in setups + u = rand(Float32, setup.u_size...) + y = rand(Float32, setup.y_size...) + nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) + + ps, st = Lux.setup(rng, nomad) + + pred = first(nomad((u, y), ps, st)) + @test setup.out_size == size(pred) + + ps_ra, st_ra = xdev((ps, st)) + u_ra, y_ra = xdev(u), xdev(y) + + @testset "check gradients" begin + ∂u_zyg, ∂ps_zyg = zygote_gradient(nomad, (u, y), ps, st) + + ∂u_ra, ∂ps_ra = Reactant.with_config(; + dot_general_precision=PrecisionConfig.HIGH, + convolution_precision=PrecisionConfig.HIGH, + ) do + @jit enzyme_gradient(nomad, (u_ra, y_ra), ps_ra, st_ra) + end + ∂u_ra, ∂ps_ra = (∂u_ra, ∂ps_ra) |> cpu_device() + + @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-2 rtol = 1.0f-2 + @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-2 rtol = 1.0f-2 + @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) end end end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 7305720..57c8042 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,19 +1,23 @@ -@testitem "doctests: Quality Assurance" tags=[:qa] begin +@testitem "doctests: Quality Assurance" tags = [:qa] begin using Documenter, NeuralOperators - DocMeta.setdocmeta!(NeuralOperators, :DocTestSetup, - :(using Lux, NeuralOperators, Random); recursive=true) + DocMeta.setdocmeta!( + NeuralOperators, + :DocTestSetup, + :(using Lux, NeuralOperators, Random); + recursive=true, + ) doctest(NeuralOperators; manual=false) end -@testitem "Aqua: Quality Assurance" tags=[:qa] begin +@testitem "Aqua: Quality Assurance" tags = [:qa] begin using Aqua Aqua.test_all(NeuralOperators; ambiguities=false) Aqua.test_ambiguities(NeuralOperators; recursive=false) end -@testitem "Explicit Imports: Quality Assurance" tags=[:qa] begin +@testitem "Explicit Imports: Quality Assurance" tags = [:qa] begin using ExplicitImports, Lux # Skip our own packages @@ -22,8 +26,4 @@ end @test check_no_self_qualified_accesses(NeuralOperators) === nothing @test check_all_explicit_imports_via_owners(NeuralOperators) === nothing @test check_all_qualified_accesses_via_owners(NeuralOperators) === nothing - if VERSION >= v"1.11-" - @test_broken check_all_explicit_imports_are_public(NeuralOperators) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(NeuralOperators) === nothing # mostly upstream problems - end end diff --git a/test/runtests.jl b/test/runtests.jl index 1987473..c777622 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,31 +1,16 @@ -using Preferences - -Preferences.set_preferences!("LuxLib", "instability_check" => "error") -Preferences.set_preferences!("LuxCore", "instability_check" => "error") - -using ReTestItems, Pkg, Test, InteractiveUtils, Hwloc, NeuralOperators +using ReTestItems, Test, Hwloc, NeuralOperators, Reactant const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) -const EXTRA_PKGS = String[] -(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA") -(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") - -if !isempty(EXTRA_PKGS) - @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.add(EXTRA_PKGS) - Pkg.update() - Base.retry_load_extensions() - Pkg.instantiate() -end - -const RETESTITEMS_NWORKERS = parse( - Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) -const RETESTITEMS_NWORKER_THREADS = parse(Int, - get(ENV, "RETESTITEMS_NWORKER_THREADS", - string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) +const RETESTITEMS_NWORKER_THREADS = parse( + Int, get(ENV, "RETESTITEMS_NWORKER_THREADS", string(Hwloc.num_virtual_cores())) +) @testset "NeuralOperators.jl Tests" begin - ReTestItems.runtests(NeuralOperators; nworkers=RETESTITEMS_NWORKERS, - nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) + ReTestItems.runtests( + NeuralOperators; + nworkers=1, + nworker_threads=RETESTITEMS_NWORKER_THREADS, + testitem_timeout=3600, + ) end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 6dcb2bf..fd16b02 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,56 +1,39 @@ @testsetup module SharedTestSetup import Reexport: @reexport -@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, LuxTestUtils -using MLDataDevices - -LuxTestUtils.jet_target_modules!(["NeuralOperators", "Lux", "LuxLib"]) +@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, Reactant, Enzyme +using LuxTestUtils: check_approx +using FFTW const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) -if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" - using LuxCUDA -end - -if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" - using AMDGPU -end - -cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" -function cuda_testing() - return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && - MLDataDevices.functional(CUDADevice) -end -function amdgpu_testing() - return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && - MLDataDevices.functional(AMDGPUDevice) -end - -const MODES = begin - modes = [] - cpu_testing() && push!(modes, ("CPU", Array, CPUDevice(), false)) - cuda_testing() && push!(modes, ("CUDA", CuArray, CUDADevice(), true)) - amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, AMDGPUDevice(), true)) - modes -end - train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...) function train!(loss, backend, model, ps, st, data; epochs=10) - l1 = loss(model, ps, st, first(data)) + l1 = @jit loss(model, ps, st, first(data)) tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) for _ in 1:epochs, (x, y) in data - _, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate) end - l2 = loss(model, ps, st, first(data)) + l2 = @jit loss(model, tstate.parameters, tstate.states, first(data)) return l2, l1 end +sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) + +function zygote_gradient(model, x, ps, st) + return Zygote.gradient(sumabs2first, model, x, ps, st)[2:3] +end + +function enzyme_gradient(model, x, ps, st) + return Enzyme.gradient(Reverse, sumabs2first, Const(model), x, ps, Const(st))[2:3] +end + export check_approx -export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, train! +export BACKEND_GROUP, train! +export zygote_gradient, enzyme_gradient end diff --git a/test/utils_tests.jl b/test/utils_tests.jl deleted file mode 100644 index 801e4a9..0000000 --- a/test/utils_tests.jl +++ /dev/null @@ -1,58 +0,0 @@ -@testitem "utils" setup=[SharedTestSetup] begin - import NeuralOperators: deeponet_project, nomad_concatenate, batch_vectorize - - @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = StableRNG(12345) - - setups = [ - (b_size=(16, 5), t_size=(16, 10, 5), out_size=(10, 5), - additional=NoOpLayer(), name="Scalar"), - (b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(1, 10, 5), - additional=NoOpLayer(), name="Scalar II"), - (b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(3, 10, 5), - additional=NoOpLayer(), name="Vector"), - (b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), - out_size=(4, 3, 3, 10, 5), additional=NoOpLayer(), name="Tensor"), - (b_size=(16, 5), t_size=(16, 10, 5), out_size=(4, 10, 5), - additional=Dense(16 => 4), name="additional : Scalar"), - (b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(4, 1, 10, 5), - additional=Dense(16 => 4), name="additional : Scalar II"), - (b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(4, 3, 10, 5), - additional=Dense(16 => 4), name="additional : Vector"), - (b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), out_size=(3, 4, 3, 4, 10, 5), - additional=Chain(Dense(16 => 4), ReshapeLayer((3, 4, 3, 4, 10))), - name="additional : Tensor") - ] - - @testset "project : $(setup.name)" for setup in setups - b = rand(Float32, setup.b_size...) |> aType - t = rand(Float32, setup.t_size...) |> aType - - ps, st = Lux.setup(rng, setup.additional) |> dev - additional = setup.additional isa NoOpLayer ? nothing : - StatefulLuxLayer{true}(setup.additional, ps, st) - - @inferred deeponet_project(b, t, additional) - @jet deeponet_project(b, t, additional) - @test setup.out_size == size(deeponet_project(b, t, additional)) - end - - setups = [(x_size=(6, 5), y_size=(4, 5), out_size=(10, 5), name="Scalar"), - (x_size=(12, 5), y_size=(8, 5), out_size=(20, 5), name="Vector I"), - (x_size=(4, 6, 5), y_size=(6, 5), out_size=(30, 5), name="Vector II"), - (x_size=(4, 2, 3, 5), y_size=(2, 2, 3, 5), out_size=(36, 5), name="Tensor")] - - @testset "nomad_concatenate $(setup.name)" for setup in setups - x_size = rand(Float32, setup.x_size...) |> aType - y_size = rand(Float32, setup.y_size...) |> aType - - @test setup.out_size == size(nomad_concatenate(x_size, y_size)) - end - - @testset "batch vectorize" begin - x_size = (4, 2, 3) - x = rand(Float32, x_size..., 5) |> aType - @test size(batch_vectorize(x)) == (prod(x_size), 5) - end - end -end