@@ -137,14 +137,17 @@ Check that the element types in `vi` are compatible with the ADType of `context`
137137Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
138138"""
139139function check_adtype (context:: ADTypeCheckContext , vi:: DynamicPPL.AbstractVarInfo )
140+ # Note that `get_param_eltype` will return `Any` with e.g. InitFromPrior or
141+ # InitFromUniform, so this will fail. But on the bright side, you would never _really_
142+ # use AD with those strategies, so that's fine. The cases where you do want to
143+ # use this are DefaultContext (i.e., old, slow, LogDensityFunction) and
144+ # InitFromParams{<:VectorWithRanges} (i.e., new, fast, LogDensityFunction), and
145+ # both of those give you sensible results for `get_param_eltype`.
146+ param_eltype = DynamicPPL. get_param_eltype (vi, context)
140147 valids = valid_eltypes (context)
141- for val in vi[:]
142- valtype = typeof (val)
143- if ! any (valtype .< : valids)
144- throw (IncompatibleADTypeError (valtype, adtype (context)))
145- end
148+ if ! (param_eltype .< : valids)
149+ throw (IncompatibleADTypeError (valtype, adtype (context)))
146150 end
147- return nothing
148151end
149152
150153# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
176179"""
177180All the ADTypes on which we want to run the tests.
178181"""
179- ADTYPES = [AutoForwardDiff (), AutoReverseDiff (; compile= false )]
182+ adtypes = (
183+ AutoForwardDiff (),
184+ AutoReverseDiff (),
185+ # Don't need to test Mooncake as it doesn't use tracer types
186+ )
180187if INCLUDE_MOONCAKE
181188 push! (ADTYPES, AutoMooncake (; config= nothing ))
182189end
0 commit comments