-
Notifications
You must be signed in to change notification settings - Fork 37
Improve FastLDF type stability when all parameters are linked or unlinked #1141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: py/not-experimental
Are you sure you want to change the base?
Conversation
Benchmark Report
Computer InformationBenchmark Results |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## py/not-experimental #1141 +/- ##
=======================================================
- Coverage 77.25% 77.05% -0.20%
=======================================================
Files 40 40
Lines 3706 3731 +25
=======================================================
+ Hits 2863 2875 +12
- Misses 843 856 +13 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
3875d41 to
0a01995
Compare
|
DynamicPPL.jl documentation for PR #1141 is available at: |
177656b to
9310ec0
Compare
b403dce to
072da15
Compare
|
Would something bad happen if we just wrapped all the arrays returned by Bijectors in trivial SubArrays? I did some very crude benchmarks locally and at least |
992cea9 to
759bf8a
Compare
072da15 to
e5a3c97
Compare
759bf8a to
7fa0986
Compare
7fa0986 to
6849bc2
Compare
e5a3c97 to
fa0022b
Compare
6849bc2 to
b1368dd
Compare
The approach used in FastLDF potentially suffers from type stability issues.
One of the issues is just me being stupid: I implemented
fast_evaluate!!quite poorly (one branch would return TSVI, the other branch would return OAVI). This PR fixes that.But there is also a separate, more subtle, issue with using views. For example, this is responsible for failing type stability tests on #1115, which implement the naive solution of adding
@viewthroughout DefaultContext code. It's also (partly) responsible for Enzyme failures on #1139.The crux of the issue is that if you cannot tell whether a parameter is linked or unlinked, then you have to do something like this:
Now, consider
dist = product_distribution([Beta(2, 2), Beta(2, 2)]):and the effects of this transformation on a view:
So, generally when executing this code, if you can't tell whether the parameter is linked ahead of time, you will get a union type. Now running this in Julia itself doesn't affect performance that much because Julia is capable of handling this via union splitting. However, a test like
@inferredin #1115, or Enzyme's analysis, requires stricter type stability.This PR therefore implements special cases for what are by far the two most common use cases, where either all the parameters are linked, or all the parameters are unlinked. This is determined at LogDensityFunction construction time, and passed all the way down into
initvia a type parameter.I am still quite unsure whether there is a real scenario where mixed linked and unlinked variables. I think this was something to do with Gibbs, but if some samplers need linking (e.g. HMC), then surely we can just force all variables to be linked. This would only not be possible if some samplers need to be not linked, but I'm genuinely not sure if there is any sampler that has that property.
However, Gibbs doesn't use LDF, so I am not sure that this is an important consideration for this PR. Even so, there should be no regression in performance for the mixed linked/unlinked case: this PR should just be a strict improvement for the all-linked or all-unlinked case.
Why can't we just store the transform in the LDF?
The transform has to be constructed on-the-fly from
dist, and can't be stored ahead of time because ofBenchmarks (unlinked)
For most of the models that were benchmarked previously, the only real difference is that this PR makes Enzyme quite a bit faster. Still, it's good to verify that for those models, this PR does not cause any regressions.
Here 'before this PR' = #1139, 'after this PR' = this branch, 'v0.38.9' is current main.
The 'problem' with these benchmarks is that those models didn't catch this type stability issue. For a model where the type instability actually kicks in (
demo3here isDynamicPPL.TestUtils.DEMO_MODELS[3], see definition here), this makes a huge difference.Benchmarks (linked)
Here are the same benchmarks but run with linked parameters instead. This is arguably the more important case because HMC/NUTS use this.
Benchmark code