Skip to content

Commit 992f464

Browse files
committed
Improve cast operations on parameters (finally makes generated code type stable)
1 parent 5c45f60 commit 992f464

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

src/Symbolic.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,29 @@ function prependPar(ex::Expr, prefix, parameters=[], inputs=[])
7171
end
7272

7373
"""
74-
e = castToValueType(ex,value)
74+
e = castToFloatType(ex,value)
7575
76-
Cast `ex` to `typeof(value)` if this is a numeric data datatype (`eltype(value) <: Number`).
77-
As a result, the generated code of `ex` will have the correct type instead of `Any`, so will be more efficient.
76+
If manageable, cast `ex` to FloatType, if this is an `AbstractFloat` (`typeof(value) <: AbstractFloat`)
77+
and define it to be FloatType (called _FloatType in getDerivatives), so `FloatType(ex)::FloatType`.
78+
79+
If this is not manageable, cast `ex` to `valueType = typeof(value)` if this is a numeric data datatype
80+
(`eltype(value) <: Number`), and define it to be `valueType`, so `valueType(ex)::valueType`.
81+
82+
As a result, the generated code of `ex` will have the correct type instead of `Any`,
83+
so will be more efficient and no unnecessary memory will be calculated at run-time.
84+
85+
Note, this function should only be used on parameter, init, or start values.
7886
"""
79-
function castToValueType(ex,value)
87+
function castToFloatType(ex,value)
8088
if eltype(value) <: Number
8189
valueType = typeof(value)
82-
:($valueType($ex))
90+
if valueType <: AbstractFloat && !(valueType <: Unitful.AbstractQuantity ||
91+
valueType <: Measurements.Measurement ||
92+
valueType <: MonteCarloMeasurements.AbstractParticles)
93+
:( _FloatType($ex)::_FloatType )
94+
else
95+
:( $valueType($ex)::$valueType )
96+
end
8397
else
8498
ex
8599
end
@@ -96,7 +110,7 @@ Recursively converts der(x) to Symbol(:(der(x))) in expression `ex`
96110
function makeDerVar(ex, parameters, inputs, evaluateParameters=false)
97111
if typeof(ex) in [Symbol, Expr]
98112
if ex in keys(parameters)
99-
castToValueType( prependPar(ex, :(_p), parameters, inputs), parameters[ex] )
113+
castToFloatType( prependPar(ex, :(_p), parameters, inputs), parameters[ex] )
100114
elseif ex in keys(inputs)
101115
prependPar(ex, :(_p), parameters, inputs)
102116
else
@@ -114,7 +128,7 @@ function makeDerVar(ex::Expr, parameters, inputs, evaluateParameters=false)
114128
if evaluateParameters
115129
parameters[ex]
116130
else
117-
castToValueType( prependPar(ex, :(_p), parameters, inputs), parameters[ex] )
131+
castToFloatType( prependPar(ex, :(_p), parameters, inputs), parameters[ex] )
118132
end
119133
elseif isexpr(ex, :.) && ex in keys(inputs)
120134
if evaluateParameters

0 commit comments

Comments
 (0)