Skip to content

Rewrite more cases of Blockwise IncSubtensor #1560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 29, 2025

Spin-off from #1558 with the rewrite to remove Blockwise IncSubtensor extended to basic IncSubtensor (before only covered AdvancedIncSubtensor) and to batch indices. This shows up frequently if building vectorized jacobians so we want to make sure it's optimized away.

We could perhaps do the rewrite eagerly when we call vectorize_node, but since the logic is pretty complex I decided to keep it in a rewrite. The graph with Blockwise is still readable, so it's merely a matter of performance / enabling other Subtensor rewrites (including inplace!)

PS: We should be able to reuse the arange logic to rewrite away blockwise Subtensor with batch indices and AdvancedSubtensor.

Except for batched slices, which may not always be vectorizable in a "square manner", we shouldn't ever end up with a Blockwise of a subtensor (or a subtensor update) in the final graph. Numpy indexing is flexible enough to cover any vectorization case, it's just not trivial to write it :D


📚 Documentation preview 📚: https://pytensor--1560.org.readthedocs.build/en/1560/

@ricardoV94 ricardoV94 changed the title Rewrite blockwise incsubtensor Rewrite more cases of Blockwise incsubtensor Jul 29, 2025
@ricardoV94 ricardoV94 force-pushed the rewrite_blockwise_incsubtensor branch 2 times, most recently from 11218cf to 5608a21 Compare July 29, 2025 11:55
@ricardoV94 ricardoV94 force-pushed the rewrite_blockwise_incsubtensor branch from 5608a21 to 9e9d8ed Compare July 29, 2025 11:55
@ricardoV94 ricardoV94 changed the title Rewrite more cases of Blockwise incsubtensor Rewrite more cases of Blockwise IncSubtensor Jul 29, 2025
Copy link

codecov bot commented Jul 29, 2025

Codecov Report

❌ Patch coverage is 98.78049% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 81.54%. Comparing base (b55c473) to head (9e9d8ed).

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/subtensor.py 98.64% 0 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1560      +/-   ##
==========================================
+ Coverage   81.53%   81.54%   +0.01%     
==========================================
  Files         230      230              
  Lines       53079    53099      +20     
  Branches     9425     9432       +7     
==========================================
+ Hits        43279    43302      +23     
+ Misses       7365     7364       -1     
+ Partials     2435     2433       -2     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/blockwise.py 96.24% <ø> (-0.32%) ⬇️
pytensor/tensor/subtensor.py 89.89% <100.00%> (-0.12%) ⬇️
pytensor/tensor/rewriting/subtensor.py 91.02% <98.64%> (+0.74%) ⬆️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 added the enhancement New feature or request label Jul 29, 2025
@ricardoV94 ricardoV94 requested a review from lucianopaz August 3, 2025 07:32
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks solid to me, left a few questions

if isinstance(core_op, IncSubtensor):
# For basic IncSubtensor there are two cases:
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these dummy dimensions coming from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blockwise adds expand_left dims for all inputs to match output batch dims, like Elemwise does. Makes other parts of the code easier to reason about

counter += 1
else:
# For AdvancedIncSubtensor we have tensor integer indices,
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't clear to me why we're expanding right in this case. Is it because we're just reasoning about the batch indexes, and we want to broadcast them with the core indices in the end?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch indices by definition can't interact with core indices.

In advanced indexing interaction happens when indexes are on the same axes. When they are on different axes they broadcast/act like orthogonal indices (i.e, you visit all combinations of index1 and index2, instead of iterating simultaneously over both)

# Step 4. Introduce any implicit expand_dims on core dimension of y
missing_y_core_ndim = x_view.type.ndim - y.type.ndim
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for the squeeze_left to cancel some of the work done by expand_dims here?

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could call the rewrite that combines dimshuffle directly but that increases a bit the code complexity. It will be called anyway later

indices = tuple(
(
next(flat_indices_iterator)
if isinstance(entry, Type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is Type type?

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stuff like TensorType/ ScalarType (the type of a variable).

They use those in this flat map thing. Actually it's always ScalarType but I left as it was

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants