-
Notifications
You must be signed in to change notification settings - Fork 137
Improve dot lift rewrites #1471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
5b63faf
to
24d7a98
Compare
2dbf4f0
to
bbd12a7
Compare
bbd12a7
to
425859b
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (92.62%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1471 +/- ##
==========================================
- Coverage 81.51% 81.49% -0.02%
==========================================
Files 232 232
Lines 53033 53122 +89
Branches 9424 9444 +20
==========================================
+ Hits 43229 43292 +63
- Misses 7362 7382 +20
- Partials 2442 2448 +6
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends and simplifies subtensor lifting and matmul-related rewrites to support Blockwise
ops, unifies all matmul variants under _matmul
, and adds tests and performance benchmarks for partial Jacobian computations.
- Extend
local_subtensor_of_dot
andlocal_subtensor_of_elemwise
to handle batched/blockwise cases and add a newsqueeze
-based subtensor lift. - Unify all matmul-like ops (
matvec
,vecmat
,vecdot
, and matrix–matrix) to use a single_matmul
core and implement batch‐to‐core‐matmul rewrites with optional reshape. - Add new tests for blockwise subtensor lifts, batched matvec rewrites, and partial Jacobian benchmarks; adjust tolerances and seeds for existing tests.
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
tests/test_gradient.py | Import sqrt and add test_benchmark_partial_jacobian |
tests/tensor/test_math.py | Fix RNG seed and set atol for vector/matrix operation tests |
tests/tensor/test_blas.py | Remove xfail markers, add skipif and rename parameters |
tests/tensor/rewriting/test_subtensor_lift.py | Rename subtensor‐of‐elemwise tests, import Op , add blockwise tests |
tests/tensor/rewriting/test_math.py | Add test_batch_matvec_to_matmul parameterized test |
tests/tensor/rewriting/test_blas.py | Update imports, skip fast compile mode, adjust rewrite assertions |
pytensor/tensor/rewriting/subtensor_lift.py | Enhance local_subtensor_of_dot and local_subtensor_of_batch_dims , add squeeze lift |
pytensor/tensor/rewriting/subtensor.py | Minor cleanup in slice merging and useless‐slice rewrites |
pytensor/tensor/rewriting/math.py | Replace DimShuffle ‐through‐dot rewrite with unified _matmul , reposition specializations |
pytensor/tensor/rewriting/linalg.py | Update import of _matmul and use in transpose/blockwise rewrites |
pytensor/tensor/rewriting/elemwise.py | Simplify upcast‐constant rewrite, add register_stabilize |
pytensor/tensor/rewriting/blas.py | Adjust rewrite positions and batched‐dot reshaping logic |
pytensor/tensor/math.py | Add dimension check to Dot22.make_node , unify matmul variants |
Comments suppressed due to low confidence (1)
tests/tensor/rewriting/test_subtensor_lift.py:194
- The test references
tensor3
but it is not imported; addfrom pytensor.tensor import tensor3
to the file's imports to avoid aNameError
.
x = tensor3("x", shape=(7, 5, 11), dtype="float64")
This reduces the number of rewrite passes, by avoiding constant fold of cast/expand_dims/alloc
…ched_dot` rewrite
…in `local_subtensor_merge`
e83fe3a
to
e905280
Compare
New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
The marked xfail test was failing because Ger wasn't introduced, not because of the complex dtype.
e905280
to
e17a627
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We went over this PR extensively on a call, so this isn't just a snap approve.
This PR was motivated by the partial jacobian computation example in JAX discussed in jax-ml/jax#5904 (comment)
After #1228 it's actually easier to do this sort of optimization in PyTensor since there's no scan to worry about. We already have a bunch of rewrites to lift subtensor operations through elemwise and dots, but we did not have to lift it through blockwise (and blockwise dot - aka matmul). This PR addresses this.
Some notes on the major changes
Do constant_folding in python mode. This is not related to this PR but I noticed a test was taking 10x longer than the others just because there was a simple constant folding operation being triggered in the rewrites, and the whole c-cache was being loaded. This incurs a one time penalty that's pretty large. For users, not interested in the C backend at all, there's no reason to involve the machinery. One single python eval should be pretty fast anyway.This was moved to FixCheckAndRaise
Op C implementation #1521 as it revealed an unrelated bugSimplified
local_upcast_elemwise
. This rewrite was too complex and wasteful, in that it wrapped constants in symbolic expand_dims / alloc + cast. I just do it in numpy directly. This reduces the number of rewrite iterations.Bunch of improvements to rewrites. Including lifting index operations on the batch dimensions of blockwise, and expanding the dot subtensor lift to work with the Blockwise case. This rewrite predates Blockwise. Others are self-explanatory.
Canonicalize matvec, vecmat, vecdot internally to all use
matmul
(i.e., Blockwise of 2x2 dot operation). This makes things simpler for our rewrites, because we only need to worry about one case.The pre-existing
test_local_batched_matmul_to_core_matmul
rewrite was extend to better address cases of batched matvec, vecmat, and vecdot (batch dimensions are moved to the core dimension). It now moves non-ovelapping batch dimensions of both inputs to their core dimensions. It further tries to avoid reshape (needed when combining multiple batch/core dimensions), so that subtensor_lift rewrites mentioned above can work fine through them.Prioritize gemv/ger, which also makes several xfail tests pass. There was probably a misattribution mistaken for these xfails.
Benchmark result added in the last commit:
(Note that vectorize=True goes from underperforming (28ms) to overperforming (.37 ms).
vectorized jacobian code before:
and after:
📚 Documentation preview 📚: https://pytensor--1471.org.readthedocs.build/en/1471/