diff --git a/distrax/_src/distributions/quantized.py b/distrax/_src/distributions/quantized.py index f01f3f5a..83c54bba 100644 --- a/distrax/_src/distributions/quantized.py +++ b/distrax/_src/distributions/quantized.py @@ -14,7 +14,7 @@ # ============================================================================== """Quantized distribution.""" -from typing import cast, Optional, Tuple +from typing import Optional, Tuple import chex from distrax._src.distributions import distribution as base_distribution @@ -107,9 +107,7 @@ def high(self) -> Optional[Array]: @property def event_shape(self) -> Tuple[int, ...]: """Shape of event of distribution samples.""" - event_shape = self.distribution.event_shape - # TODO(b/149413467): Remove explicit casting when resolved. - return cast(Tuple[int, ...], event_shape) + return self.distribution.event_shape @property def batch_shape(self) -> Tuple[int, ...]: