-
Notifications
You must be signed in to change notification settings - Fork 137
Rewrite more cases of Blockwise IncSubtensor #1560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rewrite more cases of Blockwise IncSubtensor #1560
Conversation
11218cf
to
5608a21
Compare
Also cover cases of AdvancedIncSubtensor with batch indices that were not supported before
5608a21
to
9e9d8ed
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1560 +/- ##
==========================================
+ Coverage 81.53% 81.54% +0.01%
==========================================
Files 230 230
Lines 53079 53099 +20
Branches 9425 9432 +7
==========================================
+ Hits 43279 43302 +23
+ Misses 7365 7364 -1
+ Partials 2435 2433 -2
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks solid to me, left a few questions
if isinstance(core_op, IncSubtensor): | ||
# For basic IncSubtensor there are two cases: | ||
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice | ||
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are these dummy dimensions coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blockwise adds expand_left dims for all inputs to match output batch dims, like Elemwise does. Makes other parts of the code easier to reason about
counter += 1 | ||
else: | ||
# For AdvancedIncSubtensor we have tensor integer indices, | ||
# We need to expand batch indexes on the right, so they don't interact with core index dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't clear to me why we're expanding right in this case. Is it because we're just reasoning about the batch indexes, and we want to broadcast them with the core indices in the end?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch indices by definition can't interact with core indices.
In advanced indexing interaction happens when indexes are on the same axes. When they are on different axes they broadcast/act like orthogonal indices (i.e, you visit all combinations of index1 and index2, instead of iterating simultaneously over both)
# Step 4. Introduce any implicit expand_dims on core dimension of y | ||
missing_y_core_ndim = x_view.type.ndim - y.type.ndim | ||
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim)) | ||
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for the squeeze_left to cancel some of the work done by expand_dims here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could call the rewrite that combines dimshuffle directly but that increases a bit the code complexity. It will be called anyway later
indices = tuple( | ||
( | ||
next(flat_indices_iterator) | ||
if isinstance(entry, Type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is Type
type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stuff like TensorType/ ScalarType (the type of a variable).
They use those in this flat map thing. Actually it's always ScalarType but I left as it was
Spin-off from #1558 with the rewrite to remove Blockwise IncSubtensor extended to basic IncSubtensor (before only covered AdvancedIncSubtensor) and to batch indices. This shows up frequently if building vectorized jacobians so we want to make sure it's optimized away.
We could perhaps do the rewrite eagerly when we call
vectorize_node
, but since the logic is pretty complex I decided to keep it in a rewrite. The graph with Blockwise is still readable, so it's merely a matter of performance / enabling other Subtensor rewrites (including inplace!)PS: We should be able to reuse the arange logic to rewrite away blockwise Subtensor with batch indices and AdvancedSubtensor.
Except for batched slices, which may not always be vectorizable in a "square manner", we shouldn't ever end up with a Blockwise of a subtensor (or a subtensor update) in the final graph. Numpy indexing is flexible enough to cover any vectorization case, it's just not trivial to write it :D
📚 Documentation preview 📚: https://pytensor--1560.org.readthedocs.build/en/1560/