diff --git a/whisper_jax/layers.py b/whisper_jax/layers.py index 9b218be..4c2a1a9 100644 --- a/whisper_jax/layers.py +++ b/whisper_jax/layers.py @@ -60,7 +60,7 @@ # Temporary inlined JAX N-d initializer code # TODO(levskaya): remove once new JAX release is out. # ------------------------------------------------------------------------------ -def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1): +def _compute_fans(shape: Tuple[int, ...], in_axis: int = -2, out_axis: int = -1): """Inlined JAX `nn.initializer._compute_fans`.""" if isinstance(in_axis, int): in_size = shape[in_axis]