Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions swift/llm/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@
from ..infer import PtEngine, RequestConfig
from ..template import InferRequest

# Lightweight in-process registry to avoid putting heavyweight torch model objects
# into evalscope taskConfig, which may causing OOM.
_EVAL_MODEL_REGISTRY: Dict[str, Tuple[Any, Any]] = {}


def register_eval_model(key: str, model: Any, template: Any) -> None:
_EVAL_MODEL_REGISTRY[key] = (model, template)


def get_eval_model(key: str) -> Tuple[Optional[Any], Optional[Any]]:
return _EVAL_MODEL_REGISTRY.get(key, (None, None))


def unregister_eval_model(key: str) -> None:
_EVAL_MODEL_REGISTRY.pop(key, None)


@dataclass
class BatchInferInput:
Expand Down Expand Up @@ -97,8 +113,13 @@ def collect_model_arg(name: str) -> Optional[Any]:
return value

# Extract required model parameters
self.model = collect_model_arg('model') # model path or identifier
self.template = collect_model_arg('template') # conversation template
# Prefer lightweight registry reference to avoid passing torch model directly.
model_ref = collect_model_arg('model_ref')
if model_ref is not None:
self.model, self.template = get_eval_model(model_ref)
else:
self.model = collect_model_arg('model') # torch model object or path/id
self.template = collect_model_arg('template') # conversation template

# Initialize the inference engine with batch support
self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=self.config.batch_size)
Expand Down
53 changes: 28 additions & 25 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,35 +779,38 @@ def _compute_acc(self, outputs, labels) -> None:

@torch.no_grad()
def _evalscope_eval(self):
from ..llm.eval.utils import EvalModel # registry here
from ..llm.eval.utils import EvalModel, register_eval_model, unregister_eval_model # registry here
from evalscope import TaskConfig, run_task

self.model.eval()

task_config_kwargs = dict(
model=f'model-step{self.state.global_step}',
model_args=dict(
model=self.model,
template=self.template,
),
eval_type='swift_custom',
datasets=self.args.eval_dataset,
dataset_args=self.args.eval_dataset_args,
limit=self.args.eval_limit,
work_dir=os.path.join(self.args.output_dir, 'eval'),
eval_batch_size=self.args.per_device_eval_batch_size,
generation_config=self.args.eval_generation_config or {'max_tokens': 512},
)
task_config_kwargs.update(self.args.extra_eval_args or {})
task_config = TaskConfig(**task_config_kwargs)
# start evaluation
eval_report = run_task(task_config)
# convert to dict
eval_dict = {f'test_{k}': v.score for k, v in eval_report.items()}
self.log(eval_dict)

self.model.train()
return eval_dict
# Register current model+template with a lightweight key to avoid passing the whole torch model
model_key = f'swift_eval_model_{os.getpid()}_{self.state.global_step}'
register_eval_model(model_key, self.model, self.template)

try:
task_config_kwargs = dict(
model=f'model-step{self.state.global_step}',
model_args=dict(model_ref=model_key, ),
eval_type='swift_custom',
datasets=self.args.eval_dataset,
dataset_args=self.args.eval_dataset_args,
limit=self.args.eval_limit,
work_dir=os.path.join(self.args.output_dir, 'eval'),
eval_batch_size=self.args.per_device_eval_batch_size,
generation_config=self.args.eval_generation_config or {'max_tokens': 512},
)
task_config_kwargs.update(self.args.extra_eval_args or {})
task_config = TaskConfig(**task_config_kwargs)
# start evaluation
eval_report = run_task(task_config)
# convert to dict
eval_dict = {f'test_{k}': v.score for k, v in eval_report.items()}
self.log(eval_dict)
return eval_dict
finally:
unregister_eval_model(model_key)
self.model.train()

def prepare_logits_to_keep(self, inputs):
labels = inputs['labels']
Expand Down