Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions abses/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import numpy as np
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.hydra_config import HydraConf, HydraConfig
from joblib import Parallel, delayed
from omegaconf import DictConfig, OmegaConf
from tqdm.auto import tqdm
Expand Down Expand Up @@ -208,8 +208,8 @@ def new(
return cls(model_cls, cfg, **kwargs)

@property
def hydra_config(self) -> DictConfig:
"""Hydra config."""
def hydra_config(self) -> HydraConf:
"""Hydra runtime configuration object (HydraConf)."""
if self.is_hydra_job():
return HydraConfig.get()
raise RuntimeError("Experiment is not running in Hydra.")
Expand Down
8 changes: 5 additions & 3 deletions abses/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,12 @@ def steps(self, steps: int) -> None:
Parameters:
steps: Number of steps. If > 0, automatically advances time.
"""
delta = steps - getattr(self, "_steps", 0)
if not isinstance(delta, int):
raise TypeError(f"Steps must be an integer, got {type(steps)}")
if delta > 0:
self.time.go(delta)
self._steps = steps
if steps > 0:
self.time.go(steps)

def __deepcopy__(self, memo: dict) -> "MainModel":
"""Prevent deep copying of model.
Expand Down Expand Up @@ -437,7 +440,6 @@ def run_model(
run_times = 0
self.do_each("setup", order=order)
while self.running is True:
self.time.go()
self.do_each("step", order=order)
run_times += 1
if steps is not None and run_times >= steps:
Expand Down
33 changes: 21 additions & 12 deletions abses/core/time_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(self, model: MainModelProtocol):
super().__init__(model=model, name="time")
self._history: Deque[DateTime] = deque()
self._history_ticks: Deque[int] = deque()
# End time can only be DateTime | int | None at runtime
self._end_dt: DateTime | int | None = None
self._parse_ticking_mode(set(self.params.keys()))
self._parse_time_settings(self.params)
self._dt = self.start_dt
Expand Down Expand Up @@ -332,18 +334,25 @@ def end_at(self) -> Optional[DateOrTick]:
@end_at.setter
def end_at(self, dt: Optional[DateOrTick | str]) -> None:
"""Set the end time."""
is_tick = is_positive_int(dt, raise_error=False)
if dt is None or is_tick:
self._end_dt = dt
return
# If the end time is a string / datetime object.
if isinstance(dt, str):
dt = parse_datetime(dt)
if isinstance(dt, datetime) and not isinstance(dt, DateTime):
dt = pendulum.instance(dt).replace(tzinfo=None)
elif isinstance(dt, DateTime):
dt = dt.replace(tzinfo=None)
self._end_dt = dt
# Normalize into DateTime | int | None
normalized: DateTime | int | None
if dt is None:
normalized = None
elif is_positive_int(dt, raise_error=False):
normalized = int(dt) # mypy: dt is int-like here
else:
# If the end time is a string / datetime object.
if isinstance(dt, str):
tmp = parse_datetime(dt)
else:
tmp = dt
if isinstance(tmp, datetime) and not isinstance(tmp, DateTime):
normalized = pendulum.instance(tmp).replace(tzinfo=None)
elif isinstance(tmp, DateTime):
normalized = tmp.replace(tzinfo=None)
else:
raise TypeError(f"Wrong type for end time: {type(dt)}.")
self._end_dt = normalized

@property
def dt(self) -> DateTime:
Expand Down
3 changes: 2 additions & 1 deletion abses/space/nature.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def create_module(
if major_layer is True:
self.major_layer = module
self.convert_crs(module, write_crs=write_crs)
self.add_layer(module)
if module not in self.layers:
self.add_layer(module)
return module

def convert_crs(self, module: PatchModule, write_crs: bool = True):
Expand Down
20 changes: 12 additions & 8 deletions abses/space/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
MainModelProtocol,
Number,
Raster,
T,
)


Expand Down Expand Up @@ -370,8 +369,11 @@ def cells(self) -> List[List[PatchCell]]:
return self._cells

@functools.cached_property
def array_cells(self) -> NDArray[T]:
"""Array type of the `PatchCell` stored in this module."""
def array_cells(self) -> np.ndarray:
"""Array of cells stored in this module.

Returns a 2D numpy array with dtype ``object`` containing ``PatchCell``.
"""
return np.flipud(np.array(self._cells, dtype=object).T)

@property
Expand Down Expand Up @@ -439,7 +441,7 @@ def dynamic_var(
self,
attr_name: str,
dtype: Literal["numpy", "xarray"] = "numpy",
) -> np.ndarray:
) -> np.ndarray | xr.DataArray:
"""Update and get dynamic variable.

Parameters:
Expand Down Expand Up @@ -652,10 +654,12 @@ def apply(self, ufunc: Callable[..., Any], *args: Any, **kwargs: Any) -> np.ndar
return np.vectorize(func)(self.array_cells)

def coord_iter(self) -> Iterator[tuple[Coordinate, PatchCell]]:
"""
An iterator that returns coordinates as well as cell contents.
"""
return np.ndenumerate(self.array_cells)
"""Iterate over coordinates and cells with precise typing."""
arr = self.array_cells
height, width = arr.shape
for i in range(height):
for j in range(width):
yield (i, j), arr[i, j]

def _add_attribute(
self,
Expand Down
4 changes: 2 additions & 2 deletions abses/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# GitHub : https://github.com/SongshGeo
# Website: https://cv.songshgeo.com/

from typing import Any, Dict
from typing import Any, Dict, cast

from omegaconf import DictConfig, OmegaConf

Expand Down Expand Up @@ -45,4 +45,4 @@ def merge_parameters(parameters: DictConfig, **kwargs: Dict[str, Any]) -> DictCo
if isinstance(merged, DictConfig):
OmegaConf.set_struct(merged, False)

return merged
return cast(DictConfig, merged)
9 changes: 7 additions & 2 deletions abses/viz/customize_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@

# https://stackoverflow.com/questions/52902086/how-to-use-font-awesome-symbol-as-marker-in-matplotlib
def get_marker(symbol: str) -> Path:
"""Returns Font Awesome marker."""
"""Return a Matplotlib Path for a given marker symbol.

If ``symbol`` is a built-in matplotlib marker (e.g. "o", "x"), convert it to
a ``Path`` via ``MarkerStyle(symbol).get_path()``. Otherwise, treat it as a
Font Awesome icon name and build a path from the configured font file.
"""
if symbol in markers.MarkerStyle.markers:
return symbol
return markers.MarkerStyle(symbol).get_path()
symbol = fa.icons.get(symbol)
if not symbol:
raise KeyError(f"Could not find {symbol} in marker style.")
Expand Down
Loading