Skip to content

Conversation

lshaw8317
Copy link
Collaborator

@lshaw8317 lshaw8317 commented Oct 16, 2025

Improving lazy expression dtype/shape handling to 1) be more robust (solve #508) and 2) incorporate awider range of proxies (solve #509).

With this PR, the following code executes successfully.

import dask.array as da
import jax.numpy as jnp
import numpy as np
import pytest
import tensorflow as tf
import torch
import zarr
import blosc2
def test_simpleproxy(xp, dtype):
    dtype_ = getattr(xp, dtype) if hasattr(xp, dtype) else np.dtype(dtype)
    if dtype == "bool":
        blosc_matrix = blosc2.asarray([True, False, False], dtype=np.dtype(dtype), chunks=(2,))
        foreign_matrix = xp.zeros((3,), dtype=dtype_)
        # Create a lazy expression object
        lexpr = blosc2.lazyexpr(
            "(b & a) | (~b)", operands={"a": blosc_matrix, "b": foreign_matrix}
        )  # this does not
        # Compare with numpy computation result
        npb = np.asarray(foreign_matrix)
        npa = blosc_matrix[()]
        res = (npb & npa) | np.logical_not(npb)
    else:
        N = 10
        shape_a = (N, N, N)
        blosc_matrix = blosc2.full(shape=shape_a, fill_value=3, dtype=np.dtype(dtype), chunks=(N // 3,) * 3)
        foreign_matrix = xp.ones(shape_a, dtype=dtype_)
        if dtype == "complex128":
            foreign_matrix = (foreign_matrix + 1j) if xp is tf else xp.full(shape_a, fill_value=1+1j, dtype=dtype_)
            blosc_matrix = blosc2.full(
                shape=shape_a, fill_value=3 + 2j, dtype=np.dtype(dtype), chunks=(N // 3,) * 3
            )

        # Create a lazy expression object
        lexpr = blosc2.lazyexpr(
            "b + sin(a) + sum(b) - tensordot(a, b, axes=1)",
            operands={"a": blosc_matrix, "b": foreign_matrix},
        )  # this does not
        # Compare with numpy computation result
        npb = np.asarray(foreign_matrix)
        npa = blosc_matrix[()]
        res = npb + np.sin(npa) + np.sum(npb) - np.tensordot(npa, npb, axes=1)

    # Test object metadata and result
    assert isinstance(lexpr, blosc2.LazyExpr)
    assert lexpr.dtype == res.dtype
    assert lexpr.shape == res.shape
    np.testing.assert_array_equal(lexpr[()], res)


for xp in [torch, tf, np, jnp, da, zarr]:
    for dtype in ["bool", "int32", "int64", "float32", "float64", "complex128"]:
        test_simpleproxy(xp, dtype)

@lshaw8317 lshaw8317 changed the title Fixes to infer_shape for tensordot Fixes to infer_shape and SimpleProxy for tensordot Oct 16, 2025
@lshaw8317 lshaw8317 changed the title Fixes to infer_shape and SimpleProxy for tensordot Fixes to infer_shape and SimpleProxy for lazy expression handling Oct 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant