-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathxrtools.py
More file actions
321 lines (252 loc) · 11 KB
/
xrtools.py
File metadata and controls
321 lines (252 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import xarray as xr
import numpy as np
from scipy.stats import binned_statistic, binned_statistic_2d
from scipy.ndimage import median_filter, generic_filter
from matplotlib.pyplot import subplots
from .misc import get_edges
def has_same_coords(da1, da2):
"""
Check if two xr.DataArrays have the same coordinates (name and values).
The coordinates do not have to come in the same order.
"""
output = True
# Check if they have the same dimensions
if set(da1.dims)!=set(da2.dims):
output = False
# Check if coordinates have same values
else:
try:
xr.align(da1, da2, join='exact')
except ValueError:
output = False
return output
def grid_data_1d(da_x, da_y, x_grid, label = None, statistics = ['mean', 'count', 'std']):
"""
Grid data using scipy.stats.binned_statistic.
da_x, da_y and x_grid should all be xarray data arrays
Returns an xarray dataset with statistics for grid cells. Default statistics are mean,
count and standard deviation, but this can be altered by input variable statistics.
If label is given, the variables will be called <label>_mean, <label>_count, <label>_std
"""
if label is None:
prefix = ""
else:
prefix = f"{label}_"
# Compute statistics
ds_list = []
for stat in statistics:
ds_list.append(da_from_binned_statistic_1d(da_x, da_y, x_grid, stat
).to_dataset(name = prefix + stat)
)
return xr.merge(ds_list)
def grid_data_2d(x1, x2, f, dims, **kwargs):
"""
Bin variable f(x1,x2) using scipy.stats.binned_statistics_2d. Binned data is returned as an xarray DataArray
Parameters:
dims : List with name of dimensions, used when creating coordinates for the DataArray
**kwargs : Passed to scipy.stats.binned_statistics_2d, see doc for details
Example:
x = np.random.rand(100)-0.5
y = 2*np.random.rand(100)-1
f = x**2+y**2
grid_data_2d(x, y, f, ['x','y'], statistic='median').plot()
"""
data = binned_statistic_2d(x1, x2, f,**kwargs)
x1_center = (data.x_edge[:-1] + data.x_edge[1:])/2
x2_center = (data.y_edge[:-1] + data.y_edge[1:])/2
return xr.DataArray(data[0], coords = {dims[0] : x1_center, dims[1] : x2_center})
def da_from_binned_statistic_1d(da_x, da_y, x_grid, statistic):
"""
Parameters:
da_x : An xarray DataArray with values to be binned
da_y : An xarray DataArray with the data on which the statistic will be computed.
Can have one more dimension than da_x, but the common dimensions must have
same names and coordinate values as da_x.
x_grid : one-dimensional xarray DataArray with center points of the grid
statistic : passed to scipy.stats.binned_statistic (e.g 'mean', 'count', 'std')
Returns:
An xarray DataArray with statistic of the binned data
"""
different_dimensions = set(da_y.dims) - set(da_x.dims)
grid_dim = list(x_grid.dims)[0]
if len(different_dimensions) == 0:
dims = grid_dim
coords = {grid_dim: x_grid.values}
data = get_binned_statistics(da_x, da_y, x_grid, statistic).statistic
elif len(different_dimensions) == 1:
keep_dim = list(different_dimensions)[0]
dims = (grid_dim, keep_dim)
coords = {grid_dim: x_grid.values,
keep_dim: da_y[keep_dim].values}
len_keep_dim = len(da_y[keep_dim].values)
data = np.empty([len(x_grid.values), len_keep_dim])
for i in range(len_keep_dim):
data[:,i] = get_binned_statistics(da_x, da_y.isel({keep_dim : i}), x_grid, statistic).statistic
else:
raise ValueError("da_y can have at most one more dimension than da_x")
# Create DataArray
da = xr.DataArray(data,
dims = dims,
coords = coords)
return da
def get_binned_statistics(da_x, da_y, x_grid, statistic):
"""
Parameters:
da_x : An xarray DataArray with values to be binned
da_y : An xarray DataArray with the data on which the statistic will be computed.
Must have same dimensions and coordinate values as da_x
x_grid : one-dimensional xarray DataArray with center points of the grid
statistic : passed to scipy.stats.binned_statistic (e.g 'mean', 'count', 'std')
Returns:
Output directly from scipy.stats.binned_statistic
"""
# Make sure grid is 1D:
if len(x_grid.dims) != 1:
raise ValueError('Grid must be one-dimensional!')
# Make sure that x and y have same coordinates
if not has_same_coords(da_x, da_y):
raise ValueError("da_x and da_y don't match")
# Make 1D
dims = da_x.dims
x = da_x.stack(z=dims).values
y = da_y.stack(z=dims).values
# Remove Nan
is_nan = np.isnan(y) | np.isnan(x)
x = x[~is_nan]
y = y[~is_nan]
# Compute statistics
bin_edges = get_edges(x_grid.values)
return binned_statistic(x, y, statistic = statistic, bins = bin_edges)
def violin_plot(da, dim, ax=None, plot_hist = True, xlabel=None, ylabel=None, hist_kwargs={}, **kwargs):
"""
Visualize data from a dataarry with violins plots along the given dimension.
Note: Keywords showextrema/showmedians/showmeans are useful for controlling the appearance of the violins
Parameters:
da : xarray.DataArray with data
dim : Dimension to keep (x-axis). All other dimensions will be collapsed
into a single dimension representing observations.
plot_hist : True/False. Plot histogram (default True)
ax : If plot_count is True, ax should be a list of 2 axes
xlabel : Optional xlabel (default: dim)
ylabel : Optional y_label (default: none)
hist_kwargs : (dict) Keyword arguments passed to matplotlib.pyplot.bar
(used for plotting histogram if plot_hist is True)
**kwargs : Passed to matplotlib.pyplot.violinplot.
Returns (ax, violins) where violins is the handle from matplotlib.pyplot.violinplot.
"""
if ax is None:
if plot_hist:
fig,ax = subplots(nrows=2, sharex=True, layout='tight', height_ratios=[3,1])
else:
fig,ax = subplots()
if xlabel is None:
xlabel = dim
# Collapse dimensions
collapse_dims = []
for d in da.dims:
if d != dim:
collapse_dims.append(d)
collapsed_data = da.stack({'observations' : collapse_dims}).dropna(dim=dim, how='all')
# Remove nans and make data to the form violinplot eats
N = len(collapsed_data[dim])
data = [[] for i in range(N)]
count = [0 for i in range(N)]
for i in range(N):
obs = collapsed_data.isel({dim : i}).values
data[i] = obs[~np.isnan(obs)]
count[i] = len(data[i])
## Plot violins
if plot_hist:
ax_violin = ax[0]
else:
ax_violin = ax
# Keyword arguments to viloinplot:
widths = 0.8*np.mean(np.diff(da[dim].values))
default_kwargs = {'showmedians' : True,
'showmeans' : True,
'positions' : collapsed_data[dim].values,
'widths' : widths}
# Overwrite with user defined keyword arguments
kwargs_all = {**default_kwargs, **kwargs}
violins = ax_violin.violinplot(data, **kwargs_all)
# Make nicer
if 'cmedians' in violins.keys():
violins['cmedians'].set_label('median')
violins['cmedians'].set_color('steelblue')
if 'cmeans' in violins.keys():
violins['cmeans'].set_label('mean')
violins['cmeans'].set_color('black')
violins['cmeans'].set_linestyle('dotted')
if 'cbars' in violins.keys():
violins['cbars'].set_label('extent')
violins['cbars'].set_color('silver')
violins['cbars'].set_linewidth(0.5)
violins['cbars'].set_zorder(-10)
violins['cmaxes'].set_visible(False)
violins['cmins'].set_visible(False)
ax_violin.legend()
## Plot histogram
if plot_hist:
default_hist_kwargs = {'width' : widths,
'alpha' : 0.4}
hist_kwargs_all = {**default_hist_kwargs, **hist_kwargs}
ax_hist = ax[1]
ax_hist.bar(collapsed_data[dim].values, count, **hist_kwargs_all)
## Label axes
ax_violin.set_ylabel(ylabel)
if plot_hist:
ax_hist.set_ylabel('count')
ax_hist.set_xlabel(xlabel)
else:
ax_violin.set_xlabel(xlabel)
return ax, violins
def apply_median_filter(da, filter_lengths, filter_dims):
"""
Apply median filter on data array using scipy.nd_image.median_filter
Parameters:
da : xarray.DataArray with data to be filtered
filter_length : Size of median filter to be used.
filter_dim : Corresponding dimensions. The filter will be applied along those dimensions in da
Returns a new xr.DataArray with filtered data.
Example:
apply_median_filter(da, [9], ['time'])
apply_median_filter(da, [9,21], ['time', 'range'])
"""
# Expand kernel to same dimensions as da
kernel_size = [1 for dim in da.dims]
for (i,dim) in enumerate(da.dims):
for (N, fdim) in zip(filter_lengths, filter_dims):
if dim == fdim:
kernel_size[i] = N
if max(kernel_size) == 1:
raise ValueError('No dimension to apply filter along!')
# Create DataArray with filtered data
filtered_da = da.copy()
filtered_da.data = median_filter(da, kernel_size, mode='reflect')
return filtered_da
def da_filter(da, filter_fun, filter_size, filter_dims, **kwargs):
"""
Apply filter on data array using scipy.ndimage.generic_filter
Parameters:
da : xarray.DataArray with data to be filtered
filter_fun : Function passed to generic_filter (e.g. np.nanmean)
filter_size : Size of filter along dimensions in filter_dims. ('size' in generic_filter) (must be a list)
filter_dim : Dimensions along which filter will be applied (must be a list)
**kwargs : Keyword arguments passed to generic_filter
Returns a new xr.DataArray with filtered data.
Example:
da_filter(da, np.nanmean, [9], ['time'])
da_filter(da, np.max, [9,21], ['time', 'range'], mode='wrap')
"""
# Expand kernel to same dimensions as da
kernel_size = [1 for dim in da.dims]
for (i,dim) in enumerate(da.dims):
for (N, fdim) in zip(filter_size, filter_dims):
if dim == fdim:
kernel_size[i] = N
if max(kernel_size) == 1:
raise ValueError('No dimension to apply filter along!')
# Create DataArray with filtered data
filtered_da = da.copy()
filtered_da.data = generic_filter(da, filter_fun, size=kernel_size)
return filtered_da