|
7 | 7 | import inspect |
8 | 8 |
|
9 | 9 | import numpy as np |
| 10 | +from scipy.special import logsumexp as scipy_logsumexp |
10 | 11 | from xarray import DataArray |
11 | 12 |
|
12 | 13 | import pytensor.scalar as ps |
13 | 14 | import pytensor.xtensor.math as pxm |
14 | 15 | from pytensor import function |
15 | 16 | from pytensor.scalar import ScalarOp |
16 | 17 | from pytensor.xtensor.basic import rename |
17 | | -from pytensor.xtensor.math import add, exp |
| 18 | +from pytensor.xtensor.math import add, exp, logsumexp |
18 | 19 | from pytensor.xtensor.type import xtensor |
19 | 20 | from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function |
20 | 21 |
|
@@ -152,6 +153,28 @@ def test_cast(): |
152 | 153 | yc64.astype("float64") |
153 | 154 |
|
154 | 155 |
|
| 156 | +@pytest.mark.parametrize( |
| 157 | + ["shape", "dims", "axis"], |
| 158 | + [ |
| 159 | + ((3, 4), ("a", "b"), None), |
| 160 | + ((3, 4), "a", 0), |
| 161 | + ((3, 4), "b", 1), |
| 162 | + ], |
| 163 | +) |
| 164 | +def test_logsumexp(shape, dims, axis): |
| 165 | + scipy_inp = np.zeros(shape) |
| 166 | + scipy_out = scipy_logsumexp(scipy_inp, axis=axis) |
| 167 | + |
| 168 | + pytensor_inp = DataArray(scipy_inp, dims=("a", "b")) |
| 169 | + f = function([], logsumexp(pytensor_inp, dim=dims)) |
| 170 | + pytensor_out = f() |
| 171 | + |
| 172 | + np.testing.assert_array_almost_equal( |
| 173 | + pytensor_out, |
| 174 | + scipy_out, |
| 175 | + ) |
| 176 | + |
| 177 | + |
155 | 178 | def test_dot(): |
156 | 179 | """Test basic dot product operations.""" |
157 | 180 | # Test matrix-vector dot product (with multiple-letter dim names) |
|
0 commit comments