From ff02162e1bada5bb134402f8236d82f731688a13 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Wed, 3 Dec 2025 18:00:18 +0800 Subject: [PATCH] =?UTF-8?q?issue/93:=20=E6=94=AF=E6=8C=81InternLM3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/jiuge.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 7c31baf8..d727dcf1 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -455,14 +455,14 @@ def load_all_safetensors_from_dir(dir_path_: str): self.jiuge_model = JiugeModel() - if "llama" == config["model_type"]: + if "llama" == config["model_type"] or "internlm3" == config["model_type"]: model = ( transformers.LlamaForCausalLM.from_pretrained(model_dir_path) .cpu() .half() ) self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) - self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path, trust_remote_code=True) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), @@ -613,6 +613,7 @@ def generate( topk_=1, temperature_=1.0, verbose=False, + skip_special_tokens=False ): input_content = self.tokenizer.apply_chat_template( conversation=[{"role": "user", "content": input_content}], @@ -696,9 +697,12 @@ def generate( output_str = self.tokenizer.decode(output_tokens[0]) output_content += output_str - print(output_str, end="", flush=True) if output_tokens[0] in self.eos_token_id: break + if skip_special_tokens and output_tokens[0] in self.tokenizer.all_special_ids: + continue + + print(output_str, end="", flush=True) infer_task.next(output_tokens[0]) if step_i > 0: @@ -869,7 +873,7 @@ def test(): ndev = int(ndev_args[0]) if ndev_args else 1 model = JiugeForCauslLM(model_path, device_type, ndev) - model.generate("山东最高的山是?", 500, verbose=verbose) + model.generate("山东最高的山是?", 500, verbose=verbose, skip_special_tokens=True) model.destroy_model_instance()