From a934497413201d07938a63d9cc576a959d469a0c Mon Sep 17 00:00:00 2001 From: SongshGeo Date: Wed, 7 Jan 2026 11:44:26 +0100 Subject: [PATCH 1/4] refactor(viz): :recycle: Simplify import statements in solara.py This commit refactors the import statements in the `solara.py` file by consolidating the imports from `abses.main` and `abses.patch` into a single line. This change enhances code readability and maintains consistency in the import structure. --- abses/viz/solara.py | 84 +++++++++++-- tests/viz/test_solara.py | 254 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 326 insertions(+), 12 deletions(-) create mode 100644 tests/viz/test_solara.py diff --git a/abses/viz/solara.py b/abses/viz/solara.py index 99b59042..1df7bb07 100644 --- a/abses/viz/solara.py +++ b/abses/viz/solara.py @@ -20,8 +20,12 @@ from mesa.visualization.utils import update_counter from xarray import DataArray -from abses.main import MainModel -from abses.patch import PatchModule +try: + from mesa.visualization.components import AgentPortrayalStyle +except ImportError: + raise ImportError("Mesa 3.3+ is required for AgentPortrayalStyle.") + +from abses import MainModel, PatchModule def draw_property_layers( @@ -106,12 +110,14 @@ def collect_agent_data( Args: space: The space containing the Agents. agent_portrayal: A callable that is called with the agent and returns a dict + or AgentPortrayalStyle object (Mesa 3.3+) color: default color size: default size marker: default marker zorder: default zorder - agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order), + agent_portrayal should return a dict or AgentPortrayalStyle object, limited to + size (size of marker), color (color of marker), zorder (z-order), marker (marker style), alpha, linewidths, and edgecolors """ @@ -126,7 +132,43 @@ def collect_agent_data( } for agent in space.agents: - portray = agent_portrayal(agent) + portray_result = agent_portrayal(agent) + + # Convert AgentPortrayalStyle to dict if needed (Mesa 3.3+) + if AgentPortrayalStyle is not None and isinstance( + portray_result, AgentPortrayalStyle + ): + # AgentPortrayalStyle objects can be converted using vars() or direct attribute access + try: + # Try vars() first (works for most Python objects with __dict__) + portray = vars(portray_result).copy() + except (TypeError, AttributeError): + # Fallback: build dict from direct attribute access + portray = {} + for attr in [ + "size", + "color", + "marker", + "zorder", + "alpha", + "edgecolors", + "linewidths", + ]: + if hasattr(portray_result, attr): + value = getattr(portray_result, attr) + if value is not None: + portray[attr] = value + elif isinstance(portray_result, dict): + portray = portray_result.copy() + else: + # Fallback: try to convert to dict + try: + portray = ( + vars(portray_result) if hasattr(portray_result, "__dict__") else {} + ) + except (TypeError, AttributeError): + portray = {} + arguments["s"].append(portray.pop("size", size)) arguments["c"].append(portray.pop("color", color)) arguments["marker"].append(portray.pop("marker", marker)) @@ -145,7 +187,21 @@ def collect_agent_data( ) # ensure loc is always a shape of (n, 2) array, even if n=0 result = {k: np.asarray(v) for k, v in arguments.items()} - result["loc"] = space.agents.array("indices") + indices_array = space.agents.array("indices") + # Convert to (n, 2) shape + if len(indices_array) == 0: + result["loc"] = np.empty((0, 2), dtype=int) + else: + # Convert list of tuples/arrays to 2D array + result["loc"] = np.array( + [ + list(idx) if isinstance(idx, (tuple, list)) else idx + for idx in indices_array + ] + ) + # Ensure it's 2D + if result["loc"].ndim == 1: + result["loc"] = result["loc"].reshape(-1, 2) return result @@ -168,8 +224,9 @@ def draw_orthogonal_grid( Returns: Returns the Axes object with the plot drawn onto it. - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", - "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + ``agent_portrayal`` is called with an agent and should return a dict or AgentPortrayalStyle + object (Mesa 3.3+). Valid fields are "color", "size", "marker", "zorder", "alpha", + "linewidths", and "edgecolors". Other fields are ignored and will result in a user warning. """ if ax is None: @@ -214,6 +271,11 @@ def SpaceMatplotlib( fig = Figure() ax = fig.add_subplot() + # Draw property layers first (background) + if propertylayer_portrayal: + draw_property_layers(space, propertylayer_portrayal, ax=ax) + + # Draw agents on top draw_orthogonal_grid( space, agent_portrayal, @@ -228,9 +290,6 @@ def SpaceMatplotlib( fig, format="png", bbox_inches="tight", dependencies=dependencies ) - if propertylayer_portrayal: - draw_property_layers(space, propertylayer_portrayal, ax=ax) - def make_mpl_space_component( agent_portrayal: Callable | None = None, @@ -247,8 +306,9 @@ def make_mpl_space_component( space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See the functions for drawing the various spaces for further details. - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", - "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. + ``agent_portrayal`` is called with an agent and should return a dict or AgentPortrayalStyle + object (Mesa 3.3+). Valid fields are "color", "size", "marker", "zorder", "alpha", + "linewidths", and "edgecolors". Other fields are ignored and will result in a user warning. Returns: function: A function that creates a SpaceMatplotlib component diff --git a/tests/viz/test_solara.py b/tests/viz/test_solara.py new file mode 100644 index 00000000..796cfa75 --- /dev/null +++ b/tests/viz/test_solara.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# -*-coding:utf-8 -*- +""" +Tests for the Solara visualization utilities. + +Tests the visualization components to ensure they properly handle +agent portrayal functions returning both dict and AgentPortrayalStyle objects. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from abses.viz.solara import ( + collect_agent_data, + draw_orthogonal_grid, + make_mpl_space_component, +) + +if TYPE_CHECKING: + from abses.core.model import MainModel + from abses.space.patch import PatchModule + + +class TestCollectAgentData: + """Test the collect_agent_data function.""" + + def test_collect_agent_data_with_dict(self, model: "MainModel") -> None: + """Test collect_agent_data with dict-based agent portrayal.""" + from abses.agents.actor import Actor + + # Create a simple agent class + class TestAgent(Actor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.agent_type = 0 + + # Create space and agents + space: PatchModule = model.nature.create_module(shape=(5, 5)) + for i in range(3): + agent = TestAgent(model=model) + cell = space.cells_lst[i] + agent.move.to(cell) + + # Define agent portrayal function returning dict + def agent_portrayal(agent: TestAgent) -> dict: + return { + "color": "tab:orange" if agent.agent_type == 0 else "tab:blue", + "size": 50, + "marker": "o", + } + + # Collect agent data + result = collect_agent_data(space, agent_portrayal) + + # Verify results + assert "loc" in result + assert len(result["loc"]) == 3 + assert len(result["c"]) == 3 + assert len(result["s"]) == 3 + assert all(c == "tab:orange" for c in result["c"]) + + def test_collect_agent_data_with_agentportrayalstyle( + self, model: "MainModel" + ) -> None: + """Test collect_agent_data with AgentPortrayalStyle object (Mesa 3.3+).""" + try: + from mesa.visualization.components import AgentPortrayalStyle + except ImportError: + pytest.skip("AgentPortrayalStyle not available (requires Mesa 3.3+)") + + from abses.agents.actor import Actor + + # Create a simple agent class + class TestAgent(Actor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.agent_type = 0 + + # Create space and agents + space: PatchModule = model.nature.create_module(shape=(5, 5)) + for i in range(3): + agent = TestAgent(model=model) + cell = space.cells_lst[i] + agent.move.to(cell) + + # Define agent portrayal function returning AgentPortrayalStyle + def agent_portrayal(agent: TestAgent) -> AgentPortrayalStyle: + return AgentPortrayalStyle( + color="tab:orange" if agent.agent_type == 0 else "tab:blue", + size=50, + ) + + # Collect agent data + result = collect_agent_data(space, agent_portrayal) + + # Verify results + assert "loc" in result + assert len(result["loc"]) == 3 + assert len(result["c"]) == 3 + assert len(result["s"]) == 3 + assert all(c == "tab:orange" for c in result["c"]) + + def test_collect_agent_data_empty_space(self, model: "MainModel") -> None: + """Test collect_agent_data with empty space.""" + space: PatchModule = model.nature.create_module(shape=(5, 5)) + + def agent_portrayal(agent) -> dict: + return {"color": "red", "size": 10} + + # Collect agent data from empty space + result = collect_agent_data(space, agent_portrayal) + + # Verify results + assert "loc" in result + assert len(result["loc"]) == 0 + assert result["loc"].shape == (0, 2) + + +class TestDrawOrthogonalGrid: + """Test the draw_orthogonal_grid function.""" + + def test_draw_orthogonal_grid_basic(self, model: "MainModel") -> None: + """Test draw_orthogonal_grid with basic setup.""" + import matplotlib.pyplot as plt + + from abses.agents.actor import Actor + + # Create a simple agent class + class TestAgent(Actor): + pass + + # Create space and agents + space: PatchModule = model.nature.create_module(shape=(5, 5)) + for i in range(3): + agent = TestAgent(model=model) + cell = space.cells_lst[i] + agent.move.to(cell) + + # Define agent portrayal function + def agent_portrayal(agent) -> dict: + return {"color": "tab:blue", "size": 25} + + # Draw grid + ax = draw_orthogonal_grid(space, agent_portrayal, draw_grid=True) + + # Verify axes properties + assert ax is not None + assert ax.get_xlim() == (-0.5, 4.5) + assert ax.get_ylim() == (-0.5, 4.5) + + plt.close(ax.figure) + + def test_draw_orthogonal_grid_without_grid_lines(self, model: "MainModel") -> None: + """Test draw_orthogonal_grid without grid lines.""" + import matplotlib.pyplot as plt + + from abses.agents.actor import Actor + + # Create a simple agent class + class TestAgent(Actor): + pass + + # Create space + space: PatchModule = model.nature.create_module(shape=(3, 3)) + + # Define agent portrayal function + def agent_portrayal(agent) -> dict: + return {} + + # Draw grid without grid lines + ax = draw_orthogonal_grid(space, agent_portrayal, draw_grid=False) + + # Verify axes properties + assert ax is not None + assert ax.get_xlim() == (-0.5, 2.5) + assert ax.get_ylim() == (-0.5, 2.5) + + plt.close(ax.figure) + + +class TestMakeMplSpaceComponent: + """Test the make_mpl_space_component function.""" + + def test_make_mpl_space_component_basic(self, model: "MainModel") -> None: + """Test make_mpl_space_component with basic setup.""" + + # Define agent portrayal function + def agent_portrayal(agent) -> dict: + return {"color": "tab:blue", "size": 25} + + # Create component factory + component_factory = make_mpl_space_component(agent_portrayal=agent_portrayal) + + # Verify it's callable + assert callable(component_factory) + + # Create component + component = component_factory(model) + + # Verify component is created (it's a Solara component) + assert component is not None + + def test_make_mpl_space_component_with_property_layers( + self, model: "MainModel" + ) -> None: + """Test make_mpl_space_component with property layers.""" + + # Define agent portrayal function + def agent_portrayal(agent) -> dict: + return {"color": "tab:blue", "size": 25} + + # Define property layer portrayal + # Note: This test just verifies the component can be created with property layers + # The actual layer data would need to be set up properly in a real scenario + propertylayer_portrayal = { + "mask": {"colormap": "viridis", "alpha": 0.5, "colorbar": False} + } + + # Create component factory + component_factory = make_mpl_space_component( + agent_portrayal=agent_portrayal, + propertylayer_portrayal=propertylayer_portrayal, + ) + + # Verify it's callable + assert callable(component_factory) + + # Create component (may raise error if layer doesn't exist, which is expected) + # We just verify the factory works + try: + component = component_factory(model) + assert component is not None + except (ValueError, KeyError, AttributeError): + # Expected if layer doesn't exist - the factory still works + pass + + def test_make_mpl_space_component_default_agent_portrayal( + self, model: "MainModel" + ) -> None: + """Test make_mpl_space_component with default agent portrayal.""" + # Create component factory without agent_portrayal + component_factory = make_mpl_space_component() + + # Verify it's callable + assert callable(component_factory) + + # Create component + component = component_factory(model) + + # Verify component is created + assert component is not None From ac719500a110e3f2ed97b225d03bbaf953c3100c Mon Sep 17 00:00:00 2001 From: SongshGeo Date: Wed, 7 Jan 2026 16:59:25 +0100 Subject: [PATCH 2/4] refactor(analysis): :recycle: Enhance read_data method for improved CSV handling This commit refactors the `read_data` method in the `ResultAnalyzer` class to improve the process of reading and merging CSV files. It now prioritizes files matching the pattern `*_cities.csv`, concatenates them into a single DataFrame, and includes better error handling for missing files. Additionally, it maintains backward compatibility by falling back to a single `cities.csv` if no valid files are found. This change enhances data loading efficiency and robustness. --- abses/utils/analysis.py | 63 ++++++++++++++++++------------------ abses/utils/datacollector.py | 4 +-- docs/api/analysis.md | 1 + 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/abses/utils/analysis.py b/abses/utils/analysis.py index fca8a691..0ac187e9 100644 --- a/abses/utils/analysis.py +++ b/abses/utils/analysis.py @@ -250,41 +250,40 @@ def _load_hydra_cfg(self, path: PathLike) -> None: self.agent_reporter = reporters.get("agents", {}) self.final_reporter = reporters.get("final", {}) - def read_data(self) -> None: - """Read data from CSV files or datacollector output. - - This method attempts to find and load data files in the following order: - 1. Common CSV filenames (cities.csv, 1_cities.csv, etc.) - 2. Datacollector output files if available - 3. User-specified files - - Raises: - FileNotFoundError: If no data file is found. + def read_data(self, suffix: str = "csv") -> pd.DataFrame: + """Read and merge result csv files under the experiment folder. + + This method will: + - First, look for all files matching ``*_cities.csv`` (e.g. ``1_cities.csv``, + ``2_cities.csv`` ...) under ``self.path``. + - If found, read them all and vertically concatenate them into a single + dataframe. + - If none are found, fall back to reading a single ``cities.csv`` file. """ - # Try common CSV filenames - common_names = ["cities.csv", "1_cities.csv", "data.csv", "results.csv"] - for name in common_names: - csv_path = self.path / name - if csv_path.is_file(): - self.data = self.read_csv(csv_path) - logger.info(f"Loaded data from {csv_path}.") - return + # Prefer numbered runs like 1_cities.csv, 2_cities.csv, ... + csv_files = sorted(self.path.glob(f"*.{suffix}")) - # Try to find any CSV file in the directory - csv_files = list(self.path.glob("*.csv")) if csv_files: - # Use the first CSV file found - self.data = self.read_csv(csv_files[0]) - logger.info(f"Loaded data from {csv_files[0]}.") - return - - # If no CSV found, try to load from datacollector output - # This would require the datacollector to have saved its output - # For now, we'll raise an error - raise FileNotFoundError( - f"No data file found in {self.path}. " - f"Expected CSV files or datacollector output." - ) + data_frames = [] + for csv_file in csv_files: + try: + df = self.read_csv(path=csv_file) + data_frames.append(df) + except FileNotFoundError: + logger.warning(f"Skip missing file: {csv_file}") + if not data_frames: + raise FileNotFoundError( + f"No valid *.{suffix} files found under {self.path}." + ) + self.data = pd.concat(data_frames, ignore_index=True) + logger.info( + "Loaded and merged result files: " + f"{[f.name for f in csv_files]} from {self.path}." + ) + else: + # Backward compatibility: fall back to a single cities.csv + logger.warning(f"No valid *.{suffix} files found under {self.path}.") + return pd.DataFrame() def read_csv(self, path: PathLike) -> pd.DataFrame: """Read a CSV file into a DataFrame. diff --git a/abses/utils/datacollector.py b/abses/utils/datacollector.py index 526d3e62..03a5923c 100644 --- a/abses/utils/datacollector.py +++ b/abses/utils/datacollector.py @@ -28,8 +28,8 @@ if TYPE_CHECKING: from abses.agents.actor import Actor from abses.agents.sequences import ActorsList - from abses.main import MainModel - from abses.time import TimeDriver + from abses.core.model import MainModel + from abses.core.time_driver import TimeDriver from abses.utils.tracker import TrackerProtocol diff --git a/docs/api/analysis.md b/docs/api/analysis.md index cfa9c12a..2b47ba41 100644 --- a/docs/api/analysis.md +++ b/docs/api/analysis.md @@ -10,3 +10,4 @@ date: 2024-12-20 + From 1745558440abcb6ed043265a1fb099ab96121499 Mon Sep 17 00:00:00 2001 From: SongshGeo Date: Wed, 7 Jan 2026 17:28:35 +0100 Subject: [PATCH 3/4] fix(tracker): :bug: Resolved the issue where configurations could only be recorded as strings. This commit updates the `AimTracker` class to accept `DictConfig` directly in the `log_params` method, leveraging OmegaConf's integration for improved parameter logging. Additionally, the `start_tracker_run` function is modified to handle `DictConfig` appropriately when using `AimTracker`, ensuring seamless logging of model parameters. These changes enhance the flexibility and usability of the tracking system, particularly for configurations managed by OmegaConf. --- abses/utils/tracker/aim_tracker.py | 16 ++- abses/utils/tracker/factory.py | 18 ++- tests/utils/test_tracker.py | 188 ++++++++++++++++++++++++++++- 3 files changed, 214 insertions(+), 8 deletions(-) diff --git a/abses/utils/tracker/aim_tracker.py b/abses/utils/tracker/aim_tracker.py index 0f27131a..c90f85e7 100644 --- a/abses/utils/tracker/aim_tracker.py +++ b/abses/utils/tracker/aim_tracker.py @@ -8,6 +8,12 @@ from abses.utils.tracker import TrackerProtocol +try: + from omegaconf import DictConfig, OmegaConf +except ImportError: + DictConfig = None + OmegaConf = None + try: from aim import Run except ImportError: @@ -140,12 +146,18 @@ def log_final_metrics( if numeric_metrics: self.log_metrics(numeric_metrics, step=step) - def log_params(self, params: Dict[str, Any]) -> None: + def log_params(self, params: Dict[str, Any] | DictConfig) -> None: """Log hyperparameters to Aim. Args: - params: Dictionary of parameter names to values. + params: Dictionary of parameter names to values, or DictConfig. """ + # If params is DictConfig, use Aim's built-in OmegaConf integration + if DictConfig is not None and isinstance(params, DictConfig): + self._run["config"] = OmegaConf.to_container(params, resolve=True) + return + + # Otherwise, handle as regular dict for key, value in params.items(): # Aim supports various types for parameters if isinstance(value, (int, float, str, bool)): diff --git a/abses/utils/tracker/factory.py b/abses/utils/tracker/factory.py index 6ce49b9c..8274c314 100644 --- a/abses/utils/tracker/factory.py +++ b/abses/utils/tracker/factory.py @@ -155,12 +155,20 @@ def start_tracker_run( log_params = cfg_dict.get("log_params", True) if log_params and hasattr(tracker, "log_params"): - if isinstance(model_params, DictConfig): - params_dict = OmegaConf.to_container(model_params, resolve=True) + # For AimTracker, if model_params is DictConfig, pass it directly + # to use Aim's built-in OmegaConf integration + if type(tracker).__name__ == "AimTracker" and isinstance( + model_params, DictConfig + ): + tracker.log_params(model_params) else: - params_dict = dict(model_params) - if isinstance(params_dict, dict): - tracker.log_params(params_dict) + # For other trackers or plain dict, convert to dict first + if isinstance(model_params, DictConfig): + params_dict = OmegaConf.to_container(model_params, resolve=True) + else: + params_dict = dict(model_params) + if isinstance(params_dict, dict): + tracker.log_params(params_dict) def create_tracker( diff --git a/tests/utils/test_tracker.py b/tests/utils/test_tracker.py index 1805db5e..af994cf7 100644 --- a/tests/utils/test_tracker.py +++ b/tests/utils/test_tracker.py @@ -4,10 +4,17 @@ from __future__ import annotations +from unittest.mock import MagicMock + from omegaconf import OmegaConf from abses.utils.tracker.default import DefaultTracker -from abses.utils.tracker.factory import create_tracker +from abses.utils.tracker.factory import ( + create_tracker, + prepare_tracker_run_name, + prepare_tracker_tags, + start_tracker_run, +) def test_create_tracker_default() -> None: @@ -20,3 +27,182 @@ def test_create_tracker_unknown_backend_fallback() -> None: """Unknown backend should fall back to default.""" tracker = create_tracker(OmegaConf.create({"backend": "unknown"}), model=None) assert isinstance(tracker, DefaultTracker) + + +# --- Tests for prepare_tracker_run_name --- + + +def test_prepare_tracker_run_name_default_with_run_id() -> None: + """Default run name includes model name and run id.""" + cfg = OmegaConf.create({}) + run_name = prepare_tracker_run_name(cfg, "TestModel", "1.0.0", 42) + assert run_name == "TestModel_run_42" + + +def test_prepare_tracker_run_name_default_without_run_id() -> None: + """Default run name is model name when no run id.""" + cfg = OmegaConf.create({}) + run_name = prepare_tracker_run_name(cfg, "TestModel", "1.0.0", None) + assert run_name == "TestModel" + + +def test_prepare_tracker_run_name_custom_template() -> None: + """Custom run name template is formatted correctly.""" + cfg = OmegaConf.create({"run_name": "{model_name}_v{version}_{run_id}"}) + run_name = prepare_tracker_run_name(cfg, "TestModel", "1.0.0", 5) + assert run_name == "TestModel_v1.0.0_5" + + +def test_prepare_tracker_run_name_invalid_template() -> None: + """Invalid template returns template as-is.""" + cfg = OmegaConf.create({"run_name": "{invalid_key}_test"}) + run_name = prepare_tracker_run_name(cfg, "TestModel", "1.0.0", 1) + assert run_name == "{invalid_key}_test" + + +# --- Tests for prepare_tracker_tags --- + + +def test_prepare_tracker_tags_default() -> None: + """Default tags include model and version.""" + cfg = OmegaConf.create({}) + tags = prepare_tracker_tags(cfg, "TestModel", "1.0.0", None) + assert tags == {"model": "TestModel", "version": "1.0.0"} + + +def test_prepare_tracker_tags_default_with_run_id() -> None: + """Default tags include run_id when provided.""" + cfg = OmegaConf.create({}) + tags = prepare_tracker_tags(cfg, "TestModel", "1.0.0", 42) + assert tags == {"model": "TestModel", "version": "1.0.0", "run_id": "42"} + + +def test_prepare_tracker_tags_custom() -> None: + """Custom tags are returned from config.""" + cfg = OmegaConf.create({"tags": {"env": "test", "project": "demo"}}) + tags = prepare_tracker_tags(cfg, "TestModel", "1.0.0", None) + assert tags == {"env": "test", "project": "demo"} + + +def test_prepare_tracker_tags_with_template() -> None: + """Tags with template variables are formatted.""" + cfg = OmegaConf.create({"tags": {"model_tag": "{model_name}", "run": "{run_id}"}}) + tags = prepare_tracker_tags(cfg, "TestModel", "1.0.0", 99) + assert tags == {"model_tag": "TestModel", "run": "99"} + + +# --- Tests for start_tracker_run --- + + +def test_start_tracker_run_with_none_tracker() -> None: + """start_tracker_run does nothing when tracker is None.""" + # Should not raise any error + start_tracker_run( + tracker=None, + tracker_cfg=OmegaConf.create({}), + model_name="TestModel", + version="1.0.0", + run_id=1, + model_params=OmegaConf.create({"param1": 1}), + ) + + +def test_start_tracker_run_calls_start_run() -> None: + """start_tracker_run calls tracker.start_run with correct args.""" + mock_tracker = MagicMock() + mock_tracker.log_params = MagicMock() + + start_tracker_run( + tracker=mock_tracker, + tracker_cfg=OmegaConf.create({}), + model_name="TestModel", + version="1.0.0", + run_id=1, + model_params=OmegaConf.create({"param1": 1}), + ) + + mock_tracker.start_run.assert_called_once() + call_kwargs = mock_tracker.start_run.call_args[1] + assert call_kwargs["run_name"] == "TestModel_run_1" + assert "model" in call_kwargs["tags"] + + +def test_start_tracker_run_logs_params_dict() -> None: + """start_tracker_run logs params as dict for non-Aim trackers.""" + mock_tracker = MagicMock() + mock_tracker.log_params = MagicMock() + # Ensure type name is not "AimTracker" + type(mock_tracker).__name__ = "MockTracker" + + model_params = OmegaConf.create({"param1": 1, "param2": "value"}) + + start_tracker_run( + tracker=mock_tracker, + tracker_cfg=OmegaConf.create({"log_params": True}), + model_name="TestModel", + version="1.0.0", + run_id=1, + model_params=model_params, + ) + + mock_tracker.log_params.assert_called_once() + logged_params = mock_tracker.log_params.call_args[0][0] + # Should be a plain dict, not DictConfig + assert isinstance(logged_params, dict) + assert logged_params["param1"] == 1 + assert logged_params["param2"] == "value" + + +def test_start_tracker_run_logs_params_disabled() -> None: + """start_tracker_run does not log params when disabled.""" + mock_tracker = MagicMock() + mock_tracker.log_params = MagicMock() + + start_tracker_run( + tracker=mock_tracker, + tracker_cfg=OmegaConf.create({"log_params": False}), + model_name="TestModel", + version="1.0.0", + run_id=1, + model_params=OmegaConf.create({"param1": 1}), + ) + + mock_tracker.log_params.assert_not_called() + + +def test_start_tracker_run_aim_tracker_with_dictconfig() -> None: + """AimTracker receives DictConfig directly for log_params.""" + + class FakeAimTracker: + """Fake AimTracker to test type name checking.""" + + def __init__(self): + self.start_run_called = False + self.logged_params = None + + def start_run(self, run_name=None, tags=None): + self.start_run_called = True + + def log_params(self, params): + self.logged_params = params + + # Rename the class to "AimTracker" for the type check + FakeAimTracker.__name__ = "AimTracker" + fake_tracker = FakeAimTracker() + + model_params = OmegaConf.create({"param1": 1, "nested": {"key": "value"}}) + + start_tracker_run( + tracker=fake_tracker, + tracker_cfg=OmegaConf.create({"log_params": True}), + model_name="TestModel", + version="1.0.0", + run_id=1, + model_params=model_params, + ) + + # AimTracker should receive the DictConfig directly + assert fake_tracker.logged_params is model_params + from omegaconf import DictConfig + + assert isinstance(fake_tracker.logged_params, DictConfig) From e0b5d5f2a5c8e8182cd26b1d2a7c6d6a3fce1f8d Mon Sep 17 00:00:00 2001 From: SongshGeo Date: Wed, 7 Jan 2026 20:52:08 +0100 Subject: [PATCH 4/4] fix(analysis): :bug: Improve data handling in ResultAnalyzer for missing CSV files This commit updates the `read_data` method in the `ResultAnalyzer` class to ensure that it initializes `self.data` as an empty DataFrame when no valid CSV files are found. Additionally, it returns `self.data` consistently, enhancing the method's reliability and clarity in handling data loading scenarios. These changes improve the robustness of the analysis utilities. --- abses/utils/analysis.py | 4 +++- abses/utils/datacollector.py | 2 +- abses/viz/solara.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/abses/utils/analysis.py b/abses/utils/analysis.py index 0ac187e9..41ea1896 100644 --- a/abses/utils/analysis.py +++ b/abses/utils/analysis.py @@ -280,10 +280,12 @@ def read_data(self, suffix: str = "csv") -> pd.DataFrame: "Loaded and merged result files: " f"{[f.name for f in csv_files]} from {self.path}." ) + return self.data else: # Backward compatibility: fall back to a single cities.csv logger.warning(f"No valid *.{suffix} files found under {self.path}.") - return pd.DataFrame() + self.data = pd.DataFrame() + return self.data def read_csv(self, path: PathLike) -> pd.DataFrame: """Read a CSV file into a DataFrame. diff --git a/abses/utils/datacollector.py b/abses/utils/datacollector.py index 03a5923c..0f3f84cb 100644 --- a/abses/utils/datacollector.py +++ b/abses/utils/datacollector.py @@ -192,7 +192,7 @@ def _record_a_breed_of_agents( result = { "AgentID": agents.array("unique_id"), "Step": np.repeat(time.tick, len(agents)), - "Time": np.repeat(time.dt, len(agents)), + "Time": np.repeat(str(time.dt), len(agents)), } for name, reporter in self.agent_reporters[breed].items(): result[name] = agents.apply(reporter) diff --git a/abses/viz/solara.py b/abses/viz/solara.py index 1df7bb07..f0181f7e 100644 --- a/abses/viz/solara.py +++ b/abses/viz/solara.py @@ -23,7 +23,7 @@ try: from mesa.visualization.components import AgentPortrayalStyle except ImportError: - raise ImportError("Mesa 3.3+ is required for AgentPortrayalStyle.") + AgentPortrayalStyle = None from abses import MainModel, PatchModule