From 4988ee10da32d305e3c74b9cdc0888ff68dcb9e8 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 17 Feb 2025 17:51:46 -0800 Subject: [PATCH 01/14] Early work on the new discrete backend for MTK --- src/systems/clock_inference.jl | 11 ++++++++- src/systems/systems.jl | 7 +++++- src/systems/systemstructure.jl | 43 +++++++++++++++++++++++++--------- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 42fe28f7c7..86fe7e85ac 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -100,7 +100,7 @@ function infer_clocks!(ci::ClockInference) c = BitSet(c′) idxs = intersect(c, inferred) isempty(idxs) && continue - if !allequal(var_domain[i] for i in idxs) + if !allequal(iscontinuous(var_domain[i]) for i in idxs) display(fullvars[c′]) throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])")) end @@ -155,6 +155,9 @@ function split_system(ci::ClockInference{S}) where {S} cid_to_var = Vector{Int}[] # cid_counter = number of clocks cid_counter = Ref(0) + + # populates clock_to_id and id_to_clock + # checks if there is a continuous_id (for some reason? clock to id does this too) for (i, d) in enumerate(eq_domain) cid = let cid_counter = cid_counter, id_to_clock = id_to_clock, continuous_id = continuous_id @@ -174,9 +177,13 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_eq, i, cid) end continuous_id = continuous_id[] + # for each clock partition what are the input (indexes/vars) input_idxs = map(_ -> Int[], 1:cid_counter[]) inputs = map(_ -> Any[], 1:cid_counter[]) + # var_domain corresponds to fullvars/all variables in the system nvv = length(var_domain) + # put variables into the right clock partition + # keep track of inputs to each partition for i in 1:nvv d = var_domain[i] cid = get(clock_to_id, d, 0) @@ -190,6 +197,7 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_var, i, cid) end + # breaks the system up into a continous and 0 or more discrete systems tss = similar(cid_to_eq, S) for (id, ieqs) in enumerate(cid_to_eq) ts_i = system_subset(ts, ieqs) @@ -199,6 +207,7 @@ function split_system(ci::ClockInference{S}) where {S} end tss[id] = ts_i end + # put the continous system at the back if continuous_id != 0 tss[continuous_id], tss[end] = tss[end], tss[continuous_id] inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id] diff --git a/src/systems/systems.jl b/src/systems/systems.jl index ff455fb811..c6b9d78f97 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -36,7 +36,7 @@ function mtkcompile( isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) newsys′ = __mtkcompile(sys; simplify, allow_symbolic, allow_parameter, conservative, fully_determined, - inputs, outputs, disturbance_inputs, + inputs, outputs, disturbance_inputs, additional_passes, kwargs...) if newsys′ isa Tuple @assert length(newsys′) == 2 @@ -292,3 +292,8 @@ function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivative return mapping end + +""" +Mark whether an extra pass `p` can support compiling discrete systems. +""" +discrete_compile_pass(p) = false diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 5dfd36a6fc..a5401c9f9b 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -820,19 +820,40 @@ function mtkcompile!(state::TearingState; simplify = false, time_domains = merge(Dict(state.fullvars .=> ci.var_domain), Dict(default_toterm.(state.fullvars) .=> ci.var_domain)) tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci) + if continuous_id == 0 + # do a trait check here - handle fully discrete system + additional_passes = get(kwargs, :additional_passes, nothing) + if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + # take the first discrete compilation pass given for now + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + return discrete_compile(tss, clocked_inputs) + end + throw(HybridSystemNotSupportedException(""" + Discrete systems with multiple clocks are not supported with the standard \ + MTK compiler. + """)) + end if length(tss) > 1 - if continuous_id == 0 - throw(HybridSystemNotSupportedException(""" - Discrete systems with multiple clocks are not supported with the standard \ - MTK compiler. - """)) - else - throw(HybridSystemNotSupportedException(""" - Hybrid continuous-discrete systems are currently not supported with \ - the standard MTK compiler. This system requires JuliaSimCompiler.jl, \ - see https://help.juliahub.com/juliasimcompiler/stable/ - """)) + # simplify as normal + sys = _mtkcompile!(tss[continuous_id]; simplify, + inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs, + check_consistency, fully_determined, + kwargs...) + if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems + # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed + return discrete_compile(sys, tss[2:end], inputs) end + throw(HybridSystemNotSupportedException(""" + Hybrid continuous-discrete systems are currently not supported with \ + the standard MTK compiler. This system requires JuliaSimCompiler.jl, \ + see https://help.juliahub.com/juliasimcompiler/stable/ + """)) end if get_is_discrete(state.sys) || continuous_id == 1 && any(Base.Fix2(isoperator, Shift), state.fullvars) From ed0612b33c3d112498095a75f08b9ccb6e9d2cef Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 20 Feb 2025 14:52:14 +0530 Subject: [PATCH 02/14] feat: retain original equations of the system in `TearingState` --- src/systems/systems.jl | 4 ++-- src/systems/systemstructure.jl | 21 ++++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index c6b9d78f97..bdca6ff71a 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -80,7 +80,6 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, @unpack structure, fullvars = state @unpack graph, var_to_diff, var_types = structure - eqs = equations(state) brown_vars = Int[] new_idxs = zeros(Int, length(var_types)) idx = 0 @@ -98,7 +97,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, Is = Int[] Js = Int[] vals = Num[] - new_eqs = copy(eqs) + make_eqs_zero_equals!(state) + new_eqs = copy(equations(state)) dvar2eq = Dict{Any, Int}() for (v, dv) in enumerate(var_to_diff) dv === nothing && continue diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index a5401c9f9b..c3ff82478e 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -203,6 +203,7 @@ end mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} """The system of equations.""" sys::T + original_eqs::Vector{Equation} """The set of variables of the system.""" fullvars::Vector{BasicSymbolic} structure::SystemStructure @@ -219,6 +220,7 @@ end TransformationState(sys::AbstractSystem) = TearingState(sys) function system_subset(ts::TearingState, ieqs::Vector{Int}) eqs = equations(ts) + @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) @@ -524,7 +526,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) eq_to_diff = DiffGraph(nsrcs(graph)) - ts = TearingState(sys, fullvars, + ts = TearingState(sys, original_eqs, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, false), Any[], param_derivative_map, original_eqs, Equation[]) @@ -810,6 +812,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure) printstyled(io, " SelectedState") end +function make_eqs_zero_equals!(ts::TearingState) + neweqs = map(enumerate(get_eqs(ts.sys))) do kvp + i, eq = kvp + isalgeq = true + for j in 𝑠neighbors(ts.structure.graph, i) + isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing + end + if isalgeq + return 0 ~ eq.rhs - eq.lhs + else + return eq + end + end + copyto!(get_eqs(ts.sys), neweqs) +end + function mtkcompile!(state::TearingState; simplify = false, check_consistency = true, fully_determined = true, warn_initialize_determined = true, inputs = Any[], outputs = Any[], @@ -836,6 +854,7 @@ function mtkcompile!(state::TearingState; simplify = false, """)) end if length(tss) > 1 + make_eqs_zero_equals!(tss[continuous_id]) # simplify as normal sys = _mtkcompile!(tss[continuous_id]; simplify, inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs, From 7f8b8f260efd0ea7c067baea118c16f40263c091 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 14 Mar 2025 14:50:31 +0530 Subject: [PATCH 03/14] feat: allow namespacing statemachine equations --- src/systems/abstractsystem.jl | 1 + src/systems/state_machines.jl | 33 +++++++++++++++++++++++++++++++++ src/utils.jl | 6 ++++++ 3 files changed, 40 insertions(+) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 70e7b06bfe..34dfb5283e 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1228,6 +1228,7 @@ function namespace_expr( O end end + _nonum(@nospecialize x) = x isa Num ? x.val : x """ diff --git a/src/systems/state_machines.jl b/src/systems/state_machines.jl index 347f92e6f8..ea65981804 100644 --- a/src/systems/state_machines.jl +++ b/src/systems/state_machines.jl @@ -153,3 +153,36 @@ entry When used in a finite state machine, this operator returns `true` if the queried state is active and false otherwise. """ activeState + +function vars!(vars, O::Transition; op = Differential) + vars!(vars, O.from) + vars!(vars, O.to) + vars!(vars, O.cond; op) + return vars +end +function vars!(vars, O::InitialState; op = Differential) + vars!(vars, O.s; op) + return vars +end +function vars!(vars, O::StateMachineOperator; op = Differential) + error("Unhandled state machine operator") +end + +function namespace_expr( + O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys)) + return Transition( + O.from === nothing ? O.from : renamespace(sys, O.from), + O.to === nothing ? O.to : renamespace(sys, O.to), + O.cond === nothing ? O.cond : namespace_expr(O.cond, sys), + O.immediate, O.reset, O.synchronize, O.priority + ) +end + +function namespace_expr( + O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys)) + return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s)) +end + +function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...) + error("Unhandled state machine operator") +end diff --git a/src/utils.jl b/src/utils.jl index e96f31f533..d028d4ed18 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -391,6 +391,12 @@ vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) end +function vars!(vars, O::AbstractSystem; op = Differential) + for eq in equations(O) + vars!(vars, eq; op) + end + return vars +end function vars!(vars, O; op = Differential) if isvariable(O) if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O))) From 7cf774d0bd4d355b9f696def2246988b49265d87 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 14 Mar 2025 14:56:00 +0530 Subject: [PATCH 04/14] feat: propagate state machines in structural simplification --- src/systems/systems.jl | 4 +- src/systems/systemstructure.jl | 81 ++++++++++++++++++++++++++++++---- 2 files changed, 76 insertions(+), 9 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index bdca6ff71a..9769d42e96 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -75,8 +75,10 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify) end + sys, statemachines = extract_top_level_statemachines(sys) sys = expand_connections(sys) - state = TearingState(sys; sort_eqs) + state = TearingState(sys) + append!(state.statemachines, statemachines) @unpack structure, fullvars = state @unpack graph, var_to_diff, var_types = structure diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index c3ff82478e..f072dc3f52 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -203,7 +203,6 @@ end mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} """The system of equations.""" sys::T - original_eqs::Vector{Equation} """The set of variables of the system.""" fullvars::Vector{BasicSymbolic} structure::SystemStructure @@ -215,6 +214,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} are not used in the rest of the system. """ additional_observed::Vector{Equation} + statemachines::Vector{T} end TransformationState(sys::AbstractSystem) = TearingState(sys) @@ -224,6 +224,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int}) @set! ts.sys.eqs = eqs[ieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) + if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys)) + names = Symbol[] + for eq in get_eqs(ts.sys) + if eq.lhs isa Transition + push!(names, first(namespace_hierarchy(nameof(eq.rhs.from)))) + push!(names, first(namespace_hierarchy(nameof(eq.rhs.to)))) + elseif eq.lhs isa InitialState + push!(names, first(namespace_hierarchy(nameof(eq.rhs.s)))) + else + error("Unhandled state machine operator") + end + end + @set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines) + else + @set! ts.statemachines = eltype(ts.statemachines)[] + end ts end @@ -277,6 +293,49 @@ function symbolic_contains(var, set) all(x -> x in set, Symbolics.scalarize(var)) end +""" + $(TYPEDSIGNATURES) + +Descend through the system hierarchy and look for statemachines. Remove equations from +the inner statemachine systems. Return the new `sys` and an array of top-level +statemachines. +""" +function extract_top_level_statemachines(sys::AbstractSystem) + eqs = get_eqs(sys) + + if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs) + # top-level statemachine + with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys)) + return with_removed, [sys] + elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs) + # error: can't mix + error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.") + else + # descend + subsystems = get_systems(sys) + newsubsystems = eltype(subsystems)[] + statemachines = eltype(subsystems)[] + for subsys in subsystems + newsubsys, sub_statemachines = extract_top_level_statemachines(subsys) + push!(newsubsystems, newsubsys) + append!(statemachines, sub_statemachines) + end + @set! sys.systems = newsubsystems + return sys, statemachines + end +end + +""" + $(TYPEDSIGNATURES) + +Return `sys` with all equations (including those in subsystems) removed. +""" +function remove_child_equations(sys::AbstractSystem) + @set! sys.eqs = eltype(get_eqs(sys))[] + @set! sys.systems = map(remove_child_equations, get_systems(sys)) + return sys +end + function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) # flatten system sys = flatten(sys) @@ -342,9 +401,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) # change the equation if the RHS is `missing` so the rest of this loop works eq = 0.0 ~ coalesce(eq.rhs, 0.0) end - rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs - if !_iszero(eq.lhs) + is_statemachine_equation = false + if eq.lhs isa StateMachineOperator + is_statemachine_equation = true + eq = eq + rhs = eq.rhs + elseif _iszero(eq.lhs) + rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs + else lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs + rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs eq = 0 ~ rhs - lhs end empty!(varsbuf) @@ -408,8 +474,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) addvar!(v, VARIABLE) end end - - if isalgeq + if isalgeq || is_statemachine_equation eqs[i] = eq else eqs[i] = eqs[i].lhs ~ rhs @@ -526,11 +591,10 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) eq_to_diff = DiffGraph(nsrcs(graph)) - ts = TearingState(sys, original_eqs, fullvars, + ts = TearingState(sys, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, false), - Any[], param_derivative_map, original_eqs, Equation[]) - + Any[], param_derivative_map, original_eqs, Equation[], typeof(sys)[]) return ts end @@ -860,6 +924,7 @@ function mtkcompile!(state::TearingState; simplify = false, inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs, check_consistency, fully_determined, kwargs...) + additional_passes = get(kwargs, :additional_passes, nothing) if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) discrete_compile = additional_passes[discrete_pass_idx] From c081a34b30549b7f2bff9ae5feff11a57067a4ca Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 14 Mar 2025 18:24:11 -0700 Subject: [PATCH 05/14] Handle nothing updates better --- src/systems/imperative_affect.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 7b1a9fb286..f3d45e258a 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -262,7 +262,9 @@ function compile_functional_affect( upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) + if !isnothing(upd_vals) + _generated_writeback(integ, upd_funs, upd_vals) + end reset_jumps && reset_aggregated_jumps!(integ) end From b60be7975ee056f2458340e75332b0f01eb89229 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 14 Mar 2025 18:24:24 -0700 Subject: [PATCH 06/14] Redefine the discrete_compile interface a bit --- src/systems/systemstructure.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index f072dc3f52..06f8ccc639 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -897,6 +897,9 @@ function mtkcompile!(state::TearingState; simplify = false, inputs = Any[], outputs = Any[], disturbance_inputs = Any[], kwargs...) + # split_system returns one or two systems and the inputs for each + # mod clock inference to be binary + # if it's continous keep going, if not then error unless given trait impl in additional passes ci = ModelingToolkit.ClockInference(state) ci = ModelingToolkit.infer_clocks!(ci) time_domains = merge(Dict(state.fullvars .=> ci.var_domain), @@ -910,7 +913,7 @@ function mtkcompile!(state::TearingState; simplify = false, discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) discrete_compile = additional_passes[discrete_pass_idx] deleteat!(additional_passes, discrete_pass_idx) - return discrete_compile(tss, clocked_inputs) + return discrete_compile(tss, clocked_inputs, ci) end throw(HybridSystemNotSupportedException(""" Discrete systems with multiple clocks are not supported with the standard \ @@ -931,7 +934,7 @@ function mtkcompile!(state::TearingState; simplify = false, deleteat!(additional_passes, discrete_pass_idx) # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed - return discrete_compile(sys, tss[2:end], inputs) + return discrete_compile(sys, tss[[i for i in eachindex(tss) if i != continuous_id]], clocked_inputs, ci) end throw(HybridSystemNotSupportedException(""" Hybrid continuous-discrete systems are currently not supported with \ From 405aafabfba1af39bee6f81d6506f004c943dff4 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 14 May 2025 13:43:00 -0700 Subject: [PATCH 07/14] Change the external synchronous signature to include the id/clock map --- src/systems/systemstructure.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 06f8ccc639..66f3a4b6f9 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -934,7 +934,9 @@ function mtkcompile!(state::TearingState; simplify = false, deleteat!(additional_passes, discrete_pass_idx) # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed - return discrete_compile(sys, tss[[i for i in eachindex(tss) if i != continuous_id]], clocked_inputs, ci) + return discrete_compile( + sys, tss[[i for i in eachindex(tss) if i != continuous_id]], + clocked_inputs, ci, id_to_clock) end throw(HybridSystemNotSupportedException(""" Hybrid continuous-discrete systems are currently not supported with \ From 865523b33df9a36698879f85e7e68d32eb659623 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Jun 2025 17:15:07 +0530 Subject: [PATCH 08/14] feat: add `zero_crossing_id` to `SymbolicContinuousCallback` --- src/systems/callbacks.jl | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a4f39243d9..f05e455cbc 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -165,6 +165,7 @@ struct SymbolicContinuousCallback <: AbstractCallback finalize::Union{Affect, Nothing} rootfind::Union{Nothing, SciMLBase.RootfindOpt} reinitializealg::SciMLBase.DAEInitializationAlgorithm + zero_crossing_id::Symbol function SymbolicContinuousCallback( conditions::Union{Equation, Vector{Equation}}, @@ -174,6 +175,7 @@ struct SymbolicContinuousCallback <: AbstractCallback finalize = nothing, rootfind = SciMLBase.LeftRootFind, reinitializealg = nothing, + zero_crossing_id = gensym(), kwargs...) conditions = (conditions isa AbstractVector) ? conditions : [conditions] @@ -190,7 +192,7 @@ struct SymbolicContinuousCallback <: AbstractCallback make_affect(affect_neg; kwargs...), make_affect(initialize; kwargs...), make_affect( finalize; kwargs...), - rootfind, reinitializealg) + rootfind, reinitializealg, zero_crossing_id) end # Default affect to nothing end @@ -466,7 +468,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo affect_neg = namespace_affects(affect_negs(cb), s), initialize = namespace_affects(initialize_affects(cb), s), finalize = namespace_affects(finalize_affects(cb), s), - rootfind = cb.rootfind, reinitializealg = cb.reinitializealg) + rootfind = cb.rootfind, reinitializealg = cb.reinitializealg, + zero_crossing_id = cb.zero_crossing_id) end function namespace_conditions(condition, s) @@ -490,6 +493,8 @@ function Base.hash(cb::AbstractCallback, s::UInt) s = hash(finalize_affects(cb), s) !is_discrete(cb) && (s = hash(cb.rootfind, s)) hash(cb.reinitializealg, s) + !is_discrete(cb) && (s = hash(cb.zero_crossing_id, s)) + return s end ########################### @@ -524,13 +529,16 @@ function finalize_affects(cbs::Vector{<:AbstractCallback}) end function Base.:(==)(e1::AbstractCallback, e2::AbstractCallback) - (is_discrete(e1) === is_discrete(e2)) || return false - (isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) && - isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) && - isequal(e1.reinitializealg, e2.reinitializealg) || - return false - is_discrete(e1) || - (isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)) + is_discrete(e1) === is_discrete(e2) || return false + isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) || return false + isequal(e1.initialize, e2.initialize) || return false + isequal(e1.finalize, e2.finalize) || return false + isequal(e1.reinitializealg, e2.reinitializealg) || return false + if !is_discrete(e1) + isequal(e1.affect_neg, e2.affect_neg) || return false + isequal(e1.rootfind, e2.rootfind) || return false + isequal(e1.zero_crossing_id, e2.zero_crossing_id) || return false + end end Base.isempty(cb::AbstractCallback) = isempty(cb.conditions) From aeefc8abcfd6c00170352459d55f132b6ecfba1a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Jun 2025 17:25:53 +0530 Subject: [PATCH 09/14] feat: add `ZeroCrossing` and `EventClock` from zero crossing --- src/discretedomain.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/discretedomain.jl b/src/discretedomain.jl index da8417de4e..370d93d894 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -365,3 +365,14 @@ function input_timedomain(x) throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression")) end end + +function ZeroCrossing(expr; name = gensym(), up = true, down = true, kwargs...) + return SymbolicContinuousCallback( + [expr ~ 0], up ? ImperativeAffect(Returns(nothing)) : nothing; + affect_neg = down ? ImperativeAffect(Returns(nothing)) : nothing, + kwargs..., zero_crossing_id = name) +end + +function SciMLBase.Clocks.EventClock(cb::SymbolicContinuousCallback) + return SciMLBase.Clocks.EventClock(cb.zero_crossing_id) +end From fbef1a8bd5f27080b8d40c1100c5eb30764b50e0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Jun 2025 13:10:21 +0530 Subject: [PATCH 10/14] feat: subset variables appropriately in clock inference --- src/systems/clock_inference.jl | 4 ++-- src/systems/systemstructure.jl | 30 +++++++++++++++++++++--------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 86fe7e85ac..ff2d77f19b 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -199,8 +199,8 @@ function split_system(ci::ClockInference{S}) where {S} # breaks the system up into a continous and 0 or more discrete systems tss = similar(cid_to_eq, S) - for (id, ieqs) in enumerate(cid_to_eq) - ts_i = system_subset(ts, ieqs) + for (id, (ieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_var)) + ts_i = system_subset(ts, ieqs, ivars) if id != continuous_id ts_i = shift_discrete_system(ts_i) @set! ts_i.structure.only_discrete = true diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 66f3a4b6f9..3a0e0584d3 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -218,12 +218,12 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} end TransformationState(sys::AbstractSystem) = TearingState(sys) -function system_subset(ts::TearingState, ieqs::Vector{Int}) +function system_subset(ts::TearingState, ieqs::Vector{Int}, ivars::Vector{Int}) eqs = equations(ts) @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] - @set! ts.structure = system_subset(ts.structure, ieqs) + @set! ts.structure = system_subset(ts.structure, ieqs, ivars) if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys)) names = Symbol[] for eq in get_eqs(ts.sys) @@ -240,22 +240,33 @@ function system_subset(ts::TearingState, ieqs::Vector{Int}) else @set! ts.statemachines = eltype(ts.statemachines)[] end + @set! ts.fullvars = ts.fullvars[ivars] ts end -function system_subset(structure::SystemStructure, ieqs::Vector{Int}) - @unpack graph, eq_to_diff = structure +function system_subset(structure::SystemStructure, ieqs::Vector{Int}, ivars::Vector{Int}) + @unpack graph = structure fadj = Vector{Int}[] eq_to_diff = DiffGraph(length(ieqs)) + var_to_diff = DiffGraph(length(ivars)) + ne = 0 + old_to_new_var = zeros(Int, ndsts(graph)) + for (i, iv) in enumerate(ivars) + old_to_new_var[iv] = i + structure.var_to_diff[iv] === nothing && continue + var_to_diff[i] = old_to_new_var[structure.var_to_diff[iv]] + end for (j, eq_i) in enumerate(ieqs) - ivars = copy(graph.fadjlist[eq_i]) - ne += length(ivars) - push!(fadj, ivars) + var_adj = [old_to_new_var[i] for i in graph.fadjlist[eq_i]] + @assert all(!iszero, var_adj) + ne += length(var_adj) + push!(fadj, var_adj) eq_to_diff[j] = structure.eq_to_diff[eq_i] end - @set! structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph))) + @set! structure.graph = complete(BipartiteGraph(ne, fadj, length(ivars))) @set! structure.eq_to_diff = eq_to_diff + @set! structure.var_to_diff = complete(var_to_diff) structure end @@ -440,7 +451,8 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) isdelay(v, iv) && continue if !symbolic_contains(v, dvs) - isvalid = iscall(v) && (operation(v) isa Shift || is_transparent_operator(operation(v))) + isvalid = iscall(v) && + (operation(v) isa Shift || is_transparent_operator(operation(v))) v′ = v while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift} v′ = arguments(v′)[1] From ea917fc060bc6beb695c1a649302162565efa978 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Jun 2025 14:24:24 +0530 Subject: [PATCH 11/14] feat: add hook during problem construction --- src/systems/problem_utils.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index ecba542fd0..a35b3b1663 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1272,6 +1272,8 @@ function get_p_constructor(p_constructor, pType::Type, floatT::Type) end end +abstract type ProblemConstructionHook end + """ $(TYPEDSIGNATURES) @@ -1324,6 +1326,8 @@ function process_SciMLProblem( check_inputmap_keys(sys, op) + op = getmetadata(sys, ProblemConstructionHook, identity)(op) + defs = add_toterms(recursive_unwrap(defaults(sys)); replace = is_discrete_system(sys)) kwargs = NamedTuple(kwargs) From 66d5b0760e382aebb7a8a04fcbbfca8449ce6c77 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 7 Jul 2025 15:40:07 +0530 Subject: [PATCH 12/14] fixup! feat: retain original equations of the system in `TearingState` --- src/systems/systemstructure.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 3a0e0584d3..1a76fc88c0 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -220,7 +220,6 @@ end TransformationState(sys::AbstractSystem) = TearingState(sys) function system_subset(ts::TearingState, ieqs::Vector{Int}, ivars::Vector{Int}) eqs = equations(ts) - @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs, ivars) From 6e6138087614d0341c9b4be9c7f2237ddd89ff2f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 15:22:42 +0530 Subject: [PATCH 13/14] fix: fix `get_mtkparameters_reconstructor` handling of nonnumerics --- src/systems/problem_utils.jl | 40 ++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index a35b3b1663..1f5c021a3b 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -711,7 +711,8 @@ end $(TYPEDEF) A callable struct which applies `p_constructor` to possibly nested arrays. It also -ensures that views (including nested ones) are concretized. +ensures that views (including nested ones) are concretized. This is implemented manually +of using `narrow_buffer_type` to preserve type-stability. """ struct PConstructorApplicator{F} p_constructor::F @@ -721,10 +722,18 @@ function (pca::PConstructorApplicator)(x::AbstractArray) pca.p_constructor(x) end +function (pca::PConstructorApplicator)(x::AbstractArray{Bool}) + pca.p_constructor(BitArray(x)) +end + function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray) collect(x) end +function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{Bool}) + BitArray(x) +end + function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray}) collect(pca.(x)) end @@ -749,6 +758,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns """ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; initials = false, unwrap_initials = false, p_constructor = identity) + _p_constructor = p_constructor p_constructor = PConstructorApplicator(p_constructor) # if we call `getu` on this (and it were able to handle empty tuples) we get the # fields of `MTKParameters` except caches. @@ -802,14 +812,24 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[3]) end - rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf - if buf == () - return Returns(()) - else - return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) - end + const_getter = if syms[4] == () + Returns(()) + else + Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[4]) end - getters = (tunable_getter, initials_getter, discs_getter, rest_getters...) + nonnumeric_getter = if syms[5] == () + Returns(()) + else + ic = get_index_cache(dstsys) + buftypes = Tuple(map(ic.nonnumeric_buffer_sizes) do bufsize + Vector{bufsize.type} + end) + # nonnumerics retain the assigned buffer type without narrowing + Base.Fix1(broadcast, _p_constructor) ∘ + Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5]) + end + getters = ( + tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter) getter = let getters = getters function _getter(valp, initprob) oldcache = parameter_values(initprob).caches @@ -822,6 +842,10 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac return getter end +function call(f, args...) + f(args...) +end + """ $(TYPEDSIGNATURES) From bcd21af0562582cdb8495453d8790fb902afb345 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 14:23:30 +0530 Subject: [PATCH 14/14] test: test nonnumerics aren't narrowed in `ODEProblem` and `init` --- test/initializationsystem.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 3be3e400c3..8f37ca905b 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1670,3 +1670,23 @@ end prob = ODEProblem(sys, [x[1] => nothing, x[2] => 1], (0.0, 1.0)) @test SciMLBase.initialization_status(prob) == SciMLBase.FULLY_DETERMINED end + +@testset "Nonnumerics aren't narrowed" begin + @mtkmodel Foo begin + @variables begin + x(t) = 1.0 + end + @parameters begin + p::AbstractString + r = 1.0 + end + @equations begin + D(x) ~ r * x + end + end + @mtkbuild sys = Foo(p = "a") + prob = ODEProblem(sys, [], (0.0, 1.0)) + @test prob.p.nonnumeric[1] isa Vector{AbstractString} + integ = init(prob) + @test integ.p.nonnumeric[1] isa Vector{AbstractString} +end