Skip to content

Commit 9624103

Browse files
committed
implement LogDensityProblems.dimension
1 parent 535ce4f commit 9624103

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/fasteval.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ struct FastLDF{
149149
_iden_varname_ranges::N
150150
_varname_ranges::Dict{VarName,RangeAndLinked}
151151
_adprep::ADP
152+
_dim::Int
152153

153154
function FastLDF(
154155
model::Model,
@@ -159,13 +160,14 @@ struct FastLDF{
159160
# Figure out which variable corresponds to which index, and
160161
# which variables are linked.
161162
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
163+
x = [val for val in varinfo[:]]
164+
dim = length(x)
162165
# Do AD prep if needed
163166
prep = if adtype === nothing
164167
nothing
165168
else
166169
# Make backend-specific tweaks to the adtype
167170
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
168-
x = [val for val in varinfo[:]]
169171
DI.prepare_gradient(
170172
FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
171173
adtype,
@@ -179,7 +181,7 @@ struct FastLDF{
179181
typeof(all_iden_ranges),
180182
typeof(prep),
181183
}(
182-
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep
184+
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
183185
)
184186
end
185187
end
@@ -260,6 +262,10 @@ function LogDensityProblems.logdensity_and_gradient(
260262
)
261263
end
262264

265+
function LogDensityProblems.dimension(fldf::FastLDF)
266+
return fldf._dim
267+
end
268+
263269
######################################################
264270
# Helper functions to extract ranges and link status #
265271
######################################################

0 commit comments

Comments
 (0)