Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 152 additions & 8 deletions fastdeploy/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def generate(
],
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
stream: bool = False,
):
"""
Generate function for the LLM class.
Expand All @@ -149,9 +150,11 @@ def generate(
sampling_params (Optional[Union[SamplingParams, list[SamplingParams]]], optional):
The sampling parameters to use for generating the response. Defaults to None.
use_tqdm (bool, optional): Whether to use tqdm for the progress bar. Defaults to True.
stream (bool, optional): Whether to return a streaming iterator. Defaults to False.

Returns:
Union[str, list[str]]: The generated response.
If stream=False: Union[str, list[str]]: The generated response.
If stream=True: Iterator: An iterator that yields partial responses as they become available.
"""

if not self._check_master():
Expand Down Expand Up @@ -186,10 +189,13 @@ def generate(
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs

# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
for i in range(len(outputs)):
outputs[i].prompt = prompts[i]
return outputs
if stream:
return self._run_engine_stream(req_ids, prompts, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
else:
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
for i in range(len(outputs)):
outputs[i].prompt = prompts[i]
return outputs

def chat(
self,
Expand All @@ -198,6 +204,7 @@ def chat(
use_tqdm: bool = True,
chat_template_kwargs: Optional[dict[str, Any]] = None,
chat_template: Optional[str] = None,
stream: bool = False,
):
"""
Args:
Expand All @@ -208,9 +215,11 @@ def chat(
use_tqdm (bool, optional): Whether to use tqdm for the progress bar. Defaults to True.
chat_template_kwargs(Optional[dict[str,Any]]): Additional kwargs to pass to the chat
template.
stream (bool, optional): Whether to return a streaming iterator. Defaults to False.

Returns:
Union[str, list[str]]: The generated response.
If stream=False: Union[str, list[str]]: The generated response.
If stream=True: Iterator: An iterator that yields partial responses as they become available.
"""

if not self._check_master():
Expand Down Expand Up @@ -247,8 +256,11 @@ def chat(
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs

# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
return outputs
if stream:
return self._run_engine_stream(req_ids, messages, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
else:
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
return outputs

def _add_request(
self,
Expand Down Expand Up @@ -414,6 +426,138 @@ def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optiona
pbar.close()
return output

def _run_engine_stream(self, req_ids: list[str], prompts, use_tqdm: bool, topk_logprobs: Optional[int] = None):
"""
运行引擎并返回流式响应的迭代器。

Args:
req_ids (list[str]): 请求ID列表
prompts: 原始提示词列表,用于设置到输出中
use_tqdm (bool, optional): 是否使用tqdm进度条
topk_logprobs (Optional[int]): 返回的top-k logprobs数量

Yields:
list[RequestOutput]: 包含增量更新的部分响应列表
"""
# Initialize tqdm
if use_tqdm:
num_requests = len(req_ids)
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"),
)

num_requests = len(req_ids)
original_num_requests = len(req_ids) # Keep track of original count
output = [None] * original_num_requests
req_ids_with_pos = [(pos, req_id) for pos, req_id in enumerate(req_ids)]

# Track previous token counts for each request to identify new tokens
previous_token_counts = {req_id: 0 for req_id in req_ids}

while num_requests > 0:
has_new_tokens = False
finished = []

for i, (pos, req_id) in enumerate(req_ids_with_pos):
with self.mutex:
if req_id not in self.req_output:
continue

current_result = self.req_output[req_id]
current_token_count = (
len(current_result.outputs.token_ids) if current_result.outputs.token_ids else 0
)
previous_count = previous_token_counts[req_id]

# Check if there are new tokens since last yield
if current_token_count > previous_count:
has_new_tokens = True
# Create incremental output with only new tokens
incremental_result = self._create_incremental_result(
current_result, previous_count, pos, prompts
)

# Apply logprobs filtering to the incremental result if needed
if incremental_result.outputs.top_logprobs and topk_logprobs:
incremental_result.outputs.logprobs = self._build_sample_logprobs(
incremental_result.outputs.top_logprobs, topk_logprobs
)

output[pos] = incremental_result
previous_token_counts[req_id] = current_token_count

# Check if request is finished
if current_result.finished:
finished.append(i)

# For streaming, when a request is finished, we should NOT output anything
self.req_output.pop(req_id)

llm_logger.debug(f"Request id: {req_id} has been completed.")

if use_tqdm:
pbar.update(1)

# Yield updates if there are new tokens
if has_new_tokens or finished:
# yield [result for result in output if result is not None]
# Create a complete output array with proper indexing
complete_output = [None] * original_num_requests # Use original length
for i, (pos, _) in enumerate(req_ids_with_pos):
if output[pos] is not None:
complete_output[pos] = output[pos]
yield complete_output
# Clear output for next iteration
output = [None] * original_num_requests

# Remove finished requests
num_requests -= len(finished)
for i in reversed(finished):
req_ids_with_pos.pop(i)

if num_requests > 0:
time.sleep(0.01)

if use_tqdm:
pbar.close()

def _create_incremental_result(self, current_result, previous_count, pos, prompts):
"""
创建包含增量token的结果对象

Args:
current_result: 当前的RequestOutput对象
previous_count: 之前已处理的token数量
pos: 在prompts列表中的位置
prompts: 原始提示词列表

Returns:
RequestOutput: 包含增量更新的结果对象
"""
# Create a copy of current result for incremental update
from copy import deepcopy

incremental_result = deepcopy(current_result)

# Extract only new tokens
if current_result.outputs.token_ids and len(current_result.outputs.token_ids) > previous_count:
new_token_ids = current_result.outputs.token_ids[previous_count:]
incremental_result.outputs.token_ids = new_token_ids

# Process new tokens to get text
incremental_result = self.llm_engine.data_processor.process_response(incremental_result)

# Set the prompt
if isinstance(prompts, list):
incremental_result.prompt = prompts[pos]
else:
incremental_result.prompt = prompts

return incremental_result


if __name__ == "__main__":
# llm = LLM(model="llama_model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,74 @@ def test_chat_completion(llm):
pytest.fail(f"Chat case {i + 1} failed")


def test_generate_prompts_stream(llm):
"""
Test basic prompt generation stream outputs
"""

prompts = [
"请介绍一下中国的四大发明。",
]

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)

try:
outputs = llm.generate(prompts, sampling_params, stream=True)

# Collect streaming output
output = []
for chunk in outputs:
if chunk[0] is not None:
output.append(chunk[0].outputs.text)
assert len(output) > 0

except Exception:
print("Failed during prompt generation.")
traceback.print_exc()
pytest.fail("Prompt generation test failed")


def test_chat_completion_stream(llm):
"""
Test chat completion stream outputs
"""
chat_cases = [
[
{"role": "user", "content": "你好,请介绍一下你自己。"},
],
[
{"role": "user", "content": "你知道地球到月球的距离是多少吗?"},
{"role": "assistant", "content": "大约是38万公里左右。"},
{"role": "user", "content": "那太阳到地球的距离是多少?"},
],
]

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)

try:
outputs = llm.chat(chat_cases, sampling_params, stream=True)

# Collect streaming output
output = [[], []]
for chunks in outputs:
for req_idx, chunk in enumerate(chunks):
if chunk is not None:
output[req_idx].append(chunk.outputs.text)
assert len(output[0]) > 0
assert len(output[1]) > 0

except Exception:
print("Failed during prompt chat.")
traceback.print_exc()
pytest.fail("Prompt chat test failed")


def test_seed(llm):
"""
Test chat completion with same seed
Expand Down
Loading