Skip to content

Commit 9c05a83

Browse files
authored
log model config (#627)
* log model config * log model config * fix tests * make tests fail at first failure * make tests fail at first failure * fix tests * fix tests
1 parent 14d289e commit 9c05a83

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

.github/workflows/tests.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
HF_HOME: "cache/models"
3939
HF_DATASETS_CACHE: "cache/datasets"
4040
run: | # PYTHONPATH="${PYTHONPATH}:src" HF_DATASETS_CACHE="cache/datasets" HF_HOME="cache/models"
41-
python -m pytest --disable-pytest-warnings
41+
python -m pytest -x --disable-pytest-warnings
4242
- name: Write cache
4343
uses: actions/cache@v4
4444
with:

src/lighteval/logging/info_loggers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ class GeneralConfigLogger:
9090
model_dtype: str = None
9191
model_size: str = None
9292

93+
generation_parameters: dict | None = None
94+
9395
# Nanotron config
9496
config: "Config" = None
9597

@@ -133,14 +135,16 @@ def log_args_info(
133135
self.job_id = job_id
134136
self.config = config
135137

136-
def log_model_info(self, model_info: ModelInfo) -> None:
138+
def log_model_info(self, generation_parameters: dict, model_info: ModelInfo) -> None:
137139
"""
138140
Logs the model information.
139141
140142
Args:
143+
model_config: the model config used to initalize the model.
141144
model_info (ModelInfo): Model information to be logged.
142145
143146
"""
147+
self.generation_parameters = generation_parameters
144148
self.model_name = model_info.model_name
145149
self.model_sha = model_info.model_sha
146150
self.model_dtype = model_info.model_dtype

src/lighteval/pipeline.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import re
2828
import shutil
2929
from contextlib import nullcontext
30-
from dataclasses import dataclass, field
30+
from dataclasses import asdict, dataclass, field
3131
from datetime import timedelta
3232
from enum import Enum, auto
3333

@@ -156,7 +156,9 @@ def __init__(
156156
self.accelerator, self.parallel_context = self._init_parallelism_manager()
157157
self.model = self._init_model(model_config, model)
158158

159-
self.evaluation_tracker.general_config_logger.log_model_info(self.model.model_info)
159+
generation_parameters = asdict(model_config.generation_parameters) if model_config else {}
160+
161+
self.evaluation_tracker.general_config_logger.log_model_info(generation_parameters, self.model.model_info)
160162
self._init_tasks_and_requests(tasks=tasks)
161163
self._init_random_seeds()
162164
# Final results

tests/models/test_base_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626

2727

2828
def test_empty_requests():
29-
model_config = TransformersModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM")
29+
model_config = TransformersModelConfig(
30+
"hf-internal-testing/tiny-random-LlamaForCausalLM", model_parallel=False, revision="main"
31+
)
3032
model: TransformersModel = load_model(config=model_config, env_config=EnvConfig(cache_dir="."))
3133

3234
assert model.loglikelihood([]) == []

0 commit comments

Comments
 (0)