From 1f254ec6cce6e1bf9f08e2644fc5af88926803f3 Mon Sep 17 00:00:00 2001 From: Jens Glaser Date: Sun, 14 Aug 2022 14:43:00 -0400 Subject: [PATCH] Fix masking with more than one input feature --- neural_tangents/_src/stax/requirements.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/neural_tangents/_src/stax/requirements.py b/neural_tangents/_src/stax/requirements.py index 0858cc5f..d6e3e2c0 100644 --- a/neural_tangents/_src/stax/requirements.py +++ b/neural_tangents/_src/stax/requirements.py @@ -764,6 +764,11 @@ def get_x_cov_mask(x): x = _get_masked_array(x, mask_constant) x, mask = x.masked_value, x.mask + # reduce mask + if mask_constant and mask.shape[channel_axis] > 1: + warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.") + mask = np.any(mask, axis=channel_axis, keepdims=True) + # TODO(schsam): Think more about dtype automatic vs manual dtype promotion. x = x.astype(jax.dtypes.canonicalize_dtype(np.float64))