Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 23 additions & 51 deletions abses/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ def run_single(
hooks:
The hooks to run after the model is run.
"""
job_id, repeat_id = key
job_id, run_id = key
model = model_cls(
parameters=cfg,
run_id=repeat_id,
run_id=run_id,
seed=seed,
**kwargs,
)
Expand All @@ -135,7 +135,7 @@ def run_single(
for hook_name, hook_func in hooks.items():
logger.info(f"Running hook {hook_name}.")
_call_hook_with_optional_args(
hook_func, model, job_id=job_id, repeat_id=repeat_id
hook_func, model, job_id=job_id, run_id=run_id
)
return key, seed, results

Expand Down Expand Up @@ -370,45 +370,17 @@ def _load_hydra_cfg(

return cfg

# def _get_logging_mode(self, repeat_id: Optional[int] = None) -> str | bool:
# log_mode = self.exp_config.get("logging", "once")
# if log_mode == "once":
# if repeat_id == 1:
# logging: bool | str = self.name
# else:
# return False
# elif bool(log_mode):
# logging = f"{self.name}_{repeat_id}"
# else:
# logging = False
# return logging

# def _update_log_config(
# self, config, repeat_id: Optional[int] = None
# ) -> bool:
# """Update the log configuration."""
# if isinstance(config, dict):
# config = DictConfig(config)
# OmegaConf.set_struct(config, False)
# log_name = self._get_logging_mode(repeat_id=repeat_id)
# if not log_name:
# config["log"] = False
# return config
# logging_cfg = OmegaConf.create({"log": {"name": log_name}})
# config = OmegaConf.merge(config, logging_cfg)
# return config

def _get_seed(self, repeat_id: int, job_id: Optional[int] = None) -> Optional[int]:
def _get_seed(self, run_id: int, job_id: Optional[int] = None) -> Optional[int]:
"""获取每次运行的随机种子

使用基础种子初始化随机数生成器,为每次运行生成唯一的随机种子。
这样可以保证:
1. 如果基础种子相同,生成的种子序列也相同
2. 不同的 job_id 和 repeat_id 组合会得到不同的种子
2. 不同的 job_id 和 run_id 组合会得到不同的种子
3. 种子序列具有更好的随机性

Args:
repeat_id: 重复实验的ID
run_id: 重复实验的ID

Returns:
如果没有设置基础种子则返回 None,否则返回生成的随机种子
Expand All @@ -419,7 +391,7 @@ def _get_seed(self, repeat_id: int, job_id: Optional[int] = None) -> Optional[in
if job_id is None:
job_id = self.job_id
# 使用基础种子和 job_id 创建随机数生成器
r = random.Random(self._base_seed + job_id * 1000 + repeat_id)
r = random.Random(self._base_seed + job_id * 1000 + run_id)
return r.randrange(2**32)

def _get_logging_mode(self) -> str:
Expand All @@ -431,13 +403,13 @@ def _get_logging_mode(self) -> str:
return get_log_mode(self._cfg)

def _get_log_file_path(
self, log_name: str, repeat_id: int, logging_mode: str
self, log_name: str, run_id: int, logging_mode: str
) -> Optional[Path]:
"""Get log file path for a specific repeat.

Args:
log_name: Base log file name.
repeat_id: Repeat ID (1-indexed).
run_id: Repeat ID (1-indexed).
logging_mode: Logging mode.

Returns:
Expand All @@ -449,7 +421,7 @@ def _get_log_file_path(
outpath=self.outpath,
log_name=log_name,
logging_mode=logging_mode,
repeat_id=repeat_id,
run_id=run_id,
)

def _log_experiment_info(
Expand Down Expand Up @@ -518,18 +490,18 @@ def _batch_run_repeats(
if self._is_hydra_parallel() or number_process == 1:
# Hydra 并行或指定单进程时,顺序执行
disable = repeats == 1 or not display_progress
for repeat_id in tqdm(
for run_id in tqdm(
range(1, repeats + 1),
disable=disable,
desc=f"Job {self.job_id} repeats {repeats} times.",
):
# Log separator for merge mode
if logging_mode == "merge" and repeat_id > 1:
if logging_mode == "merge" and run_id > 1:
# Note: Separator will be logged in model setup
pass

# Get log file path for this repeat
log_path = self._get_log_file_path(log_name, repeat_id, logging_mode)
log_path = self._get_log_file_path(log_name, run_id, logging_mode)

# Display log file location for separate mode
# This should only go to stdout, not to model run log files
Expand All @@ -539,14 +511,14 @@ def _batch_run_repeats(
and log_path is not None
):
# Use print instead of logger to avoid writing to model run log files
print(f"Repeat {repeat_id}: Logging to {log_path}")
print(f"Repeat {run_id}: Logging to {log_path}")

run_single(
model_cls=self.model_cls,
cfg=cfg,
key=(self.job_id, repeat_id),
key=(self.job_id, run_id),
outpath=self.outpath,
seed=self._get_seed(repeat_id),
seed=self._get_seed(run_id),
hooks=self._manager.hooks,
**self._extra_kwargs,
)
Expand All @@ -564,13 +536,13 @@ def _batch_run_repeats(
delayed(run_single)(
model_cls=self.model_cls,
cfg=cfg,
key=(self.job_id, repeat_id),
key=(self.job_id, run_id),
outpath=self.outpath,
seed=self._get_seed(repeat_id),
seed=self._get_seed(run_id),
hooks=self._manager.hooks,
**self._extra_kwargs,
)
for repeat_id in tqdm(
for run_id in tqdm(
range(1, repeats + 1),
disable=not display_progress,
desc=f"Job {self.job_id} repeats {repeats} times, with {number_process} processes.",
Expand Down Expand Up @@ -644,22 +616,22 @@ def _call_hook_with_optional_args(
hook_func: Callable,
model: MainModelProtocol,
job_id: Optional[int] = None,
repeat_id: Optional[int] = None,
run_id: Optional[int] = None,
) -> Any:
"""根据钩子函数的参数签名动态调用函数

Args:
hook_func: 要调用的钩子函数
model: 模型实例
job_id: 可选的任务ID
repeat_id: 可选的重复实验ID
run_id: 可选的重复实验ID
"""
sig = inspect.signature(hook_func)
hook_args = {}

if "job_id" in sig.parameters:
hook_args["job_id"] = job_id
if "repeat_id" in sig.parameters:
hook_args["repeat_id"] = repeat_id
if "run_id" in sig.parameters:
hook_args["run_id"] = run_id

return hook_func(model, **hook_args)
30 changes: 25 additions & 5 deletions abses/core/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type

import pandas as pd
Expand Down Expand Up @@ -102,7 +103,7 @@ def update_result(
"""更新实验结果

Args:
key: (job_id, repeat_id) tuple
key: (job_id, run_id) tuple
overrides: Configuration overrides for this run
datasets: Row-like mapping of metrics/values to store
seed: Random seed used for this run
Expand All @@ -115,25 +116,44 @@ def dict_to_df(self, results: dict) -> pd.DataFrame:
"""将嵌套字典转换为 DataFrame

Args:
results: 形如 {(job_id, repeat_id): {'metric': value}} 的字典
results: 形如 {(job_id, run_id): {'metric': value}} 的字典

Returns:
包含 job_id, repeat_id 和指标值的 DataFrame
包含 job_id, run_id 和指标值的 DataFrame
"""
return pd.DataFrame(results.values(), index=self.index)

def get_datasets(
self,
seed: bool = True,
) -> pd.DataFrame:
"""获取所有实验结果的 DataFrame"""
"""获取所有实验结果的 DataFrame

Note:
The ``repeat_id`` column is **deprecated** and will be removed in a
future version. Please use the ``run_id`` column instead.
"""
to_concat = []
to_concat.append(self.dict_to_df(self._overrides))
if seed:
seed = pd.Series(self._seeds, name="seed", index=self.index)
to_concat.append(seed)
to_concat.append(self.dict_to_df(self._datasets))
return pd.concat(to_concat, axis=1).reset_index()
df = pd.concat(to_concat, axis=1).reset_index()

# Backward compatibility: if legacy results contain a `repeat_id` column
# (e.g. from older versions or custom datasets), mirror it into `run_id`
# and emit a deprecation warning. New code should only rely on `run_id`.
if "repeat_id" in df.columns and "run_id" not in df.columns:
warnings.warn(
"Column 'repeat_id' is deprecated and will be removed in a future "
"version. Please use 'run_id' instead.",
DeprecationWarning,
stacklevel=2,
)
df["run_id"] = df["repeat_id"]

return df

def add_a_hook(
self,
Expand Down
6 changes: 4 additions & 2 deletions abses/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def __init__(
tracker_backend = create_tracker(tracker_cfg, model=self)
collector_cfg = prepare_collector_config(tracker_cfg)
self.datacollector: ABSESpyDataCollector = ABSESpyDataCollector(
reports=collector_cfg, tracker=tracker_backend
reports=collector_cfg,
tracker=tracker_backend,
run_id=run_id,
)

# Setup logging BEFORE initialize() so user logs in initialize() are captured
Expand Down Expand Up @@ -347,7 +349,7 @@ def _setup_logger(self, log_cfg: Dict[str, Any]) -> None:
rotation=rotation,
retention=retention,
logging_mode=logging_mode,
repeat_id=self.run_id,
run_id=self.run_id,
file_level=file_level,
file_format=file_format,
file_datefmt=file_datefmt,
Expand Down
28 changes: 21 additions & 7 deletions abses/utils/datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,18 @@ def __init__(
self,
reports: Dict[ReportType, Dict[str, Reporter]] | None = None,
tracker: Optional[TrackerProtocol] = None,
run_id: Optional[int] = None,
):
"""Initialize data collector.

Args:
reports: Reporters configuration.
tracker: Optional tracker backend.
run_id: Optional run id.
"""
reports = reports or {}
self.tracker = tracker
self.run_id = run_id
self.model_reporters: Dict[str, Reporter] = {}
self.final_reporters: Dict[str, Reporter] = {}
self.agent_reporters: Dict[str, Dict[str, Reporter]] = {}
Expand Down Expand Up @@ -162,6 +165,13 @@ def add_reporters(
for name, reporter in reporters.items():
self._new_agent_reporter(breed=item, name=name, reporter=reporter)

def _add_run_id_to_data(
self, data: pd.DataFrame | Dict[str, Any]
) -> pd.DataFrame | Dict[str, Any]:
if self.run_id is not None:
data["run_id"] = self.run_id
return data

def _new_model_reporter(self, name: str, reporter: Reporter) -> None:
"""Add a new model-level reporter to collect data.

Expand Down Expand Up @@ -216,8 +226,9 @@ def get_model_vars_dataframe(self):
logger.warning(
"No model reporters have been definedreturning empty DataFrame."
)

return pd.DataFrame(self.model_vars)
df = pd.DataFrame(self.model_vars)
df = self._add_run_id_to_data(df)
return df

def get_agent_vars_dataframe(self, breed: Optional[str] = None) -> pd.DataFrame:
"""获取某种 Agents 的 DataFrame"""
Expand All @@ -229,8 +240,12 @@ def get_agent_vars_dataframe(self, breed: Optional[str] = None) -> pd.DataFrame:
if not self.agent_reporters:
logger.warning("No agent reporters have been defined in the DataCollector.")
if results := self._agent_records.get(breed):
return pd.concat([pd.DataFrame(res) for res in results])
return pd.DataFrame()
df = pd.concat([pd.DataFrame(res) for res in results])
else:
logger.warning(f"No agent records found for breed {breed}.")
df = pd.DataFrame()
df = self._add_run_id_to_data(data=df)
return df

def get_final_vars_report(self, model: MainModel) -> Dict[str, Any]:
"""Report at the end of this model.
Expand All @@ -239,11 +254,10 @@ def get_final_vars_report(self, model: MainModel) -> Dict[str, Any]:
A dictionary mapping variable names to their computed values.
"""
if not self.final_reporters:
logger.warning(
"No final reporters have been defined, returning empty dict."
)
logger.info("No final reporters have been defined.")
return {}
results = {var: func(model) for var, func in self.final_reporters.items()}
self._add_run_id_to_data(results)
if self.tracker is not None:
self.tracker.log_final_metrics(results)
return results
Expand Down
14 changes: 7 additions & 7 deletions abses/utils/log_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,15 @@ def determine_log_file_path(
outpath: Optional[Path],
log_name: str,
logging_mode: str = "once",
repeat_id: Optional[int] = None,
run_id: Optional[int] = None,
) -> Optional[Path]:
"""Determine log file path based on logging mode.

Args:
outpath: Output directory for log files.
log_name: Base log file name (without extension).
logging_mode: Logging mode - 'once', 'separate', or 'merge'.
repeat_id: Repeat ID for the current run (1-indexed).
run_id: Run ID for the current run (1-indexed).

Comment thread
coderabbitai[bot] marked this conversation as resolved.
Returns:
Path to log file, or None if logging should be disabled.
Expand All @@ -250,21 +250,21 @@ def determine_log_file_path(

if logging_mode == "once":
# Only log the first repeat
if repeat_id is None or repeat_id == 1:
if run_id is None or run_id == 1:
return outpath / f"{log_name}.log"
return None
elif logging_mode == "separate":
# Each repeat gets its own file
# In separate mode, repeat_id must be provided
if repeat_id is None:
# In separate mode, run_id must be provided
if run_id is None:
return None # Don't create default file in separate mode
return outpath / f"{log_name}_{repeat_id}.log"
return outpath / f"{log_name}_{run_id}.log"
elif logging_mode == "merge":
# All repeats go to the same file
return outpath / f"{log_name}.log"
else:
# Unknown mode, default to once behavior
if repeat_id is None or repeat_id == 1:
if run_id is None or run_id == 1:
return outpath / f"{log_name}.log"
return None

Expand Down
Loading
Loading