-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpost_process_arc_zarr.py
More file actions
242 lines (188 loc) · 13 KB
/
post_process_arc_zarr.py
File metadata and controls
242 lines (188 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Script to post-process arctic simulation data and compute histograms for ARC in a more storage and memory efficient manner
# NOTE: Because there is only one restart file for 2008, we will handle that year separately at the bottom of post_process_arc_zarr_repeat.ipynb
import numpy as np
import os
import xarray as xr
import pandas as pd
import os
import gc
from tqdm import tqdm
from argparse import ArgumentParser
# Parse the arguments
p = ArgumentParser(description="""Arctic post-process of parcels simulations into zarr format""")
p.add_argument('-startdate', '--startdate', default='2000-01-01', help='Start date for processing (YYYY-MM-DD)')
p.add_argument('-startposition', '--startposition', default='0', help='Start position for processing (number of weeks)')
parsed_args = p.parse_args()
startdate = parsed_args.startdate
startposition = int(parsed_args.startposition)
# Locations of trajectory data and output folder
data_path = "/storage/shared/oceanparcels/output_data/data_Michael/NECCTONsimulations/data/copernicus_simulations/"
output_dir = '/storage/shared/oceanparcels/output_data/data_Michael/NECCTONsimulations/data/histograms_arc_zarr/'
scaling_factor_file = '/storage/shared/oceanparcels/output_data/data_Michael/NECCTONsimulations/data/plastic_scaling_factors.npz'
# Load plastic scaling factors
scaling_factors = np.load(scaling_factor_file, allow_pickle=True)['scaling_factors'].item()
# Construct the ARC grid for the histograms by shifting the bathymetry grid
bathy = xr.open_dataset("/storage/shared/oceanparcels/output_data/data_Michael/NECCTONsimulations/data/copernicus_data/cmems_mod_glo_phy_my_static_fulldomain.nc")
bathy_ARC = bathy.sel(latitude=slice(59.9, 90))
# Handle the north pole...
bathy_ARC_latitude_values = np.empty(bathy_ARC.latitude.size+1)
bathy_ARC_latitude_values[:-1] = bathy_ARC.latitude.values
bathy_ARC_latitude_values[-1] = bathy_ARC_latitude_values[-2] + 1/12
# Handle the periodic BC
bathy_ARC_longitude_values = np.empty(bathy_ARC.longitude.size+3)
bathy_ARC_longitude_values[1:-2] = bathy_ARC.longitude.values
bathy_ARC_longitude_values[0] = bathy_ARC_longitude_values[1] - 1/12
bathy_ARC_longitude_values[-2] = bathy_ARC_longitude_values[-3] + 1/12
#We will create an extra column on the far right, and later add these values to the first column, and remove
bathy_ARC_longitude_values[-1] = bathy_ARC_longitude_values[-2] + 1/12
# Construct bin edges and cell centers
lat_bins = (bathy_ARC_latitude_values[:-1] + bathy_ARC_latitude_values[1:]) / 2
lat_centers = bathy_ARC_latitude_values[1:-1]
lon_bins = (bathy_ARC_longitude_values[:-1] + bathy_ARC_longitude_values[1:]) / 2
lon_centers = bathy_ARC_longitude_values[1:-2]
# Create a grid to compute the densitites over
arc_x_edges = lon_bins
arc_y_edges = lat_bins
surface_threshhold = 5 # meters
# Save to file
if not os.path.isfile(output_dir + "arc_grid.npz"):
np.savez(output_dir + "arc_grid.npz", arc_x_edges=arc_x_edges, arc_y_edges=arc_y_edges, lon_centers=lon_centers, lat_centers=lat_centers)
# Load the at dataset we are going to analyse
try:
sim_ds = xr.open_zarr(os.path.join(data_path, f'rechunked_particles_{startdate}.zarr'))
print("Using rechunked data!")
starting_day_str = str(sim_ds.isel(obs=0, trajectory=0).time.values.astype('datetime64[D]').astype(str))
starting_day_year = int(starting_day_str[:4])
except:
# No rechunked particleset yet.
sim_ds = xr.open_zarr(os.path.join(data_path, f'particles_{startdate}.zarr'))
print("Using original data!")
starting_day_str = str(sim_ds.isel(obs=0, trajectory=0).time.values.astype('datetime64[D]').astype(str))
starting_day_year = int(starting_day_str[:4])
# pre-processing to forward fill missing data (when particles get deleted). Assumption -> particles deleted remain at the last known position!
#sim_ds['lon'] = sim_ds.lon.ffill(dim='obs') # Forward fill longitude
#sim_ds['lat'] = sim_ds.lat.ffill(dim='obs') # Forward fill latitude
#sim_ds['z'] = sim_ds.z.ffill(dim='obs') # Forward fill depth
# Shift longitudes to -180 to 180 for ARC processing
#sim_ds['lon'] = ((sim_ds['lon'] + 180) % 360) - 180
# NOTE:What would happen if we just request 64GB and load the entire dataset into memory?
# sim_ds = sim_ds.load() # Load the entire dataset into memory for faster processing
# and what if larger chunk sizes = faster loading? - check normal vs rechunked
# Construct a list of "starting times" for each trajectory
global_start_times = (sim_ds.isel(obs=0).time.values - sim_ds.isel(obs=0, trajectory=0).time.values).astype('timedelta64[D]').astype(int)
# Construct a list of trajectory IDS
global_trajectory_id = sim_ds.trajectory.values
# List of plastic sizes we model
plastic_sizes = np.unique(sim_ds.plastic_diameter.values)
# List of release types (0=river, 1=coastal, 2=fisheries)
release_classes = np.unique(sim_ds.release_class.values).astype(int)
# List of trajectories for each plastic size
plastic_class_traj = {}
for p_size in plastic_sizes:
mask = ((sim_ds.plastic_diameter == p_size)).rename("plastic_mask")
plastic_class_traj[p_size] = sim_ds.sel(trajectory=mask).trajectory.values
# List of trajectories for each release type
release_class_traj = {}
for r_class in release_classes:
mask = ((sim_ds.release_class == r_class)).rename("release_mask")
release_class_traj[r_class] = sim_ds.sel(trajectory=mask).trajectory.values
# Map the number of days to process into weekly units
ndays_to_process = int(np.floor(sim_ds.obs.size/7)*7) # NOTE: because we save in chunks of 7 obs, sometimes the last few days are not complete weeks, this will be handled in the restart processing step
# Flag if there are additional days at the end that we have to deal with, especially if we need to handle restart files...
additional_days_f = True if ndays_to_process < sim_ds.obs.size else False
print(f"Processing: start_day={starting_day_str}. Do we need to handle additional days at the end? {additional_days_f}")
# Compute number of weeks to process for this file
loop_ndays_to_process = range(startposition*7, ndays_to_process, 7)
for obs in tqdm(loop_ndays_to_process):
sim_day = pd.to_datetime(starting_day_str) + pd.Timedelta(days=obs)
sim_day_str = sim_day.strftime("%Y-%m-%d")
if os.path.exists(os.path.join(output_dir, f"arc_{starting_day_str}_week_{sim_day_str}.zarr")):
continue # Skip already processed weeks
# If we start on a new week, we create a new accumulator
accumulator_histograms = np.zeros((2, 6, lon_centers.size, lat_centers.size), dtype=np.float32) # Surface/depth, plastic size, lon, lat
# Get a list of trajectories that were valid in this obs
valid_trajectory_mask = global_start_times <= obs
valid_trajectory_ids = global_trajectory_id[valid_trajectory_mask]
for release_i, release_class in enumerate(release_classes):
# Get only the trajectories for this release type and plastic class
release_type_trajectory_ids = release_class_traj[release_class]
valid_release_trajectory_ids = np.intersect1d(valid_trajectory_ids, release_type_trajectory_ids, assume_unique=True)
# Get the scaling factor for this year and release type
rc_scaling_factor = scaling_factors[(starting_day_year, release_class)]/6 # Divide by 6 because of the number of plastic classes
# Construct 2 dataarrays to select over - see example here: https://docs.xarray.dev/en/latest/user-guide/interpolation.html#advanced-interpolation
x = xr.DataArray(valid_release_trajectory_ids, dims="traj")
y = xr.DataArray([obs + i - global_start_times[valid_release_trajectory_ids] for i in range(7)], dims=["localobs","traj"])
# The object we want to compute the histogram for!
release_object_ds = sim_ds.sel(trajectory=x, obs=y)
release_object_ds = release_object_ds.set_xindex('trajectory')
release_object_ds.load()
for plastic_i, plastic_size in enumerate(plastic_sizes):
# Get only the trajectories for this plastic size
plastic_size_trajectory_ids = plastic_class_traj[plastic_size]
valid_plastic_trajectory_ids = np.intersect1d(valid_release_trajectory_ids, plastic_size_trajectory_ids, assume_unique=True)
# Construct 2 dataarrays to select over - see example here: https://docs.xarray.dev/en/latest/user-guide/interpolation.html#advanced-interpolation
x = xr.DataArray(valid_plastic_trajectory_ids, dims="traj")
#y = xr.DataArray([obs + i - global_start_times[valid_plastic_trajectory_ids] for i in range(7)], dims=["localobs","traj"])
# The object we want to compute the histogram for!
object_ds = release_object_ds.sel(trajectory=x)#, obs=y)
object_ds.load() # Load into memory for faster processing
# Create list of lons, lats, and weights
nonnan_indices_nf = ~np.isnan(object_ds.lon.values)
nonnan_indices = nonnan_indices_nf.flatten()
lon_values = object_ds.lon.values.flatten()[nonnan_indices]
lon_values = ((lon_values + 180) % 360) - 180
lat_values = object_ds.lat.values.flatten()[nonnan_indices]
plastic_amount_values = np.repeat(object_ds.plastic_amount.values.flatten(), np.sum(nonnan_indices_nf, axis=1))
# Compute the histogram for the watercolumn
H_arc, _, _ = np.histogram2d(lon_values, lat_values,
weights=plastic_amount_values*rc_scaling_factor,
bins=(arc_x_edges, arc_y_edges), density=False)
# Add the far RHS column to the LHS column for the periodicity in the domain
H_arc[0, :] += H_arc[-1, :]
H_arc = H_arc[:-1, :] # Remove the far RHS column
# Divide by 7 to get an average value over the week
H_arc = H_arc / 7
# Now compute surface only histograms
surface_object_ds_mask = (object_ds.z <= surface_threshhold).rename("surface_mask").compute()
surface_object_ds = object_ds.where(surface_object_ds_mask)
# Because some particles will move into/out of the surface layer during the week, we have to handle the weights accordingly
nonnan_indices = ~np.isnan(surface_object_ds.lon.values.flatten())
lon_values = surface_object_ds.lon.values.flatten()[nonnan_indices]
lon_values = ((lon_values + 180) % 360) - 180
lat_values = surface_object_ds.lat.values.flatten()[nonnan_indices]
plastic_amount_values = surface_object_ds.plastic_amount.values.flatten()[nonnan_indices]
# Compute the histogram for the watercolumn
H_arc_surf, _, _ = np.histogram2d(lon_values, lat_values,
weights=plastic_amount_values*rc_scaling_factor,
bins=(arc_x_edges, arc_y_edges), density=False)
# Add the far RHS column to the LHS column for the periodicity
H_arc_surf[0,:] += H_arc_surf[-1,:]
H_arc_surf = H_arc_surf[:-1, :] # Remove the last column
# Divide by 7 to get an average value over the week
H_arc_surf = H_arc_surf / 7
accumulator_histograms[0, plastic_i, :, :] += H_arc
accumulator_histograms[1, plastic_i, :, :] += H_arc_surf
# Now that the accumulators are ready, construct an xarray dataset, and save to file
histogram_ds = xr.Dataset(
{
"plastic_amount": (("watercolumn_surface_flag", "plastic_size", "lon", "lat"), accumulator_histograms, {"units": "kilograms", 'description': 'Average plastic mass per grid cell over the week.'})
},
coords={
"watercolumn_surface_flag": ("watercolumn_surface_flag", ["watercolumn", "surface"], {'description': 'Flag indicating whether the histogram is for the entire water column or just the surface layer (top 5 meters).'}),
"plastic_size": ("plastic_size", plastic_sizes, {'units': 'm', 'description': 'Diameter of the plastic particles.'}),
"lon": ("lon", lon_centers, {'units': 'degrees_east', 'description': 'Longitude bin centers for the histogram.'}),
"lat": ("lat", lat_centers, {'units': 'degrees_north', 'description': 'Latitude bin centers for the histogram.'}),
"time": ("time", [pd.to_datetime(sim_day_str)], {'description': 'Starting day of the week for which the histogram is computed.'})
},
attrs={
"description": "Weekly averaged plastic mass histograms over the Arctic region for different plastic sizes.",
"starting_day": starting_day_str
}
)
# Save the output to zarr
histogram_ds.to_zarr(os.path.join(output_dir, f"arc_{starting_day_str}_week_{sim_day_str}.zarr"), mode='w')
# Cleanup memory
del histogram_ds
gc.collect()
print(f"Processing for: start_day={starting_day_str}, has been completed.")
print(f"Additional days at the end: {additional_days_f} (if True, these have not been handled here.)") #NOTE: See above on how additinal_days_f might not cover the entire last week.