Skip to content

Commit cb30103

Browse files
authored
Merge pull request #111 from Tridu33/master
feat: qwen demo code
2 parents 81226e6 + 7dc2e45 commit cb30103

File tree

4 files changed

+1909
-0
lines changed

4 files changed

+1909
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import mindspore
2+
import numpy as np
3+
from mindspore import dtype as mstype
4+
import mindspore.ops as ops
5+
from mindspore import Tensor
6+
from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM
7+
import faulthandler
8+
9+
faulthandler.enable()
10+
11+
model_id = "Qwen/Qwen1.5-0.5B-Chat"
12+
tokenizer = AutoTokenizer.from_pretrained(model_id, mirror='modelscope')
13+
model = AutoModelForCausalLM.from_pretrained(
14+
model_id,
15+
ms_dtype=mindspore.float16,
16+
mirror='modelscope'
17+
)
18+
19+
messages = [
20+
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
21+
{"role": "user", "content": "Who are you?"},
22+
]
23+
24+
input_ids = tokenizer.apply_chat_template(
25+
messages,
26+
add_generation_prompt=True,
27+
return_tensors="ms"
28+
)
29+
attention_mask = Tensor(np.ones(input_ids.shape), mstype.float32)
30+
31+
terminators = [
32+
tokenizer.eos_token_id,
33+
tokenizer.convert_tokens_to_ids("<|endoftext|>")
34+
]
35+
outputs = model.generate(
36+
input_ids,
37+
attention_mask=attention_mask,
38+
max_new_tokens=20,
39+
eos_token_id=terminators,
40+
do_sample=False,
41+
# do_sample=True,
42+
# temperature=0.6,
43+
# top_p=0.9,
44+
)
45+
response = outputs[0][input_ids.shape[-1]:]
46+
print(outputs)
47+
print(tokenizer.decode(response, skip_special_tokens=True))
48+
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import gradio as gr
2+
import mindspore
3+
from mindspore import dtype as mstype
4+
import numpy as np
5+
from mindnlpv041.mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
6+
from mindnlpv041.mindnlp.transformers import TextIteratorStreamer
7+
from threading import Thread
8+
9+
# Loading the tokenizer and model from Hugging Face's model hub.
10+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16)
11+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16)
12+
13+
system_prompt = "You are a helpful and friendly chatbot"
14+
15+
def build_input_from_chat_history(chat_history, msg: str):
16+
messages = [{'role': 'system', 'content': system_prompt}]
17+
for user_msg, ai_msg in chat_history:
18+
messages.append({'role': 'user', 'content': user_msg})
19+
messages.append({'role': 'assistant', 'content': ai_msg})
20+
messages.append({'role': 'user', 'content': msg})
21+
return messages
22+
23+
# Function to generate model predictions.
24+
def predict(message, history):
25+
# Formatting the input for the model.
26+
messages = build_input_from_chat_history(history, message)
27+
input_ids = tokenizer.apply_chat_template(
28+
messages,
29+
add_generation_prompt=True,
30+
return_tensors="ms",
31+
tokenize=True
32+
)
33+
attention_mask = mindspore.Tensor(np.ones(input_ids.shape), mstype.float32)
34+
streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True)
35+
generate_kwargs = dict(
36+
input_ids=input_ids,
37+
streamer=streamer,
38+
max_new_tokens=1024,
39+
do_sample=True,
40+
top_p=0.9,
41+
temperature=0.1,
42+
num_beams=1,
43+
attention_mask=attention_mask,
44+
)
45+
t = Thread(target=model.generate, kwargs=generate_kwargs)
46+
t.start() # Starting the generation in a separate thread.
47+
partial_message = ""
48+
for new_token in streamer:
49+
partial_message += new_token
50+
if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
51+
break
52+
yield partial_message
53+
54+
55+
# Setting up the Gradio chat interface.
56+
gr.ChatInterface(predict,
57+
title="Qwen1.5-0.5b-Chat",
58+
description="问几个问题",
59+
examples=['你是谁?', '介绍一下华为公司']
60+
).launch(share=True, server_name='0.0.0.0', server_port=7860) # Launching the web interface.
61+

0 commit comments

Comments
 (0)