From e7c12f681d2dc8d5c11e6e62faad4bbab21812f1 Mon Sep 17 00:00:00 2001 From: Yi-Te Huang Date: Tue, 2 Sep 2025 14:31:26 +0800 Subject: [PATCH 1/2] support autodiff for `SteadyStateODESolver` --- src/steadystate.jl | 5 +++-- test/ext-test/cpu/autodiff/autodiff.jl | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/steadystate.jl b/src/steadystate.jl index 49ce2f44c..cd7b8de01 100644 --- a/src/steadystate.jl +++ b/src/steadystate.jl @@ -214,7 +214,8 @@ end function _steadystate(L::AbstractQuantumObject{SuperOperator}, solver::SteadyStateODESolver; kwargs...) ψ0 = isnothing(solver.ψ0) ? rand_ket(L.dimensions) : solver.ψ0 ftype = _float_type(ψ0) - tlist = [ftype(0), ftype(solver.tmax)] + tmax = ftype(solver.tmax) + tlist = [ftype(0), tmax] # overwrite some kwargs and throw warning message to tell the users that we are ignoring these settings haskey(kwargs, :progress_bar) && @warn "Ignore keyword argument 'progress_bar' for SteadyStateODESolver" @@ -222,7 +223,7 @@ function _steadystate(L::AbstractQuantumObject{SuperOperator}, solver::SteadySta haskey(kwargs, :saveat) && @warn "Ignore keyword argument 'saveat' for SteadyStateODESolver" kwargs2 = merge( NamedTuple(kwargs), # we convert to NamedTuple just in case if kwargs is empty - (progress_bar = Val(false), save_everystep = false, saveat = ftype[]), + (progress_bar = Val(false), save_everystep = false, saveat = ftype[tmax]), ) # add terminate condition (callback) diff --git a/test/ext-test/cpu/autodiff/autodiff.jl b/test/ext-test/cpu/autodiff/autodiff.jl index 3f8e601a4..addf532cf 100644 --- a/test/ext-test/cpu/autodiff/autodiff.jl +++ b/test/ext-test/cpu/autodiff/autodiff.jl @@ -67,6 +67,17 @@ function my_f_mesolve(p) return real(expect(a' * a, sol.states[end])) end +function my_f_steadystate(p) + ρss = steadystate( + L, + SteadyStateODESolver(ψ0 = ψ0_mesolve, tmax = tlist_mesolve[end]); + params = p, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + ) + + return real(expect(a' * a, ρss)) +end + # Analytical solution n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2)) @@ -113,8 +124,12 @@ n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2)) my_f_mesolve_direct(params) my_f_mesolve(params) + my_f_steadystate(params) + # calculate exact solution and check if steadystate works grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1] + grad_ss = Zygote.gradient(my_f_steadystate, params)[1] + @test grad_ss ≈ grad_exact atol=1e-5 @testset "ForwardDiff.jl" begin grad_qt = ForwardDiff.gradient(my_f_mesolve_direct, params) From d1b186ad1b12ab4d5c4d1c00c1f52096fde44927 Mon Sep 17 00:00:00 2001 From: Yi-Te Huang Date: Tue, 2 Sep 2025 14:54:16 +0800 Subject: [PATCH 2/2] fix typo --- test/ext-test/cpu/autodiff/autodiff.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ext-test/cpu/autodiff/autodiff.jl b/test/ext-test/cpu/autodiff/autodiff.jl index addf532cf..36ecc5294 100644 --- a/test/ext-test/cpu/autodiff/autodiff.jl +++ b/test/ext-test/cpu/autodiff/autodiff.jl @@ -69,8 +69,8 @@ end function my_f_steadystate(p) ρss = steadystate( - L, - SteadyStateODESolver(ψ0 = ψ0_mesolve, tmax = tlist_mesolve[end]); + L; + solver = SteadyStateODESolver(ψ0 = ψ0_mesolve, tmax = tlist_mesolve[end]), params = p, sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), )