Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 156 additions & 21 deletions abses/utils/tracker/aim_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# @Author : ABSESpy Team
from __future__ import annotations

import statistics
from typing import Any, Dict

from abses.utils.tracker import TrackerProtocol
Expand All @@ -15,9 +14,10 @@
OmegaConf = None

try:
from aim import Run
from aim import Distribution, Run
except ImportError:
Run = None
Distribution = None


class AimTracker(TrackerProtocol):
Expand All @@ -41,11 +41,13 @@ def __init__(self, config: Dict[str, Any]) -> None:
config: Aim-specific configuration dictionary. Supported keys:
- experiment: Experiment name (optional)
- repo: Path to Aim repository (optional, defaults to ~/.aim)
- distribution_bin_count: Number of bins for Distribution (optional, default 64, range 1-512)
- log_categorical_stats: Whether to log categorical statistics (optional, default True)

Raises:
ImportError: If aim is not installed.
"""
if Run is None:
if Run is None or Distribution is None:
raise ImportError(
"Aim is not installed. Install with: pip install abses[aim] or pip install aim"
)
Expand All @@ -54,6 +56,15 @@ def __init__(self, config: Dict[str, Any]) -> None:
self._run = Run(experiment=experiment, repo=repo)
self._params_logged = False

# Distribution configuration
bin_count = config.get("distribution_bin_count", 64)
if not isinstance(bin_count, int) or bin_count < 1 or bin_count > 512:
raise ValueError(
f"distribution_bin_count must be an integer between 1 and 512, got {bin_count}"
)
self._bin_count = bin_count
self._log_categorical_stats = config.get("log_categorical_stats", True)

def start_run(
self, run_name: str | None = None, tags: Dict[str, str] | None = None
) -> None:
Expand Down Expand Up @@ -102,33 +113,157 @@ def log_agent_vars(
) -> None:
"""Log agent variables with breed prefix.

Uses Aim Distribution for numeric variables and frequency statistics for categorical variables.
Directly uses pandas Series and numpy arrays, leveraging built-in tools for type conversion
and NaN handling.

Args:
breed: Agent breed/class name.
agent_vars: Dictionary of variable names to values (can be lists for aggregation).
agent_vars: Dictionary of variable names to values (can be lists, Series, or arrays).
step: Step number (optional).
"""
# Handle list values (aggregated agent data)
metrics_to_log: Dict[str, float] = {}
import numpy as np
import pandas as pd

for key, value in agent_vars.items():
metric_name = f"{breed}.{key}"
if isinstance(value, (int, float)):
metrics_to_log[metric_name] = value
elif isinstance(value, list) and value:
# Aggregate list values (mean, min, max)
numeric_values = [v for v in value if isinstance(v, (int, float))]
if numeric_values:
metrics_to_log[f"{metric_name}.mean"] = statistics.mean(
numeric_values

# Convert to pandas Series (if not already)
if isinstance(value, list):
series = pd.Series(value)
elif isinstance(value, pd.Series):
series = value
elif isinstance(value, np.ndarray):
series = pd.Series(value)
elif isinstance(value, (int, float)):
# Single scalar value
self._run.track(value, name=metric_name, step=step)
continue
else:
# Other types, try to convert
try:
series = pd.Series(value)
except (TypeError, ValueError):
continue

# Skip empty Series
if len(series) == 0:
continue

# Handle based on data type
# Note: Boolean must be checked before numeric, because is_numeric_dtype
# returns True for boolean types as well.

# 1. Boolean type -> Convert to 0/1 then use Distribution
# Check both bool dtype and object dtype with boolean values
# (pandas converts bool dtype to object when None values are present)
is_boolean_type = pd.api.types.is_bool_dtype(series) or (
pd.api.types.is_object_dtype(series)
and len(series.dropna()) > 0
and all(isinstance(x, bool) for x in series.dropna())
)
if is_boolean_type:
bool_series = series.dropna()
if len(bool_series) == 0:
continue
# Convert to 0/1
numeric_series = bool_series.astype(int)
if len(numeric_series) == 1:
self._run.track(numeric_series.iloc[0], name=metric_name, step=step)
else:
dist = Distribution(
distribution=numeric_series, bin_count=self._bin_count
)
metrics_to_log[f"{metric_name}.min"] = min(numeric_values)
metrics_to_log[f"{metric_name}.max"] = max(numeric_values)
if len(numeric_values) > 1:
metrics_to_log[f"{metric_name}.std"] = statistics.stdev(
numeric_values
self._run.track(dist, name=metric_name, step=step)
# Additional statistics
true_count = bool_series.sum()
self._run.track(true_count, name=f"{metric_name}.true_count", step=step)
self._run.track(
true_count / len(bool_series),
name=f"{metric_name}.true_ratio",
step=step,
)

# 2. Numeric types (int, float) -> Distribution
elif pd.api.types.is_numeric_dtype(series):
# Remove NaN (pandas handles automatically)
numeric_series = series.dropna()
if len(numeric_series) == 0:
continue
elif len(numeric_series) == 1:
# Single value, log as scalar
self._run.track(numeric_series.iloc[0], name=metric_name, step=step)
else:
# Multiple values, use Distribution
dist = Distribution(
distribution=numeric_series, bin_count=self._bin_count
)
self._run.track(dist, name=metric_name, step=step)

# 3. String type (categorical) -> Use pandas value_counts()
elif pd.api.types.is_string_dtype(series) or pd.api.types.is_object_dtype(
series
):
if not self._log_categorical_stats:
continue
# Remove NaN and empty strings
str_series = series.dropna()
str_series = str_series[str_series != ""]
if len(str_series) == 0:
continue

# Use pandas value_counts() for statistics
value_counts = str_series.value_counts()
unique_count = len(value_counts)
total_count = len(str_series)

# Log unique count
self._run.track(
unique_count, name=f"{metric_name}.unique_count", step=step
)

# Log most common category
if len(value_counts) > 0:
most_common = value_counts.iloc[0]
self._run.track(
most_common,
name=f"{metric_name}.most_common_count",
step=step,
)
self._run.track(
most_common / total_count,
name=f"{metric_name}.most_common_ratio",
step=step,
)

# If category count <= 10, log each category's count
if unique_count <= 10:
for category, count in value_counts.items():
# Clean category name (replace special characters)
safe_name = str(category).replace(".", "_").replace(" ", "_")
self._run.track(
count,
name=f"{metric_name}.{safe_name}_count",
step=step,
)

if metrics_to_log:
self.log_metrics(metrics_to_log, step=step)
# 4. Other types: Try to convert to numeric
else:
try:
numeric_series = pd.to_numeric(series, errors="coerce").dropna()
if len(numeric_series) > 0:
if len(numeric_series) == 1:
self._run.track(
numeric_series.iloc[0], name=metric_name, step=step
)
else:
dist = Distribution(
distribution=numeric_series, bin_count=self._bin_count
)
self._run.track(dist, name=metric_name, step=step)
except (TypeError, ValueError):
# Cannot convert, skip
pass

def log_final_metrics(
self, metrics: Dict[str, Any], step: int | None = None
Expand Down
64 changes: 64 additions & 0 deletions docs/home/configuration_schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,32 @@ Agent trackers collect data from agent instances at each step.
- The referenced attribute must exist on the agent class
- Attributes should be scalar values

**Data Type Handling:**

When using the Aim tracker backend (`backend: aim`), agent variables are automatically handled based on their data types:

- **Numeric types** (int, float): Recorded as **Distribution** objects in Aim, allowing you to visualize the full distribution of values across agents (histograms, density plots, etc.). This preserves the heterogeneity of agent attributes.

- **Boolean types**: Converted to 0/1 and recorded as Distribution, with additional statistics (true_count, true_ratio).

- **String types** (categorical): Recorded as frequency statistics:
- `{breed}.{attribute}.unique_count` - Number of unique categories
- `{breed}.{attribute}.most_common_count` - Count of most common category
- `{breed}.{attribute}.most_common_ratio` - Ratio of most common category
- `{breed}.{attribute}.{category}_count` - Count for each category (if ≤10 categories)

**Aim Tracker Configuration:**

```yaml
tracker:
backend: aim
aim:
experiment: "my_experiment"
repo: "./aim_repo" # Optional, defaults to ~/.aim
distribution_bin_count: 64 # Optional, default 64, range 1-512
log_categorical_stats: true # Optional, default true
```

**Example Agent Classes:**

```python
Expand Down Expand Up @@ -430,6 +456,43 @@ tracker:
burned_rate: "burned_rate"
```

### Agent Variable Distribution Tracking (Aim Backend)

When using the Aim tracker backend, agent variables are tracked as distributions rather than simple aggregates. This allows you to:

- **Visualize heterogeneity**: See the full distribution of agent attributes, not just mean/min/max
- **Track changes over time**: Observe how distributions evolve during simulation
- **Compare runs**: Compare distributions across different parameter settings

**Example:**

```yaml
tracker:
backend: aim
aim:
experiment: "flood_adaptation_abm"
agents:
City:
budget: budget
population: population
Individual:
wealth: wealth
moved: moved # Boolean
status: status # String/categorical
```

In Aim UI, you'll see:
- `City.budget` as a distribution (histogram) showing the full range of budgets
- `City.population` as a distribution
- `Individual.wealth` as a distribution
- `Individual.moved` as a distribution (0/1) plus `Individual.moved.true_count` and `Individual.moved.true_ratio`
- `Individual.status` as frequency statistics (unique_count, most_common_count, etc.)

**Configuration Options:**

- `distribution_bin_count` (default: 64, range: 1-512): Number of bins for Distribution histograms
- `log_categorical_stats` (default: true): Whether to log statistics for string/categorical variables

### Common Tracker Errors

| Error | Cause | Solution |
Expand All @@ -438,6 +501,7 @@ tracker:
| `KeyError: 'Sheep'` | Agent breed name mismatch | Use exact class name (case-sensitive) |
| Empty DataFrame | No trackers defined | Add at least one tracker |
| `TypeError: 'str' object is not callable` | Tried to call a string | Use method name without quotes in code |
| `ValueError: distribution_bin_count must be...` | Invalid bin_count value | Use integer between 1 and 512 |

---

Expand Down
Loading
Loading