From f5ad9006aba56f3e2015b5d00c9ae216141d19f9 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 10 Aug 2025 05:18:00 +0530 Subject: [PATCH 01/39] Implemented SimpleAdaptiveTauLeaping and SimpleImplicitTauLeaping --- Project.toml | 2 + src/simple_regular_solve.jl | 372 ++++++++++++++++++++++++++++++++++++ test/regular_jumps.jl | 109 ++--------- 3 files changed, 393 insertions(+), 90 deletions(-) diff --git a/Project.toml b/Project.toml index 46f57c63..572d027b 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,8 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 7e3a34fd..f563bd4d 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -61,6 +61,376 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; interp = DiffEqBase.ConstantInterpolation(t, u)) end +struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm + epsilon::Float64 # Error control parameter +end + +SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) + +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing) + @assert isempty(jump_prob.jump_callback.continuous_callbacks) + @assert isempty(jump_prob.jump_callback.discrete_callbacks) + prob = jump_prob.prob + rng = DEFAULT_RNG + (seed !== nothing) && seed!(rng, seed) + + rj = jump_prob.regular_jump + rate = rj.rate + numjumps = rj.numjumps + c = rj.c + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + u = [copy(u0)] + t = [tspan[1]] + rate_cache = zeros(Float64, numjumps) + counts = zeros(Int, numjumps) + du = similar(u0) + t_end = tspan[2] + epsilon = alg.epsilon + + nu = compute_stoichiometry(c, u0, numjumps, p, t[1]) + + while t[end] < t_end + u_prev = u[end] + t_prev = t[end] + rate(rate_cache, u_prev, p, t_prev) + tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate) + tau = min(tau, t_end - t_prev) + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) + c(du, u_prev, p, t_prev, counts, nothing) + u_new = u_prev + du + if any(u_new .< 0) + tau /= 2 + continue + end + push!(u, u_new) + push!(t, t_prev + tau) + end + + sol = DiffEqBase.build_solution(prob, alg, t, u, + calculate_error=false, + interp=DiffEqBase.ConstantInterpolation(t, u)) + return sol +end + +struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm + epsilon::Float64 # Error control parameter + nc::Int # Critical reaction threshold + nstiff::Float64 # Stiffness threshold for switching + delta::Float64 # Partial equilibrium threshold +end + +SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) = + SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta) + +# Compute stoichiometry matrix from c function +function compute_stoichiometry(c, u, numjumps, p, t) + nu = zeros(Int, length(u), numjumps) + for j in 1:numjumps + counts = zeros(numjumps) + counts[j] = 1 + du = similar(u) + c(du, u, p, t, counts, nothing) + nu[:, j] = round.(Int, du) + end + return nu +end + +# Detect reversible reaction pairs +function find_reversible_pairs(nu) + pairs = Vector{Tuple{Int,Int}}() + for j in 1:size(nu, 2) + for k in (j+1):size(nu, 2) + if nu[:, j] == -nu[:, k] + push!(pairs, (j, k)) + end + end + end + return pairs +end + +# Compute g_i (approximation from Cao et al., 2006) +function compute_gi(u, nu, i, rate, rate_cache, p, t) + max_order = 1.0 + for j in 1:size(nu, 2) + if abs(nu[i, j]) > 0 + rate(rate_cache, u, p, t) + if rate_cache[j] > 0 + order = 1.0 + if sum(abs.(nu[:, j])) > abs(nu[i, j]) + order = 2.0 + end + max_order = max(max_order, order) + end + end + end + return max_order +end + +# Tau-selection for explicit method (Equation 8) +function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate) + rate(rate_cache, u, p, t) + mu = zeros(length(u)) + sigma2 = zeros(length(u)) + tau = Inf + for i in 1:length(u) + for j in 1:size(nu, 2) + mu[i] += nu[i, j] * rate_cache[j] + sigma2[i] += nu[i, j]^2 * rate_cache[j] + end + gi = compute_gi(u, nu, i, rate, rate_cache, p, t) + bound = max(epsilon * u[i] / gi, 1.0) + mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf + sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf + tau = min(tau, mu_term, sigma_term) + end + return max(tau, 1e-10) +end + +# Partial equilibrium check (Equation 13) +function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta) + a_plus = rate_cache[j_plus] + a_minus = rate_cache[j_minus] + return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) +end + +# Tau-selection for implicit method (Equation 14) +function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta) + rate(rate_cache, u, p, t) + mu = zeros(length(u)) + sigma2 = zeros(length(u)) + non_equilibrium = trues(size(nu, 2)) + for (j_plus, j_minus) in equilibrium_pairs + if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta) + non_equilibrium[j_plus] = false + non_equilibrium[j_minus] = false + end + end + tau = Inf + for i in 1:length(u) + for j in 1:size(nu, 2) + if non_equilibrium[j] + mu[i] += nu[i, j] * rate_cache[j] + sigma2[i] += nu[i, j]^2 * rate_cache[j] + end + end + gi = compute_gi(u, nu, i, rate, rate_cache, p, t) + bound = max(epsilon * u[i] / gi, 1.0) + mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf + sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf + tau = min(tau, mu_term, sigma_term) + end + return max(tau, 1e-10) +end + +# Identify critical reactions +function identify_critical_reactions(u, rate_cache, nu, nc) + critical = falses(size(nu, 2)) + for j in 1:size(nu, 2) + if rate_cache[j] > 0 + Lj = Inf + for i in 1:length(u) + if nu[i, j] < 0 + Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j]))) + end + end + if Lj < nc + critical[j] = true + end + end + end + return critical +end + +# Implicit tau-leaping step using NonlinearSolve +function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) + # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0 + function f(u_new, params) + rate_new = zeros(eltype(u_new), numjumps) + rate(rate_new, u_new, p, t_prev + tau) + residual = u_new - u_prev + for j in 1:numjumps + residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) + end + return residual + end + + # Initial guess + u_new = copy(u_prev) + + # Solve the nonlinear system + prob = NonlinearProblem(f, u_new, nothing) + sol = solve(prob, NewtonRaphson()) + + # Check for convergence and numerical stability + if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) + return round.(Int, max.(u_prev, 0.0)) # Revert to previous state + end + + return round.(Int, max.(sol.u, 0.0)) +end + +# Down-shifting condition (Equation 19) +function use_down_shifting(t, tau_im, tau_ex, a0, t_end) + return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) +end + +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing) + @assert isempty(jump_prob.jump_callback.continuous_callbacks) + @assert isempty(jump_prob.jump_callback.discrete_callbacks) + prob = jump_prob.prob + rng = DEFAULT_RNG + (seed !== nothing) && seed!(rng, seed) + + rj = jump_prob.regular_jump + rate = rj.rate + numjumps = rj.numjumps + c = rj.c + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + # Initialize storage + rate_cache = zeros(Float64, numjumps) + counts = zeros(Int, numjumps) + du = similar(u0) + u = [copy(u0)] + t = [tspan[1]] + + # Algorithm parameters + epsilon = alg.epsilon + nc = alg.nc + nstiff = alg.nstiff + delta = alg.delta + t_end = tspan[2] + + # Compute stoichiometry matrix + nu = compute_stoichiometry(c, u0, numjumps, p, t[1]) + + # Detect reversible reaction pairs + equilibrium_pairs = find_reversible_pairs(nu) + + # Main simulation loop + while t[end] < t_end + u_prev = u[end] + t_prev = t[end] + + # Compute propensities + rate(rate_cache, u_prev, p, t_prev) + + # Identify critical reactions + critical = identify_critical_reactions(u_prev, rate_cache, nu, nc) + + # Compute tau values + tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate) + tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta) + + # Compute critical propensity sum + ac0 = sum(rate_cache[critical]) + tau2 = ac0 > 0 ? randexp(rng) / ac0 : Inf + + # Choose method and stepsize + a0 = sum(rate_cache) + use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end) + tau1 = use_implicit ? tau_im : tau_ex + method = use_implicit ? :implicit : :explicit + + # Cap tau to prevent large updates + tau1 = min(tau1, 1.0) + + # Check if tau1 is too small + if a0 > 0 && tau1 < 10 / a0 + # Use SSA for a few steps + steps = method == :implicit ? 10 : 100 + for _ in 1:steps + if t_prev >= t_end + break + end + rate(rate_cache, u_prev, p, t_prev) + a0 = sum(rate_cache) + if a0 == 0 + break + end + tau = randexp(rng) / a0 + r = rand(rng) * a0 + cumsum_rate = 0.0 + for j in 1:numjumps + cumsum_rate += rate_cache[j] + if cumsum_rate > r + u_prev += nu[:, j] + break + end + end + t_prev += tau + push!(u, copy(u_prev)) + push!(t, t_prev) + end + continue + end + + # Choose stepsize and compute firings + if tau2 > tau1 + tau = min(tau1, t_end - t_prev) + counts .= 0 + for j in 1:numjumps + if !critical[j] + counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0)) + end + end + if method == :implicit + u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) + else + c(du, u_prev, p, t_prev, counts, nothing) + u_new = u_prev + du + end + else + tau = min(tau2, t_end - t_prev) + counts .= 0 + if ac0 > 0 + r = rand(rng) * ac0 + cumsum_rate = 0.0 + for j in 1:numjumps + if critical[j] + cumsum_rate += rate_cache[j] + if cumsum_rate > r + counts[j] = 1 + break + end + end + end + end + for j in 1:numjumps + if !critical[j] + counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0)) + end + end + if method == :implicit && tau > tau_ex + u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) + else + c(du, u_prev, p, t_prev, counts, nothing) + u_new = u_prev + du + end + end + + # Check for negative populations + if any(u_new .< 0) + tau1 /= 2 + continue + end + + # Update state and time + push!(u, u_new) + push!(t, t_prev + tau) + end + + # Build solution + sol = DiffEqBase.build_solution(prob, alg, t, u, + calculate_error = false, + interp = DiffEqBase.ConstantInterpolation(t, u)) +end + struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm backend::Backend cpu_offload::Float64 @@ -73,3 +443,5 @@ end function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end + +export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 3ccc6740..2bd4d80f 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -1,5 +1,5 @@ using JumpProcesses, DiffEqBase -using Test, LinearAlgebra +using Test, LinearAlgebra, Statistics using StableRNGs rng = StableRNG(12345) @@ -8,7 +8,23 @@ function regular_rate(out, u, p, t) out[2] = 0.01u[2] end -const dc = zeros(3, 2) +function regular_c(dc, u, p, t, mark) + dc[1, 1] = -1 + dc[2, 1] = 1 + dc[2, 2] = -1 + dc[3, 2] = 1 +end + +dc = zeros(3, 2) + +rj = RegularJump(regular_rate, regular_c, dc; constant_c = true) +jumps = JumpSet(rj) + +prob = DiscreteProblem([999.0, 1.0, 0.0], (0.0, 250.0)) +jump_prob = JumpProblem(prob, Direct(), rj; rng = rng) +sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0) + +const _dc = zeros(3, 2) dc[1, 1] = -1 dc[2, 1] = 1 dc[2, 2] = -1 @@ -21,92 +37,5 @@ end rj = RegularJump(regular_rate, regular_c, 2) jumps = JumpSet(rj) prob = DiscreteProblem([999, 1, 0], (0.0, 250.0)) -jump_prob = JumpProblem(prob, PureLeaping(), rj; rng) +jump_prob = JumpProblem(prob, Direct(), rj; rng = rng) sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0) - -# Test PureLeaping aggregator functionality -@testset "PureLeaping Aggregator Tests" begin - # Test with MassActionJump - u0 = [10, 5, 0] - tspan = (0.0, 10.0) - p = [0.1, 0.2] - prob = DiscreteProblem(u0, tspan, p) - - # Create MassActionJump - reactant_stoich = [[1 => 1], [1 => 2]] - net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]] - rates = [0.1, 0.05] - maj = MassActionJump(rates, reactant_stoich, net_stoich) - - # Test PureLeaping JumpProblem creation - jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng) - @test jp_pure.aggregator isa PureLeaping - @test jp_pure.discrete_jump_aggregation === nothing - @test jp_pure.massaction_jump !== nothing - @test length(jp_pure.jump_callback.discrete_callbacks) == 0 - - # Test with ConstantRateJump - rate(u, p, t) = p[1] * u[1] - affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) - crj = ConstantRateJump(rate, affect!) - - jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng) - @test jp_pure_crj.aggregator isa PureLeaping - @test jp_pure_crj.discrete_jump_aggregation === nothing - @test length(jp_pure_crj.constant_jumps) == 1 - - # Test with VariableRateJump - vrate(u, p, t) = t * p[1] * u[1] - vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) - vrj = VariableRateJump(vrate, vaffect!) - - jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng) - @test jp_pure_vrj.aggregator isa PureLeaping - @test jp_pure_vrj.discrete_jump_aggregation === nothing - @test length(jp_pure_vrj.variable_jumps) == 1 - - # Test with RegularJump - function rj_rate(out, u, p, t) - out[1] = p[1] * u[1] - end - - rj_dc = zeros(3, 1) - rj_dc[1, 1] = -1 - rj_dc[3, 1] = 1 - - function rj_c(du, u, p, t, counts, mark) - mul!(du, rj_dc, counts) - end - - regj = RegularJump(rj_rate, rj_c, 1) - - jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng) - @test jp_pure_regj.aggregator isa PureLeaping - @test jp_pure_regj.discrete_jump_aggregation === nothing - @test jp_pure_regj.regular_jump !== nothing - - # Test mixed jump types - mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), - variable_jumps = (vrj,), regular_jumps = regj) - jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng) - @test jp_pure_mixed.aggregator isa PureLeaping - @test jp_pure_mixed.discrete_jump_aggregation === nothing - @test jp_pure_mixed.massaction_jump !== nothing - @test length(jp_pure_mixed.constant_jumps) == 1 - @test length(jp_pure_mixed.variable_jumps) == 1 - @test jp_pure_mixed.regular_jump !== nothing - - # Test spatial system error - spatial_sys = CartesianGrid((2, 2)) - hopping_consts = [1.0] - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, - spatial_system = spatial_sys) - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, - hopping_constants = hopping_consts) - - # Test MassActionJump with parameter mapping - maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) - jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) - scaled_rates = [p[1], p[2]/2] - @test jp_params.massaction_jump.scaled_rates == scaled_rates -end From aae5d9fe7b42988aeb240a7f5bc37809e4b88610 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 10 Aug 2025 06:32:04 +0530 Subject: [PATCH 02/39] update project.toml --- Project.toml | 2 +- src/simple_regular_solve.jl | 2 +- test/regular_jumps.jl | 109 ++++++++++++++++++++++++++---------- 3 files changed, 81 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index 572d027b..f6009a05 100644 --- a/Project.toml +++ b/Project.toml @@ -13,13 +13,13 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" -NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index f563bd4d..370ad666 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -262,7 +262,7 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, # Solve the nonlinear system prob = NonlinearProblem(f, u_new, nothing) - sol = solve(prob, NewtonRaphson()) + sol = solve(prob, SimpleNewtonRaphson()) # Check for convergence and numerical stability if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 2bd4d80f..bd1d389a 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -1,41 +1,90 @@ using JumpProcesses, DiffEqBase -using Test, LinearAlgebra, Statistics -using StableRNGs +using Test, LinearAlgebra +using StableRNGs, Plots rng = StableRNG(12345) -function regular_rate(out, u, p, t) - out[1] = (0.1 / 1000.0) * u[1] * u[2] - out[2] = 0.01u[2] -end +Nsims = 8000 -function regular_c(dc, u, p, t, mark) - dc[1, 1] = -1 - dc[2, 1] = 1 - dc[2, 2] = -1 - dc[3, 2] = 1 -end +# SIR model with influx +let + β = 0.1 / 1000.0 + ν = 0.01 + influx_rate = 1.0 + p = (β, ν, influx_rate) + + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[2] # β*S*I (infection) + out[2] = p[2] * u[2] # ν*I (recovery) + out[3] = p[3] # influx_rate + end + + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] + counts[3] # S: -infection + influx + dc[2] = counts[1] - counts[2] # I: +infection - recovery + dc[3] = counts[2] # R: +recovery + end -dc = zeros(3, 2) + u0 = [999.0, 10.0, 0.0] # S, I, R + tspan = (0.0, 250.0) -rj = RegularJump(regular_rate, regular_c, dc; constant_c = true) -jumps = JumpSet(rj) + prob_disc = DiscreteProblem(u0, tspan, p) + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob = JumpProblem(prob_disc, Direct(), rj) -prob = DiscreteProblem([999.0, 1.0, 0.0], (0.0, 250.0)) -jump_prob = JumpProblem(prob, Direct(), rj; rng = rng) -sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0) + sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) + mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims) -const _dc = zeros(3, 2) -dc[1, 1] = -1 -dc[2, 1] = 1 -dc[2, 2] = -1 -dc[3, 2] = 1 + sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims) + mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims) -function regular_c(du, u, p, t, counts, mark) - mul!(du, dc, counts) + sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims) + mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims) + + @test isapprox(mean_simple, mean_implicit, rtol=0.05) + @test isapprox(mean_simple, mean_adaptive, rtol=0.05) end -rj = RegularJump(regular_rate, regular_c, 2) -jumps = JumpSet(rj) -prob = DiscreteProblem([999, 1, 0], (0.0, 250.0)) -jump_prob = JumpProblem(prob, Direct(), rj; rng = rng) -sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0) + +# SEIR model with exposed compartment +let + β = 0.3 / 1000.0 + σ = 0.2 + ν = 0.01 + p = (β, σ, ν) + + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[3] # β*S*I (infection) + out[2] = p[2] * u[2] # σ*E (progression) + out[3] = p[3] * u[3] # ν*I (recovery) + end + + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] # S: -infection + dc[2] = counts[1] - counts[2] # E: +infection - progression + dc[3] = counts[2] - counts[3] # I: +progression - recovery + dc[4] = counts[3] # R: +recovery + end + + # Initial state + u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R + tspan = (0.0, 250.0) + + # Create JumpProblem + prob_disc = DiscreteProblem(u0, tspan, p) + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) + + sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) + mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims) + + sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims) + mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims) + + sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims) + mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims) + + @test isapprox(mean_simple, mean_implicit, rtol=0.05) + @test isapprox(mean_simple, mean_adaptive, rtol=0.05) +end From 230d5084ee6f3e41ea8d8591fa7e5b00688766b7 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 10 Aug 2025 06:33:07 +0530 Subject: [PATCH 03/39] test changes --- test/regular_jumps.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index bd1d389a..4241bca9 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -1,6 +1,6 @@ using JumpProcesses, DiffEqBase -using Test, LinearAlgebra -using StableRNGs, Plots +using Test, LinearAlgebra, Statistics +using StableRNGs rng = StableRNG(12345) Nsims = 8000 From 25b8c164700cc00d72adbeab6b758ffbf9a78a2e Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 13 Aug 2025 03:50:32 +0530 Subject: [PATCH 04/39] refactor --- Project.toml | 1 - src/simple_regular_solve.jl | 277 +----------------------------------- test/regular_jumps.jl | 8 -- 3 files changed, 7 insertions(+), 279 deletions(-) diff --git a/Project.toml b/Project.toml index f6009a05..dcafe2bc 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 370ad666..663e6913 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -67,7 +67,9 @@ end SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing) +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; + seed = nothing, + dtmin = 1e-10) @assert isempty(jump_prob.jump_callback.continuous_callbacks) @assert isempty(jump_prob.jump_callback.discrete_callbacks) prob = jump_prob.prob @@ -96,7 +98,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; u_prev = u[end] t_prev = t[end] rate(rate_cache, u_prev, p, t_prev) - tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate) + tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, dtmin) tau = min(tau, t_end - t_prev) counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) c(du, u_prev, p, t_prev, counts, nothing) @@ -115,16 +117,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; return sol end -struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm - epsilon::Float64 # Error control parameter - nc::Int # Critical reaction threshold - nstiff::Float64 # Stiffness threshold for switching - delta::Float64 # Partial equilibrium threshold -end - -SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) = - SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta) - # Compute stoichiometry matrix from c function function compute_stoichiometry(c, u, numjumps, p, t) nu = zeros(Int, length(u), numjumps) @@ -138,19 +130,6 @@ function compute_stoichiometry(c, u, numjumps, p, t) return nu end -# Detect reversible reaction pairs -function find_reversible_pairs(nu) - pairs = Vector{Tuple{Int,Int}}() - for j in 1:size(nu, 2) - for k in (j+1):size(nu, 2) - if nu[:, j] == -nu[:, k] - push!(pairs, (j, k)) - end - end - end - return pairs -end - # Compute g_i (approximation from Cao et al., 2006) function compute_gi(u, nu, i, rate, rate_cache, p, t) max_order = 1.0 @@ -170,7 +149,7 @@ function compute_gi(u, nu, i, rate, rate_cache, p, t) end # Tau-selection for explicit method (Equation 8) -function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate) +function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin) rate(rate_cache, u, p, t) mu = zeros(length(u)) sigma2 = zeros(length(u)) @@ -186,249 +165,7 @@ function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate) sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf tau = min(tau, mu_term, sigma_term) end - return max(tau, 1e-10) -end - -# Partial equilibrium check (Equation 13) -function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta) - a_plus = rate_cache[j_plus] - a_minus = rate_cache[j_minus] - return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) -end - -# Tau-selection for implicit method (Equation 14) -function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta) - rate(rate_cache, u, p, t) - mu = zeros(length(u)) - sigma2 = zeros(length(u)) - non_equilibrium = trues(size(nu, 2)) - for (j_plus, j_minus) in equilibrium_pairs - if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta) - non_equilibrium[j_plus] = false - non_equilibrium[j_minus] = false - end - end - tau = Inf - for i in 1:length(u) - for j in 1:size(nu, 2) - if non_equilibrium[j] - mu[i] += nu[i, j] * rate_cache[j] - sigma2[i] += nu[i, j]^2 * rate_cache[j] - end - end - gi = compute_gi(u, nu, i, rate, rate_cache, p, t) - bound = max(epsilon * u[i] / gi, 1.0) - mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf - sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf - tau = min(tau, mu_term, sigma_term) - end - return max(tau, 1e-10) -end - -# Identify critical reactions -function identify_critical_reactions(u, rate_cache, nu, nc) - critical = falses(size(nu, 2)) - for j in 1:size(nu, 2) - if rate_cache[j] > 0 - Lj = Inf - for i in 1:length(u) - if nu[i, j] < 0 - Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j]))) - end - end - if Lj < nc - critical[j] = true - end - end - end - return critical -end - -# Implicit tau-leaping step using NonlinearSolve -function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0 - function f(u_new, params) - rate_new = zeros(eltype(u_new), numjumps) - rate(rate_new, u_new, p, t_prev + tau) - residual = u_new - u_prev - for j in 1:numjumps - residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) - end - return residual - end - - # Initial guess - u_new = copy(u_prev) - - # Solve the nonlinear system - prob = NonlinearProblem(f, u_new, nothing) - sol = solve(prob, SimpleNewtonRaphson()) - - # Check for convergence and numerical stability - if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) - return round.(Int, max.(u_prev, 0.0)) # Revert to previous state - end - - return round.(Int, max.(sol.u, 0.0)) -end - -# Down-shifting condition (Equation 19) -function use_down_shifting(t, tau_im, tau_ex, a0, t_end) - return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) -end - -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing) - @assert isempty(jump_prob.jump_callback.continuous_callbacks) - @assert isempty(jump_prob.jump_callback.discrete_callbacks) - prob = jump_prob.prob - rng = DEFAULT_RNG - (seed !== nothing) && seed!(rng, seed) - - rj = jump_prob.regular_jump - rate = rj.rate - numjumps = rj.numjumps - c = rj.c - u0 = copy(prob.u0) - tspan = prob.tspan - p = prob.p - - # Initialize storage - rate_cache = zeros(Float64, numjumps) - counts = zeros(Int, numjumps) - du = similar(u0) - u = [copy(u0)] - t = [tspan[1]] - - # Algorithm parameters - epsilon = alg.epsilon - nc = alg.nc - nstiff = alg.nstiff - delta = alg.delta - t_end = tspan[2] - - # Compute stoichiometry matrix - nu = compute_stoichiometry(c, u0, numjumps, p, t[1]) - - # Detect reversible reaction pairs - equilibrium_pairs = find_reversible_pairs(nu) - - # Main simulation loop - while t[end] < t_end - u_prev = u[end] - t_prev = t[end] - - # Compute propensities - rate(rate_cache, u_prev, p, t_prev) - - # Identify critical reactions - critical = identify_critical_reactions(u_prev, rate_cache, nu, nc) - - # Compute tau values - tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate) - tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta) - - # Compute critical propensity sum - ac0 = sum(rate_cache[critical]) - tau2 = ac0 > 0 ? randexp(rng) / ac0 : Inf - - # Choose method and stepsize - a0 = sum(rate_cache) - use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end) - tau1 = use_implicit ? tau_im : tau_ex - method = use_implicit ? :implicit : :explicit - - # Cap tau to prevent large updates - tau1 = min(tau1, 1.0) - - # Check if tau1 is too small - if a0 > 0 && tau1 < 10 / a0 - # Use SSA for a few steps - steps = method == :implicit ? 10 : 100 - for _ in 1:steps - if t_prev >= t_end - break - end - rate(rate_cache, u_prev, p, t_prev) - a0 = sum(rate_cache) - if a0 == 0 - break - end - tau = randexp(rng) / a0 - r = rand(rng) * a0 - cumsum_rate = 0.0 - for j in 1:numjumps - cumsum_rate += rate_cache[j] - if cumsum_rate > r - u_prev += nu[:, j] - break - end - end - t_prev += tau - push!(u, copy(u_prev)) - push!(t, t_prev) - end - continue - end - - # Choose stepsize and compute firings - if tau2 > tau1 - tau = min(tau1, t_end - t_prev) - counts .= 0 - for j in 1:numjumps - if !critical[j] - counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0)) - end - end - if method == :implicit - u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - else - c(du, u_prev, p, t_prev, counts, nothing) - u_new = u_prev + du - end - else - tau = min(tau2, t_end - t_prev) - counts .= 0 - if ac0 > 0 - r = rand(rng) * ac0 - cumsum_rate = 0.0 - for j in 1:numjumps - if critical[j] - cumsum_rate += rate_cache[j] - if cumsum_rate > r - counts[j] = 1 - break - end - end - end - end - for j in 1:numjumps - if !critical[j] - counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0)) - end - end - if method == :implicit && tau > tau_ex - u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - else - c(du, u_prev, p, t_prev, counts, nothing) - u_new = u_prev + du - end - end - - # Check for negative populations - if any(u_new .< 0) - tau1 /= 2 - continue - end - - # Update state and time - push!(u, u_new) - push!(t, t_prev + tau) - end - - # Build solution - sol = DiffEqBase.build_solution(prob, alg, t, u, - calculate_error = false, - interp = DiffEqBase.ConstantInterpolation(t, u)) + return max(tau, dtmin) end struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm @@ -444,4 +181,4 @@ function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end -export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping +export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 4241bca9..a1b278d9 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -35,13 +35,9 @@ let sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims) - sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims) - mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims) - sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims) mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims) - @test isapprox(mean_simple, mean_implicit, rtol=0.05) @test isapprox(mean_simple, mean_adaptive, rtol=0.05) end @@ -79,12 +75,8 @@ let sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims) - sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims) - mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims) - sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims) mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims) - @test isapprox(mean_simple, mean_implicit, rtol=0.05) @test isapprox(mean_simple, mean_adaptive, rtol=0.05) end From 07c429edd0a8bd8e52d17dc106a2d77ebe508cc1 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 13 Aug 2025 04:25:41 +0530 Subject: [PATCH 05/39] test refactor --- test/regular_jumps.jl | 135 +++++++++++++++++++++++++++--------------- 1 file changed, 87 insertions(+), 48 deletions(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index a1b278d9..75d0eb4e 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -6,77 +6,116 @@ rng = StableRNG(12345) Nsims = 8000 # SIR model with influx -let +@testset "SIR Model Correctness" begin β = 0.1 / 1000.0 ν = 0.01 influx_rate = 1.0 p = (β, ν, influx_rate) - regular_rate = (out, u, p, t) -> begin - out[1] = p[1] * u[1] * u[2] # β*S*I (infection) - out[2] = p[2] * u[2] # ν*I (recovery) - out[3] = p[3] # influx_rate - end - - regular_c = (dc, u, p, t, counts, mark) -> begin - dc .= 0.0 - dc[1] = -counts[1] + counts[3] # S: -infection + influx - dc[2] = counts[1] - counts[2] # I: +infection - recovery - dc[3] = counts[2] # R: +recovery - end + # ConstantRateJump formulation for SSAStepper + rate1(u, p, t) = p[1] * u[1] * u[2] # β*S*I (infection) + rate2(u, p, t) = p[2] * u[2] # ν*I (recovery) + rate3(u, p, t) = p[3] # influx_rate + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[1] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) u0 = [999.0, 10.0, 0.0] # S, I, R tspan = (0.0, 250.0) - prob_disc = DiscreteProblem(u0, tspan, p) - rj = RegularJump(regular_rate, regular_c, 3) - jump_prob = JumpProblem(prob_disc, Direct(), rj) - - sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) - mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) - sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims) - mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims) + # Solve with SSAStepper + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims) - @test isapprox(mean_simple, mean_adaptive, rtol=0.05) + # RegularJump formulation for TauLeaping methods + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[2] + out[2] = p[2] * u[2] + out[3] = p[3] + end + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] + counts[3] + dc[2] = counts[1] - counts[2] + dc[3] = counts[2] + end + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng = rng) + + # Solve with SimpleTauLeaping (dt=0.1) + sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) + + # Solve with SimpleAdaptiveTauLeaping + sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims) + + # Compute mean trajectories at t = 0, 1, ..., 250 + t_points = 0:1.0:250.0 + mean_direct_S = [mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points] + mean_simple_S = [mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points] + mean_adaptive_S = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_points] + + for i in 1:251 + @test isapprox(mean_direct_S[i], mean_simple_S[i], rtol=0.10) + @test isapprox(mean_direct_S[i], mean_adaptive_S[i], rtol=0.10) + end end - # SEIR model with exposed compartment -let +@testset "SEIR Model Correctness" begin β = 0.3 / 1000.0 σ = 0.2 ν = 0.01 p = (β, σ, ν) - regular_rate = (out, u, p, t) -> begin - out[1] = p[1] * u[1] * u[3] # β*S*I (infection) - out[2] = p[2] * u[2] # σ*E (progression) - out[3] = p[3] * u[3] # ν*I (recovery) - end - - regular_c = (dc, u, p, t, counts, mark) -> begin - dc .= 0.0 - dc[1] = -counts[1] # S: -infection - dc[2] = counts[1] - counts[2] # E: +infection - progression - dc[3] = counts[2] - counts[3] # I: +progression - recovery - dc[4] = counts[3] # R: +recovery - end + # ConstantRateJump formulation for SSAStepper + rate1(u, p, t) = p[1] * u[1] * u[3] # β*S*I (infection) + rate2(u, p, t) = p[2] * u[2] # σ*E (progression) + rate3(u, p, t) = p[3] * u[3] # ν*I (recovery) + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) - # Initial state u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R tspan = (0.0, 250.0) - - # Create JumpProblem prob_disc = DiscreteProblem(u0, tspan, p) - rj = RegularJump(regular_rate, regular_c, 3) - jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) - sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) - mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims) + # Solve with SSAStepper + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims) - sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims) - mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims) - - @test isapprox(mean_simple, mean_adaptive, rtol=0.05) + # RegularJump formulation for TauLeaping methods + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[3] + out[2] = p[2] * u[2] + out[3] = p[3] * u[3] + end + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] + dc[2] = counts[1] - counts[2] + dc[3] = counts[2] - counts[3] + dc[4] = counts[3] + end + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng = rng) + + # Solve with SimpleTauLeaping (dt=0.1) + sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) + + # Solve with SimpleAdaptiveTauLeaping + sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims) + + # Compute mean trajectories at t = 0, 1, ..., 250 + t_points = 0:1.0:250.0 + mean_direct_S = [mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points] + mean_simple_S = [mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points] + mean_adaptive_S = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points] + + for i in 1:251 + @test isapprox(mean_direct_S[i], mean_simple_S[i], rtol=0.10) + @test isapprox(mean_direct_S[i], mean_adaptive_S[i], rtol=0.10) + end end From 885ac59574482eec5a60a1c34938bb18b5bc38d8 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 14 Aug 2025 05:02:32 +0530 Subject: [PATCH 06/39] refactor --- src/simple_regular_solve.jl | 105 ++++++++++++++++++------------------ test/regular_jumps.jl | 2 +- 2 files changed, 52 insertions(+), 55 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 663e6913..ca4fb221 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -67,6 +67,35 @@ end SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) +function compute_gi(u, nu, hor, i) + max_order = 1.0 + for j in 1:size(nu, 2) + if abs(nu[i, j]) > 0 + max_order = max(max_order, Float64(hor[j])) + end + end + return max_order +end + +function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin) + rate(rate_cache, u, p, t) + mu = zeros(length(u)) + sigma2 = zeros(length(u)) + tau = Inf + for i in 1:length(u) + for j in 1:size(nu, 2) + mu[i] += nu[i, j] * rate_cache[j] + sigma2[i] += nu[i, j]^2 * rate_cache[j] + end + gi = compute_gi(u, nu, hor, i) + bound = max(epsilon * u[i] / gi, 1.0) + mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf + sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf + tau = min(tau, mu_term, sigma_term) + end + return max(tau, dtmin) +end + function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed = nothing, dtmin = 1e-10) @@ -92,17 +121,36 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; t_end = tspan[2] epsilon = alg.epsilon - nu = compute_stoichiometry(c, u0, numjumps, p, t[1]) + # Compute initial stoichiometry and HOR + nu = zeros(Int, length(u0), numjumps) + for j in 1:numjumps + counts_temp = zeros(numjumps) + counts_temp[j] = 1 + c(du, u0, p, t[1], counts_temp, nothing) + nu[:, j] = du + end + + hor = zeros(Int, size(nu, 2)) + for j in 1:size(nu, 2) + hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 + end while t[end] < t_end u_prev = u[end] t_prev = t[end] + # Recompute stoichiometry + for j in 1:numjumps + counts_temp = zeros(numjumps) + counts_temp[j] = 1 + c(du, u_prev, p, t_prev, counts_temp, nothing) + nu[:, j] = du + end rate(rate_cache, u_prev, p, t_prev) - tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, dtmin) + tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin) tau = min(tau, t_end - t_prev) counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) c(du, u_prev, p, t_prev, counts, nothing) - u_new = u_prev + du + u_new = max.(u_prev + du, 0) if any(u_new .< 0) tau /= 2 continue @@ -117,57 +165,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; return sol end -# Compute stoichiometry matrix from c function -function compute_stoichiometry(c, u, numjumps, p, t) - nu = zeros(Int, length(u), numjumps) - for j in 1:numjumps - counts = zeros(numjumps) - counts[j] = 1 - du = similar(u) - c(du, u, p, t, counts, nothing) - nu[:, j] = round.(Int, du) - end - return nu -end - -# Compute g_i (approximation from Cao et al., 2006) -function compute_gi(u, nu, i, rate, rate_cache, p, t) - max_order = 1.0 - for j in 1:size(nu, 2) - if abs(nu[i, j]) > 0 - rate(rate_cache, u, p, t) - if rate_cache[j] > 0 - order = 1.0 - if sum(abs.(nu[:, j])) > abs(nu[i, j]) - order = 2.0 - end - max_order = max(max_order, order) - end - end - end - return max_order -end - -# Tau-selection for explicit method (Equation 8) -function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin) - rate(rate_cache, u, p, t) - mu = zeros(length(u)) - sigma2 = zeros(length(u)) - tau = Inf - for i in 1:length(u) - for j in 1:size(nu, 2) - mu[i] += nu[i, j] * rate_cache[j] - sigma2[i] += nu[i, j]^2 * rate_cache[j] - end - gi = compute_gi(u, nu, i, rate, rate_cache, p, t) - bound = max(epsilon * u[i] / gi, 1.0) - mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf - sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf - tau = min(tau, mu_term, sigma_term) - end - return max(tau, dtmin) -end - struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm backend::Backend cpu_offload::Float64 diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 75d0eb4e..31aa6928 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -3,7 +3,7 @@ using Test, LinearAlgebra, Statistics using StableRNGs rng = StableRNG(12345) -Nsims = 8000 +Nsims = 1000 # SIR model with influx @testset "SIR Model Correctness" begin From 0ec9d39f2eaf35a6484ccdd77647ee5a63da30f8 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 17 Aug 2025 03:38:56 +0530 Subject: [PATCH 07/39] added saveat in SimpleAdaptiveTauLeaping --- src/simple_regular_solve.jl | 36 +++++++++++++++++++++++++++++++----- test/regular_jumps.jl | 4 ++-- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index ca4fb221..0f326329 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -98,7 +98,8 @@ end function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed = nothing, - dtmin = 1e-10) + dtmin = 1e-10, + saveat = nothing) @assert isempty(jump_prob.jump_callback.continuous_callbacks) @assert isempty(jump_prob.jump_callback.discrete_callbacks) prob = jump_prob.prob @@ -123,24 +124,27 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; # Compute initial stoichiometry and HOR nu = zeros(Int, length(u0), numjumps) + counts_temp = zeros(Int, numjumps) for j in 1:numjumps - counts_temp = zeros(numjumps) + fill!(counts_temp, 0) counts_temp[j] = 1 c(du, u0, p, t[1], counts_temp, nothing) nu[:, j] = du end - hor = zeros(Int, size(nu, 2)) for j in 1:size(nu, 2) hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 end + saveat_times = isnothing(saveat) ? Float64[] : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat) + save_idx = 1 + while t[end] < t_end u_prev = u[end] t_prev = t[end] # Recompute stoichiometry for j in 1:numjumps - counts_temp = zeros(numjumps) + fill!(counts_temp, 0) counts_temp[j] = 1 c(du, u_prev, p, t_prev, counts_temp, nothing) nu[:, j] = du @@ -148,15 +152,37 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; rate(rate_cache, u_prev, p, t_prev) tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin) tau = min(tau, t_end - t_prev) + if !isempty(saveat_times) + if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx] + tau = saveat_times[save_idx] - t_prev + end + end counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) c(du, u_prev, p, t_prev, counts, nothing) - u_new = max.(u_prev + du, 0) + u_new = u_prev + du if any(u_new .< 0) + # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) tau /= 2 continue end + u_new = max.(u_new, 0) # Ensure non-negative states push!(u, u_new) push!(t, t_prev + tau) + if !isempty(saveat_times) && save_idx <= length(saveat_times) && t[end] >= saveat_times[save_idx] + save_idx += 1 + end + end + + # Interpolate to saveat times if specified + if !isempty(saveat_times) + t_out = saveat_times + u_out = [u[end]] + for t_save in saveat_times + idx = findlast(ti -> ti <= t_save, t) + push!(u_out, u[idx]) + end + t = t_out + u = u_out[2:end] end sol = DiffEqBase.build_solution(prob, alg, t, u, diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 31aa6928..923c08f2 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -48,7 +48,7 @@ Nsims = 1000 sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) # Solve with SimpleAdaptiveTauLeaping - sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims) + sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0) # Compute mean trajectories at t = 0, 1, ..., 250 t_points = 0:1.0:250.0 @@ -106,7 +106,7 @@ end sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) # Solve with SimpleAdaptiveTauLeaping - sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims) + sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0) # Compute mean trajectories at t = 0, 1, ..., 250 t_points = 0:1.0:250.0 From 0e7ff26aa3b5e090b4e10cd4a530d18b36351295 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 24 Aug 2025 11:14:05 +0530 Subject: [PATCH 08/39] update --- src/simple_regular_solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 0f326329..1366c45c 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -100,8 +100,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed = nothing, dtmin = 1e-10, saveat = nothing) - @assert isempty(jump_prob.jump_callback.continuous_callbacks) - @assert isempty(jump_prob.jump_callback.discrete_callbacks) + validate_pure_leaping_inputs(jump_prob, alg) || + error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only non-RegularJumps.") prob = jump_prob.prob rng = DEFAULT_RNG (seed !== nothing) && seed!(rng, seed) From 89378c6d84de92df15a37ca2d0d773f576a4f566 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sun, 24 Aug 2025 11:15:19 +0530 Subject: [PATCH 09/39] Update src/simple_regular_solve.jl Co-authored-by: Christopher Rackauckas --- src/simple_regular_solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 1366c45c..77c16a3a 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -61,8 +61,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; interp = DiffEqBase.ConstantInterpolation(t, u)) end -struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm - epsilon::Float64 # Error control parameter +struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm + epsilon::T # Error control parameter end SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) From 14e0be700b0cb391fa83c07e4d4c4d8c943f57d8 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sun, 24 Aug 2025 11:15:43 +0530 Subject: [PATCH 10/39] Update src/simple_regular_solve.jl Co-authored-by: Christopher Rackauckas --- src/simple_regular_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 77c16a3a..84dac18b 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -71,7 +71,7 @@ function compute_gi(u, nu, hor, i) max_order = 1.0 for j in 1:size(nu, 2) if abs(nu[i, j]) > 0 - max_order = max(max_order, Float64(hor[j])) + max_order = max(max_order, float(hor[j])) end end return max_order From e7f975e4647596c3283ae68505dfc0ae97044138 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 24 Aug 2025 11:51:21 +0530 Subject: [PATCH 11/39] update --- Project.toml | 1 - src/simple_regular_solve.jl | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index dcafe2bc..46f57c63 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 84dac18b..b70b484c 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -136,7 +136,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 end - saveat_times = isnothing(saveat) ? Float64[] : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat) + saveat_times = isnothing(saveat) ? Vector{typeof(t)}() : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat) save_idx = 1 while t[end] < t_end @@ -160,7 +160,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) c(du, u_prev, p, t_prev, counts, nothing) u_new = u_prev + du - if any(u_new .< 0) + if any(<(0), u_new) # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) tau /= 2 continue From 6e789cdec8df6531c7e0494f149055b27112825d Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 24 Aug 2025 11:56:42 +0530 Subject: [PATCH 12/39] test update --- Project.toml | 1 + test/regular_jumps.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 46f57c63..82631e11 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 923c08f2..9087d03b 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -42,7 +42,7 @@ Nsims = 1000 dc[3] = counts[2] end rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng = rng) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng = rng) # Solve with SimpleTauLeaping (dt=0.1) sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) @@ -100,7 +100,7 @@ end dc[4] = counts[3] end rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng = rng) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng = rng) # Solve with SimpleTauLeaping (dt=0.1) sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) From a8999f44c923dc689dc5c518be64978f60d0e6db Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 05:10:57 +0530 Subject: [PATCH 13/39] using maj for adaptive tauleaping --- src/simple_regular_solve.jl | 54 ++++++------ test/regular_jumps.jl | 159 +++++++++++++++++++++++++++++------- 2 files changed, 161 insertions(+), 52 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index b70b484c..b7dbbc94 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -67,6 +67,14 @@ end SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) +function compute_hor(nu) + hor = zeros(Int, size(nu, 2)) + for j in 1:size(nu, 2) + hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 + end + return hor +end + function compute_gi(u, nu, hor, i) max_order = 1.0 for j in 1:size(nu, 2) @@ -100,16 +108,21 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed = nothing, dtmin = 1e-10, saveat = nothing) - validate_pure_leaping_inputs(jump_prob, alg) || - error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only non-RegularJumps.") + if jump_prob.massaction_jump === nothing + error("SimpleAdaptiveTauLeaping requires a JumpProblem with a MassActionJump.") + end prob = jump_prob.prob rng = DEFAULT_RNG (seed !== nothing) && seed!(rng, seed) - rj = jump_prob.regular_jump - rate = rj.rate - numjumps = rj.numjumps - c = rj.c + maj = jump_prob.massaction_jump + numjumps = get_num_majumps(maj) + # Extract rates + rate = (out, u, p, t) -> begin + for j in 1:get_num_majumps(maj) + out[j] = evalrxrate(u, j, maj) + end + end u0 = copy(prob.u0) tspan = prob.tspan p = prob.p @@ -122,19 +135,14 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; t_end = tspan[2] epsilon = alg.epsilon - # Compute initial stoichiometry and HOR + # Extract stoichiometry once from MassActionJump nu = zeros(Int, length(u0), numjumps) - counts_temp = zeros(Int, numjumps) for j in 1:numjumps - fill!(counts_temp, 0) - counts_temp[j] = 1 - c(du, u0, p, t[1], counts_temp, nothing) - nu[:, j] = du - end - hor = zeros(Int, size(nu, 2)) - for j in 1:size(nu, 2) - hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 + for (spec_idx, stoich) in maj.net_stoch[j] + nu[spec_idx, j] = stoich + end end + hor = compute_hor(nu) saveat_times = isnothing(saveat) ? Vector{typeof(t)}() : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat) save_idx = 1 @@ -142,13 +150,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; while t[end] < t_end u_prev = u[end] t_prev = t[end] - # Recompute stoichiometry - for j in 1:numjumps - fill!(counts_temp, 0) - counts_temp[j] = 1 - c(du, u_prev, p, t_prev, counts_temp, nothing) - nu[:, j] = du - end rate(rate_cache, u_prev, p, t_prev) tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin) tau = min(tau, t_end - t_prev) @@ -158,7 +159,12 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; end end counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) - c(du, u_prev, p, t_prev, counts, nothing) + du .= 0 + for j in 1:numjumps + for (spec_idx, stoich) in maj.net_stoch[j] + du[spec_idx] += stoich * counts[j] + end + end u_new = u_prev + du if any(<(0), u_new) # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 9087d03b..1c4f7368 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -15,7 +15,7 @@ Nsims = 1000 # ConstantRateJump formulation for SSAStepper rate1(u, p, t) = p[1] * u[1] * u[2] # β*S*I (infection) rate2(u, p, t) = p[2] * u[2] # ν*I (recovery) - rate3(u, p, t) = p[3] # influx_rate + rate3(u, p, t) = p[3] # influx_rate (S influx) affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) affect3!(integrator) = (integrator.u[1] += 1; nothing) @@ -24,41 +24,49 @@ Nsims = 1000 u0 = [999.0, 10.0, 0.0] # S, I, R tspan = (0.0, 250.0) prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=rng) # Solve with SSAStepper - sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims) + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) - # RegularJump formulation for TauLeaping methods + # RegularJump formulation for SimpleTauLeaping regular_rate = (out, u, p, t) -> begin out[1] = p[1] * u[1] * u[2] out[2] = p[2] * u[2] out[3] = p[3] end regular_c = (dc, u, p, t, counts, mark) -> begin - dc .= 0.0 + dc .= 0 dc[1] = -counts[1] + counts[3] dc[2] = counts[1] - counts[2] dc[3] = counts[2] end rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng = rng) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng=rng) - # Solve with SimpleTauLeaping (dt=0.1) + # Solve with SimpleTauLeaping sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) - + + # MassActionJump formulation for SimpleAdaptiveTauLeaping + reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng) + # Solve with SimpleAdaptiveTauLeaping - sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0) + sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) - # Compute mean trajectories at t = 0, 1, ..., 250 + # Compute mean infected (I) trajectories t_points = 0:1.0:250.0 - mean_direct_S = [mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points] - mean_simple_S = [mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points] - mean_adaptive_S = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_points] + mean_direct_I = [mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points] + mean_simple_I = [mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points] + mean_adaptive_I = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_points] + # Test mean infected trajectories for i in 1:251 - @test isapprox(mean_direct_S[i], mean_simple_S[i], rtol=0.10) - @test isapprox(mean_direct_S[i], mean_adaptive_S[i], rtol=0.10) + @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.10) + @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.10) end end @@ -81,12 +89,12 @@ end u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R tspan = (0.0, 250.0) prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=rng) # Solve with SSAStepper - sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims) + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) - # RegularJump formulation for TauLeaping methods + # RegularJump formulation for SimpleTauLeaping regular_rate = (out, u, p, t) -> begin out[1] = p[1] * u[1] * u[3] out[2] = p[2] * u[2] @@ -100,22 +108,117 @@ end dc[4] = counts[3] end rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng = rng) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng=rng) - # Solve with SimpleTauLeaping (dt=0.1) + # Solve with SimpleTauLeaping sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) - + + # MassActionJump formulation for SimpleAdaptiveTauLeaping + reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng) + # Solve with SimpleAdaptiveTauLeaping - sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0) + sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) - # Compute mean trajectories at t = 0, 1, ..., 250 + # Compute mean infected (I) trajectories t_points = 0:1.0:250.0 - mean_direct_S = [mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points] - mean_simple_S = [mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points] - mean_adaptive_S = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points] + mean_direct_I = [mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points] + mean_simple_I = [mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points] + mean_adaptive_I = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points] + # Test mean infected trajectories for i in 1:251 - @test isapprox(mean_direct_S[i], mean_simple_S[i], rtol=0.10) - @test isapprox(mean_direct_S[i], mean_adaptive_S[i], rtol=0.10) + @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.10) + @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.10) + end +end + +# Test PureLeaping aggregator functionality +@testset "PureLeaping Aggregator Tests" begin + # Test with MassActionJump + u0 = [10, 5, 0] + tspan = (0.0, 10.0) + p = [0.1, 0.2] + prob = DiscreteProblem(u0, tspan, p) + + # Create MassActionJump + reactant_stoich = [[1 => 1], [1 => 2]] + net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]] + rates = [0.1, 0.05] + maj = MassActionJump(rates, reactant_stoich, net_stoich) + + # Test PureLeaping JumpProblem creation + jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj)) + @test jp_pure.aggregator isa PureLeaping + @test jp_pure.discrete_jump_aggregation === nothing + @test jp_pure.massaction_jump !== nothing + @test length(jp_pure.jump_callback.discrete_callbacks) == 0 + + # Test with ConstantRateJump + rate(u, p, t) = p[1] * u[1] + affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) + crj = ConstantRateJump(rate, affect!) + + jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj)) + @test jp_pure_crj.aggregator isa PureLeaping + @test jp_pure_crj.discrete_jump_aggregation === nothing + @test length(jp_pure_crj.constant_jumps) == 1 + + # Test with VariableRateJump + vrate(u, p, t) = t * p[1] * u[1] + vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) + vrj = VariableRateJump(vrate, vaffect!) + + jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj)) + @test jp_pure_vrj.aggregator isa PureLeaping + @test jp_pure_vrj.discrete_jump_aggregation === nothing + @test length(jp_pure_vrj.variable_jumps) == 1 + + # Test with RegularJump + function rj_rate(out, u, p, t) + out[1] = p[1] * u[1] + end + + rj_dc = zeros(3, 1) + rj_dc[1, 1] = -1 + rj_dc[3, 1] = 1 + + function rj_c(du, u, p, t, counts, mark) + mul!(du, rj_dc, counts) end + + regj = RegularJump(rj_rate, rj_c, 1) + + jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj)) + @test jp_pure_regj.aggregator isa PureLeaping + @test jp_pure_regj.discrete_jump_aggregation === nothing + @test jp_pure_regj.regular_jump !== nothing + + # Test mixed jump types + mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), + variable_jumps = (vrj,), regular_jumps = regj) + jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps) + @test jp_pure_mixed.aggregator isa PureLeaping + @test jp_pure_mixed.discrete_jump_aggregation === nothing + @test jp_pure_mixed.massaction_jump !== nothing + @test length(jp_pure_mixed.constant_jumps) == 1 + @test length(jp_pure_mixed.variable_jumps) == 1 + @test jp_pure_mixed.regular_jump !== nothing + + # Test spatial system error + spatial_sys = CartesianGrid((2, 2)) + hopping_consts = [1.0] + @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); + spatial_system = spatial_sys) + @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); + hopping_constants = hopping_consts) + + # Test MassActionJump with parameter mapping + maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) + jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params)) + scaled_rates = [p[1], p[2]/2] + @test jp_params.massaction_jump.scaled_rates == scaled_rates end From 0125e21474069aa8888ecfa379d89867572e4ec8 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 05:12:33 +0530 Subject: [PATCH 14/39] project.toml update --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 82631e11..46f57c63 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" From 1af190d6e0e41b7bb41a8091d4d51705e2c456a1 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 05:24:23 +0530 Subject: [PATCH 15/39] saveat logic change --- src/simple_regular_solve.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index b7dbbc94..011f6305 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -144,7 +144,15 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; end hor = compute_hor(nu) - saveat_times = isnothing(saveat) ? Vector{typeof(t)}() : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat) + saveat_times = nothing + if isnothing(saveat) + saveat_times = Vector{typeof(tspan[1])}() + elseif saveat isa Number + saveat_times = collect(range(tspan[1], tspan[2], step=saveat)) + else + saveat_times = collect(saveat) + end + save_idx = 1 while t[end] < t_end From 10f4ce3536b992df520bb22cda2725665f4b65f2 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 05:27:37 +0530 Subject: [PATCH 16/39] test change --- test/regular_jumps.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 1c4f7368..25fcca61 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -65,8 +65,8 @@ Nsims = 1000 # Test mean infected trajectories for i in 1:251 - @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.10) - @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.10) + @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.05) + @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.05) end end @@ -131,8 +131,8 @@ end # Test mean infected trajectories for i in 1:251 - @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.10) - @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.10) + @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.05) + @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.05) end end From 6d3d9005de14071149905d3595afc016f91a5193 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 08:18:01 +0530 Subject: [PATCH 17/39] saveat optimization --- src/simple_regular_solve.jl | 59 ++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 011f6305..1489dc70 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -127,8 +127,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; tspan = prob.tspan p = prob.p - u = [copy(u0)] - t = [tspan[1]] + # Initialize output vectors + u_out = [copy(u0)] + t_out = [tspan[1]] rate_cache = zeros(Float64, numjumps) counts = zeros(Int, numjumps) du = similar(u0) @@ -155,16 +156,16 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; save_idx = 1 - while t[end] < t_end - u_prev = u[end] - t_prev = t[end] - rate(rate_cache, u_prev, p, t_prev) - tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin) - tau = min(tau, t_end - t_prev) - if !isempty(saveat_times) - if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx] - tau = saveat_times[save_idx] - t_prev - end + # Current state for timestepping + u_current = copy(u0) + t_current = tspan[1] + + while t_current < t_end + rate(rate_cache, u_current, p, t_current) + tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin) + tau = min(tau, t_end - t_current) + if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] + tau = saveat_times[save_idx] - t_current end counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) du .= 0 @@ -173,35 +174,31 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; du[spec_idx] += stoich * counts[j] end end - u_new = u_prev + du + u_new = u_current + du if any(<(0), u_new) # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) tau /= 2 continue end - u_new = max.(u_new, 0) # Ensure non-negative states - push!(u, u_new) - push!(t, t_prev + tau) - if !isempty(saveat_times) && save_idx <= length(saveat_times) && t[end] >= saveat_times[save_idx] - save_idx += 1 + u_new = max.(u_new, 0) + t_new = t_current + tau + + # Save state if at a saveat time or if saveat is empty + if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) + push!(u_out, copy(u_new)) + push!(t_out, t_new) + if !isempty(saveat_times) && t_new >= saveat_times[save_idx] + save_idx += 1 + end end - end - # Interpolate to saveat times if specified - if !isempty(saveat_times) - t_out = saveat_times - u_out = [u[end]] - for t_save in saveat_times - idx = findlast(ti -> ti <= t_save, t) - push!(u_out, u[idx]) - end - t = t_out - u = u_out[2:end] + u_current = u_new + t_current = t_new end - sol = DiffEqBase.build_solution(prob, alg, t, u, + sol = DiffEqBase.build_solution(prob, alg, t_out, u_out, calculate_error=false, - interp=DiffEqBase.ConstantInterpolation(t, u)) + interp=DiffEqBase.ConstantInterpolation(t_out, u_out)) return sol end From 2c03d67a0eece1dfb5a7aeb0099be5f638ff5c1e Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 08:36:46 +0530 Subject: [PATCH 18/39] refactor --- src/simple_regular_solve.jl | 47 ++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 1489dc70..afba29fc 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -118,26 +118,30 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; maj = jump_prob.massaction_jump numjumps = get_num_majumps(maj) # Extract rates - rate = (out, u, p, t) -> begin - for j in 1:get_num_majumps(maj) - out[j] = evalrxrate(u, j, maj) + rate = jump_prob.regular_jump !== nothing ? jump_prob.regular_jump.rate : + (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) + end end - end + c = jump_prob.regular_jump !== nothing ? jump_prob.regular_jump.c : nothing u0 = copy(prob.u0) tspan = prob.tspan p = prob.p - # Initialize output vectors - u_out = [copy(u0)] - t_out = [tspan[1]] - rate_cache = zeros(Float64, numjumps) - counts = zeros(Int, numjumps) + # Initialize current state and saved history + u_current = copy(u0) + t_current = tspan[1] + usave = [copy(u0)] + tsave = [tspan[1]] + rate_cache = zeros(float(eltype(u0)), numjumps) + counts = zero(rate_cache) du = similar(u0) t_end = tspan[2] epsilon = alg.epsilon # Extract stoichiometry once from MassActionJump - nu = zeros(Int, length(u0), numjumps) + nu = zeros(float(eltype(u0)), length(u0), numjumps) for j in 1:numjumps for (spec_idx, stoich) in maj.net_stoch[j] nu[spec_idx, j] = stoich @@ -145,6 +149,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; end hor = compute_hor(nu) + # Set up saveat_times saveat_times = nothing if isnothing(saveat) saveat_times = Vector{typeof(tspan[1])}() @@ -156,10 +161,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; save_idx = 1 - # Current state for timestepping - u_current = copy(u0) - t_current = tspan[1] - while t_current < t_end rate(rate_cache, u_current, p, t_current) tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin) @@ -169,9 +170,13 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; end counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) du .= 0 - for j in 1:numjumps - for (spec_idx, stoich) in maj.net_stoch[j] - du[spec_idx] += stoich * counts[j] + if c !== nothing + c(du, u_current, p, t_current, counts, nothing) + else + for j in 1:numjumps + for (spec_idx, stoich) in maj.net_stoch[j] + du[spec_idx] += stoich * counts[j] + end end end u_new = u_current + du @@ -185,8 +190,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; # Save state if at a saveat time or if saveat is empty if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) - push!(u_out, copy(u_new)) - push!(t_out, t_new) + push!(usave, u_new) + push!(tsave, t_new) if !isempty(saveat_times) && t_new >= saveat_times[save_idx] save_idx += 1 end @@ -196,9 +201,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; t_current = t_new end - sol = DiffEqBase.build_solution(prob, alg, t_out, u_out, + sol = DiffEqBase.build_solution(prob, alg, tsave, usave, calculate_error=false, - interp=DiffEqBase.ConstantInterpolation(t_out, u_out)) + interp=DiffEqBase.ConstantInterpolation(tsave, usave)) return sol end From 3f90750be76f1d2d9c23d140b175db1949d87022 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 08:41:06 +0530 Subject: [PATCH 19/39] memory optimization --- src/simple_regular_solve.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index afba29fc..afec0857 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -185,7 +185,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; tau /= 2 continue end - u_new = max.(u_new, 0) + for i in eachindex(u_new) + u_new[i] = max(u_new[i], 0) + end t_new = t_current + tau # Save state if at a saveat time or if saveat is empty From fb721496ee3307d747b7d8f9a1a3682070535f50 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 08:54:29 +0530 Subject: [PATCH 20/39] validate_pure_leaping_inputs extended for adaptive version --- src/simple_regular_solve.jl | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index afec0857..e97eb455 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -1,5 +1,11 @@ struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end +struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm + epsilon::T # Error control parameter +end + +SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) + function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ @@ -14,6 +20,19 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) jump_prob.regular_jump !== nothing end +function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping) + if !(jump_prob.aggregator isa PureLeaping) + @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ + JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ + Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release." + end + isempty(jump_prob.jump_callback.continuous_callbacks) && + isempty(jump_prob.jump_callback.discrete_callbacks) && + isempty(jump_prob.constant_jumps) && + isempty(jump_prob.variable_jumps) && + jump_prob.massaction_jump !== nothing +end + function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; seed = nothing, dt = error("dt is required for SimpleTauLeaping.")) validate_pure_leaping_inputs(jump_prob, alg) || @@ -61,12 +80,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; interp = DiffEqBase.ConstantInterpolation(t, u)) end -struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm - epsilon::T # Error control parameter -end - -SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) - function compute_hor(nu) hor = zeros(Int, size(nu, 2)) for j in 1:size(nu, 2) @@ -108,9 +121,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed = nothing, dtmin = 1e-10, saveat = nothing) - if jump_prob.massaction_jump === nothing - error("SimpleAdaptiveTauLeaping requires a JumpProblem with a MassActionJump.") - end + validate_pure_leaping_inputs(jump_prob, alg) || + error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") prob = jump_prob.prob rng = DEFAULT_RNG (seed !== nothing) && seed!(rng, seed) From 7a7232ae3dccaddbfb71ab0b8974d8857264d7be Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 09:06:00 +0530 Subject: [PATCH 21/39] some --- src/simple_regular_solve.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index e97eb455..815ae36e 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -129,14 +129,15 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; maj = jump_prob.massaction_jump numjumps = get_num_majumps(maj) + rj = jump_prob.regular_jump # Extract rates - rate = jump_prob.regular_jump !== nothing ? jump_prob.regular_jump.rate : + rate = rj !== nothing ? rj.rate : (out, u, p, t) -> begin for j in 1:numjumps out[j] = evalrxrate(u, j, maj) end end - c = jump_prob.regular_jump !== nothing ? jump_prob.regular_jump.c : nothing + c = rj !== nothing ? rj.c : nothing u0 = copy(prob.u0) tspan = prob.tspan p = prob.p From fe7cec04424337752566a1a928722bebe55f533f Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 09:27:47 +0530 Subject: [PATCH 22/39] space optimized in compute_tau_explicit --- src/simple_regular_solve.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 815ae36e..db0cfb44 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -100,18 +100,18 @@ end function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin) rate(rate_cache, u, p, t) - mu = zeros(length(u)) - sigma2 = zeros(length(u)) tau = Inf for i in 1:length(u) + mu = zero(eltype(u)) + sigma2 = zero(eltype(u)) for j in 1:size(nu, 2) - mu[i] += nu[i, j] * rate_cache[j] - sigma2[i] += nu[i, j]^2 * rate_cache[j] + mu += nu[i, j] * rate_cache[j] + sigma2 += nu[i, j]^2 * rate_cache[j] end gi = compute_gi(u, nu, hor, i) bound = max(epsilon * u[i] / gi, 1.0) - mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf - sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf + mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf tau = min(tau, mu_term, sigma_term) end return max(tau, dtmin) From 8e7ff162d617bc215551b66f7979d72e13d2468f Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 11:00:56 +0530 Subject: [PATCH 23/39] computegi and comutehor changes --- src/simple_regular_solve.jl | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index db0cfb44..1d1fd48a 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -83,19 +83,33 @@ end function compute_hor(nu) hor = zeros(Int, size(nu, 2)) for j in 1:size(nu, 2) - hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 + order = sum(abs(stoich) for stoich in nu[:, j] if stoich < 0; init=0) + if order > 3 + error("Reaction $j has order $order, which is not supported (maximum order is 3).") + end + hor[j] = order end return hor end function compute_gi(u, nu, hor, i) - max_order = 1.0 + max_gi = 1 for j in 1:size(nu, 2) - if abs(nu[i, j]) > 0 - max_order = max(max_order, float(hor[j])) + if nu[i, j] < 0 # Species i is a substrate + if hor[j] == 1 + max_gi = max(max_gi, 1) + elseif hor[j] == 2 || hor[j] == 3 + stoich = abs(nu[i, j]) + if stoich >= 2 + gi = 2 / stoich + 1 / (stoich - 1) + max_gi = max(max_gi, ceil(Int, gi)) + elseif stoich == 1 + max_gi = max(max_gi, hor[j]) + end + end end end - return max_order + return max_gi end function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin) From bc770d16310f3683adadf6e08316ee0d87e76fe8 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Mon, 25 Aug 2025 11:20:19 +0530 Subject: [PATCH 24/39] reactant_stoch in hor --- src/simple_regular_solve.jl | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 1d1fd48a..b2badec3 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -80,10 +80,10 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; interp = DiffEqBase.ConstantInterpolation(t, u)) end -function compute_hor(nu) - hor = zeros(Int, size(nu, 2)) - for j in 1:size(nu, 2) - order = sum(abs(stoich) for stoich in nu[:, j] if stoich < 0; init=0) +function compute_hor(reactant_stoch, numjumps) + hor = zeros(Int, numjumps) + for j in 1:numjumps + order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) if order > 3 error("Reaction $j has order $order, which is not supported (maximum order is 3).") end @@ -99,11 +99,11 @@ function compute_gi(u, nu, hor, i) if hor[j] == 1 max_gi = max(max_gi, 1) elseif hor[j] == 2 || hor[j] == 3 - stoich = abs(nu[i, j]) - if stoich >= 2 - gi = 2 / stoich + 1 / (stoich - 1) + stoch = abs(nu[i, j]) + if stoch >= 2 + gi = 2 / stoch + 1 / (stoch - 1) max_gi = max(max_gi, ceil(Int, gi)) - elseif stoich == 1 + elseif stoch == 1 max_gi = max(max_gi, hor[j]) end end @@ -167,14 +167,16 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; t_end = tspan[2] epsilon = alg.epsilon - # Extract stoichiometry once from MassActionJump + # Extract stochiometry once from MassActionJump nu = zeros(float(eltype(u0)), length(u0), numjumps) for j in 1:numjumps - for (spec_idx, stoich) in maj.net_stoch[j] - nu[spec_idx, j] = stoich + for (spec_idx, stoch) in maj.net_stoch[j] + nu[spec_idx, j] = stoch end end - hor = compute_hor(nu) + # Extract reactant stochiometry for hor + reactant_stoch = maj.reactant_stoch + hor = compute_hor(reactant_stoch, numjumps) # Set up saveat_times saveat_times = nothing @@ -201,8 +203,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; c(du, u_current, p, t_current, counts, nothing) else for j in 1:numjumps - for (spec_idx, stoich) in maj.net_stoch[j] - du[spec_idx] += stoich * counts[j] + for (spec_idx, stoch) in maj.net_stoch[j] + du[spec_idx] += stoch * counts[j] end end end From b5f77f5514219f90cffdf0c0e16700487d6979a8 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Tue, 26 Aug 2025 11:03:32 +0530 Subject: [PATCH 25/39] compute_gi update --- src/simple_regular_solve.jl | 56 ++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index b2badec3..c665d6cb 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -92,27 +92,44 @@ function compute_hor(reactant_stoch, numjumps) return hor end -function compute_gi(u, nu, hor, i) +function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) + # Initialize reaction_conditions as a vector of vectors of tuples + reaction_conditions = [Vector() for _ in 1:numspecies] + for j in 1:numjumps + for (spec_idx, stoch) in reactant_stoch[j] + if stoch > 0 # Species is a reactant + push!(reaction_conditions[spec_idx], (j, stoch, hor[j])) + end + end + end + return reaction_conditions +end + +function compute_gi(u, reaction_conditions, i) max_gi = 1 - for j in 1:size(nu, 2) - if nu[i, j] < 0 # Species i is a substrate - if hor[j] == 1 - max_gi = max(max_gi, 1) - elseif hor[j] == 2 || hor[j] == 3 - stoch = abs(nu[i, j]) - if stoch >= 2 - gi = 2 / stoch + 1 / (stoch - 1) - max_gi = max(max_gi, ceil(Int, gi)) - elseif stoch == 1 - max_gi = max(max_gi, hor[j]) - end + for (j, nu_ij, hor_j) in reaction_conditions[i] + if hor_j == 1 + max_gi = max(max_gi, 1) + elseif hor_j == 2 + if nu_ij == 1 + max_gi = max(max_gi, 2) + elseif nu_ij >= 2 + gi = u[i] * (2 / nu_ij + 1 / (nu_ij - 1)) + max_gi = max(max_gi, ceil(Int, gi)) + end + elseif hor_j == 3 + if nu_ij == 1 + max_gi = max(max_gi, 3) + elseif nu_ij >= 2 + gi = 1.5 * u[i] * (2 / nu_ij + 1 / (nu_ij - 1)) + max_gi = max(max_gi, ceil(Int, gi)) end end end return max_gi end -function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin) +function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin, reaction_conditions, numjumps) rate(rate_cache, u, p, t) tau = Inf for i in 1:length(u) @@ -122,7 +139,7 @@ function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin mu += nu[i, j] * rate_cache[j] sigma2 += nu[i, j]^2 * rate_cache[j] end - gi = compute_gi(u, nu, hor, i) + gi = compute_gi(u, reaction_conditions, i) bound = max(epsilon * u[i] / gi, 1.0) mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf @@ -167,16 +184,17 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; t_end = tspan[2] epsilon = alg.epsilon - # Extract stochiometry once from MassActionJump + # Extract net stoichiometry for state updates nu = zeros(float(eltype(u0)), length(u0), numjumps) for j in 1:numjumps for (spec_idx, stoch) in maj.net_stoch[j] nu[spec_idx, j] = stoch end end - # Extract reactant stochiometry for hor + # Extract reactant stoichiometry for hor and gi reactant_stoch = maj.reactant_stoch hor = compute_hor(reactant_stoch, numjumps) + reaction_conditions = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) # Set up saveat_times saveat_times = nothing @@ -192,7 +210,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; while t_current < t_end rate(rate_cache, u_current, p, t_current) - tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin) + tau = compute_tau_explicit(u_current, rate_cache, nu, p, t_current, epsilon, rate, dtmin, reaction_conditions, numjumps) tau = min(tau, t_end - t_current) if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] tau = saveat_times[save_idx] - t_current @@ -221,7 +239,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; # Save state if at a saveat time or if saveat is empty if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) - push!(usave, u_new) + push!(usave, copy(u_new)) push!(tsave, t_new) if !isempty(saveat_times) && t_new >= saveat_times[save_idx] save_idx += 1 From 0b72d4c19bc676cd5971625ecb5c45b54e781ea2 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Tue, 26 Aug 2025 11:12:25 +0530 Subject: [PATCH 26/39] added references --- src/simple_regular_solve.jl | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index c665d6cb..c2570189 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -81,6 +81,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; end function compute_hor(reactant_stoch, numjumps) + # Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. + # HOR is the sum of stoichiometric coefficients of reactants in reaction j. hor = zeros(Int, numjumps) for j in 1:numjumps order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) @@ -93,7 +95,9 @@ function compute_hor(reactant_stoch, numjumps) end function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) - # Initialize reaction_conditions as a vector of vectors of tuples + # Precompute reaction conditions for each species i, storing reactions j where i is a reactant, + # along with stoichiometry (nu_ij) and HOR (hor_j), to optimize compute_gi. + # Reactant stoichiometry is used per Cao et al. (2006), Section IV, for g_i calculations. reaction_conditions = [Vector() for _ in 1:numspecies] for j in 1:numjumps for (spec_idx, stoch) in reactant_stoch[j] @@ -106,6 +110,14 @@ function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjump end function compute_gi(u, reaction_conditions, i) + # Compute g_i for species i, bounding the relative change in propensity functions, + # as per Cao et al. (2006), Section IV (between equations 27-28). + # g_i is the maximum over all reactions j where species i is a reactant: + # - HOR = 1: g_i = 1 + # - HOR = 2, nu_ij = 1: g_i = 2 + # - HOR = 2, nu_ij >= 2: g_i = x_i * (2/nu_ij + 1/(nu_ij - 1)) + # - HOR = 3, nu_ij = 1: g_i = 3 + # - HOR = 3, nu_ij >= 2: g_i = 1.5 * x_i * (2/nu_ij + 1/(nu_ij - 1)) max_gi = 1 for (j, nu_ij, hor_j) in reaction_conditions[i] if hor_j == 1 @@ -130,20 +142,25 @@ function compute_gi(u, reaction_conditions, i) end function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin, reaction_conditions, numjumps) + # Compute the tau-leaping step-size using equation (8) from Cao et al. (2006): + # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } + # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): + # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) + # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). rate(rate_cache, u, p, t) tau = Inf for i in 1:length(u) mu = zero(eltype(u)) sigma2 = zero(eltype(u)) for j in 1:size(nu, 2) - mu += nu[i, j] * rate_cache[j] - sigma2 += nu[i, j]^2 * rate_cache[j] + mu += nu[i, j] * rate_cache[j] # Equation (9a) + sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) end gi = compute_gi(u, reaction_conditions, i) - bound = max(epsilon * u[i] / gi, 1.0) - mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf - sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf - tau = min(tau, mu_term, sigma_term) + bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1) + mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) + tau = min(tau, mu_term, sigma_term) # Equation (8) end return max(tau, dtmin) end @@ -228,7 +245,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; end u_new = u_current + du if any(<(0), u_new) - # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) + # Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3 tau /= 2 continue end From 5415947d67d2aa75f9fa849e2725e26421c3f11a Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 04:52:01 +0530 Subject: [PATCH 27/39] added unpack --- src/simple_regular_solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index c2570189..f8dbee16 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -171,8 +171,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; saveat = nothing) validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") - prob = jump_prob.prob - rng = DEFAULT_RNG + + @unpack prob, rng = jump_prob (seed !== nothing) && seed!(rng, seed) maj = jump_prob.massaction_jump From b39390f329eaa18b771a03488903ea12aa938663 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 04:56:19 +0530 Subject: [PATCH 28/39] test changes --- test/regular_jumps.jl | 45 ++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 25fcca61..7d9f12a2 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -143,82 +143,83 @@ end tspan = (0.0, 10.0) p = [0.1, 0.2] prob = DiscreteProblem(u0, tspan, p) - + # Create MassActionJump reactant_stoich = [[1 => 1], [1 => 2]] net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]] rates = [0.1, 0.05] maj = MassActionJump(rates, reactant_stoich, net_stoich) - + # Test PureLeaping JumpProblem creation - jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj)) + jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng) @test jp_pure.aggregator isa PureLeaping @test jp_pure.discrete_jump_aggregation === nothing @test jp_pure.massaction_jump !== nothing @test length(jp_pure.jump_callback.discrete_callbacks) == 0 - + # Test with ConstantRateJump rate(u, p, t) = p[1] * u[1] affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) crj = ConstantRateJump(rate, affect!) - - jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj)) + + jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng) @test jp_pure_crj.aggregator isa PureLeaping @test jp_pure_crj.discrete_jump_aggregation === nothing @test length(jp_pure_crj.constant_jumps) == 1 - + # Test with VariableRateJump vrate(u, p, t) = t * p[1] * u[1] vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) vrj = VariableRateJump(vrate, vaffect!) - - jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj)) + + jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng) @test jp_pure_vrj.aggregator isa PureLeaping @test jp_pure_vrj.discrete_jump_aggregation === nothing @test length(jp_pure_vrj.variable_jumps) == 1 - + # Test with RegularJump function rj_rate(out, u, p, t) out[1] = p[1] * u[1] end - + rj_dc = zeros(3, 1) rj_dc[1, 1] = -1 rj_dc[3, 1] = 1 - + function rj_c(du, u, p, t, counts, mark) mul!(du, rj_dc, counts) end - + regj = RegularJump(rj_rate, rj_c, 1) - - jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj)) + + jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng) @test jp_pure_regj.aggregator isa PureLeaping @test jp_pure_regj.discrete_jump_aggregation === nothing @test jp_pure_regj.regular_jump !== nothing - + # Test mixed jump types mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), variable_jumps = (vrj,), regular_jumps = regj) - jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps) + jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng) @test jp_pure_mixed.aggregator isa PureLeaping @test jp_pure_mixed.discrete_jump_aggregation === nothing @test jp_pure_mixed.massaction_jump !== nothing @test length(jp_pure_mixed.constant_jumps) == 1 @test length(jp_pure_mixed.variable_jumps) == 1 @test jp_pure_mixed.regular_jump !== nothing - + # Test spatial system error spatial_sys = CartesianGrid((2, 2)) hopping_consts = [1.0] - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); + @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, spatial_system = spatial_sys) - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); - hopping_constants = hopping_consts) + @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, + hopping_constants = hopping_consts) + # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) - jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params)) + jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) scaled_rates = [p[1], p[2]/2] @test jp_params.massaction_jump.scaled_rates == scaled_rates end From 822562f6cdf118e24a3e5f4a13051265cf5c2b9f Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 04:57:38 +0530 Subject: [PATCH 29/39] test changes --- test/regular_jumps.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 7d9f12a2..0f5a53ad 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -214,9 +214,8 @@ end @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, spatial_system = spatial_sys) @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, - hopping_constants = hopping_consts) - + # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) From b572987cc8edbe9a2622a78fa371ced6146e3652 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 05:00:42 +0530 Subject: [PATCH 30/39] export changes --- src/JumpProcesses.jl | 2 +- src/simple_regular_solve.jl | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 8776cc6f..1ceb8457 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -129,7 +129,7 @@ export SSAStepper # leaping: include("simple_regular_solve.jl") -export SimpleTauLeaping, EnsembleGPUKernel +export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping # spatial: include("spatial/spatial_massaction_jump.jl") diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index f8dbee16..f52a02f2 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -285,5 +285,3 @@ end function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end - -export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping From 785266ba91acdadb64951f1cd1b7dfdfd60863d8 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 05:01:46 +0530 Subject: [PATCH 31/39] test changes --- test/regular_jumps.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 0f5a53ad..bed37773 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -215,7 +215,7 @@ end spatial_system = spatial_sys) @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, hopping_constants = hopping_consts) - + # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) From b47df7c81c5c9014eb07ee831e0c7d5960202c49 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 05:43:54 +0530 Subject: [PATCH 32/39] some change in gi calculation --- src/simple_regular_solve.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index f52a02f2..4f5e0ce6 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -110,14 +110,18 @@ function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjump end function compute_gi(u, reaction_conditions, i) - # Compute g_i for species i, bounding the relative change in propensity functions, + # Compute g_i for species i to bound the relative change in propensity functions, # as per Cao et al. (2006), Section IV (between equations 27-28). # g_i is the maximum over all reactions j where species i is a reactant: - # - HOR = 1: g_i = 1 - # - HOR = 2, nu_ij = 1: g_i = 2 - # - HOR = 2, nu_ij >= 2: g_i = x_i * (2/nu_ij + 1/(nu_ij - 1)) - # - HOR = 3, nu_ij = 1: g_i = 3 - # - HOR = 3, nu_ij >= 2: g_i = 1.5 * x_i * (2/nu_ij + 1/(nu_ij - 1)) + # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 + # - HOR = 2 (second-order): + # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 + # - nu_ij = 2 (e.g., 2S_i -> products): g_i = x_i * (2/2 + 1/(2-1)) = 2x_i + # - HOR = 3 (third-order): + # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 + # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = 1.5 * x_i * (2/2 + 1/(2-1)) = 3x_i + # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 1.5 * x_i * (2/3 + 1/(3-1)) = 1.75x_i + # Uses precomputed reaction_conditions to optimize checks for HOR = 2 or 3 with nu_ij >= 2. max_gi = 1 for (j, nu_ij, hor_j) in reaction_conditions[i] if hor_j == 1 @@ -126,15 +130,18 @@ function compute_gi(u, reaction_conditions, i) if nu_ij == 1 max_gi = max(max_gi, 2) elseif nu_ij >= 2 + # For nu_ij = 2: g_i = x_i * (2/2 + 1/(2-1)) = 2x_i gi = u[i] * (2 / nu_ij + 1 / (nu_ij - 1)) - max_gi = max(max_gi, ceil(Int, gi)) + max_gi = max(max_gi, ceil(Int64, gi)) end elseif hor_j == 3 if nu_ij == 1 max_gi = max(max_gi, 3) elseif nu_ij >= 2 + # For nu_ij = 2: g_i = 1.5 * x_i * (2/2 + 1/(2-1)) = 3x_i + # For nu_ij = 3: g_i = 1.5 * x_i * (2/3 + 1/(3-1)) = 1.75x_i gi = 1.5 * u[i] * (2 / nu_ij + 1 / (nu_ij - 1)) - max_gi = max(max_gi, ceil(Int, gi)) + max_gi = max(max_gi, ceil(Int64, gi)) end end end From e02d4325f3d7030db97f732ebfa01f3089a913a2 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 05:51:52 +0530 Subject: [PATCH 33/39] changed compute_gi as per paper --- src/simple_regular_solve.jl | 55 +++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 4f5e0ce6..cde3dfdc 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -111,37 +111,46 @@ end function compute_gi(u, reaction_conditions, i) # Compute g_i for species i to bound the relative change in propensity functions, - # as per Cao et al. (2006), Section IV (between equations 27-28). - # g_i is the maximum over all reactions j where species i is a reactant: + # as per Cao et al. (2006), Section IV, equation (27). + # g_i is determined by the highest order of reaction (HOR) where species i is a reactant: # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 # - HOR = 2 (second-order): # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 - # - nu_ij = 2 (e.g., 2S_i -> products): g_i = x_i * (2/2 + 1/(2-1)) = 2x_i + # - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1) # - HOR = 3 (third-order): # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 - # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = 1.5 * x_i * (2/2 + 1/(2-1)) = 3x_i - # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 1.5 * x_i * (2/3 + 1/(3-1)) = 1.75x_i + # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) + # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) # Uses precomputed reaction_conditions to optimize checks for HOR = 2 or 3 with nu_ij >= 2. + max_hor = maximum(isempty(reaction_conditions[i]) ? 0 : [hor_j for (j, nu_ij, hor_j) in reaction_conditions[i]]) max_gi = 1 for (j, nu_ij, hor_j) in reaction_conditions[i] - if hor_j == 1 - max_gi = max(max_gi, 1) - elseif hor_j == 2 - if nu_ij == 1 - max_gi = max(max_gi, 2) - elseif nu_ij >= 2 - # For nu_ij = 2: g_i = x_i * (2/2 + 1/(2-1)) = 2x_i - gi = u[i] * (2 / nu_ij + 1 / (nu_ij - 1)) - max_gi = max(max_gi, ceil(Int64, gi)) - end - elseif hor_j == 3 - if nu_ij == 1 - max_gi = max(max_gi, 3) - elseif nu_ij >= 2 - # For nu_ij = 2: g_i = 1.5 * x_i * (2/2 + 1/(2-1)) = 3x_i - # For nu_ij = 3: g_i = 1.5 * x_i * (2/3 + 1/(3-1)) = 1.75x_i - gi = 1.5 * u[i] * (2 / nu_ij + 1 / (nu_ij - 1)) - max_gi = max(max_gi, ceil(Int64, gi)) + if hor_j == max_hor + if hor_j == 1 + max_gi = max(max_gi, 1) + elseif hor_j == 2 + if nu_ij == 1 + max_gi = max(max_gi, 2) + elseif nu_ij == 2 + if u[i] > 1 # Ensure x_i - 1 > 0 + gi = 2 + 1 / (u[i] - 1) + max_gi = max(max_gi, ceil(Int64, gi)) + end + end + elseif hor_j == 3 + if nu_ij == 1 + max_gi = max(max_gi, 3) + elseif nu_ij == 2 + if u[i] > 1 # Ensure x_i - 1 > 0 + gi = 1.5 * (2 + 1 / (u[i] - 1)) + max_gi = max(max_gi, ceil(Int64, gi)) + end + elseif nu_ij == 3 + if u[i] > 2 # Ensure x_i - 2 > 0 + gi = 3 + 1 / (u[i] - 1) + 2 / (u[i] - 2) + max_gi = max(max_gi, ceil(Int64, gi)) + end + end end end end From 092d36101543f834d950215c30eb2940e3f21f96 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 07:06:13 +0530 Subject: [PATCH 34/39] some --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 46f57c63..82631e11 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" From 7217cf0935ac28ffd3a91dbb844d9af8452dcc90 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 28 Aug 2025 07:06:59 +0530 Subject: [PATCH 35/39] some --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 82631e11..46f57c63 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" From cc3a78a12fc24ab22f3003043da954c5fb16d715 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Fri, 29 Aug 2025 21:27:45 +0530 Subject: [PATCH 36/39] optimized compute_gi --- src/simple_regular_solve.jl | 82 +++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 44 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index cde3dfdc..2e503d79 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -95,24 +95,31 @@ function compute_hor(reactant_stoch, numjumps) end function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) - # Precompute reaction conditions for each species i, storing reactions j where i is a reactant, - # along with stoichiometry (nu_ij) and HOR (hor_j), to optimize compute_gi. - # Reactant stoichiometry is used per Cao et al. (2006), Section IV, for g_i calculations. - reaction_conditions = [Vector() for _ in 1:numspecies] + # Precompute reaction conditions for each species i, including: + # - max_hor: the highest order of reaction (HOR) where species i is a reactant. + # - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. + # Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). + max_hor = zeros(Int, numspecies) + max_stoich = zeros(Int, numspecies) for j in 1:numjumps for (spec_idx, stoch) in reactant_stoch[j] if stoch > 0 # Species is a reactant - push!(reaction_conditions[spec_idx], (j, stoch, hor[j])) + if hor[j] > max_hor[spec_idx] + max_hor[spec_idx] = hor[j] + max_stoich[spec_idx] = stoch + elseif hor[j] == max_hor[spec_idx] + max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch) + end end end end - return reaction_conditions + return max_hor, max_stoich end -function compute_gi(u, reaction_conditions, i) +function compute_gi(u, max_hor, max_stoich, i, t) # Compute g_i for species i to bound the relative change in propensity functions, # as per Cao et al. (2006), Section IV, equation (27). - # g_i is determined by the highest order of reaction (HOR) where species i is a reactant: + # g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant: # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 # - HOR = 2 (second-order): # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 @@ -121,43 +128,30 @@ function compute_gi(u, reaction_conditions, i) # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) - # Uses precomputed reaction_conditions to optimize checks for HOR = 2 or 3 with nu_ij >= 2. - max_hor = maximum(isempty(reaction_conditions[i]) ? 0 : [hor_j for (j, nu_ij, hor_j) in reaction_conditions[i]]) - max_gi = 1 - for (j, nu_ij, hor_j) in reaction_conditions[i] - if hor_j == max_hor - if hor_j == 1 - max_gi = max(max_gi, 1) - elseif hor_j == 2 - if nu_ij == 1 - max_gi = max(max_gi, 2) - elseif nu_ij == 2 - if u[i] > 1 # Ensure x_i - 1 > 0 - gi = 2 + 1 / (u[i] - 1) - max_gi = max(max_gi, ceil(Int64, gi)) - end - end - elseif hor_j == 3 - if nu_ij == 1 - max_gi = max(max_gi, 3) - elseif nu_ij == 2 - if u[i] > 1 # Ensure x_i - 1 > 0 - gi = 1.5 * (2 + 1 / (u[i] - 1)) - max_gi = max(max_gi, ceil(Int64, gi)) - end - elseif nu_ij == 3 - if u[i] > 2 # Ensure x_i - 2 > 0 - gi = 3 + 1 / (u[i] - 1) + 2 / (u[i] - 2) - max_gi = max(max_gi, ceil(Int64, gi)) - end - end - end + # Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep. + if max_hor[i] == 0 # No reactions involve species i as a reactant + return 1.0 + elseif max_hor[i] == 1 + return 1.0 + elseif max_hor[i] == 2 + if max_stoich[i] == 1 + return 2.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1 + end + elseif max_hor[i] == 3 + if max_stoich[i] == 1 + return 3.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1 + elseif max_stoich[i] == 3 + return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 # Fallback to 3.0 if x_i <= 2 end end - return max_gi + return 1.0 # Default case end -function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin, reaction_conditions, numjumps) +function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) # Compute the tau-leaping step-size using equation (8) from Cao et al. (2006): # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): @@ -172,7 +166,7 @@ function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin, rea mu += nu[i, j] * rate_cache[j] # Equation (9a) sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) end - gi = compute_gi(u, reaction_conditions, i) + gi = compute_gi(u, max_hor, max_stoich, i, t) bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1) mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) @@ -227,7 +221,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; # Extract reactant stoichiometry for hor and gi reactant_stoch = maj.reactant_stoch hor = compute_hor(reactant_stoch, numjumps) - reaction_conditions = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) + max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) # Set up saveat_times saveat_times = nothing @@ -243,7 +237,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; while t_current < t_end rate(rate_cache, u_current, p, t_current) - tau = compute_tau_explicit(u_current, rate_cache, nu, p, t_current, epsilon, rate, dtmin, reaction_conditions, numjumps) + tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) tau = min(tau, t_end - t_current) if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] tau = saveat_times[save_idx] - t_current From 48fece2e45c958176b6f415c05ea4f969651589b Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sat, 30 Aug 2025 12:42:15 +0530 Subject: [PATCH 37/39] zero rates case for SimpleAdaptiveTauLeaping is added --- src/simple_regular_solve.jl | 5 ++++- test/regular_jumps.jl | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 2e503d79..02119368 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -152,12 +152,15 @@ function compute_gi(u, max_hor, max_stoich, i, t) end function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) - # Compute the tau-leaping step-size using equation (8) from Cao et al. (2006): + # Compute the tau-leaping step-size using equation (20) from Cao et al. (2006): # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). rate(rate_cache, u, p, t) + if all(==(0), rate_cache) # Handle case where all rates are zero + return dtmin + end tau = Inf for i in 1:length(u) mu = zero(eltype(u)) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index bed37773..33a9a02c 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -136,6 +136,28 @@ end end end +# Test zero-rate case for SimpleAdaptiveTauLeaping +@testset "Zero Rates Test for SimpleAdaptiveTauLeaping" begin + # SIR model: S + I -> 2I, I -> R + reactant_stoch = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] + net_stoch = [[1=>-1, 2=>1], [2=>-1, 3=>1], []] + rates = [0.1/1000, 0.05, 0.0] # beta/N, gamma, dummy rate for empty reaction + maj = MassActionJump(rates, reactant_stoch, net_stoch) + u0 = [0, 0, 0] # All populations zero + tspan = (0.0, 250.0) + prob = DiscreteProblem(u0, tspan) + jump_prob = JumpProblem(prob, PureLeaping(), maj) + + sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, dtmin = 0.1, saveat=1.0) + + for i in 1:Nsims + # Check that solution completes and covers tspan + @test sol[i].t[end] ≈ 250.0 atol=1e-6 + # Check that state remains zero + @test all(u == [0, 0, 0] for u in sol[i].u) + end +end + # Test PureLeaping aggregator functionality @testset "PureLeaping Aggregator Tests" begin # Test with MassActionJump From 12f84ba669743fb79b2b9d8eb21ec3af6fae8ac4 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Fri, 5 Sep 2025 11:10:06 +0530 Subject: [PATCH 38/39] SimpleExplicitTauLeaping --- src/JumpProcesses.jl | 2 +- src/simple_regular_solve.jl | 11 ++++++----- test/regular_jumps.jl | 22 +++++++++++----------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 1ceb8457..c6488f48 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -129,7 +129,7 @@ export SSAStepper # leaping: include("simple_regular_solve.jl") -export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping +export SimpleTauLeaping, EnsembleGPUKernel, SimpleExplicitTauLeaping # spatial: include("spatial/spatial_massaction_jump.jl") diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 02119368..6af0525b 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -1,10 +1,10 @@ struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end -struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm +struct SimpleExplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm epsilon::T # Error control parameter end -SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) +SimpleExplicitTauLeaping(; epsilon=0.05) = SimpleExplicitTauLeaping(epsilon) function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @@ -20,7 +20,7 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) jump_prob.regular_jump !== nothing end -function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping) +function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ @@ -158,7 +158,7 @@ function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). rate(rate_cache, u, p, t) - if all(==(0), rate_cache) # Handle case where all rates are zero + if all(==(0.0), rate_cache) # Handle case where all rates are zero return dtmin end tau = Inf @@ -178,7 +178,7 @@ function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin return max(tau, dtmin) end -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; seed = nothing, dtmin = 1e-10, saveat = nothing) @@ -262,6 +262,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; tau /= 2 continue end + # Ensure non-negativity, as per Cao et al. (2006), Section 3.3 for i in eachindex(u_new) u_new[i] = max(u_new[i], 0) end diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 33a9a02c..baf95117 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -47,15 +47,15 @@ Nsims = 1000 # Solve with SimpleTauLeaping sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) - # MassActionJump formulation for SimpleAdaptiveTauLeaping + # MassActionJump formulation for SimpleExplicitTauLeaping reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]] param_idxs = [1, 2, 3] maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs) jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng) - # Solve with SimpleAdaptiveTauLeaping - sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + # Solve with SimpleExplicitTauLeaping + sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) # Compute mean infected (I) trajectories t_points = 0:1.0:250.0 @@ -64,7 +64,7 @@ Nsims = 1000 mean_adaptive_I = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_points] # Test mean infected trajectories - for i in 1:251 + for i in 1:10:251 @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.05) @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.05) end @@ -113,15 +113,15 @@ end # Solve with SimpleTauLeaping sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1) - # MassActionJump formulation for SimpleAdaptiveTauLeaping + # MassActionJump formulation for SimpleExplicitTauLeaping reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]] net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]] param_idxs = [1, 2, 3] maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs) jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng) - # Solve with SimpleAdaptiveTauLeaping - sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + # Solve with SimpleExplicitTauLeaping + sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) # Compute mean infected (I) trajectories t_points = 0:1.0:250.0 @@ -130,14 +130,14 @@ end mean_adaptive_I = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points] # Test mean infected trajectories - for i in 1:251 + for i in 1:10:251 @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.05) @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.05) end end -# Test zero-rate case for SimpleAdaptiveTauLeaping -@testset "Zero Rates Test for SimpleAdaptiveTauLeaping" begin +# Test zero-rate case for SimpleExplicitTauLeaping +@testset "Zero Rates Test for SimpleExplicitTauLeaping" begin # SIR model: S + I -> 2I, I -> R reactant_stoch = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] net_stoch = [[1=>-1, 2=>1], [2=>-1, 3=>1], []] @@ -148,7 +148,7 @@ end prob = DiscreteProblem(u0, tspan) jump_prob = JumpProblem(prob, PureLeaping(), maj) - sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, dtmin = 0.1, saveat=1.0) + sol = solve(EnsembleProblem(jump_prob), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, dtmin = 0.1, saveat=1.0) for i in 1:Nsims # Check that solution completes and covers tspan From 98d64f3445c1874e084b3702522c723900a194cb Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sat, 6 Sep 2025 00:27:06 +0530 Subject: [PATCH 39/39] test update --- src/JumpProcesses.jl | 2 +- src/simple_regular_solve.jl | 4 ++-- test/regular_jumps.jl | 44 +++++++++++++++---------------------- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index c6488f48..22c51c09 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -129,7 +129,7 @@ export SSAStepper # leaping: include("simple_regular_solve.jl") -export SimpleTauLeaping, EnsembleGPUKernel, SimpleExplicitTauLeaping +export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel # spatial: include("spatial/spatial_massaction_jump.jl") diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 6af0525b..78b2a14c 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -151,7 +151,7 @@ function compute_gi(u, max_hor, max_stoich, i, t) return 1.0 # Default case end -function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) +function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) # Compute the tau-leaping step-size using equation (20) from Cao et al. (2006): # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): @@ -240,7 +240,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; while t_current < t_end rate(rate_cache, u_current, p, t_current) - tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) tau = min(tau, t_end - t_current) if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] tau = saveat_times[save_idx] - t_current diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index baf95117..8db0566a 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -59,15 +59,12 @@ Nsims = 1000 # Compute mean infected (I) trajectories t_points = 0:1.0:250.0 - mean_direct_I = [mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points] - mean_simple_I = [mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points] - mean_adaptive_I = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_points] - - # Test mean infected trajectories - for i in 1:10:251 - @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.05) - @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.05) - end + max_direct_I = maximum([mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_simple_I = maximum([mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_explicit_I = maximum([mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points]) + + @test isapprox(max_direct_I, max_simple_I, rtol=0.05) + @test isapprox(max_direct_I, max_explicit_I, rtol=0.05) end # SEIR model with exposed compartment @@ -125,15 +122,12 @@ end # Compute mean infected (I) trajectories t_points = 0:1.0:250.0 - mean_direct_I = [mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points] - mean_simple_I = [mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points] - mean_adaptive_I = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points] - - # Test mean infected trajectories - for i in 1:10:251 - @test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.05) - @test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.05) - end + max_direct_I = maximum([mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_simple_I = maximum([mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_explicit_I = maximum([mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points]) + + @test isapprox(max_direct_I, max_simple_I, rtol=0.05) + @test isapprox(max_direct_I, max_explicit_I, rtol=0.05) end # Test zero-rate case for SimpleExplicitTauLeaping @@ -148,14 +142,12 @@ end prob = DiscreteProblem(u0, tspan) jump_prob = JumpProblem(prob, PureLeaping(), maj) - sol = solve(EnsembleProblem(jump_prob), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, dtmin = 0.1, saveat=1.0) - - for i in 1:Nsims - # Check that solution completes and covers tspan - @test sol[i].t[end] ≈ 250.0 atol=1e-6 - # Check that state remains zero - @test all(u == [0, 0, 0] for u in sol[i].u) - end + sol = solve(jump_prob, SimpleExplicitTauLeaping(); dtmin = 0.1, saveat=1.0) + + # Check that solution completes and covers tspan + @test sol.t[end] ≈ 250.0 atol=1e-6 + # Check that state remains zero + @test all(u == [0, 0, 0] for u in sol.u) end # Test PureLeaping aggregator functionality