diff --git a/fmsgridtools/__init__.py b/fmsgridtools/__init__.py index 683eaa7..6c42b8e 100644 --- a/fmsgridtools/__init__.py +++ b/fmsgridtools/__init__.py @@ -5,6 +5,7 @@ from .make_mosaic import coupler_mosaic from .make_topog import make_topog from .remap import remap +from .remap.variableobj import VariableObj, SrcFileObj, TgtFileObj, DimObj, FileDimsObj from .shared.gridtools_utils import check_file_is_there, get_provenance_attrs from .shared.gridobj import GridObj from .shared.mosaicobj import MosaicObj diff --git a/fmsgridtools/remap/conservative.py b/fmsgridtools/remap/conservative.py index dac398d..7168f31 100644 --- a/fmsgridtools/remap/conservative.py +++ b/fmsgridtools/remap/conservative.py @@ -1,31 +1,65 @@ import numpy as np +from fmsgridtools.remap.variableobj import SrcFileObj, TgtFileObj, VariableObj from fmsgridtools.shared.xgridobj import XGridObj -def remap(src_mosaic: str, - input_dir: str = "./", - output_dir: str = None, - input_file: str = None, - output_file: str = None, - tgt_mosaic: str = None, - tgt_nlon: int = None, - tgt_nlat: int = None, - lon_bounds: list = None, - lat_bounds: list = None, - kbounds: list = None, - tbounds: list = None, - order: int = 1, - static_file: str = None, - check_conserve: bool = False): +import pyfms - #create an xgrid object - xgrid = XGridObj(input_dir, src_mosaic_file=src_mosaic, tgt_mosaic_file=tgt_mosaic) - #create xgrid - xgrid.create_xgrid() +def remap(input_dir: str = "./", + src_mosaicfile: str = None, + tgt_mosaicfile: str = None, + input_file: str = None, + output_dir: str = "./", + output_file: str = None, + scalar_variables: list[str] = None, + lon_bounds: list = None, + lat_bounds: list = None, + kbounds: list = None, + tbounds: list = None, + order: int = 1, + check_conserve: bool = False, + gpu: bool = False): - #write - xgrid.write() + xgrid = XGridObj(input_dir=input_dir, src_mosaicfile=src_mosaicfile, tgt_mosaicfile=tgt_mosaicfile) + tgt_tiles = list(xgrid.tgt.grid.keys()) + src_tiles = list(xgrid.src.grid.keys()) + pyfms.fms.init(ndomain=len(tgt_tiles)) + for tgt_tile in tgt_tiles: + + pyfms.horiz_interp.init(len(src_tiles)) + + nx_tgt, ny_tgt = xgrid.tgt.grid[tgt_tile].nx, xgrid.tgt.grid[tgt_tile].ny + + #get xgrid, make sure mpi works cause i don't think it will + xgrid.domain = pyfms.mpp_domains.define_domains([0, nx_tgt-1, 0, ny_tgt-1]) + xgrid.set_target_tile(tgt_tile) + xgrid.get_interp() + + if input_file is None: return + if output_file is None: output_file = input_file + if output_dir is None: output_dir = input_dir + + src_fileobj = SrcFileObj(datafile=input_file, input_dir=input_dir, tiles=src_tiles) + tgt_fileobj = TgtFileObj(datafile=output_file, output_dir=output_dir, nx=nx_tgt, ny=ny_tgt) + + for var in src_fileobj.variables: + + variable = VariableObj(var, src_fileobj) + variable.init() + + for itime in variable.timeslist: + for klevel in variable.zlist: + + input_data = variable.slice(tile=src_tiles[0], timepoint=itime, klevel=klevel, prepare_data=True) + data_slice = pyfms.horiz_interp.interp(xgrid.interps[src_tile].interp_id, input_data, convert_cf_order=False) + + for src_tile in src_tiles[1:]: + input_data = variable.slice(tile=src_tile, timepoint=itime, klevel=klevel, prepare_data=True) + data_slice += pyfms.horiz_interp.interp(xgrid.interps[src_tile].interp_id, input_data, convert_cf_order=False) + + #gather if parallel, for now, no + variable.set_tgt_data(data_slice, timepoint=itime, klevel=klevel) diff --git a/fmsgridtools/remap/remap.py b/fmsgridtools/remap/remap.py index d29beef..bab73b0 100644 --- a/fmsgridtools/remap/remap.py +++ b/fmsgridtools/remap/remap.py @@ -17,30 +17,42 @@ Only 1 or 2 order is supported for conservative interpolation """ ) -@click.option("--static_file", - type = click.Path(exists=True), +@click.option("--check_conserve", + type = bool, help = """ - To remap data where cell_methods = CELL_METHODS_MEAN, the static_file - will src grid cell areas should be provided + If true, output grid area conservative will be checked """ ) -@click.option("--check_conserve", +@click.option("--gpu", type = bool, help = """ - If true, output grid area conservative will be checked + If true, the exchange grid will be created on the GPU """ ) -def conservative_method(input_dir, output_dir, input_file, output_file, #common_options - src_mosaic, tgt_mosaic, tgt_nlon, tgt_nlat, #common_options - lon_bounds, lat_bounds, kbounds, tbounds, #common_options - debug, order, static_file, check_conserve): - +def conservative_method(input_dir, output_dir, input_mosaic_dir, output_mosaic_dir, + input_file, output_file, #common_options + src_mosaic, tgt_mosaic, tgt_nlon, tgt_nlat, #common_options + scalar_variables, lon_bounds, lat_bounds, + kbounds, tbounds, debug, order, check_conserve, gpu): + setlogger.setconfig("remap.log", debug) logger.info("Starting conservative remapping") - - conservative.remap(src_mosaic, input_dir, output_dir, input_file, - output_file, tgt_mosaic, tgt_nlon, tgt_nlat, - lon_bounds, lat_bounds, kbounds, tbounds, - order, static_file, check_conserve) + + xgrid = conservative.remap(input_dir=input_dir, + output_dir=output_dir, + input_mosaic_dir=input_mosaic_dir, + output_mosaic_dir=output_mosaic_dir, + input_file=input_file, + src_mosaic=src_mosaic, + tgt_mosaic=tgt_mosaic, + output_file=output_file, + scalar_variables=scalar_variables, + lon_bounds=lon_bounds, + lat_bounds=lat_bounds, + kbounds=kbounds, + tbounds=tbounds, + order=order, + check_conserve=check_conserve, + gpu=gpu) diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py new file mode 100644 index 0000000..a1fe40c --- /dev/null +++ b/fmsgridtools/remap/variableobj.py @@ -0,0 +1,256 @@ +from itertools import pairwise +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import numpy.typing as npt +import xarray as xr + +class DimObj(): + + def __init__(self, axis: str, name: str = None, here: bool = False, size: int = None, attr: dict = None, coord_values: npt.NDArray = None): + self.axis = axis + self.name = name + self.size = size + self.attr = None + self.coords = coord_values + self.here = False + + +class FileDimsObj(): + + def __init__(self): + self.x = DimObj("X") + self.y = DimObj("Y") + self.z = DimObj("Z") + self.time = DimObj("T") + + + def init(self, dataset: xr.Dataset): + + """ + get dimensions + """ + + dims_dict = {dim.axis: dim for dim in [self.time, self.z, self.y, self.x]} + + for name, coord in dataset.coords.items(): + try: + dim = dims_dict[coord.attrs.get("axis")] + dim.name = name + dim.attr = coord.attrs + dim.size = coord.size + dim.coords = coord.values + dim.here = True + except: + print(f"{name} dim not found") + + def get_z(self, dims_list: list): + if self.z.name in dims_list: + return self.z + else: + return DimObj(axis="Z", name=self.z.name) + + def get_time(self, dims_list: list): + if self.time.name in dims_list: + return self.time + else: + return DimObj(axis="T", name=self.time.name) + + + def __repr__(self): + + repr_str ="\n" + for obj in [self.x, self.y, self.z, self.time]: + repr_str += f"axis={obj.axis}, name={obj.name}, here={obj.here}, size={obj.size}\n" + return repr_str + + +class SrcFileObj(): + + skip_variables = [ + "geolon_c", "geolat_c", "geolon_u", "geolat_u", "geolon_v", "geolat_v", + "FA_X", "FA_Y", "FI_X", "FI_Y", "IX_TRANS", "IY_TRANS", + "UI", "VI", "UO", "VO", "wet_c", "wet_v", "wet_u", + "dxCu", "dyCu", "dxCv", "dyCv", "Coriolis", + "areacello_cu", "areacello_cv", "areacello_bu", + "average_T1", "average_T2", "average_DT", "time_bnds" + ] + + + def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./", variables: list = None): + + self.input_dir = str(input_dir) + self.tiles = tiles + self.dims = FileDimsObj() + self.src_datafiles = {tile: Path(input_dir)/Path(datafile + f".{tile}.nc") for tile in tiles} + self.static_files = {} + self.variables = [] if variables is None else variables + + self.datasets = { + tile: xr.open_dataset(self.src_datafiles[tile], decode_cf=False) for tile in self.tiles + } + + dataset = self.datasets[tiles[0]] + + # get dims + self.dims.init(dataset) + + # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc + associated_files = dataset.attrs.get("associated_files") + if associated_files is not None: + stringsplit = associated_files.replace(":", " ").split() + for i in range(0, len(stringsplit), 2): + self.static_files[stringsplit[i]] = { + tile: Path(input_dir)/stringsplit[i+1].replace(".nc", f".{tile}.nc") for tile in self.tiles + } + + #get list of variables if not specified + if not bool(self.variables): + for variable in dataset: + if variable in self.skip_variables: + print(f"skipping {variable}") + else: + self.variables.append(variable) + + def __repr__(self): + repr_str = "\n" + repr_str += f"input_dir = {self.input_dir}\n" + repr_str += f"tiles = {self.tiles}\n" + repr_str += f"datafiles = {self.src_datafiles}\n" + repr_str += f"static_files = {self.static_files}\n" + return repr_str + + + +class TgtFileObj(): + + def __init__(self, datafile: str|Path, nx: int, ny: int, output_dir: str = "./"): + + self.output_dir = str(output_dir) + self.datafile = datafile + self.nx = nx + self.ny = ny + + self.datadict = {} + + + def init_variable(self, variable: str, dtype, dims_list: list = None, attributes: dict = None, z_size: int = None, time_size: int = None): + + data_info = {} + if dims_list is not None: data_info["dims"] = dims_list + if attributes is not None: data_info["attrs"] = attributes + + shape = [] + if time_size is not None: shape.append(time_size) + if z_size is not None: shape.append(z_size) + shape += [self.ny, self.nx] + + data_info["data"] = np.zeros(shape, dtype=dtype) + + self.datadict[variable] = data_info + + + def set_coords(self, grid): + pass + + + +class VariableObj(): + + time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) + + def __init__(self, variable: str, src_fileobj: SrcFileObj = None, tgt_fileobj: TgtFileObj = None): + + self.variable = variable + self.src_fileobj = src_fileobj + self.tgt_fileobj = tgt_fileobj + + self.z = None + self.time = None + + self.dtype = None + self.missing_value = None + self.fill_value = None + self.offset = None + self.scale_factor = None + + self.static_files: {} = None + + tile = self.src_fileobj.tiles[0] + dataset = self.src_fileobj.datasets[tile] + + if self.variable not in dataset: + raise RuntimeError("variable not found") + + dataarray = dataset[self.variable] + + dims_list = dataarray.dims + self.z = self.src_fileobj.dims.get_z(dims_list) + self.time = self.src_fileobj.dims.get_time(dims_list) + + self.dtype=dataarray.dtype + + attributes = dataarray.attrs + self.missing_value = attributes.get("missing_value") + self.fill_value = attributes.get("_FillValue") + self.offset = attributes.get("add_offset") + self.scale_factor = attributes.get("scale_factor") + + if "area" in str(attributes.get("cell_method")): + cell_measures = str(attributes.get("cell_measures")) + if "area:" in cell_measures: + self.static_files = self.src_fileobj.static_files[cell_measures.split()[1]] + + self.tgt_fileobj.init_variable(self.variable, self.dtype, dims_list=dims_list)#, attributes=attributes, z_size=self.z.size, time_size=self.time.size) + + + def slice(self, tile: str = "tile1", timepoint: int = None, klevel: int = None, prepare_data: bool = False): + + + dataset = self.src_fileobj.datasets[tile] + + slice_dict = {} + if klevel is not None: + slice_dict[self.z.name] = klevel + if timepoint is not None: + slice_dict[self.time.name] = timepoint + + src_data = dataset[self.variable].isel(slice_dict).values + + if prepare_data: + src_data = self.prepare_data(src_data) + return src_data + + + def prepare_data(self, src_data: npt.NDArray = None): + + #missing value mask + missing_value_mask = None + if self.missing_value is not None: + missing_value_mask = src_data == self.missing_value + + if self.offset is not None: src_data += self.offset + if self.scale_factor is not None: src_data *= self.scale_factor + + #zero out missing values so it doens't contribute to remapping + if missing_value_mask is not None: + src_data = xr.where(missing_value_mask, 0.0, src_data) + + return src_data + + + def set_tgt_data(self, data: npt.NDArray, timepoint: int = None, klevel: int = None): + + tgt_data = self.tgt_fileobj.datadict[self.variable]["data"] + + if timepoint is None and klevel is None: + tgt_data = data + elif timepoint is not None and klevel is not None: + tgt_data[timepoint, klevel, :, :] = data + elif timepoint is not None: + tgt_data[timepoint, :, :] = data + elif klevel is not None: + tgt_data[klevel, :, :] = data + + self.tgt_fileobj.datadict[self.variable]["data"] = tgt_data \ No newline at end of file diff --git a/fmsgridtools/shared/gridobj.py b/fmsgridtools/shared/gridobj.py index 9a08048..44d659a 100644 --- a/fmsgridtools/shared/gridobj.py +++ b/fmsgridtools/shared/gridobj.py @@ -150,7 +150,7 @@ def to_domain(self, domain: dict = None): else: if self.domain is not None: logger.warning("Overwriting %s with %s", self.domain, domain) - self.domain = domain + self.domain = domain if not pyfms.fms.module_is_initialized(): logger.error("Please initialize pyfms first") diff --git a/fmsgridtools/shared/xgridobj.py b/fmsgridtools/shared/xgridobj.py index ddbd5ad..415bfdf 100755 --- a/fmsgridtools/shared/xgridobj.py +++ b/fmsgridtools/shared/xgridobj.py @@ -1,4 +1,5 @@ -import ctypes +import logging +from pathlib import Path import numpy as np import numpy.typing as npt import xarray as xr @@ -7,226 +8,286 @@ import pyfms from fmsgridtools.shared.gridobj import GridObj -from fmsgridtools.shared.gridtools_utils import check_file_is_there from fmsgridtools.shared.mosaicobj import MosaicObj - -class XGridObj() : - - def __init__(self, - input_dir: str = "./", - src_mosaic_file: str = None, - tgt_mosaic_file: str = None, - restart_remap_file: str = None, - write_remap_file: str = "remap.nc", - src_mosaic: type[MosaicObj] = None, - tgt_mosaic: type[MosaicObj] = None, - src_grid: dict[type[GridObj]] = None, - tgt_grid: dict[type[GridObj]] = None, - dataset: type[xr.Dataset] = None, - datadict: dict = None, - on_agrid: bool = True, - order: int = 1, - on_gpu: bool = False): - self.input_dir = input_dir - self.src_mosaic_file = src_mosaic_file - self.tgt_mosaic_file = tgt_mosaic_file - self.restart_remap_file = restart_remap_file - self.write_remap_file = write_remap_file - self.src_mosaic = src_mosaic - self.tgt_mosaic = tgt_mosaic - self.src_grid = src_grid - self.tgt_grid = tgt_grid +logger = logging.getLogger(__name__) + + +class Parent: + + def __init__( + self, + parent: str, + input_dir: str | Path = "./", + mosaicfile: str | Path = None, + mosaic: MosaicObj = None, + gridfile: str | Path = None, + ntiles: int = None, + grid: dict[str, GridObj] = None, + mask: dict[str, npt.NDArray] = None, + domain: pyfms.Domain = None, + ): + + self.parent = parent + self.input_dir = Path(input_dir) + self.mosaicfile = mosaicfile + self.mosaic = mosaic + self.gridfile = gridfile + self.ntiles = ntiles + self.grid = grid + self.mask = mask + self.domain = domain + + if self.grid is None: + if self.mosaic is None: + if self.mosaicfile is None: + logger.warning("Cannot set %s grid", self.parent) + self.mosaic = MosaicObj( + input_dir=self.input_dir, mosaicfile=self.mosaicfile + ).read() + self.grid = self.mosaic.get_grid( + input_dir=self.input_dir, center=True, radians=True, domain=self.domain + ) + logger.info("set grid for %s", self.parent) + + +class XGridObj: + + def __init__( + self, + input_dir: str | Path = "./", + src_mosaicfile: str = None, + tgt_mosaicfile: str = None, + remapfile: str | Path = None, + src_mosaic: MosaicObj = None, + tgt_mosaic: MosaicObj = None, + src_gridfile: str | Path = None, + tgt_gridfile: str | Path = None, + src_grid: dict[str, GridObj] = None, + tgt_grid: dict[str, GridObj] = None, + tgt_tile: str = None, + src_mask: dict[str, np.ndarray] = None, + tgt_mask: dict[str, np.ndarray] = None, + order: int = 1, + domain: pyfms.Domain = None, + ): + + """ + Create an XGridObj container for building/reading remap/interp data. + + Args: + input_dir (str|Path): Base directory for input files. + src_mosaicfile (str): Source mosaic filename (optional). + tgt_mosaicfile (str): Target mosaic filename (optional). + remapfile (str|Path): Remap weights filename (optional). + src_mosaic (MosaicObj): Optional pre-built source MosaicObj. + tgt_mosaic (MosaicObj): Optional pre-built target MosaicObj. + src_gridfile (str|Path): Source grid filename (optional). + tgt_gridfile (str|Path): Target grid filename (optional). + src_grid (dict[str, GridObj]): Optional dict of source GridObj keyed by tile name. + tgt_grid (dict[str, GridObj]): Optional dict of target GridObj keyed by tile name. + tgt_tile (str): Name of the target tile to use (default: "tile1"). + src_mask (dict[str, np.ndarray]): Optional masks for source tiles. + tgt_mask (dict[str, np.ndarray]): Optional masks for the target grid. + order (int): Interpolation order (currently stored; semantics depend on downstream code). + on_gpu (bool): If True, use GPU-based building of xgrid via pyfrenctools. + + The constructed object stores container namespaces for source (`self.src`) and + target (`self.tgt`) data and an `interps` mapping that is populated by + `read` or `get_interp`. + """ + + self.input_dir: str | Path = Path(input_dir) + self.src = Parent( + parent="src", + input_dir=input_dir, + mosaicfile=src_mosaicfile, + gridfile=src_gridfile, + mosaic=src_mosaic, + grid=src_grid, + mask=src_mask, + domain=None, + ) + self.tgt = Parent( + parent="tgt", + mosaicfile=tgt_mosaicfile, + gridfile=tgt_gridfile, + mosaic=tgt_mosaic, + grid=tgt_grid, + mask=tgt_mask, + domain=domain, + ) + + self.tgt_tile = tgt_tile + if self.tgt_tile is not None: + self.tgt.grid = self.tgt.grid[tgt_tile] + + self.remapfile: str | Path = remapfile self.order = order - self.on_gpu = on_gpu - self.on_agrid = on_agrid - self.dataset = dataset - self.datadict = datadict - self._srcinfoisthere = False - self._tgtinfoisthere = False + self.interps: pyfms.ConserveInterp | dict[str, pyfms.ConserveInterp] = None + - if self._check_restart_remap_file(): return - if self.datadict is not None: return - if self.dataset is not None: return + def set_target_tile(self, tgt_tile: str = "tile1"): + self.tgt_tile = tgt_tile + self.tgt.grid = self.tgt.grid[tgt_tile] - self._check_grid() - self._check_mosaic() - self._check_mosaic_file() - if not self._srcinfoisthere or not self._tgtinfoisthere: - raise RuntimeError("Please provide grid information") + def read( + self, + input_dir: Path | str = None, + remapfile: Path | str = None, + domain: pyfms.Domain = None, + ): + """ + read remap file and store as pyfms.ConserveInterp objects + """ - def read(self, infile: str = None): + input_dir = self.input_dir if input_dir is None else Path(input_dir) - if infile is None: - if self.restart_remap_file is None: - raise RuntimeError("must provide the input remap file for reading") - infile = self.restart_remap_file + if remapfile is None: + if self.remapfile is None: + logger.error("Please specify remapfile to read") + remapfile = self.remapfile + remapfile = input_dir / Path(remapfile) - self.dataset = xr.open_dataset(infile) + if not remapfile.exists(): + logger.error("remap file %s does not exist", self.remapfile) - for key in self.dataset.data_vars.keys(): - setattr(self, key, self.dataset[key].values) + if domain is None: + domain = self.tgt.domain - for key in self.dataset.sizes: - setattr(self, key, self.dataset.sizes[key]) + self.interps = {} + for itile, src_tile in enumerate(self.src.grid): + interp_id = pyfms.horiz_interp.read_weights_conserve( + weight_filename=str(remapfile), + weight_file_src="fregrid", + nlon_src=self.src.grid[src_tile].nx, + nlat_src=self.src.grid[src_tile].ny, + nlon_tgt=self.tgt.grid.nx, + nlat_tgt=self.tgt.grid.ny, + domain=domain, + src_tile=itile, + save_xgrid_area=True, + ) + self.interps[src_tile] = pyfms.ConserveInterp(interp_id, save_xgrid_area=True) - def write(self, outfile: str = None): + def gather(self): - if outfile is None: - outfile = self.write_remap_file + """ + gathers xgrid + """ - if self.dataset is None: - if self.datadict is not None: - self.to_dataset() + isc, jsc = self.tgt.domain.isc, self.tgt.domain.jsc - for tgt_tile in self.dataset: - concat_dataset = xr.concat([self.dataset[tgt_tile][src_tile] for src_tile in self.dataset[tgt_tile]], dim="nxcells") + global_interps = {} + for src_tile in self.interps: + interp = self.interps[src_tile] + global_interp = global_interps[src_tile] = pyfms.ConserveInterp() + nxgrids = pyfms.mpp.gather(np.array([interp.nxgrid], dtype=np.int32), rbuf_size=len(self.interps)) + global_interp.nxgrid = np.sum(nxgrids) if pyfms.mpp.pe() == pyfms.mpp.root_pe() else None + global_interp.i_src = pyfms.mpp.gatherv(interp.i_src, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.j_src = pyfms.mpp.gatherv(interp.j_src, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.i_dst = pyfms.mpp.gatherv(interp.i_dst + isc, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.j_dst = pyfms.mpp.gatherv(interp.j_dst + jsc, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.xgrid_area = pyfms.mpp.gatherv(interp.xgrid_area, ssize=interp.nxgrid, rsize=nxgrids) + return global_interps - concat_dataset.to_netcdf(outfile) + def write(self, output_dir: Path | str = "./", outfile: str | Path = Path("remap.nc")): + """ + write remap file + """ - def to_dataset(self): + global_interps = self.interps if self.tgt.domain is None else self.gather() - if self.datadict is None: raise OSError("datadict is None") + if pyfms.mpp.pe() == pyfms.mpp.root_pe(): - datadict = self.datadict - self.dataset = {} + outfile = Path(output_dir) / outfile + logger.info("writing remap file to %s", outfile) - for tgt_tile in datadict: - self.dataset[tgt_tile] = {} - for src_tile in datadict[tgt_tile]: + datasets = [] + for tile1, src_tile in enumerate(global_interps): - thisdict = datadict[tgt_tile][src_tile] - dataset = self.dataset[tgt_tile][src_tile] = xr.Dataset() + interp = global_interps[src_tile] + dataset = xr.Dataset() - dataset["src_cell"] = xr.DataArray(np.column_stack((thisdict['src_i']+1, thisdict['src_j']+1)), - dims=["nxcells", "two"], - attrs={"src_cell": "parent cell indices in src mosaic", - "_FillValue": False} + dataset["tile1"] = xr.DataArray( + np.full(interp.nxgrid, tile1), + dims=["ncells"], + attrs={"standard_name": "tile_number_in_mosaic1"}, ) - dataset["tgt_cell"] = xr.DataArray(np.column_stack((thisdict['tgt_i']+1, thisdict['tgt_j']+1)), - dims=["nxcells", "two"], - attrs={"tgt_cell": "parent cell indices in tgt mosaic", - "_FillValue": False}) - dataset["xarea"] = xr.DataArray(thisdict['xarea'], - dims=["nxcells"], - attrs={"xarea": "exchange grid area", - "_FillValue": False} + dataset["tile1_cell"] = xr.DataArray( + np.column_stack((interp.i_src + 1, interp.j_src + 1)), + dims=["ncells", "two"], + attrs={"standard_name": "parent_cell_indices_in_mosaic1"}, ) - - - def create_xgrid(self, src_mask: dict[str,npt.NDArray] = None, tgt_mask: dict[str, npt.NDArray] = None) -> dict: - - if self.order not in (1,2): - raise RuntimeError("conservative order must be 1 or 2") - - if self.on_gpu: - create_xgrid_2dx2d_order1 = pyfrenctools.create_xgrid.get_2dx2d_order1_gpu - else: - create_xgrid_2dx2d_order1 = pyfrenctools.create_xgrid.get_2dx2d_order1 - - if self.datadict is None: self.datadict = {} - - for tgt_tile in self.tgt_grid: - - self.datadict[tgt_tile], itile = {}, 1 - - itgt_mask = None if tgt_mask is None else tgt_mask[tgt_tile] - - for src_tile in self.src_grid: - - isrc_mask = None if src_mask is None else src_mask[src_tile] - - xgrid_out = create_xgrid_2dx2d_order1( - src_nlon = self.src_grid[src_tile].nx, - src_nlat = self.src_grid[src_tile].ny, - tgt_nlon = self.tgt_grid[tgt_tile].nx, - tgt_nlat = self.tgt_grid[tgt_tile].ny, - src_lon=self.src_grid[src_tile].x, - src_lat=self.src_grid[src_tile].y, - tgt_lon=self.tgt_grid[tgt_tile].x, - tgt_lat=self.tgt_grid[tgt_tile].y, - src_mask=isrc_mask, - tgt_mask=itgt_mask + dataset["tile2_cell"] = xr.DataArray( + np.column_stack((interp.i_dst + 1, interp.j_dst + 1)), + dims=["ncells", "two"], + attrs={"standard_name": "parent_cell_indices_in_mosaic2"}, ) - - nxcells = xgrid_out["nxcells"] - if nxcells > 0: - xgrid_out["tile"] = np.full(nxcells, itile, dtype=np.int32) - self.datadict[tgt_tile][src_tile] = xgrid_out - itile = itile + 1 - - - def to_dataset_raw(self): - - if self.datadict is None: raise RunTimeError("datadict is None") - - for tgt_tile in self.datadict: - for src_tile in self.datadict[tgt_tile]: - src_i = xr.DataArray(self.datadict[tgt_tile][src_tile]["src_i"], dims=["nxcells"], - attrs={"src_i": "parent longitudinal (x) cell indices in src_mosaic"}) - src_j = xr.DataArray(self.datadict[tgt_tile][src_tile]["src_j"], dims=["nxcells"], - attrs={"src_j": "parent latitudinal (y) cell indices in src_mosaic"}) - tgt_i = xr.DataArray(self.datadict[tgt_tile][src_tile]["tgt_i"], dims=["nxcells"], - attrs={"src_i": "parent longitudinal (x) cell indices in src_mosaic"}) - tgt_j = xr.DataArray(self.datadict[tgt_tile][src_tile]["tgt_j"], dims=["nxcells"], - attrs={"src_j": "parent latitudinal (y) cell indices in src_mosaic"}) - xarea = xr.DataArray(self.datadict[tgt_i][src_i]["xarea"], dims=["nxcells"], - attrs={"xarea":"exchange grid cell area (m2)"}) - - self.dataset[tgt_tile][src_tile] = xr.DataSet(data_vars={"src_i": src_i, - "src_j": src_j, - "tgt_i": tgt_i, - "tgt_j": tgt_j, - "xarea": xarea}) - - def _check_restart_remap_file(self): - if self.restart_remap_file is not None : - check_file_is_there(self.restart_remap_file) - self.read() - return True - return False - - - def _check_grid(self): - - if self.src_grid is not None: self._srcinfoisthere = True - if self.tgt_grid is not None: self._tgtinfoisthere = True - - - def _check_mosaic(self): - - if self.src_mosaic is None: return - if self.tgt_mosaic is None: return - - if self.src_mosaic.grid is None: - self.src_mosaic.get_grid(toradians=True, agrid=self.on_agrid, free_dataset=True) - if self.tgt_mosaic.grid is None: - self.tgt_mosaic.get_grid(toradians=True, agrid=self.on_agrid, free_dataset=True) - - self.src_grid = self.src_mosaic.grid - self.tgt_grid = self.tgt_mosaic.grid - - self._srcinfoisthere = True - self._tgtinfoisthere = True - - - def _check_mosaic_file(self): - - if self.src_mosaic_file is not None: - self.src_grid = MosaicObj(input_dir=self.input_dir, - mosaic_file=self.src_mosaic_file).read().get_grid(toradians=True, - agrid=self.on_agrid, - free_dataset=True) - self._srcinfoisthere = True - - if self.tgt_mosaic_file is not None: - self.tgt_grid = MosaicObj(input_dir=self.input_dir, - mosaic_file=self.tgt_mosaic_file).read().get_grid(toradians=True, - agrid=self.on_agrid, - free_dataset=True) - self._tgtinfoisthere = True + dataset["xgrid_area"] = xr.DataArray( + interp.xgrid_area, + dims=["ncells"], + attrs={"standard_name": "exchange_grid_area", "units": "m2"}, + ) + datasets.append(xr.Dataset(dataset)) + + dataset = xr.concat(datasets, dim="ncells") + encoding = {variable: {"_FillValue": None} for variable in dataset} + dataset.to_netcdf(outfile, encoding=encoding) + + pyfms.mpp.sync() + + + def get_interp(self, on_gpu: bool = False) -> dict: + + """ + call fms to compute xgrid + """ + + self.interps = {} + + for src_tile in self.src.grid: + src_grid = self.src.grid[src_tile] + src_mask = None if self.src.mask is None else self.src.mask[src_tile] + if on_gpu: + xdict = pyfrenctools.create_xgrid.get_2dx2d_order1_gpu( + src_nlon=src_grid.nx, + src_nlat=src_grid.ny, + tgt_nlon=self.tgt.grid.nx, + tgt_nlat=self.tgt.grid.ny, + src_lon=src_grid.x, + src_lat=src_grid.y, + tgt_lon=self.tgt.grid.x, + tgt_lat=self.tgt.grid.y, + src_mask=src_mask, + tgt_mask=self.tgt.mask, + ) + interp = pyfms.ConserveInterp() + interp.nxgrid = xdict["nxcells"] + interp.i_src = xdict["src_i"] + interp.j_src = xdict["src_j"] + interp.i_dst = xdict["tgt_i"] + interp.j_dst = xdict["tgt_j"] + interp.xgrid_area = xdict["xarea"] + self.interps[src_tile] = interp + else: + interp_id = pyfms.horiz_interp.get_weights( + lon_in=src_grid.x, + lat_in=src_grid.y, + lon_out=self.tgt.grid.x, + lat_out=self.tgt.grid.y, + mask_in=src_mask, + mask_out=self.tgt.mask, + is_latlon_in=False, + is_latlon_out=False, + save_xgrid_area=True, + convert_cf_order=False, + as_fregrid=True, + interp_method="conservative", + ) + self.interps[src_tile] = pyfms.ConserveInterp(interp_id, save_xgrid_area=True) diff --git a/pyFMS b/pyFMS index 4e828f0..12c87d8 160000 --- a/pyFMS +++ b/pyFMS @@ -1 +1 @@ -Subproject commit 4e828f0f7f1253e7be166fd8efb11608da13a282 +Subproject commit 12c87d8532e1fcaa627d3163b3c8ec0fc115c4ba diff --git a/setup.py b/setup.py index 00074dc..d0a3d1b 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def local_pkg(name: str, relative_path: str) -> str: "numpy", "xarray", "netCDF4", - local_pkg("pyFMS", "pyFMS"), +# local_pkg("pyFMS", "pyFMS"), local_pkg("pyfrenctools", "FREnctools_lib") ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/remap/__init__.py b/tests/remap/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/remap/test_dataobj.py b/tests/remap/test_dataobj.py new file mode 100644 index 0000000..0494a41 --- /dev/null +++ b/tests/remap/test_dataobj.py @@ -0,0 +1,61 @@ +import numpy as np +from types import SimpleNamespace +import xarray as xr + +import fmsgridtools +import write_files + +def test_dataobj(): + + dataset = write_files.write_data("test") + + srcfileobj = fmsgridtools.SrcFileObj("test", tiles=["tile1"]) + tgtfileobj = fmsgridtools.TgtFileObj("test", nx=10, ny=14) + + variable1 = fmsgridtools.VariableObj(variable="variable1", src_fileobj=srcfileobj, tgt_fileobj=tgtfileobj) + + #check dimensions + dims = [ + ("grid_xt", variable1.src_fileobj.dims.x), + ("grid_yt", variable1.src_fileobj.dims.y), + ("time", variable1.time), + ("pfull", variable1.z) + ] + + for (name, dim) in dims: + assert dim.name == name + assert dim.here + assert dim.size == dataset[name].size + + attributes = dataset["variable1"].attrs + assert variable1.missing_value == attributes.get("missing_value") + assert variable1.fill_value == attributes.get("_FillValue") + assert variable1.offset == attributes.get("add_offset") + assert variable1.scale_factor == attributes.get("scale_factor") + + #slice + src = write_files.src + reconstruct_data = np.zeros((src.ntimes, src.nk, src.ny//2, src.nx//2), dtype=np.float64) + for itime in range(src.ntimes): + for k in range(src.nk): + sliced_data = variable1.slice(tile="tile1", timepoint=itime, klevel=k) + np.testing.assert_equal(sliced_data, write_files.data_list[0][itime, k, :, :]) + reconstruct_data[itime, k, :, :] = variable1.prepare_data(sliced_data) + + #check missing values + for (itime, k, j, i) in write_files.missing_ijkl: + assert reconstruct_data[itime,k,j,i] == np.float64(0.0) + +def test_remap(): + + write_files.write_mosaics() + dataset = write_files.write_data("test") + + src_mosaic = write_files.src.mosaicfile + tgt_mosaic = write_files.tgt.mosaicfile + + fmsgridtools.remap.conservative.remap(src_mosaicfile=src_mosaic, tgt_mosaicfile=tgt_mosaic, input_file="test") + + +if __name__ == "__main__": + test_dataobj() diff --git a/tests/remap/write_files.py b/tests/remap/write_files.py new file mode 100644 index 0000000..016719c --- /dev/null +++ b/tests/remap/write_files.py @@ -0,0 +1,157 @@ +import numpy as np +from types import SimpleNamespace + +import xarray as xr +import fmsgridtools + +src = SimpleNamespace( + ntiles=6, + nx=12, + ny=24, + nk=3, + ntimes=4, + dxy=1.0, + mosaicfile="src_mosaic.nc", + gridfile="src_grid" +) +tgt = SimpleNamespace( + ntiles=1, + nx=src.nx, #* 2, + ny=src.ny, #* 2, + nk=3, + ntimes=4, + dxy=src.dxy, # / 2.0, + mosaicfile="tgt_mosaic.nc", + gridfile="tgt_grid" +) + +nxgrid_per_tile = tgt.nx//2 * tgt.ny//2 +nxgrid = nxgrid_per_tile * 6 + +ntimes, nk, ny, nx = src.ntimes, src.nk, src.ny//2, src.nx//2 + +missing_value = -np.float64(-99.) +missing_ijkl = [(1,1,1,1), (2,2,2,2)] + +data_list = [] +for itile in range(1, src.ntiles+1): + data = np.zeros((ntimes, nk, ny, nx), dtype=np.float64) + for itime in range(ntimes): + for k in range(nk): + for j in range(ny): + start = itile*10000 + itime*1000+k*100+j*10 + data[itime, k, j, :] = np.arange(start, start+nx, dtype=np.float64) + for (itime, k, j, i) in missing_ijkl: + data[itime,k,j,i] = missing_value + data_list.append(data) + + +def write_mosaics(): + + """ + make mosaic and grid files for testing + """ + + # write mosaic + for parent in [src, tgt]: + fmsgridtools.MosaicObj( + gridtiles=[f"tile{i}" for i in range(1, parent.ntiles+1)], + gridfiles=[f"{parent.gridfile}.tile{i}.nc" for i in range(1, parent.ntiles+1)] + ).write(parent.mosaicfile) + + # write grid + for parent in [src, tgt]: + for itile in range(1, parent.ntiles+1): + x1 = np.array([i*parent.dxy for i in range(parent.nx+1)], dtype=np.float64) + y1 = np.array([j*parent.dxy for j in range(parent.ny+1)], dtype=np.float64) + x, y = np.meshgrid(x1, y1) + fmsgridtools.GridObj(x=x, y=y).write(parent.gridfile + f".tile{itile}.nc") + + +def write_data(outfile): + + for itile in range(src.ntiles): + + variable_xt = xr.DataArray( + np.arange(nx, dtype=np.float64), + dims = ["grid_xt"], + attrs={ + "units": "degrees_E", + "long_name": "made up T-cell longitude", + "axis": "X" + } + ) + variable_yt = xr.DataArray( + np.arange(ny, dtype=np.float64), + dims = ["grid_yt"], + attrs={ + "units": "degrees_N", + "long_name": "made up T-cell latitude", + "axis": "Y" + } + ) + variable_pfull = xr.DataArray( + np.arange(nk, dtype=np.float64), + dims = ["pfull"], + attrs = { + "units": "mb", + "long_name": "ref full pressure level", + "axis": "Z", + "positive": "down" + } + ) + variable_time = xr.DataArray( + np.arange(ntimes, dtype=np.float64), + dims = ["time"], + attrs= { + "units": "days since 0001-01-01 00:00:00", + "long_name": "time", + "axis": "T", + "calendar_type": "NOLEAP", + "calendar": "noleap" + } + ) + variable1 = xr.DataArray( + data_list[itile], + dims = ["time", "pfull", "grid_yt", "grid_xt"], + attrs = { + "_FillValue": -np.float64(-100), + "long_name": "test variable 1", + "units": "kg m-2", + "missing_value": missing_value, + "add_offset": np.float64(0.0), + "scale_factor": np.float64(1.0), + "cell_method": "area:mean time:mean", + "cell_measures": "area: pemberley_area", + "standard_name": "Mr. Darcy" + } + ) + variable2 = xr.DataArray( + np.zeros((ny, nx), dtype=np.float32), + dims = ["grid_yt", "grid_xt"], + attrs = { + "_FillValue": False, + "long_name": "test variable 2", + "cell_method": "area:mean", + "standard_name": "Missing cell_measures, expected to fail", + "missing_value": -np.float32(-1.0), + "scale_factor": -np.float32(-0.05), + "offset": np.float32(0.0) + } + ) + data_vars = { + "grid_xt": variable_xt, + "grid_yt": variable_yt, + "pfull": variable_pfull, + "time": variable_time, + "variable1": variable1, + "variable2": variable2 + } + + dataset = xr.Dataset( + data_vars=data_vars, + attrs={"associated_files": "pemberley_area: pemberley.nc longbourn_area: longbourn.nc"} + ) + + dataset.to_netcdf(outfile+f".tile{itile+1}.nc", unlimited_dims="time") + return dataset diff --git a/tests/shared/test_xgridobj.py b/tests/shared/test_xgridobj.py index 64c69d6..3267566 100755 --- a/tests/shared/test_xgridobj.py +++ b/tests/shared/test_xgridobj.py @@ -1,108 +1,128 @@ -import os +""" +test functionalities in xgridobj +""" + +from types import SimpleNamespace import numpy as np -import pytest -import xarray as xr +import pyfms import fmsgridtools +src = SimpleNamespace( + ntiles=6, + nx=12, + ny=24, + dxy=1.0, + mosaicfile="src_mosaic.nc", + gridfile="src_grid" +) +tgt = SimpleNamespace( + ntiles=1, + nx=src.nx * 2, + ny=src.ny * 2, + dxy=src.dxy / 2.0, + mosaicfile="tgt_mosaic.nc", + gridfile="tgt_grid" +) + +nxgrid_per_tile = tgt.nx//2 * tgt.ny//2 +nxgrid = nxgrid_per_tile * 6 +remapfile = "test_remap.nc" + +def make_testfiles(): + + """ + make mosaic and grid files for testing + """ + + # write mosaic + for parent in [src, tgt]: + fmsgridtools.MosaicObj( + gridtiles=[f"tile{i}" for i in range(1, parent.ntiles+1)], + gridfiles=[f"{parent.gridfile}.tile{i}.nc" for i in range(1, parent.ntiles+1)] + ).write(parent.mosaicfile) + + # write grid + for parent in [src, tgt]: + for itile in range(1, parent.ntiles+1): + x1 = np.array([i*parent.dxy for i in range(parent.nx+1)], dtype=np.float64) + y1 = np.array([j*parent.dxy for j in range(parent.ny+1)], dtype=np.float64) + x, y = np.meshgrid(x1, y1) + fmsgridtools.GridObj(x=x, y=y).write(parent.gridfile + f".tile{itile}.nc") + + +#@pytest.mark.parametrize("on_gpu", [False, True]) +def xgridobj_test(on_gpu: bool = False): + + """ + tests generating the exchange grid + tests reading and write exchange grid + """ + + pyfms.fms.init(ndomain=None if on_gpu else 4) + pyfms.horiz_interp.init(ninterp=src.ntiles*2) + + if on_gpu: + domain = None + else: + domain = pyfms.mpp_domains.define_domains([0, tgt.nx-1, 0, tgt.ny-1]) + + if pyfms.mpp.pe() == pyfms.mpp.root_pe(): + make_testfiles() + pyfms.mpp.sync() + + xgrid = fmsgridtools.XGridObj( + src_mosaicfile=src.mosaicfile, + tgt_mosaicfile=tgt.mosaicfile, + domain=domain, + ) + xgrid.set_target_tile("tile1") + xgrid.get_interp(on_gpu=on_gpu) + xgrid.write(outfile=remapfile) -def generate_mosaic(nx: int = 90, ny: int = 45, refine: int = 2): - - xstart, xend = 0, 180 - ystart, yend = -45, 45 - - x_src = np.linspace(xstart, xend, nx+1) - y_src = np.linspace(ystart, yend, ny+1) - x_src, y_src = np.meshgrid(x_src, y_src) - - x_tgt = np.linspace(xstart, xend, nx*refine+1) - y_tgt = np.linspace(ystart, yend, ny*refine+1) - x_tgt, y_tgt = np.meshgrid(x_tgt, y_tgt) - - area_src = np.ones((ny, nx), dtype=np.float64) - area_tgt = np.ones((ny*refine, nx*refine), dtype=np.float64) - - for ifile in ("src", "tgt"): - mosaicfile = ifile + "_mosaic.nc" - gridfile = ifile + "_grid.nc" - gridlocation = "./" - gridtile = "tile1" - xr.Dataset(data_vars=dict(mosaic=mosaicfile.encode(), - gridlocation=gridlocation.encode(), - gridfiles=(["ntiles"], [gridfile.encode()]), - gridtiles=(["ntiles"], [gridtile.encode()])) - ).to_netcdf(mosaicfile) - - - for (x, y, area, prefix) in [(x_src, y_src, area_src, "src"), (x_tgt, y_tgt, area_tgt, "tgt")]: - xr.Dataset(data_vars=dict(x=(["nyp", "nxp"], x), - y=(["nyp", "nxp"], y), - area=(["ny", "nx"], area)) - ).to_netcdf(prefix+"_grid.nc") - - -def remove_mosaic(): - os.remove("src_grid.nc") - os.remove("tgt_grid.nc") - os.remove("src_mosaic.nc") - os.remove("tgt_mosaic.nc") - os.remove("remap.nc") + del xgrid + xgrid = fmsgridtools.XGridObj( + src_mosaicfile=src.mosaicfile, + tgt_mosaicfile=tgt.mosaicfile, + remapfile=remapfile, + tgt_tile = "tile1") + xgrid.read(remapfile=remapfile) -@pytest.mark.parametrize("on_gpu", [False, True]) -def test_create_xgrid(on_gpu): + #answers + area = fmsgridtools.GridObj( + gridfile=tgt.gridfile + ".tile1.nc").read(center=True, radians=True).get_fms_area() - nx, ny, refine = 45, 45, 2 - generate_mosaic(nx=nx, ny=ny, refine=refine) + for tile in xgrid.interps: - xgrid = fmsgridtools.XGridObj(src_mosaic_file="src_mosaic.nc", - tgt_mosaic_file="tgt_mosaic.nc", - on_gpu=on_gpu, - on_agrid=False - ) - xgrid.create_xgrid() - xgrid.to_dataset() - xgrid.dataset["tile1"]["tile1"].to_netcdf("remap.nc") - - del xgrid - - xgrid = fmsgridtools.XGridObj(restart_remap_file="remap.nc") + interp = xgrid.interps[tile] + i_src = interp.i_src + j_src = interp.j_src + i_dst = interp.i_dst + j_dst = interp.j_dst - #check nxcells - nxcells = nx * refine * ny * refine - assert xgrid.nxcells == nxcells + assert interp.nxgrid == tgt.nx//2 * tgt.ny//2, f"src_tile = {tile}, {interp.nxgrid}" - #check parent input cells - answer_i = [i+1 for i in range(nx) for ixcells in range(refine*refine)]*ny - answer_j = [j+1 for j in range(ny) for i in range(nx*refine) for ixcells in range(refine)] + for i in range(interp.nxgrid): - src_i = [xgrid.src_cell[i][0] for i in range(nxcells)] - src_j = [xgrid.src_cell[i][1] for i in range(nxcells)] - - assert src_i == answer_i - assert src_j == answer_j + i_d, j_d = i_dst[i], j_dst[i] + assert i_src[i] == i_d // 2 and j_src[i] == j_d // 2, f"xcell {i}, i_src={i_src[i]}, j_src={j_src[i]} i_dst={i_d}, j_dst={j_d}" - #check parent output cells - answer_i = [] - for j in range(ny): - for i in range(nx): - answer_i += [refine*i + ixcell + 1 for ixcell in range(refine)]*refine + np.testing.assert_almost_equal( + interp.xgrid_area[i], + area[j_d, i_d], + decimal=2, + err_msg=f"tile {tile} gridpoint {i}") - answer_j = [] - for j in range(ny): - for i in range(nx): - for ixcell in range(refine): - answer_j += [j*refine + ixcell + 1]*refine - - tgt_i = [xgrid.tgt_cell[i][0] for i in range(nxcells)] - tgt_j = [xgrid.tgt_cell[i][1] for i in range(nxcells)] + pyfms.fms.end() - assert tgt_i == answer_i - assert tgt_j == answer_j - remove_mosaic() +def test_xgridobj_gpu(): + xgridobj_test(on_gpu=True) +def test_xgridobj_cpu(): + xgridobj_test(on_gpu=False) if __name__ == "__main__": - test_create_xgrid(on_gpu=False) + test_xgridobj_gpu()