diff --git a/clu/data/dataset_iterator.py b/clu/data/dataset_iterator.py index 21d90dd..e4491f0 100644 --- a/clu/data/dataset_iterator.py +++ b/clu/data/dataset_iterator.py @@ -41,9 +41,9 @@ from etils import epath import jax.numpy as jnp # Just for type checking. import numpy as np +import numpy.typing as npt Array = Union[np.ndarray, jnp.ndarray] -DType = np.dtype # Sizes of dimensions, None means the dimension size is unknown. Shape = Tuple[Optional[int], ...] @@ -51,7 +51,7 @@ @dataclasses.dataclass(frozen=True) class ArraySpec: """Describes an array via it's dtype and shape.""" - dtype: DType + dtype: npt.DTypeLike shape: Shape def __repr__(self):