|
| 1 | +import gradio as gr |
| 2 | +import mindspore |
| 3 | +from mindspore import dtype as mstype |
| 4 | +import numpy as np |
| 5 | +from mindnlpv041.mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer |
| 6 | +from mindnlpv041.mindnlp.transformers import TextIteratorStreamer |
| 7 | +from threading import Thread |
| 8 | + |
| 9 | +# Loading the tokenizer and model from Hugging Face's model hub. |
| 10 | +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16) |
| 11 | +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16) |
| 12 | + |
| 13 | +system_prompt = "You are a helpful and friendly chatbot" |
| 14 | + |
| 15 | +def build_input_from_chat_history(chat_history, msg: str): |
| 16 | + messages = [{'role': 'system', 'content': system_prompt}] |
| 17 | + for user_msg, ai_msg in chat_history: |
| 18 | + messages.append({'role': 'user', 'content': user_msg}) |
| 19 | + messages.append({'role': 'assistant', 'content': ai_msg}) |
| 20 | + messages.append({'role': 'user', 'content': msg}) |
| 21 | + return messages |
| 22 | + |
| 23 | +# Function to generate model predictions. |
| 24 | +def predict(message, history): |
| 25 | + # Formatting the input for the model. |
| 26 | + messages = build_input_from_chat_history(history, message) |
| 27 | + input_ids = tokenizer.apply_chat_template( |
| 28 | + messages, |
| 29 | + add_generation_prompt=True, |
| 30 | + return_tensors="ms", |
| 31 | + tokenize=True |
| 32 | + ) |
| 33 | + attention_mask = mindspore.Tensor(np.ones(input_ids.shape), mstype.float32) |
| 34 | + streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True) |
| 35 | + generate_kwargs = dict( |
| 36 | + input_ids=input_ids, |
| 37 | + streamer=streamer, |
| 38 | + max_new_tokens=1024, |
| 39 | + do_sample=True, |
| 40 | + top_p=0.9, |
| 41 | + temperature=0.1, |
| 42 | + num_beams=1, |
| 43 | + attention_mask=attention_mask, |
| 44 | + ) |
| 45 | + t = Thread(target=model.generate, kwargs=generate_kwargs) |
| 46 | + t.start() # Starting the generation in a separate thread. |
| 47 | + partial_message = "" |
| 48 | + for new_token in streamer: |
| 49 | + partial_message += new_token |
| 50 | + if '</s>' in partial_message: # Breaking the loop if the stop token is generated. |
| 51 | + break |
| 52 | + yield partial_message |
| 53 | + |
| 54 | + |
| 55 | +# Setting up the Gradio chat interface. |
| 56 | +gr.ChatInterface(predict, |
| 57 | + title="Qwen1.5-0.5b-Chat", |
| 58 | + description="问几个问题", |
| 59 | + examples=['你是谁?', '介绍一下华为公司'] |
| 60 | + ).launch(share=True, server_name='0.0.0.0', server_port=7860) # Launching the web interface. |
| 61 | + |
0 commit comments