Skip to content

Commit 7037f9a

Browse files
committed
fix: improve performance of implicit affect
1 parent 15065bb commit 7037f9a

File tree

4 files changed

+26
-32
lines changed

4 files changed

+26
-32
lines changed

src/systems/callbacks.jl

+24-16
Original file line numberDiff line numberDiff line change
@@ -942,32 +942,27 @@ function compile_equational_affect(
942942
dvs_to_access = unknowns(affsys)
943943
ps_to_access = parameters(affsys)
944944

945-
u_getters = [getsym(sys, aff_map[u]) for u in dvs_to_access]
946-
p_getters = [getsym(sys, unPre(p)) for p in ps_to_access]
947-
u_setters = [setsym(sys, u) for u in dvs_to_update]
948-
p_setters = [setsym(sys, p) for p in ps_to_update]
949-
affu_getters = [getsym(affsys, sys_map[u]) for u in dvs_to_update]
950-
affp_getters = [getsym(affsys, sys_map[p]) for p in ps_to_update]
945+
u_getter = getsym(sys, [aff_map[u] for u in dvs_to_access])
946+
p_getter = getsym(sys, [unPre(p) for p in ps_to_access])
947+
u_setter! = setsym(sys, dvs_to_update)
948+
p_setter! = setsym(sys, ps_to_update)
949+
affu_getter = getsym(affsys, [sys_map[u] for u in dvs_to_update])
950+
affp_getter = getsym(affsys, [sys_map[p] for p in ps_to_update])
951951

952952
affprob = ImplicitDiscreteProblem(affsys, [dv => 0 for dv in dvs_to_access],
953953
(0, 0), [p => 0 for p in ps_to_access];
954954
build_initializeprob = false, check_length = false)
955955

956956
function implicit_affect!(integ)
957-
pmap = Pair[p => getp(integ) for (p, getp) in zip(ps_to_access, p_getters)]
958-
u0map = Pair[u => getu(integ)
959-
for (u, getu) in zip(dvs_to_access, u_getters)]
960-
affprob = remake(affprob, u0 = u0map, p = pmap, tspan = (integ.t, integ.t))
957+
new_us = u_getter(integ)
958+
new_ps = p_getter(integ)
959+
affprob = remake(affprob, u0 = new_us, p = new_ps, tspan = (integ.t, integ.t))
961960
affsol = init(affprob, IDSolve())
962961
(check_error(affsol) === ReturnCode.InitialFailure) &&
963962
throw(UnsolvableCallbackError(all_equations(aff)))
964963

965-
for (setu!, getu) in zip(u_setters, affu_getters)
966-
setu!(integ, getu(affsol))
967-
end
968-
for (setp!, getp) in zip(p_setters, affp_getters)
969-
setp!(integ, getp(affsol))
970-
end
964+
u_setter!(integ, affu_getter(affsol))
965+
p_setter!(integ, affp_getter(affsol))
971966
end
972967
end
973968
end
@@ -1080,3 +1075,16 @@ function continuous_events_toplevel(sys::AbstractSystem)
10801075
end
10811076
return get_continuous_events(sys)
10821077
end
1078+
1079+
"""
1080+
Process the symbolic events of a system.
1081+
"""
1082+
function create_symbolic_events(cont_events, disc_events, sys_eqs, iv)
1083+
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
1084+
sys_eqs)
1085+
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback,
1086+
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1087+
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback,
1088+
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1089+
cont_callbacks, disc_callbacks
1090+
end

src/systems/diffeqs/odesystem.jl

+1-7
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
318318
throw(ArgumentError("System names must be unique."))
319319
end
320320

321-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
322-
deqs)
323-
cont_callbacks = to_cb_vector(continuous_events; CB_TYPE = SymbolicContinuousCallback,
324-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
325-
disc_callbacks = to_cb_vector(discrete_events; CB_TYPE = SymbolicDiscreteCallback,
326-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
327-
321+
cont_callbacks, disc_callbacks = create_symbolic_events(continuous_events, discrete_events, deqs, iv)
328322
if is_dde === nothing
329323
is_dde = _check_if_dde(deqs, iv′, systems)
330324
end

src/systems/diffeqs/sdesystem.jl

+1-6
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
270270
Wfact = RefValue(EMPTY_JAC)
271271
Wfact_t = RefValue(EMPTY_JAC)
272272

273-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
274-
deqs)
275-
cont_callbacks = to_cb_vector(continuous_events; CB_TYPE = SymbolicContinuousCallback,
276-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
277-
disc_callbacks = to_cb_vector(discrete_events; CB_TYPE = SymbolicDiscreteCallback,
278-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
273+
cont_callbacks, disc_callbacks = create_symbolic_events(continuous_events, discrete_events, deqs, iv)
279274

280275
if is_dde === nothing
281276
is_dde = _check_if_dde(deqs, iv′, systems)

src/systems/discrete_system/implicit_discrete_system.jl

-3
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,6 @@ function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
298298
for k in collect(keys(u0map))
299299
v = u0map[k]
300300
if !((op = operation(k)) isa Shift)
301-
isnothing(getunshifted(k)) &&
302-
@warn "Initial condition given in term of current state of the unknown. If `build_initializeprob = false`, this may be overridden by the implicit discrete solver."
303-
304301
updated[k] = v
305302
elseif op.steps > 0
306303
error("Initial conditions must be for the current or past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).")

0 commit comments

Comments
 (0)