diff --git a/abses/core/experiment.py b/abses/core/experiment.py index fe362f8d..49157dd8 100644 --- a/abses/core/experiment.py +++ b/abses/core/experiment.py @@ -42,7 +42,7 @@ import numpy as np from hydra import compose, initialize from hydra.core.global_hydra import GlobalHydra -from hydra.core.hydra_config import HydraConfig +from hydra.core.hydra_config import HydraConf, HydraConfig from joblib import Parallel, delayed from omegaconf import DictConfig, OmegaConf from tqdm.auto import tqdm @@ -208,8 +208,8 @@ def new( return cls(model_cls, cfg, **kwargs) @property - def hydra_config(self) -> DictConfig: - """Hydra config.""" + def hydra_config(self) -> HydraConf: + """Hydra runtime configuration object (HydraConf).""" if self.is_hydra_job(): return HydraConfig.get() raise RuntimeError("Experiment is not running in Hydra.") diff --git a/abses/core/model.py b/abses/core/model.py index 9ee006e2..3206e597 100644 --- a/abses/core/model.py +++ b/abses/core/model.py @@ -181,9 +181,12 @@ def steps(self, steps: int) -> None: Parameters: steps: Number of steps. If > 0, automatically advances time. """ + delta = steps - getattr(self, "_steps", 0) + if not isinstance(delta, int): + raise TypeError(f"Steps must be an integer, got {type(steps)}") + if delta > 0: + self.time.go(delta) self._steps = steps - if steps > 0: - self.time.go(steps) def __deepcopy__(self, memo: dict) -> "MainModel": """Prevent deep copying of model. @@ -437,7 +440,6 @@ def run_model( run_times = 0 self.do_each("setup", order=order) while self.running is True: - self.time.go() self.do_each("step", order=order) run_times += 1 if steps is not None and run_times >= steps: diff --git a/abses/core/time_driver.py b/abses/core/time_driver.py index 1a18f7fb..ef7ee9d3 100644 --- a/abses/core/time_driver.py +++ b/abses/core/time_driver.py @@ -123,6 +123,8 @@ def __init__(self, model: MainModelProtocol): super().__init__(model=model, name="time") self._history: Deque[DateTime] = deque() self._history_ticks: Deque[int] = deque() + # End time can only be DateTime | int | None at runtime + self._end_dt: DateTime | int | None = None self._parse_ticking_mode(set(self.params.keys())) self._parse_time_settings(self.params) self._dt = self.start_dt @@ -332,18 +334,25 @@ def end_at(self) -> Optional[DateOrTick]: @end_at.setter def end_at(self, dt: Optional[DateOrTick | str]) -> None: """Set the end time.""" - is_tick = is_positive_int(dt, raise_error=False) - if dt is None or is_tick: - self._end_dt = dt - return - # If the end time is a string / datetime object. - if isinstance(dt, str): - dt = parse_datetime(dt) - if isinstance(dt, datetime) and not isinstance(dt, DateTime): - dt = pendulum.instance(dt).replace(tzinfo=None) - elif isinstance(dt, DateTime): - dt = dt.replace(tzinfo=None) - self._end_dt = dt + # Normalize into DateTime | int | None + normalized: DateTime | int | None + if dt is None: + normalized = None + elif is_positive_int(dt, raise_error=False): + normalized = int(dt) # mypy: dt is int-like here + else: + # If the end time is a string / datetime object. + if isinstance(dt, str): + tmp = parse_datetime(dt) + else: + tmp = dt + if isinstance(tmp, datetime) and not isinstance(tmp, DateTime): + normalized = pendulum.instance(tmp).replace(tzinfo=None) + elif isinstance(tmp, DateTime): + normalized = tmp.replace(tzinfo=None) + else: + raise TypeError(f"Wrong type for end time: {type(dt)}.") + self._end_dt = normalized @property def dt(self) -> DateTime: diff --git a/abses/space/nature.py b/abses/space/nature.py index 3e3053b0..67bfb9c6 100644 --- a/abses/space/nature.py +++ b/abses/space/nature.py @@ -82,7 +82,8 @@ def create_module( if major_layer is True: self.major_layer = module self.convert_crs(module, write_crs=write_crs) - self.add_layer(module) + if module not in self.layers: + self.add_layer(module) return module def convert_crs(self, module: PatchModule, write_crs: bool = True): diff --git a/abses/space/patch.py b/abses/space/patch.py index 4c1086b3..6d877f40 100644 --- a/abses/space/patch.py +++ b/abses/space/patch.py @@ -53,7 +53,6 @@ MainModelProtocol, Number, Raster, - T, ) @@ -370,8 +369,11 @@ def cells(self) -> List[List[PatchCell]]: return self._cells @functools.cached_property - def array_cells(self) -> NDArray[T]: - """Array type of the `PatchCell` stored in this module.""" + def array_cells(self) -> np.ndarray: + """Array of cells stored in this module. + + Returns a 2D numpy array with dtype ``object`` containing ``PatchCell``. + """ return np.flipud(np.array(self._cells, dtype=object).T) @property @@ -439,7 +441,7 @@ def dynamic_var( self, attr_name: str, dtype: Literal["numpy", "xarray"] = "numpy", - ) -> np.ndarray: + ) -> np.ndarray | xr.DataArray: """Update and get dynamic variable. Parameters: @@ -652,10 +654,12 @@ def apply(self, ufunc: Callable[..., Any], *args: Any, **kwargs: Any) -> np.ndar return np.vectorize(func)(self.array_cells) def coord_iter(self) -> Iterator[tuple[Coordinate, PatchCell]]: - """ - An iterator that returns coordinates as well as cell contents. - """ - return np.ndenumerate(self.array_cells) + """Iterate over coordinates and cells with precise typing.""" + arr = self.array_cells + height, width = arr.shape + for i in range(height): + for j in range(width): + yield (i, j), arr[i, j] def _add_attribute( self, diff --git a/abses/utils/args.py b/abses/utils/args.py index b8034525..8faf57bd 100644 --- a/abses/utils/args.py +++ b/abses/utils/args.py @@ -5,7 +5,7 @@ # GitHub : https://github.com/SongshGeo # Website: https://cv.songshgeo.com/ -from typing import Any, Dict +from typing import Any, Dict, cast from omegaconf import DictConfig, OmegaConf @@ -45,4 +45,4 @@ def merge_parameters(parameters: DictConfig, **kwargs: Dict[str, Any]) -> DictCo if isinstance(merged, DictConfig): OmegaConf.set_struct(merged, False) - return merged + return cast(DictConfig, merged) diff --git a/abses/viz/customize_marker.py b/abses/viz/customize_marker.py index f61356d6..5907de1c 100644 --- a/abses/viz/customize_marker.py +++ b/abses/viz/customize_marker.py @@ -24,9 +24,14 @@ # https://stackoverflow.com/questions/52902086/how-to-use-font-awesome-symbol-as-marker-in-matplotlib def get_marker(symbol: str) -> Path: - """Returns Font Awesome marker.""" + """Return a Matplotlib Path for a given marker symbol. + + If ``symbol`` is a built-in matplotlib marker (e.g. "o", "x"), convert it to + a ``Path`` via ``MarkerStyle(symbol).get_path()``. Otherwise, treat it as a + Font Awesome icon name and build a path from the configured font file. + """ if symbol in markers.MarkerStyle.markers: - return symbol + return markers.MarkerStyle(symbol).get_path() symbol = fa.icons.get(symbol) if not symbol: raise KeyError(f"Could not find {symbol} in marker style.")