diff --git a/swift/llm/eval/utils.py b/swift/llm/eval/utils.py index 1ab6937dec..3b9cc8e8ac 100644 --- a/swift/llm/eval/utils.py +++ b/swift/llm/eval/utils.py @@ -21,6 +21,22 @@ from ..infer import PtEngine, RequestConfig from ..template import InferRequest +# Lightweight in-process registry to avoid putting heavyweight torch model objects +# into evalscope taskConfig, which may causing OOM. +_EVAL_MODEL_REGISTRY: Dict[str, Tuple[Any, Any]] = {} + + +def register_eval_model(key: str, model: Any, template: Any) -> None: + _EVAL_MODEL_REGISTRY[key] = (model, template) + + +def get_eval_model(key: str) -> Tuple[Optional[Any], Optional[Any]]: + return _EVAL_MODEL_REGISTRY.get(key, (None, None)) + + +def unregister_eval_model(key: str) -> None: + _EVAL_MODEL_REGISTRY.pop(key, None) + @dataclass class BatchInferInput: @@ -97,8 +113,13 @@ def collect_model_arg(name: str) -> Optional[Any]: return value # Extract required model parameters - self.model = collect_model_arg('model') # model path or identifier - self.template = collect_model_arg('template') # conversation template + # Prefer lightweight registry reference to avoid passing torch model directly. + model_ref = collect_model_arg('model_ref') + if model_ref is not None: + self.model, self.template = get_eval_model(model_ref) + else: + self.model = collect_model_arg('model') # torch model object or path/id + self.template = collect_model_arg('template') # conversation template # Initialize the inference engine with batch support self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=self.config.batch_size) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 3cda3e3a0d..2f8402f0c5 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -779,35 +779,38 @@ def _compute_acc(self, outputs, labels) -> None: @torch.no_grad() def _evalscope_eval(self): - from ..llm.eval.utils import EvalModel # registry here + from ..llm.eval.utils import EvalModel, register_eval_model, unregister_eval_model # registry here from evalscope import TaskConfig, run_task self.model.eval() - task_config_kwargs = dict( - model=f'model-step{self.state.global_step}', - model_args=dict( - model=self.model, - template=self.template, - ), - eval_type='swift_custom', - datasets=self.args.eval_dataset, - dataset_args=self.args.eval_dataset_args, - limit=self.args.eval_limit, - work_dir=os.path.join(self.args.output_dir, 'eval'), - eval_batch_size=self.args.per_device_eval_batch_size, - generation_config=self.args.eval_generation_config or {'max_tokens': 512}, - ) - task_config_kwargs.update(self.args.extra_eval_args or {}) - task_config = TaskConfig(**task_config_kwargs) - # start evaluation - eval_report = run_task(task_config) - # convert to dict - eval_dict = {f'test_{k}': v.score for k, v in eval_report.items()} - self.log(eval_dict) - - self.model.train() - return eval_dict + # Register current model+template with a lightweight key to avoid passing the whole torch model + model_key = f'swift_eval_model_{os.getpid()}_{self.state.global_step}' + register_eval_model(model_key, self.model, self.template) + + try: + task_config_kwargs = dict( + model=f'model-step{self.state.global_step}', + model_args=dict(model_ref=model_key, ), + eval_type='swift_custom', + datasets=self.args.eval_dataset, + dataset_args=self.args.eval_dataset_args, + limit=self.args.eval_limit, + work_dir=os.path.join(self.args.output_dir, 'eval'), + eval_batch_size=self.args.per_device_eval_batch_size, + generation_config=self.args.eval_generation_config or {'max_tokens': 512}, + ) + task_config_kwargs.update(self.args.extra_eval_args or {}) + task_config = TaskConfig(**task_config_kwargs) + # start evaluation + eval_report = run_task(task_config) + # convert to dict + eval_dict = {f'test_{k}': v.score for k, v in eval_report.items()} + self.log(eval_dict) + return eval_dict + finally: + unregister_eval_model(model_key) + self.model.train() def prepare_logits_to_keep(self, inputs): labels = inputs['labels']