diff --git a/FREnctools_lib/pyfrenctools/shared/create_xgrid.py b/FREnctools_lib/pyfrenctools/shared/create_xgrid.py index 2c4f15d..4f8ba11 100644 --- a/FREnctools_lib/pyfrenctools/shared/create_xgrid.py +++ b/FREnctools_lib/pyfrenctools/shared/create_xgrid.py @@ -108,11 +108,7 @@ def transfer_data_gpu(nxcells: int, src_nlon: int, tgt_nlon: int): xarea=xarea[:nxcells]) -def get_2dx2d_order1_gpu(src_nlon: int, - src_nlat: int, - tgt_nlon: int, - tgt_nlat: int, - src_lon: npt.NDArray, +def get_2dx2d_order1_gpu(src_lon: npt.NDArray, src_lat: npt.NDArray, tgt_lon: npt.NDArray, tgt_lat: npt.NDArray, @@ -121,11 +117,18 @@ def get_2dx2d_order1_gpu(src_nlon: int, create_xgrid_order1_gpu_wrapper = _lib.create_xgrid_order1_gpu_wrapper + nyp, nxp = src_lon.shape + src_nlon = nxp - 1 + src_nlat = nyp - 1 + + nyp, nxp = tgt_lon.shape + tgt_nlon = nxp - 1 + tgt_nlat = nyp - 1 + if src_mask is None: src_mask = np.ones((src_nlon*src_nlat), dtype=np.float64) if tgt_mask is None: tgt_mask = np.ones((tgt_nlon*tgt_nlat), dtype=np.float64) arrayptr_double = np.ctypeslib.ndpointer(dtype=np.float64, flags="C_CONTIGUOUS") - create_xgrid_order1_gpu_wrapper.restype = np.int32 create_xgrid_order1_gpu_wrapper.argtypes = [c_int, #src_nlon c_int, #src_nlat diff --git a/fmsgridtools/shared/gridobj.py b/fmsgridtools/shared/gridobj.py index bc0bd78..52e53cd 100644 --- a/fmsgridtools/shared/gridobj.py +++ b/fmsgridtools/shared/gridobj.py @@ -88,23 +88,16 @@ def __init__(self, name: str = None, data=None): class MiniGridObj: - def __init__(self, xsize: int = None, ysize: int = None, x: npt.NDArray = None, y: npt.NDArray = None): - self.xsize = xsize - self.ysize = ysize + def __init__(self, nx: int = None, ny: int = None, nxp: int = None, nyp: int = None, + x: npt.NDArray = None, y: npt.NDArray = None): + + self.nx = nx #number of cells + self.ny = ny #number of cells + self.nxp = nxp #number of gridpoints + self.nyp = nyp #number of gridpoints self.x = x self.y = y - def set_size(self): - - xsize, ysize = self.x.shape - - if self.x.shape != self.y.shape: - logger.error("MiniGrid, x and y differ in dimensions") - - self.xsize = xsize - self.ysize = ysize - - class GridObj: """ Class for grid information @@ -171,7 +164,7 @@ def to_domain(self, domain: dict = None): else: if self.domain is not None: logger.warning("Overwriting %s with %s", self.domain, domain) - self.domain = domain + self.domain = domain if not pyfms.fms.module_is_initialized(): logger.error("Please initialize pyfms first") @@ -218,7 +211,7 @@ def to_radians(self): logger.info("Converting %s to radians", {obj.name}) obj.data = np.radians(obj.data, dtype=np.float64) - def get_fms_area(self): + def get_fms_area(self, gridc: bool = False): """ Compute grid cell areas """ @@ -228,8 +221,13 @@ def get_fms_area(self): if not pyfms.fms.module_is_initialized(): logger.error("Please initialize pyfms first") - x = np.ascontiguousarray(self.x, dtype=np.float64) - y = np.ascontiguousarray(self.y, dtype=np.float64) + if gridc: + x = self.gridc.x + y = self.gridc.y + else: + x = np.ascontiguousarray(self.x, dtype=np.float64) + y = np.ascontiguousarray(self.y, dtype=np.float64) + self.area = pyfms.grid_utils.get_grid_area(lon=x, lat=y, convert_cf_order=False) return self.area @@ -244,7 +242,10 @@ def get_gridc(self): self.gridc.x = np.ascontiguousarray(self.x[::2, ::2]) self.gridc.y = np.ascontiguousarray(self.y[::2, ::2]) - self.gridc.set_size() + + self.gridc.nyp, self.gridc.nxp = self.gridc.x.shape + self.gridc.ny = self.gridc.nyp - 1 + self.gridc.nx = self.gridc.nxp - 1 return self.gridc @@ -259,7 +260,7 @@ def get_gridt(self): self.gridt.x = np.ascontiguousarray(self.x[1::2, 1::2]) self.gridt.y = np.ascontiguousarray(self.x[1::2, 1::2]) - self.gridt.set_size() + self.gridt.nyp, self.gridt.nxp = self.gridt.x.shape return self.gridt diff --git a/fmsgridtools/shared/xgridobj.py b/fmsgridtools/shared/xgridobj.py index ddbd5ad..fe8626b 100755 --- a/fmsgridtools/shared/xgridobj.py +++ b/fmsgridtools/shared/xgridobj.py @@ -1,4 +1,5 @@ -import ctypes +import logging +from pathlib import Path import numpy as np import numpy.typing as npt import xarray as xr @@ -7,226 +8,306 @@ import pyfms from fmsgridtools.shared.gridobj import GridObj -from fmsgridtools.shared.gridtools_utils import check_file_is_there from fmsgridtools.shared.mosaicobj import MosaicObj - -class XGridObj() : - - def __init__(self, - input_dir: str = "./", - src_mosaic_file: str = None, - tgt_mosaic_file: str = None, - restart_remap_file: str = None, - write_remap_file: str = "remap.nc", - src_mosaic: type[MosaicObj] = None, - tgt_mosaic: type[MosaicObj] = None, - src_grid: dict[type[GridObj]] = None, - tgt_grid: dict[type[GridObj]] = None, - dataset: type[xr.Dataset] = None, - datadict: dict = None, - on_agrid: bool = True, - order: int = 1, - on_gpu: bool = False): - self.input_dir = input_dir - self.src_mosaic_file = src_mosaic_file - self.tgt_mosaic_file = tgt_mosaic_file - self.restart_remap_file = restart_remap_file - self.write_remap_file = write_remap_file - self.src_mosaic = src_mosaic - self.tgt_mosaic = tgt_mosaic - self.src_grid = src_grid - self.tgt_grid = tgt_grid +logger = logging.getLogger(__name__) + + +class Parent: + + def __init__( + self, + parent: str, + input_dir: str | Path = "./", + mosaicfile: str | Path = None, + mosaic: MosaicObj = None, + gridfile: str | Path = None, + ntiles: int = None, + grid: dict[str, GridObj] = None, + mask: dict[str, npt.NDArray] = None, + domain: pyfms.Domain = None, + ): + + self.parent = parent + self.input_dir = Path(input_dir) + self.mosaicfile = mosaicfile + self.mosaic = mosaic + self.gridfile = gridfile + self.ntiles = ntiles + self.grid = grid + self.mask = mask + self.domain = domain + + if self.grid is None: + if self.mosaic is None: + if self.mosaicfile is None: + logger.warning("Cannot set %s grid", self.parent) + self.mosaic = MosaicObj( + input_dir=self.input_dir, mosaicfile=self.mosaicfile + ).read() + self.grid = self.mosaic.get_grid( + input_dir=self.input_dir, radians=True, domain=self.domain + ) + logger.info("set grid for %s", self.parent) + + +class XGridObj: + + def __init__( + self, + input_dir: str | Path = "./", + src_mosaicfile: str = None, + tgt_mosaicfile: str = None, + remapfile: str | Path = None, + src_mosaic: MosaicObj = None, + tgt_mosaic: MosaicObj = None, + src_gridfile: str | Path = None, + tgt_gridfile: str | Path = None, + src_grid: dict[str, GridObj] = None, + tgt_grid: dict[str, GridObj] = None, + tgt_tile: str = "tile1", + src_mask: dict[str, np.ndarray] = None, + tgt_mask: dict[str, np.ndarray] = None, + order: int = 1, + domain: pyfms.Domain = None, + use_allpoints: bool = False + ): + + """ + Create an XGridObj container for building/reading remap/interp data. + + Args: + input_dir (str|Path): Base directory for input files. + src_mosaicfile (str): Source mosaic filename (optional). + tgt_mosaicfile (str): Target mosaic filename (optional). + remapfile (str|Path): Remap weights filename (optional). + src_mosaic (MosaicObj): Optional pre-built source MosaicObj. + tgt_mosaic (MosaicObj): Optional pre-built target MosaicObj. + src_gridfile (str|Path): Source grid filename (optional). + tgt_gridfile (str|Path): Target grid filename (optional). + src_grid (dict[str, GridObj]): Optional dict of source GridObj keyed by tile name. + tgt_grid (dict[str, GridObj]): Optional dict of target GridObj keyed by tile name. + tgt_tile (str): Name of the target tile to use (default: "tile1"). + src_mask (dict[str, np.ndarray]): Optional masks for source tiles. + tgt_mask (dict[str, np.ndarray]): Optional masks for the target grid. + order (int): Interpolation order (currently stored; semantics depend on downstream code). + on_gpu (bool): If True, use GPU-based building of xgrid via pyfrenctools. + + The constructed object stores container namespaces for source (`self.src`) and + target (`self.tgt`) data and an `interps` mapping that is populated by + `read` or `get_interp`. + """ + + self.input_dir: str | Path = Path(input_dir) + self.src = Parent( + parent="src", + input_dir=input_dir, + mosaicfile=src_mosaicfile, + gridfile=src_gridfile, + mosaic=src_mosaic, + grid=src_grid, + mask=src_mask, + domain=None, + ) + self.tgt = Parent( + parent="tgt", + mosaicfile=tgt_mosaicfile, + gridfile=tgt_gridfile, + mosaic=tgt_mosaic, + grid=tgt_grid, + mask=tgt_mask, + domain=domain, + ) + + self.tgt_tile = tgt_tile + if self.tgt_tile is not None: + self.tgt.grid = self.tgt.grid[tgt_tile] + + if not use_allpoints: + for key in self.src.grid: + self.src.grid[key] = self.src.grid[key].get_gridc() + self.src.grid[key].free_supergrid() + self.tgt.grid = self.tgt.grid.get_gridc() + self.tgt.grid.free_supergrid() + + + self.use_allpoints = use_allpoints + self.remapfile: str | Path = remapfile self.order = order - self.on_gpu = on_gpu - self.on_agrid = on_agrid - self.dataset = dataset - self.datadict = datadict - self._srcinfoisthere = False - self._tgtinfoisthere = False + self.interps: pyfms.ConserveInterp | dict[str, pyfms.ConserveInterp] = None + - if self._check_restart_remap_file(): return - if self.datadict is not None: return - if self.dataset is not None: return + def set_target_tile(self, tgt_tile: str = "tile1"): + self.tgt_tile = tgt_tile + self.tgt.grid = self.tgt.grid[tgt_tile] - self._check_grid() - self._check_mosaic() - self._check_mosaic_file() - if not self._srcinfoisthere or not self._tgtinfoisthere: - raise RuntimeError("Please provide grid information") + def read( + self, + input_dir: Path | str = None, + remapfile: Path | str = None, + domain: pyfms.Domain = None, + ): + """ + read remap file and store as pyfms.ConserveInterp objects + """ - def read(self, infile: str = None): + input_dir = self.input_dir if input_dir is None else Path(input_dir) - if infile is None: - if self.restart_remap_file is None: - raise RuntimeError("must provide the input remap file for reading") - infile = self.restart_remap_file + if remapfile is None: + if self.remapfile is None: + logger.error("Please specify remapfile to read") + remapfile = self.remapfile + remapfile = input_dir / Path(remapfile) - self.dataset = xr.open_dataset(infile) + if not remapfile.exists(): + logger.error("remap file %s does not exist", self.remapfile) - for key in self.dataset.data_vars.keys(): - setattr(self, key, self.dataset[key].values) + if domain is None: + domain = self.tgt.domain - for key in self.dataset.sizes: - setattr(self, key, self.dataset.sizes[key]) + self.interps = {} + for itile, src_tile in enumerate(self.src.grid): + interp_id = pyfms.horiz_interp.read_weights_conserve( + weight_filename=str(remapfile), + weight_file_src="fregrid", + nlon_src=self.src.grid[src_tile].nx, + nlat_src=self.src.grid[src_tile].ny, + nlon_tgt=self.tgt.grid.nx, + nlat_tgt=self.tgt.grid.ny, + domain=domain, + src_tile=itile, + save_xgrid_area=True, + ) + self.interps[src_tile] = pyfms.ConserveInterp(interp_id, save_xgrid_area=True) - def write(self, outfile: str = None): + def gather(self): - if outfile is None: - outfile = self.write_remap_file + """ + gathers xgrid + """ - if self.dataset is None: - if self.datadict is not None: - self.to_dataset() + isc, jsc = self.tgt.domain.isc, self.tgt.domain.jsc - for tgt_tile in self.dataset: - concat_dataset = xr.concat([self.dataset[tgt_tile][src_tile] for src_tile in self.dataset[tgt_tile]], dim="nxcells") + global_interps = {} + for src_tile in self.interps: + interp = self.interps[src_tile] + global_interp = global_interps[src_tile] = pyfms.ConserveInterp() + nxgrids = pyfms.mpp.gather(np.array([interp.nxgrid], dtype=np.int32), rbuf_size=len(self.interps)) + global_interp.nxgrid = np.sum(nxgrids) if pyfms.mpp.pe() == pyfms.mpp.root_pe() else None + global_interp.i_src = pyfms.mpp.gatherv(interp.i_src, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.j_src = pyfms.mpp.gatherv(interp.j_src, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.i_dst = pyfms.mpp.gatherv(interp.i_dst + isc, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.j_dst = pyfms.mpp.gatherv(interp.j_dst + jsc, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.xgrid_area = pyfms.mpp.gatherv(interp.xgrid_area, ssize=interp.nxgrid, rsize=nxgrids) + return global_interps - concat_dataset.to_netcdf(outfile) + def write(self, output_dir: Path | str = "./", outfile: str | Path = Path("remap.nc")): + """ + write remap file + """ - def to_dataset(self): + global_interps = self.interps if self.tgt.domain is None else self.gather() - if self.datadict is None: raise OSError("datadict is None") + if pyfms.mpp.pe() == pyfms.mpp.root_pe(): - datadict = self.datadict - self.dataset = {} + outfile = Path(output_dir) / outfile + logger.info("writing remap file to %s", outfile) - for tgt_tile in datadict: - self.dataset[tgt_tile] = {} - for src_tile in datadict[tgt_tile]: + datasets = [] + for tile1, src_tile in enumerate(global_interps): - thisdict = datadict[tgt_tile][src_tile] - dataset = self.dataset[tgt_tile][src_tile] = xr.Dataset() + interp = global_interps[src_tile] + dataset = xr.Dataset() - dataset["src_cell"] = xr.DataArray(np.column_stack((thisdict['src_i']+1, thisdict['src_j']+1)), - dims=["nxcells", "two"], - attrs={"src_cell": "parent cell indices in src mosaic", - "_FillValue": False} + dataset["tile1"] = xr.DataArray( + np.full(interp.nxgrid, tile1), + dims=["ncells"], + attrs={"standard_name": "tile_number_in_mosaic1"}, ) - dataset["tgt_cell"] = xr.DataArray(np.column_stack((thisdict['tgt_i']+1, thisdict['tgt_j']+1)), - dims=["nxcells", "two"], - attrs={"tgt_cell": "parent cell indices in tgt mosaic", - "_FillValue": False}) - dataset["xarea"] = xr.DataArray(thisdict['xarea'], - dims=["nxcells"], - attrs={"xarea": "exchange grid area", - "_FillValue": False} + dataset["tile1_cell"] = xr.DataArray( + np.column_stack((interp.i_src + 1, interp.j_src + 1)), + dims=["ncells", "two"], + attrs={"standard_name": "parent_cell_indices_in_mosaic1"}, ) - - - def create_xgrid(self, src_mask: dict[str,npt.NDArray] = None, tgt_mask: dict[str, npt.NDArray] = None) -> dict: - - if self.order not in (1,2): - raise RuntimeError("conservative order must be 1 or 2") - - if self.on_gpu: - create_xgrid_2dx2d_order1 = pyfrenctools.create_xgrid.get_2dx2d_order1_gpu - else: - create_xgrid_2dx2d_order1 = pyfrenctools.create_xgrid.get_2dx2d_order1 - - if self.datadict is None: self.datadict = {} - - for tgt_tile in self.tgt_grid: - - self.datadict[tgt_tile], itile = {}, 1 - - itgt_mask = None if tgt_mask is None else tgt_mask[tgt_tile] - - for src_tile in self.src_grid: - - isrc_mask = None if src_mask is None else src_mask[src_tile] - - xgrid_out = create_xgrid_2dx2d_order1( - src_nlon = self.src_grid[src_tile].nx, - src_nlat = self.src_grid[src_tile].ny, - tgt_nlon = self.tgt_grid[tgt_tile].nx, - tgt_nlat = self.tgt_grid[tgt_tile].ny, - src_lon=self.src_grid[src_tile].x, - src_lat=self.src_grid[src_tile].y, - tgt_lon=self.tgt_grid[tgt_tile].x, - tgt_lat=self.tgt_grid[tgt_tile].y, - src_mask=isrc_mask, - tgt_mask=itgt_mask + dataset["tile2_cell"] = xr.DataArray( + np.column_stack((interp.i_dst + 1, interp.j_dst + 1)), + dims=["ncells", "two"], + attrs={"standard_name": "parent_cell_indices_in_mosaic2"}, ) + dataset["xgrid_area"] = xr.DataArray( + interp.xgrid_area, + dims=["ncells"], + attrs={"standard_name": "exchange_grid_area", "units": "m2"}, + ) + datasets.append(xr.Dataset(dataset)) - nxcells = xgrid_out["nxcells"] - if nxcells > 0: - xgrid_out["tile"] = np.full(nxcells, itile, dtype=np.int32) - self.datadict[tgt_tile][src_tile] = xgrid_out - itile = itile + 1 - - - def to_dataset_raw(self): - - if self.datadict is None: raise RunTimeError("datadict is None") - - for tgt_tile in self.datadict: - for src_tile in self.datadict[tgt_tile]: - src_i = xr.DataArray(self.datadict[tgt_tile][src_tile]["src_i"], dims=["nxcells"], - attrs={"src_i": "parent longitudinal (x) cell indices in src_mosaic"}) - src_j = xr.DataArray(self.datadict[tgt_tile][src_tile]["src_j"], dims=["nxcells"], - attrs={"src_j": "parent latitudinal (y) cell indices in src_mosaic"}) - tgt_i = xr.DataArray(self.datadict[tgt_tile][src_tile]["tgt_i"], dims=["nxcells"], - attrs={"src_i": "parent longitudinal (x) cell indices in src_mosaic"}) - tgt_j = xr.DataArray(self.datadict[tgt_tile][src_tile]["tgt_j"], dims=["nxcells"], - attrs={"src_j": "parent latitudinal (y) cell indices in src_mosaic"}) - xarea = xr.DataArray(self.datadict[tgt_i][src_i]["xarea"], dims=["nxcells"], - attrs={"xarea":"exchange grid cell area (m2)"}) - - self.dataset[tgt_tile][src_tile] = xr.DataSet(data_vars={"src_i": src_i, - "src_j": src_j, - "tgt_i": tgt_i, - "tgt_j": tgt_j, - "xarea": xarea}) - - def _check_restart_remap_file(self): - if self.restart_remap_file is not None : - check_file_is_there(self.restart_remap_file) - self.read() - return True - return False - - - def _check_grid(self): + dataset = xr.concat(datasets, dim="ncells") + encoding = {variable: {"_FillValue": None} for variable in dataset} + dataset.to_netcdf(outfile, encoding=encoding) - if self.src_grid is not None: self._srcinfoisthere = True - if self.tgt_grid is not None: self._tgtinfoisthere = True + pyfms.mpp.sync() - def _check_mosaic(self): + def get_interp(self, on_gpu) -> dict: - if self.src_mosaic is None: return - if self.tgt_mosaic is None: return + """ + call fms to compute xgrid + """ - if self.src_mosaic.grid is None: - self.src_mosaic.get_grid(toradians=True, agrid=self.on_agrid, free_dataset=True) - if self.tgt_mosaic.grid is None: - self.tgt_mosaic.get_grid(toradians=True, agrid=self.on_agrid, free_dataset=True) + self.interps = {} - self.src_grid = self.src_mosaic.grid - self.tgt_grid = self.tgt_mosaic.grid + tgt_grid = tgt.grid[self.tgt_tile] + if self.use_allpoints: + tgt_x, tgt_y = tgt_grid.x, tgt_grid.y + else: + tgt_x, tgt_y = tgt_grid.gridc.x, tgt_grid.gridc.y - self._srcinfoisthere = True - self._tgtinfoisthere = True + for src_tile in self.src.grid: - def _check_mosaic_file(self): + src_grid = self.src.grid[src_tile] + if self.use_allpoints: + src_x, src_y = src_grid.x, src_grid.y + else: + src_x, src_y = src_grid.gridc.x, src_grid.gridc.y - if self.src_mosaic_file is not None: - self.src_grid = MosaicObj(input_dir=self.input_dir, - mosaic_file=self.src_mosaic_file).read().get_grid(toradians=True, - agrid=self.on_agrid, - free_dataset=True) - self._srcinfoisthere = True - if self.tgt_mosaic_file is not None: - self.tgt_grid = MosaicObj(input_dir=self.input_dir, - mosaic_file=self.tgt_mosaic_file).read().get_grid(toradians=True, - agrid=self.on_agrid, - free_dataset=True) - self._tgtinfoisthere = True + src_mask = None if self.src.mask is None else self.src.mask[src_tile] + if on_gpu: + xdict = pyfrenctools.create_xgrid.get_2dx2d_order1_gpu( + src_lon=src_x, + src_lat=src_y, + tgt_lon=self.tgt_x, + tgt_lat=self.tgt_y, + src_mask=src_mask, + tgt_mask=self.tgt.mask, + ) + interp = pyfms.ConserveInterp() + interp.nxgrid = xdict["nxcells"] + interp.i_src = xdict["src_i"] + interp.j_src = xdict["src_j"] + interp.i_dst = xdict["tgt_i"] + interp.j_dst = xdict["tgt_j"] + interp.xgrid_area = xdict["xarea"] + self.interps[src_tile] = interp + else: + interp_id = pyfms.horiz_interp.get_weights( + lon_in=src_x, + lat_in=src_y, + lon_out=self.tgt_x, + lat_out=self.tgt_y, + mask_in=src_mask, + mask_out=self.tgt.mask, + is_latlon_in=False, + is_latlon_out=False, + save_xgrid_area=True, + convert_cf_order=False, + as_fregrid=True, + interp_method="conservative", + ) + self.interps[src_tile] = pyfms.ConserveInterp(interp_id, save_xgrid_area=True) diff --git a/setup.py b/setup.py index 00074dc..d0a3d1b 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def local_pkg(name: str, relative_path: str) -> str: "numpy", "xarray", "netCDF4", - local_pkg("pyFMS", "pyFMS"), +# local_pkg("pyFMS", "pyFMS"), local_pkg("pyfrenctools", "FREnctools_lib") ] diff --git a/tests/shared/test_xgridobj.py b/tests/shared/test_xgridobj.py index 226f6a9..6dd142b 100755 --- a/tests/shared/test_xgridobj.py +++ b/tests/shared/test_xgridobj.py @@ -1,66 +1,98 @@ -import os +""" +test functionalities in xgridobj +""" + +from types import SimpleNamespace import numpy as np -import pytest -import xarray as xr +import pyfms import fmsgridtools - -def generate_mosaic(nx: int = 90, ny: int = 45, refine: int = 2): - - xstart, xend = 0, 180 - ystart, yend = -45, 45 - - x_src = np.linspace(xstart, xend, nx+1) - y_src = np.linspace(ystart, yend, ny+1) - x_src, y_src = np.meshgrid(x_src, y_src) - +src = SimpleNamespace( + ntiles=6, + nx=12, + ny=24, + dxy=1.0, + mosaicfile="src_mosaic.nc", + gridfile="src_grid" +) +tgt = SimpleNamespace( + ntiles=1, + nx=src.nx * 2, + ny=src.ny * 2, + dxy=src.dxy / 2.0, + mosaicfile="tgt_mosaic.nc", + gridfile="tgt_grid" +) + +nxgrid_per_tile = tgt.nx//2 * tgt.ny//2 +nxgrid = nxgrid_per_tile * 6 +remapfile = "test_remap.nc" + +def make_testfiles(): + + """ + make mosaic and grid files for testing + """ + +<<<<<<< HEAD + # write mosaic + for parent in [src, tgt]: + fmsgridtools.MosaicObj( + gridtiles=[f"tile{i}" for i in range(1, parent.ntiles+1)], + gridfiles=[f"{parent.gridfile}.tile{i}.nc" for i in range(1, parent.ntiles+1)] + ).write(parent.mosaicfile) +======= x_tgt = np.linspace(xstart, xend, nx*refine+1) y_tgt = np.linspace(ystart, yend, ny*refine+1) x_tgt, y_tgt = np.meshgrid(x_tgt, y_tgt) area_src = np.ones((ny, nx), dtype=np.float64) area_tgt = np.ones((ny*refine, nx*refine), dtype=np.float64) +>>>>>>> origin/main - for ifile in ("src", "tgt"): - mosaicfile = ifile + "_mosaic.nc" - gridfile = ifile + "_grid.nc" - gridlocation = "./" - gridtile = "tile1" - xr.Dataset(data_vars=dict(mosaic=mosaicfile.encode(), - gridlocation=gridlocation.encode(), - gridfiles=(["ntiles"], [gridfile.encode()]), - gridtiles=(["ntiles"], [gridtile.encode()])) - ).to_netcdf(mosaicfile) - + # write grid + for parent in [src, tgt]: + for itile in range(1, parent.ntiles+1): + x1 = np.array([i*parent.dxy for i in range(parent.nx+1)], dtype=np.float64) + y1 = np.array([j*parent.dxy for j in range(parent.ny+1)], dtype=np.float64) + x, y = np.meshgrid(x1, y1) + fmsgridtools.GridObj(x=x, y=y).write(parent.gridfile + f".tile{itile}.nc") - for (x, y, area, prefix) in [(x_src, y_src, area_src, "src"), (x_tgt, y_tgt, area_tgt, "tgt")]: - xr.Dataset(data_vars=dict(x=(["nyp", "nxp"], x), - y=(["nyp", "nxp"], y), - area=(["ny", "nx"], area)) - ).to_netcdf(prefix+"_grid.nc") +#@pytest.mark.parametrize("on_gpu", [False, True]) +def xgridobj_test(on_gpu: bool = False): -def remove_mosaic(): - os.remove("src_grid.nc") - os.remove("tgt_grid.nc") - os.remove("src_mosaic.nc") - os.remove("tgt_mosaic.nc") - os.remove("remap.nc") + """ + tests generating the exchange grid + tests reading and write exchange grid + """ + pyfms.fms.init(ndomain=None if on_gpu else 4) + pyfms.horiz_interp.init(ninterp=src.ntiles*2) -@pytest.mark.parametrize("on_gpu", [False, True]) -def test_create_xgrid(on_gpu): + if on_gpu: + domain = None + else: + domain = pyfms.mpp_domains.define_domains([0, tgt.nx-1, 0, tgt.ny-1]) - nx, ny, refine = 45, 45, 2 - generate_mosaic(nx=nx, ny=ny, refine=refine) + if pyfms.mpp.pe() == pyfms.mpp.root_pe(): + make_testfiles() + pyfms.mpp.sync() - xgrid = fmsgridtools.XGridObj(src_mosaic_file="src_mosaic.nc", - tgt_mosaic_file="tgt_mosaic.nc", - on_gpu=on_gpu, - on_agrid=False + xgrid = fmsgridtools.XGridObj( + src_mosaicfile=src.mosaicfile, + tgt_mosaicfile=tgt.mosaicfile, + domain=domain, ) +<<<<<<< HEAD + xgrid.set_target_tile("tile1") + xgrid.get_interp(on_gpu=on_gpu) + xgrid.write(outfile=remapfile) + + del xgrid +======= xgrid.create_xgrid() xgrid.to_dataset() xgrid.dataset["tile1"]["tile1"].to_netcdf("remap.nc") @@ -68,27 +100,38 @@ def test_create_xgrid(on_gpu): del xgrid xgrid = fmsgridtools.XGridObj(restart_remap_file="remap.nc") - - #check nxcells - nxcells = nx * refine * ny * refine - assert xgrid.nxcells == nxcells - - #check parent input cells - answer_i = [i+1 for i in range(nx) for ixcells in range(refine*refine)]*ny - answer_j = [j+1 for j in range(ny) for i in range(nx*refine) for ixcells in range(refine)] - +>>>>>>> origin/main + + xgrid = fmsgridtools.XGridObj( + src_mosaicfile=src.mosaicfile, + tgt_mosaicfile=tgt.mosaicfile, + remapfile=remapfile, + tgt_tile = "tile1") + xgrid.read(remapfile=remapfile) + + #answers + area = fmsgridtools.GridObj( + gridfile=tgt.gridfile + ".tile1.nc").read(center=True, radians=True).get_fms_area() + +<<<<<<< HEAD + for tile in xgrid.interps: +======= src_i = [xgrid.src_cell[i][0] for i in range(nxcells)] src_j = [xgrid.src_cell[i][1] for i in range(nxcells)] assert src_i == answer_i assert src_j == answer_j +>>>>>>> origin/main - #check parent output cells - answer_i = [] - for j in range(ny): - for i in range(nx): - answer_i += [refine*i + ixcell + 1 for ixcell in range(refine)]*refine + interp = xgrid.interps[tile] + i_src = interp.i_src + j_src = interp.j_src + i_dst = interp.i_dst + j_dst = interp.j_dst +<<<<<<< HEAD + assert interp.nxgrid == tgt.nx//2 * tgt.ny//2, f"src_tile = {tile}, {interp.nxgrid}" +======= answer_j = [] for j in range(ny): for i in range(nx): @@ -97,12 +140,27 @@ def test_create_xgrid(on_gpu): tgt_i = [xgrid.tgt_cell[i][0] for i in range(nxcells)] tgt_j = [xgrid.tgt_cell[i][1] for i in range(nxcells)] +>>>>>>> origin/main + + for i in range(interp.nxgrid): + + i_d, j_d = i_dst[i], j_dst[i] + assert i_src[i] == i_d // 2 and j_src[i] == j_d // 2, f"xcell {i}, i_src={i_src[i]}, j_src={j_src[i]} i_dst={i_d}, j_dst={j_d}" + + np.testing.assert_almost_equal( + interp.xgrid_area[i], + area[j_d, i_d], + decimal=2, + err_msg=f"tile {tile} gridpoint {i}") + + pyfms.fms.end() - assert tgt_i == answer_i - assert tgt_j == answer_j - remove_mosaic() +def test_xgridobj_gpu(): + xgridobj_test(on_gpu=True) +def test_xgridobj_cpu(): + xgridobj_test(on_gpu=False) if __name__ == "__main__": - test_create_xgrid(on_gpu=False) + test_xgridobj_gpu()