diff --git a/pyfesom2/diagnostics.py b/pyfesom2/diagnostics.py index 14db4ee..c24d329 100644 --- a/pyfesom2/diagnostics.py +++ b/pyfesom2/diagnostics.py @@ -73,10 +73,13 @@ def ice_ext(data, mesh, hemisphere="N", threshhold=0.15, attrs={}): hemis_mask = mesh.y2 < 0 if isinstance(data, xr.DataArray): - data = data.where(data < threshhold, 1) - data = data.where(data > threshhold) + # Create binary ice extent field: 1 where ice >= threshold, 0 where < threshold, NaN stays NaN + ice_binary = xr.where(data >= threshhold, 1, 0) + ice_binary = ice_binary.where(~np.isnan(data)) # Preserve NaN values (land) - ext = (data[:, hemis_mask] * mesh.lump2[hemis_mask]).sum(axis=1) + finite_mask = ~np.isnan(ice_binary.squeeze().values) + + ext = (ice_binary[:, hemis_mask & finite_mask] * mesh.lump2[hemis_mask & finite_mask]).sum(axis=1) da = xr.DataArray( ext, dims=["time"], coords={"time": data.time}, name=varname, attrs=attrs ) @@ -84,10 +87,12 @@ def ice_ext(data, mesh, hemisphere="N", threshhold=0.15, attrs={}): else: logger.debug(data) - i, j = np.where(data < 0.15) - data[:] = 1 - data[i, j] = 0 - ext = (data[:, hemis_mask] * mesh.lump2[hemis_mask]).sum(axis=1) + # Create binary ice extent field: 1 where ice >= threshold, 0 where < threshold, NaN stays NaN + ice_binary = np.where(data >= threshhold, 1, 0).astype(float) + ice_binary[np.isnan(data)] = np.nan # Preserve NaN values (land) + + finite_mask = ~np.isnan(ice_binary.squeeze()) + ext = (ice_binary[:, hemis_mask & finite_mask] * mesh.lump2[hemis_mask & finite_mask]).sum(axis=1) return ext