From 3424eaea9146e17f960120c7fa7e8edd53252b70 Mon Sep 17 00:00:00 2001 From: Marvin Ritter Date: Sat, 22 Jul 2023 07:59:13 -0700 Subject: [PATCH] Fix type annotations to pass pytype checks. PiperOrigin-RevId: 550199072 --- clu/data/dataset_iterator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clu/data/dataset_iterator.py b/clu/data/dataset_iterator.py index 21d90dd..e4491f0 100644 --- a/clu/data/dataset_iterator.py +++ b/clu/data/dataset_iterator.py @@ -41,9 +41,9 @@ from etils import epath import jax.numpy as jnp # Just for type checking. import numpy as np +import numpy.typing as npt Array = Union[np.ndarray, jnp.ndarray] -DType = np.dtype # Sizes of dimensions, None means the dimension size is unknown. Shape = Tuple[Optional[int], ...] @@ -51,7 +51,7 @@ @dataclasses.dataclass(frozen=True) class ArraySpec: """Describes an array via it's dtype and shape.""" - dtype: DType + dtype: npt.DTypeLike shape: Shape def __repr__(self):