Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 28 additions & 36 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,
DistServeInitRequest)
from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch
from lmdeploy.utils import get_logger, get_max_batch_size, get_model, logging_timer
from lmdeploy.utils import get_logger, get_max_batch_size, get_model

from ..adapter.adapter import AdapterManager
from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig, SpecDecodeConfig
Expand Down Expand Up @@ -281,23 +281,23 @@ def do_prefill_dp(self):
if self.next_is_prefill:
ret = scheduler.has_waiting()
else:
ret = not scheduler.has_running()
ret = not scheduler.has_ready()
return ret

def do_prefill_default(self):
# decoding if no waiting
scheduler = self.scheduler
if not scheduler.has_waiting():
return False
num_running = scheduler.num_running()
num_ready = scheduler.num_ready()
num_waiting = scheduler.num_waiting()
max_batches = self.scheduler_config.max_batches
# prefill if too much waiting
permitted_waiting = 4 if (self.engine.engine_config.role != EngineRole.Prefill) else 1
if num_waiting >= permitted_waiting:
return True
# prefill if no enough running
if num_running < max_batches * 0.5:
if num_ready < max_batches * 0.5:
return True
# decoding
return False
Expand Down Expand Up @@ -328,11 +328,11 @@ async def prefetch_next_inputs(self):
if prefill:
enable = True
else:
num_running = scheduler.num_running()
num_ready = scheduler.num_ready()
is_decoding = self.forward_inputs['inputs'].is_decoding
running_threshold = (self.scheduler_config.max_batches // 4) if is_decoding or self.spec_decoding else 0

if num_running > running_threshold:
if num_ready > running_threshold:
enable = True

if enable:
Expand Down Expand Up @@ -592,7 +592,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs):
if session_id in self.scheduler.sessions:
msgs = list(self.scheduler.sessions[session_id].sequences.values())
if len(msgs) > 0 and msgs[0].preserve_cache:
self.scheduler._set_message_status(msgs[0], MessageStatus.TO_BE_MIGRATED)
msgs[0].state.finish()
else:
self.end_session(session_id)
resp_type = ResponseType.SUCCESS
Expand Down Expand Up @@ -656,11 +656,10 @@ def __update_max_new_tokens(msg):
scheduler = self.scheduler
for req in reqs:
session_id = req.data['session_id']
if scheduler is None:
sess = scheduler.sessions.get(session_id, None)
if sess is None:
self._response(req.resp, ResponseType.SESSION_NOT_EXIST)
continue
session_id = req.data['session_id']
sess = scheduler.sessions[session_id]
# TODO: support 1 session n sequence
sampling_param = req.data['sampling_param']
if len(sess.sequences) == 0:
Expand All @@ -675,10 +674,7 @@ def __update_max_new_tokens(msg):
resp_cache=req.data.get('with_cache'),
preserve_cache=req.data.get('preserve_cache'))
msg = next(iter(sess.sequences.values()))
__update_max_new_tokens(msg)
scheduler.add_sequence(msg)
if migration_request:
self.scheduler._set_message_status(msg, MessageStatus.WAITING_MIGRATION)
self.migration_event.set()
else:
msg = next(iter(sess.sequences.values()))
Expand All @@ -689,20 +685,16 @@ def __update_max_new_tokens(msg):
mode=UpdateTokenMode.INPUTS,
)
msg.sampling_param = sampling_param
msg.status = MessageStatus.WAITING
__update_max_new_tokens(msg)
msg.state.activate()

__update_max_new_tokens(msg)
msg.resp = req.resp

@property
def model_config(self) -> ModelConfig:
"""Model config."""
return self.executor.model_config

@property
def gpu_count(self):
return self.dist_config.world_size

@property
def torch_int_dtype(self):
"""Return int32 for cuda, int64 for others."""
Expand Down Expand Up @@ -775,8 +767,7 @@ def __has_values(input_multimodals):
return vision_embedding_inputs

@torch.inference_mode()
@logging_timer('CreateModelInputs', logger)
@record_function('CreateModelInputs')
@record_function('create_model_inputs')
def create_model_inputs(self, messages: SeqList, is_prefill: bool):
"""Create model inputs from messages.

Expand Down Expand Up @@ -861,7 +852,7 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray,
if model_metas is None:
model_metas = [None] * len(running)
for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas):
if msg.status != MessageStatus.MIGRATION_LOCKED:
if msg.status != MessageStatus.MIGRATION_RUNNING:
continue
update_token = token

Expand All @@ -870,7 +861,7 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray,
if stop:
update_token = _EMPTY_TOKEN
msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL)
msg.status = MessageStatus.STOPPED
msg.state.finish()

@record_function('make_infer_outputs')
def _make_infer_outputs(
Expand All @@ -889,7 +880,7 @@ def _make_infer_outputs(
logprobs.indices = logprobs.indices.tolist()

seq_length = [seq.num_token_ids for seq in running]
is_run = [seq.status == MessageStatus.LOCKED for seq in running]
is_run = [seq.status == MessageStatus.RUNNING for seq in running]
self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding)

# generate output
Expand Down Expand Up @@ -941,6 +932,7 @@ def _make_infer_outputs(
outputs[session_id].logits = logits.split(seq_length)[idx]
return outputs

@record_function('make_forward_inputs')
def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False):
"""Make forward inputs."""

Expand All @@ -966,7 +958,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
if (self.engine_config.role == EngineRole.Prefill):
return False
# disable decoding if no running reqs.
if not self.scheduler.has_running():
if not self.scheduler.has_ready():
logger.warning('No running sequences for decoding scheduling after prefill scheduling.')
return False
return True
Expand Down Expand Up @@ -1107,12 +1099,12 @@ async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionReque
async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event):
"""Async loop migration."""
while True:
migration_running = self.scheduler._schedule_migration()
if not migration_running and not self.scheduler.has_migration_waiting():
migration_ready = self.scheduler._schedule_migration()
if not migration_ready and not self.scheduler.has_migration_waiting():
await self.migration_event.wait()
elif migration_running:
elif migration_ready:
self.migration_event.clear()
for msg in migration_running:
for msg in migration_ready:
migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = []
migration_request = msg.migration_request
prefill_block_ids = migration_request.remote_block_ids
Expand All @@ -1137,8 +1129,8 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event

# generate output
outputs: Dict[int, InferOutput] = dict()
self.scheduler.lock_running_migration(migration_running)
for _, msg in enumerate(migration_running):
self.scheduler.activate_migration_seqs(migration_ready)
for _, msg in enumerate(migration_ready):
session_id = msg.session_id
msg.resp.type = ResponseType.SUCCESS
token_ids = [msg.migration_request.remote_token_id]
Expand All @@ -1155,7 +1147,7 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event
outputs[session_id] = out
self.update_running_migration([msg], np.array([token_ids]), [False], [None])
resp_que.put_nowait(outputs)
self.scheduler.unlock_running_migration(migration_running)
self.scheduler.deactivate_migration_seqs(migration_ready)
has_runable_event.set()
else:
# release coroutine for decoding
Expand Down Expand Up @@ -1202,7 +1194,7 @@ async def _async_loop_main(
is_decoding = forward_inputs['inputs'].is_decoding
running = next_running
next_running = None
scheduler.lock_running(running)
scheduler.active_seqs(running)
for idx in range(num_loops):

# pre-forward before get last token
Expand All @@ -1221,7 +1213,7 @@ async def _async_loop_main(
if idx == num_loops // 2:
forward_event.clear()

scheduler.unlock_running(running)
scheduler.deactive_seqs(running)
has_runable_event.set()

@staticmethod
Expand Down Expand Up @@ -1312,8 +1304,8 @@ async def async_loop(self):
forward_event=forward_event,
has_runable_event=has_runable_event,
inputs_maker=inputs_maker)
except Exception as e:
logger.exception(f'exception happened: {type(e)} {e}')
except Exception:
logger.exception('Engine main loop failed.')
finally:
self._loop_finally()

Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@ def warmup(self):
# warmup decoding(with cuda graph)
capture_batch_sizes = self.patched_model.get_capture_batch_sizes()
capture_batch_sizes = sorted(capture_batch_sizes, reverse=True)
if self.cache_config.role == EngineRole.Prefill:
# do not warmup decoding for prefill engine
capture_batch_sizes = []
for num_tokens in capture_batch_sizes:
inputs = self.inputs_strategy.make_dummy(num_tokens,
is_decoding=True,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _gather_request(self, req_types: List[RequestType], data: List[Any]):
resps = []
for rtype, rdata in zip(req_types, data):
event = asyncio.Event()
resp = Response(type=ResponseType.HANDLER_NOT_EXIST,
resp = Response(type=ResponseType.INTERNAL_ENGINE_ERROR,
sender_id=self.sender_id,
event=event,
data=None,
Expand Down
57 changes: 33 additions & 24 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional

Expand All @@ -14,6 +15,8 @@
from .block import LogicalTokenBlocks

if TYPE_CHECKING:
from lmdeploy.pytorch.paging.scheduler import Scheduler
from lmdeploy.pytorch.paging.seq_states.states import StateBase
from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy

logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -146,21 +149,19 @@ class MessageStatus(enum.Enum):
"""Status of a sequence."""

WAITING = enum.auto()
RUNNING = enum.auto()
READY = enum.auto()
STOPPED = enum.auto()
ENDED = enum.auto()
ABORTED = enum.auto()
LOCKED = enum.auto()
RUNNING = enum.auto()

# PD Disaggregation
# WAITING_MIGRATION: state of Unmigrated Requests
# MIGRATION_WAITING: state of Unmigrated Requests
# in both prefill and decode engines are tagged by
# RUNNING_MIGRATION: state of Migrating Requests
# MIGRATION_READY: state of Migrating Requests
# in decode engine
TO_BE_MIGRATED = enum.auto()
WAITING_MIGRATION = enum.auto()
RUNNING_MIGRATION = enum.auto()
MIGRATION_LOCKED = enum.auto()
MIGRATION_WAITING = enum.auto()
MIGRATION_READY = enum.auto()
MIGRATION_RUNNING = enum.auto()
MIGRATION_DONE = enum.auto()


Expand All @@ -179,9 +180,7 @@ class SequenceManager:

def __init__(self, seq_meta: SequenceMeta) -> None:
self._seq_map: SeqMap = dict()
self._status_seq_map: Dict[MessageStatus, SeqMap] = dict()
for status in MessageStatus:
self._status_seq_map[status] = dict()
self._status_seq_map: Dict[MessageStatus, SeqMap] = defaultdict(dict)

self.seq_meta = seq_meta
self._seq_count = 0
Expand Down Expand Up @@ -247,12 +246,12 @@ def _to_ndarray(token_ids) -> np.ndarray:
class SchedulerSession:
"""Scheduler session."""

def __init__(self, session_id: int, seq_manager: SequenceManager) -> None:
def __init__(self, session_id: int, seq_manager: SequenceManager, scheduler: 'Scheduler') -> None:
self.session_id = session_id
self.seq_meta = seq_manager.seq_meta
self.status: MessageStatus = MessageStatus.RUNNING
self.sequences: SeqMap = dict()
self.seq_manager = seq_manager
self.scheduler = scheduler

def add_sequence(self,
token_ids: Tensor,
Expand All @@ -264,6 +263,8 @@ def add_sequence(self,
resp_cache: bool = False,
preserve_cache: bool = False) -> 'SchedulerSequence':
"""Add a new message."""
from lmdeploy.pytorch.paging.seq_states.states import build_seq_state

if sampling_param is None:
sampling_param = SamplingParam()

Expand All @@ -282,12 +283,22 @@ def add_sequence(self,
mode=UpdateTokenMode.INPUTS,
)
self.sequences[seq.seq_id] = seq

# set status
# update seq manager
status = MessageStatus.WAITING if migration_request is None else MessageStatus.MIGRATION_WAITING
seq.set_state(build_seq_state(self.scheduler, seq, status))
self.seq_manager.add_sequence(seq)

# metrics
seq.record_event(EventType.QUEUED)

return seq

def remove_sequence(self, seq: 'SchedulerSequence'):
"""Remove sequence."""
assert seq.seq_id in self.sequences
seq.state.free()
self.sequences.pop(seq.seq_id)
self.seq_manager.remove_sequence(seq)

Expand Down Expand Up @@ -557,7 +568,6 @@ class SchedulerSequence:
arrive_time: float = 0.0
output_start_pos: int = 0
meta: Any = None
_status: MessageStatus = field(default=MessageStatus.WAITING, init=False)
num_ignored_history: int = 0
model_meta: Dict[str, Any] = None

Expand All @@ -583,6 +593,7 @@ def __post_init__(self):
self._num_images: int = len(self.history_embeddings)
self._num_history_cross: int = 0
self._num_cross: int = self.history_multimodals.get_encoder_len(0, self._num_token_ids)
self._state = None

@property
def block_size(self) -> int:
Expand Down Expand Up @@ -692,23 +703,21 @@ def num_blocks(self):
return len(self.logical_blocks)

@property
def seq_manager(self) -> SequenceManager:
"""Sequence manager."""
return self.session.seq_manager
def state(self) -> 'StateBase':
return self._state

def set_state(self, state: 'StateBase'):
"""Set state."""
self._state = state

@property
def status(self):
return self._status
return self.state.status

@property
def return_logits(self):
return self.sampling_param.out_logits

@status.setter
def status(self, value: MessageStatus):
self.seq_manager.update_sequence_status(self, value)
self._status = value

def num_all_cross_tokens(self):
"""Num of all cross tokens."""
return self._num_cross + self._num_history_cross
Expand Down
Loading