From 44de86189da4c579cad5bd5e70e8cf9636270c04 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 13 Jul 2025 01:04:49 +0530 Subject: [PATCH 1/4] basic kernel --- ext/JumpProcessesKernelAbstractionsExt.jl | 5 +- ext/ssa_stepper.jl | 272 ++++++++++++++++++++++ test/gpu/ssa_test.jl | 29 +++ 3 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 ext/ssa_stepper.jl create mode 100644 test/gpu/ssa_test.jl diff --git a/ext/JumpProcessesKernelAbstractionsExt.jl b/ext/JumpProcessesKernelAbstractionsExt.jl index ae5834606..a726bebb8 100644 --- a/ext/JumpProcessesKernelAbstractionsExt.jl +++ b/ext/JumpProcessesKernelAbstractionsExt.jl @@ -1,8 +1,11 @@ module JumpProcessesKernelAbstractionsExt -using JumpProcesses, SciMLBase +using JumpProcesses, SciMLBase, DiffEqBase using KernelAbstractions, Adapt using StaticArrays +using Random + +include("ssa_stepper.jl") function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, alg::SimpleTauLeaping, diff --git a/ext/ssa_stepper.jl b/ext/ssa_stepper.jl new file mode 100644 index 000000000..cf1b7f77a --- /dev/null +++ b/ext/ssa_stepper.jl @@ -0,0 +1,272 @@ +# Define a GPU-compatible jump data structure +struct GPUJumpData{RF, AF} + num_jumps::Int + rates::RF + affects::AF +end + +# Helper to convert DirectJumpAggregation into GPUJumpData +function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation) + rates = agg.rates + affects = agg.affects! + return GPUJumpData(length(rates), rates, affects) +end + +# Entry point for solving ensembles on GPU +function SciMLBase.__solve( + ensembleprob::SciMLBase.AbstractEnsembleProblem, + alg::SSAStepper, + ensemblealg::EnsembleGPUKernel; + trajectories, + seed=nothing, + saveat=nothing, + save_everystep=true, + save_start=true, + save_end=true, + max_steps=nothing, + kwargs... +) + if trajectories == 1 + return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); + trajectories=1, seed, saveat, save_everystep, save_start, save_end, kwargs...) + end + + prob = ensembleprob.prob + @assert isa(prob, JumpProblem) "Only JumpProblems supported" + @assert isempty(prob.jump_callback.continuous_callbacks) "No continuous callbacks allowed" + @assert prob.prob isa DiscreteProblem "SSAStepper only supports DiscreteProblems" + + # Select backend + backend = ensemblealg.backend === nothing ? CPU() : ensemblealg.backend + probs = [remake(prob) for _ in 1:trajectories] + + # Get aggregation and validate + agg = prob.jump_callback.discrete_callbacks[end].condition + @assert agg isa JumpProcesses.DirectJumpAggregation "Only DirectJumpAggregation is supported" + + # Prepare max_steps estimate + rate_funcs = agg.rates + u0 = prob.prob.u0 + p = prob.prob.p + t0 = prob.prob.tspan[1] + total_rate = sum(rate_func(u0, p, t0) for rate_func in rate_funcs) + max_steps = max_steps === nothing ? Int(ceil(max(1000, prob.prob.tspan[2] * total_rate * 2) + length(saveat isa Number ? collect(prob.prob.tspan[1]:saveat:prob.prob.tspan[2]) : saveat))) : max_steps + @assert max_steps > 0 "max_steps must be positive" + + # Build GPU jump data + rj_data = make_gpu_jump_data(agg) + rj_data_gpu = adapt(backend, GPUJumpData(rj_data.num_jumps, rj_data.rates, rj_data.affects)) + + # Run vectorized Gillespie Direct SSA + ts, us = vectorized_gillespie_direct(probs, prob, alg; backend, trajectories, seed, saveat, save_everystep, save_start, save_end, max_steps, rj_data=rj_data_gpu) + + # Bring results back to CPU + _ts = Array(ts) + _us = Array(us) + + time = @elapsed sol = [begin + ts_view = @view _ts[:, i] + us_view = @view _us[:, :, i] + sol_idx = findlast(!isnan, ts_view) + if sol_idx === nothing + @error "No valid solution for trajectory $i" tspan=probs[i].prob.tspan ts=ts_view + error("Batch solve failed") + end + @views ensembleprob.output_func( + SciMLBase.build_solution( + probs[i].prob, + alg, + ts_view[1:sol_idx], + [SVector{length(us_view[1, :]), eltype(us_view[1, :])}(us_view[j, :]) for j in 1:sol_idx], + k = nothing, + stats = nothing, + calculate_error = false, + retcode = sol_idx < max_steps ? ReturnCode.Success : ReturnCode.Terminated + ), + i)[1] + end for i in eachindex(probs)] + + return SciMLBase.EnsembleSolution(sol, time, true) +end + +# Struct to hold trajectory-specific data +struct TrajectoryDataSSA{U <: StaticArray, P, T} + u0::U + p::P + t_start::T + t_end::T + saveat::Vector{T} +end + +# GPU-compatible random number generation +@inline function exponential_rand(lambda::T, seed::UInt64, idx::Int64) where T + seed = (1103515245 * (seed ⊻ UInt64(idx)) + 12345) % 2^31 + u = Float64(seed) / 2^31 + return -log(u) / lambda +end + +@inline function uniform_rand(seed::UInt64, idx::Int64) + seed = (1103515245 * (seed ⊻ UInt64(idx)) + 12345) % 2^31 + return Float64(seed) / 2^31 +end + +# Main vectorized solver +function vectorized_gillespie_direct(probs, prob::JumpProblem, alg::SSAStepper; + backend, trajectories, seed, saveat, save_everystep, save_start, save_end, max_steps, rj_data) + # Prepare saveat + _saveat = saveat isa Number ? collect(prob.prob.tspan[1]:saveat:prob.prob.tspan[2]) : saveat + _saveat = save_start && _saveat !== nothing && !isempty(_saveat) && _saveat[1] != prob.prob.tspan[1] ? + vcat(prob.prob.tspan[1], _saveat) : _saveat + _saveat = save_end && _saveat !== nothing && !isempty(_saveat) && _saveat[end] != prob.prob.tspan[2] ? + vcat(_saveat, prob.prob.tspan[2]) : _saveat + _saveat = _saveat === nothing ? Float64[] : _saveat + + # Convert to static arrays + probs_data = [TrajectoryDataSSA(SA{eltype(p.prob.u0)}[p.prob.u0...], + p.prob.p, + p.prob.tspan[1], # t_start + p.prob.tspan[2], # t_end + _saveat) for p in probs] + probs_data_gpu = adapt(backend, probs_data) + + state_dim = length(first(probs_data).u0) + num_jumps = rj_data.num_jumps + + # Allocate buffers + ts = allocate(backend, Float64, (max_steps, trajectories)) + us = allocate(backend, Float64, (max_steps, state_dim, trajectories)) + current_u_buf = allocate(backend, Float64, (state_dim, trajectories)) + rate_cache_buf = allocate(backend, Float64, (num_jumps, trajectories)) + + # Initialize current_u_buf with u0 + @kernel function init_buffers_kernel(@Const(probs_data), current_u_buf) + i = @index(Global, Linear) + if i <= size(current_u_buf, 2) + u0 = probs_data[i].u0 + @inbounds for k in 1:length(u0) + current_u_buf[k, i] = u0[k] + end + end + end + init_kernel = init_buffers_kernel(backend) + init_event = init_kernel(probs_data_gpu, current_u_buf; ndrange=trajectories) + synchronize(backend) + + seed_val = seed === nothing ? UInt64(12345) : UInt64(seed) + kernel = gillespie_direct_kernel(backend) + kernel_event = kernel(probs_data_gpu, rj_data, us, ts, current_u_buf, rate_cache_buf, seed_val, max_steps; + ndrange=trajectories) + synchronize(backend) + + return ts, us +end + +# Main Gillespie Direct kernel +@kernel function gillespie_direct_kernel(@Const(prob_data), @Const(rj_data), + us_out, ts_out, current_u_buf, rate_cache_buf, seed::UInt64, max_steps) + i = @index(Global, Linear) + if i <= size(current_u_buf, 2) + current_u = view(current_u_buf, :, i) + rate_cache = view(rate_cache_buf, :, i) + + prob_i = prob_data[i] + u0 = prob_i.u0 + p = prob_i.p + t_start = prob_i.t_start + t_end = prob_i.t_end + saveat = prob_i.saveat + + state_dim = length(u0) + @inbounds for k in 1:state_dim + current_u[k] = u0[k] + end + + t = t_start + step_idx = 1 + saveat_idx = 1 + ts_view = view(ts_out, :, i) + us_view = view(us_out, :, :, i) + + @inbounds for j in 1:max_steps + ts_view[j] = NaN + @inbounds for k in 1:state_dim + us_view[j, k] = NaN + end + end + + ts_view[1] = t + @inbounds for k in 1:state_dim + us_view[1, k] = current_u[k] + end + + while t < t_end && step_idx < max_steps + total_rate = 0.0 + @inbounds for k in 1:rj_data.num_jumps + rate = rj_data.rates[k](current_u, p, t) + rate_cache[k] = max(0.0, rate) + total_rate += rate_cache[k] + end + + if total_rate <= 0.0 + if !isempty(saveat) + while saveat_idx <= length(saveat) && step_idx < max_steps && saveat[saveat_idx] <= t_end + step_idx += 1 + ts_view[step_idx] = saveat[saveat_idx] + @inbounds for k in 1:state_dim + us_view[step_idx, k] = current_u[k] + end + saveat_idx += 1 + end + end + break + end + + delta_t = exponential_rand(total_rate, seed + UInt64(i * max_steps + step_idx), i) + next_t = t + delta_t + + if !isempty(saveat) + while saveat_idx <= length(saveat) && saveat[saveat_idx] <= next_t && step_idx < max_steps + step_idx += 1 + ts_view[step_idx] = saveat[saveat_idx] + @inbounds for k in 1:state_dim + us_view[step_idx, k] = current_u[k] + end + saveat_idx += 1 + end + end + + r = total_rate * uniform_rand(seed + UInt64(i * max_steps + step_idx + 1), i) + cum_rate = 0.0 + jump_idx = 0 + @inbounds for k in 1:rj_data.num_jumps + cum_rate += rate_cache[k] + if r <= cum_rate + jump_idx = k + break + end + end + + if next_t <= t_end && jump_idx > 0 && step_idx < max_steps + t = next_t + mock_integrator = (u=current_u, p=p, t=t) + rj_data.affects[jump_idx](mock_integrator) + step_idx += 1 + ts_view[step_idx] = t + @inbounds for k in 1:state_dim + us_view[step_idx, k] = current_u[k] + end + else + t = t_end + end + end + + while saveat_idx <= length(saveat) && step_idx < max_steps + step_idx += 1 + ts_view[step_idx] = saveat[saveat_idx] + @inbounds for k in 1:state_dim + us_view[step_idx, k] = current_u[k] + end + saveat_idx += 1 + end + end +end \ No newline at end of file diff --git a/test/gpu/ssa_test.jl b/test/gpu/ssa_test.jl new file mode 100644 index 000000000..78485d0b3 --- /dev/null +++ b/test/gpu/ssa_test.jl @@ -0,0 +1,29 @@ +using JumpProcesses, DiffEqBase, SciMLBase, Plots, CUDA +using Test, LinearAlgebra +using StableRNGs +rng = StableRNG(12345) + +rate = (u, p, t) -> u[1] +affect! = function (integrator) + integrator.u[1] += 1 +end +jump = ConstantRateJump(rate, affect!) + +rate = (u, p, t) -> 0.5u[1] +affect! = function (integrator) + integrator.u[1] -= 1 +end +jump2 = ConstantRateJump(rate, affect!) + +prob = DiscreteProblem([10.0], (0.0, 3.0)) +jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) + +integrator = init(jump_prob, SSAStepper()) +step!(integrator) +integrator.u[1] + +# test different saving behaviors + +sol = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleGPUKernel(), + trajectories=100, saveat=1.0) +plot(sol) From 10498451dc6fece219afc9e3a9e10560ad77bb99 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Thu, 17 Jul 2025 20:13:38 +0530 Subject: [PATCH 2/4] basic version for gpu enhanced ssa is working --- ext/ssa_stepper.jl | 147 +++++++++++++++++++++---------------------- test/gpu/ssa_test.jl | 2 +- 2 files changed, 73 insertions(+), 76 deletions(-) diff --git a/ext/ssa_stepper.jl b/ext/ssa_stepper.jl index cf1b7f77a..e33757fbf 100644 --- a/ext/ssa_stepper.jl +++ b/ext/ssa_stepper.jl @@ -1,15 +1,57 @@ -# Define a GPU-compatible jump data structure -struct GPUJumpData{RF, AF} - num_jumps::Int - rates::RF - affects::AF -end - -# Helper to convert DirectJumpAggregation into GPUJumpData -function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation) +function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::JumpProblem, backend) + num_jumps = length(agg.rates) + state_dim = length(prob.prob.u0) # Get state dimension from DiscreteProblem + p = prob.prob.p + t = prob.prob.tspan[1] rates = agg.rates affects = agg.affects! - return GPUJumpData(length(rates), rates, affects) + + # Initialize arrays + rate_coeffs = zeros(Float64, num_jumps) + affect_increments = zeros(Int64, num_jumps, state_dim) + depend_idx = zeros(Int64, num_jumps) + + # Test point for evaluating rate functions + u_test = ones(Float64, state_dim) + + # Extract rate coefficients and dependency indices + for k in 1:num_jumps + rate_base = rates[k](u_test, p, t) + found_dep = false + for i in 1:state_dim + u_perturbed = copy(u_test) + u_perturbed[i] = 2.0 + rate_perturbed = rates[k](u_perturbed, p, t) + delta_rate = rate_perturbed - rate_base + if abs(delta_rate) > 1e-10 # Detect significant dependence + rate_coeffs[k] = delta_rate / (u_perturbed[i] - u_test[i]) + depend_idx[k] = i + found_dep = true + break + end + end + if !found_dep + rate_coeffs[k] = rate_base # Constant rate (no state dependency) + depend_idx[k] = 1 # Default to first state + end + end + + # Extract affect increments + for k in 1:num_jumps + u = copy(u_test) + mock_integrator = (u=u, p=p, t=t) + affects[k](mock_integrator) + for i in 1:state_dim + affect_increments[k, i] = Int64(u[i] - u_test[i]) + end + end + + # Adapt to GPU + num_jumps = adapt(backend, num_jumps) + rate_coeffs_gpu = adapt(backend, rate_coeffs) + affect_increments_gpu = adapt(backend, affect_increments) + depend_idx_gpu = adapt(backend, depend_idx) + return (num_jumps, rate_coeffs_gpu, affect_increments_gpu, depend_idx_gpu) end # Entry point for solving ensembles on GPU @@ -28,7 +70,7 @@ function SciMLBase.__solve( ) if trajectories == 1 return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); - trajectories=1, seed, saveat, save_everystep, save_start, save_end, kwargs...) + trajectories=1, seed, max_steps, kwargs...) end prob = ensembleprob.prob @@ -50,15 +92,13 @@ function SciMLBase.__solve( p = prob.prob.p t0 = prob.prob.tspan[1] total_rate = sum(rate_func(u0, p, t0) for rate_func in rate_funcs) - max_steps = max_steps === nothing ? Int(ceil(max(1000, prob.prob.tspan[2] * total_rate * 2) + length(saveat isa Number ? collect(prob.prob.tspan[1]:saveat:prob.prob.tspan[2]) : saveat))) : max_steps + max_steps = max_steps === nothing ? Int(ceil(max(1000, prob.prob.tspan[2] * total_rate * 2))) : max_steps @assert max_steps > 0 "max_steps must be positive" - # Build GPU jump data - rj_data = make_gpu_jump_data(agg) - rj_data_gpu = adapt(backend, GPUJumpData(rj_data.num_jumps, rj_data.rates, rj_data.affects)) + rj_data = make_gpu_jump_data(agg, prob, backend) + rj_data_gpu = adapt(backend, rj_data) - # Run vectorized Gillespie Direct SSA - ts, us = vectorized_gillespie_direct(probs, prob, alg; backend, trajectories, seed, saveat, save_everystep, save_start, save_end, max_steps, rj_data=rj_data_gpu) + ts, us = vectorized_gillespie_direct(probs, prob, alg; backend, trajectories, seed, max_steps, rj_data=rj_data_gpu) # Bring results back to CPU _ts = Array(ts) @@ -95,7 +135,6 @@ struct TrajectoryDataSSA{U <: StaticArray, P, T} p::P t_start::T t_end::T - saveat::Vector{T} end # GPU-compatible random number generation @@ -112,33 +151,21 @@ end # Main vectorized solver function vectorized_gillespie_direct(probs, prob::JumpProblem, alg::SSAStepper; - backend, trajectories, seed, saveat, save_everystep, save_start, save_end, max_steps, rj_data) - # Prepare saveat - _saveat = saveat isa Number ? collect(prob.prob.tspan[1]:saveat:prob.prob.tspan[2]) : saveat - _saveat = save_start && _saveat !== nothing && !isempty(_saveat) && _saveat[1] != prob.prob.tspan[1] ? - vcat(prob.prob.tspan[1], _saveat) : _saveat - _saveat = save_end && _saveat !== nothing && !isempty(_saveat) && _saveat[end] != prob.prob.tspan[2] ? - vcat(_saveat, prob.prob.tspan[2]) : _saveat - _saveat = _saveat === nothing ? Float64[] : _saveat - - # Convert to static arrays + backend, trajectories, seed, max_steps, rj_data) + num_jumps, rate_coeffs, affect_increments, depend_idx = rj_data # Unpack the tuple probs_data = [TrajectoryDataSSA(SA{eltype(p.prob.u0)}[p.prob.u0...], p.prob.p, - p.prob.tspan[1], # t_start - p.prob.tspan[2], # t_end - _saveat) for p in probs] + p.prob.tspan[1], + p.prob.tspan[2]) for p in probs] probs_data_gpu = adapt(backend, probs_data) state_dim = length(first(probs_data).u0) - num_jumps = rj_data.num_jumps - # Allocate buffers ts = allocate(backend, Float64, (max_steps, trajectories)) us = allocate(backend, Float64, (max_steps, state_dim, trajectories)) current_u_buf = allocate(backend, Float64, (state_dim, trajectories)) rate_cache_buf = allocate(backend, Float64, (num_jumps, trajectories)) - # Initialize current_u_buf with u0 @kernel function init_buffers_kernel(@Const(probs_data), current_u_buf) i = @index(Global, Linear) if i <= size(current_u_buf, 2) @@ -154,7 +181,7 @@ function vectorized_gillespie_direct(probs, prob::JumpProblem, alg::SSAStepper; seed_val = seed === nothing ? UInt64(12345) : UInt64(seed) kernel = gillespie_direct_kernel(backend) - kernel_event = kernel(probs_data_gpu, rj_data, us, ts, current_u_buf, rate_cache_buf, seed_val, max_steps; + kernel_event = kernel(probs_data_gpu, num_jumps, rate_coeffs, affect_increments, depend_idx, us, ts, current_u_buf, rate_cache_buf, seed_val, max_steps; ndrange=trajectories) synchronize(backend) @@ -162,8 +189,9 @@ function vectorized_gillespie_direct(probs, prob::JumpProblem, alg::SSAStepper; end # Main Gillespie Direct kernel -@kernel function gillespie_direct_kernel(@Const(prob_data), @Const(rj_data), - us_out, ts_out, current_u_buf, rate_cache_buf, seed::UInt64, max_steps) +@kernel function gillespie_direct_kernel(@Const(prob_data), @Const(num_jumps), + @Const(rate_coeffs), @Const(affect_increments), + @Const(depend_idx), us_out, ts_out, current_u_buf, rate_cache_buf, seed::UInt64, max_steps) i = @index(Global, Linear) if i <= size(current_u_buf, 2) current_u = view(current_u_buf, :, i) @@ -174,7 +202,6 @@ end p = prob_i.p t_start = prob_i.t_start t_end = prob_i.t_end - saveat = prob_i.saveat state_dim = length(u0) @inbounds for k in 1:state_dim @@ -183,7 +210,6 @@ end t = t_start step_idx = 1 - saveat_idx = 1 ts_view = view(ts_out, :, i) us_view = view(us_out, :, :, i) @@ -201,44 +227,23 @@ end while t < t_end && step_idx < max_steps total_rate = 0.0 - @inbounds for k in 1:rj_data.num_jumps - rate = rj_data.rates[k](current_u, p, t) + @inbounds for k in 1:num_jumps + rate = rate_coeffs[k] * current_u[depend_idx[k]] rate_cache[k] = max(0.0, rate) total_rate += rate_cache[k] end if total_rate <= 0.0 - if !isempty(saveat) - while saveat_idx <= length(saveat) && step_idx < max_steps && saveat[saveat_idx] <= t_end - step_idx += 1 - ts_view[step_idx] = saveat[saveat_idx] - @inbounds for k in 1:state_dim - us_view[step_idx, k] = current_u[k] - end - saveat_idx += 1 - end - end break end delta_t = exponential_rand(total_rate, seed + UInt64(i * max_steps + step_idx), i) next_t = t + delta_t - if !isempty(saveat) - while saveat_idx <= length(saveat) && saveat[saveat_idx] <= next_t && step_idx < max_steps - step_idx += 1 - ts_view[step_idx] = saveat[saveat_idx] - @inbounds for k in 1:state_dim - us_view[step_idx, k] = current_u[k] - end - saveat_idx += 1 - end - end - r = total_rate * uniform_rand(seed + UInt64(i * max_steps + step_idx + 1), i) cum_rate = 0.0 jump_idx = 0 - @inbounds for k in 1:rj_data.num_jumps + @inbounds for k in 1:num_jumps cum_rate += rate_cache[k] if r <= cum_rate jump_idx = k @@ -248,8 +253,9 @@ end if next_t <= t_end && jump_idx > 0 && step_idx < max_steps t = next_t - mock_integrator = (u=current_u, p=p, t=t) - rj_data.affects[jump_idx](mock_integrator) + @inbounds for j in 1:state_dim + current_u[j] += affect_increments[jump_idx, j] + end step_idx += 1 ts_view[step_idx] = t @inbounds for k in 1:state_dim @@ -259,14 +265,5 @@ end t = t_end end end - - while saveat_idx <= length(saveat) && step_idx < max_steps - step_idx += 1 - ts_view[step_idx] = saveat[saveat_idx] - @inbounds for k in 1:state_dim - us_view[step_idx, k] = current_u[k] - end - saveat_idx += 1 - end end -end \ No newline at end of file +end diff --git a/test/gpu/ssa_test.jl b/test/gpu/ssa_test.jl index 78485d0b3..bb8bf87a0 100644 --- a/test/gpu/ssa_test.jl +++ b/test/gpu/ssa_test.jl @@ -24,6 +24,6 @@ integrator.u[1] # test different saving behaviors -sol = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleGPUKernel(), +sol = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleGPUKernel(CUDABackend()), trajectories=100, saveat=1.0) plot(sol) From bde3206a468898eeef6b21223b3f31d6a223ee03 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 20 Jul 2025 02:41:12 +0530 Subject: [PATCH 3/4] some changes in kernel --- ext/ssa_stepper.jl | 295 +++++++++++++++++++++++++-------------------- 1 file changed, 161 insertions(+), 134 deletions(-) diff --git a/ext/ssa_stepper.jl b/ext/ssa_stepper.jl index e33757fbf..ca7aa4740 100644 --- a/ext/ssa_stepper.jl +++ b/ext/ssa_stepper.jl @@ -1,48 +1,56 @@ +# Modified make_gpu_jump_data to handle arbitrary dependencies function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::JumpProblem, backend) num_jumps = length(agg.rates) - state_dim = length(prob.prob.u0) # Get state dimension from DiscreteProblem + state_dim = length(prob.prob.u0) p = prob.prob.p t = prob.prob.tspan[1] rates = agg.rates affects = agg.affects! - # Initialize arrays - rate_coeffs = zeros(Float64, num_jumps) + # Initialize arrays for affect increments affect_increments = zeros(Int64, num_jumps, state_dim) - depend_idx = zeros(Int64, num_jumps) - - # Test point for evaluating rate functions u_test = ones(Float64, state_dim) - # Extract rate coefficients and dependency indices + # Extract affect increments + for k in 1:num_jumps + u = copy(u_test) + mock_integrator = (u=u, p=p, t=t) + affects[k](mock_integrator) + for i in 1:state_dim + affect_increments[k, i] = Int64(u[i] - u_test[i]) + end + end + + # Analyze rate dependencies + rate_coeffs = zeros(Float64, num_jumps) + depend_indices = Int64[] # Flattened array of dependency indices + depend_starts = zeros(Int64, num_jumps) # Start index for each jump + depend_counts = zeros(Int64, num_jumps) # Number of dependencies per jump + for k in 1:num_jumps rate_base = rates[k](u_test, p, t) - found_dep = false + deps = Int64[] for i in 1:state_dim u_perturbed = copy(u_test) u_perturbed[i] = 2.0 rate_perturbed = rates[k](u_perturbed, p, t) delta_rate = rate_perturbed - rate_base - if abs(delta_rate) > 1e-10 # Detect significant dependence - rate_coeffs[k] = delta_rate / (u_perturbed[i] - u_test[i]) - depend_idx[k] = i - found_dep = true - break + if abs(delta_rate) > 1e-10 + push!(deps, i) end end - if !found_dep - rate_coeffs[k] = rate_base # Constant rate (no state dependency) - depend_idx[k] = 1 # Default to first state - end - end - # Extract affect increments - for k in 1:num_jumps - u = copy(u_test) - mock_integrator = (u=u, p=p, t=t) - affects[k](mock_integrator) - for i in 1:state_dim - affect_increments[k, i] = Int64(u[i] - u_test[i]) + depend_starts[k] = length(depend_indices) + 1 + depend_counts[k] = length(deps) + append!(depend_indices, deps) + + if isempty(deps) + # Constant rate + rate_coeffs[k] = rate_base + else + # Assume polynomial rate: k * prod(u[i] for i in deps) + u_prod = prod(u_test[i] for i in deps) + rate_coeffs[k] = rate_base / u_prod end end @@ -50,113 +58,20 @@ function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::Jump num_jumps = adapt(backend, num_jumps) rate_coeffs_gpu = adapt(backend, rate_coeffs) affect_increments_gpu = adapt(backend, affect_increments) - depend_idx_gpu = adapt(backend, depend_idx) - return (num_jumps, rate_coeffs_gpu, affect_increments_gpu, depend_idx_gpu) + depend_indices_gpu = adapt(backend, depend_indices) + depend_starts_gpu = adapt(backend, depend_starts) + depend_counts_gpu = adapt(backend, depend_counts) + return (num_jumps, rate_coeffs_gpu, affect_increments_gpu, depend_indices_gpu, depend_starts_gpu, depend_counts_gpu) end -# Entry point for solving ensembles on GPU -function SciMLBase.__solve( - ensembleprob::SciMLBase.AbstractEnsembleProblem, - alg::SSAStepper, - ensemblealg::EnsembleGPUKernel; - trajectories, - seed=nothing, - saveat=nothing, - save_everystep=true, - save_start=true, - save_end=true, - max_steps=nothing, - kwargs... -) - if trajectories == 1 - return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); - trajectories=1, seed, max_steps, kwargs...) - end - - prob = ensembleprob.prob - @assert isa(prob, JumpProblem) "Only JumpProblems supported" - @assert isempty(prob.jump_callback.continuous_callbacks) "No continuous callbacks allowed" - @assert prob.prob isa DiscreteProblem "SSAStepper only supports DiscreteProblems" - - # Select backend - backend = ensemblealg.backend === nothing ? CPU() : ensemblealg.backend - probs = [remake(prob) for _ in 1:trajectories] - - # Get aggregation and validate - agg = prob.jump_callback.discrete_callbacks[end].condition - @assert agg isa JumpProcesses.DirectJumpAggregation "Only DirectJumpAggregation is supported" - - # Prepare max_steps estimate - rate_funcs = agg.rates - u0 = prob.prob.u0 - p = prob.prob.p - t0 = prob.prob.tspan[1] - total_rate = sum(rate_func(u0, p, t0) for rate_func in rate_funcs) - max_steps = max_steps === nothing ? Int(ceil(max(1000, prob.prob.tspan[2] * total_rate * 2))) : max_steps - @assert max_steps > 0 "max_steps must be positive" - - rj_data = make_gpu_jump_data(agg, prob, backend) - rj_data_gpu = adapt(backend, rj_data) - - ts, us = vectorized_gillespie_direct(probs, prob, alg; backend, trajectories, seed, max_steps, rj_data=rj_data_gpu) - - # Bring results back to CPU - _ts = Array(ts) - _us = Array(us) - - time = @elapsed sol = [begin - ts_view = @view _ts[:, i] - us_view = @view _us[:, :, i] - sol_idx = findlast(!isnan, ts_view) - if sol_idx === nothing - @error "No valid solution for trajectory $i" tspan=probs[i].prob.tspan ts=ts_view - error("Batch solve failed") - end - @views ensembleprob.output_func( - SciMLBase.build_solution( - probs[i].prob, - alg, - ts_view[1:sol_idx], - [SVector{length(us_view[1, :]), eltype(us_view[1, :])}(us_view[j, :]) for j in 1:sol_idx], - k = nothing, - stats = nothing, - calculate_error = false, - retcode = sol_idx < max_steps ? ReturnCode.Success : ReturnCode.Terminated - ), - i)[1] - end for i in eachindex(probs)] - - return SciMLBase.EnsembleSolution(sol, time, true) -end - -# Struct to hold trajectory-specific data -struct TrajectoryDataSSA{U <: StaticArray, P, T} - u0::U - p::P - t_start::T - t_end::T -end - -# GPU-compatible random number generation -@inline function exponential_rand(lambda::T, seed::UInt64, idx::Int64) where T - seed = (1103515245 * (seed ⊻ UInt64(idx)) + 12345) % 2^31 - u = Float64(seed) / 2^31 - return -log(u) / lambda -end - -@inline function uniform_rand(seed::UInt64, idx::Int64) - seed = (1103515245 * (seed ⊻ UInt64(idx)) + 12345) % 2^31 - return Float64(seed) / 2^31 -end - -# Main vectorized solver +# Modified vectorized_gillespie_direct function vectorized_gillespie_direct(probs, prob::JumpProblem, alg::SSAStepper; - backend, trajectories, seed, max_steps, rj_data) - num_jumps, rate_coeffs, affect_increments, depend_idx = rj_data # Unpack the tuple + backend, trajectories, seed, max_steps, rj_data) + num_jumps, rate_coeffs, affect_increments, depend_indices, depend_starts, depend_counts = rj_data probs_data = [TrajectoryDataSSA(SA{eltype(p.prob.u0)}[p.prob.u0...], - p.prob.p, - p.prob.tspan[1], - p.prob.tspan[2]) for p in probs] + p.prob.p, + p.prob.tspan[1], + p.prob.tspan[2]) for p in probs] probs_data_gpu = adapt(backend, probs_data) state_dim = length(first(probs_data).u0) @@ -181,17 +96,20 @@ function vectorized_gillespie_direct(probs, prob::JumpProblem, alg::SSAStepper; seed_val = seed === nothing ? UInt64(12345) : UInt64(seed) kernel = gillespie_direct_kernel(backend) - kernel_event = kernel(probs_data_gpu, num_jumps, rate_coeffs, affect_increments, depend_idx, us, ts, current_u_buf, rate_cache_buf, seed_val, max_steps; + kernel_event = kernel(probs_data_gpu, num_jumps, rate_coeffs, affect_increments, + depend_indices, depend_starts, depend_counts, + us, ts, current_u_buf, rate_cache_buf, seed_val, max_steps; ndrange=trajectories) synchronize(backend) return ts, us end -# Main Gillespie Direct kernel +# Modified Gillespie Direct kernel for arbitrary dependencies @kernel function gillespie_direct_kernel(@Const(prob_data), @Const(num_jumps), @Const(rate_coeffs), @Const(affect_increments), - @Const(depend_idx), us_out, ts_out, current_u_buf, rate_cache_buf, seed::UInt64, max_steps) + @Const(depend_indices), @Const(depend_starts), @Const(depend_counts), + us_out, ts_out, current_u_buf, rate_cache_buf, seed::UInt64, max_steps) i = @index(Global, Linear) if i <= size(current_u_buf, 2) current_u = view(current_u_buf, :, i) @@ -199,7 +117,6 @@ end prob_i = prob_data[i] u0 = prob_i.u0 - p = prob_i.p t_start = prob_i.t_start t_end = prob_i.t_end @@ -228,12 +145,27 @@ end while t < t_end && step_idx < max_steps total_rate = 0.0 @inbounds for k in 1:num_jumps - rate = rate_coeffs[k] * current_u[depend_idx[k]] + rate = rate_coeffs[k] + start_idx = depend_starts[k] + count = depend_counts[k] + for d in 0:(count-1) + state_idx = depend_indices[start_idx + d] + rate *= current_u[state_idx] + end rate_cache[k] = max(0.0, rate) total_rate += rate_cache[k] end if total_rate <= 0.0 + # Extend trajectory to t_end with constant state + while t < t_end && step_idx < max_steps + step_idx += 1 + t = min(t + 0.1, t_end) # Match saveat interval + ts_view[step_idx] = t + @inbounds for k in 1:state_dim + us_view[step_idx, k] = current_u[k] + end + end break end @@ -254,7 +186,7 @@ end if next_t <= t_end && jump_idx > 0 && step_idx < max_steps t = next_t @inbounds for j in 1:state_dim - current_u[j] += affect_increments[jump_idx, j] + current_u[j] = max(0.0, current_u[j] + affect_increments[jump_idx, j]) # Prevent negative states end step_idx += 1 ts_view[step_idx] = t @@ -263,7 +195,102 @@ end end else t = t_end + # Ensure final state is recorded + if step_idx < max_steps + step_idx += 1 + ts_view[step_idx] = t + @inbounds for k in 1:state_dim + us_view[step_idx, k] = current_u[k] + end + end end end end end + +# Modified SciMLBase.__solve with proper interpolation +function SciMLBase.__solve( + ensembleprob::SciMLBase.AbstractEnsembleProblem, + alg::SSAStepper, + ensemblealg::EnsembleGPUKernel; + trajectories, + seed=nothing, + saveat=0.1, + save_everystep=true, + save_start=true, + save_end=true, + kwargs... +) + if trajectories == 1 + return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); + trajectories=1, seed, saveat, kwargs...) + end + + prob = ensembleprob.prob + @assert isa(prob, JumpProblem) "Only JumpProblems supported" + @assert isempty(prob.jump_callback.continuous_callbacks) "No continuous callbacks allowed" + @assert prob.prob isa DiscreteProblem "SSAStepper only supports DiscreteProblems" + + backend = ensemblealg.backend === nothing ? CPU() : ensemblealg.backend + probs = [remake(prob) for _ in 1:trajectories] + + rate_funcs = prob.jump_callback.discrete_callbacks[end].condition.rates + u0 = prob.prob.u0 + p = prob.prob.p + t0 = prob.prob.tspan[1] + total_rate = sum(rate_func(u0, p, t0) for rate_func in rate_funcs) + max_steps = Int(ceil(max(10000, prob.prob.tspan[2] * total_rate * 2))) + @assert max_steps > 0 "max_steps must be positive" + + rj_data = make_gpu_jump_data(prob.jump_callback.discrete_callbacks[end].condition, prob, backend) + rj_data_gpu = adapt(backend, rj_data) + + ts, us = vectorized_gillespie_direct(probs, prob, alg; backend, trajectories, seed, max_steps, rj_data=rj_data_gpu) + + _ts = Array(ts) + _us = Array(us) + + time = @elapsed sol = [begin + ts_view = @view _ts[:, i] + us_view = @view _us[:, :, i] + sol_idx = findlast(!isnan, ts_view) + if sol_idx === nothing + @error "No valid solution for trajectory $i" tspan=probs[i].prob.tspan ts=ts_view + error("Batch solve failed") + end + @views ensembleprob.output_func( + SciMLBase.build_solution( + probs[i].prob, + alg, + ts_view[1:sol_idx], + [SVector{length(us_view[1, :]), eltype(us_view[1, :])}(us_view[j, :]) for j in 1:sol_idx], + k = nothing, + stats = nothing, + calculate_error = false, + retcode = sol_idx < max_steps ? ReturnCode.Success : ReturnCode.Terminated + ), + i)[1] + end for i in eachindex(probs)] + + return SciMLBase.EnsembleSolution(sol, time, true) +end + +# Struct to hold trajectory-specific data +struct TrajectoryDataSSA{U <: StaticArray, P, T} + u0::U + p::P + t_start::T + t_end::T +end + +# GPU-compatible random number generation +@inline function exponential_rand(lambda::T, seed::UInt64, idx::Int64) where T + seed = (1103515245 * (seed ⊻ UInt64(idx)) + 12345) % 2^31 + u = Float64(seed) / 2^31 + return -log(u) / lambda +end + +@inline function uniform_rand(seed::UInt64, idx::Int64) + seed = (1103515245 * (seed ⊻ UInt64(idx)) + 12345) % 2^31 + return Float64(seed) / 2^31 +end \ No newline at end of file From 1646405dfbb68fdf243d68dfcfcd505a681c447b Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sun, 20 Jul 2025 04:34:24 +0530 Subject: [PATCH 4/4] some changes in make_gpu_jump_data --- ext/ssa_stepper.jl | 60 +++++++++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/ext/ssa_stepper.jl b/ext/ssa_stepper.jl index ca7aa4740..ff054bfbf 100644 --- a/ext/ssa_stepper.jl +++ b/ext/ssa_stepper.jl @@ -1,5 +1,5 @@ # Modified make_gpu_jump_data to handle arbitrary dependencies -function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::JumpProblem, backend) +function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::JumpProblem, backend; user_rate_indices=nothing) num_jumps = length(agg.rates) state_dim = length(prob.prob.u0) p = prob.prob.p @@ -9,7 +9,7 @@ function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::Jump # Initialize arrays for affect increments affect_increments = zeros(Int64, num_jumps, state_dim) - u_test = ones(Float64, state_dim) + u_test = copy(prob.prob.u0) # Use initial state for realistic affect testing # Extract affect increments for k in 1:num_jumps @@ -17,7 +17,7 @@ function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::Jump mock_integrator = (u=u, p=p, t=t) affects[k](mock_integrator) for i in 1:state_dim - affect_increments[k, i] = Int64(u[i] - u_test[i]) + affect_increments[k, i] = Int64(round(u[i] - u_test[i])) end end @@ -27,16 +27,25 @@ function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::Jump depend_starts = zeros(Int64, num_jumps) # Start index for each jump depend_counts = zeros(Int64, num_jumps) # Number of dependencies per jump + # Test points: ones, initial state, and perturbed state + test_points = [ones(Float64, state_dim), prob.prob.u0, 2.0 * ones(Float64, state_dim)] for k in 1:num_jumps - rate_base = rates[k](u_test, p, t) deps = Int64[] - for i in 1:state_dim - u_perturbed = copy(u_test) - u_perturbed[i] = 2.0 - rate_perturbed = rates[k](u_perturbed, p, t) - delta_rate = rate_perturbed - rate_base - if abs(delta_rate) > 1e-10 - push!(deps, i) + rate_base = rates[k](test_points[1], p, t) + is_constant = true + + # Check dependencies across multiple test points + for u_test in test_points + rate_base = rates[k](u_test, p, t) + for i in 1:state_dim + u_perturbed = copy(u_test) + u_perturbed[i] = u_test[i] * 1.5 + 1e-6 # Small perturbation + rate_perturbed = rates[k](u_perturbed, p, t) + delta_rate = rate_perturbed - rate_base + if abs(delta_rate) > 1e-6 * max(abs(rate_base), 1e-6) && !(i in deps) + push!(deps, i) + is_constant = false + end end end @@ -44,13 +53,32 @@ function make_gpu_jump_data(agg::JumpProcesses.DirectJumpAggregation, prob::Jump depend_counts[k] = length(deps) append!(depend_indices, deps) - if isempty(deps) - # Constant rate + if is_constant rate_coeffs[k] = rate_base else - # Assume polynomial rate: k * prod(u[i] for i in deps) - u_prod = prod(u_test[i] for i in deps) - rate_coeffs[k] = rate_base / u_prod + # Compute coefficient assuming rate = k * prod(u[i] for i in deps) + u_test = test_points[2] # Use initial state + rate_base = rates[k](u_test, p, t) + u_prod = prod(u_test[i] for i in deps; init=1.0) + rate_coeffs[k] = u_prod != 0.0 ? rate_base / u_prod : 0.0 + end + end + + # Override with user-provided rate indices if available + if user_rate_indices !== nothing + @assert length(user_rate_indices) == num_jumps "user_rate_indices must match number of jumps" + depend_indices = Int64[] + depend_starts = zeros(Int64, num_jumps) + depend_counts = zeros(Int64, num_jumps) + for k in 1:num_jumps + deps = user_rate_indices[k] + depend_starts[k] = length(depend_indices) + 1 + depend_counts[k] = length(deps) + append!(depend_indices, deps) + u_test = test_points[2] + rate_base = rates[k](u_test, p, t) + u_prod = prod(u_test[i] for i in deps; init=1.0) + rate_coeffs[k] = u_prod != 0.0 ? rate_base / u_prod : 0.0 end end