From 73de44f7d9ecdb5bb599f732f6ed6e033f65c7a1 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Thu, 20 Nov 2025 13:46:28 -0500 Subject: [PATCH 01/27] initial files --- fmsgridtools/__init__.py | 4 + fmsgridtools/make_mosaic/coupler_mosaic.py | 12 +- fmsgridtools/make_mosaic/solo_mosaic.py | 4 +- fmsgridtools/shared/gridobj.py | 588 ++++++++++++++++----- fmsgridtools/shared/mosaicobj.py | 400 ++++++++++---- fmsgridtools/utils/setlogger.py | 24 +- setup.py | 2 +- tests/shared/test_gridobj.py | 207 +++++--- 8 files changed, 914 insertions(+), 327 deletions(-) diff --git a/fmsgridtools/__init__.py b/fmsgridtools/__init__.py index beb9bca..683eaa7 100644 --- a/fmsgridtools/__init__.py +++ b/fmsgridtools/__init__.py @@ -9,3 +9,7 @@ from .shared.gridobj import GridObj from .shared.mosaicobj import MosaicObj from .shared.xgridobj import XGridObj + +import fmsgridtools.utils.setlogger as setlogger + +setlogger.setconfig(filename="FMSGRIDTOOLS.LOG", debug=True) diff --git a/fmsgridtools/make_mosaic/coupler_mosaic.py b/fmsgridtools/make_mosaic/coupler_mosaic.py index 41b04d1..592d0f0 100644 --- a/fmsgridtools/make_mosaic/coupler_mosaic.py +++ b/fmsgridtools/make_mosaic/coupler_mosaic.py @@ -193,14 +193,14 @@ def make(atm_mosaic_file: str, lnd_mosaic_file: str, ocn_mosaic_file: str, topog_file: str, input_dir: str = "./", on_gpu: bool = False): #read in mosaic files - atm_mosaic = MosaicObj(input_dir=input_dir, mosaic_file=atm_mosaic_file).read() - lnd_mosaic = MosaicObj(input_dir=input_dir, mosaic_file=lnd_mosaic_file).read() - ocn_mosaic = MosaicObj(input_dir=input_dir, mosaic_file=ocn_mosaic_file).read() + atm_mosaic = MosaicObj(input_dir=input_dir, mosaicfile=atm_mosaic_file).read() + lnd_mosaic = MosaicObj(input_dir=input_dir, mosaicfile=lnd_mosaic_file).read() + ocn_mosaic = MosaicObj(input_dir=input_dir, mosaicfile=ocn_mosaic_file).read() #read in grids - atm_mosaic.get_grid(toradians=True, agrid=True, free_dataset=True) - lnd_mosaic.get_grid(toradians=True, agrid=True, free_dataset=True) - ocn_mosaic.get_grid(toradians=True, agrid=True, free_dataset=True) + atmg_grids = atm_mosaic.get_grid(radians=True, center=True) + lnd_grids =lnd_mosaic.get_grid(radians=True, center=True) + ocn_grids = ocn_mosaic.get_grid(radians=True, center=True) #get ocean mask topogfile_dict = {'tile1': input_dir + '/' + topog_file} diff --git a/fmsgridtools/make_mosaic/solo_mosaic.py b/fmsgridtools/make_mosaic/solo_mosaic.py index 3de77e9..1fca5a4 100644 --- a/fmsgridtools/make_mosaic/solo_mosaic.py +++ b/fmsgridtools/make_mosaic/solo_mosaic.py @@ -36,7 +36,7 @@ def make(num_tiles, if len(gridtiles) != num_tiles: sys.exit("Error, number of gridtiles does not equal num_tiles") - grid = MosaicObj(ntiles=num_tiles, gridfiles = tilefiles, gridtiles = gridtiles).get_grid(toradians=True) + grid = MosaicObj(ntiles=num_tiles, gridfiles = tilefiles, gridtiles = gridtiles).get_grid(radians=True) ncontact, contacts, contact_index = 0, [], [] #FIND CONTACT REGIONS @@ -64,7 +64,7 @@ def make(num_tiles, print(f"NOTE: There are {ncontact} contacts\n") if ncontact > 0: - mosaic = MosaicObj(name=mosaic_name, + mosaic = MosaicObj(mosaic=mosaic_name, gridlocation=dir_name, gridfiles=tilefiles, gridtiles=gridtiles, diff --git a/fmsgridtools/shared/gridobj.py b/fmsgridtools/shared/gridobj.py index 47c1b62..1922047 100644 --- a/fmsgridtools/shared/gridobj.py +++ b/fmsgridtools/shared/gridobj.py @@ -1,166 +1,510 @@ -import dataclasses -from typing import List, Optional +""" +GridObj: +Class for containing basic grid data to be used by other grid objects +""" + +import logging +from pathlib import Path +from types import SimpleNamespace + import numpy as np import numpy.typing as npt import xarray as xr -from fmsgridtools.shared.gridtools_utils import check_file_is_there +import pyfms + +logger = logging.getLogger(__name__) + +attrs = {} +attrs["x"] = dict( + standard_name = "geographic_longitude", + units = "degree_east" +) +attrs["y"] = dict( + standard_name = "geographic_latitude", + units = "degrees_north" +) +attrs["tile"] = {} +attrs["tile_options"] = {} +attrs["tile_options"]["cubic"] = dict( + standard_name = "grid_tile_spec", + geometry = "spherical", + north_pole = "0.0 90.0", + discretization = "logically_rectangular", + conformal = "false" +) +attrs["tile_options"]["simple_cartesian"] = dict( + standard_name = "grid_tile_spec", + geometry = "planar", + discretization = "logically_rectangular", + conformal = "true" +) +attrs["tile_options"]["none"] = dict( + standard_name = "grid_tile_spec", + geometry = "spherical", + north_pole = "0.0 90.0", + projection = "none", + discretization = "logically_rectangular", + conformal = "true" +) + +attrs["dx"] = dict( + standard_name = "grid_edge_x_distance", + units = "meters" +) +attrs["dy"] = dict( + standard_name = "grid_edge_y_distance", + units = "meters" +) +attrs["area"] = dict( + standard_name = "grid_cell_area", + units = "m2" +) +attrs["angle_dx"] = dict( + standard_name = "grid_vertex_x_angle_WRT_geographic_east", + units = "degrees_east" +) +attrs["angle_dy"] = dict( + standard_name = "grid_vertex_y_angle_WRT_geographic_north", + units = "degrees_north" +) +attrs["arcx"] = dict( + standard_name = "grid_edge_x_arc_type", + north_pole = "0.0,90.0" +) + +dims = {} +dims["x"] = ["nyp", "nxp"] +dims["y"] = ["nyp", "nxp"] +dims["dx"] = ["nyp", "nx"] +dims["dy"] = ["ny", "nxp"] +dims["area"] = ["ny", "nx"] +dims["angle_dx"] = ["nyp", "nxp"] +dims["angle_dy"] = ["nyp", "nxp"] +dims["tile"] = () +dims["arcx"] = () -""" -GridObj: - -Class for containing basic grid data to be used by other grid objects -""" class GridObj: - def __init__(self, dataset: type[xr.Dataset] = None, gridfile: str = None): + """ + Class for grid information + """ + + def __init__(self, + input_dir: str = "./", + gridfile: str = None, + domain: pyfms.Domain = None, + gridtype: str = None, + tile: str = None, + nx: int = None, + ny: int = None, + nxp: int = None, + nyp: int = None, + x: npt.NDArray = None, + y: npt.NDArray = None, + dx: npt.NDArray = None, + dy: npt.NDArray = None, + area: npt.NDArray = None, + angle_dx: npt.NDArray = None, + angle_dy: npt.NDArray = None, + arcx: npt.NDArray = None, + on_gpu: bool = False + ): + + self.input_dir = Path(input_dir) self.gridfile = gridfile - self.tile = None - self.nx = None - self.ny = None - self.nxp = None - self.nyp = None - self.tile = None - self.x = None - self.y = None - self.dx = None - self.dy = None - self.area = None - self.angle_dx = None - self.angle_dy = None - self.arcx = None - self.dataset = dataset + self.domain = domain + + self.nx = nx + self.ny = ny + self.nxp = nxp + self.nyp = nyp + + self.gridtype = gridtype + + self.x_obj = SimpleNamespace( + name="x", + data=x, + ) + self.y_obj = SimpleNamespace( + name="y", + data=y, + ) + self.tile_obj = SimpleNamespace( + name="tile", + data=tile, + ) + self.dx_obj = SimpleNamespace( + name="dx", + data=dx, + ) + self.dy_obj = SimpleNamespace( + name="dy", + data=dy, + ) + self.area_obj = SimpleNamespace( + name="area", + data=area, + ) + self.angle_dx_obj = SimpleNamespace( + name="angle_dx", + data=angle_dx, + ) + self.angle_dy_obj = SimpleNamespace( + name="angle_dy", + data=angle_dy, + ) + self.arcx_obj = SimpleNamespace( + name="arcx", + data=arcx, + ) + self.objlist = [ + self.x_obj, + self.y_obj, + self.dx_obj, + self.dy_obj, + self.area_obj, + self.angle_dx_obj, + self.angle_dy_obj, + self.arcx_obj, + self.tile_obj + ] + + def to_domain(self): - """ - read: - This function reads in the gridfile and initializes the instance variables - """ - def read(self, toradians: bool = False, agrid: bool = False, free_dataset: bool = False): + """ + Stores data on the compute domain + """ + + isc, iec, jsc, jec = self.domain.isc, self.domain.iec, self.domain.jsc, self.domain.jec + + for obj in self.objlist: + if obj.data is not None and not self.arcx_obj: + logger.info("saving %s on domain", {obj.name}) + edge = 1 if obj is self.area_obj else 2 + obj.data = np.ascontiguousarray(obj.data[jsc:jec+edge, isc:iec+edge]) + + self._set_dims(on_domain=True) + + return self + + + def to_radians(self): + + """ + Converts data from degres to radians + """ - check_file_is_there(self.gridfile) - self.dataset = xr.open_dataset(self.gridfile) - self.get_attributes() + objlist = [self.x_obj, self.y_obj, self.dx_obj, self.dy_obj, self.angle_dx_obj, self.angle_dy_obj] + for obj in objlist: + if obj.data is not None: + logger.info("converting %s to radians", {obj.name}) + obj.data = np.radians(obj.data, dtype=np.float64) + - if free_dataset: - del self.dataset - self.dataset = None + def get_fms_area(self): + + """ + Compute grid cell areas + """ + + logger.info("computing grid cell area with fms") + + x = np.ascontiguousarray(self.x, dtype=np.float64) + y = np.ascontiguousarray(self.y, dtype=np.float64) + self.area = pyfms.grid_utils.get_grid_area(lon=x, lat=y, convert_cf_order=False) + return self.area + + + def read(self, radians: bool = False, center: bool = False, on_domain: bool = False, xy_only: bool = True): + + """ + Reads in the gridfile and initializes the instance variables + """ + + if xy_only: + logger.info("reading only x and y coordinates from file %s", {self.gridfile}) + objlist = [self.x_obj, self.y_obj] + else: + objlist = self.objlist + logger.info(f"reading in file {self.gridfile}") + + with xr.open_dataset(self.input_dir/self.gridfile) as ds: + for obj in objlist: + if obj.name in ds: + obj.data = ds[obj.name].data + if center: + logger.info("saving center points for %s", {obj.name}) + obj.data = np.ascontiguousarray(obj.data[::2, ::2]) + else: + logger.error("could not %s in %s", {obj.name}, {self.gridfile}) + + if radians: + self.to_radians() - if toradians: - self.x = np.radians(self.x, dtype=np.float64) - self.y = np.radians(self.y, dtype=np.float64) + if on_domain: + if self.domain is None: + logger.error("please specify domain by ") + self.to_domain() + + self._set_dims(ds.sizes, center=center, on_domain=on_domain) + + return self + - if agrid: - self.x, self.y = self.agrid() - [self.nyp, self.nxp] = self.x.shape + def write(self, gridfile: str = None): + + """ + write_out_grid: + This method will generate a netcdf file containing grid content + """ + + if gridfile is None: + if self.gridfile is None: + logger.error("must provide gridfile name") + gridfile = self.gridfile + + logger.info()"writing out gridfile %s", {gridfile}) + + if self.gridtype == "none": + attrs["tile"] = attrs["tile_options"]["none"] + elif self.gridtype == "cubic": + attrs["tile"] = attrs["tile_options"]["cubic"] + elif self.gridtype == "simple_cartesian": + attrs["tile"] = attrs["tile_options"]["simple_cartesian"] + + ds = {} + for obj in self.objlist: + if obj.data is not None: + name = obj.name + ds[name] = xr.DataArray( + data = obj.data, + attrs = attrs[name], + dims=dims[name] + ) + logger.info(ds[name]) + + xr.Dataset(data_vars=ds).to_netcdf(gridfile) + + + def _set_dims(self, dims: dict = None, center: bool = True, on_domain: bool = False): + + + if on_domain: + self.nx = self.domain.xsize_c + self.ny = self.domain.ysize_c + self.nxp = self.nx + 1 + self.nyp = self.ny + 1 + return + + self.nxp = dims.get("nxp") + self.nyp = dims.get("nyp") + self.nx = dims.get("nx") + self.ny = dims.get("ny") + + #nx + if self.nxp is None: + if self.nx is None: + logger.error("cannot set dimension nxp") + self.nxp = self.nx + 1 + elif self.nx is None: + if self.nxp is None: + logger.error("cannot set dimension nx") self.nx = self.nxp - 1 + + #ny + if self.nyp is None: + if self.ny is None: + logger.error("cannot set dimension nyp") + self.nyp = self.ny + 1 + elif self.ny is None: + if self.nyp is None: + logger.error("cannot set dimension ny") self.ny = self.nyp - 1 - - return self - - def get_attributes(self): + if center: + self.nx = self.nx // 2 + self.ny = self.ny // 2 + self.nxp = self.nx + 1 + self.nyp = self.ny + 1 - for key in self.dataset.data_vars: - if isinstance(self.dataset.data_vars[key].values, np.ndarray): - setattr(self, key, self.dataset[key].values) - else: - setattr(self, key, str(self.dataset[key].astype(str).values)) - for key in self.dataset.sizes: - setattr(self, key, self.dataset.sizes[key]) + @property + def x(self): - """ - write_out_grid: - This method will generate a netcdf file containing the contents of the - dataset attribute. - """ - def write(self, filepath: str): + """ + retrieve x + """ - if self.dataset is not None: - self.dataset.to_netcdf(filepath) + return self.x_obj.data + @x.setter + def x(self, data): - """ - get_variable_list: - This method returns a list of variables contained within the dataset. - """ - def get_variable_list(self) -> list: + """ + set x + """ - if self.dataset is not None: - return list(self.dataset.data_vars.keys()) - else: - return None + self.x_obj.data = data - - def x_contiguous(self): + @property + def y(self): - if self.x is not None: - return np.ascontiguousarray(self.x) - else: - return None + """ + retrieve y + """ - - def y_contiguous(self): + return self.y_obj.data - if self.y is not None: - return np.ascontiguousarray(self.y) - else: - return None + @y.setter + def y(self, data): - - def dx_contiguous(self): + """ + set y + """ - if self.dx is not None: - return np.ascontiguousarray(self.dx) - else: - return None + self.y_obj.data = data - - def dy_contiguous(self): + @property + def tile(self): - if self.dy is not None: - return np.ascontiguousarray(self.dataset.dy) - else: - return None + """" + retrieve tile + """ - - def area_contiguous(self): + return self.tile_obj.data - if self.area is not None: - return np.ascontiguousarray(self.dataset.area) - else: - return None + @tile.setter + def tile(self, data): - - def angle_dx_contiguous(self): + """ + set tile + """ - if self.angle_dx is not None: - return np.ascontiguousarray(self.dataset.angle_dx) - else: - return None + self.tile_obj.data = data - - def angle_dy_contiguous(self): + @property + def dx(self): - if self.angle_dy is not None: - return np.ascontiguousarray(self.dataset.angle_dy) - else: - return None + """ + retrieve dx + """ + + return self.dx_obj.data + + @dx.setter + def dx(self, data): + + """ + set dx + """ + + self.dx_obj.data = data + + @property + def dy(self): + + """ + retrieve dy + """ + + return self.dy_obj.data + + @dy.setter + def dy(self, data): + + """ + set dy + """ + + self.dy_obj.data = data + + @property + def area(self): + + """ + retrieve area + """ + + return self.area_obj.data + + @area.setter + def area(self, data): + + """ + set area + """ + + self.area_obj.data = data + + @property + def angle_dx(self): + + """ + retrieve angle_dx + """ + + return self.angle_dx_obj.data + + @angle_dx.setter + def angle_dx(self, data): + + """ + set angle_dx + """ + + self.angle_dx_obj.data = data + + @property + def angle_dy(self): + + """ + retrieve angle_dy + """ + + return self.angle_dy_obj.data + + @angle_dy.setter + def angle_dy(self, data): + + """ + set angle_dy + """ + + self.angle_dy_obj.data = data + + @property + def arcx(self): + + """ + retrieve arcx + """ + + return self.arcx_obj.data + + @arcx.setter + def arcx(self, data): + + """ + set arcx + """ + + self.arcx_obj.data = data + + + def __repr__(self): + summary = f"\n\nGrid for {self.gridfile}, tile = {self.tile_obj.name}\n" + summary += "nx = {:>5} ny = {:>5} nxp = {:>5} nyp = {:>5}\n".format(self.nx, self.ny, self.nxp, self.nyp) + summary += f"gridtype = {self.gridtype}\n" + + for obj in self.objlist: + summary += f"{obj.name} = {obj.data}\n" + + return summary - """ - get_agrid_lonlat: - This method returns the lon and lat for the A-grid as calculated from the - x and y attributes of the GridObj. - """ - def agrid(self)-> tuple[npt.NDArray, npt.NDArray]: - if self.x is not None and self.y is not None: - a_lon = np.ascontiguousarray(self.x[::2, ::2]) - a_lat = np.ascontiguousarray(self.y[::2, ::2]) - return a_lon, a_lat -#TODO: I/O method for passing to the host diff --git a/fmsgridtools/shared/mosaicobj.py b/fmsgridtools/shared/mosaicobj.py index ab0397f..c845201 100644 --- a/fmsgridtools/shared/mosaicobj.py +++ b/fmsgridtools/shared/mosaicobj.py @@ -1,137 +1,343 @@ +""" +MosaicObj class +""" -from typing import Optional, Dict, List, Any +from pathlib import Path +from types import SimpleNamespace + +import numpy as np import xarray as xr + +import pyfms from fmsgridtools.shared.gridobj import GridObj -from fmsgridtools.shared.gridtools_utils import check_file_is_there + +attrs = dict( + mosaic = dict( + standard_name="grid_mosaic_spec", + contact_regions="contacts", + children="gridtiles", + grid_descriptor="" + ), + gridlocation = dict( + standard_name="grid_file_location" + ), + gridfiles = dict(), + gridtiles = dict(), + contacts = dict( + standard_name="grid_contact_spec", + contact_type="boundary", + alignment="true", + contact_index="contact_index", + orientation="orient" + ), + contact_index = dict( + standard_name="starting_ending_point_index_of_contact" + ) +) + +dims = dict( + mosaic = (), + gridlocation = (), + gridfiles = ["ntiles"], + gridtiles = ["ntiles"], + contacts = ["ncontact"], + contact_index = ["ncontact"] +) + + +def set_attribute(variable: str, var_attr: dict): + + global attrs + + if variable in attrs: + attrs[variable] = var_attr + else: + raise RuntimeError(f"{variable} does not exist in attributes") + + +def set_dims(variable: str, var_dim: list): + + global dims + + if variable in dims: + dims[variable] = var_dim + else: + raise RuntimeError(f"{variable} does not exist in dims") class MosaicObj: - def __init__(self, input_dir: str = "./", - mosaic_file: str = None, - name: str = None, + """ + MosaicObj + """ + + def __init__(self, + input_dir: str = "./", + mosaicfile: str = None, + mosaic: str = None, ntiles: int = None, gridlocation: str = "./", gridfiles: list[str] = None, gridtiles: list[str] = None, contacts: list[str] = None, contact_index: list[str] = None, - dataset: type[xr.Dataset] = None, - grid: dict = None): - - self.input_dir = input_dir+"/" - self.mosaic_file = mosaic_file - self.name = name - self.gridlocation = gridlocation - self.gridfiles = gridfiles - self.gridtiles = gridtiles - self.contacts = contacts - self.contact_index = contact_index - self.dataset = dataset - self.grid = grid + ): + + self.input_dir = Path(input_dir) + self.mosaicfile = mosaicfile + self.ntiles = ntiles - #for key, value in self.__dict__.items(): - # if key == 'gridfiles' or 'gridtiles' or 'contacts' or 'contact_index': - # if value is None: - # self.__dict__[key] = [] - # if key == 'grid': - # if value is None: - # self.__dict__[key] = {} + self.ncontacts = None + + self.mosaic_obj = SimpleNamespace( + name = "mosaic", + data = mosaic + ) + self.gridlocation_obj = SimpleNamespace( + name = "gridlocation", + data = gridlocation + ) + self.gridfiles_obj = SimpleNamespace( + name = "gridfiles", + data = gridfiles + ) + self.gridtiles_obj = SimpleNamespace( + name = "gridtiles", + data = gridtiles + ) + self.contacts_obj = SimpleNamespace( + name = "contacts", + data = contacts + ) + self.contact_index_obj = SimpleNamespace( + name = "contact_index", + data = contact_index + ) + self.objlist = [ + self.mosaic_obj, + self.gridlocation_obj, + self.gridfiles_obj, + self.gridtiles_obj, + self.contacts_obj, + self.contact_index_obj + ] + + + def read(self, mosaicfile: str|Path = None, input_dir: str|Path = "."): + + """ + Read the mosac file + """ + + if mosaicfile is None: + if self.mosaicfile is None: + raise IOError("Please specify the mosaic file") + mosaicfile = self.mosaicfile + + with xr.open_dataset(Path(self.input_dir)/mosaicfile) as ds: + + for obj in self.objlist: + variable = ds.get(obj.name) + if variable is not None: + if variable.dtype is bytes: + variable = variable.str.decode(encoding="utf-8") + if isinstance(variable.data, np.ndarray): + obj.data = variable.data.tolist() + else: + obj.data = str(variable.data) + self.ntiles = ds.sizes.get("ntiles") + self.ncontacts = ds.sizes.get("ncontact") + self.input_dir = input_dir + self.mosaicfile = mosaicfile + return self - def read(self): - if self.mosaic_file is None: - raise IOError("Please specify the mosaic file") + def from_dict(self, mosaic_dict: dict): - check_file_is_there(self.input_dir+self.mosaic_file) - self.dataset = xr.open_dataset(self.input_dir+self.mosaic_file) + """ + Generate mosaic file from dictionary + """ - self.get_attributes() + names = [obj.name for obj in self.objlist] + for key in mosaic_dict: + if key not in names: + raise RuntimeError(f"{key} not a field in MosaicObj") - if hasattr(self, "mosaic_name"): - self.name = self.mosaic_name - else: - self.name = self.mosaic_file[:-3] + for key in mosaic_dict: + for obj in self.objlist: + if obj.name == key: + obj.data = mosaic_dict[key] + + for obj in self.objlist: + if obj.data is None: + printf(f"{obj.name} not set") + + self.ntiles = None if self.gridfiles is None else len(self.gridfiles.data) + self.ncontacts = None if self.contacts is None else len(self.contacts.data) return self - def get_attributes(self) -> None: - for key in self.dataset.data_vars: - setattr(self, key, self.dataset[key].astype(str).values) + def get_grid(self, input_dir: str|Path = "./", + radians: bool = False, + center: bool = False, + domain: pyfms.Domain = None) -> dict: - for key in self.dataset.sizes: - setattr(self, key, self.dataset.sizes[key]) + """ + Get grids from gridfiles + """ - def add_attributes(self, attribute: str, value: Any = None) -> None: + if self.gridfiles is None: + raise RuntimeError("need to set gridfiles") - setattr(self, attribute, value) + if self.gridtiles is None: + raise RuntimeError("need to set gridtiles") + if self.ntiles is None: + ntiles = len(self.gridfiles) - def get_grid(self, toradians: bool = False, agrid: bool = False, free_dataset: bool = False) -> dict: + grid = {} - if self.grid is None: self.grid = {} - for i in range(self.ntiles): - gridfile = str(self.input_dir) + str(self.gridlocation) + str(self.gridfiles[i]) - self.grid[self.gridtiles[i]] = GridObj(gridfile=gridfile).read(toradians=toradians, - agrid=agrid, - free_dataset=free_dataset) + for gridfile, gridtile in zip(self.gridfiles, self.gridtiles): + readfile = Path(input_dir)/gridfile + grid[gridtile] = GridObj(gridfile=readfile).read_xy(radians=radians, center=center, domain=domain) - return self.grid + return grid - def write(self, outfile: str = None) -> None: + def write(self, mosaicfile: str = None) -> None: - dataset = xr.Dataset() - if self.name is not None: - dataset["mosaic"] = xr.DataArray( - data=self.name.encode(), - attrs=dict( - standard_name="grid_mosaic_spec", - contact_regions="contacts", - children="gridtiles", - grid_descriptor="" - ) - ) + """ + write mosaic file + """ - if self.gridlocation is not None: - dataset["gridlocation"] = xr.DataArray( - data=self.gridlocation, - attrs=dict( - standard_name="grid_file_location" - ) - ) - - if self.gridfiles is not None: - dataset["gridfiles"] = xr.DataArray( - data=self.gridfiles, dims=["ntiles"] - ) - - if self.gridtiles is not None: - dataset["gridtiles"] = xr.DataArray( - data=self.gridtiles, dims=["ntiles"] - ) - - if self.contacts is not None: - dataset["contacts"] = xr.DataArray( - data=self.contacts, dims=["ncontact"], - attrs=dict( - standard_name="grid_contact_spec", - contact_type="boundary", - alignment="true", contact_index="contact_index", - orientation="orient" - ) - ) + if mosaicfile is None: + if self.mosaicfile is None: + raise RuntimeError("need to specify mosaic filename") + else: + mosaicfile = self.mosaicfile + + ds = {} - if self.contact_index is not None: - dataset["contact_index"] = xr.DataArray( - data=self.contact_index, dims=["ncontact"], - attrs=dict( - standard_name="starting_ending_point_index_of_contact" + for obj in self.objlist: + if obj.data is not None: + name = obj.name + ds[name] = xr.DataArray( + data=obj.data, + attrs=attrs[name], + dims=dims[name] ) - ) - dataset.to_netcdf(outfile) + xr.Dataset(data_vars=ds).to_netcdf(mosaicfile) + + + @property + def mosaic(self): + + """ + retrieve mosaic + """ + + return self.mosaic_obj.data + + @mosaic.setter + def mosaic(self, data): + + """ + set mosaic data + """ + + self.mosaic_obj.data = data + + @property + def gridlocation(self): + + """ + retrieve gridlocation + """ + + return self.gridlocation_obj.data + + @gridlocation.setter + def gridlocation(self, data): + + """ + set gridlocation data + """ + + self.gridlocation_obj.data = data + + @property + def gridtiles(self): + + """ + retrieve gridtiles + """ + + return self.gridtiles_obj.data + + @gridtiles.setter + def gridtiles(self, data): + + """ + set gridtiles data + """ + + self.gridtiles_obj.data = data + + @property + def gridfiles(self): + + """ + retrieve gridfiles + """ + + return self.gridfiles_obj.data + + @gridfiles.setter + def gridfiles(self, data): + + """ + set gridfiles data + """ + + self.gridfiles_obj.data = data + + @property + def contacts(self): + + """ + retrieve contacts + """ + + return self.contacts_obj.data + + @contacts.setter + def contacts(self, data): + + """ + set contacts data + """ + + self.contacts_obj.data = data + + @property + def contact_index(self): + + """ + retrieve contact_index + """ + + return self.contact_index_obj.data + + @contact_index.setter + def contact_index(self, data): + + """ + set contact_index data + """ + + self.contact_index_obj.data = data diff --git a/fmsgridtools/utils/setlogger.py b/fmsgridtools/utils/setlogger.py index d762e8c..7274dad 100644 --- a/fmsgridtools/utils/setlogger.py +++ b/fmsgridtools/utils/setlogger.py @@ -1,20 +1,16 @@ import logging -import os -import datetime def setconfig(filename: str, debug: bool = False): - if os.path.isfile(filename): - os.rename(filename, "OLDLOG") - if debug: - logging.basicConfig(filename=filename, - format=("[%(levelname)s]%(filename)s:%(module)s:" - "%(funcName)s:%(lineno)d:" - "%(asctime)s:\n %(message)s\n"), - level=logging.DEBUG) + logging.basicConfig( + filename=filename, + format=("[%(levelname)s]%(filename)s:%(module)s: %(funcName)s:%(lineno)d: %(asctime)s:\n %(message)s\n"), + level=logging.DEBUG + ) else: - logging.basicConfig(filename=filename, - format=("[%(levelname)s]%(module)s:%(funcName)s:" - "%(message)s\n"), - level=logging.INFO) + logging.basicConfig( + filename=filename, + format=("[%(levelname)s]%(module)s: %(funcName)s: %(message)s\n"), + level=logging.INFO + ) diff --git a/setup.py b/setup.py index 00074dc..d0a3d1b 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def local_pkg(name: str, relative_path: str) -> str: "numpy", "xarray", "netCDF4", - local_pkg("pyFMS", "pyFMS"), +# local_pkg("pyFMS", "pyFMS"), local_pkg("pyfrenctools", "FREnctools_lib") ] diff --git a/tests/shared/test_gridobj.py b/tests/shared/test_gridobj.py index 8e45e06..37aaf28 100644 --- a/tests/shared/test_gridobj.py +++ b/tests/shared/test_gridobj.py @@ -1,9 +1,15 @@ -import os - +import logging import numpy as np +from pathlib import Path +import pytest +from types import SimpleNamespace import xarray as xr + +import pyfms +import fmsgridtools from fmsgridtools import GridObj +logger = logging.getLogger(__name__) """ Creating data to generate xarray dataset from @@ -14,11 +20,9 @@ nxp = nx + 1 nyp = ny + 1 -x = np.array([[i*10+j for j in range(nyp)] for i in range(nxp)], dtype=np.float64) -y = np.array([[-i*10-j for j in range(nyp)] for i in range(nxp)], dtype=np.float64) - -tile = xr.DataArray( - [b'tile1'], +ds = SimpleNamespace() +ds.tile = xr.DataArray( + data='tile1', attrs=dict( standard_name="grid_tile_spec", geometry="spherical", @@ -27,23 +31,23 @@ discretization="logically_rectangular", ) ) -x = xr.DataArray( - data=x, +ds.x = xr.DataArray( + data=np.full(shape=(nyp,nxp), fill_value=0.5, dtype=np.float64), dims=["nyp", "nxp"], attrs=dict( units="degree_east", standard_name="geographic_longitude", ) ) -y = xr.DataArray( - data=y, +ds.y = xr.DataArray( + data=np.full(shape=(nyp,nxp), fill_value=1.0, dtype=np.float64), dims=["nyp", "nxp"], attrs=dict( units="degree_north", standard_name="geographic_latitude", ) ) -dx = xr.DataArray( +ds.dx = xr.DataArray( data=np.full(shape=(nyp,nx), fill_value=1.5, dtype=np.float64), dims=["nyp", "nx"], attrs=dict( @@ -51,7 +55,7 @@ standard_name="grid_edge_x_distance", ) ) -dy = xr.DataArray( +ds.dy = xr.DataArray( data=np.full(shape=(ny,nxp), fill_value=2.5, dtype=np.float64), dims=["ny", "nxp"], attrs=dict( @@ -59,7 +63,7 @@ standard_name="grid_edge_y_distance", ) ) -area = xr.DataArray( +ds.area = xr.DataArray( data=np.full(shape=(ny,nx), fill_value=4.0, dtype=np.float64), dims=["ny", "nx"], attrs=dict( @@ -67,7 +71,7 @@ standard_name="grid_cell_area", ) ) -angle_dx = xr.DataArray( +ds.angle_dx = xr.DataArray( data=np.full(shape=(nyp,nxp), fill_value=3.0, dtype=np.float64), dims=["nyp", "nxp"], attrs=dict( @@ -75,7 +79,7 @@ standard_name="grid_vertex_x_angle_WRT_geographic_east", ) ) -angle_dy = xr.DataArray( +ds.angle_dy = xr.DataArray( data=np.full(shape=(nyp,nxp), fill_value=5.0, dtype=np.float64), dims=["nyp", "nxp"], attrs=dict( @@ -83,7 +87,7 @@ standard_name="grid_vertex_x_angle_WRT_geographic_east", ) ) -arcx = xr.DataArray( +ds.arcx = xr.DataArray( [b'arcx'], attrs=dict( standard_name="grid_edge_x_arc_type", @@ -92,98 +96,131 @@ ) ) -out_grid_dataset = xr.Dataset( - data_vars={ - "tile": tile, - "x": x, - "y": y, - "dx": dx, - "dy": dy, - "area": area, - "angle_dx": angle_dx, - "angle_dy": angle_dy, - "arcx": arcx, - } -) +#@pytest.fixture(autouse=True) +def set_fms_files(): + + inputnml = Path("input.nml") + logfile = Path("logfile.000000.out") + warnfile = Path("warnfile.000000.out") + + inputnml.touch() + yield -def test_empty_grid_obj(): + if inputnml.exists(): inputnml.unlink() + if logfile.exists(): logfile.unlink() + if warnfile.exists(): warnfile.unlink() + - empty_grid_obj = GridObj() - assert isinstance(empty_grid_obj, GridObj) +def test_read_write(set_fms_files): + + gridfile = Path("test_read_write.nc") + pyfms.fms.init() -def test_gridobj_from_dataset(): + testgrid = GridObj(gridtype="cubic") - from_dataset_grid_obj = GridObj(dataset=out_grid_dataset) - from_dataset_grid_obj.get_attributes() - assert isinstance(from_dataset_grid_obj, GridObj) - assert from_dataset_grid_obj.dataset is not None + testgrid.x = ds.x.data + testgrid.y = ds.y.data + testgrid.dx = ds.dx.data + testgrid.dy = ds.dy.data + testgrid.area = ds.area.data + testgrid.angle_dx = ds.angle_dx.data + testgrid.angle_dy = ds.angle_dy.data + testgrid.arcx = str(ds.arcx.data) + testgrid.tile = str(ds.tile.data) - np.testing.assert_array_equal(from_dataset_grid_obj.x, out_grid_dataset.x.values) - np.testing.assert_array_equal(from_dataset_grid_obj.y, out_grid_dataset.y.values) - np.testing.assert_array_equal(from_dataset_grid_obj.dx, out_grid_dataset.dx.values) - np.testing.assert_array_equal(from_dataset_grid_obj.dy, out_grid_dataset.dy.values) - np.testing.assert_array_equal(from_dataset_grid_obj.area, out_grid_dataset.area.values) - np.testing.assert_array_equal(from_dataset_grid_obj.angle_dx, out_grid_dataset.angle_dx.values) - np.testing.assert_array_equal(from_dataset_grid_obj.angle_dy, out_grid_dataset.angle_dy) + testgrid.write(gridfile) + exit() - -def test_write_grid(tmp_path): + assert gridfile.exists() - from_dataset_grid_obj = GridObj(dataset=out_grid_dataset) + del testgrid + + testgrid = GridObj(gridfile=gridfile).read() - file_path = tmp_path / "test_grid.nc" + #test dims + assert testgrid.nx == nx + assert testgrid.ny == ny + assert testgrid.nxp == nxp + assert testgrid.nyp == nyp + + #test values + np.testing.assert_array_equal(testgrid.x, ds.x.data) + np.testing.assert_array_equal(testgrid.y, ds.y.data) + np.testing.assert_array_equal(testgrid.dx, ds.dx.data) + np.testing.assert_array_equal(testgrid.dy, ds.dy.data) + np.testing.assert_array_equal(testgrid.area, ds.area.data) + np.testing.assert_array_equal(testgrid.angle_dx, ds.angle_dx.data) + np.testing.assert_array_equal(testgrid.angle_dy, ds.angle_dy.data) + assert testgrid.arcx == str(ds.arcx.data) + assert testgrid.tile == str(ds.tile.data) - from_dataset_grid_obj.write(filepath=file_path) + gridfile.unlink() - assert file_path.exists() + pyfms.fms.end() - file_path.unlink() +test_read_write(None) - assert not file_path.exists() +def test_center_option(set_fms_files): + pyfms.fms.init() -def test_gridobj_from_file(tmp_path): + gridfile = Path("test_center.nc") + nx2 = nx // 2 + ny2 = ny // 2 + nx2p = nx2 + 1 + ny2p = ny2 + 1 - gridfile = tmp_path / "test_grid.nc" - - out_grid_dataset.to_netcdf(gridfile) + # center points are value of 1 + x = np.array([[1,0]*nx2 + [1]]*nyp) + y = np.array([[1]*nxp, [0]*nxp]*ny2 + [[1]*nxp]) - from_file_init_grid_obj = GridObj(gridfile=gridfile).read() - assert isinstance(from_file_init_grid_obj, GridObj) - assert from_file_init_grid_obj.gridfile is not None + GridObj(gridfile=gridfile, x=x, y=y).write() - np.testing.assert_array_equal(from_file_init_grid_obj.x, out_grid_dataset.x.values) - np.testing.assert_array_equal(from_file_init_grid_obj.y, out_grid_dataset.y.values) - np.testing.assert_array_equal(from_file_init_grid_obj.dx, out_grid_dataset.dx.values) - np.testing.assert_array_equal(from_file_init_grid_obj.dy, out_grid_dataset.dy.values) - np.testing.assert_array_equal(from_file_init_grid_obj.area, out_grid_dataset.area.values) - np.testing.assert_array_equal(from_file_init_grid_obj.angle_dx, out_grid_dataset.angle_dx.values) - np.testing.assert_array_equal(from_file_init_grid_obj.angle_dy, out_grid_dataset.angle_dy.values) + grid = GridObj(gridfile=gridfile).read_xy(center=True, radians=True) - os.remove(gridfile) + assert grid.nx == nx2 + assert grid.ny == ny2 + assert grid.nxp == nx2 + 1 + assert grid.nyp == ny2 + 1 + answer = np.radians(np.ones((ny2p,nx2p), dtype=np.float64)) + + np.testing.assert_array_equal(grid.x, answer) + np.testing.assert_array_equal(grid.y, answer) -def test_gridobj_read(tmp_path): + gridfile.unlink() + + pyfms.fms.end() + - gridfile = tmp_path / "test_grid.nc" +def test_to_domain(set_fms_files): - out_grid_dataset.to_netcdf(gridfile) + nx, ny = 8, 8 + global_indices = [0, nx-1, 0, ny-1] - grid = GridObj(gridfile=gridfile).read(toradians=True, agrid=True, free_dataset=True) + Path("input.nml").touch() + pyfms.fms.init(ndomain=1) + domain = pyfms.mpp_domains.define_domains(global_indices) - assert grid.dataset == None + x1 = np.arange(nx+1, dtype=np.float64) + y1 = np.arange(ny+1, dtype=np.float64) + x, y = np.meshgrid(x1, y1) + + area = np.ones((ny, nx), dtype=np.float64) + + grid = GridObj(x=x, y=y, area=area) + grid.to_domain(domain) + + x1 = np.arange(domain.isc, domain.iec+2, dtype=np.float64) + y1 = np.arange(domain.jsc, domain.jec+2, dtype=np.float64) + xanswer, yanswer = np.meshgrid(x1, y1) + area_answer = np.ones((domain.ysize_c, domain.xsize_c), dtype=np.float64) - assert grid.nx == nx//2 - assert grid.ny == ny//2 - assert grid.nxp == nx//2 + 1 - assert grid.nyp == ny//2 + 1 - - for i in range(grid.nxp): - for j in range(grid.nyp): - answer = 2*10*i+2*j - assert grid.x[i][j] == np.radians(answer) - assert grid.y[i][j] == np.radians(-answer) - - os.remove(gridfile) + np.testing.assert_array_equal(grid.x, xanswer) + np.testing.assert_array_equal(grid.y, yanswer) + np.testing.assert_array_equal(grid.area, area_answer) + + pyfms.fms.end() + From 642c884aa0655a59dc5325c5d369484a174f5c21 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Thu, 20 Nov 2025 14:20:35 -0500 Subject: [PATCH 02/27] logging --- fmsgridtools/shared/gridobj.py | 69 +++++++++++----------------------- 1 file changed, 22 insertions(+), 47 deletions(-) diff --git a/fmsgridtools/shared/gridobj.py b/fmsgridtools/shared/gridobj.py index 1922047..60ea3c8 100644 --- a/fmsgridtools/shared/gridobj.py +++ b/fmsgridtools/shared/gridobj.py @@ -169,7 +169,11 @@ def __init__(self, self.arcx_obj, self.tile_obj ] - + + self._set_dims() + + logger.info("Created new GridObj named:\n %s", self.__repr__()) + def to_domain(self): @@ -185,7 +189,7 @@ def to_domain(self): edge = 1 if obj is self.area_obj else 2 obj.data = np.ascontiguousarray(obj.data[jsc:jec+edge, isc:iec+edge]) - self._set_dims(on_domain=True) + self._set_dims() return self @@ -228,7 +232,7 @@ def read(self, radians: bool = False, center: bool = False, on_domain: bool = Fa objlist = [self.x_obj, self.y_obj] else: objlist = self.objlist - logger.info(f"reading in file {self.gridfile}") + logger.info("reading in file %s\n", self.gridfile) with xr.open_dataset(self.input_dir/self.gridfile) as ds: for obj in objlist: @@ -248,7 +252,7 @@ def read(self, radians: bool = False, center: bool = False, on_domain: bool = Fa logger.error("please specify domain by ") self.to_domain() - self._set_dims(ds.sizes, center=center, on_domain=on_domain) + self._set_dims() return self @@ -265,7 +269,7 @@ def write(self, gridfile: str = None): logger.error("must provide gridfile name") gridfile = self.gridfile - logger.info()"writing out gridfile %s", {gridfile}) + logger.info("writing out gridfile %s", {gridfile}) if self.gridtype == "none": attrs["tile"] = attrs["tile_options"]["none"] @@ -288,46 +292,13 @@ def write(self, gridfile: str = None): xr.Dataset(data_vars=ds).to_netcdf(gridfile) - def _set_dims(self, dims: dict = None, center: bool = True, on_domain: bool = False): - - - if on_domain: - self.nx = self.domain.xsize_c - self.ny = self.domain.ysize_c - self.nxp = self.nx + 1 - self.nyp = self.ny + 1 - return + def _set_dims(self): - self.nxp = dims.get("nxp") - self.nyp = dims.get("nyp") - self.nx = dims.get("nx") - self.ny = dims.get("ny") - - #nx - if self.nxp is None: - if self.nx is None: - logger.error("cannot set dimension nxp") - self.nxp = self.nx + 1 - elif self.nx is None: - if self.nxp is None: - logger.error("cannot set dimension nx") - self.nx = self.nxp - 1 - - #ny - if self.nyp is None: - if self.ny is None: - logger.error("cannot set dimension nyp") - self.nyp = self.ny + 1 - elif self.ny is None: - if self.nyp is None: - logger.error("cannot set dimension ny") + if self.x_obj.data is not None: + print("hereherehere") + self.nyp, self.nxp = self.x_obj.data.shape self.ny = self.nyp - 1 - - if center: - self.nx = self.nx // 2 - self.ny = self.ny // 2 - self.nxp = self.nx + 1 - self.nyp = self.ny + 1 + self.nx = self.nxp - 1 @property @@ -494,12 +465,16 @@ def arcx(self, data): def __repr__(self): - summary = f"\n\nGrid for {self.gridfile}, tile = {self.tile_obj.name}\n" - summary += "nx = {:>5} ny = {:>5} nxp = {:>5} nyp = {:>5}\n".format(self.nx, self.ny, self.nxp, self.nyp) - summary += f"gridtype = {self.gridtype}\n" + summary = "%s\n" % (self.__class__.__name__) + summary += "gridfile = %s\n" % (self.gridfile) + summary += "gridtype = %s\n" % (self.gridtype) + summary += "nx = %s" % (self.nx) + summary += "ny = %s" % (self.ny) + summary += "nxp = %s" %(self.nxp) + summary += "nyp = %s" %(self.nyp) for obj in self.objlist: - summary += f"{obj.name} = {obj.data}\n" + summary += "%s = %s\n" % (obj.name, obj.data) return summary From 5de9813cc3102ce258880ae4eeba0ccfb34e912c Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 12:16:13 -0500 Subject: [PATCH 03/27] finish gridobj --- .github/workflows/test.yaml | 4 +- fmsgridtools/shared/gridobj.py | 168 ++++++++++++++++++--------------- pyFMS | 2 +- tests/shared/test_gridobj.py | 62 ++++++------ 4 files changed, 129 insertions(+), 107 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 0b318f5..5f432fd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -28,6 +28,8 @@ jobs: run: | pytest tests/shared/test_libs.py pytest tests/mosaic/test_mosaic.py - pytest tests/shared/test_gridobj.py + pytest tests/shared/test_gridobj.py::test_read_write + pytest tests/shared/test_gridobj.py::test_center_option + pytest tests/shared/test_gridobj.py::test_to_domain pytest tests/shared/test_xgridobj.py pytest tests/hgrid/test_hgrid.py diff --git a/fmsgridtools/shared/gridobj.py b/fmsgridtools/shared/gridobj.py index 60ea3c8..bbd0516 100644 --- a/fmsgridtools/shared/gridobj.py +++ b/fmsgridtools/shared/gridobj.py @@ -18,11 +18,13 @@ attrs = {} attrs["x"] = dict( standard_name = "geographic_longitude", - units = "degree_east" + units = "degree_east", + _FillValue = False ) attrs["y"] = dict( standard_name = "geographic_latitude", - units = "degrees_north" + units = "degrees_north", + _FillValue = False ) attrs["tile"] = {} attrs["tile_options"] = {} @@ -31,13 +33,15 @@ geometry = "spherical", north_pole = "0.0 90.0", discretization = "logically_rectangular", - conformal = "false" + conformal = "false", + _FillValue = False ) attrs["tile_options"]["simple_cartesian"] = dict( standard_name = "grid_tile_spec", geometry = "planar", discretization = "logically_rectangular", - conformal = "true" + conformal = "true", + _FillValue = False ) attrs["tile_options"]["none"] = dict( standard_name = "grid_tile_spec", @@ -45,32 +49,39 @@ north_pole = "0.0 90.0", projection = "none", discretization = "logically_rectangular", - conformal = "true" + conformal = "true", + _FillValue = False ) attrs["dx"] = dict( standard_name = "grid_edge_x_distance", - units = "meters" + units = "meters", + _FillValue = False ) attrs["dy"] = dict( standard_name = "grid_edge_y_distance", - units = "meters" + units = "meters", + _FillValue = False ) attrs["area"] = dict( standard_name = "grid_cell_area", - units = "m2" + units = "m2", + _FillValue = False ) attrs["angle_dx"] = dict( standard_name = "grid_vertex_x_angle_WRT_geographic_east", - units = "degrees_east" + units = "degrees_east", + _FillValue = False ) attrs["angle_dy"] = dict( standard_name = "grid_vertex_y_angle_WRT_geographic_north", - units = "degrees_north" + units = "degrees_north", + _FillValue = False ) attrs["arcx"] = dict( standard_name = "grid_edge_x_arc_type", - north_pole = "0.0,90.0" + north_pole = "0.0,90.0", + _FillValue = False ) dims = {} @@ -84,6 +95,11 @@ dims["tile"] = () dims["arcx"] = () +class Variable: + def __init__(self, name: str = None, data = None): + self.name = name + self.data = data + class GridObj: """ @@ -122,42 +138,15 @@ def __init__(self, self.gridtype = gridtype - self.x_obj = SimpleNamespace( - name="x", - data=x, - ) - self.y_obj = SimpleNamespace( - name="y", - data=y, - ) - self.tile_obj = SimpleNamespace( - name="tile", - data=tile, - ) - self.dx_obj = SimpleNamespace( - name="dx", - data=dx, - ) - self.dy_obj = SimpleNamespace( - name="dy", - data=dy, - ) - self.area_obj = SimpleNamespace( - name="area", - data=area, - ) - self.angle_dx_obj = SimpleNamespace( - name="angle_dx", - data=angle_dx, - ) - self.angle_dy_obj = SimpleNamespace( - name="angle_dy", - data=angle_dy, - ) - self.arcx_obj = SimpleNamespace( - name="arcx", - data=arcx, - ) + self.x_obj = Variable(name="x", data=x) + self.y_obj = Variable(name="y", data=y) + self.tile_obj = Variable(name="tile", data=tile) + self.dx_obj = Variable(name="dx", data=dx) + self.dy_obj = Variable(name="dy", data=dy) + self.area_obj = Variable(name="area", data=area) + self.angle_dx_obj = Variable(name="angle_dx", data=angle_dx) + self.angle_dy_obj = Variable(name="angle_dy", data=angle_dy) + self.arcx_obj = Variable(name="arcx", data=arcx) self.objlist = [ self.x_obj, self.y_obj, @@ -175,19 +164,41 @@ def __init__(self, logger.info("Created new GridObj named:\n %s", self.__repr__()) - def to_domain(self): + def to_domain(self, domain: dict = None): """ Stores data on the compute domain """ - isc, iec, jsc, jec = self.domain.isc, self.domain.iec, self.domain.jsc, self.domain.jec - - for obj in self.objlist: - if obj.data is not None and not self.arcx_obj: - logger.info("saving %s on domain", {obj.name}) - edge = 1 if obj is self.area_obj else 2 - obj.data = np.ascontiguousarray(obj.data[jsc:jec+edge, isc:iec+edge]) + if domain is None: + if self.domain is None: + logger.error("Please specify Domain object from pyfms") + domain = self.domain + else: + if self.domain is not None: + logger.warning("Overwriting %s with %s", self.domain, domain) + self.domain = domain + + if not pyfms.fms.module_is_initialized(): + logger.error("Please initialize pyfms first") + + isc, jsc = self.domain.isc, self.domain.jsc + xsize_c, ysize_c = self.domain.xsize_c, self.domain.ysize_c + + objdict = { + self.x_obj: (ysize_c+1, xsize_c+1), + self.y_obj: (ysize_c+1, xsize_c+1), + self.area_obj: (ysize_c, xsize_c), + self.dx_obj: (ysize_c+1, xsize_c), + self.dy_obj: (ysize_c, xsize_c+1), + self.angle_dx_obj: (ysize_c+1, xsize_c), + self.angle_dy_obj: (ysize_c, xsize_c+1) + } + + for obj, (ysize, xsize) in objdict.items(): + if obj.data is not None: + logger.info("Saving %s on domain", {obj.name}) + obj.data = np.ascontiguousarray(obj.data[jsc:jsc+ysize, isc:isc+xsize]) self._set_dims() @@ -203,9 +214,9 @@ def to_radians(self): objlist = [self.x_obj, self.y_obj, self.dx_obj, self.dy_obj, self.angle_dx_obj, self.angle_dy_obj] for obj in objlist: if obj.data is not None: - logger.info("converting %s to radians", {obj.name}) + logger.info("Converting %s to radians", {obj.name}) obj.data = np.radians(obj.data, dtype=np.float64) - + def get_fms_area(self): @@ -213,7 +224,10 @@ def get_fms_area(self): Compute grid cell areas """ - logger.info("computing grid cell area with fms") + logger.info("Computing grid cell area with fms") + + if not pyfms.fms.module_is_initialized(): + logger.error("Please initialize pyfms first") x = np.ascontiguousarray(self.x, dtype=np.float64) y = np.ascontiguousarray(self.y, dtype=np.float64) @@ -221,17 +235,19 @@ def get_fms_area(self): return self.area - def read(self, radians: bool = False, center: bool = False, on_domain: bool = False, xy_only: bool = True): + def read(self, radians: bool = False, center: bool = False, on_domain: bool = False, xy_only: bool = False): """ Reads in the gridfile and initializes the instance variables """ + objlist = [self.x_obj, self.y_obj] if xy_only: logger.info("reading only x and y coordinates from file %s", {self.gridfile}) - objlist = [self.x_obj, self.y_obj] else: - objlist = self.objlist + objlist += [self.area_obj, self.dx_obj, self.dy_obj, self.angle_dx_obj, + self.angle_dy_obj, self.arcx_obj, self.tile_obj + ] logger.info("reading in file %s\n", self.gridfile) with xr.open_dataset(self.input_dir/self.gridfile) as ds: @@ -243,33 +259,32 @@ def read(self, radians: bool = False, center: bool = False, on_domain: bool = Fa obj.data = np.ascontiguousarray(obj.data[::2, ::2]) else: logger.error("could not %s in %s", {obj.name}, {self.gridfile}) - + if radians: self.to_radians() - + if on_domain: - if self.domain is None: - logger.error("please specify domain by ") self.to_domain() self._set_dims() + logger.info("Finished reading file %s\n %s", self.gridfile, self.__repr__()) + return self def write(self, gridfile: str = None): """ - write_out_grid: - This method will generate a netcdf file containing grid content + Generate a netcdf file containing grid content """ if gridfile is None: - if self.gridfile is None: - logger.error("must provide gridfile name") + if self.gridfile is None: + logger.error("Please provide gridfile name") gridfile = self.gridfile - logger.info("writing out gridfile %s", {gridfile}) + logger.info("Writing out gridfile %s", {gridfile}) if self.gridtype == "none": attrs["tile"] = attrs["tile_options"]["none"] @@ -294,8 +309,9 @@ def write(self, gridfile: str = None): def _set_dims(self): - if self.x_obj.data is not None: - print("hereherehere") + if self.x_obj.data is None: + logger.warning("Cannot set dimensions if x and y coordinates are not set") + else: self.nyp, self.nxp = self.x_obj.data.shape self.ny = self.nyp - 1 self.nx = self.nxp - 1 @@ -468,10 +484,10 @@ def __repr__(self): summary = "%s\n" % (self.__class__.__name__) summary += "gridfile = %s\n" % (self.gridfile) summary += "gridtype = %s\n" % (self.gridtype) - summary += "nx = %s" % (self.nx) - summary += "ny = %s" % (self.ny) - summary += "nxp = %s" %(self.nxp) - summary += "nyp = %s" %(self.nyp) + summary += "nx = %s\n" % (self.nx) + summary += "ny = %s\n" % (self.ny) + summary += "nxp = %s\n" %(self.nxp) + summary += "nyp = %s\n" %(self.nyp) for obj in self.objlist: summary += "%s = %s\n" % (obj.name, obj.data) diff --git a/pyFMS b/pyFMS index 9369625..9e74ea8 160000 --- a/pyFMS +++ b/pyFMS @@ -1 +1 @@ -Subproject commit 936962562b62df96b492565a1dbbfa78dc97ca10 +Subproject commit 9e74ea87b410abfab5456d8621223fa68ca1cc10 diff --git a/tests/shared/test_gridobj.py b/tests/shared/test_gridobj.py index 37aaf28..216c3a7 100644 --- a/tests/shared/test_gridobj.py +++ b/tests/shared/test_gridobj.py @@ -29,6 +29,7 @@ north_pole="0.0 90.0", projection="cube_gnomonic", discretization="logically_rectangular", + _FillValue=None ) ) ds.x = xr.DataArray( @@ -37,6 +38,7 @@ attrs=dict( units="degree_east", standard_name="geographic_longitude", + _FillValue=None ) ) ds.y = xr.DataArray( @@ -45,6 +47,7 @@ attrs=dict( units="degree_north", standard_name="geographic_latitude", + _FillValue=None ) ) ds.dx = xr.DataArray( @@ -53,6 +56,7 @@ attrs=dict( units="meters", standard_name="grid_edge_x_distance", + _FillValue=None ) ) ds.dy = xr.DataArray( @@ -61,6 +65,7 @@ attrs=dict( units="meters", standard_name="grid_edge_y_distance", + _FillValue=None ) ) ds.area = xr.DataArray( @@ -69,6 +74,7 @@ attrs=dict( units="m2", standard_name="grid_cell_area", + _FillValue=None ) ) ds.angle_dx = xr.DataArray( @@ -77,6 +83,7 @@ attrs=dict( units="degrees_east", standard_name="grid_vertex_x_angle_WRT_geographic_east", + _FillValue=None ) ) ds.angle_dy = xr.DataArray( @@ -85,24 +92,26 @@ attrs=dict( units="degrees_east", standard_name="grid_vertex_x_angle_WRT_geographic_east", + _FillValue=None ) ) ds.arcx = xr.DataArray( - [b'arcx'], + 'arcx', attrs=dict( standard_name="grid_edge_x_arc_type", north_pole="0.0 90.0", - _FillValue=False, + _FillValue=None, ) ) -#@pytest.fixture(autouse=True) + +@pytest.fixture(autouse=True) def set_fms_files(): inputnml = Path("input.nml") logfile = Path("logfile.000000.out") warnfile = Path("warnfile.000000.out") - + inputnml.touch() yield @@ -110,14 +119,14 @@ def set_fms_files(): if inputnml.exists(): inputnml.unlink() if logfile.exists(): logfile.unlink() if warnfile.exists(): warnfile.unlink() - + def test_read_write(set_fms_files): - - gridfile = Path("test_read_write.nc") pyfms.fms.init() - + + gridfile = Path("test_read_write.nc") + testgrid = GridObj(gridtype="cubic") testgrid.x = ds.x.data @@ -131,12 +140,11 @@ def test_read_write(set_fms_files): testgrid.tile = str(ds.tile.data) testgrid.write(gridfile) - exit() assert gridfile.exists() del testgrid - + testgrid = GridObj(gridfile=gridfile).read() #test dims @@ -144,7 +152,7 @@ def test_read_write(set_fms_files): assert testgrid.ny == ny assert testgrid.nxp == nxp assert testgrid.nyp == nyp - + #test values np.testing.assert_array_equal(testgrid.x, ds.x.data) np.testing.assert_array_equal(testgrid.y, ds.y.data) @@ -157,15 +165,13 @@ def test_read_write(set_fms_files): assert testgrid.tile == str(ds.tile.data) gridfile.unlink() - pyfms.fms.end() -test_read_write(None) def test_center_option(set_fms_files): pyfms.fms.init() - + gridfile = Path("test_center.nc") nx2 = nx // 2 ny2 = ny // 2 @@ -178,7 +184,7 @@ def test_center_option(set_fms_files): GridObj(gridfile=gridfile, x=x, y=y).write() - grid = GridObj(gridfile=gridfile).read_xy(center=True, radians=True) + grid = GridObj(gridfile=gridfile).read(center=True, radians=True, xy_only=True) assert grid.nx == nx2 assert grid.ny == ny2 @@ -186,41 +192,39 @@ def test_center_option(set_fms_files): assert grid.nyp == ny2 + 1 answer = np.radians(np.ones((ny2p,nx2p), dtype=np.float64)) - + np.testing.assert_array_equal(grid.x, answer) np.testing.assert_array_equal(grid.y, answer) gridfile.unlink() - pyfms.fms.end() - - + + def test_to_domain(set_fms_files): nx, ny = 8, 8 global_indices = [0, nx-1, 0, ny-1] Path("input.nml").touch() - pyfms.fms.init(ndomain=1) + + pyfms.fms.init() domain = pyfms.mpp_domains.define_domains(global_indices) x1 = np.arange(nx+1, dtype=np.float64) y1 = np.arange(ny+1, dtype=np.float64) x, y = np.meshgrid(x1, y1) - area = np.ones((ny, nx), dtype=np.float64) - - grid = GridObj(x=x, y=y, area=area) + + grid = GridObj(x=x, y=y, area=area, domain=domain) grid.to_domain(domain) - x1 = np.arange(domain.isc, domain.iec+2, dtype=np.float64) - y1 = np.arange(domain.jsc, domain.jec+2, dtype=np.float64) - xanswer, yanswer = np.meshgrid(x1, y1) + x1_answer = np.arange(domain.isc, domain.iec+2, dtype=np.float64) + y1_answer = np.arange(domain.jsc, domain.jec+2, dtype=np.float64) + xanswer, yanswer = np.meshgrid(x1_answer, y1_answer) area_answer = np.ones((domain.ysize_c, domain.xsize_c), dtype=np.float64) - + np.testing.assert_array_equal(grid.x, xanswer) np.testing.assert_array_equal(grid.y, yanswer) np.testing.assert_array_equal(grid.area, area_answer) - pyfms.fms.end() - + pyfms.fms.end() \ No newline at end of file From 70f27e2919c0465a9ab3d62a28ad0d4a1f83407b Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 12:37:06 -0500 Subject: [PATCH 04/27] update pyFMS --- fmsgridtools/shared/gridobj.py | 25 +++++++++---------------- pyFMS | 2 +- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/fmsgridtools/shared/gridobj.py b/fmsgridtools/shared/gridobj.py index bbd0516..2ebe229 100644 --- a/fmsgridtools/shared/gridobj.py +++ b/fmsgridtools/shared/gridobj.py @@ -5,7 +5,6 @@ import logging from pathlib import Path -from types import SimpleNamespace import numpy as np import numpy.typing as npt @@ -123,8 +122,7 @@ def __init__(self, area: npt.NDArray = None, angle_dx: npt.NDArray = None, angle_dy: npt.NDArray = None, - arcx: npt.NDArray = None, - on_gpu: bool = False + arcx: npt.NDArray = None ): self.input_dir = Path(input_dir) @@ -147,17 +145,6 @@ def __init__(self, self.angle_dx_obj = Variable(name="angle_dx", data=angle_dx) self.angle_dy_obj = Variable(name="angle_dy", data=angle_dy) self.arcx_obj = Variable(name="arcx", data=arcx) - self.objlist = [ - self.x_obj, - self.y_obj, - self.dx_obj, - self.dy_obj, - self.area_obj, - self.angle_dx_obj, - self.angle_dy_obj, - self.arcx_obj, - self.tile_obj - ] self._set_dims() @@ -293,8 +280,11 @@ def write(self, gridfile: str = None): elif self.gridtype == "simple_cartesian": attrs["tile"] = attrs["tile_options"]["simple_cartesian"] + objlist = [self.x_obj, self.y_obj, self.dx_obj, self.dy_obj, self.area_obj, + self.angle_dx_obj, self.angle_dy_obj, self.arcx_obj, self.tile_obj] + ds = {} - for obj in self.objlist: + for obj in objlist: if obj.data is not None: name = obj.name ds[name] = xr.DataArray( @@ -489,7 +479,10 @@ def __repr__(self): summary += "nxp = %s\n" %(self.nxp) summary += "nyp = %s\n" %(self.nyp) - for obj in self.objlist: + objlist = [self.x_obj, self.y_obj, self.dx_obj, self.dy_obj, self.area_obj, + self.angle_dx_obj, self.angle_dy_obj, self.arcx_obj, self.tile_obj] + + for obj in objlist: summary += "%s = %s\n" % (obj.name, obj.data) return summary diff --git a/pyFMS b/pyFMS index 9e74ea8..4e828f0 160000 --- a/pyFMS +++ b/pyFMS @@ -1 +1 @@ -Subproject commit 9e74ea87b410abfab5456d8621223fa68ca1cc10 +Subproject commit 4e828f0f7f1253e7be166fd8efb11608da13a282 From 52640615066d3e5169c2fffb811a4b9a8c978416 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 14:05:02 -0500 Subject: [PATCH 05/27] mosaicobj --- fmsgridtools/shared/gridobj.py | 250 +++++++++++++++---------------- fmsgridtools/shared/mosaicobj.py | 209 +++++++++++--------------- 2 files changed, 215 insertions(+), 244 deletions(-) diff --git a/fmsgridtools/shared/gridobj.py b/fmsgridtools/shared/gridobj.py index 2ebe229..5bdfb31 100644 --- a/fmsgridtools/shared/gridobj.py +++ b/fmsgridtools/shared/gridobj.py @@ -16,71 +16,57 @@ attrs = {} attrs["x"] = dict( - standard_name = "geographic_longitude", - units = "degree_east", - _FillValue = False + standard_name="geographic_longitude", units="degree_east", _FillValue=False ) attrs["y"] = dict( - standard_name = "geographic_latitude", - units = "degrees_north", - _FillValue = False + standard_name="geographic_latitude", units="degrees_north", _FillValue=False ) attrs["tile"] = {} attrs["tile_options"] = {} attrs["tile_options"]["cubic"] = dict( - standard_name = "grid_tile_spec", - geometry = "spherical", - north_pole = "0.0 90.0", - discretization = "logically_rectangular", - conformal = "false", - _FillValue = False + standard_name="grid_tile_spec", + geometry="spherical", + north_pole="0.0 90.0", + discretization="logically_rectangular", + conformal="false", + _FillValue=False, ) attrs["tile_options"]["simple_cartesian"] = dict( - standard_name = "grid_tile_spec", - geometry = "planar", - discretization = "logically_rectangular", - conformal = "true", - _FillValue = False + standard_name="grid_tile_spec", + geometry="planar", + discretization="logically_rectangular", + conformal="true", + _FillValue=False, ) attrs["tile_options"]["none"] = dict( - standard_name = "grid_tile_spec", - geometry = "spherical", - north_pole = "0.0 90.0", - projection = "none", - discretization = "logically_rectangular", - conformal = "true", - _FillValue = False + standard_name="grid_tile_spec", + geometry="spherical", + north_pole="0.0 90.0", + projection="none", + discretization="logically_rectangular", + conformal="true", + _FillValue=False, ) attrs["dx"] = dict( - standard_name = "grid_edge_x_distance", - units = "meters", - _FillValue = False + standard_name="grid_edge_x_distance", units="meters", _FillValue=False ) attrs["dy"] = dict( - standard_name = "grid_edge_y_distance", - units = "meters", - _FillValue = False -) -attrs["area"] = dict( - standard_name = "grid_cell_area", - units = "m2", - _FillValue = False + standard_name="grid_edge_y_distance", units="meters", _FillValue=False ) +attrs["area"] = dict(standard_name="grid_cell_area", units="m2", _FillValue=False) attrs["angle_dx"] = dict( - standard_name = "grid_vertex_x_angle_WRT_geographic_east", - units = "degrees_east", - _FillValue = False + standard_name="grid_vertex_x_angle_WRT_geographic_east", + units="degrees_east", + _FillValue=False, ) attrs["angle_dy"] = dict( - standard_name = "grid_vertex_y_angle_WRT_geographic_north", - units = "degrees_north", - _FillValue = False + standard_name="grid_vertex_y_angle_WRT_geographic_north", + units="degrees_north", + _FillValue=False, ) attrs["arcx"] = dict( - standard_name = "grid_edge_x_arc_type", - north_pole = "0.0,90.0", - _FillValue = False + standard_name="grid_edge_x_arc_type", north_pole="0.0,90.0", _FillValue=False ) dims = {} @@ -94,35 +80,37 @@ dims["tile"] = () dims["arcx"] = () + class Variable: - def __init__(self, name: str = None, data = None): + def __init__(self, name: str = None, data=None): self.name = name self.data = data -class GridObj: +class GridObj: """ Class for grid information """ - def __init__(self, - input_dir: str = "./", - gridfile: str = None, - domain: pyfms.Domain = None, - gridtype: str = None, - tile: str = None, - nx: int = None, - ny: int = None, - nxp: int = None, - nyp: int = None, - x: npt.NDArray = None, - y: npt.NDArray = None, - dx: npt.NDArray = None, - dy: npt.NDArray = None, - area: npt.NDArray = None, - angle_dx: npt.NDArray = None, - angle_dy: npt.NDArray = None, - arcx: npt.NDArray = None + def __init__( + self, + input_dir: str = "./", + gridfile: str = None, + domain: pyfms.Domain = None, + gridtype: str = None, + tile: str = None, + nx: int = None, + ny: int = None, + nxp: int = None, + nyp: int = None, + x: npt.NDArray = None, + y: npt.NDArray = None, + dx: npt.NDArray = None, + dy: npt.NDArray = None, + area: npt.NDArray = None, + angle_dx: npt.NDArray = None, + angle_dy: npt.NDArray = None, + arcx: npt.NDArray = None, ): self.input_dir = Path(input_dir) @@ -150,9 +138,7 @@ def __init__(self, logger.info("Created new GridObj named:\n %s", self.__repr__()) - def to_domain(self, domain: dict = None): - """ Stores data on the compute domain """ @@ -173,40 +159,45 @@ def to_domain(self, domain: dict = None): xsize_c, ysize_c = self.domain.xsize_c, self.domain.ysize_c objdict = { - self.x_obj: (ysize_c+1, xsize_c+1), - self.y_obj: (ysize_c+1, xsize_c+1), + self.x_obj: (ysize_c + 1, xsize_c + 1), + self.y_obj: (ysize_c + 1, xsize_c + 1), self.area_obj: (ysize_c, xsize_c), - self.dx_obj: (ysize_c+1, xsize_c), - self.dy_obj: (ysize_c, xsize_c+1), - self.angle_dx_obj: (ysize_c+1, xsize_c), - self.angle_dy_obj: (ysize_c, xsize_c+1) + self.dx_obj: (ysize_c + 1, xsize_c), + self.dy_obj: (ysize_c, xsize_c + 1), + self.angle_dx_obj: (ysize_c + 1, xsize_c), + self.angle_dy_obj: (ysize_c, xsize_c + 1), } for obj, (ysize, xsize) in objdict.items(): if obj.data is not None: logger.info("Saving %s on domain", {obj.name}) - obj.data = np.ascontiguousarray(obj.data[jsc:jsc+ysize, isc:isc+xsize]) + obj.data = np.ascontiguousarray( + obj.data[jsc : jsc + ysize, isc : isc + xsize] + ) self._set_dims() return self - def to_radians(self): - """ Converts data from degres to radians """ - objlist = [self.x_obj, self.y_obj, self.dx_obj, self.dy_obj, self.angle_dx_obj, self.angle_dy_obj] + objlist = [ + self.x_obj, + self.y_obj, + self.dx_obj, + self.dy_obj, + self.angle_dx_obj, + self.angle_dy_obj, + ] for obj in objlist: if obj.data is not None: logger.info("Converting %s to radians", {obj.name}) obj.data = np.radians(obj.data, dtype=np.float64) - def get_fms_area(self): - """ Compute grid cell areas """ @@ -221,23 +212,43 @@ def get_fms_area(self): self.area = pyfms.grid_utils.get_grid_area(lon=x, lat=y, convert_cf_order=False) return self.area - - def read(self, radians: bool = False, center: bool = False, on_domain: bool = False, xy_only: bool = False): - + def read( + self, + gridfile: str = None, + domain: dict = None, + radians: bool = False, + center: bool = False, + on_domain: bool = False, + xy_only: bool = False, + ): """ Reads in the gridfile and initializes the instance variables """ + if gridfile is None: + if self.gridfile is None: + logger.error("Please provide gridfile name") + else: + self.gridfile = gridfile + objlist = [self.x_obj, self.y_obj] if xy_only: - logger.info("reading only x and y coordinates from file %s", {self.gridfile}) + logger.info( + "reading only x and y coordinates from file %s", {self.gridfile} + ) else: - objlist += [self.area_obj, self.dx_obj, self.dy_obj, self.angle_dx_obj, - self.angle_dy_obj, self.arcx_obj, self.tile_obj + objlist += [ + self.area_obj, + self.dx_obj, + self.dy_obj, + self.angle_dx_obj, + self.angle_dy_obj, + self.arcx_obj, + self.tile_obj, ] logger.info("reading in file %s\n", self.gridfile) - with xr.open_dataset(self.input_dir/self.gridfile) as ds: + with xr.open_dataset(self.input_dir / self.gridfile) as ds: for obj in objlist: if obj.name in ds: obj.data = ds[obj.name].data @@ -251,7 +262,7 @@ def read(self, radians: bool = False, center: bool = False, on_domain: bool = Fa self.to_radians() if on_domain: - self.to_domain() + self.to_domain(domain) self._set_dims() @@ -259,9 +270,7 @@ def read(self, radians: bool = False, center: bool = False, on_domain: bool = Fa return self - def write(self, gridfile: str = None): - """ Generate a netcdf file containing grid content """ @@ -280,23 +289,29 @@ def write(self, gridfile: str = None): elif self.gridtype == "simple_cartesian": attrs["tile"] = attrs["tile_options"]["simple_cartesian"] - objlist = [self.x_obj, self.y_obj, self.dx_obj, self.dy_obj, self.area_obj, - self.angle_dx_obj, self.angle_dy_obj, self.arcx_obj, self.tile_obj] + objlist = [ + self.x_obj, + self.y_obj, + self.dx_obj, + self.dy_obj, + self.area_obj, + self.angle_dx_obj, + self.angle_dy_obj, + self.arcx_obj, + self.tile_obj, + ] ds = {} for obj in objlist: if obj.data is not None: name = obj.name ds[name] = xr.DataArray( - data = obj.data, - attrs = attrs[name], - dims=dims[name] + data=obj.data, attrs=attrs[name], dims=dims[name] ) logger.info(ds[name]) xr.Dataset(data_vars=ds).to_netcdf(gridfile) - def _set_dims(self): if self.x_obj.data is None: @@ -306,10 +321,8 @@ def _set_dims(self): self.ny = self.nyp - 1 self.nx = self.nxp - 1 - @property def x(self): - """ retrieve x """ @@ -318,7 +331,6 @@ def x(self): @x.setter def x(self, data): - """ set x """ @@ -327,7 +339,6 @@ def x(self, data): @property def y(self): - """ retrieve y """ @@ -336,7 +347,6 @@ def y(self): @y.setter def y(self, data): - """ set y """ @@ -345,8 +355,7 @@ def y(self, data): @property def tile(self): - - """" + """ " retrieve tile """ @@ -354,7 +363,6 @@ def tile(self): @tile.setter def tile(self, data): - """ set tile """ @@ -363,7 +371,6 @@ def tile(self, data): @property def dx(self): - """ retrieve dx """ @@ -372,7 +379,6 @@ def dx(self): @dx.setter def dx(self, data): - """ set dx """ @@ -381,7 +387,6 @@ def dx(self, data): @property def dy(self): - """ retrieve dy """ @@ -390,7 +395,6 @@ def dy(self): @dy.setter def dy(self, data): - """ set dy """ @@ -399,7 +403,6 @@ def dy(self, data): @property def area(self): - """ retrieve area """ @@ -408,7 +411,6 @@ def area(self): @area.setter def area(self, data): - """ set area """ @@ -417,7 +419,6 @@ def area(self, data): @property def angle_dx(self): - """ retrieve angle_dx """ @@ -426,7 +427,6 @@ def angle_dx(self): @angle_dx.setter def angle_dx(self, data): - """ set angle_dx """ @@ -435,7 +435,6 @@ def angle_dx(self, data): @property def angle_dy(self): - """ retrieve angle_dy """ @@ -444,7 +443,6 @@ def angle_dy(self): @angle_dy.setter def angle_dy(self, data): - """ set angle_dy """ @@ -453,7 +451,6 @@ def angle_dy(self, data): @property def arcx(self): - """ retrieve arcx """ @@ -462,33 +459,34 @@ def arcx(self): @arcx.setter def arcx(self, data): - """ set arcx """ self.arcx_obj.data = data - def __repr__(self): summary = "%s\n" % (self.__class__.__name__) summary += "gridfile = %s\n" % (self.gridfile) summary += "gridtype = %s\n" % (self.gridtype) summary += "nx = %s\n" % (self.nx) summary += "ny = %s\n" % (self.ny) - summary += "nxp = %s\n" %(self.nxp) - summary += "nyp = %s\n" %(self.nyp) - - objlist = [self.x_obj, self.y_obj, self.dx_obj, self.dy_obj, self.area_obj, - self.angle_dx_obj, self.angle_dy_obj, self.arcx_obj, self.tile_obj] + summary += "nxp = %s\n" % (self.nxp) + summary += "nyp = %s\n" % (self.nyp) + + objlist = [ + self.x_obj, + self.y_obj, + self.dx_obj, + self.dy_obj, + self.area_obj, + self.angle_dx_obj, + self.angle_dy_obj, + self.arcx_obj, + self.tile_obj, + ] for obj in objlist: summary += "%s = %s\n" % (obj.name, obj.data) return summary - - - - - - diff --git a/fmsgridtools/shared/mosaicobj.py b/fmsgridtools/shared/mosaicobj.py index c845201..bad69b3 100644 --- a/fmsgridtools/shared/mosaicobj.py +++ b/fmsgridtools/shared/mosaicobj.py @@ -2,8 +2,8 @@ MosaicObj class """ +import logging from pathlib import Path -from types import SimpleNamespace import numpy as np import xarray as xr @@ -11,77 +11,66 @@ import pyfms from fmsgridtools.shared.gridobj import GridObj +logger = logging.getLogger(__name__) + attrs = dict( - mosaic = dict( + mosaic=dict( standard_name="grid_mosaic_spec", contact_regions="contacts", children="gridtiles", - grid_descriptor="" - ), - gridlocation = dict( - standard_name="grid_file_location" + grid_descriptor="", + _FillValue=False, ), - gridfiles = dict(), - gridtiles = dict(), - contacts = dict( + gridlocation=dict(standard_name="grid_file_location", _FillValue=False), + gridfiles=dict(), + gridtiles=dict(), + contacts=dict( standard_name="grid_contact_spec", contact_type="boundary", alignment="true", contact_index="contact_index", - orientation="orient" + orientation="orient", + _FillValue=False, + ), + contact_index=dict( + standard_name="starting_ending_point_index_of_contact", _FillValue=False ), - contact_index = dict( - standard_name="starting_ending_point_index_of_contact" - ) ) dims = dict( - mosaic = (), - gridlocation = (), - gridfiles = ["ntiles"], - gridtiles = ["ntiles"], - contacts = ["ncontact"], - contact_index = ["ncontact"] + mosaic=(), + gridlocation=(), + gridfiles=["ntiles"], + gridtiles=["ntiles"], + contacts=["ncontact"], + contact_index=["ncontact"], ) -def set_attribute(variable: str, var_attr: dict): - - global attrs - - if variable in attrs: - attrs[variable] = var_attr - else: - raise RuntimeError(f"{variable} does not exist in attributes") - +class Variable: -def set_dims(variable: str, var_dim: list): - - global dims - - if variable in dims: - dims[variable] = var_dim - else: - raise RuntimeError(f"{variable} does not exist in dims") + def __init__(self, name: str = None, data=None): + self.name = name + self.data = data class MosaicObj: - """ MosaicObj """ - def __init__(self, - input_dir: str = "./", - mosaicfile: str = None, - mosaic: str = None, - ntiles: int = None, - gridlocation: str = "./", - gridfiles: list[str] = None, - gridtiles: list[str] = None, - contacts: list[str] = None, - contact_index: list[str] = None, - ): + def __init__( + self, + input_dir: str = "./", + mosaicfile: str = None, + mosaic: str = None, + ntiles: int = None, + gridlocation: str = "./", + gridfiles: list[str] = None, + gridtiles: list[str] = None, + contacts: list[str] = None, + contact_index: list[str] = None, + ): self.input_dir = Path(input_dir) self.mosaicfile = mosaicfile @@ -89,52 +78,33 @@ def __init__(self, self.ntiles = ntiles self.ncontacts = None - self.mosaic_obj = SimpleNamespace( - name = "mosaic", - data = mosaic - ) - self.gridlocation_obj = SimpleNamespace( - name = "gridlocation", - data = gridlocation - ) - self.gridfiles_obj = SimpleNamespace( - name = "gridfiles", - data = gridfiles - ) - self.gridtiles_obj = SimpleNamespace( - name = "gridtiles", - data = gridtiles - ) - self.contacts_obj = SimpleNamespace( - name = "contacts", - data = contacts - ) - self.contact_index_obj = SimpleNamespace( - name = "contact_index", - data = contact_index - ) + self.mosaic_obj = Variable(name="mosaic", data=mosaic) + self.gridlocation_obj = Variable(name="gridlocation", data=gridlocation) + self.gridfiles_obj = Variable(name="gridfiles", data=gridfiles) + self.gridtiles_obj = Variable(name="gridtiles", data=gridtiles) + self.contacts_obj = Variable(name="contacts", data=contacts) + self.contact_index_obj = Variable(name="contact_index", data=contact_index) self.objlist = [ self.mosaic_obj, self.gridlocation_obj, self.gridfiles_obj, self.gridtiles_obj, self.contacts_obj, - self.contact_index_obj + self.contact_index_obj, ] - - def read(self, mosaicfile: str|Path = None, input_dir: str|Path = "."): - + def read(self, mosaicfile: str | Path = None, input_dir: str | Path = "."): """ - Read the mosac file + Read the mosaic file """ if mosaicfile is None: if self.mosaicfile is None: - raise IOError("Please specify the mosaic file") - mosaicfile = self.mosaicfile + logger.error("Please specify mosaic file") + else: + self.mosaicfile = mosaicfile - with xr.open_dataset(Path(self.input_dir)/mosaicfile) as ds: + with xr.open_dataset(Path(self.input_dir) / self.mosaicfile) as ds: for obj in self.objlist: variable = ds.get(obj.name) @@ -148,13 +118,14 @@ def read(self, mosaicfile: str|Path = None, input_dir: str|Path = "."): self.ntiles = ds.sizes.get("ntiles") self.ncontacts = ds.sizes.get("ncontact") self.input_dir = input_dir - self.mosaicfile = mosaicfile - return self + logger.info( + "Finished reading file %s\n %s\n", self.mosaicfile, self.__repr__() + ) + return self def from_dict(self, mosaic_dict: dict): - """ Generate mosaic file from dictionary """ @@ -162,38 +133,38 @@ def from_dict(self, mosaic_dict: dict): names = [obj.name for obj in self.objlist] for key in mosaic_dict: if key not in names: - raise RuntimeError(f"{key} not a field in MosaicObj") + logger.warning(f"{key} not a field in MosaicObj") for key in mosaic_dict: for obj in self.objlist: if obj.name == key: obj.data = mosaic_dict[key] - for obj in self.objlist: - if obj.data is None: - printf(f"{obj.name} not set") - self.ntiles = None if self.gridfiles is None else len(self.gridfiles.data) self.ncontacts = None if self.contacts is None else len(self.contacts.data) - return self - - - def get_grid(self, input_dir: str|Path = "./", - radians: bool = False, - center: bool = False, - domain: pyfms.Domain = None) -> dict: + logger.info("Finished setting mosaicobj from dict: %s\n", self.__repr__()) + return self + def get_grid( + self, + input_dir: str | Path = "./", + radians: bool = False, + center: bool = False, + domain: pyfms.Domain = None, + ) -> dict: """ Get grids from gridfiles """ - if self.gridfiles is None: - raise RuntimeError("need to set gridfiles") + logger.info("Reding in grid") + + if self.gridfiles_obj is None: + raise RuntimeError("Cannot find gridfiles to read") if self.gridtiles is None: - raise RuntimeError("need to set gridtiles") + raise RuntimeError("Cannot find gridtile information") if self.ntiles is None: ntiles = len(self.gridfiles) @@ -201,14 +172,20 @@ def get_grid(self, input_dir: str|Path = "./", grid = {} for gridfile, gridtile in zip(self.gridfiles, self.gridtiles): - readfile = Path(input_dir)/gridfile - grid[gridtile] = GridObj(gridfile=readfile).read_xy(radians=radians, center=center, domain=domain) + readfile = Path(input_dir) / gridfile + grid[gridtile] = GridObj(gridfile=readfile).read( + radians=radians, + center=center, + domain=domain, + on_domain=False if domain is None else True, + xy_only=True, + ) - return grid + logger.info("Finished reading in grid %s\n", grid) + return grid def write(self, mosaicfile: str = None) -> None: - """ write mosaic file """ @@ -219,23 +196,21 @@ def write(self, mosaicfile: str = None) -> None: else: mosaicfile = self.mosaicfile + logger.info("Writing out mosaicfile %s\n", mosaicfile) + ds = {} for obj in self.objlist: if obj.data is not None: name = obj.name ds[name] = xr.DataArray( - data=obj.data, - attrs=attrs[name], - dims=dims[name] + data=obj.data, attrs=attrs[name], dims=dims[name] ) xr.Dataset(data_vars=ds).to_netcdf(mosaicfile) - @property def mosaic(self): - """ retrieve mosaic """ @@ -244,7 +219,6 @@ def mosaic(self): @mosaic.setter def mosaic(self, data): - """ set mosaic data """ @@ -253,7 +227,6 @@ def mosaic(self, data): @property def gridlocation(self): - """ retrieve gridlocation """ @@ -262,7 +235,6 @@ def gridlocation(self): @gridlocation.setter def gridlocation(self, data): - """ set gridlocation data """ @@ -271,7 +243,6 @@ def gridlocation(self, data): @property def gridtiles(self): - """ retrieve gridtiles """ @@ -280,7 +251,6 @@ def gridtiles(self): @gridtiles.setter def gridtiles(self, data): - """ set gridtiles data """ @@ -289,7 +259,6 @@ def gridtiles(self, data): @property def gridfiles(self): - """ retrieve gridfiles """ @@ -298,7 +267,6 @@ def gridfiles(self): @gridfiles.setter def gridfiles(self, data): - """ set gridfiles data """ @@ -307,7 +275,6 @@ def gridfiles(self, data): @property def contacts(self): - """ retrieve contacts """ @@ -316,7 +283,6 @@ def contacts(self): @contacts.setter def contacts(self, data): - """ set contacts data """ @@ -325,7 +291,6 @@ def contacts(self, data): @property def contact_index(self): - """ retrieve contact_index """ @@ -334,10 +299,18 @@ def contact_index(self): @contact_index.setter def contact_index(self, data): - """ set contact_index data """ self.contact_index_obj.data = data + def __repr__(self): + summary = "%s\n" % (self.__class__.__name__) + summary += "ntiles = %s\n" % (self.ntiles) + summary += "ncontacts = %s\n" % (self.ncontacts) + + for obj in self.objlist: + summary += "%s = %s\n" % (obj.name, obj.data) + + return summary From 2bcc5b5eb68f18a21905b6fcec8c8d76ffb3cbd5 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 14:09:23 -0500 Subject: [PATCH 06/27] undo setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d0a3d1b..00074dc 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def local_pkg(name: str, relative_path: str) -> str: "numpy", "xarray", "netCDF4", -# local_pkg("pyFMS", "pyFMS"), + local_pkg("pyFMS", "pyFMS"), local_pkg("pyfrenctools", "FREnctools_lib") ] From 0d8efe9d7ae45cbb4300793ef54520207c897b81 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 14:19:26 -0500 Subject: [PATCH 07/27] fix tests --- fmsgridtools/shared/mosaicobj.py | 16 +++- tests/mosaic/test_mosaic.py | 49 +++------- tests/shared/test_mosaicobj.py | 148 +++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 43 deletions(-) create mode 100644 tests/shared/test_mosaicobj.py diff --git a/fmsgridtools/shared/mosaicobj.py b/fmsgridtools/shared/mosaicobj.py index bad69b3..f6fd772 100644 --- a/fmsgridtools/shared/mosaicobj.py +++ b/fmsgridtools/shared/mosaicobj.py @@ -98,13 +98,19 @@ def read(self, mosaicfile: str | Path = None, input_dir: str | Path = "."): Read the mosaic file """ + logger.info("Reading mosaicfile") + if mosaicfile is None: if self.mosaicfile is None: logger.error("Please specify mosaic file") else: self.mosaicfile = mosaicfile - with xr.open_dataset(Path(self.input_dir) / self.mosaicfile) as ds: + if str(input_dir) != (self.input_dir): + logger.warning("Resetting input_dir to %s", input_dir) + self.input_dir = Path(input_dir) + + with xr.open_dataset(self.input_dir / self.mosaicfile) as ds: for obj in self.objlist: variable = ds.get(obj.name) @@ -130,10 +136,12 @@ def from_dict(self, mosaic_dict: dict): Generate mosaic file from dictionary """ + logger.info("Setting MosaicObj from dict") + names = [obj.name for obj in self.objlist] for key in mosaic_dict: if key not in names: - logger.warning(f"{key} not a field in MosaicObj") + logger.error(f"{key} not a field in MosaicObj") for key in mosaic_dict: for obj in self.objlist: @@ -158,9 +166,9 @@ def get_grid( Get grids from gridfiles """ - logger.info("Reding in grid") + logger.info("Reading in grid") - if self.gridfiles_obj is None: + if self.gridfiles is None: raise RuntimeError("Cannot find gridfiles to read") if self.gridtiles is None: diff --git a/tests/mosaic/test_mosaic.py b/tests/mosaic/test_mosaic.py index fee9828..c688bec 100644 --- a/tests/mosaic/test_mosaic.py +++ b/tests/mosaic/test_mosaic.py @@ -53,46 +53,18 @@ def test_create_regional_input(): "grid_yt_sub01": yt_data}).to_netcdf( f"regional_input_file.tile{tile_number}.nc") -def test_write(): - mosaic = fmsgridtools.MosaicObj(ntiles=6, - name=f"output[:-3]", - gridlocation='./', - gridfiles=np.asarray(gridfiles), - gridtiles=np.asarray(gridtiles), - contacts=np.full(6, "", dtype=str), - contact_index=np.full(6, "", dtype=str)) - mosaic.write(output) - assert os.path.exists(output) - -def test_ntiles(): - mosaic = fmsgridtools.MosaicObj(mosaic_file=output).read() - assert mosaic.ntiles == 6 - - -def test_gridfiles(): - mosaic2 = fmsgridtools.MosaicObj(mosaic_file=output).read() - assert all([mosaic2.gridfiles[i] == gridfiles[i] for i in range(mosaic2.ntiles)]) - os.remove(output) - - -def test_getgrid(): - - for ifile in gridfiles: make_grid(ifile) - mosaic = fmsgridtools.MosaicObj(ntiles=ntiles, gridtiles=gridtiles, gridfiles=gridfiles) - mosaic.get_grid(toradians=True, agrid=True, free_dataset=True) - def test_solo_mosaic(): - x1, y1 = np.meshgrid(np.arange(0,46,1, dtype=np.float64), np.arange(0,11,1, dtype=np.float64)) - xr.Dataset(data_vars=dict(x=xr.DataArray(x1, dims=["nyp","nxp"]), - y=xr.DataArray(y1, dims=["nyp", "nxp"])) - ).to_netcdf('grid.tile1.nc') - - x2, y2 = np.meshgrid(np.arange(45,90,1, dtype=np.float64), np.arange(0,11,1, dtype=np.float64)) - xr.Dataset(data_vars=dict(x=xr.DataArray(x2, dims=["nyp","nxp"]), - y=xr.DataArray(y2, dims=["nyp", "nxp"])) - ).to_netcdf('grid.tile2.nc') + x = np.arange(0, 46, dtype=np.float64) + y = np.arange(0, 11, dtype=np.float64) + x1, y1 = np.meshgrid(x, y) + fmsgridtools.GridObj(x=x1, y=y1).write("grid.tile1.nc") + + x = np.arange(45, 90, 1, dtype=np.float64) + y = np.arange(0, 11, 1, dtype=np.float64) + x2, y2 = np.meshgrid(x, y) + fmsgridtools.GridObj(x=x2, y=y2).write("grid.tile2.nc") runner = CliRunner() result = runner.invoke(fmsgridtools.make_mosaic.solo, ['--num_tiles', '2', @@ -100,8 +72,7 @@ def test_solo_mosaic(): '--tile_file', 'grid.tile2.nc']) assert result.exit_code == 0 - print(result.stdout) - assert 'NOTE: There are 1 contacts' in result.stdout + assert 'NOTE: There are 1 contacts' in result.stdout, print(result.stdout) os.remove('mosaic.nc') os.remove('grid.tile1.nc') os.remove('grid.tile2.nc') diff --git a/tests/shared/test_mosaicobj.py b/tests/shared/test_mosaicobj.py new file mode 100644 index 0000000..e61626d --- /dev/null +++ b/tests/shared/test_mosaicobj.py @@ -0,0 +1,148 @@ +import numpy as np +from pathlib import Path +import pytest +import xarray as xr + +import pyfms +import fmsgridtools + +ntiles = 6 +ncontacts = 12 + +ds = {} +ds["mosaic"] = xr.DataArray( + data="test_mosaic", + attrs=dict( + standard_name="grid_mosaic_spec", + children="gridtiles", + contact_regions="contacts", + grid_descriptor="", + ), +) +ds["gridlocation"] = xr.DataArray( + data="./", attrs=dict(standard_name="grid_file_location") +) +ds["gridfiles"] = xr.DataArray( + data=[f"C96.tile{i}.nc" for i in range(1, ntiles + 1)], dims=["ntiles"] +) +ds["gridtiles"] = xr.DataArray( + data=[f"tile{i}" for i in range(1, ntiles + 1)], dims=["ntiles"] +) +ds["contacts"] = xr.DataArray( + data=[ + "C384_mosaic:tile1::C384_mosaic:tile2", + "C384_mosaic:tile1::C384_mosaic:tile3", + "C384_mosaic:tile1::C384_mosaic:tile5", + "C384_mosaic:tile1::C384_mosaic:tile6", + "C384_mosaic:tile2::C384_mosaic:tile3", + "C384_mosaic:tile2::C384_mosaic:tile4", + "C384_mosaic:tile2::C384_mosaic:tile6", + "C384_mosaic:tile3::C384_mosaic:tile4", + "C384_mosaic:tile3::C384_mosaic:tile5", + "C384_mosaic:tile4::C384_mosaic:tile5", + "C384_mosaic:tile4::C384_mosaic:tile6", + "C384_mosaic:tile5::C384_mosaic:tile6", + ], + attrs=dict( + standard_name="grid_contact_spec", + contact_type="boundary", + alignment="true", + contact_index="contact_index", + orientation="orient", + ), + dims=["ncontacts"], +) +ds["contact_index"] = xr.DataArray( + data=[ + "768:768,1:768::1:1,1:768", + "1:768,768:768::1:1,768:1", + "1:1,1:768::768:1,768:768", + "1:768,1:1::1:768,768:768", + "1:768,768:768::1:768,1:1", + "768:768,1:768::768:1,1:1", + "1:768,1:1::768:768,768:1", + "768:768,1:768::1:1,1:768", + "1:768,768:768::1:1,768:1", + "1:768,768:768::1:768,1:1", + "768:768,1:768::768:1,1:1", + "768:768,1:768::1:1,1:768", + ], + dims=["ncontacts"], + attrs=dict(standard_name="starting_ending_point_index_of_contact"), +) +example_ds = xr.Dataset(data_vars=ds) + + +@pytest.fixture(autouse=True) +def set_fms_files(): + + inputnml = Path("input.nml") + logfile = Path("logfile.000000.out") + warnfile = Path("warnfile.000000.out") + + inputnml.touch() + + yield + + if inputnml.exists(): + inputnml.unlink() + if logfile.exists(): + logfile.unlink() + if warnfile.exists(): + warnfile.unlink() + + +def test_read_and_write(set_fms_files): + + mosaicfile = "test_mosaic.nc" + + mosaic = fmsgridtools.MosaicObj(mosaicfile=mosaicfile) + mosaic.mosaic = example_ds.mosaic.data + mosaic.gridlocation = example_ds.gridlocation.data + mosaic.gridfiles = example_ds.gridfiles.data + mosaic.gridtiles = example_ds.gridtiles.data + mosaic.contacts = example_ds.contacts.data + mosaic.contact_index = example_ds.contact_index.data + + mosaic.write() + + del mosaic + + mosaic = fmsgridtools.MosaicObj(mosaicfile=mosaicfile).read() + + assert mosaic.mosaic == example_ds.mosaic.data + assert mosaic.gridlocation == example_ds.gridlocation.data + assert mosaic.gridfiles == list(example_ds.gridfiles.data) + assert mosaic.gridtiles == list(example_ds.gridtiles.data) + assert mosaic.contacts == list(example_ds.contacts.data) + assert mosaic.contact_index == list(example_ds.contact_index.data) + + +def test_get_grid(set_fms_files): + + pyfms.fms.init() + + mosaicfile = "test_get_grid.nc" + example_ds.to_netcdf(mosaicfile) + + mosaic = fmsgridtools.MosaicObj(mosaicfile=mosaicfile).read() + + # write grid + xy = np.arange(0, 96, dtype=np.float64) + x, y = np.meshgrid(xy, xy) + for gridfile in mosaic.gridfiles: + fmsgridtools.GridObj(x=x, y=y).write(gridfile) + + # test get grid + grids = mosaic.get_grid() + for tile in mosaic.gridtiles: + np.testing.assert_array_equal(grids[tile].x, x) + np.testing.assert_array_equal(grids[tile].y, y) + + # test get grid in radians + grids_radians = mosaic.get_grid(radians=True) + for tile in mosaic.gridtiles: + np.testing.assert_array_equal(grids_radians[tile].x, np.radians(x)) + np.testing.assert_array_equal(grids_radians[tile].y, np.radians(y)) + + pyfms.fms.end() From ac48ec9b3362b4553a4967581c4890b2b2511c6a Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 14:27:41 -0500 Subject: [PATCH 08/27] turn off tests that need to be updated --- .github/workflows/test.yaml | 4 +-- tests/shared/test_mosaicobj.py | 57 +++++++++++++++++++--------------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5f432fd..380a973 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -31,5 +31,5 @@ jobs: pytest tests/shared/test_gridobj.py::test_read_write pytest tests/shared/test_gridobj.py::test_center_option pytest tests/shared/test_gridobj.py::test_to_domain - pytest tests/shared/test_xgridobj.py - pytest tests/hgrid/test_hgrid.py + #pytest tests/shared/test_xgridobj.py + #pytest tests/hgrid/test_hgrid.py diff --git a/tests/shared/test_mosaicobj.py b/tests/shared/test_mosaicobj.py index e61626d..057f749 100644 --- a/tests/shared/test_mosaicobj.py +++ b/tests/shared/test_mosaicobj.py @@ -9,8 +9,8 @@ ntiles = 6 ncontacts = 12 -ds = {} -ds["mosaic"] = xr.DataArray( +answers = {} +answers["mosaic"] = xr.DataArray( data="test_mosaic", attrs=dict( standard_name="grid_mosaic_spec", @@ -19,16 +19,16 @@ grid_descriptor="", ), ) -ds["gridlocation"] = xr.DataArray( +answers["gridlocation"] = xr.DataArray( data="./", attrs=dict(standard_name="grid_file_location") ) -ds["gridfiles"] = xr.DataArray( +answers["gridfiles"] = xr.DataArray( data=[f"C96.tile{i}.nc" for i in range(1, ntiles + 1)], dims=["ntiles"] ) -ds["gridtiles"] = xr.DataArray( +answers["gridtiles"] = xr.DataArray( data=[f"tile{i}" for i in range(1, ntiles + 1)], dims=["ntiles"] ) -ds["contacts"] = xr.DataArray( +answers["contacts"] = xr.DataArray( data=[ "C384_mosaic:tile1::C384_mosaic:tile2", "C384_mosaic:tile1::C384_mosaic:tile3", @@ -52,7 +52,7 @@ ), dims=["ncontacts"], ) -ds["contact_index"] = xr.DataArray( +answers["contact_index"] = xr.DataArray( data=[ "768:768,1:768::1:1,1:768", "1:768,768:768::1:1,768:1", @@ -70,7 +70,7 @@ dims=["ncontacts"], attrs=dict(standard_name="starting_ending_point_index_of_contact"), ) -example_ds = xr.Dataset(data_vars=ds) +answers_ds = xr.Dataset(data_vars=answers) @pytest.fixture(autouse=True) @@ -96,26 +96,30 @@ def test_read_and_write(set_fms_files): mosaicfile = "test_mosaic.nc" + # write mosaic file mosaic = fmsgridtools.MosaicObj(mosaicfile=mosaicfile) - mosaic.mosaic = example_ds.mosaic.data - mosaic.gridlocation = example_ds.gridlocation.data - mosaic.gridfiles = example_ds.gridfiles.data - mosaic.gridtiles = example_ds.gridtiles.data - mosaic.contacts = example_ds.contacts.data - mosaic.contact_index = example_ds.contact_index.data + mosaic.mosaic = answers_ds.mosaic.data + mosaic.gridlocation = answers_ds.gridlocation.data + mosaic.gridfiles = answers_ds.gridfiles.data + mosaic.gridtiles = answers_ds.gridtiles.data + mosaic.contacts = answers_ds.contacts.data + mosaic.contact_index = answers_ds.contact_index.data mosaic.write() + # delete object in order to read del mosaic + # read mosaic file mosaic = fmsgridtools.MosaicObj(mosaicfile=mosaicfile).read() - assert mosaic.mosaic == example_ds.mosaic.data - assert mosaic.gridlocation == example_ds.gridlocation.data - assert mosaic.gridfiles == list(example_ds.gridfiles.data) - assert mosaic.gridtiles == list(example_ds.gridtiles.data) - assert mosaic.contacts == list(example_ds.contacts.data) - assert mosaic.contact_index == list(example_ds.contact_index.data) + # check answers + assert mosaic.mosaic == answers_ds.mosaic.data + assert mosaic.gridlocation == answers_ds.gridlocation.data + assert mosaic.gridfiles == list(answers_ds.gridfiles.data) + assert mosaic.gridtiles == list(answers_ds.gridtiles.data) + assert mosaic.contacts == list(answers_ds.contacts.data) + assert mosaic.contact_index == list(answers_ds.contact_index.data) def test_get_grid(set_fms_files): @@ -123,23 +127,26 @@ def test_get_grid(set_fms_files): pyfms.fms.init() mosaicfile = "test_get_grid.nc" - example_ds.to_netcdf(mosaicfile) - mosaic = fmsgridtools.MosaicObj(mosaicfile=mosaicfile).read() + # write mosaic file for testing + answers_ds.to_netcdf(mosaicfile) - # write grid + # write grid for testing xy = np.arange(0, 96, dtype=np.float64) x, y = np.meshgrid(xy, xy) for gridfile in mosaic.gridfiles: fmsgridtools.GridObj(x=x, y=y).write(gridfile) - # test get grid + # read mosaic file, get grid + mosaic = fmsgridtools.MosaicObj(mosaicfile=mosaicfile).read() grids = mosaic.get_grid() + + # test grids have been read in correctly for tile in mosaic.gridtiles: np.testing.assert_array_equal(grids[tile].x, x) np.testing.assert_array_equal(grids[tile].y, y) - # test get grid in radians + # test get grids have been converted to radians correctly grids_radians = mosaic.get_grid(radians=True) for tile in mosaic.gridtiles: np.testing.assert_array_equal(grids_radians[tile].x, np.radians(x)) From ebe7eb212a77baa559e9b037af60c206cbed7820 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 14:33:03 -0500 Subject: [PATCH 09/27] minor comments --- tests/mosaic/test_mosaic.py | 96 ++++++++++++---------- tests/shared/test_gridobj.py | 155 ++++++++++++++++------------------- 2 files changed, 124 insertions(+), 127 deletions(-) diff --git a/tests/mosaic/test_mosaic.py b/tests/mosaic/test_mosaic.py index c688bec..e09765f 100644 --- a/tests/mosaic/test_mosaic.py +++ b/tests/mosaic/test_mosaic.py @@ -9,51 +9,47 @@ grid_size = 48 tile_number = 1 -gridfiles = [f'grid.tile{x}.nc' for x in range(6)] -gridtiles = [f'tile{x}' for x in range(6)] +gridfiles = [f"grid.tile{x}.nc" for x in range(6)] +gridtiles = [f"tile{x}" for x in range(6)] ntiles = 6 -output = 'test_mosaic.nc' +output = "test_mosaic.nc" def make_grid(gridfile): - xy = np.arange(0, grid_size+1, dtype=np.float64) + xy = np.arange(0, grid_size + 1, dtype=np.float64) x, y = np.meshgrid(xy, xy) - area = xr.DataArray(np.ones((grid_size, grid_size), dtype=np.float64), dims=["ny", "nx"]) + area = xr.DataArray( + np.ones((grid_size, grid_size), dtype=np.float64), dims=["ny", "nx"] + ) x = xr.DataArray(x, dims=["nyp", "nxp"]) y = xr.DataArray(y, dims=["nyp", "nxp"]) - - xr.Dataset(data_vars = {"x": x, "y":y, "area":area}).to_netcdf(gridfile) - + xr.Dataset(data_vars={"x": x, "y": y, "area": area}).to_netcdf(gridfile) + + @pytest.mark.skip def test_create_regional_input(): - nx = 1 + random.randint(1,100) % grid_size - ny = 1 + random.randint(1,100) % grid_size + nx = 1 + random.randint(1, 100) % grid_size + ny = 1 + random.randint(1, 100) % grid_size - nx_start = 1 + random.randint(1,100) % (grid_size - nx + 1) - ny_start = 1 + random.randint(1,100) % (grid_size - nx + 1) + nx_start = 1 + random.randint(1, 100) % (grid_size - nx + 1) + ny_start = 1 + random.randint(1, 100) % (grid_size - nx + 1) - xt = [nx_start+i for i in range(1, nx+1)] - yt = [ny_start+i for i in range(1, ny+1)] + xt = [nx_start + i for i in range(1, nx + 1)] + yt = [ny_start + i for i in range(1, ny + 1)] - xt_data = xr.DataArray( - data = xt, - dims = ["grid_xt_sub01"]).astype(np.float64) + xt_data = xr.DataArray(data=xt, dims=["grid_xt_sub01"]).astype(np.float64) - yt_data = xr.DataArray( - data = yt, - dims = ["grid_yt_sub01"]).astype(np.float64) + yt_data = xr.DataArray(data=yt, dims=["grid_yt_sub01"]).astype(np.float64) xr.Dataset( - data_vars={ - "grid_xt_sub01": xt_data, - "grid_yt_sub01": yt_data}).to_netcdf( - f"regional_input_file.tile{tile_number}.nc") + data_vars={"grid_xt_sub01": xt_data, "grid_yt_sub01": yt_data} + ).to_netcdf(f"regional_input_file.tile{tile_number}.nc") + - def test_solo_mosaic(): x = np.arange(0, 46, dtype=np.float64) @@ -65,28 +61,42 @@ def test_solo_mosaic(): y = np.arange(0, 11, 1, dtype=np.float64) x2, y2 = np.meshgrid(x, y) fmsgridtools.GridObj(x=x2, y=y2).write("grid.tile2.nc") - + runner = CliRunner() - result = runner.invoke(fmsgridtools.make_mosaic.solo, ['--num_tiles', '2', - '--tile_file', 'grid.tile1.nc', - '--tile_file', 'grid.tile2.nc']) - + result = runner.invoke( + fmsgridtools.make_mosaic.solo, + [ + "--num_tiles", + "2", + "--tile_file", + "grid.tile1.nc", + "--tile_file", + "grid.tile2.nc", + ], + ) + assert result.exit_code == 0 - assert 'NOTE: There are 1 contacts' in result.stdout, print(result.stdout) - os.remove('mosaic.nc') - os.remove('grid.tile1.nc') - os.remove('grid.tile2.nc') + assert "NOTE: There are 1 contacts" in result.stdout, print(result.stdout) + os.remove("mosaic.nc") + os.remove("grid.tile1.nc") + os.remove("grid.tile2.nc") + @pytest.mark.skip def test_regional_mosaic(): runner = CliRunner() - result = runner.invoke(fmsgridtools.make_mosaic, ['regional', - '--global_mosaic', - 'C48_mosaic.nc', - '--regional_file', - 'regional_input_file.tile1.nc']) + result = runner.invoke( + fmsgridtools.make_mosaic, + [ + "regional", + "--global_mosaic", + "C48_mosaic.nc", + "--regional_file", + "regional_input_file.tile1.nc", + ], + ) assert result.exit_code == 0 - assert 'Congratulations: You have successfully run regional mosaic' in result.stdout - os.remove('regional_mosaic.nc') - os.remove(f'regional_grid.tile{tile_number}.nc') - os.remove(f'regional_input_file.tile{tile_number}.nc') + assert "Congratulations: You have successfully run regional mosaic" in result.stdout + os.remove("regional_mosaic.nc") + os.remove(f"regional_grid.tile{tile_number}.nc") + os.remove(f"regional_input_file.tile{tile_number}.nc") diff --git a/tests/shared/test_gridobj.py b/tests/shared/test_gridobj.py index 216c3a7..fc6628c 100644 --- a/tests/shared/test_gridobj.py +++ b/tests/shared/test_gridobj.py @@ -20,88 +20,72 @@ nxp = nx + 1 nyp = ny + 1 -ds = SimpleNamespace() -ds.tile = xr.DataArray( - data='tile1', +answers = SimpleNamespace() +answers.tile = xr.DataArray( + data="tile1", attrs=dict( standard_name="grid_tile_spec", geometry="spherical", north_pole="0.0 90.0", projection="cube_gnomonic", discretization="logically_rectangular", - _FillValue=None - ) + _FillValue=None, + ), ) -ds.x = xr.DataArray( - data=np.full(shape=(nyp,nxp), fill_value=0.5, dtype=np.float64), +answers.x = xr.DataArray( + data=np.full(shape=(nyp, nxp), fill_value=0.5, dtype=np.float64), dims=["nyp", "nxp"], attrs=dict( - units="degree_east", - standard_name="geographic_longitude", - _FillValue=None - ) + units="degree_east", standard_name="geographic_longitude", _FillValue=None + ), ) -ds.y = xr.DataArray( - data=np.full(shape=(nyp,nxp), fill_value=1.0, dtype=np.float64), +answers.y = xr.DataArray( + data=np.full(shape=(nyp, nxp), fill_value=1.0, dtype=np.float64), dims=["nyp", "nxp"], attrs=dict( - units="degree_north", - standard_name="geographic_latitude", - _FillValue=None - ) + units="degree_north", standard_name="geographic_latitude", _FillValue=None + ), ) -ds.dx = xr.DataArray( - data=np.full(shape=(nyp,nx), fill_value=1.5, dtype=np.float64), +answers.dx = xr.DataArray( + data=np.full(shape=(nyp, nx), fill_value=1.5, dtype=np.float64), dims=["nyp", "nx"], - attrs=dict( - units="meters", - standard_name="grid_edge_x_distance", - _FillValue=None - ) + attrs=dict(units="meters", standard_name="grid_edge_x_distance", _FillValue=None), ) -ds.dy = xr.DataArray( - data=np.full(shape=(ny,nxp), fill_value=2.5, dtype=np.float64), +answers.dy = xr.DataArray( + data=np.full(shape=(ny, nxp), fill_value=2.5, dtype=np.float64), dims=["ny", "nxp"], - attrs=dict( - units="meters", - standard_name="grid_edge_y_distance", - _FillValue=None - ) + attrs=dict(units="meters", standard_name="grid_edge_y_distance", _FillValue=None), ) -ds.area = xr.DataArray( - data=np.full(shape=(ny,nx), fill_value=4.0, dtype=np.float64), +answers.area = xr.DataArray( + data=np.full(shape=(ny, nx), fill_value=4.0, dtype=np.float64), dims=["ny", "nx"], - attrs=dict( - units="m2", - standard_name="grid_cell_area", - _FillValue=None - ) + attrs=dict(units="m2", standard_name="grid_cell_area", _FillValue=None), ) -ds.angle_dx = xr.DataArray( - data=np.full(shape=(nyp,nxp), fill_value=3.0, dtype=np.float64), +answers.angle_dx = xr.DataArray( + data=np.full(shape=(nyp, nxp), fill_value=3.0, dtype=np.float64), dims=["nyp", "nxp"], attrs=dict( units="degrees_east", standard_name="grid_vertex_x_angle_WRT_geographic_east", - _FillValue=None - ) + _FillValue=None, + ), ) -ds.angle_dy = xr.DataArray( - data=np.full(shape=(nyp,nxp), fill_value=5.0, dtype=np.float64), +answers.angle_dy = xr.DataArray( + data=np.full(shape=(nyp, nxp), fill_value=5.0, dtype=np.float64), dims=["nyp", "nxp"], attrs=dict( units="degrees_east", standard_name="grid_vertex_x_angle_WRT_geographic_east", - _FillValue=None - ) + _FillValue=None, + ), ) -ds.arcx = xr.DataArray( - 'arcx', +answers.arcx = xr.DataArray( + "arcx", attrs=dict( standard_name="grid_edge_x_arc_type", north_pole="0.0 90.0", - _FillValue=None, - ) + _FillValue=None, + ), ) @@ -116,9 +100,12 @@ def set_fms_files(): yield - if inputnml.exists(): inputnml.unlink() - if logfile.exists(): logfile.unlink() - if warnfile.exists(): warnfile.unlink() + if inputnml.exists(): + inputnml.unlink() + if logfile.exists(): + logfile.unlink() + if warnfile.exists(): + warnfile.unlink() def test_read_write(set_fms_files): @@ -127,42 +114,42 @@ def test_read_write(set_fms_files): gridfile = Path("test_read_write.nc") + # write gridobj testgrid = GridObj(gridtype="cubic") - - testgrid.x = ds.x.data - testgrid.y = ds.y.data - testgrid.dx = ds.dx.data - testgrid.dy = ds.dy.data - testgrid.area = ds.area.data - testgrid.angle_dx = ds.angle_dx.data - testgrid.angle_dy = ds.angle_dy.data - testgrid.arcx = str(ds.arcx.data) - testgrid.tile = str(ds.tile.data) + testgrid.x = answers.x.data + testgrid.y = answers.y.data + testgrid.dx = answers.dx.data + testgrid.dy = answers.dy.data + testgrid.area = answers.area.data + testgrid.angle_dx = answers.angle_dx.data + testgrid.angle_dy = answers.angle_dy.data + testgrid.arcx = str(answers.arcx.data) + testgrid.tile = str(answers.tile.data) testgrid.write(gridfile) - assert gridfile.exists() del testgrid + # read gridobj testgrid = GridObj(gridfile=gridfile).read() - #test dims + # test dims assert testgrid.nx == nx assert testgrid.ny == ny assert testgrid.nxp == nxp assert testgrid.nyp == nyp - #test values - np.testing.assert_array_equal(testgrid.x, ds.x.data) - np.testing.assert_array_equal(testgrid.y, ds.y.data) - np.testing.assert_array_equal(testgrid.dx, ds.dx.data) - np.testing.assert_array_equal(testgrid.dy, ds.dy.data) - np.testing.assert_array_equal(testgrid.area, ds.area.data) - np.testing.assert_array_equal(testgrid.angle_dx, ds.angle_dx.data) - np.testing.assert_array_equal(testgrid.angle_dy, ds.angle_dy.data) - assert testgrid.arcx == str(ds.arcx.data) - assert testgrid.tile == str(ds.tile.data) + # test values + np.testing.assert_array_equal(testgrid.x, answers.x.data) + np.testing.assert_array_equal(testgrid.y, answers.y.data) + np.testing.assert_array_equal(testgrid.dx, answers.dx.data) + np.testing.assert_array_equal(testgrid.dy, answers.dy.data) + np.testing.assert_array_equal(testgrid.area, answers.area.data) + np.testing.assert_array_equal(testgrid.angle_dx, answers.angle_dx.data) + np.testing.assert_array_equal(testgrid.angle_dy, answers.angle_dy.data) + assert testgrid.arcx == str(answers.arcx.data) + assert testgrid.tile == str(answers.tile.data) gridfile.unlink() pyfms.fms.end() @@ -179,8 +166,8 @@ def test_center_option(set_fms_files): ny2p = ny2 + 1 # center points are value of 1 - x = np.array([[1,0]*nx2 + [1]]*nyp) - y = np.array([[1]*nxp, [0]*nxp]*ny2 + [[1]*nxp]) + x = np.array([[1, 0] * nx2 + [1]] * nyp) + y = np.array([[1] * nxp, [0] * nxp] * ny2 + [[1] * nxp]) GridObj(gridfile=gridfile, x=x, y=y).write() @@ -191,7 +178,7 @@ def test_center_option(set_fms_files): assert grid.nxp == nx2 + 1 assert grid.nyp == ny2 + 1 - answer = np.radians(np.ones((ny2p,nx2p), dtype=np.float64)) + answer = np.radians(np.ones((ny2p, nx2p), dtype=np.float64)) np.testing.assert_array_equal(grid.x, answer) np.testing.assert_array_equal(grid.y, answer) @@ -203,23 +190,23 @@ def test_center_option(set_fms_files): def test_to_domain(set_fms_files): nx, ny = 8, 8 - global_indices = [0, nx-1, 0, ny-1] + global_indices = [0, nx - 1, 0, ny - 1] Path("input.nml").touch() pyfms.fms.init() domain = pyfms.mpp_domains.define_domains(global_indices) - x1 = np.arange(nx+1, dtype=np.float64) - y1 = np.arange(ny+1, dtype=np.float64) + x1 = np.arange(nx + 1, dtype=np.float64) + y1 = np.arange(ny + 1, dtype=np.float64) x, y = np.meshgrid(x1, y1) area = np.ones((ny, nx), dtype=np.float64) grid = GridObj(x=x, y=y, area=area, domain=domain) grid.to_domain(domain) - x1_answer = np.arange(domain.isc, domain.iec+2, dtype=np.float64) - y1_answer = np.arange(domain.jsc, domain.jec+2, dtype=np.float64) + x1_answer = np.arange(domain.isc, domain.iec + 2, dtype=np.float64) + y1_answer = np.arange(domain.jsc, domain.jec + 2, dtype=np.float64) xanswer, yanswer = np.meshgrid(x1_answer, y1_answer) area_answer = np.ones((domain.ysize_c, domain.xsize_c), dtype=np.float64) @@ -227,4 +214,4 @@ def test_to_domain(set_fms_files): np.testing.assert_array_equal(grid.y, yanswer) np.testing.assert_array_equal(grid.area, area_answer) - pyfms.fms.end() \ No newline at end of file + pyfms.fms.end() From ddccf58864371bc9f7cf12df7e1ea00efe34ae76 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 14:37:35 -0500 Subject: [PATCH 10/27] add file exists check --- tests/shared/test_mosaicobj.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/shared/test_mosaicobj.py b/tests/shared/test_mosaicobj.py index 057f749..c3993ac 100644 --- a/tests/shared/test_mosaicobj.py +++ b/tests/shared/test_mosaicobj.py @@ -106,6 +106,7 @@ def test_read_and_write(set_fms_files): mosaic.contact_index = answers_ds.contact_index.data mosaic.write() + assert mosaicfile.exists() # delete object in order to read del mosaic From fae7d43f340a71f54e6fee42de146d9a1f428f15 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 21 Nov 2025 14:49:53 -0500 Subject: [PATCH 11/27] add files --- fmsgridtools/shared/xgridobj.py | 433 +++++++++++++++++--------------- tests/shared/test_xgridobj.py | 191 +++++++------- 2 files changed, 336 insertions(+), 288 deletions(-) diff --git a/fmsgridtools/shared/xgridobj.py b/fmsgridtools/shared/xgridobj.py index ddbd5ad..b3e5205 100755 --- a/fmsgridtools/shared/xgridobj.py +++ b/fmsgridtools/shared/xgridobj.py @@ -1,232 +1,257 @@ -import ctypes import numpy as np -import numpy.typing as npt +from pathlib import Path +from types import SimpleNamespace import xarray as xr import pyfrenctools import pyfms from fmsgridtools.shared.gridobj import GridObj -from fmsgridtools.shared.gridtools_utils import check_file_is_there from fmsgridtools.shared.mosaicobj import MosaicObj class XGridObj() : def __init__(self, - input_dir: str = "./", - src_mosaic_file: str = None, - tgt_mosaic_file: str = None, - restart_remap_file: str = None, - write_remap_file: str = "remap.nc", - src_mosaic: type[MosaicObj] = None, - tgt_mosaic: type[MosaicObj] = None, - src_grid: dict[type[GridObj]] = None, - tgt_grid: dict[type[GridObj]] = None, - dataset: type[xr.Dataset] = None, - datadict: dict = None, - on_agrid: bool = True, + input_dir: str|Path = "./", + src_mosaicfile: str = None, + tgt_mosaicfile: str = None, + remapfile: str|Path = None, + src_mosaic: MosaicObj = None, + tgt_mosaic: MosaicObj = None, + src_gridfile: str|Path = None, + tgt_gridfile: str|Path = None, + src_grid: dict[str, GridObj] = None, + tgt_grid: dict[str, GridObj] = None, + tgt_tile: str = "tile1", + src_mask: dict[str, np.ndarray] = None, + tgt_mask: dict[str, np.ndarray] = None, order: int = 1, + domain: pyfms.Domain = None, on_gpu: bool = False): - self.input_dir = input_dir - self.src_mosaic_file = src_mosaic_file - self.tgt_mosaic_file = tgt_mosaic_file - self.restart_remap_file = restart_remap_file - self.write_remap_file = write_remap_file - self.src_mosaic = src_mosaic - self.tgt_mosaic = tgt_mosaic - self.src_grid = src_grid - self.tgt_grid = tgt_grid + + """Create an XGridObj container for building/reading remap/interp data. + + Args: + input_dir (str|Path): Base directory for input files. + src_mosaicfile (str): Source mosaic filename (optional). + tgt_mosaicfile (str): Target mosaic filename (optional). + remapfile (str|Path): Remap weights filename (optional). + src_mosaic (MosaicObj): Optional pre-built source MosaicObj. + tgt_mosaic (MosaicObj): Optional pre-built target MosaicObj. + src_gridfile (str|Path): Source grid filename (optional). + tgt_gridfile (str|Path): Target grid filename (optional). + src_grid (dict[str, GridObj]): Optional dict of source GridObj keyed by tile name. + tgt_grid (dict[str, GridObj]): Optional dict of target GridObj keyed by tile name. + tgt_tile (str): Name of the target tile to use (default: "tile1"). + src_mask (dict[str, np.ndarray]): Optional masks for source tiles. + tgt_mask (dict[str, np.ndarray]): Optional masks for the target grid. + order (int): Interpolation order (currently stored; semantics depend on downstream code). + on_gpu (bool): If True, use GPU-based building of xgrid via pyfrenctools. + + The constructed object stores container namespaces for source (`self.src`) and + target (`self.tgt`) data and an `interps` mapping that is populated by + `read` or `get_interp`. + """ + + self.input_dir: str|Path = Path(input_dir) + self.src = SimpleNamespace( + mosaicfile = src_mosaicfile, + gridfile = src_gridfile, + ntiles = None, + mosaic = src_mosaic, + grid = src_grid, + mask = src_mask, + domain = None + ) + self.tgt = SimpleNamespace( + tile = tgt_tile, + mosaicfile = tgt_mosaicfile, + gridfile = tgt_gridfile, + mosaic = tgt_mosaic, + grid = tgt_grid, + mask = tgt_mask, + domain = domain + ) + + self.remapfile: str|Path = remapfile self.order = order self.on_gpu = on_gpu - self.on_agrid = on_agrid - self.dataset = dataset - self.datadict = datadict - - self._srcinfoisthere = False - self._tgtinfoisthere = False - - if self._check_restart_remap_file(): return - if self.datadict is not None: return - if self.dataset is not None: return - - self._check_grid() - self._check_mosaic() - self._check_mosaic_file() - - if not self._srcinfoisthere or not self._tgtinfoisthere: - raise RuntimeError("Please provide grid information") - - - def read(self, infile: str = None): - - if infile is None: - if self.restart_remap_file is None: - raise RuntimeError("must provide the input remap file for reading") - infile = self.restart_remap_file - - self.dataset = xr.open_dataset(infile) - - for key in self.dataset.data_vars.keys(): - setattr(self, key, self.dataset[key].values) - - for key in self.dataset.sizes: - setattr(self, key, self.dataset.sizes[key]) - - - def write(self, outfile: str = None): - - if outfile is None: - outfile = self.write_remap_file - if self.dataset is None: - if self.datadict is not None: - self.to_dataset() - - for tgt_tile in self.dataset: - concat_dataset = xr.concat([self.dataset[tgt_tile][src_tile] for src_tile in self.dataset[tgt_tile]], dim="nxcells") - - concat_dataset.to_netcdf(outfile) - - - def to_dataset(self): + self.interps: pyfms.ConserveInterp|dict[str, pyfms.ConserveInterp] = None + + + def read(self, input_dir: Path|str = None, remapfile: Path|str = None, domain: pyfms.Domain = None): + + """ + read remap file and store as pyfms.ConserveInterp objects + """ + + if input_dir is None: + input_dir = self.input_dir + input_dir = Path(input_dir) + + if remapfile is None: + if self.remapfile is None: + print("specify remapfile") + remapfile = input_dir/self.remapfile + remapfile = input_dir/Path(remapfile) + + if remapfile.exists(): + itile = 1 + self.interps = {} + for src_tile in self.src.grid: + interp_id = pyfms.horiz_interp.read_weights_conserve( + weight_filename=str(remapfile), + weight_file_src="fregrid", + nlon_src=self.src.grid[src_tile].nx, + nlat_src=self.src.grid[src_tile].ny, + nlon_tgt=self.tgt.grid.nx, + nlat_tgt=self.tgt.grid.ny, + domain=domain, + src_tile = itile, + save_weights_as_fregrid=True + ) + self.interps[src_tile] = pyfms.ConserveInterp( + interp_id, + weights_as_fregrid=True) + itile += 1 - if self.datadict is None: raise OSError("datadict is None") - datadict = self.datadict - self.dataset = {} + def write(self, output_dir: Path|str = "./", outfile: str|Path = None): - for tgt_tile in datadict: - self.dataset[tgt_tile] = {} - for src_tile in datadict[tgt_tile]: + """ + write remap file + """ - thisdict = datadict[tgt_tile][src_tile] - dataset = self.dataset[tgt_tile][src_tile] = xr.Dataset() + is_root_pe = pyfms.mpp.pe() == pyfms.mpp.root_pe() - dataset["src_cell"] = xr.DataArray(np.column_stack((thisdict['src_i']+1, thisdict['src_j']+1)), - dims=["nxcells", "two"], - attrs={"src_cell": "parent cell indices in src mosaic", - "_FillValue": False} + if self.tgt.domain is None: + i_src = {itile: self.interps[itile].i_src for itile in self.interps} + j_src = {itile: self.interps[itile].j_src for itile in self.interps} + i_dst = {itile: self.interps[itile].i_dst for itile in self.interps} + j_dst = {itile: self.interps[itile].j_dst for itile in self.interps} + xgrid_area = {itile: self.interps[itile].xgrid_area for itile in self.interps} + else: + i_src, j_src, i_dst, j_dst, xgrid_area = {}, {}, {}, {}, {} + for src_tile in self.interps: + interp = self.interps[src_tile] + nxgrids = pyfms.mpp.gather(np.array([interp.nxgrid], dtype=np.int32)) + i_src[src_tile] = pyfms.mpp.gather(interp.i_src, ssize=interp.nxgrid, rsize=nxgrids) + j_src[src_tile] = pyfms.mpp.gather(interp.j_src, ssize=interp.nxgrid, rsize=nxgrids) + i_dst[src_tile] = pyfms.mpp.gather(interp.i_dst, ssize=interp.nxgrid, rsize=nxgrids) + j_dst[src_tile] = pyfms.mpp.gather(interp.j_dst, ssize=interp.nxgrid, rsize=nxgrids) + xgrid_area[src_tile] = pyfms.mpp.gather(interp.xgrid_area, ssize=interp.nxgrid, rsize=nxgrids) + + if pyfms.mpp.pe() == pyfms.mpp.root_pe(): + + if outfile is None: + print("writing remap file to remap.nc") + outfile = Path(output_dir)/"remap.nc" + else: + outfile = Path(output_dir)/outfile + + datasets, tile1 = [], 1 + for src_tile in self.interps: + nxgrid = i_src[src_tile].size + dataset = xr.Dataset() + dataset["tile1"] = xr.DataArray( + np.full(nxgrid, tile1), + dims=["ncells"], + attrs={"standard_name": "tile_number_in_mosaic1"} ) - dataset["tgt_cell"] = xr.DataArray(np.column_stack((thisdict['tgt_i']+1, thisdict['tgt_j']+1)), - dims=["nxcells", "two"], - attrs={"tgt_cell": "parent cell indices in tgt mosaic", - "_FillValue": False}) - dataset["xarea"] = xr.DataArray(thisdict['xarea'], - dims=["nxcells"], - attrs={"xarea": "exchange grid area", - "_FillValue": False} + tile1 += 1 + dataset["tile1_cell"] = xr.DataArray( + np.column_stack((i_src[src_tile]+1, j_src[src_tile]+1)), + dims=["ncells", "two"], + attrs={"standard_name": "parent_cell_indices_in_mosaic1"} ) - - - def create_xgrid(self, src_mask: dict[str,npt.NDArray] = None, tgt_mask: dict[str, npt.NDArray] = None) -> dict: - - if self.order not in (1,2): - raise RuntimeError("conservative order must be 1 or 2") - - if self.on_gpu: - create_xgrid_2dx2d_order1 = pyfrenctools.create_xgrid.get_2dx2d_order1_gpu - else: - create_xgrid_2dx2d_order1 = pyfrenctools.create_xgrid.get_2dx2d_order1 - - if self.datadict is None: self.datadict = {} - - for tgt_tile in self.tgt_grid: - - self.datadict[tgt_tile], itile = {}, 1 - - itgt_mask = None if tgt_mask is None else tgt_mask[tgt_tile] - - for src_tile in self.src_grid: - - isrc_mask = None if src_mask is None else src_mask[src_tile] - - xgrid_out = create_xgrid_2dx2d_order1( - src_nlon = self.src_grid[src_tile].nx, - src_nlat = self.src_grid[src_tile].ny, - tgt_nlon = self.tgt_grid[tgt_tile].nx, - tgt_nlat = self.tgt_grid[tgt_tile].ny, - src_lon=self.src_grid[src_tile].x, - src_lat=self.src_grid[src_tile].y, - tgt_lon=self.tgt_grid[tgt_tile].x, - tgt_lat=self.tgt_grid[tgt_tile].y, - src_mask=isrc_mask, - tgt_mask=itgt_mask + dataset["tile2_cell"] = xr.DataArray( + np.column_stack((i_dst[src_tile]+1, j_dst[src_tile]+1)), + dims=["ncells", "two"], + attrs={"standard_name": "parent_cell_indices_in_mosaic2"} + ) + dataset["xgrid_area"] = xr.DataArray( + xgrid_area[src_tile], + dims=["ncells"], + attrs={"standard_name": "exchange_grid_area", "units": "m2"} + ) + datasets.append(xr.Dataset(dataset)) + + dataset = xr.concat(datasets, dim="ncells") + encoding = {variable: {"_FillValue": None} for variable in dataset} + dataset.to_netcdf(outfile,encoding=encoding) + + pyfms.mpp.sync() + + + def get_interp(self) -> dict: + + self.interps = {} + + for src_tile in self.src.grid: + src_grid = self.src.grid[src_tile] + src_mask = None if self.src.mask is None else self.src.mask[src_tile] + if self.on_gpu: + xdict = pyfrenctools.create_xgrid.get_2dx2d_order1_gpu( + src_nlon=src_grid.nx, + src_nlat=src_grid.ny, + tgt_nlon=self.tgt.grid.nx, + tgt_nlat=self.tgt.grid.ny, + src_lon=src_grid.x, + src_lat=src_grid.y, + tgt_lon=self.tgt.grid.x, + tgt_lat=self.tgt.grid.y, + src_mask=src_mask, + tgt_mask=self.tgt.mask + ) + interp = pyfms.ConserveInterp() + interp.nxgrid = xdict["ncells"], + interp.i_src = xdict["src_i"], + interp.j_src = xdict["src_j"], + interp.i_tgt = xdict["tgt_i"], + interp.j_tgt = xdict["tgt_j"], + interp.xgrid_area = xdict["xarea"] + self.interps[src_tile] = interp + else: + interp_id = pyfms.horiz_interp.get_weights( + lon_in=src_grid.x, + lat_in=src_grid.y, + lon_out=self.tgt.grid.x, + lat_out=self.tgt.grid.y, + mask_in=src_mask, + mask_out=self.tgt.mask, + is_latlon_in=False, + is_latlon_out=False, + save_weights_as_fregrid=True, + convert_cf_order=False, + interp_method="conservative" + ) + self.interps[src_tile] = pyfms.ConserveInterp( + interp_id, + weights_as_fregrid=True ) - nxcells = xgrid_out["nxcells"] - if nxcells > 0: - xgrid_out["tile"] = np.full(nxcells, itile, dtype=np.int32) - self.datadict[tgt_tile][src_tile] = xgrid_out - itile = itile + 1 - - - def to_dataset_raw(self): - - if self.datadict is None: raise RunTimeError("datadict is None") - - for tgt_tile in self.datadict: - for src_tile in self.datadict[tgt_tile]: - src_i = xr.DataArray(self.datadict[tgt_tile][src_tile]["src_i"], dims=["nxcells"], - attrs={"src_i": "parent longitudinal (x) cell indices in src_mosaic"}) - src_j = xr.DataArray(self.datadict[tgt_tile][src_tile]["src_j"], dims=["nxcells"], - attrs={"src_j": "parent latitudinal (y) cell indices in src_mosaic"}) - tgt_i = xr.DataArray(self.datadict[tgt_tile][src_tile]["tgt_i"], dims=["nxcells"], - attrs={"src_i": "parent longitudinal (x) cell indices in src_mosaic"}) - tgt_j = xr.DataArray(self.datadict[tgt_tile][src_tile]["tgt_j"], dims=["nxcells"], - attrs={"src_j": "parent latitudinal (y) cell indices in src_mosaic"}) - xarea = xr.DataArray(self.datadict[tgt_i][src_i]["xarea"], dims=["nxcells"], - attrs={"xarea":"exchange grid cell area (m2)"}) - - self.dataset[tgt_tile][src_tile] = xr.DataSet(data_vars={"src_i": src_i, - "src_j": src_j, - "tgt_i": tgt_i, - "tgt_j": tgt_j, - "xarea": xarea}) - - def _check_restart_remap_file(self): - if self.restart_remap_file is not None : - check_file_is_there(self.restart_remap_file) - self.read() - return True - return False - - - def _check_grid(self): - - if self.src_grid is not None: self._srcinfoisthere = True - if self.tgt_grid is not None: self._tgtinfoisthere = True - - - def _check_mosaic(self): - - if self.src_mosaic is None: return - if self.tgt_mosaic is None: return - - if self.src_mosaic.grid is None: - self.src_mosaic.get_grid(toradians=True, agrid=self.on_agrid, free_dataset=True) - if self.tgt_mosaic.grid is None: - self.tgt_mosaic.get_grid(toradians=True, agrid=self.on_agrid, free_dataset=True) - - self.src_grid = self.src_mosaic.grid - self.tgt_grid = self.tgt_mosaic.grid - - self._srcinfoisthere = True - self._tgtinfoisthere = True - - - def _check_mosaic_file(self): - - if self.src_mosaic_file is not None: - self.src_grid = MosaicObj(input_dir=self.input_dir, - mosaic_file=self.src_mosaic_file).read().get_grid(toradians=True, - agrid=self.on_agrid, - free_dataset=True) - self._srcinfoisthere = True + def get_parents(self): + + input_dir = self.input_dir + for parent in [self.src, self.tgt]: + if parent.grid is None: + if parent.mosaic is None: + if parent.mosaicfile is None: + raise RuntimeError("can't get grid") + parent.mosaic = MosaicObj( + input_dir=input_dir, + mosaicfile=parent.mosaicfile + ).read() + parent.grid = parent.mosaic.get_grid( + input_dir=input_dir, + center=True, + radians=True, + domain=parent.domain + ) + else: + print("parent grid exists") - if self.tgt_mosaic_file is not None: - self.tgt_grid = MosaicObj(input_dir=self.input_dir, - mosaic_file=self.tgt_mosaic_file).read().get_grid(toradians=True, - agrid=self.on_agrid, - free_dataset=True) - self._tgtinfoisthere = True + self.tgt.grid = self.tgt.grid[self.tgt.tile] + self.src.ntiles = len(self.src.grid) diff --git a/tests/shared/test_xgridobj.py b/tests/shared/test_xgridobj.py index 64c69d6..1d56d59 100755 --- a/tests/shared/test_xgridobj.py +++ b/tests/shared/test_xgridobj.py @@ -1,108 +1,131 @@ -import os +""" +test functionalities in xgridobj +""" + +from types import SimpleNamespace import numpy as np import pytest -import xarray as xr +import pyfms import fmsgridtools +src = SimpleNamespace( + ntiles=6, + nx=12, + ny=24, + dxy=1, + mosaicfile="src_mosaic.nc", + gridfile="src_grid" +) +tgt = SimpleNamespace( + ntiles=1, + nx=24, + ny=48, + dxy=0.5, + mosaicfile="tgt_mosaic.nc", + gridfile="tgt_grid" +) + +nxgrid_per_tile = tgt.nx//2 * tgt.ny//2 +nxgrid = nxgrid_per_tile * 6 +remapfile = "test_remap.nc" + + +def make_testfiles(): + + """ + make mosaic and grid files for testing + """ + + # write mosaic + for parent in [src, tgt]: + fmsgridtools.MosaicObj( + gridtiles=[f"tile{i}" for i in range(1, parent.ntiles+1)], + gridfiles=[f"{parent.gridfile}.tile{i}.nc" for i in range(1, parent.ntiles+1)] + ).write(parent.mosaicfile) + + # write grid + for parent in [src, tgt]: + for itile in range(1, parent.ntiles+1): + x1 = np.array([i*parent.dxy for i in range(parent.nx+1)], dtype=np.float64) + y1 = np.array([j*parent.dxy for j in range(parent.ny+1)], dtype=np.float64) + x, y = np.meshgrid(x1, y1) + fmsgridtools.GridObj(x=x, y=y).write(parent.gridfile + f".tile{itile}.nc") + + +#@pytest.mark.parametrize("on_gpu", [False, True]) +def xgridobj_test(on_gpu: bool = False): + + """ + tests generating the exchange grid + tests reading and write exchange grid + """ + + pyfms.fms.init(ndomain=4) + pyfms.horiz_interp.init(ninterp=src.ntiles) + + domain = pyfms.mpp_domains.define_domains([0, tgt.nx-1, 0, tgt.ny-1]) + + if pyfms.mpp.pe() == pyfms.mpp.root_pe(): + make_testfiles() + pyfms.mpp.sync() + + xgrid = fmsgridtools.XGridObj( + src_mosaicfile=src.mosaicfile, + tgt_mosaicfile=tgt.mosaicfile, + domain=domain + ) -def generate_mosaic(nx: int = 90, ny: int = 45, refine: int = 2): - - xstart, xend = 0, 180 - ystart, yend = -45, 45 - - x_src = np.linspace(xstart, xend, nx+1) - y_src = np.linspace(ystart, yend, ny+1) - x_src, y_src = np.meshgrid(x_src, y_src) - - x_tgt = np.linspace(xstart, xend, nx*refine+1) - y_tgt = np.linspace(ystart, yend, ny*refine+1) - x_tgt, y_tgt = np.meshgrid(x_tgt, y_tgt) - - area_src = np.ones((ny, nx), dtype=np.float64) - area_tgt = np.ones((ny*refine, nx*refine), dtype=np.float64) - - for ifile in ("src", "tgt"): - mosaicfile = ifile + "_mosaic.nc" - gridfile = ifile + "_grid.nc" - gridlocation = "./" - gridtile = "tile1" - xr.Dataset(data_vars=dict(mosaic=mosaicfile.encode(), - gridlocation=gridlocation.encode(), - gridfiles=(["ntiles"], [gridfile.encode()]), - gridtiles=(["ntiles"], [gridtile.encode()])) - ).to_netcdf(mosaicfile) - + xgrid.get_parents() + xgrid.get_interp() + xgrid.write(outfile=remapfile) - for (x, y, area, prefix) in [(x_src, y_src, area_src, "src"), (x_tgt, y_tgt, area_tgt, "tgt")]: - xr.Dataset(data_vars=dict(x=(["nyp", "nxp"], x), - y=(["nyp", "nxp"], y), - area=(["ny", "nx"], area)) - ).to_netcdf(prefix+"_grid.nc") + pyfms.horiz_interp.end() + del xgrid + pyfms.horiz_interp.init(ninterp=src.ntiles) -def remove_mosaic(): - os.remove("src_grid.nc") - os.remove("tgt_grid.nc") - os.remove("src_mosaic.nc") - os.remove("tgt_mosaic.nc") - os.remove("remap.nc") + xgrid = fmsgridtools.XGridObj( + src_mosaicfile=src.mosaicfile, + tgt_mosaicfile=tgt.mosaicfile, + remapfile=remapfile) + xgrid.get_parents() + xgrid.read(remapfile=remapfile) -@pytest.mark.parametrize("on_gpu", [False, True]) -def test_create_xgrid(on_gpu): - nx, ny, refine = 45, 45, 2 - generate_mosaic(nx=nx, ny=ny, refine=refine) + area = fmsgridtools.GridObj( + gridfile=tgt.gridfile + ".tile1.nc").read(center=True, radians=True).get_fms_area() - xgrid = fmsgridtools.XGridObj(src_mosaic_file="src_mosaic.nc", - tgt_mosaic_file="tgt_mosaic.nc", - on_gpu=on_gpu, - on_agrid=False - ) - xgrid.create_xgrid() - xgrid.to_dataset() - xgrid.dataset["tile1"]["tile1"].to_netcdf("remap.nc") - - del xgrid - - xgrid = fmsgridtools.XGridObj(restart_remap_file="remap.nc") + for tile in xgrid.interps: - #check nxcells - nxcells = nx * refine * ny * refine - assert xgrid.nxcells == nxcells + interp = xgrid.interps[tile] + i_src = interp.i_src + j_src = interp.j_src + i_dst = interp.i_dst + j_dst = interp.j_dst - #check parent input cells - answer_i = [i+1 for i in range(nx) for ixcells in range(refine*refine)]*ny - answer_j = [j+1 for j in range(ny) for i in range(nx*refine) for ixcells in range(refine)] + assert interp.nxgrid == tgt.nx//2 * tgt.ny//2, errmsg.format(tile, "N/A", nxgrid, interp.nxgrid) - src_i = [xgrid.src_cell[i][0] for i in range(nxcells)] - src_j = [xgrid.src_cell[i][1] for i in range(nxcells)] - - assert src_i == answer_i - assert src_j == answer_j + for i in range(interp.nxgrid): - #check parent output cells - answer_i = [] - for j in range(ny): - for i in range(nx): - answer_i += [refine*i + ixcell + 1 for ixcell in range(refine)]*refine + idd, jdd = i_dst[i], j_dst[i] + assert i_src[i] == idd//2, f"xcell {i}, i_dst={idd}, j_dst={jdd}" + assert j_src[i] == jdd//2, f"xcell {i}, i_dst={idd}, j_dst={jdd}" - answer_j = [] - for j in range(ny): - for i in range(nx): - for ixcell in range(refine): - answer_j += [j*refine + ixcell + 1]*refine - - tgt_i = [xgrid.tgt_cell[i][0] for i in range(nxcells)] - tgt_j = [xgrid.tgt_cell[i][1] for i in range(nxcells)] + np.testing.assert_almost_equal( + interp.xgrid_area[i], + area[jdd, idd], + decimal=2, + err_msg=f"tile {tile} gridpoint {i}") - assert tgt_i == answer_i - assert tgt_j == answer_j - remove_mosaic() +def test_xgridobj_gpu(): + xgridobj_test(on_gpu=True) +def test_xgridobj_cpu(): + xgridobj_test(on_gpu=False) if __name__ == "__main__": - test_create_xgrid(on_gpu=False) + test_xgridobj_cpu() \ No newline at end of file From 62b42f29d751062536f60e44a3a9219a3693fe81 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 19 Dec 2025 12:09:59 -0500 Subject: [PATCH 12/27] saving --- fmsgridtools/shared/xgridobj.py | 256 +++++++++++++++++++------------- 1 file changed, 154 insertions(+), 102 deletions(-) diff --git a/fmsgridtools/shared/xgridobj.py b/fmsgridtools/shared/xgridobj.py index b3e5205..3f181d4 100755 --- a/fmsgridtools/shared/xgridobj.py +++ b/fmsgridtools/shared/xgridobj.py @@ -1,6 +1,7 @@ -import numpy as np +import logging from pathlib import Path -from types import SimpleNamespace +import numpy as np +import numpy.typing as npt import xarray as xr import pyfrenctools @@ -9,27 +10,68 @@ from fmsgridtools.shared.gridobj import GridObj from fmsgridtools.shared.mosaicobj import MosaicObj - -class XGridObj() : - - def __init__(self, - input_dir: str|Path = "./", - src_mosaicfile: str = None, - tgt_mosaicfile: str = None, - remapfile: str|Path = None, - src_mosaic: MosaicObj = None, - tgt_mosaic: MosaicObj = None, - src_gridfile: str|Path = None, - tgt_gridfile: str|Path = None, - src_grid: dict[str, GridObj] = None, - tgt_grid: dict[str, GridObj] = None, - tgt_tile: str = "tile1", - src_mask: dict[str, np.ndarray] = None, - tgt_mask: dict[str, np.ndarray] = None, - order: int = 1, - domain: pyfms.Domain = None, - on_gpu: bool = False): - +logger = logging.getLogger(__name__) + + +class Parent: + + def __init__( + self, + parent: str, + input_dir: str | Path = "./", + mosaicfile: str | Path = None, + mosaic: MosaicObj = None, + gridfile: str | Path = None, + ntiles: int = None, + grid: dict[str, GridObj] = None, + mask: dict[str, npt.NDArray] = None, + domain: pyfms.Domain = None, + ): + + self.parent = parent + self.input_dir = Path(input_dir) + self.mosaicfile = mosaicfile + self.mosaic = mosaic + self.gridfile = gridfile + self.ntiles = ntiles + self.grid = grid + self.mask = mask + self.domain = domain + + if self.grid is None: + if self.mosaic is None: + if self.mosaicfile is None: + logger.warning("Cannot set %s grid", self.parent) + self.mosaic = MosaicObj( + input_dir=self.input_dir, mosaicfile=self.mosaicfile + ) + self.grid = self.mosaic.get_grid( + input_dir=self.input_dir, center=True, radians=True, domain=self.domain + ) + logger.info("set grid for %s", self.parent) + + +class XGridObj: + + def __init__( + self, + input_dir: str | Path = "./", + src_mosaicfile: str = None, + tgt_mosaicfile: str = None, + remapfile: str | Path = None, + src_mosaic: MosaicObj = None, + tgt_mosaic: MosaicObj = None, + src_gridfile: str | Path = None, + tgt_gridfile: str | Path = None, + src_grid: dict[str, GridObj] = None, + tgt_grid: dict[str, GridObj] = None, + tgt_tile: str = "tile1", + src_mask: dict[str, np.ndarray] = None, + tgt_mask: dict[str, np.ndarray] = None, + order: int = 1, + domain: pyfms.Domain = None, + on_gpu: bool = False, + ): """Create an XGridObj container for building/reading remap/interp data. Args: @@ -54,35 +96,39 @@ def __init__(self, `read` or `get_interp`. """ - self.input_dir: str|Path = Path(input_dir) - self.src = SimpleNamespace( - mosaicfile = src_mosaicfile, - gridfile = src_gridfile, - ntiles = None, - mosaic = src_mosaic, - grid = src_grid, - mask = src_mask, - domain = None + self.input_dir: str | Path = Path(input_dir) + self.src = Parent( + parent="src", + input_dir=input_dir, + mosaicfile=src_mosaicfile, + gridfile=src_gridfile, + mosaic=src_mosaic, + grid=src_grid, + mask=src_mask, + domain=None, ) - self.tgt = SimpleNamespace( - tile = tgt_tile, - mosaicfile = tgt_mosaicfile, - gridfile = tgt_gridfile, - mosaic = tgt_mosaic, - grid = tgt_grid, - mask = tgt_mask, - domain = domain + self.tgt = Parent( + parent="tgt", + mosaicfile=tgt_mosaicfile, + gridfile=tgt_gridfile, + mosaic=tgt_mosaic, + grid=tgt_grid, + mask=tgt_mask, + domain=domain, ) - self.remapfile: str|Path = remapfile + self.remapfile: str | Path = remapfile self.order = order self.on_gpu = on_gpu - self.interps: pyfms.ConserveInterp|dict[str, pyfms.ConserveInterp] = None - - - def read(self, input_dir: Path|str = None, remapfile: Path|str = None, domain: pyfms.Domain = None): + self.interps: pyfms.ConserveInterp | dict[str, pyfms.ConserveInterp] = None + def read( + self, + input_dir: Path | str = None, + remapfile: Path | str = None, + domain: pyfms.Domain = None, + ): """ read remap file and store as pyfms.ConserveInterp objects """ @@ -93,33 +139,33 @@ def read(self, input_dir: Path|str = None, remapfile: Path|str = None, domain: p if remapfile is None: if self.remapfile is None: - print("specify remapfile") - remapfile = input_dir/self.remapfile - remapfile = input_dir/Path(remapfile) - - if remapfile.exists(): - itile = 1 - self.interps = {} - for src_tile in self.src.grid: - interp_id = pyfms.horiz_interp.read_weights_conserve( - weight_filename=str(remapfile), - weight_file_src="fregrid", - nlon_src=self.src.grid[src_tile].nx, - nlat_src=self.src.grid[src_tile].ny, - nlon_tgt=self.tgt.grid.nx, - nlat_tgt=self.tgt.grid.ny, - domain=domain, - src_tile = itile, - save_weights_as_fregrid=True - ) - self.interps[src_tile] = pyfms.ConserveInterp( - interp_id, - weights_as_fregrid=True) - itile += 1 + logger.error("Please specify remapfile to read") + remapfile = self.remapfile + remapfile = input_dir / Path(remapfile) + if not remapfile.exists(): + logger.error("remap file %s does not exist", self.remapfile) - def write(self, output_dir: Path|str = "./", outfile: str|Path = None): - + itile = 1 + self.interps = {} + for src_tile in self.src.grid: + interp_id = pyfms.horiz_interp.read_weights_conserve( + weight_filename=str(remapfile), + weight_file_src="fregrid", + nlon_src=self.src.grid[src_tile].nx, + nlat_src=self.src.grid[src_tile].ny, + nlon_tgt=self.tgt.grid.nx, + nlat_tgt=self.tgt.grid.ny, + domain=domain, + src_tile=itile, + save_weights_as_fregrid=True, + ) + self.interps[src_tile] = pyfms.ConserveInterp( + interp_id, weights_as_fregrid=True + ) + itile += 1 + + def write(self, output_dir: Path | str = "./", outfile: str | Path = None): """ write remap file """ @@ -131,25 +177,37 @@ def write(self, output_dir: Path|str = "./", outfile: str|Path = None): j_src = {itile: self.interps[itile].j_src for itile in self.interps} i_dst = {itile: self.interps[itile].i_dst for itile in self.interps} j_dst = {itile: self.interps[itile].j_dst for itile in self.interps} - xgrid_area = {itile: self.interps[itile].xgrid_area for itile in self.interps} + xgrid_area = { + itile: self.interps[itile].xgrid_area for itile in self.interps + } else: i_src, j_src, i_dst, j_dst, xgrid_area = {}, {}, {}, {}, {} for src_tile in self.interps: interp = self.interps[src_tile] - nxgrids = pyfms.mpp.gather(np.array([interp.nxgrid], dtype=np.int32)) - i_src[src_tile] = pyfms.mpp.gather(interp.i_src, ssize=interp.nxgrid, rsize=nxgrids) - j_src[src_tile] = pyfms.mpp.gather(interp.j_src, ssize=interp.nxgrid, rsize=nxgrids) - i_dst[src_tile] = pyfms.mpp.gather(interp.i_dst, ssize=interp.nxgrid, rsize=nxgrids) - j_dst[src_tile] = pyfms.mpp.gather(interp.j_dst, ssize=interp.nxgrid, rsize=nxgrids) - xgrid_area[src_tile] = pyfms.mpp.gather(interp.xgrid_area, ssize=interp.nxgrid, rsize=nxgrids) + nxgrids = pyfms.mpp.gather(np.array([interp.nxgrid], dtype=np.int32)) + i_src[src_tile] = pyfms.mpp.gather( + interp.i_src, ssize=interp.nxgrid, rsize=nxgrids + ) + j_src[src_tile] = pyfms.mpp.gather( + interp.j_src, ssize=interp.nxgrid, rsize=nxgrids + ) + i_dst[src_tile] = pyfms.mpp.gather( + interp.i_dst, ssize=interp.nxgrid, rsize=nxgrids + ) + j_dst[src_tile] = pyfms.mpp.gather( + interp.j_dst, ssize=interp.nxgrid, rsize=nxgrids + ) + xgrid_area[src_tile] = pyfms.mpp.gather( + interp.xgrid_area, ssize=interp.nxgrid, rsize=nxgrids + ) if pyfms.mpp.pe() == pyfms.mpp.root_pe(): if outfile is None: print("writing remap file to remap.nc") - outfile = Path(output_dir)/"remap.nc" + outfile = Path(output_dir) / "remap.nc" else: - outfile = Path(output_dir)/outfile + outfile = Path(output_dir) / outfile datasets, tile1 = [], 1 for src_tile in self.interps: @@ -158,33 +216,32 @@ def write(self, output_dir: Path|str = "./", outfile: str|Path = None): dataset["tile1"] = xr.DataArray( np.full(nxgrid, tile1), dims=["ncells"], - attrs={"standard_name": "tile_number_in_mosaic1"} + attrs={"standard_name": "tile_number_in_mosaic1"}, ) tile1 += 1 dataset["tile1_cell"] = xr.DataArray( - np.column_stack((i_src[src_tile]+1, j_src[src_tile]+1)), + np.column_stack((i_src[src_tile] + 1, j_src[src_tile] + 1)), dims=["ncells", "two"], - attrs={"standard_name": "parent_cell_indices_in_mosaic1"} + attrs={"standard_name": "parent_cell_indices_in_mosaic1"}, ) dataset["tile2_cell"] = xr.DataArray( - np.column_stack((i_dst[src_tile]+1, j_dst[src_tile]+1)), + np.column_stack((i_dst[src_tile] + 1, j_dst[src_tile] + 1)), dims=["ncells", "two"], - attrs={"standard_name": "parent_cell_indices_in_mosaic2"} + attrs={"standard_name": "parent_cell_indices_in_mosaic2"}, ) dataset["xgrid_area"] = xr.DataArray( xgrid_area[src_tile], dims=["ncells"], - attrs={"standard_name": "exchange_grid_area", "units": "m2"} + attrs={"standard_name": "exchange_grid_area", "units": "m2"}, ) datasets.append(xr.Dataset(dataset)) dataset = xr.concat(datasets, dim="ncells") encoding = {variable: {"_FillValue": None} for variable in dataset} - dataset.to_netcdf(outfile,encoding=encoding) + dataset.to_netcdf(outfile, encoding=encoding) pyfms.mpp.sync() - def get_interp(self) -> dict: self.interps = {} @@ -203,14 +260,14 @@ def get_interp(self) -> dict: tgt_lon=self.tgt.grid.x, tgt_lat=self.tgt.grid.y, src_mask=src_mask, - tgt_mask=self.tgt.mask + tgt_mask=self.tgt.mask, ) interp = pyfms.ConserveInterp() - interp.nxgrid = xdict["ncells"], - interp.i_src = xdict["src_i"], - interp.j_src = xdict["src_j"], - interp.i_tgt = xdict["tgt_i"], - interp.j_tgt = xdict["tgt_j"], + interp.nxgrid = (xdict["ncells"],) + interp.i_src = (xdict["src_i"],) + interp.j_src = (xdict["src_j"],) + interp.i_tgt = (xdict["tgt_i"],) + interp.j_tgt = (xdict["tgt_j"],) interp.xgrid_area = xdict["xarea"] self.interps[src_tile] = interp else: @@ -225,11 +282,10 @@ def get_interp(self) -> dict: is_latlon_out=False, save_weights_as_fregrid=True, convert_cf_order=False, - interp_method="conservative" + interp_method="conservative", ) self.interps[src_tile] = pyfms.ConserveInterp( - interp_id, - weights_as_fregrid=True + interp_id, weights_as_fregrid=True ) def get_parents(self): @@ -241,14 +297,10 @@ def get_parents(self): if parent.mosaicfile is None: raise RuntimeError("can't get grid") parent.mosaic = MosaicObj( - input_dir=input_dir, - mosaicfile=parent.mosaicfile + input_dir=input_dir, mosaicfile=parent.mosaicfile ).read() parent.grid = parent.mosaic.get_grid( - input_dir=input_dir, - center=True, - radians=True, - domain=parent.domain + input_dir=input_dir, center=True, radians=True, domain=parent.domain ) else: print("parent grid exists") From 94586050ca203d0ea099dae5197b73b1e0f2467d Mon Sep 17 00:00:00 2001 From: mlee03 Date: Tue, 23 Dec 2025 13:14:43 -0500 Subject: [PATCH 13/27] almost there --- fmsgridtools/shared/gridobj.py | 2 +- fmsgridtools/shared/xgridobj.py | 156 ++++++++++++++------------------ pyFMS | 2 +- setup.py | 2 +- tests/shared/test_xgridobj.py | 44 ++++----- 5 files changed, 93 insertions(+), 113 deletions(-) diff --git a/fmsgridtools/shared/gridobj.py b/fmsgridtools/shared/gridobj.py index 9a08048..44d659a 100644 --- a/fmsgridtools/shared/gridobj.py +++ b/fmsgridtools/shared/gridobj.py @@ -150,7 +150,7 @@ def to_domain(self, domain: dict = None): else: if self.domain is not None: logger.warning("Overwriting %s with %s", self.domain, domain) - self.domain = domain + self.domain = domain if not pyfms.fms.module_is_initialized(): logger.error("Please initialize pyfms first") diff --git a/fmsgridtools/shared/xgridobj.py b/fmsgridtools/shared/xgridobj.py index 3f181d4..f6823db 100755 --- a/fmsgridtools/shared/xgridobj.py +++ b/fmsgridtools/shared/xgridobj.py @@ -44,7 +44,7 @@ def __init__( logger.warning("Cannot set %s grid", self.parent) self.mosaic = MosaicObj( input_dir=self.input_dir, mosaicfile=self.mosaicfile - ) + ).read() self.grid = self.mosaic.get_grid( input_dir=self.input_dir, center=True, radians=True, domain=self.domain ) @@ -70,9 +70,10 @@ def __init__( tgt_mask: dict[str, np.ndarray] = None, order: int = 1, domain: pyfms.Domain = None, - on_gpu: bool = False, ): - """Create an XGridObj container for building/reading remap/interp data. + + """ + Create an XGridObj container for building/reading remap/interp data. Args: input_dir (str|Path): Base directory for input files. @@ -117,9 +118,9 @@ def __init__( domain=domain, ) + self.tgt_tile = tgt_tile self.remapfile: str | Path = remapfile self.order = order - self.on_gpu = on_gpu self.interps: pyfms.ConserveInterp | dict[str, pyfms.ConserveInterp] = None @@ -129,6 +130,7 @@ def read( remapfile: Path | str = None, domain: pyfms.Domain = None, ): + """ read remap file and store as pyfms.ConserveInterp objects """ @@ -146,60 +148,56 @@ def read( if not remapfile.exists(): logger.error("remap file %s does not exist", self.remapfile) - itile = 1 + if domain is None: + domain = self.tgt.domain + self.interps = {} - for src_tile in self.src.grid: + tgt_tile = self.tgt_tile + for itile, src_tile in enumerate(self.src.grid): interp_id = pyfms.horiz_interp.read_weights_conserve( weight_filename=str(remapfile), weight_file_src="fregrid", nlon_src=self.src.grid[src_tile].nx, nlat_src=self.src.grid[src_tile].ny, - nlon_tgt=self.tgt.grid.nx, - nlat_tgt=self.tgt.grid.ny, + nlon_tgt=self.tgt.grid[tgt_tile].nx, + nlat_tgt=self.tgt.grid[tgt_tile].ny, domain=domain, src_tile=itile, - save_weights_as_fregrid=True, + save_xgrid_area=True, ) - self.interps[src_tile] = pyfms.ConserveInterp( - interp_id, weights_as_fregrid=True - ) - itile += 1 + self.interps[src_tile] = pyfms.ConserveInterp(interp_id, save_xgrid_area=True) + + + def gather(self): + + """ + gathers xgrid + """ + + isc, jsc = self.tgt.domain.isc, self.tgt.domain.jsc + + global_interps = {} + for src_tile in self.interps: + interp = self.interps[src_tile] + global_interp = global_interps[src_tile] = pyfms.ConserveInterp() + nxgrids = pyfms.mpp.gather(np.array([interp.nxgrid], dtype=np.int32), rbuf_size=len(self.interps)) + global_interp.nxgrid = np.sum(nxgrids) if pyfms.mpp.pe() == pyfms.mpp.root_pe() else None + global_interp.i_src = pyfms.mpp.gatherv(interp.i_src, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.j_src = pyfms.mpp.gatherv(interp.j_src, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.i_dst = pyfms.mpp.gatherv(interp.i_dst + isc, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.j_dst = pyfms.mpp.gatherv(interp.j_dst + jsc, ssize=interp.nxgrid, rsize=nxgrids) + global_interp.xgrid_area = pyfms.mpp.gatherv(interp.xgrid_area, ssize=interp.nxgrid, rsize=nxgrids) + return global_interps def write(self, output_dir: Path | str = "./", outfile: str | Path = None): """ write remap file """ - is_root_pe = pyfms.mpp.pe() == pyfms.mpp.root_pe() - if self.tgt.domain is None: - i_src = {itile: self.interps[itile].i_src for itile in self.interps} - j_src = {itile: self.interps[itile].j_src for itile in self.interps} - i_dst = {itile: self.interps[itile].i_dst for itile in self.interps} - j_dst = {itile: self.interps[itile].j_dst for itile in self.interps} - xgrid_area = { - itile: self.interps[itile].xgrid_area for itile in self.interps - } + global_interps = self.interps else: - i_src, j_src, i_dst, j_dst, xgrid_area = {}, {}, {}, {}, {} - for src_tile in self.interps: - interp = self.interps[src_tile] - nxgrids = pyfms.mpp.gather(np.array([interp.nxgrid], dtype=np.int32)) - i_src[src_tile] = pyfms.mpp.gather( - interp.i_src, ssize=interp.nxgrid, rsize=nxgrids - ) - j_src[src_tile] = pyfms.mpp.gather( - interp.j_src, ssize=interp.nxgrid, rsize=nxgrids - ) - i_dst[src_tile] = pyfms.mpp.gather( - interp.i_dst, ssize=interp.nxgrid, rsize=nxgrids - ) - j_dst[src_tile] = pyfms.mpp.gather( - interp.j_dst, ssize=interp.nxgrid, rsize=nxgrids - ) - xgrid_area[src_tile] = pyfms.mpp.gather( - interp.xgrid_area, ssize=interp.nxgrid, rsize=nxgrids - ) + global_interps = self.gather() if pyfms.mpp.pe() == pyfms.mpp.root_pe(): @@ -209,28 +207,29 @@ def write(self, output_dir: Path | str = "./", outfile: str | Path = None): else: outfile = Path(output_dir) / outfile - datasets, tile1 = [], 1 - for src_tile in self.interps: - nxgrid = i_src[src_tile].size + datasets = [] + for tile1, src_tile in enumerate(global_interps): + + interp = global_interps[src_tile] dataset = xr.Dataset() + dataset["tile1"] = xr.DataArray( - np.full(nxgrid, tile1), + np.full(interp.nxgrid, tile1), dims=["ncells"], attrs={"standard_name": "tile_number_in_mosaic1"}, ) - tile1 += 1 dataset["tile1_cell"] = xr.DataArray( - np.column_stack((i_src[src_tile] + 1, j_src[src_tile] + 1)), + np.column_stack((interp.i_src + 1, interp.j_src + 1)), dims=["ncells", "two"], attrs={"standard_name": "parent_cell_indices_in_mosaic1"}, ) dataset["tile2_cell"] = xr.DataArray( - np.column_stack((i_dst[src_tile] + 1, j_dst[src_tile] + 1)), + np.column_stack((interp.i_dst + 1, interp.j_dst + 1)), dims=["ncells", "two"], attrs={"standard_name": "parent_cell_indices_in_mosaic2"}, ) dataset["xgrid_area"] = xr.DataArray( - xgrid_area[src_tile], + interp.xgrid_area, dims=["ncells"], attrs={"standard_name": "exchange_grid_area", "units": "m2"}, ) @@ -242,68 +241,53 @@ def write(self, output_dir: Path | str = "./", outfile: str | Path = None): pyfms.mpp.sync() - def get_interp(self) -> dict: + + def get_interp(self, on_gpu) -> dict: + + """ + call fms to compute xgrid + """ self.interps = {} + tgt_grid = list(self.tgt.grid.values())[0] for src_tile in self.src.grid: src_grid = self.src.grid[src_tile] src_mask = None if self.src.mask is None else self.src.mask[src_tile] - if self.on_gpu: + if on_gpu: xdict = pyfrenctools.create_xgrid.get_2dx2d_order1_gpu( src_nlon=src_grid.nx, src_nlat=src_grid.ny, - tgt_nlon=self.tgt.grid.nx, - tgt_nlat=self.tgt.grid.ny, + tgt_nlon=tgt_grid.nx, + tgt_nlat=tgt_grid.ny, src_lon=src_grid.x, src_lat=src_grid.y, - tgt_lon=self.tgt.grid.x, - tgt_lat=self.tgt.grid.y, + tgt_lon=tgt_grid.x, + tgt_lat=tgt_grid.y, src_mask=src_mask, tgt_mask=self.tgt.mask, ) interp = pyfms.ConserveInterp() - interp.nxgrid = (xdict["ncells"],) - interp.i_src = (xdict["src_i"],) - interp.j_src = (xdict["src_j"],) - interp.i_tgt = (xdict["tgt_i"],) - interp.j_tgt = (xdict["tgt_j"],) + interp.nxgrid = xdict["nxcells"] + interp.i_src = xdict["src_i"] + interp.j_src = xdict["src_j"] + interp.i_dst = xdict["tgt_i"] + interp.j_dst = xdict["tgt_j"] interp.xgrid_area = xdict["xarea"] self.interps[src_tile] = interp else: interp_id = pyfms.horiz_interp.get_weights( lon_in=src_grid.x, lat_in=src_grid.y, - lon_out=self.tgt.grid.x, - lat_out=self.tgt.grid.y, + lon_out=tgt_grid.x, + lat_out=tgt_grid.y, mask_in=src_mask, mask_out=self.tgt.mask, is_latlon_in=False, is_latlon_out=False, - save_weights_as_fregrid=True, + save_xgrid_area=True, convert_cf_order=False, + as_fregrid=True, interp_method="conservative", ) - self.interps[src_tile] = pyfms.ConserveInterp( - interp_id, weights_as_fregrid=True - ) - - def get_parents(self): - - input_dir = self.input_dir - for parent in [self.src, self.tgt]: - if parent.grid is None: - if parent.mosaic is None: - if parent.mosaicfile is None: - raise RuntimeError("can't get grid") - parent.mosaic = MosaicObj( - input_dir=input_dir, mosaicfile=parent.mosaicfile - ).read() - parent.grid = parent.mosaic.get_grid( - input_dir=input_dir, center=True, radians=True, domain=parent.domain - ) - else: - print("parent grid exists") - - self.tgt.grid = self.tgt.grid[self.tgt.tile] - self.src.ntiles = len(self.src.grid) + self.interps[src_tile] = pyfms.ConserveInterp(interp_id, save_xgrid_area=True) \ No newline at end of file diff --git a/pyFMS b/pyFMS index 4e828f0..dd9948b 160000 --- a/pyFMS +++ b/pyFMS @@ -1 +1 @@ -Subproject commit 4e828f0f7f1253e7be166fd8efb11608da13a282 +Subproject commit dd9948b5d2d61cfea1704667ab21a7ccb4c88664 diff --git a/setup.py b/setup.py index 00074dc..d0a3d1b 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def local_pkg(name: str, relative_path: str) -> str: "numpy", "xarray", "netCDF4", - local_pkg("pyFMS", "pyFMS"), +# local_pkg("pyFMS", "pyFMS"), local_pkg("pyfrenctools", "FREnctools_lib") ] diff --git a/tests/shared/test_xgridobj.py b/tests/shared/test_xgridobj.py index 1d56d59..65c84cf 100755 --- a/tests/shared/test_xgridobj.py +++ b/tests/shared/test_xgridobj.py @@ -14,15 +14,15 @@ ntiles=6, nx=12, ny=24, - dxy=1, + dxy=1.0, mosaicfile="src_mosaic.nc", gridfile="src_grid" ) tgt = SimpleNamespace( ntiles=1, - nx=24, - ny=48, - dxy=0.5, + nx=src.nx * 2, + ny=src.ny * 2, + dxy=src.dxy / 2.0, mosaicfile="tgt_mosaic.nc", gridfile="tgt_grid" ) @@ -31,7 +31,6 @@ nxgrid = nxgrid_per_tile * 6 remapfile = "test_remap.nc" - def make_testfiles(): """ @@ -62,10 +61,13 @@ def xgridobj_test(on_gpu: bool = False): tests reading and write exchange grid """ - pyfms.fms.init(ndomain=4) - pyfms.horiz_interp.init(ninterp=src.ntiles) + pyfms.fms.init(ndomain=None if on_gpu else 4) + pyfms.horiz_interp.init(ninterp=src.ntiles*2) - domain = pyfms.mpp_domains.define_domains([0, tgt.nx-1, 0, tgt.ny-1]) + if on_gpu: + domain = None + else: + domain = pyfms.mpp_domains.define_domains([0, tgt.nx-1, 0, tgt.ny-1]) if pyfms.mpp.pe() == pyfms.mpp.root_pe(): make_testfiles() @@ -74,27 +76,20 @@ def xgridobj_test(on_gpu: bool = False): xgrid = fmsgridtools.XGridObj( src_mosaicfile=src.mosaicfile, tgt_mosaicfile=tgt.mosaicfile, - domain=domain + domain=domain, ) - - xgrid.get_parents() - xgrid.get_interp() + xgrid.get_interp(on_gpu=on_gpu) xgrid.write(outfile=remapfile) - - pyfms.horiz_interp.end() + del xgrid - pyfms.horiz_interp.init(ninterp=src.ntiles) - xgrid = fmsgridtools.XGridObj( src_mosaicfile=src.mosaicfile, tgt_mosaicfile=tgt.mosaicfile, remapfile=remapfile) - - xgrid.get_parents() xgrid.read(remapfile=remapfile) - + #answers area = fmsgridtools.GridObj( gridfile=tgt.gridfile + ".tile1.nc").read(center=True, radians=True).get_fms_area() @@ -106,20 +101,21 @@ def xgridobj_test(on_gpu: bool = False): i_dst = interp.i_dst j_dst = interp.j_dst - assert interp.nxgrid == tgt.nx//2 * tgt.ny//2, errmsg.format(tile, "N/A", nxgrid, interp.nxgrid) + assert interp.nxgrid == tgt.nx//2 * tgt.ny//2, f"src_tile = {tile}, {interp.nxgrid}" for i in range(interp.nxgrid): - idd, jdd = i_dst[i], j_dst[i] - assert i_src[i] == idd//2, f"xcell {i}, i_dst={idd}, j_dst={jdd}" - assert j_src[i] == jdd//2, f"xcell {i}, i_dst={idd}, j_dst={jdd}" + i_d, j_d = i_dst[i], j_dst[i] + assert i_src[i] == i_d // 2 and j_src[i] == j_d // 2, f"xcell {i}, i_src={i_src[i]}, j_src={j_src[i]} i_dst={i_d}, j_dst={j_d}" np.testing.assert_almost_equal( interp.xgrid_area[i], - area[jdd, idd], + area[j_d, i_d], decimal=2, err_msg=f"tile {tile} gridpoint {i}") + pyfms.fms.end() + def test_xgridobj_gpu(): xgridobj_test(on_gpu=True) From 0c4b65d7a6ac84c7b7041b4e29369f61b4dc4c32 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Tue, 23 Dec 2025 13:36:50 -0500 Subject: [PATCH 14/27] bit more changes --- fmsgridtools/shared/xgridobj.py | 38 +++++++++++++-------------------- tests/shared/test_xgridobj.py | 3 +-- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/fmsgridtools/shared/xgridobj.py b/fmsgridtools/shared/xgridobj.py index f6823db..e3681da 100755 --- a/fmsgridtools/shared/xgridobj.py +++ b/fmsgridtools/shared/xgridobj.py @@ -118,6 +118,7 @@ def __init__( domain=domain, ) + self.tgt.grid = self.tgt.grid[tgt_tile] self.tgt_tile = tgt_tile self.remapfile: str | Path = remapfile self.order = order @@ -135,9 +136,7 @@ def read( read remap file and store as pyfms.ConserveInterp objects """ - if input_dir is None: - input_dir = self.input_dir - input_dir = Path(input_dir) + input_dir = self.input_dir if input_dir is None else Path(input_dir) if remapfile is None: if self.remapfile is None: @@ -152,15 +151,14 @@ def read( domain = self.tgt.domain self.interps = {} - tgt_tile = self.tgt_tile for itile, src_tile in enumerate(self.src.grid): interp_id = pyfms.horiz_interp.read_weights_conserve( weight_filename=str(remapfile), weight_file_src="fregrid", nlon_src=self.src.grid[src_tile].nx, nlat_src=self.src.grid[src_tile].ny, - nlon_tgt=self.tgt.grid[tgt_tile].nx, - nlat_tgt=self.tgt.grid[tgt_tile].ny, + nlon_tgt=self.tgt.grid.nx, + nlat_tgt=self.tgt.grid.ny, domain=domain, src_tile=itile, save_xgrid_area=True, @@ -189,23 +187,18 @@ def gather(self): global_interp.xgrid_area = pyfms.mpp.gatherv(interp.xgrid_area, ssize=interp.nxgrid, rsize=nxgrids) return global_interps - def write(self, output_dir: Path | str = "./", outfile: str | Path = None): + + def write(self, output_dir: Path | str = "./", outfile: str | Path = Path("remap.nc")): """ write remap file """ - if self.tgt.domain is None: - global_interps = self.interps - else: - global_interps = self.gather() + global_interps = self.interps if self.tgt.domain is None else self.gather() if pyfms.mpp.pe() == pyfms.mpp.root_pe(): - if outfile is None: - print("writing remap file to remap.nc") - outfile = Path(output_dir) / "remap.nc" - else: - outfile = Path(output_dir) / outfile + outfile = Path(output_dir) / outfile + logger.info("writing remap file to %s", outfile) datasets = [] for tile1, src_tile in enumerate(global_interps): @@ -249,7 +242,6 @@ def get_interp(self, on_gpu) -> dict: """ self.interps = {} - tgt_grid = list(self.tgt.grid.values())[0] for src_tile in self.src.grid: src_grid = self.src.grid[src_tile] @@ -258,12 +250,12 @@ def get_interp(self, on_gpu) -> dict: xdict = pyfrenctools.create_xgrid.get_2dx2d_order1_gpu( src_nlon=src_grid.nx, src_nlat=src_grid.ny, - tgt_nlon=tgt_grid.nx, - tgt_nlat=tgt_grid.ny, + tgt_nlon=self.tgt.grid.nx, + tgt_nlat=self.tgt.grid.ny, src_lon=src_grid.x, src_lat=src_grid.y, - tgt_lon=tgt_grid.x, - tgt_lat=tgt_grid.y, + tgt_lon=self.tgt.grid.x, + tgt_lat=self.tgt.grid.y, src_mask=src_mask, tgt_mask=self.tgt.mask, ) @@ -279,8 +271,8 @@ def get_interp(self, on_gpu) -> dict: interp_id = pyfms.horiz_interp.get_weights( lon_in=src_grid.x, lat_in=src_grid.y, - lon_out=tgt_grid.x, - lat_out=tgt_grid.y, + lon_out=self.tgt.grid.x, + lat_out=self.tgt.grid.y, mask_in=src_mask, mask_out=self.tgt.mask, is_latlon_in=False, diff --git a/tests/shared/test_xgridobj.py b/tests/shared/test_xgridobj.py index 65c84cf..38d1b5f 100755 --- a/tests/shared/test_xgridobj.py +++ b/tests/shared/test_xgridobj.py @@ -5,7 +5,6 @@ from types import SimpleNamespace import numpy as np -import pytest import pyfms import fmsgridtools @@ -124,4 +123,4 @@ def test_xgridobj_cpu(): xgridobj_test(on_gpu=False) if __name__ == "__main__": - test_xgridobj_cpu() \ No newline at end of file + test_xgridobj_gpu() \ No newline at end of file From 4b2ef6e9415d93573f7a69d8f91ab77ddad0cf65 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Tue, 23 Dec 2025 14:32:10 -0500 Subject: [PATCH 15/27] add set target tile --- fmsgridtools/shared/xgridobj.py | 14 +++++++++++--- tests/shared/test_xgridobj.py | 6 ++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/fmsgridtools/shared/xgridobj.py b/fmsgridtools/shared/xgridobj.py index e3681da..2285a34 100755 --- a/fmsgridtools/shared/xgridobj.py +++ b/fmsgridtools/shared/xgridobj.py @@ -65,7 +65,7 @@ def __init__( tgt_gridfile: str | Path = None, src_grid: dict[str, GridObj] = None, tgt_grid: dict[str, GridObj] = None, - tgt_tile: str = "tile1", + tgt_tile: str = None, src_mask: dict[str, np.ndarray] = None, tgt_mask: dict[str, np.ndarray] = None, order: int = 1, @@ -118,13 +118,21 @@ def __init__( domain=domain, ) - self.tgt.grid = self.tgt.grid[tgt_tile] self.tgt_tile = tgt_tile + if self.tgt_tile is not None: + self.tgt.grid = self.tgt.grid[tgt_tile] + self.remapfile: str | Path = remapfile self.order = order self.interps: pyfms.ConserveInterp | dict[str, pyfms.ConserveInterp] = None + + def set_target_tile(self, tgt_tile: str = "tile1"): + self.tgt_tile = tgt_tile + self.tgt.grid = self.tgt.grid[tgt_tile] + + def read( self, input_dir: Path | str = None, @@ -282,4 +290,4 @@ def get_interp(self, on_gpu) -> dict: as_fregrid=True, interp_method="conservative", ) - self.interps[src_tile] = pyfms.ConserveInterp(interp_id, save_xgrid_area=True) \ No newline at end of file + self.interps[src_tile] = pyfms.ConserveInterp(interp_id, save_xgrid_area=True) diff --git a/tests/shared/test_xgridobj.py b/tests/shared/test_xgridobj.py index 38d1b5f..02aa58c 100755 --- a/tests/shared/test_xgridobj.py +++ b/tests/shared/test_xgridobj.py @@ -77,6 +77,7 @@ def xgridobj_test(on_gpu: bool = False): tgt_mosaicfile=tgt.mosaicfile, domain=domain, ) + xgrid.set_target_tile("tile1") xgrid.get_interp(on_gpu=on_gpu) xgrid.write(outfile=remapfile) @@ -85,7 +86,8 @@ def xgridobj_test(on_gpu: bool = False): xgrid = fmsgridtools.XGridObj( src_mosaicfile=src.mosaicfile, tgt_mosaicfile=tgt.mosaicfile, - remapfile=remapfile) + remapfile=remapfile, + tgt_tile = "tile1") xgrid.read(remapfile=remapfile) #answers @@ -123,4 +125,4 @@ def test_xgridobj_cpu(): xgridobj_test(on_gpu=False) if __name__ == "__main__": - test_xgridobj_gpu() \ No newline at end of file + test_xgridobj_gpu() From 73f73da8cff6d8d6080e7f57794202a2e605f9e8 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Wed, 24 Dec 2025 12:25:10 -0500 Subject: [PATCH 16/27] start dataobj rewrite --- fmsgridtools/__init__.py | 1 + fmsgridtools/remap/conservative.py | 59 ++++++++++++++++++------------ fmsgridtools/remap/remap.py | 44 ++++++++++++++-------- 3 files changed, 64 insertions(+), 40 deletions(-) diff --git a/fmsgridtools/__init__.py b/fmsgridtools/__init__.py index 683eaa7..fa573aa 100644 --- a/fmsgridtools/__init__.py +++ b/fmsgridtools/__init__.py @@ -5,6 +5,7 @@ from .make_mosaic import coupler_mosaic from .make_topog import make_topog from .remap import remap +from .remap.variableobj import VariableObj from .shared.gridtools_utils import check_file_is_there, get_provenance_attrs from .shared.gridobj import GridObj from .shared.mosaicobj import MosaicObj diff --git a/fmsgridtools/remap/conservative.py b/fmsgridtools/remap/conservative.py index dac398d..97e829f 100644 --- a/fmsgridtools/remap/conservative.py +++ b/fmsgridtools/remap/conservative.py @@ -2,30 +2,41 @@ from fmsgridtools.shared.xgridobj import XGridObj -def remap(src_mosaic: str, - input_dir: str = "./", - output_dir: str = None, - input_file: str = None, - output_file: str = None, - tgt_mosaic: str = None, - tgt_nlon: int = None, - tgt_nlat: int = None, - lon_bounds: list = None, - lat_bounds: list = None, - kbounds: list = None, - tbounds: list = None, - order: int = 1, - static_file: str = None, - check_conserve: bool = False): - - #create an xgrid object - xgrid = XGridObj(input_dir, src_mosaic_file=src_mosaic, tgt_mosaic_file=tgt_mosaic) - - #create xgrid - xgrid.create_xgrid() - - #write - xgrid.write() +import pyfms + + +def remap(input_dir: str = "./", + src_mosaic: str = None, + tgt_mosaic: str = None, + input_file: str = None, + output_dir: str = "./", + output_file: str = None, + scalar_variables: list[str] = None, + lon_bounds: list = None, + lat_bounds: list = None, + kbounds: list = None, + tbounds: list = None, + order: int = 1, + check_conserve: bool = False, + gpu: bool = False): + + xgrid = XGridObj(input_dir=input_dir, src_mosaic_file = src_mosaic, tgt_mosaicfile=tgt_mosaic) + tgt_tiles = list(xgrid.tgt.grid.keys()) + nsrc_tiles = list(xgrid.src.mosaic.ntiles) + + pyfms.fms.init(ndomain=len(tgt_tiles)) + + for tgt_tile in tgt_tiles: + + pyfms.horiz_interp.init(nsrc_tiles) + + xgrid.domain = pyfms.mpp_domains.define_domains([0, xgrid.tgt.grid.nx-1, 0, xgrid.tgt.grid.ny-1]) + xgrid.set_target_tile(tgt_tile) + xgrid.get_interp() + + + + diff --git a/fmsgridtools/remap/remap.py b/fmsgridtools/remap/remap.py index d29beef..bab73b0 100644 --- a/fmsgridtools/remap/remap.py +++ b/fmsgridtools/remap/remap.py @@ -17,30 +17,42 @@ Only 1 or 2 order is supported for conservative interpolation """ ) -@click.option("--static_file", - type = click.Path(exists=True), +@click.option("--check_conserve", + type = bool, help = """ - To remap data where cell_methods = CELL_METHODS_MEAN, the static_file - will src grid cell areas should be provided + If true, output grid area conservative will be checked """ ) -@click.option("--check_conserve", +@click.option("--gpu", type = bool, help = """ - If true, output grid area conservative will be checked + If true, the exchange grid will be created on the GPU """ ) -def conservative_method(input_dir, output_dir, input_file, output_file, #common_options - src_mosaic, tgt_mosaic, tgt_nlon, tgt_nlat, #common_options - lon_bounds, lat_bounds, kbounds, tbounds, #common_options - debug, order, static_file, check_conserve): - +def conservative_method(input_dir, output_dir, input_mosaic_dir, output_mosaic_dir, + input_file, output_file, #common_options + src_mosaic, tgt_mosaic, tgt_nlon, tgt_nlat, #common_options + scalar_variables, lon_bounds, lat_bounds, + kbounds, tbounds, debug, order, check_conserve, gpu): + setlogger.setconfig("remap.log", debug) logger.info("Starting conservative remapping") - - conservative.remap(src_mosaic, input_dir, output_dir, input_file, - output_file, tgt_mosaic, tgt_nlon, tgt_nlat, - lon_bounds, lat_bounds, kbounds, tbounds, - order, static_file, check_conserve) + + xgrid = conservative.remap(input_dir=input_dir, + output_dir=output_dir, + input_mosaic_dir=input_mosaic_dir, + output_mosaic_dir=output_mosaic_dir, + input_file=input_file, + src_mosaic=src_mosaic, + tgt_mosaic=tgt_mosaic, + output_file=output_file, + scalar_variables=scalar_variables, + lon_bounds=lon_bounds, + lat_bounds=lat_bounds, + kbounds=kbounds, + tbounds=tbounds, + order=order, + check_conserve=check_conserve, + gpu=gpu) From d84ae8e999643f937ed3c6f35b46b3b62da258a5 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Wed, 24 Dec 2025 12:27:06 -0500 Subject: [PATCH 17/27] add untracked files --- fmsgridtools/remap/dataobj.py | 113 ++++++++++++++++++++++++++ fmsgridtools/remap/variableobj.py | 128 ++++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 fmsgridtools/remap/dataobj.py create mode 100644 fmsgridtools/remap/variableobj.py diff --git a/fmsgridtools/remap/dataobj.py b/fmsgridtools/remap/dataobj.py new file mode 100644 index 0000000..0342b42 --- /dev/null +++ b/fmsgridtools/remap/dataobj.py @@ -0,0 +1,113 @@ +from itertools import pairwise +from pathlib import Path +import xarray as xr + +class DimObj(): + + def __init__(self, name: str, here: bool = False, size: int = None): + self.name = name + self.here = here + self.size = size + + +class DimsObj(): + + def __init__(self): + self.x: DimObj("X") + self.y: DimObj("Y") + self.z: DimObj("Z") + self.time: DimObj("T") + + + def get(self, da_coords: xr.Coordinates): + + """ + get dimensions + """ + + dims_list = [self.x, self.y, self.z, self.time] + + for name, coord in da_coords.items(): + for dim in dims_list: + if dim.name == coord.attrs["axis"]: + dim.size = coord.size + dim.here = True + dims_list.remove(dim) + break + + +class VariableObj(): + + time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) + + def __init__(self, datafile: str, tiles: list = None, input_dir: str | Path = Path("./")): + + self.tiles = tiles + self.dims = DimsObj() + + if tiles is None: + self.datafiles = {"tile1": input_dir/Path(datafile).with_suffix(".nc")} + else: + self.datafiles = {tile: input_dir/Path(datafile.with_suffix(tile + ".nc")) for tile in tiles} + + self.attr = SimpleNamespace( + missing = None, + fill_value = None, + offset = None, + scale_factor = None + ) + + self.area = SimpleNamespace( + static_files = None, + cell_method = None, + area = None + ) + + + + def get_attributes(self, variable): + + infile = self.datafiles[0] + + if not infile.exists: + raise RuntimeError("file does not exist") + + with xr.open_dataset(infile, decode_cf=False) as dataset: + + if variable not in dataset: + raise RuntimeError("variable not found") + dataarray = dataset[variable] + + self.dims.get(dataarray.coords) + + attributes = dataarray.attrs + self.attr.missing = attributes.get("missing_value") + self.attr.fill_value = attributes.get("_FillValue") + self.attr.offset = attributes.get("add_offset") + self.attr.scale_factor = attributes.get("scale_factor") + + if "area" in str(attributes.get("cell_method")): + self._get_static_files(dataset, variable) + + + def _get_static_files(self, dataset, variable): + + # get cell measures (area type) + cell_measures = str(dataset[variable].attrs.get("cell_measures")) + if "area:" in cell_measures: + cell_measures = cell_measures.split()[1] + else: + return + + # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc + global_attrs = str(dataset.attrs.get("associated_files")) + if cell_measures in global_attrs: + global_attrs_dict = {keyval.strip(":"): areaval for keyval, areaval in pairwise(global_attrs.split())} + static_file = global_attrs_dict[cell_measures] + else: + raise RuntimeError("cannot find static file") + + if self.tiles is not None: + self.area.static_files = {tile: input_dir/static_file.replace(".nc", tile+".nc") for tile in self.tiles} + else: + self.area.static_files = {"tile1": input_dir/static_file} diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py new file mode 100644 index 0000000..7d4e7f8 --- /dev/null +++ b/fmsgridtools/remap/variableobj.py @@ -0,0 +1,128 @@ +from itertools import pairwise +from pathlib import Path +from types import SimpleNamespace + +import xarray as xr + +class DimObj(): + + def __init__(self, axis: str, name: str = None, here: bool = False, size: int = None): + self.axis = axis + self.name = name + self.here = here + self.size = size + + +class DimsObj(): + + def __init__(self): + self.x = DimObj("X") + self.y = DimObj("Y") + self.z = DimObj("Z") + self.time = DimObj("T") + + def get(self, da_coords: xr.Coordinates): + + """ + get dimensions + """ + + dims_list = [self.x, self.y, self.z, self.time] + + for name, coord in da_coords.items(): + for dim in dims_list: + if dim.axis == coord.attrs["axis"]: + dim.name = name + dim.size = coord.size + dim.here = True + dims_list.remove(dim) + break + + def __repr__(self): + + repr_str ="\n" + for obj in [self.x, self.y, self.z, self.time]: + repr_str += f"axis={obj.axis}, name={obj.name}, here={obj.here}, size={obj.size}\n" + return repr_str + + +class VariableObj(): + + time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) + + def __init__(self, datafile: str, tiles: list = None, input_dir: str | Path = Path("./")): + + self.tiles = tiles + self.dims = DimsObj() + self.input_dir = Path(input_dir) + + if tiles is None: + self.datafiles = {"tile1": input_dir/Path(datafile).with_suffix(".nc")} + else: + self.datafiles = {tile: input_dir/Path(datafile.with_suffix(tile + ".nc")) for tile in tiles} + + self.attrs = SimpleNamespace( + missing = None, + fill_value = None, + offset = None, + scale_factor = None + ) + + self.area = SimpleNamespace( + static_files = None, + cell_measures = None, + area = None + ) + + + def get_attributes(self, variable): + + infile = list(self.datafiles.values())[0] + + if not infile.exists: + raise RuntimeError("file does not exist") + + with xr.open_dataset(infile, decode_cf=False) as dataset: + + if variable not in dataset: + raise RuntimeError("variable not found") + dataarray = dataset[variable] + + self.dims.get(dataarray.coords) + + attributes = dataarray.attrs + self.attrs.missing = attributes.get("missing_value") + self.attrs.fill_value = attributes.get("_FillValue") + self.attrs.offset = attributes.get("add_offset") + self.attrs.scale_factor = attributes.get("scale_factor") + + if "area" in str(attributes.get("cell_method")): + self._get_static_files(dataset, variable) + + + def slice(self, time: int = None, z: int = None): + + + def _get_static_files(self, dataset, variable): + + # get cell measures (area type) + cell_measures = str(dataset[variable].attrs.get("cell_measures")) + if "area:" in cell_measures: + cell_measures = cell_measures.split()[1] + else: + return + + # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc + global_attrs = str(dataset.attrs.get("associated_files")) + if cell_measures in global_attrs: + global_attrs_dict = {keyval.strip(":"): areaval for keyval, areaval in pairwise(global_attrs.split())} + static_file = global_attrs_dict[cell_measures] + else: + raise RuntimeError("cannot find static file") + + if self.tiles is not None: + self.area.static_files = {tile: self.input_dir/static_file.replace(".nc", tile+".nc") for tile in self.tiles} + else: + self.area.static_files = {"tile1": self.input_dir/static_file} + + self.area.cell_measures = cell_measures \ No newline at end of file From 405b5cc5fba006fe87e89c394a294ec278e8dabc Mon Sep 17 00:00:00 2001 From: mlee03 Date: Wed, 24 Dec 2025 12:29:13 -0500 Subject: [PATCH 18/27] get rid of dataobj.py --- fmsgridtools/remap/dataobj.py | 113 ---------------------------------- 1 file changed, 113 deletions(-) delete mode 100644 fmsgridtools/remap/dataobj.py diff --git a/fmsgridtools/remap/dataobj.py b/fmsgridtools/remap/dataobj.py deleted file mode 100644 index 0342b42..0000000 --- a/fmsgridtools/remap/dataobj.py +++ /dev/null @@ -1,113 +0,0 @@ -from itertools import pairwise -from pathlib import Path -import xarray as xr - -class DimObj(): - - def __init__(self, name: str, here: bool = False, size: int = None): - self.name = name - self.here = here - self.size = size - - -class DimsObj(): - - def __init__(self): - self.x: DimObj("X") - self.y: DimObj("Y") - self.z: DimObj("Z") - self.time: DimObj("T") - - - def get(self, da_coords: xr.Coordinates): - - """ - get dimensions - """ - - dims_list = [self.x, self.y, self.z, self.time] - - for name, coord in da_coords.items(): - for dim in dims_list: - if dim.name == coord.attrs["axis"]: - dim.size = coord.size - dim.here = True - dims_list.remove(dim) - break - - -class VariableObj(): - - time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) - - def __init__(self, datafile: str, tiles: list = None, input_dir: str | Path = Path("./")): - - self.tiles = tiles - self.dims = DimsObj() - - if tiles is None: - self.datafiles = {"tile1": input_dir/Path(datafile).with_suffix(".nc")} - else: - self.datafiles = {tile: input_dir/Path(datafile.with_suffix(tile + ".nc")) for tile in tiles} - - self.attr = SimpleNamespace( - missing = None, - fill_value = None, - offset = None, - scale_factor = None - ) - - self.area = SimpleNamespace( - static_files = None, - cell_method = None, - area = None - ) - - - - def get_attributes(self, variable): - - infile = self.datafiles[0] - - if not infile.exists: - raise RuntimeError("file does not exist") - - with xr.open_dataset(infile, decode_cf=False) as dataset: - - if variable not in dataset: - raise RuntimeError("variable not found") - dataarray = dataset[variable] - - self.dims.get(dataarray.coords) - - attributes = dataarray.attrs - self.attr.missing = attributes.get("missing_value") - self.attr.fill_value = attributes.get("_FillValue") - self.attr.offset = attributes.get("add_offset") - self.attr.scale_factor = attributes.get("scale_factor") - - if "area" in str(attributes.get("cell_method")): - self._get_static_files(dataset, variable) - - - def _get_static_files(self, dataset, variable): - - # get cell measures (area type) - cell_measures = str(dataset[variable].attrs.get("cell_measures")) - if "area:" in cell_measures: - cell_measures = cell_measures.split()[1] - else: - return - - # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc - global_attrs = str(dataset.attrs.get("associated_files")) - if cell_measures in global_attrs: - global_attrs_dict = {keyval.strip(":"): areaval for keyval, areaval in pairwise(global_attrs.split())} - static_file = global_attrs_dict[cell_measures] - else: - raise RuntimeError("cannot find static file") - - if self.tiles is not None: - self.area.static_files = {tile: input_dir/static_file.replace(".nc", tile+".nc") for tile in self.tiles} - else: - self.area.static_files = {"tile1": input_dir/static_file} From f55303b5261b252ed7736f0f8b8cd06351845b4e Mon Sep 17 00:00:00 2001 From: mlee03 Date: Wed, 24 Dec 2025 12:32:13 -0500 Subject: [PATCH 19/27] test file --- tests/remap/test_dataobj.py | 104 ++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 tests/remap/test_dataobj.py diff --git a/tests/remap/test_dataobj.py b/tests/remap/test_dataobj.py new file mode 100644 index 0000000..587e5f7 --- /dev/null +++ b/tests/remap/test_dataobj.py @@ -0,0 +1,104 @@ +import numpy as np +from types import SimpleNamespace +import xarray as xr + +import fmsgridtools + +def write_files(outfile): + grid_xt, grid_yt = 100, 50 + + variable_xt = xr.DataArray( + np.arange(grid_xt, dtype=np.float64), + dims = ["grid_xt"], + attrs={ + "units": "degrees_E", + "long_name": "made up T-cell longitude", + "axis": "X" + } + ) + + variable_yt = xr.DataArray( + np.arange(grid_yt, dtype=np.float64), + dims = ["grid_yt"], + attrs={ + "units": "degrees_N", + "long_name": "made up T-cell latitude", + "axis": "Y" + } + ) + + variable1 = xr.DataArray( + np.ones((grid_yt, grid_xt), dtype=np.float64), + dims = ["grid_yt", "grid_xt"], + attrs = { + "_FillValue": -np.float64(-100), + "long_name": "test variable 1", + "units": "kg m-2", + "missing": np.float64(45.12), + "add_offset": np.float64(-24.326), + "scale_factor": np.float64(-55), + "cell_method": "area:mean time:mean", + "cell_measures": "area: pemberley_area", + "standard_name": "Mr. Darcy" + } + ) + + variable2 = xr.DataArray( + np.zeros((grid_yt, grid_xt), dtype=np.float32), + dims = ["grid_yt", "grid_xt"], + attrs = { + "_FillValue": False, + "long_name": "test variable 2", + "cell_method": "area:mean", + "standard_name": "Missing cell_measures, expected to fail", + "missing_value": -np.float32(-1.0), + "scale_factor": -np.float32(-0.05), + "offset": np.float32(0.0) + } + ) + + data_vars = { + "grid_xt": variable_xt, + "grid_yt": variable_yt, + "variable1": variable1, + "variable2": variable2 + } + + dataset = xr.Dataset( + data_vars=data_vars, + attrs={"associated_files": "pemberley_area: pemberley.nc"} + ) + + dataset.to_netcdf(outfile) + return dataset + + +def test_dataobj(): + + dataset = write_files("test.nc") + + variable1 = fmsgridtools.VariableObj(datafile="test") + variable1.get_attributes(variable="variable1") + + #check dimensions + for (name, dim) in [("grid_xt", variable1.dims.x), ("grid_yt", variable1.dims.y)]: + assert dim.name == name + assert dim.here + assert dim.size == dataset[name].size + for dim in [variable1.dims.time, variable1.dims.z]: + assert dim.name is None + assert not dim.here + assert dim.size is None + + attributes = dataset["variable1"].attrs + assert variable1.attrs.missing == attributes.get("missing_value") + assert variable1.attrs.fill_value == attributes.get("_FillValue") + assert variable1.attrs.offset == attributes.get("add_offset") + assert variable1.attrs.scale_factor == attributes.get("scale_factor") + + print(variable1.area) + + + +if __name__ == "__main__": + test_dataobj() From 97c611d040020692402b35a6883bb3ee4ad42de9 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Wed, 24 Dec 2025 13:02:40 -0500 Subject: [PATCH 20/27] add slice --- fmsgridtools/remap/variableobj.py | 53 +++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py index 7d4e7f8..e7fcb36 100644 --- a/fmsgridtools/remap/variableobj.py +++ b/fmsgridtools/remap/variableobj.py @@ -31,7 +31,7 @@ def get(self, da_coords: xr.Coordinates): for name, coord in da_coords.items(): for dim in dims_list: - if dim.axis == coord.attrs["axis"]: + if dim.axis == coord.att["axis"]: dim.name = name dim.size = coord.size dim.here = True @@ -61,11 +61,11 @@ def __init__(self, datafile: str, tiles: list = None, input_dir: str | Path = Pa else: self.datafiles = {tile: input_dir/Path(datafile.with_suffix(tile + ".nc")) for tile in tiles} - self.attrs = SimpleNamespace( - missing = None, - fill_value = None, - offset = None, - scale_factor = None + self.att = SimpleNamespace( + missing = False, + fill_value = False, + offset = False, + scale_factor = False ) self.area = SimpleNamespace( @@ -91,17 +91,40 @@ def get_attributes(self, variable): self.dims.get(dataarray.coords) attributes = dataarray.attrs - self.attrs.missing = attributes.get("missing_value") - self.attrs.fill_value = attributes.get("_FillValue") - self.attrs.offset = attributes.get("add_offset") - self.attrs.scale_factor = attributes.get("scale_factor") + self.att.missing = attributes.get("missing_value") + self.att.fill_value = attributes.get("_FillValue") + self.att.offset = attributes.get("add_offset") + self.att.scale_factor = attributes.get("scale_factor") if "area" in str(attributes.get("cell_method")): self._get_static_files(dataset, variable) - def slice(self, time: int = None, z: int = None): + def slice(self, tile: str = "tile1", timepoint: int|bool = False, klevel: int|bool = False): + #python, values above 0 are true... + + with xr.open_dataset(self.datafiles[tile], decode_cf=False) as dataset: + + slice_dict = {} + + if klevel and self.dims.z.here: slice_dict[self.dims.z.name] = klevel + if timepoint and self.dims.time.here: slice_dict[self.dims.time.name] = timepoint + + data = dataset[self.variable].isel(slice_dict) + + #missing value mask + mask = None + if self.att.missing_value: + mask = data != self.att.missing_value + + if self.att.offset: data += self.att.offset + if self.att.scale_factor: data *= self.att.scale_factor + + #zero out missing values so it doens't contribute to remapping + data = data.where(mask, 0.0, data) + + return data def _get_static_files(self, dataset, variable): @@ -113,11 +136,9 @@ def _get_static_files(self, dataset, variable): return # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc - global_attrs = str(dataset.attrs.get("associated_files")) - if cell_measures in global_attrs: - global_attrs_dict = {keyval.strip(":"): areaval for keyval, areaval in pairwise(global_attrs.split())} - static_file = global_attrs_dict[cell_measures] - else: + global_att = str(dataset.attrs.get("associated_files")) + static_file = {keyval.strip(":"): areaval for keyval, areaval in pairwise(global_att.split())}.get(cell_measures) + if static_file is None: raise RuntimeError("cannot find static file") if self.tiles is not None: From 0f37fcb4392cdc9e84663311da337bff11b2e444 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Wed, 24 Dec 2025 13:13:39 -0500 Subject: [PATCH 21/27] better --- fmsgridtools/remap/variableobj.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py index e7fcb36..21c7c37 100644 --- a/fmsgridtools/remap/variableobj.py +++ b/fmsgridtools/remap/variableobj.py @@ -27,16 +27,16 @@ def get(self, da_coords: xr.Coordinates): get dimensions """ - dims_list = [self.x, self.y, self.z, self.time] + dims_dict = {dim.axis: dim for dim in [self.x, self.y, self.z, self.time]} for name, coord in da_coords.items(): - for dim in dims_list: - if dim.axis == coord.att["axis"]: - dim.name = name - dim.size = coord.size - dim.here = True - dims_list.remove(dim) - break + try: + dim = dims_dict[coord.att.get("axis")] + dim.name = name + dim.size = coord.size + dim.here = True + except: + print("dim not found") def __repr__(self): @@ -114,7 +114,7 @@ def slice(self, tile: str = "tile1", timepoint: int|bool = False, klevel: int|bo data = dataset[self.variable].isel(slice_dict) #missing value mask - mask = None + mask = False if self.att.missing_value: mask = data != self.att.missing_value @@ -122,10 +122,11 @@ def slice(self, tile: str = "tile1", timepoint: int|bool = False, klevel: int|bo if self.att.scale_factor: data *= self.att.scale_factor #zero out missing values so it doens't contribute to remapping - data = data.where(mask, 0.0, data) + if mask: data = data.where(mask, 0.0, data) return data + def _get_static_files(self, dataset, variable): # get cell measures (area type) From 9248be115f3f8657b0fd473a40775a8ccf1e4468 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 26 Dec 2025 12:50:16 -0500 Subject: [PATCH 22/27] fileobj --- fmsgridtools/__init__.py | 2 +- fmsgridtools/remap/variableobj.py | 136 +++++++++++++++--------------- tests/remap/test_dataobj.py | 58 ++++++++++--- 3 files changed, 114 insertions(+), 82 deletions(-) diff --git a/fmsgridtools/__init__.py b/fmsgridtools/__init__.py index fa573aa..980043b 100644 --- a/fmsgridtools/__init__.py +++ b/fmsgridtools/__init__.py @@ -5,7 +5,7 @@ from .make_mosaic import coupler_mosaic from .make_topog import make_topog from .remap import remap -from .remap.variableobj import VariableObj +from .remap.variableobj import VariableObj, FileObj, DimObj, DimsObj from .shared.gridtools_utils import check_file_is_there, get_provenance_attrs from .shared.gridobj import GridObj from .shared.mosaicobj import MosaicObj diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py index 21c7c37..5f4d0e0 100644 --- a/fmsgridtools/remap/variableobj.py +++ b/fmsgridtools/remap/variableobj.py @@ -9,8 +9,8 @@ class DimObj(): def __init__(self, axis: str, name: str = None, here: bool = False, size: int = None): self.axis = axis self.name = name - self.here = here self.size = size + self.here = False class DimsObj(): @@ -21,7 +21,7 @@ def __init__(self): self.z = DimObj("Z") self.time = DimObj("T") - def get(self, da_coords: xr.Coordinates): + def get(self, da: xr.DataArray): """ get dimensions @@ -29,14 +29,14 @@ def get(self, da_coords: xr.Coordinates): dims_dict = {dim.axis: dim for dim in [self.x, self.y, self.z, self.time]} - for name, coord in da_coords.items(): + for name, coord in da.coords.items(): try: - dim = dims_dict[coord.att.get("axis")] + dim = dims_dict[coord.attrs.get("axis")] dim.name = name dim.size = coord.size dim.here = True except: - print("dim not found") + print(f"{name} dim not found") def __repr__(self): @@ -46,38 +46,55 @@ def __repr__(self): return repr_str +class FileObj(): + + def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./"): + + self.input_dir = str(input_dir) + self.tiles = tiles + self.datafiles = {tile: Path(input_dir)/Path(datafile + f".{tile}.nc") for tile in tiles} + self.static_files = {} + + with xr.open_dataset(self.datafiles[self.tiles[0]], decode_cf=False) as dataset: + # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc + associated_files = dataset.attrs.get("associated_files") + if associated_files is not None: + stringsplit = associated_files.replace(":", " ").split() + for i in range(0, len(stringsplit), 2): + self.static_files[stringsplit[i]] = { + tile: Path(input_dir)/stringsplit[i+1].replace(".nc", f".{tile}.nc") for tile in self.tiles + } + + def __repr__(self): + repr_str = "\n" + repr_str += f"input_dir = {self.input_dir}\n" + repr_str += f"tiles = {self.tiles}\n" + repr_str += f"datafiles = {self.datafiles}\n" + repr_str += f"static_files = {self.static_files}\n" + return repr_str + + class VariableObj(): time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) - def __init__(self, datafile: str, tiles: list = None, input_dir: str | Path = Path("./")): + def __init__(self, fileobj: FileObj = None): - self.tiles = tiles + self.fileobj = fileobj self.dims = DimsObj() - self.input_dir = Path(input_dir) - - if tiles is None: - self.datafiles = {"tile1": input_dir/Path(datafile).with_suffix(".nc")} - else: - self.datafiles = {tile: input_dir/Path(datafile.with_suffix(tile + ".nc")) for tile in tiles} - self.att = SimpleNamespace( - missing = False, - fill_value = False, - offset = False, - scale_factor = False - ) + self.missing = False, + self.fill_value = False, + self.offset = False, + self.scale_factor = False - self.area = SimpleNamespace( - static_files = None, - cell_measures = None, - area = None - ) + self.static_files = {} def get_attributes(self, variable): - infile = list(self.datafiles.values())[0] + tile = self.fileobj.tiles[0] + infile = self.fileobj.datafiles[tile] if not infile.exists: raise RuntimeError("file does not exist") @@ -88,63 +105,42 @@ def get_attributes(self, variable): raise RuntimeError("variable not found") dataarray = dataset[variable] - self.dims.get(dataarray.coords) + self.dims.get(dataarray) attributes = dataarray.attrs - self.att.missing = attributes.get("missing_value") - self.att.fill_value = attributes.get("_FillValue") - self.att.offset = attributes.get("add_offset") - self.att.scale_factor = attributes.get("scale_factor") + self.missing = attributes.get("missing_value") + self.fill_value = attributes.get("_FillValue") + self.offset = attributes.get("add_offset") + self.scale_factor = attributes.get("scale_factor") if "area" in str(attributes.get("cell_method")): - self._get_static_files(dataset, variable) - - - def slice(self, tile: str = "tile1", timepoint: int|bool = False, klevel: int|bool = False): - - #python, values above 0 are true... - - with xr.open_dataset(self.datafiles[tile], decode_cf=False) as dataset: - - slice_dict = {} - - if klevel and self.dims.z.here: slice_dict[self.dims.z.name] = klevel - if timepoint and self.dims.time.here: slice_dict[self.dims.time.name] = timepoint + cell_measures = str(attributes.get("cell_measures")) + if "area:" in cell_measures: + self.static_files = self.fileobj.static_files[cell_measures.split()[1]] - data = dataset[self.variable].isel(slice_dict) - #missing value mask - mask = False - if self.att.missing_value: - mask = data != self.att.missing_value + # def slice(self, tile: str = "tile1", timepoint: int|bool = False, klevel: int|bool = False): - if self.att.offset: data += self.att.offset - if self.att.scale_factor: data *= self.att.scale_factor + # #python, values above 0 are true... - #zero out missing values so it doens't contribute to remapping - if mask: data = data.where(mask, 0.0, data) + # with xr.open_dataset(self.datafiles[tile], decode_cf=False) as dataset: - return data + # slice_dict = {} + # if klevel and self.dims.z.here: slice_dict[self.dims.z.name] = klevel + # if timepoint and self.dims.time.here: slice_dict[self.dims.time.name] = timepoint - def _get_static_files(self, dataset, variable): + # data = dataset[self.variable].isel(slice_dict) - # get cell measures (area type) - cell_measures = str(dataset[variable].attrs.get("cell_measures")) - if "area:" in cell_measures: - cell_measures = cell_measures.split()[1] - else: - return + # #missing value mask + # mask = False + # if self.missing_value: + # mask = data != self.missing_value - # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc - global_att = str(dataset.attrs.get("associated_files")) - static_file = {keyval.strip(":"): areaval for keyval, areaval in pairwise(global_att.split())}.get(cell_measures) - if static_file is None: - raise RuntimeError("cannot find static file") + # if self.offset: data += self.offset + # if self.scale_factor: data *= self.scale_factor - if self.tiles is not None: - self.area.static_files = {tile: self.input_dir/static_file.replace(".nc", tile+".nc") for tile in self.tiles} - else: - self.area.static_files = {"tile1": self.input_dir/static_file} + # #zero out missing values so it doens't contribute to remapping + # if mask: data = data.where(mask, 0.0, data) - self.area.cell_measures = cell_measures \ No newline at end of file + # return data diff --git a/tests/remap/test_dataobj.py b/tests/remap/test_dataobj.py index 587e5f7..7918b49 100644 --- a/tests/remap/test_dataobj.py +++ b/tests/remap/test_dataobj.py @@ -5,7 +5,8 @@ import fmsgridtools def write_files(outfile): - grid_xt, grid_yt = 100, 50 + + grid_xt, grid_yt, nk, ntimes = 100, 50, 10, 4 variable_xt = xr.DataArray( np.arange(grid_xt, dtype=np.float64), @@ -27,6 +28,39 @@ def write_files(outfile): } ) + variable_yt = xr.DataArray( + np.arange(grid_yt, dtype=np.float64), + dims = ["grid_yt"], + attrs={ + "units": "degrees_N", + "long_name": "made up T-cell latitude", + "axis": "Y" + } + ) + + variable_pfull = xr.DataArray( + np.arange(nk, dtype=np.float64), + dims = ["pfull"], + attrs = { + "units": "mb", + "long_name": "ref full pressure level", + "axis": "Z", + "positive": "down" + } + ) + + variable_time = xr.DataArray( + np.arange(ntimes, dtype=np.float64), + dims = ["time"], + attrs= { + "units": "days since 0001-01-01 00:00:00", + "long_name": "time", + "axis": "T", + "calendar_type": "NOLEAP", + "calendar": "noleap" + } + ) + variable1 = xr.DataArray( np.ones((grid_yt, grid_xt), dtype=np.float64), dims = ["grid_yt", "grid_xt"], @@ -60,13 +94,15 @@ def write_files(outfile): data_vars = { "grid_xt": variable_xt, "grid_yt": variable_yt, + "pfull": variable_pfull, + "time": variable_time, "variable1": variable1, "variable2": variable2 } dataset = xr.Dataset( data_vars=data_vars, - attrs={"associated_files": "pemberley_area: pemberley.nc"} + attrs={"associated_files": "pemberley_area: pemberley.nc longbourn_area: longbourn.nc"} ) dataset.to_netcdf(outfile) @@ -75,10 +111,13 @@ def write_files(outfile): def test_dataobj(): - dataset = write_files("test.nc") + dataset = write_files("test.tile1.nc") + + fileobj = fmsgridtools.FileObj("test") - variable1 = fmsgridtools.VariableObj(datafile="test") + variable1 = fmsgridtools.VariableObj(fileobj=fileobj) variable1.get_attributes(variable="variable1") + print(variable1.dims) #check dimensions for (name, dim) in [("grid_xt", variable1.dims.x), ("grid_yt", variable1.dims.y)]: @@ -91,13 +130,10 @@ def test_dataobj(): assert dim.size is None attributes = dataset["variable1"].attrs - assert variable1.attrs.missing == attributes.get("missing_value") - assert variable1.attrs.fill_value == attributes.get("_FillValue") - assert variable1.attrs.offset == attributes.get("add_offset") - assert variable1.attrs.scale_factor == attributes.get("scale_factor") - - print(variable1.area) - + assert variable1.missing == attributes.get("missing_value") + assert variable1.fill_value == attributes.get("_FillValue") + assert variable1.offset == attributes.get("add_offset") + assert variable1.scale_factor == attributes.get("scale_factor") if __name__ == "__main__": From 127d252bf5aead265909987f4c60b3e08406d066 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 26 Dec 2025 13:50:37 -0500 Subject: [PATCH 23/27] slice --- fmsgridtools/remap/variableobj.py | 64 ++++++++++++++----------- tests/remap/test_dataobj.py | 77 +++++++++++++++++++------------ 2 files changed, 84 insertions(+), 57 deletions(-) diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py index 5f4d0e0..bb35732 100644 --- a/fmsgridtools/remap/variableobj.py +++ b/fmsgridtools/remap/variableobj.py @@ -78,20 +78,22 @@ class VariableObj(): time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) - def __init__(self, fileobj: FileObj = None): + def __init__(self, variable: str, fileobj: FileObj = None): - self.fileobj = fileobj + self.variable = variable + self.fileobj = fileobj self.dims = DimsObj() - self.missing = False, - self.fill_value = False, - self.offset = False, - self.scale_factor = False + self.missing_value = None, + self.fill_value = None, + self.offset = None, + self.scale_factor = None - self.static_files = {} + self.static_files: {} = None + self.data = None - def get_attributes(self, variable): + def get_attributes(self): tile = self.fileobj.tiles[0] infile = self.fileobj.datafiles[tile] @@ -101,14 +103,14 @@ def get_attributes(self, variable): with xr.open_dataset(infile, decode_cf=False) as dataset: - if variable not in dataset: + if self.variable not in dataset: raise RuntimeError("variable not found") - dataarray = dataset[variable] + dataarray = dataset[self.variable] self.dims.get(dataarray) attributes = dataarray.attrs - self.missing = attributes.get("missing_value") + self.missing_value = attributes.get("missing_value") self.fill_value = attributes.get("_FillValue") self.offset = attributes.get("add_offset") self.scale_factor = attributes.get("scale_factor") @@ -119,28 +121,34 @@ def get_attributes(self, variable): self.static_files = self.fileobj.static_files[cell_measures.split()[1]] - # def slice(self, tile: str = "tile1", timepoint: int|bool = False, klevel: int|bool = False): - - # #python, values above 0 are true... + def slice(self, tile: str = "tile1", timepoint: int = -99, klevel: int = -99): - # with xr.open_dataset(self.datafiles[tile], decode_cf=False) as dataset: + #python, values above 0 are true... - # slice_dict = {} + with xr.open_dataset(self.fileobj.datafiles[tile], decode_cf=False) as dataset: - # if klevel and self.dims.z.here: slice_dict[self.dims.z.name] = klevel - # if timepoint and self.dims.time.here: slice_dict[self.dims.time.name] = timepoint + slice_dict = {} - # data = dataset[self.variable].isel(slice_dict) + if klevel>-1 and self.dims.z.here: slice_dict[self.dims.z.name] = klevel + if timepoint>-1 and self.dims.time.here: slice_dict[self.dims.time.name] = timepoint - # #missing value mask - # mask = False - # if self.missing_value: - # mask = data != self.missing_value + self.data = dataset[self.variable].isel(slice_dict).values + return self.data - # if self.offset: data += self.offset - # if self.scale_factor: data *= self.scale_factor + + def prepare_data(self): - # #zero out missing values so it doens't contribute to remapping - # if mask: data = data.where(mask, 0.0, data) + #missing value mask + missing_value_mask = None + if self.missing_value is not None: + missing_value_mask = self.data == self.missing_value + + if self.offset: self.data += self.offset + if self.scale_factor: self.data *= self.scale_factor + + #zero out missing values so it doens't contribute to remapping + if missing_value_mask is not None: + self.data = xr.where(missing_value_mask, 0.0, self.data) - # return data + return self.data + diff --git a/tests/remap/test_dataobj.py b/tests/remap/test_dataobj.py index 7918b49..812cfa5 100644 --- a/tests/remap/test_dataobj.py +++ b/tests/remap/test_dataobj.py @@ -4,12 +4,26 @@ import fmsgridtools +nx, ny, nk, ntimes = 10, 8, 3, 4 + +answers = np.zeros((ntimes, nk, ny, nx), dtype=np.float64) +for itime in range(ntimes): + for k in range(nk): + for j in range(ny): + start = itime*1000+k*100+j*10 + answers[itime, k, j, :] = np.arange(start, start+nx, dtype=np.float64) + +missing_ijkl = [(1,1,1,1), (2,2,2,2)] +for (itime, k, j, i) in missing_ijkl: + answers[itime,k,j,i] = missing_value + + def write_files(outfile): - grid_xt, grid_yt, nk, ntimes = 100, 50, 10, 4 + variable_xt = xr.DataArray( - np.arange(grid_xt, dtype=np.float64), + np.arange(nx, dtype=np.float64), dims = ["grid_xt"], attrs={ "units": "degrees_E", @@ -19,17 +33,7 @@ def write_files(outfile): ) variable_yt = xr.DataArray( - np.arange(grid_yt, dtype=np.float64), - dims = ["grid_yt"], - attrs={ - "units": "degrees_N", - "long_name": "made up T-cell latitude", - "axis": "Y" - } - ) - - variable_yt = xr.DataArray( - np.arange(grid_yt, dtype=np.float64), + np.arange(ny, dtype=np.float64), dims = ["grid_yt"], attrs={ "units": "degrees_N", @@ -62,15 +66,15 @@ def write_files(outfile): ) variable1 = xr.DataArray( - np.ones((grid_yt, grid_xt), dtype=np.float64), - dims = ["grid_yt", "grid_xt"], + answers, + dims = ["time", "pfull", "grid_yt", "grid_xt"], attrs = { "_FillValue": -np.float64(-100), "long_name": "test variable 1", "units": "kg m-2", - "missing": np.float64(45.12), - "add_offset": np.float64(-24.326), - "scale_factor": np.float64(-55), + "missing_value": missing_value, + "add_offset": np.float64(0.0), + "scale_factor": np.float64(0.0), "cell_method": "area:mean time:mean", "cell_measures": "area: pemberley_area", "standard_name": "Mr. Darcy" @@ -78,7 +82,7 @@ def write_files(outfile): ) variable2 = xr.DataArray( - np.zeros((grid_yt, grid_xt), dtype=np.float32), + np.zeros((ny, nx), dtype=np.float32), dims = ["grid_yt", "grid_xt"], attrs = { "_FillValue": False, @@ -105,7 +109,7 @@ def write_files(outfile): attrs={"associated_files": "pemberley_area: pemberley.nc longbourn_area: longbourn.nc"} ) - dataset.to_netcdf(outfile) + dataset.to_netcdf(outfile, unlimited_dims="time") return dataset @@ -115,26 +119,41 @@ def test_dataobj(): fileobj = fmsgridtools.FileObj("test") - variable1 = fmsgridtools.VariableObj(fileobj=fileobj) - variable1.get_attributes(variable="variable1") - print(variable1.dims) + variable1 = fmsgridtools.VariableObj(variable="variable1", fileobj=fileobj) #check dimensions - for (name, dim) in [("grid_xt", variable1.dims.x), ("grid_yt", variable1.dims.y)]: + variable1.get_attributes() + dims = [ + ("grid_xt", variable1.dims.x), + ("grid_yt", variable1.dims.y), + ("time", variable1.dims.time), + ("pfull", variable1.dims.z) + ] + for (name, dim) in dims: assert dim.name == name assert dim.here assert dim.size == dataset[name].size - for dim in [variable1.dims.time, variable1.dims.z]: - assert dim.name is None - assert not dim.here - assert dim.size is None attributes = dataset["variable1"].attrs - assert variable1.missing == attributes.get("missing_value") + assert variable1.missing_value == attributes.get("missing_value") assert variable1.fill_value == attributes.get("_FillValue") assert variable1.offset == attributes.get("add_offset") assert variable1.scale_factor == attributes.get("scale_factor") + #slice + reconstruct_data = np.zeros((ntimes, nk, ny, nx), dtype=np.float64) + for itime in range(ntimes): + for k in range(nk): + sliced_data = variable1.slice(timepoint=itime, klevel=k) + np.testing.assert_equal(sliced_data, answers[itime, k, :, :]) + reconstruct_data[itime, k, :, :] = variable1.prepare_data() + + #check missing values + for (itime, k, j, i) in missing_ijkl: + assert reconstruct_data[itime,k,j,i] == np.float64(0.0) + + + if __name__ == "__main__": test_dataobj() From 5dbdf55e598533214abf5a5b70e87897d866094a Mon Sep 17 00:00:00 2001 From: mlee03 Date: Fri, 26 Dec 2025 14:13:57 -0500 Subject: [PATCH 24/27] saving --- fmsgridtools/remap/conservative.py | 30 ++++++++++++++++++---- fmsgridtools/remap/variableobj.py | 41 +++++++++++++++++++++++++----- tests/remap/test_dataobj.py | 2 -- 3 files changed, 59 insertions(+), 14 deletions(-) diff --git a/fmsgridtools/remap/conservative.py b/fmsgridtools/remap/conservative.py index 97e829f..88f6f07 100644 --- a/fmsgridtools/remap/conservative.py +++ b/fmsgridtools/remap/conservative.py @@ -1,5 +1,6 @@ import numpy as np +from fmsgridtools.remap.variableobj import FileObj, VariableObj from fmsgridtools.shared.xgridobj import XGridObj import pyfms @@ -18,23 +19,42 @@ def remap(input_dir: str = "./", tbounds: list = None, order: int = 1, check_conserve: bool = False, - gpu: bool = False): + gpu: bool = False): xgrid = XGridObj(input_dir=input_dir, src_mosaic_file = src_mosaic, tgt_mosaicfile=tgt_mosaic) tgt_tiles = list(xgrid.tgt.grid.keys()) - nsrc_tiles = list(xgrid.src.mosaic.ntiles) + src_tiles = list(xgrid.src.grid.keys()) pyfms.fms.init(ndomain=len(tgt_tiles)) for tgt_tile in tgt_tiles: - pyfms.horiz_interp.init(nsrc_tiles) - + pyfms.horiz_interp.init(len(src_tiles)) + + #get xgrid xgrid.domain = pyfms.mpp_domains.define_domains([0, xgrid.tgt.grid.nx-1, 0, xgrid.tgt.grid.ny-1]) xgrid.set_target_tile(tgt_tile) xgrid.get_interp() - + if input_file is None: return + + fileobj = FileObj(datafile=input_file, input_dir=input_dir, tiles=src_tiles) + + for var in fileobj.variables: + + variable = VariableObj(var, fileobj) + variable.get_attributes() + + times = list(range(variable.dims.ntimes)) if variable.time.here else [None] + klevels = list(range(variable.klevels.nz)) if variable.z.here else [None] + + for itime in times: + for klevel in klevels: + for src_tile in src_tiles: + input_data = variable.slice(tile=src_tile, timepoint=itime, klevel=klevel, prepare_data=True) + #HEREHEREHERE + pyfms.fms.horiz_interp(xgrid.interps[src_tile].interp_id, ) + diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py index bb35732..a79b46b 100644 --- a/fmsgridtools/remap/variableobj.py +++ b/fmsgridtools/remap/variableobj.py @@ -48,14 +48,26 @@ def __repr__(self): class FileObj(): - def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./"): + skip = [ + "geolon_c", "geolat_c", "geolon_u", "geolat_u", "geolon_v", "geolat_v", + "FA_X", "FA_Y", "FI_X", "FI_Y", "IX_TRANS", "IY_TRANS", + "UI", "VI", "UO", "VO", "wet_c", "wet_v", "wet_u", + "dxCu", "dyCu", "dxCv", "dyCv", "Coriolis", + "areacello_cu", "areacello_cv", "areacello_bu", + "average_T1", "average_T2", "average_DT", "time_bnds" + ] + + + def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./", variables: list = None): self.input_dir = str(input_dir) self.tiles = tiles self.datafiles = {tile: Path(input_dir)/Path(datafile + f".{tile}.nc") for tile in tiles} self.static_files = {} + self.variables = variables with xr.open_dataset(self.datafiles[self.tiles[0]], decode_cf=False) as dataset: + # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc associated_files = dataset.attrs.get("associated_files") if associated_files is not None: @@ -65,6 +77,16 @@ def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./" tile: Path(input_dir)/stringsplit[i+1].replace(".nc", f".{tile}.nc") for tile in self.tiles } + #get list of variables + if self.variables is None: + self.variables = [] + for variable in dataset: + if variable in dataset: + print(f"skipping {variable}") + else: + self.variables.append(variable) + + def __repr__(self): repr_str = "\n" repr_str += f"input_dir = {self.input_dir}\n" @@ -121,7 +143,7 @@ def get_attributes(self): self.static_files = self.fileobj.static_files[cell_measures.split()[1]] - def slice(self, tile: str = "tile1", timepoint: int = -99, klevel: int = -99): + def slice(self, tile: str = "tile1", timepoint: int = None, klevel: int = None, prepare_data: bool = False): #python, values above 0 are true... @@ -129,10 +151,15 @@ def slice(self, tile: str = "tile1", timepoint: int = -99, klevel: int = -99): slice_dict = {} - if klevel>-1 and self.dims.z.here: slice_dict[self.dims.z.name] = klevel - if timepoint>-1 and self.dims.time.here: slice_dict[self.dims.time.name] = timepoint + if klevel is not None and self.dims.z.here: + slice_dict[self.dims.z.name] = klevel + if timepoint is not None and self.dims.time.here: + slice_dict[self.dims.time.name] = timepoint + + self.data = dataset[self.variable].isel(slice_dict).values - self.data = dataset[self.variable].isel(slice_dict).values + if prepare_data: + self.prepare_data() return self.data @@ -143,8 +170,8 @@ def prepare_data(self): if self.missing_value is not None: missing_value_mask = self.data == self.missing_value - if self.offset: self.data += self.offset - if self.scale_factor: self.data *= self.scale_factor + if self.offset is not None: self.data += self.offset + if self.scale_factor is not None: self.data *= self.scale_factor #zero out missing values so it doens't contribute to remapping if missing_value_mask is not None: diff --git a/tests/remap/test_dataobj.py b/tests/remap/test_dataobj.py index 812cfa5..f0900b9 100644 --- a/tests/remap/test_dataobj.py +++ b/tests/remap/test_dataobj.py @@ -153,7 +153,5 @@ def test_dataobj(): assert reconstruct_data[itime,k,j,i] == np.float64(0.0) - - if __name__ == "__main__": test_dataobj() From 0d8b7b048d95a38f89df51f4455e8079294e7c5c Mon Sep 17 00:00:00 2001 From: mlee03 Date: Mon, 29 Dec 2025 13:10:16 -0500 Subject: [PATCH 25/27] save --- fmsgridtools/remap/conservative.py | 19 ++-- fmsgridtools/remap/variableobj.py | 21 ++-- fmsgridtools/shared/xgridobj.py | 10 +- pyFMS | 2 +- tests/__init__.py | 0 tests/remap/__init__.py | 0 tests/remap/test_dataobj.py | 135 ++++--------------------- tests/remap/write_files.py | 157 +++++++++++++++++++++++++++++ tests/shared/test_xgridobj.py | 4 +- 9 files changed, 203 insertions(+), 145 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/remap/__init__.py create mode 100644 tests/remap/write_files.py diff --git a/fmsgridtools/remap/conservative.py b/fmsgridtools/remap/conservative.py index 88f6f07..52e5811 100644 --- a/fmsgridtools/remap/conservative.py +++ b/fmsgridtools/remap/conservative.py @@ -7,8 +7,8 @@ def remap(input_dir: str = "./", - src_mosaic: str = None, - tgt_mosaic: str = None, + src_mosaicfile: str = None, + tgt_mosaicfile: str = None, input_file: str = None, output_dir: str = "./", output_file: str = None, @@ -21,7 +21,7 @@ def remap(input_dir: str = "./", check_conserve: bool = False, gpu: bool = False): - xgrid = XGridObj(input_dir=input_dir, src_mosaic_file = src_mosaic, tgt_mosaicfile=tgt_mosaic) + xgrid = XGridObj(input_dir=input_dir, src_mosaicfile=src_mosaicfile, tgt_mosaicfile=tgt_mosaicfile) tgt_tiles = list(xgrid.tgt.grid.keys()) src_tiles = list(xgrid.src.grid.keys()) @@ -32,7 +32,7 @@ def remap(input_dir: str = "./", pyfms.horiz_interp.init(len(src_tiles)) #get xgrid - xgrid.domain = pyfms.mpp_domains.define_domains([0, xgrid.tgt.grid.nx-1, 0, xgrid.tgt.grid.ny-1]) + xgrid.domain = pyfms.mpp_domains.define_domains([0, xgrid.tgt.grid[tgt_tile].nx-1, 0, xgrid.tgt.grid[tgt_tile].ny-1]) xgrid.set_target_tile(tgt_tile) xgrid.get_interp() @@ -45,18 +45,17 @@ def remap(input_dir: str = "./", variable = VariableObj(var, fileobj) variable.get_attributes() - times = list(range(variable.dims.ntimes)) if variable.time.here else [None] - klevels = list(range(variable.klevels.nz)) if variable.z.here else [None] + times = list(range(variable.dims.time.size)) if variable.dims.time.here else [None] + klevels = list(range(variable.dims.z.size)) if variable.dims.z.here else [None] for itime in times: for klevel in klevels: for src_tile in src_tiles: input_data = variable.slice(tile=src_tile, timepoint=itime, klevel=klevel, prepare_data=True) - #HEREHEREHERE - pyfms.fms.horiz_interp(xgrid.interps[src_tile].interp_id, ) - - + data = pyfms.horiz_interp.interp(xgrid.interps[src_tile].interp_id, input_data, convert_cf_order=False) + print(data) + exit() diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py index a79b46b..68d106f 100644 --- a/fmsgridtools/remap/variableobj.py +++ b/fmsgridtools/remap/variableobj.py @@ -48,7 +48,7 @@ def __repr__(self): class FileObj(): - skip = [ + skip_variables = [ "geolon_c", "geolat_c", "geolon_u", "geolat_u", "geolon_v", "geolat_v", "FA_X", "FA_Y", "FI_X", "FI_Y", "IX_TRANS", "IY_TRANS", "UI", "VI", "UO", "VO", "wet_c", "wet_v", "wet_u", @@ -57,7 +57,7 @@ class FileObj(): "average_T1", "average_T2", "average_DT", "time_bnds" ] - + def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./", variables: list = None): self.input_dir = str(input_dir) @@ -81,11 +81,11 @@ def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./" if self.variables is None: self.variables = [] for variable in dataset: - if variable in dataset: + if variable in self.skip_variables: print(f"skipping {variable}") else: self.variables.append(variable) - + def __repr__(self): repr_str = "\n" @@ -103,7 +103,7 @@ class VariableObj(): def __init__(self, variable: str, fileobj: FileObj = None): self.variable = variable - self.fileobj = fileobj + self.fileobj = fileobj self.dims = DimsObj() self.missing_value = None, @@ -145,7 +145,6 @@ def get_attributes(self): def slice(self, tile: str = "tile1", timepoint: int = None, klevel: int = None, prepare_data: bool = False): - #python, values above 0 are true... with xr.open_dataset(self.fileobj.datafiles[tile], decode_cf=False) as dataset: @@ -159,23 +158,23 @@ def slice(self, tile: str = "tile1", timepoint: int = None, klevel: int = None, self.data = dataset[self.variable].isel(slice_dict).values if prepare_data: - self.prepare_data() + self.prepare_data() return self.data - + def prepare_data(self): #missing value mask missing_value_mask = None if self.missing_value is not None: missing_value_mask = self.data == self.missing_value - + if self.offset is not None: self.data += self.offset if self.scale_factor is not None: self.data *= self.scale_factor - + #zero out missing values so it doens't contribute to remapping if missing_value_mask is not None: self.data = xr.where(missing_value_mask, 0.0, self.data) return self.data - + diff --git a/fmsgridtools/shared/xgridobj.py b/fmsgridtools/shared/xgridobj.py index 2285a34..415bfdf 100755 --- a/fmsgridtools/shared/xgridobj.py +++ b/fmsgridtools/shared/xgridobj.py @@ -130,9 +130,9 @@ def __init__( def set_target_tile(self, tgt_tile: str = "tile1"): self.tgt_tile = tgt_tile - self.tgt.grid = self.tgt.grid[tgt_tile] + self.tgt.grid = self.tgt.grid[tgt_tile] + - def read( self, input_dir: Path | str = None, @@ -177,7 +177,7 @@ def read( def gather(self): """ - gathers xgrid + gathers xgrid """ isc, jsc = self.tgt.domain.isc, self.tgt.domain.jsc @@ -210,7 +210,7 @@ def write(self, output_dir: Path | str = "./", outfile: str | Path = Path("remap datasets = [] for tile1, src_tile in enumerate(global_interps): - + interp = global_interps[src_tile] dataset = xr.Dataset() @@ -243,7 +243,7 @@ def write(self, output_dir: Path | str = "./", outfile: str | Path = Path("remap pyfms.mpp.sync() - def get_interp(self, on_gpu) -> dict: + def get_interp(self, on_gpu: bool = False) -> dict: """ call fms to compute xgrid diff --git a/pyFMS b/pyFMS index dd9948b..12c87d8 160000 --- a/pyFMS +++ b/pyFMS @@ -1 +1 @@ -Subproject commit dd9948b5d2d61cfea1704667ab21a7ccb4c88664 +Subproject commit 12c87d8532e1fcaa627d3163b3c8ec0fc115c4ba diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/remap/__init__.py b/tests/remap/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/remap/test_dataobj.py b/tests/remap/test_dataobj.py index f0900b9..34c0f76 100644 --- a/tests/remap/test_dataobj.py +++ b/tests/remap/test_dataobj.py @@ -3,119 +3,11 @@ import xarray as xr import fmsgridtools - -nx, ny, nk, ntimes = 10, 8, 3, 4 - -answers = np.zeros((ntimes, nk, ny, nx), dtype=np.float64) -for itime in range(ntimes): - for k in range(nk): - for j in range(ny): - start = itime*1000+k*100+j*10 - answers[itime, k, j, :] = np.arange(start, start+nx, dtype=np.float64) - -missing_ijkl = [(1,1,1,1), (2,2,2,2)] -for (itime, k, j, i) in missing_ijkl: - answers[itime,k,j,i] = missing_value - - -def write_files(outfile): - - - - variable_xt = xr.DataArray( - np.arange(nx, dtype=np.float64), - dims = ["grid_xt"], - attrs={ - "units": "degrees_E", - "long_name": "made up T-cell longitude", - "axis": "X" - } - ) - - variable_yt = xr.DataArray( - np.arange(ny, dtype=np.float64), - dims = ["grid_yt"], - attrs={ - "units": "degrees_N", - "long_name": "made up T-cell latitude", - "axis": "Y" - } - ) - - variable_pfull = xr.DataArray( - np.arange(nk, dtype=np.float64), - dims = ["pfull"], - attrs = { - "units": "mb", - "long_name": "ref full pressure level", - "axis": "Z", - "positive": "down" - } - ) - - variable_time = xr.DataArray( - np.arange(ntimes, dtype=np.float64), - dims = ["time"], - attrs= { - "units": "days since 0001-01-01 00:00:00", - "long_name": "time", - "axis": "T", - "calendar_type": "NOLEAP", - "calendar": "noleap" - } - ) - - variable1 = xr.DataArray( - answers, - dims = ["time", "pfull", "grid_yt", "grid_xt"], - attrs = { - "_FillValue": -np.float64(-100), - "long_name": "test variable 1", - "units": "kg m-2", - "missing_value": missing_value, - "add_offset": np.float64(0.0), - "scale_factor": np.float64(0.0), - "cell_method": "area:mean time:mean", - "cell_measures": "area: pemberley_area", - "standard_name": "Mr. Darcy" - } - ) - - variable2 = xr.DataArray( - np.zeros((ny, nx), dtype=np.float32), - dims = ["grid_yt", "grid_xt"], - attrs = { - "_FillValue": False, - "long_name": "test variable 2", - "cell_method": "area:mean", - "standard_name": "Missing cell_measures, expected to fail", - "missing_value": -np.float32(-1.0), - "scale_factor": -np.float32(-0.05), - "offset": np.float32(0.0) - } - ) - - data_vars = { - "grid_xt": variable_xt, - "grid_yt": variable_yt, - "pfull": variable_pfull, - "time": variable_time, - "variable1": variable1, - "variable2": variable2 - } - - dataset = xr.Dataset( - data_vars=data_vars, - attrs={"associated_files": "pemberley_area: pemberley.nc longbourn_area: longbourn.nc"} - ) - - dataset.to_netcdf(outfile, unlimited_dims="time") - return dataset - +import write_files def test_dataobj(): - dataset = write_files("test.tile1.nc") + dataset = write_files.write_data("test.tile1.nc") fileobj = fmsgridtools.FileObj("test") @@ -141,17 +33,28 @@ def test_dataobj(): assert variable1.scale_factor == attributes.get("scale_factor") #slice - reconstruct_data = np.zeros((ntimes, nk, ny, nx), dtype=np.float64) - for itime in range(ntimes): - for k in range(nk): + src = write_files.src + reconstruct_data = np.zeros((src.ntimes, src.nk, src.ny, src.nx), dtype=np.float64) + for itime in range(src.ntimes): + for k in range(src.nk): sliced_data = variable1.slice(timepoint=itime, klevel=k) - np.testing.assert_equal(sliced_data, answers[itime, k, :, :]) + np.testing.assert_equal(sliced_data, write_files.data[itime, k, :, :]) reconstruct_data[itime, k, :, :] = variable1.prepare_data() #check missing values - for (itime, k, j, i) in missing_ijkl: + for (itime, k, j, i) in write_files.missing_ijkl: assert reconstruct_data[itime,k,j,i] == np.float64(0.0) +def test_remap(): + + write_files.write_mosaics() + dataset = write_files.write_data("test") + + src_mosaic = write_files.src.mosaicfile + tgt_mosaic = write_files.tgt.mosaicfile + + fmsgridtools.remap.conservative.remap(src_mosaicfile=src_mosaic, tgt_mosaicfile=tgt_mosaic, input_file="test") + if __name__ == "__main__": - test_dataobj() + test_remap() diff --git a/tests/remap/write_files.py b/tests/remap/write_files.py new file mode 100644 index 0000000..016719c --- /dev/null +++ b/tests/remap/write_files.py @@ -0,0 +1,157 @@ +import numpy as np +from types import SimpleNamespace + +import xarray as xr +import fmsgridtools + +src = SimpleNamespace( + ntiles=6, + nx=12, + ny=24, + nk=3, + ntimes=4, + dxy=1.0, + mosaicfile="src_mosaic.nc", + gridfile="src_grid" +) +tgt = SimpleNamespace( + ntiles=1, + nx=src.nx, #* 2, + ny=src.ny, #* 2, + nk=3, + ntimes=4, + dxy=src.dxy, # / 2.0, + mosaicfile="tgt_mosaic.nc", + gridfile="tgt_grid" +) + +nxgrid_per_tile = tgt.nx//2 * tgt.ny//2 +nxgrid = nxgrid_per_tile * 6 + +ntimes, nk, ny, nx = src.ntimes, src.nk, src.ny//2, src.nx//2 + +missing_value = -np.float64(-99.) +missing_ijkl = [(1,1,1,1), (2,2,2,2)] + +data_list = [] +for itile in range(1, src.ntiles+1): + data = np.zeros((ntimes, nk, ny, nx), dtype=np.float64) + for itime in range(ntimes): + for k in range(nk): + for j in range(ny): + start = itile*10000 + itime*1000+k*100+j*10 + data[itime, k, j, :] = np.arange(start, start+nx, dtype=np.float64) + for (itime, k, j, i) in missing_ijkl: + data[itime,k,j,i] = missing_value + data_list.append(data) + + +def write_mosaics(): + + """ + make mosaic and grid files for testing + """ + + # write mosaic + for parent in [src, tgt]: + fmsgridtools.MosaicObj( + gridtiles=[f"tile{i}" for i in range(1, parent.ntiles+1)], + gridfiles=[f"{parent.gridfile}.tile{i}.nc" for i in range(1, parent.ntiles+1)] + ).write(parent.mosaicfile) + + # write grid + for parent in [src, tgt]: + for itile in range(1, parent.ntiles+1): + x1 = np.array([i*parent.dxy for i in range(parent.nx+1)], dtype=np.float64) + y1 = np.array([j*parent.dxy for j in range(parent.ny+1)], dtype=np.float64) + x, y = np.meshgrid(x1, y1) + fmsgridtools.GridObj(x=x, y=y).write(parent.gridfile + f".tile{itile}.nc") + + +def write_data(outfile): + + for itile in range(src.ntiles): + + variable_xt = xr.DataArray( + np.arange(nx, dtype=np.float64), + dims = ["grid_xt"], + attrs={ + "units": "degrees_E", + "long_name": "made up T-cell longitude", + "axis": "X" + } + ) + variable_yt = xr.DataArray( + np.arange(ny, dtype=np.float64), + dims = ["grid_yt"], + attrs={ + "units": "degrees_N", + "long_name": "made up T-cell latitude", + "axis": "Y" + } + ) + variable_pfull = xr.DataArray( + np.arange(nk, dtype=np.float64), + dims = ["pfull"], + attrs = { + "units": "mb", + "long_name": "ref full pressure level", + "axis": "Z", + "positive": "down" + } + ) + variable_time = xr.DataArray( + np.arange(ntimes, dtype=np.float64), + dims = ["time"], + attrs= { + "units": "days since 0001-01-01 00:00:00", + "long_name": "time", + "axis": "T", + "calendar_type": "NOLEAP", + "calendar": "noleap" + } + ) + variable1 = xr.DataArray( + data_list[itile], + dims = ["time", "pfull", "grid_yt", "grid_xt"], + attrs = { + "_FillValue": -np.float64(-100), + "long_name": "test variable 1", + "units": "kg m-2", + "missing_value": missing_value, + "add_offset": np.float64(0.0), + "scale_factor": np.float64(1.0), + "cell_method": "area:mean time:mean", + "cell_measures": "area: pemberley_area", + "standard_name": "Mr. Darcy" + } + ) + variable2 = xr.DataArray( + np.zeros((ny, nx), dtype=np.float32), + dims = ["grid_yt", "grid_xt"], + attrs = { + "_FillValue": False, + "long_name": "test variable 2", + "cell_method": "area:mean", + "standard_name": "Missing cell_measures, expected to fail", + "missing_value": -np.float32(-1.0), + "scale_factor": -np.float32(-0.05), + "offset": np.float32(0.0) + } + ) + data_vars = { + "grid_xt": variable_xt, + "grid_yt": variable_yt, + "pfull": variable_pfull, + "time": variable_time, + "variable1": variable1, + "variable2": variable2 + } + + dataset = xr.Dataset( + data_vars=data_vars, + attrs={"associated_files": "pemberley_area: pemberley.nc longbourn_area: longbourn.nc"} + ) + + dataset.to_netcdf(outfile+f".tile{itile+1}.nc", unlimited_dims="time") + return dataset diff --git a/tests/shared/test_xgridobj.py b/tests/shared/test_xgridobj.py index 02aa58c..3267566 100755 --- a/tests/shared/test_xgridobj.py +++ b/tests/shared/test_xgridobj.py @@ -80,14 +80,14 @@ def xgridobj_test(on_gpu: bool = False): xgrid.set_target_tile("tile1") xgrid.get_interp(on_gpu=on_gpu) xgrid.write(outfile=remapfile) - + del xgrid xgrid = fmsgridtools.XGridObj( src_mosaicfile=src.mosaicfile, tgt_mosaicfile=tgt.mosaicfile, remapfile=remapfile, - tgt_tile = "tile1") + tgt_tile = "tile1") xgrid.read(remapfile=remapfile) #answers From d28813559bce44d838a1c34a276d370fa97c24c2 Mon Sep 17 00:00:00 2001 From: mlee03 Date: Mon, 29 Dec 2025 15:15:46 -0500 Subject: [PATCH 26/27] save --- fmsgridtools/remap/conservative.py | 28 +++-- fmsgridtools/remap/variableobj.py | 175 +++++++++++++++++++---------- 2 files changed, 137 insertions(+), 66 deletions(-) diff --git a/fmsgridtools/remap/conservative.py b/fmsgridtools/remap/conservative.py index 52e5811..f31c492 100644 --- a/fmsgridtools/remap/conservative.py +++ b/fmsgridtools/remap/conservative.py @@ -31,31 +31,41 @@ def remap(input_dir: str = "./", pyfms.horiz_interp.init(len(src_tiles)) - #get xgrid - xgrid.domain = pyfms.mpp_domains.define_domains([0, xgrid.tgt.grid[tgt_tile].nx-1, 0, xgrid.tgt.grid[tgt_tile].ny-1]) + nx_tgt, ny_tgt = xgrid.tgt.grid[tgt_tile].nx, xgrid.tgt.grid[tgt_tile].ny + + #get xgrid, make sure mpi works cause i don't think it will + xgrid.domain = pyfms.mpp_domains.define_domains([0, nx_tgt-1, 0, ny_tgt-1]) xgrid.set_target_tile(tgt_tile) xgrid.get_interp() if input_file is None: return + if output_file is None: output_file = input_file + if output_dir is None: output_dir = input_dir - fileobj = FileObj(datafile=input_file, input_dir=input_dir, tiles=src_tiles) + src_fileobj = SrcFileObj(datafile=input_file, input_dir=input_dir, tiles=src_tiles) + tgt_fileobj = TgtFileObj(datafile=output_file, output_dir=output_dir, tile=tgt_tile) - for var in fileobj.variables: + for var in src_fileobj.variables: - variable = VariableObj(var, fileobj) + variable = VariableObj(var, src_fileobj) variable.get_attributes() times = list(range(variable.dims.time.size)) if variable.dims.time.here else [None] klevels = list(range(variable.dims.z.size)) if variable.dims.z.here else [None] + #automatically sets up time and vertical levels + #find a way to not send in nx and ny + variable.init_tgt_data(self, nx=nx_tgt, ny=ny_tgt) + for itime in times: for klevel in klevels: + data_slice = np.zeros(ny_tgt, nx_tgt, dtype=variable.dtype) for src_tile in src_tiles: input_data = variable.slice(tile=src_tile, timepoint=itime, klevel=klevel, prepare_data=True) - data = pyfms.horiz_interp.interp(xgrid.interps[src_tile].interp_id, input_data, convert_cf_order=False) - print(data) - + data_slice += pyfms.horiz_interp.interp(xgrid.interps[src_tile].interp_id, input_data, convert_cf_order=False) - exit() + #gather if parallel, for now, no + variable.set_tgt_data(data_slice, timepoint=itime, klevel=klevel) + tgt_fileobj.set_dataarray(var, data_dict=variable.tgt_dict) diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py index 68d106f..fa90f52 100644 --- a/fmsgridtools/remap/variableobj.py +++ b/fmsgridtools/remap/variableobj.py @@ -2,6 +2,7 @@ from pathlib import Path from types import SimpleNamespace +import numpy as np import xarray as xr class DimObj(): @@ -46,7 +47,7 @@ def __repr__(self): return repr_str -class FileObj(): +class SrcFileObj(): skip_variables = [ "geolon_c", "geolat_c", "geolon_u", "geolat_u", "geolon_v", "geolat_v", @@ -62,104 +63,131 @@ def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./" self.input_dir = str(input_dir) self.tiles = tiles - self.datafiles = {tile: Path(input_dir)/Path(datafile + f".{tile}.nc") for tile in tiles} + self.src_datafiles = {tile: Path(input_dir)/Path(datafile + f".{tile}.nc") for tile in tiles} self.static_files = {} self.variables = variables - with xr.open_dataset(self.datafiles[self.tiles[0]], decode_cf=False) as dataset: - - # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc - associated_files = dataset.attrs.get("associated_files") - if associated_files is not None: - stringsplit = associated_files.replace(":", " ").split() - for i in range(0, len(stringsplit), 2): - self.static_files[stringsplit[i]] = { - tile: Path(input_dir)/stringsplit[i+1].replace(".nc", f".{tile}.nc") for tile in self.tiles - } - - #get list of variables - if self.variables is None: - self.variables = [] - for variable in dataset: - if variable in self.skip_variables: - print(f"skipping {variable}") - else: - self.variables.append(variable) - + self.src_datasets = { + tile: xr.open_dataset(self.src_datafiles[tile], decode_cf=False) for tile in self.tiles + } + + dataset = self.src_datasets[tiles[0]] + # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc + associated_files = dataset.attrs.get("associated_files") + if associated_files is not None: + stringsplit = associated_files.replace(":", " ").split() + for i in range(0, len(stringsplit), 2): + self.static_files[stringsplit[i]] = { + tile: Path(input_dir)/stringsplit[i+1].replace(".nc", f".{tile}.nc") for tile in self.tiles + } + + #get list of variables + if self.variables is None: + self.variables = [] + for variable in dataset: + if variable in self.skip_variables: + print(f"skipping {variable}") + else: + self.variables.append(variable) def __repr__(self): repr_str = "\n" repr_str += f"input_dir = {self.input_dir}\n" repr_str += f"tiles = {self.tiles}\n" - repr_str += f"datafiles = {self.datafiles}\n" + repr_str += f"datafiles = {self.src_datafiles}\n" repr_str += f"static_files = {self.static_files}\n" return repr_str + +class TgtFileObj(): + + def __init__(self, datadict: dict = {}): + self.dataarrays = datadict + self.name = None + self.dims = None + + def set_dataarray(self, variable: str = None, data_dict: dict = None, dataarray: xr.DataArray = None): + + if data_dict is not None: + self.dataarrays["variable" ] = xr.DataArray.from_dict(data_dict) + elif dataarray is not None: + self.dataarrays["variable"] = dataarray + else: + raise RuntimeError("must provide something") + + def set_coords(self, grid): + pass + + + class VariableObj(): time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) - def __init__(self, variable: str, fileobj: FileObj = None): + def __init__(self, variable: str, fileobj: SrcFileObj = None): self.variable = variable self.fileobj = fileobj self.dims = DimsObj() + self.dtype = None, self.missing_value = None, self.fill_value = None, self.offset = None, self.scale_factor = None self.static_files: {} = None - self.data = None + self.src_data = None + + self.tgt_dict = { + "data": npt.NDArray = None, + "dims": list = None, + "attrs": dict = None + } def get_attributes(self): tile = self.fileobj.tiles[0] - infile = self.fileobj.datafiles[tile] + dataset = self.fileobj.datasets[tile] - if not infile.exists: - raise RuntimeError("file does not exist") + if self.variable not in dataset: + raise RuntimeError("variable not found") - with xr.open_dataset(infile, decode_cf=False) as dataset: + dataarray = dataset[self.variable] - if self.variable not in dataset: - raise RuntimeError("variable not found") - dataarray = dataset[self.variable] + self.dims.get(dataarray) + self.dtype=dataarray.dtype - self.dims.get(dataarray) + attributes = dataarray.attrs + self.missing_value = attributes.get("missing_value") + self.fill_value = attributes.get("_FillValue") + self.offset = attributes.get("add_offset") + self.scale_factor = attributes.get("scale_factor") - attributes = dataarray.attrs - self.missing_value = attributes.get("missing_value") - self.fill_value = attributes.get("_FillValue") - self.offset = attributes.get("add_offset") - self.scale_factor = attributes.get("scale_factor") - - if "area" in str(attributes.get("cell_method")): - cell_measures = str(attributes.get("cell_measures")) - if "area:" in cell_measures: - self.static_files = self.fileobj.static_files[cell_measures.split()[1]] + if "area" in str(attributes.get("cell_method")): + cell_measures = str(attributes.get("cell_measures")) + if "area:" in cell_measures: + self.static_files = self.fileobj.static_files[cell_measures.split()[1]] def slice(self, tile: str = "tile1", timepoint: int = None, klevel: int = None, prepare_data: bool = False): - with xr.open_dataset(self.fileobj.datafiles[tile], decode_cf=False) as dataset: - - slice_dict = {} + dataset = self.fileobj.datasets[tile] - if klevel is not None and self.dims.z.here: - slice_dict[self.dims.z.name] = klevel - if timepoint is not None and self.dims.time.here: - slice_dict[self.dims.time.name] = timepoint + slice_dict = {} + if klevel is not None and self.dims.z.here: + slice_dict[self.dims.z.name] = klevel + if timepoint is not None and self.dims.time.here: + slice_dict[self.dims.time.name] = timepoint - self.data = dataset[self.variable].isel(slice_dict).values + self.src_data = dataset[self.variable].isel(slice_dict).values if prepare_data: self.prepare_data() - return self.data + return self.src_data def prepare_data(self): @@ -167,14 +195,47 @@ def prepare_data(self): #missing value mask missing_value_mask = None if self.missing_value is not None: - missing_value_mask = self.data == self.missing_value + missing_value_mask = self.src_data == self.missing_value - if self.offset is not None: self.data += self.offset - if self.scale_factor is not None: self.data *= self.scale_factor + if self.offset is not None: self.src_data += self.offset + if self.scale_factor is not None: self.src_data *= self.scale_factor #zero out missing values so it doens't contribute to remapping if missing_value_mask is not None: - self.data = xr.where(missing_value_mask, 0.0, self.data) + self.src_data = xr.where(missing_value_mask, 0.0, self.src_data) + + return self.src_data + + + def init_tgt_dict(self, nx: int, ny: int, ntimes: int = None, nz: int = None): + + dims = [] + if self.dims.time.here: + if ntimes is None: ntimes = self.dims.time.size + dims.append(ntimes) + + if self.dims.z.here: + if nz is None: nz = self.dims.z.size + dims.append(nx) + + dims += [self.dims.y.size, self.dims.x.size] + + self.tgt_dict["dims"] = dims + self.tgt_dict["attrs"]= self.src_datasets[self.fileobj.tiles[0]][self.variable].attrs + self.tgt_dict["data"] = np.zeros((ntimes, nz, ny, nx), dtype=self.dtype) + + return self.tgt_dict + + + def set_tgt_data(self, data: npt.NDArray, timepoint: int = None, klevel: int = None): - return self.data + if timepoint is None and klevel is None: + self.tgt_dict["data"] = data + elif timepoint is not None and klevel is not None: + self.tgt_dict["data"][timepoint, klevel, :, :] = data + elif timepoint is not None: + self.tgt_dict["data"][timepoint, :, :] = data + elif klevel is not None: + self.tgt_dict["data"][klevel, :, :] = data + return self.tgt_dict From 5923bc8589d021f4257181c2ba434856a7a66fad Mon Sep 17 00:00:00 2001 From: mlee03 Date: Tue, 30 Dec 2025 14:35:31 -0500 Subject: [PATCH 27/27] save --- fmsgridtools/__init__.py | 2 +- fmsgridtools/remap/conservative.py | 22 ++-- fmsgridtools/remap/variableobj.py | 181 ++++++++++++++++------------- tests/remap/test_dataobj.py | 29 ++--- 4 files changed, 122 insertions(+), 112 deletions(-) diff --git a/fmsgridtools/__init__.py b/fmsgridtools/__init__.py index 980043b..6c42b8e 100644 --- a/fmsgridtools/__init__.py +++ b/fmsgridtools/__init__.py @@ -5,7 +5,7 @@ from .make_mosaic import coupler_mosaic from .make_topog import make_topog from .remap import remap -from .remap.variableobj import VariableObj, FileObj, DimObj, DimsObj +from .remap.variableobj import VariableObj, SrcFileObj, TgtFileObj, DimObj, FileDimsObj from .shared.gridtools_utils import check_file_is_there, get_provenance_attrs from .shared.gridobj import GridObj from .shared.mosaicobj import MosaicObj diff --git a/fmsgridtools/remap/conservative.py b/fmsgridtools/remap/conservative.py index f31c492..7168f31 100644 --- a/fmsgridtools/remap/conservative.py +++ b/fmsgridtools/remap/conservative.py @@ -1,6 +1,6 @@ import numpy as np -from fmsgridtools.remap.variableobj import FileObj, VariableObj +from fmsgridtools.remap.variableobj import SrcFileObj, TgtFileObj, VariableObj from fmsgridtools.shared.xgridobj import XGridObj import pyfms @@ -43,29 +43,23 @@ def remap(input_dir: str = "./", if output_dir is None: output_dir = input_dir src_fileobj = SrcFileObj(datafile=input_file, input_dir=input_dir, tiles=src_tiles) - tgt_fileobj = TgtFileObj(datafile=output_file, output_dir=output_dir, tile=tgt_tile) + tgt_fileobj = TgtFileObj(datafile=output_file, output_dir=output_dir, nx=nx_tgt, ny=ny_tgt) for var in src_fileobj.variables: variable = VariableObj(var, src_fileobj) - variable.get_attributes() + variable.init() - times = list(range(variable.dims.time.size)) if variable.dims.time.here else [None] - klevels = list(range(variable.dims.z.size)) if variable.dims.z.here else [None] + for itime in variable.timeslist: + for klevel in variable.zlist: - #automatically sets up time and vertical levels - #find a way to not send in nx and ny - variable.init_tgt_data(self, nx=nx_tgt, ny=ny_tgt) + input_data = variable.slice(tile=src_tiles[0], timepoint=itime, klevel=klevel, prepare_data=True) + data_slice = pyfms.horiz_interp.interp(xgrid.interps[src_tile].interp_id, input_data, convert_cf_order=False) - for itime in times: - for klevel in klevels: - data_slice = np.zeros(ny_tgt, nx_tgt, dtype=variable.dtype) - for src_tile in src_tiles: + for src_tile in src_tiles[1:]: input_data = variable.slice(tile=src_tile, timepoint=itime, klevel=klevel, prepare_data=True) data_slice += pyfms.horiz_interp.interp(xgrid.interps[src_tile].interp_id, input_data, convert_cf_order=False) #gather if parallel, for now, no variable.set_tgt_data(data_slice, timepoint=itime, klevel=klevel) - tgt_fileobj.set_dataarray(var, data_dict=variable.tgt_dict) - diff --git a/fmsgridtools/remap/variableobj.py b/fmsgridtools/remap/variableobj.py index fa90f52..a1fe40c 100644 --- a/fmsgridtools/remap/variableobj.py +++ b/fmsgridtools/remap/variableobj.py @@ -3,18 +3,21 @@ from types import SimpleNamespace import numpy as np +import numpy.typing as npt import xarray as xr class DimObj(): - def __init__(self, axis: str, name: str = None, here: bool = False, size: int = None): + def __init__(self, axis: str, name: str = None, here: bool = False, size: int = None, attr: dict = None, coord_values: npt.NDArray = None): self.axis = axis self.name = name self.size = size + self.attr = None + self.coords = coord_values self.here = False -class DimsObj(): +class FileDimsObj(): def __init__(self): self.x = DimObj("X") @@ -22,23 +25,39 @@ def __init__(self): self.z = DimObj("Z") self.time = DimObj("T") - def get(self, da: xr.DataArray): + + def init(self, dataset: xr.Dataset): """ get dimensions """ - dims_dict = {dim.axis: dim for dim in [self.x, self.y, self.z, self.time]} + dims_dict = {dim.axis: dim for dim in [self.time, self.z, self.y, self.x]} - for name, coord in da.coords.items(): + for name, coord in dataset.coords.items(): try: dim = dims_dict[coord.attrs.get("axis")] dim.name = name + dim.attr = coord.attrs dim.size = coord.size + dim.coords = coord.values dim.here = True except: print(f"{name} dim not found") + def get_z(self, dims_list: list): + if self.z.name in dims_list: + return self.z + else: + return DimObj(axis="Z", name=self.z.name) + + def get_time(self, dims_list: list): + if self.time.name in dims_list: + return self.time + else: + return DimObj(axis="T", name=self.time.name) + + def __repr__(self): repr_str ="\n" @@ -63,15 +82,20 @@ def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./" self.input_dir = str(input_dir) self.tiles = tiles + self.dims = FileDimsObj() self.src_datafiles = {tile: Path(input_dir)/Path(datafile + f".{tile}.nc") for tile in tiles} self.static_files = {} - self.variables = variables + self.variables = [] if variables is None else variables - self.src_datasets = { + self.datasets = { tile: xr.open_dataset(self.src_datafiles[tile], decode_cf=False) for tile in self.tiles } - dataset = self.src_datasets[tiles[0]] + dataset = self.datasets[tiles[0]] + + # get dims + self.dims.init(dataset) + # soil_area: 00010101.land_static.nc cell_area: 00010101.land_static_sg.nc associated_files = dataset.attrs.get("associated_files") if associated_files is not None: @@ -81,9 +105,8 @@ def __init__(self, datafile: str, tiles: list = ["tile1"], input_dir: str = "./" tile: Path(input_dir)/stringsplit[i+1].replace(".nc", f".{tile}.nc") for tile in self.tiles } - #get list of variables - if self.variables is None: - self.variables = [] + #get list of variables if not specified + if not bool(self.variables): for variable in dataset: if variable in self.skip_variables: print(f"skipping {variable}") @@ -102,19 +125,31 @@ def __repr__(self): class TgtFileObj(): - def __init__(self, datadict: dict = {}): - self.dataarrays = datadict - self.name = None - self.dims = None + def __init__(self, datafile: str|Path, nx: int, ny: int, output_dir: str = "./"): - def set_dataarray(self, variable: str = None, data_dict: dict = None, dataarray: xr.DataArray = None): + self.output_dir = str(output_dir) + self.datafile = datafile + self.nx = nx + self.ny = ny + + self.datadict = {} + + + def init_variable(self, variable: str, dtype, dims_list: list = None, attributes: dict = None, z_size: int = None, time_size: int = None): + + data_info = {} + if dims_list is not None: data_info["dims"] = dims_list + if attributes is not None: data_info["attrs"] = attributes + + shape = [] + if time_size is not None: shape.append(time_size) + if z_size is not None: shape.append(z_size) + shape += [self.ny, self.nx] + + data_info["data"] = np.zeros(shape, dtype=dtype) + + self.datadict[variable] = data_info - if data_dict is not None: - self.dataarrays["variable" ] = xr.DataArray.from_dict(data_dict) - elif dataarray is not None: - self.dataarrays["variable"] = dataarray - else: - raise RuntimeError("must provide something") def set_coords(self, grid): pass @@ -125,39 +160,35 @@ class VariableObj(): time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) - def __init__(self, variable: str, fileobj: SrcFileObj = None): + def __init__(self, variable: str, src_fileobj: SrcFileObj = None, tgt_fileobj: TgtFileObj = None): self.variable = variable - self.fileobj = fileobj - self.dims = DimsObj() + self.src_fileobj = src_fileobj + self.tgt_fileobj = tgt_fileobj - self.dtype = None, - self.missing_value = None, - self.fill_value = None, - self.offset = None, + self.z = None + self.time = None + + self.dtype = None + self.missing_value = None + self.fill_value = None + self.offset = None self.scale_factor = None self.static_files: {} = None - self.src_data = None - - self.tgt_dict = { - "data": npt.NDArray = None, - "dims": list = None, - "attrs": dict = None - } - - - def get_attributes(self): - - tile = self.fileobj.tiles[0] - dataset = self.fileobj.datasets[tile] + + tile = self.src_fileobj.tiles[0] + dataset = self.src_fileobj.datasets[tile] if self.variable not in dataset: raise RuntimeError("variable not found") dataarray = dataset[self.variable] - self.dims.get(dataarray) + dims_list = dataarray.dims + self.z = self.src_fileobj.dims.get_z(dims_list) + self.time = self.src_fileobj.dims.get_time(dims_list) + self.dtype=dataarray.dtype attributes = dataarray.attrs @@ -169,73 +200,57 @@ def get_attributes(self): if "area" in str(attributes.get("cell_method")): cell_measures = str(attributes.get("cell_measures")) if "area:" in cell_measures: - self.static_files = self.fileobj.static_files[cell_measures.split()[1]] + self.static_files = self.src_fileobj.static_files[cell_measures.split()[1]] + + self.tgt_fileobj.init_variable(self.variable, self.dtype, dims_list=dims_list)#, attributes=attributes, z_size=self.z.size, time_size=self.time.size) def slice(self, tile: str = "tile1", timepoint: int = None, klevel: int = None, prepare_data: bool = False): - dataset = self.fileobj.datasets[tile] + dataset = self.src_fileobj.datasets[tile] slice_dict = {} - if klevel is not None and self.dims.z.here: - slice_dict[self.dims.z.name] = klevel - if timepoint is not None and self.dims.time.here: - slice_dict[self.dims.time.name] = timepoint + if klevel is not None: + slice_dict[self.z.name] = klevel + if timepoint is not None: + slice_dict[self.time.name] = timepoint - self.src_data = dataset[self.variable].isel(slice_dict).values + src_data = dataset[self.variable].isel(slice_dict).values if prepare_data: - self.prepare_data() - return self.src_data + src_data = self.prepare_data(src_data) + return src_data - def prepare_data(self): + def prepare_data(self, src_data: npt.NDArray = None): #missing value mask missing_value_mask = None if self.missing_value is not None: - missing_value_mask = self.src_data == self.missing_value + missing_value_mask = src_data == self.missing_value - if self.offset is not None: self.src_data += self.offset - if self.scale_factor is not None: self.src_data *= self.scale_factor + if self.offset is not None: src_data += self.offset + if self.scale_factor is not None: src_data *= self.scale_factor #zero out missing values so it doens't contribute to remapping if missing_value_mask is not None: - self.src_data = xr.where(missing_value_mask, 0.0, self.src_data) + src_data = xr.where(missing_value_mask, 0.0, src_data) - return self.src_data - - - def init_tgt_dict(self, nx: int, ny: int, ntimes: int = None, nz: int = None): - - dims = [] - if self.dims.time.here: - if ntimes is None: ntimes = self.dims.time.size - dims.append(ntimes) - - if self.dims.z.here: - if nz is None: nz = self.dims.z.size - dims.append(nx) - - dims += [self.dims.y.size, self.dims.x.size] - - self.tgt_dict["dims"] = dims - self.tgt_dict["attrs"]= self.src_datasets[self.fileobj.tiles[0]][self.variable].attrs - self.tgt_dict["data"] = np.zeros((ntimes, nz, ny, nx), dtype=self.dtype) - - return self.tgt_dict + return src_data def set_tgt_data(self, data: npt.NDArray, timepoint: int = None, klevel: int = None): + tgt_data = self.tgt_fileobj.datadict[self.variable]["data"] + if timepoint is None and klevel is None: - self.tgt_dict["data"] = data + tgt_data = data elif timepoint is not None and klevel is not None: - self.tgt_dict["data"][timepoint, klevel, :, :] = data + tgt_data[timepoint, klevel, :, :] = data elif timepoint is not None: - self.tgt_dict["data"][timepoint, :, :] = data + tgt_data[timepoint, :, :] = data elif klevel is not None: - self.tgt_dict["data"][klevel, :, :] = data + tgt_data[klevel, :, :] = data - return self.tgt_dict + self.tgt_fileobj.datadict[self.variable]["data"] = tgt_data \ No newline at end of file diff --git a/tests/remap/test_dataobj.py b/tests/remap/test_dataobj.py index 34c0f76..0494a41 100644 --- a/tests/remap/test_dataobj.py +++ b/tests/remap/test_dataobj.py @@ -7,20 +7,21 @@ def test_dataobj(): - dataset = write_files.write_data("test.tile1.nc") - - fileobj = fmsgridtools.FileObj("test") + dataset = write_files.write_data("test") - variable1 = fmsgridtools.VariableObj(variable="variable1", fileobj=fileobj) + srcfileobj = fmsgridtools.SrcFileObj("test", tiles=["tile1"]) + tgtfileobj = fmsgridtools.TgtFileObj("test", nx=10, ny=14) + + variable1 = fmsgridtools.VariableObj(variable="variable1", src_fileobj=srcfileobj, tgt_fileobj=tgtfileobj) #check dimensions - variable1.get_attributes() dims = [ - ("grid_xt", variable1.dims.x), - ("grid_yt", variable1.dims.y), - ("time", variable1.dims.time), - ("pfull", variable1.dims.z) + ("grid_xt", variable1.src_fileobj.dims.x), + ("grid_yt", variable1.src_fileobj.dims.y), + ("time", variable1.time), + ("pfull", variable1.z) ] + for (name, dim) in dims: assert dim.name == name assert dim.here @@ -34,12 +35,12 @@ def test_dataobj(): #slice src = write_files.src - reconstruct_data = np.zeros((src.ntimes, src.nk, src.ny, src.nx), dtype=np.float64) + reconstruct_data = np.zeros((src.ntimes, src.nk, src.ny//2, src.nx//2), dtype=np.float64) for itime in range(src.ntimes): for k in range(src.nk): - sliced_data = variable1.slice(timepoint=itime, klevel=k) - np.testing.assert_equal(sliced_data, write_files.data[itime, k, :, :]) - reconstruct_data[itime, k, :, :] = variable1.prepare_data() + sliced_data = variable1.slice(tile="tile1", timepoint=itime, klevel=k) + np.testing.assert_equal(sliced_data, write_files.data_list[0][itime, k, :, :]) + reconstruct_data[itime, k, :, :] = variable1.prepare_data(sliced_data) #check missing values for (itime, k, j, i) in write_files.missing_ijkl: @@ -57,4 +58,4 @@ def test_remap(): if __name__ == "__main__": - test_remap() + test_dataobj()