diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 6970d37402f..6fd18ae052e 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -101,7 +101,8 @@ def _ensure_bool_is_ndarray(result, *args): # This function ensures that the result is the appropriate shape in these # cases if isinstance(result, bool): - shape = np.broadcast(*args).shape + # Avoid repeated broadcasting: save result + shape = np.broadcast_shapes(*(getattr(arg, "shape", ()) for arg in args)) constructor = np.ones if result else np.zeros result = constructor(shape, dtype=bool) return result @@ -115,8 +116,9 @@ def array_eq(self, other): def array_ne(self, other): with warnings.catch_warnings(): - warnings.filterwarnings("ignore", r"elementwise comparison failed") - return _ensure_bool_is_ndarray(self != other, self, other) + warnings.simplefilter("ignore", category=UserWarning) + cmp_result = self != other + return _ensure_bool_is_ndarray(cmp_result, self, other) def _is_contiguous(positions):