Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
import datetime
import numpy as np
import xarray as xr
import progressbar
from tqdm import tqdm


class Dataset():
__name__ = ['fit', 'clip_by_date']
Expand Down Expand Up @@ -80,10 +77,10 @@ def fit(self):
"""main process for making training/test data"""
# get path of target et data

p = progressbar.ProgressBar()
et_path = glob.glob(self.data_path+'ET/' +
"*{name}*nc".format(name=self.et_product))[0]

"*{name}_{tr}_{sr}.nc".format(name=self.et_product,tr=self.t_resolution, sr=self.s_resolution))[0]
print(et_path)
# exit(0)
PATH = self.inputs_path+self.et_product+'/'
print(' [DataML] loading lat/lon grids')
lat_file_name = 'lat_{t}_{s}.npy'.format(
Expand All @@ -104,7 +101,6 @@ def fit(self):
sr=self.s_resolution,
begin_year=self.begin_year,
end_year=self.end_year)

if os.path.exists(PATH+file_name):
forcing = np.load(PATH+file_name) # (t,lat,lon,feat)
else:
Expand Down Expand Up @@ -176,10 +172,22 @@ def fit(self):

print('begin:{begin_year}, end:{end_year}'.format(
begin_year=self.begin_year, end_year=self.end_year))
print('forcing shape is {shape}'.format(shape=forcing.shape))
print('ET shape is {shape}'.format(shape=et.shape))
print('LAI shape is {shape}'.format(shape=lai.shape))
print('static shape is {shape}'.format(shape=static.shape))

if (self.et_product == 'REA') | (self.et_product == 'CAMELE') :
print('{name} data reshape'.format(name=self.et_product))
ilat = et.shape[1]
forcing = forcing[:,:ilat]
lai = lai[:,:ilat]
static = static[:,:ilat]
print('forcing shape is {shape}'.format(shape=forcing.shape))
print('ET shape is {shape}'.format(shape=et.shape))
print('LAI shape is {shape}'.format(shape=lai.shape))
print('static shape is {shape}'.format(shape=static.shape))
else:
print('forcing shape is {shape}'.format(shape=forcing.shape))
print('ET shape is {shape}'.format(shape=et.shape))
print('LAI shape is {shape}'.format(shape=lai.shape))
print('static shape is {shape}'.format(shape=static.shape))
assert forcing.shape[0] == et.shape[0], "X(t) /= ET(t)"
assert forcing.shape[0] == lai.shape[0], "X(t) /= LAI(t)"
# get shape
Expand All @@ -189,11 +197,9 @@ def fit(self):
n=N, m=self.time_length-N))

print('[DataML] preprocessing')

#DEBUG(@lu li):Use less memory
feat = np.concatenate([forcing, lai], axis=-1)
del forcing, lai

x_train, y_train = feat[:N], et[:N]
x_test, y_test = feat[N:], et[N:]
del feat, et
Expand Down Expand Up @@ -233,28 +239,40 @@ def fit(self):
x_train = x_train[:, lat_idx][:, :, lon_idx]
y_train = y_train[:, lat_idx][:, :, lon_idx]


Nstatic = np.tile(static,(x_test.shape[0],1,1,1))
x_test = np.concatenate([x_test, Nstatic], axis=-1)
del Nstatic
# save output
nt, nlat, nlon, nfeat = x_train.shape
_, nlat, nlon, nf_static = static.shape
x_train = x_train.reshape(nt,-1, nfeat)
y_train = y_train.reshape(nt,-1, 1)
# Debug(@lu li): use less memory
static = static.reshape(1, -1, nf_static)
static = static.reshape(1, -1, nf_static)
x_train = np.delete(x_train, np.argwhere(np.isnan(y_train)), axis=1)
static = np.delete(static, np.argwhere(np.isnan(y_train)), axis=1)
static = np.delete(static , np.argwhere(np.isnan(y_train)), axis=1)
y_train = np.delete(y_train, np.argwhere(np.isnan(y_train)), axis=1)
static = np.tile(static, (nt, 1, 1))
print('delete:',"x_train",x_train.shape, 'y_train',y_train.shape)#==============

static = np.tile(static, (nt, 1, 1))
x_train = np.concatenate([x_train, static], axis=-1)
print(y_train.shape)
print('{n} million feats for training'.format(
n=x_train.shape[0]*x_train.shape[1]/1000000))
print('{n} million samples for training'.format(
n=x_train.shape[0]/1000000))

x_train = x_train.reshape(-1, x_train.shape[2]) #TODO(@xuqch):add to reshape
y_train = y_train.reshape(-1,1)
print('processed:','x_train',x_train.shape, 'y_train',y_train.shape, 'x_test',x_test.shape, 'y_test',y_test.shape)
del static
#FIXME(@xuqch):resort
np.save('x_train.npy', x_train)
np.save('y_train.npy', y_train)
np.save('x_test.npy', x_test)
np.save('y_test.npy', y_test)
os.system('mv {} {}'.format("*.npy", PATH))

print('{n} million feats for training'.format(
n=x_train.shape[0]*x_train.shape[1]/1000000))
print('{n} million samples for training'.format(
n=x_train.shape[0]/1000000))
return x_train, y_train, x_test, y_test, lat, lon

def _load_forcing(self,
Expand All @@ -271,7 +289,9 @@ def _load_forcing(self,
fold = "{tr}_{sr}/".format(tr=t_resolution, sr=s_resolution)
file = forcing_root + fold + "/ERA5Land_{year}_{var}_{tr}_{sr}.nc".format(
year=year, var=forcing_list[i], tr=t_resolution, sr=s_resolution)
# print(file)
with xr.open_dataset(file) as f:
# print(f[forcing_list[i]].shape)
tmp.append(f[forcing_list[i]])
tmp = np.stack(tmp, axis=-1)
forcing.append(tmp)
Expand All @@ -286,19 +306,24 @@ def _load_et(self, et_root, et_product, temporal_resolution, spatial_resolution)

def _load_lai(self, lai_root, begin_year, end_year, t_resolution, s_resolution):
lai_all = []
fold = "{tr}_{sr}/".format(tr=t_resolution, sr=s_resolution)

with xr.open_dataset(lai_root+fold+'LAI_{tr}_{sr}.nc'.format(
tr=t_resolution, sr=s_resolution)) as f:
fold = "1D_{sr}/".format(sr=s_resolution)
with xr.open_dataset(lai_root+fold+'LAI_1D_{sr}.nc'.format(sr=s_resolution)) as f:
lai = np.array(f.lai)

for year in range(begin_year, end_year+1):

if (year % 4 == 0) & (year % 100 != 0):
lai_all.append(lai)
if (year % 4 == 0) & (year % 100 != 0)|(year % 400 == 0):
if t_resolution == '1D':
lai_all.append(lai)
else:
lai_all.append(lai[::8])
else:
idx = np.delete(np.arange(366), 59, axis=0) # remove 2.29
lai_all.append(lai[idx])
if t_resolution == '1D':
idx = np.delete(np.arange(366), 59, axis=0) # remove 2.29
lai_all.append(lai[idx])
else:
idx = np.delete(np.arange(366), 59, axis=0) # remove 2.29
lai_m = lai[idx]
lai_all.append(lai_m[::8])
lai = np.concatenate(lai_all, axis=0)
lai = lai[:,:,:,np.newaxis]
return lai
Expand Down