diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 7c31baf8..d727dcf1 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -455,14 +455,14 @@ def load_all_safetensors_from_dir(dir_path_: str): self.jiuge_model = JiugeModel() - if "llama" == config["model_type"]: + if "llama" == config["model_type"] or "internlm3" == config["model_type"]: model = ( transformers.LlamaForCausalLM.from_pretrained(model_dir_path) .cpu() .half() ) self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) - self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path, trust_remote_code=True) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), @@ -613,6 +613,7 @@ def generate( topk_=1, temperature_=1.0, verbose=False, + skip_special_tokens=False ): input_content = self.tokenizer.apply_chat_template( conversation=[{"role": "user", "content": input_content}], @@ -696,9 +697,12 @@ def generate( output_str = self.tokenizer.decode(output_tokens[0]) output_content += output_str - print(output_str, end="", flush=True) if output_tokens[0] in self.eos_token_id: break + if skip_special_tokens and output_tokens[0] in self.tokenizer.all_special_ids: + continue + + print(output_str, end="", flush=True) infer_task.next(output_tokens[0]) if step_i > 0: @@ -869,7 +873,7 @@ def test(): ndev = int(ndev_args[0]) if ndev_args else 1 model = JiugeForCauslLM(model_path, device_type, ndev) - model.generate("山东最高的山是?", 500, verbose=verbose) + model.generate("山东最高的山是?", 500, verbose=verbose, skip_special_tokens=True) model.destroy_model_instance()