Skip to content

Rewrite solves involving kron to eliminate kron #1559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jul 28, 2025

Description

Rewrite graphs of the form solve(kron(A, B), x) to solve(A, solve(B, x.reshape).mT).mT.reshape. This eliminates the kronecker product, and provides significant speedup.

Important limitation is that it only covers the case when b_ndim=1, because the math underpinning the rewrite requires that x is a vector. This is still an important case, however, because it's what arises in the logp of a multivariate normal when the covariance matrix is kronecker.

Also I hit what appears to be a numerical bug in the batch case when assume_a = 'pos'. There is disagreement, but only in the 2nd row of the outputs. No matter the batch size, it's always the 2nd batch that has a numerical problem -- all other batches agree. I've left in the failing test for now. We don't even vectorize kron by default, so if I can't figure it out I might just disable the rewrite for the Blockwise(Kron) case for now.

Benchmarks follow, with:

  • small: A, B are (10, 10)
  • medium: A, B are (50, 50)
  • large: A, B are (100, 100)
-----------------------------------------------------------------------------------------
Name (time in us)                                                            Min                       Max                      Mean                 StdDev                    Median                    IQR            Outliers          OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_solve_kron_to_solve_benchmark[no_rewrite-small]             17.2500 (1.0)             34.6250 (1.0)             18.5068 (1.0)           2.3577 (1.0)             17.6670 (1.0)           0.8755 (1.0)          8;10  54,034.2951 (1.0)          93           1
test_rewrite_solve_kron_to_solve_benchmark[rewrite-small]                19.2910 (1.12)            98.7500 (2.85)            21.9831 (1.19)          4.0015 (1.70)            20.9160 (1.18)          3.6250 (4.14)       135;35  45,489.5626 (0.84)       3261           1

test_rewrite_solve_kron_to_solve_benchmark[no_rewrite-medium]        93,532.8330 (>1000.0)     96,359.5420 (>1000.0)     94,835.3042 (>1000.0)     857.3874 (363.65)      94,672.9585 (>1000.0)   1,327.0000 (>1000.0)       3;0      10.5446 (0.00)         10           1
test_rewrite_solve_kron_to_solve_benchmark[rewrite-medium]               66.1660 (3.84)           288.5420 (8.33)            74.0905 (4.00)          8.2418 (3.50)            72.8750 (4.12)          5.9580 (6.81)      405;317  13,497.0108 (0.25)       7247           1

test_rewrite_solve_kron_to_solve_benchmark[no_rewrite-large]      3,250,903.0000 (>1000.0)  3,333,840.0830 (>1000.0)  3,300,615.3582 (>1000.0)  31,539.6476 (>1000.0)  3,302,135.1670 (>1000.0)  38,780.1145 (>1000.0)       2;0       0.3030 (0.00)          5           1
test_rewrite_solve_kron_to_solve_benchmark[rewrite-large]               183.1670 (10.62)          357.8750 (10.34)          196.7968 (10.63)        11.5470 (4.90)           194.6250 (11.02)         8.7920 (10.04)     401;206   5,081.3837 (0.09)       3442           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1559.org.readthedocs.build/en/1559/

@jessegrabowski jessegrabowski requested review from ricardoV94 and Copilot and removed request for ricardoV94 July 28, 2025 23:59
Copilot

This comment was marked as outdated.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a rewrite optimization for solving linear systems involving Kronecker products. The goal is to transform expressions of the form solve(kron(A, B), x) into an equivalent form that eliminates the Kronecker product computation, providing significant performance improvements.

Key changes:

  • Added a new rewrite rule rewrite_solve_kron_to_solve that transforms Kronecker-based solves using mathematical identities
  • Comprehensive test coverage including correctness tests and benchmarks demonstrating substantial speedups
  • Support for both batched and non-batched operations with limitations for certain matrix dimensions

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
pytensor/tensor/rewriting/linalg.py Implements the core rewrite logic with mathematical transformation from Kronecker solve to nested solves
tests/tensor/rewriting/test_linalg.py Adds comprehensive test suite including correctness verification and performance benchmarks

Copy link

codecov bot commented Aug 3, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.54%. Comparing base (892a8f0) to head (b06b0c7).
⚠️ Report is 7 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1559   +/-   ##
=======================================
  Coverage   81.53%   81.54%           
=======================================
  Files         230      230           
  Lines       53066    53144   +78     
  Branches     9423     9445   +22     
=======================================
+ Hits        43269    43336   +67     
- Misses       7364     7370    +6     
- Partials     2433     2438    +5     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/linalg.py 92.56% <100.00%> (+0.50%) ⬆️

... and 7 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rewrite Solve involving Kron
2 participants