File tree 4 files changed +13
-5
lines changed
4 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 38
38
HF_HOME : " cache/models"
39
39
HF_DATASETS_CACHE : " cache/datasets"
40
40
run : | # PYTHONPATH="${PYTHONPATH}:src" HF_DATASETS_CACHE="cache/datasets" HF_HOME="cache/models"
41
- python -m pytest --disable-pytest-warnings
41
+ python -m pytest -x - -disable-pytest-warnings
42
42
- name : Write cache
43
43
uses : actions/cache@v4
44
44
with :
Original file line number Diff line number Diff line change @@ -90,6 +90,8 @@ class GeneralConfigLogger:
90
90
model_dtype : str = None
91
91
model_size : str = None
92
92
93
+ generation_parameters : dict | None = None
94
+
93
95
# Nanotron config
94
96
config : "Config" = None
95
97
@@ -133,14 +135,16 @@ def log_args_info(
133
135
self .job_id = job_id
134
136
self .config = config
135
137
136
- def log_model_info (self , model_info : ModelInfo ) -> None :
138
+ def log_model_info (self , generation_parameters : dict , model_info : ModelInfo ) -> None :
137
139
"""
138
140
Logs the model information.
139
141
140
142
Args:
143
+ model_config: the model config used to initalize the model.
141
144
model_info (ModelInfo): Model information to be logged.
142
145
143
146
"""
147
+ self .generation_parameters = generation_parameters
144
148
self .model_name = model_info .model_name
145
149
self .model_sha = model_info .model_sha
146
150
self .model_dtype = model_info .model_dtype
Original file line number Diff line number Diff line change 27
27
import re
28
28
import shutil
29
29
from contextlib import nullcontext
30
- from dataclasses import dataclass , field
30
+ from dataclasses import asdict , dataclass , field
31
31
from datetime import timedelta
32
32
from enum import Enum , auto
33
33
@@ -156,7 +156,9 @@ def __init__(
156
156
self .accelerator , self .parallel_context = self ._init_parallelism_manager ()
157
157
self .model = self ._init_model (model_config , model )
158
158
159
- self .evaluation_tracker .general_config_logger .log_model_info (self .model .model_info )
159
+ generation_parameters = asdict (model_config .generation_parameters ) if model_config else {}
160
+
161
+ self .evaluation_tracker .general_config_logger .log_model_info (generation_parameters , self .model .model_info )
160
162
self ._init_tasks_and_requests (tasks = tasks )
161
163
self ._init_random_seeds ()
162
164
# Final results
Original file line number Diff line number Diff line change 26
26
27
27
28
28
def test_empty_requests ():
29
- model_config = TransformersModelConfig ("hf-internal-testing/tiny-random-LlamaForCausalLM" )
29
+ model_config = TransformersModelConfig (
30
+ "hf-internal-testing/tiny-random-LlamaForCausalLM" , model_parallel = False , revision = "main"
31
+ )
30
32
model : TransformersModel = load_model (config = model_config , env_config = EnvConfig (cache_dir = "." ))
31
33
32
34
assert model .loglikelihood ([]) == []
You can’t perform that action at this time.
0 commit comments