From d96b324fc4ea46783b9efb8345218a278d328364 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 27 Nov 2025 18:27:16 +0800 Subject: [PATCH 1/6] refactor sequence states --- lmdeploy/pytorch/engine/engine.py | 45 ++-- lmdeploy/pytorch/messages.py | 57 +++-- .../paging/eviction_helper/__init__.py | 17 +- .../recompute_eviction_helper.py | 7 +- lmdeploy/pytorch/paging/scheduler.py | 207 +++++------------- .../pytorch/paging/seq_states/__init__.py | 2 + lmdeploy/pytorch/paging/seq_states/states.py | 166 ++++++++++++++ lmdeploy/pytorch/strategies/ar/sequence.py | 4 +- .../pytorch/strategies/ar_spec/sequence.py | 4 +- lmdeploy/pytorch/strategies/dllm/sequence.py | 4 +- tests/pytorch/paging/test_block_manager.py | 92 ++++++-- tests/pytorch/paging/test_block_trie.py | 55 +++-- tests/pytorch/paging/test_scheduler.py | 66 +++--- 13 files changed, 438 insertions(+), 288 deletions(-) create mode 100644 lmdeploy/pytorch/paging/seq_states/__init__.py create mode 100644 lmdeploy/pytorch/paging/seq_states/states.py diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index b343040873..609f8fa274 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -18,7 +18,7 @@ from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest) from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.utils import get_logger, get_max_batch_size, get_model, logging_timer +from lmdeploy.utils import get_logger, get_max_batch_size, get_model from ..adapter.adapter import AdapterManager from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig, SpecDecodeConfig @@ -281,7 +281,7 @@ def do_prefill_dp(self): if self.next_is_prefill: ret = scheduler.has_waiting() else: - ret = not scheduler.has_running() + ret = not scheduler.has_ready() return ret def do_prefill_default(self): @@ -289,7 +289,7 @@ def do_prefill_default(self): scheduler = self.scheduler if not scheduler.has_waiting(): return False - num_running = scheduler.num_running() + num_ready = scheduler.num_ready() num_waiting = scheduler.num_waiting() max_batches = self.scheduler_config.max_batches # prefill if too much waiting @@ -297,7 +297,7 @@ def do_prefill_default(self): if num_waiting >= permitted_waiting: return True # prefill if no enough running - if num_running < max_batches * 0.5: + if num_ready < max_batches * 0.5: return True # decoding return False @@ -328,11 +328,11 @@ async def prefetch_next_inputs(self): if prefill: enable = True else: - num_running = scheduler.num_running() + num_ready = scheduler.num_ready() is_decoding = self.forward_inputs['inputs'].is_decoding running_threshold = (self.scheduler_config.max_batches // 4) if is_decoding or self.spec_decoding else 0 - if num_running > running_threshold: + if num_ready > running_threshold: enable = True if enable: @@ -592,7 +592,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs): if session_id in self.scheduler.sessions: msgs = list(self.scheduler.sessions[session_id].sequences.values()) if len(msgs) > 0 and msgs[0].preserve_cache: - self.scheduler._set_message_status(msgs[0], MessageStatus.TO_BE_MIGRATED) + msgs[0].state.finish() else: self.end_session(session_id) resp_type = ResponseType.SUCCESS @@ -676,9 +676,7 @@ def __update_max_new_tokens(msg): preserve_cache=req.data.get('preserve_cache')) msg = next(iter(sess.sequences.values())) __update_max_new_tokens(msg) - scheduler.add_sequence(msg) if migration_request: - self.scheduler._set_message_status(msg, MessageStatus.WAITING_MIGRATION) self.migration_event.set() else: msg = next(iter(sess.sequences.values())) @@ -689,7 +687,7 @@ def __update_max_new_tokens(msg): mode=UpdateTokenMode.INPUTS, ) msg.sampling_param = sampling_param - msg.status = MessageStatus.WAITING + msg.state.activate() __update_max_new_tokens(msg) msg.resp = req.resp @@ -775,7 +773,6 @@ def __has_values(input_multimodals): return vision_embedding_inputs @torch.inference_mode() - @logging_timer('CreateModelInputs', logger) @record_function('CreateModelInputs') def create_model_inputs(self, messages: SeqList, is_prefill: bool): """Create model inputs from messages. @@ -861,7 +858,7 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, if model_metas is None: model_metas = [None] * len(running) for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): - if msg.status != MessageStatus.MIGRATION_LOCKED: + if msg.status != MessageStatus.MIGRATION_RUNNING: continue update_token = token @@ -870,7 +867,7 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, if stop: update_token = _EMPTY_TOKEN msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) - msg.status = MessageStatus.STOPPED + msg.state.finish() @record_function('make_infer_outputs') def _make_infer_outputs( @@ -889,7 +886,7 @@ def _make_infer_outputs( logprobs.indices = logprobs.indices.tolist() seq_length = [seq.num_token_ids for seq in running] - is_run = [seq.status == MessageStatus.LOCKED for seq in running] + is_run = [seq.status == MessageStatus.RUNNING for seq in running] self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding) # generate output @@ -966,7 +963,7 @@ def __need_schedule_again(prefill: bool, scheduler_output): if (self.engine_config.role == EngineRole.Prefill): return False # disable decoding if no running reqs. - if not self.scheduler.has_running(): + if not self.scheduler.has_ready(): logger.warning('No running sequences for decoding scheduling after prefill scheduling.') return False return True @@ -1107,12 +1104,12 @@ async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionReque async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): """Async loop migration.""" while True: - migration_running = self.scheduler._schedule_migration() - if not migration_running and not self.scheduler.has_migration_waiting(): + migration_ready = self.scheduler._schedule_migration() + if not migration_ready and not self.scheduler.has_migration_waiting(): await self.migration_event.wait() - elif migration_running: + elif migration_ready: self.migration_event.clear() - for msg in migration_running: + for msg in migration_ready: migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] migration_request = msg.migration_request prefill_block_ids = migration_request.remote_block_ids @@ -1137,8 +1134,8 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event # generate output outputs: Dict[int, InferOutput] = dict() - self.scheduler.lock_running_migration(migration_running) - for _, msg in enumerate(migration_running): + self.scheduler.activate_migration_seqs(migration_ready) + for _, msg in enumerate(migration_ready): session_id = msg.session_id msg.resp.type = ResponseType.SUCCESS token_ids = [msg.migration_request.remote_token_id] @@ -1155,7 +1152,7 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event outputs[session_id] = out self.update_running_migration([msg], np.array([token_ids]), [False], [None]) resp_que.put_nowait(outputs) - self.scheduler.unlock_running_migration(migration_running) + self.scheduler.deactivate_migration_seqs(migration_ready) has_runable_event.set() else: # release coroutine for decoding @@ -1202,7 +1199,7 @@ async def _async_loop_main( is_decoding = forward_inputs['inputs'].is_decoding running = next_running next_running = None - scheduler.lock_running(running) + scheduler.active_seqs(running) for idx in range(num_loops): # pre-forward before get last token @@ -1221,7 +1218,7 @@ async def _async_loop_main( if idx == num_loops // 2: forward_event.clear() - scheduler.unlock_running(running) + scheduler.deactive_seqs(running) has_runable_event.set() @staticmethod diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 2e02fabf04..1fd26ea2cd 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -14,6 +14,8 @@ from .block import LogicalTokenBlocks if TYPE_CHECKING: + from lmdeploy.pytorch.paging.scheduler import Scheduler + from lmdeploy.pytorch.paging.seq_states.states import StateBase from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy logger = get_logger('lmdeploy') @@ -146,21 +148,19 @@ class MessageStatus(enum.Enum): """Status of a sequence.""" WAITING = enum.auto() - RUNNING = enum.auto() + READY = enum.auto() STOPPED = enum.auto() - ENDED = enum.auto() - ABORTED = enum.auto() - LOCKED = enum.auto() + RUNNING = enum.auto() # PD Disaggregation - # WAITING_MIGRATION: state of Unmigrated Requests + # MIGRATION_WAITING: state of Unmigrated Requests # in both prefill and decode engines are tagged by - # RUNNING_MIGRATION: state of Migrating Requests + # MIGRATION_READY: state of Migrating Requests # in decode engine TO_BE_MIGRATED = enum.auto() - WAITING_MIGRATION = enum.auto() - RUNNING_MIGRATION = enum.auto() - MIGRATION_LOCKED = enum.auto() + MIGRATION_WAITING = enum.auto() + MIGRATION_READY = enum.auto() + MIGRATION_RUNNING = enum.auto() MIGRATION_DONE = enum.auto() @@ -203,10 +203,9 @@ def num_sequences(self, status: MessageStatus): """Num sequences.""" return len(self.get_sequences(status)) - def add_sequence(self, seq: 'SchedulerSequence'): + def add_sequence(self, seq: 'SchedulerSequence', status: MessageStatus): """Add sequence.""" seq_id = seq.seq_id - status = seq.status status_map = self._status_seq_map[status] self._seq_map[seq_id] = seq status_map[seq_id] = seq @@ -247,12 +246,12 @@ def _to_ndarray(token_ids) -> np.ndarray: class SchedulerSession: """Scheduler session.""" - def __init__(self, session_id: int, seq_manager: SequenceManager) -> None: + def __init__(self, session_id: int, seq_manager: SequenceManager, scheduler: 'Scheduler') -> None: self.session_id = session_id self.seq_meta = seq_manager.seq_meta - self.status: MessageStatus = MessageStatus.RUNNING self.sequences: SeqMap = dict() self.seq_manager = seq_manager + self.scheduler = scheduler def add_sequence(self, token_ids: Tensor, @@ -264,6 +263,8 @@ def add_sequence(self, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" + from lmdeploy.pytorch.paging.seq_states.states import build_seq_state + if sampling_param is None: sampling_param = SamplingParam() @@ -282,12 +283,22 @@ def add_sequence(self, mode=UpdateTokenMode.INPUTS, ) self.sequences[seq.seq_id] = seq - self.seq_manager.add_sequence(seq) + + # set status + # update seq manager + status = MessageStatus.WAITING if migration_request is None else MessageStatus.MIGRATION_WAITING + seq.set_state(build_seq_state(self.scheduler, seq, status)) + self.seq_manager.add_sequence(seq, status) + + # metrics + seq.record_event(EventType.QUEUED) + return seq def remove_sequence(self, seq: 'SchedulerSequence'): """Remove sequence.""" assert seq.seq_id in self.sequences + seq.state.free() self.sequences.pop(seq.seq_id) self.seq_manager.remove_sequence(seq) @@ -557,7 +568,6 @@ class SchedulerSequence: arrive_time: float = 0.0 output_start_pos: int = 0 meta: Any = None - _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 model_meta: Dict[str, Any] = None @@ -583,6 +593,7 @@ def __post_init__(self): self._num_images: int = len(self.history_embeddings) self._num_history_cross: int = 0 self._num_cross: int = self.history_multimodals.get_encoder_len(0, self._num_token_ids) + self._state = None @property def block_size(self) -> int: @@ -692,23 +703,21 @@ def num_blocks(self): return len(self.logical_blocks) @property - def seq_manager(self) -> SequenceManager: - """Sequence manager.""" - return self.session.seq_manager + def state(self) -> 'StateBase': + return self._state + + def set_state(self, state: 'StateBase'): + """Set state.""" + self._state = state @property def status(self): - return self._status + return self.state.status @property def return_logits(self): return self.sampling_param.out_logits - @status.setter - def status(self, value: MessageStatus): - self.seq_manager.update_sequence_status(self, value) - self._status = value - def num_all_cross_tokens(self): """Num of all cross tokens.""" return self._num_cross + self._num_history_cross diff --git a/lmdeploy/pytorch/paging/eviction_helper/__init__.py b/lmdeploy/pytorch/paging/eviction_helper/__init__.py index 9d82582f1e..6b5c44ff97 100644 --- a/lmdeploy/pytorch/paging/eviction_helper/__init__.py +++ b/lmdeploy/pytorch/paging/eviction_helper/__init__.py @@ -1,4 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .recompute_eviction_helper import RecomputeEvictionHelper +from lmdeploy.utils import get_logger -__all__ = ['RecomputeEvictionHelper'] +logger = get_logger('lmdeploy') + + +def build_eviction_helper(scheduler, eviction_type: str): + """Build eviction helper.""" + if eviction_type == 'copy': + logger.warning('`copy` eviction has been deprecated, ' + 'use `recompute` instead.') + eviction_type = 'recompute' + if eviction_type == 'recompute': + from .recompute_eviction_helper import RecomputeEvictionHelper + return RecomputeEvictionHelper(scheduler) + else: + raise TypeError(f'Unknown eviction type: {eviction_type}') diff --git a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py index be0d09a5f9..bdded115dd 100644 --- a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py +++ b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py @@ -35,8 +35,7 @@ def _evict_for_seq_default(self, seq: SchedulerSequence, evictable_seqs: List[Sc if evict_seq.num_blocks == 0: continue - block_manager.free(evict_seq) - evict_seq.set_step(0) + evict_seq.state.free() num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks()) if num_req <= 0: success = True @@ -77,9 +76,7 @@ def _evict_for_ssm(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerS continue # free sequence - block_manager.free(evict_seq) - evict_seq.set_step(0) - state_manager.free(evict_seq) + evict_seq.state.free() has_free_state = True num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks()) if num_req <= 0: diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 579d335d22..1359c88ae5 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -6,12 +6,13 @@ from typing import Dict, List from lmdeploy.messages import EventType, ScheduleMetrics -from lmdeploy.utils import get_logger, logging_timer +from lmdeploy.utils import get_logger from ..config import CacheConfig, SchedulerConfig from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta from .block_manager import build_block_manager from .block_trie import BlockTrie +from .eviction_helper import build_eviction_helper from .state_manager import StateManager logger = get_logger('lmdeploy') @@ -55,7 +56,7 @@ def __init__( self.state_manager = StateManager(self.cache_config.num_state_caches) self.is_ssm = len(self.cache_config.states_shapes) > 0 - self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type) + self.eviction_helper = build_eviction_helper(self, self.scheduler_config.eviction_type) seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size) self.seq_manager = SequenceManager(seq_meta) @@ -67,9 +68,9 @@ def waiting(self): return list(seq_map.values()) @property - def running(self): + def ready(self): """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING) + seq_map = self.seq_manager.get_sequences(MessageStatus.READY) return list(seq_map.values()) @property @@ -79,21 +80,15 @@ def hanging(self): return list(seq_map.values()) @property - def locked(self): + def running(self): """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.LOCKED) - return list(seq_map.values()) - - @property - def waiting_migration(self): - """Get migration sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING_MIGRATION) + seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING) return list(seq_map.values()) @property - def running_migration(self): + def migration_waiting(self): """Get migration sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING_MIGRATION) + seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_WAITING) return list(seq_map.values()) @property @@ -102,26 +97,6 @@ def migration_done(self): seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_DONE) return list(seq_map.values()) - def build_eviction_helper(self, eviction_type: str): - if eviction_type == 'copy': - logger.warning('`copy` eviction has been deprecated, ' - 'use `recompute` instead.') - eviction_type = 'recompute' - if eviction_type == 'recompute': - from .eviction_helper import RecomputeEvictionHelper - return RecomputeEvictionHelper(self) - else: - raise TypeError(f'Unknown eviction type: {eviction_type}') - - def _set_message_status(self, message: SchedulerSequence, status: MessageStatus): - """Set status of message. - - Args: - message (SchedulerSequence): message to setup status. - status (MessageStatus): New message status. - """ - message.status = status - def add_session(self, session_id: int): """Add new session. @@ -129,32 +104,18 @@ def add_session(self, session_id: int): session_id (int): New session id. """ assert session_id not in self.sessions - session = SchedulerSession(session_id, seq_manager=self.seq_manager) + session = SchedulerSession(session_id, seq_manager=self.seq_manager, scheduler=self) self.sessions[session_id] = session return session - def add_sequence(self, seq: SchedulerSequence): - """Add sequence. - - Args: - seq (SchedulerSequence): New sequence. - """ - assert (seq.session_id in self.sessions), f'Unknown session id {seq.session_id}' - - # push message to waiting queue - self._set_message_status(seq, MessageStatus.WAITING) - - seq.record_event(EventType.QUEUED) - - @logging_timer('ScheduleMigration', logger) def _schedule_migration(self): - running_migration: SeqList = [] + migration_ready: SeqList = [] migrating_token_count = 0 def _to_running(seq: SchedulerSequence): """To running.""" - seq.status = MessageStatus.RUNNING_MIGRATION - running_migration.append(seq) + seq.state.activate() + migration_ready.append(seq) nonlocal migrating_token_count migrating_token_count += seq.num_token_ids @@ -169,28 +130,27 @@ def __evict_for_seq(seq: SchedulerSequence, waiting): def _reorder_migrating(): """Reorder waiting.""" - return sorted(self.waiting_migration, key=lambda seq: seq.arrive_time) + return sorted(self.migration_waiting, key=lambda seq: seq.arrive_time) - waiting_migration = _reorder_migrating() + migration_waiting = _reorder_migrating() - max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() - while len(waiting_migration) > 0 and len(running_migration) < max_batches: - seq = waiting_migration.pop(0) - self.block_trie.match(waiting_migration) - if not __evict_for_seq(seq, waiting_migration): + max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running() + while len(migration_waiting) > 0 and len(migration_ready) < max_batches: + seq = migration_waiting.pop(0) + self.block_trie.match(migration_waiting) + if not __evict_for_seq(seq, migration_waiting): break # allocate session memory self.block_manager.allocate(seq) _to_running(seq) - return running_migration + return migration_ready - @logging_timer('SchedulePrefilling', logger) def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" - max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() + max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running() eviction_helper = self.eviction_helper swap_out_map: Dict[int, int] = dict() swap_in_map: Dict[int, int] = dict() @@ -200,7 +160,7 @@ def _schedule_prefill(self, prealloc_size: int = 0): def _to_running(seq: SchedulerSequence): """To running.""" - seq.status = MessageStatus.RUNNING + seq.state.activate() running.append(seq) nonlocal token_count token_count += seq.num_token_ids @@ -243,11 +203,10 @@ def _reorder_waiting(): return running, swap_in_map, swap_out_map, copy_map - @logging_timer('ScheduleDecoding', logger) def _schedule_decoding(self, prealloc_size: int = 0): """Schedule decoding.""" - running = self.running + running = self.ready assert len(running) != 0 eviction_helper = self.eviction_helper @@ -272,27 +231,18 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): # 1. running for seq in running: - # token + n - num_required_blocks = self.block_manager.num_required_blocks(seq, prealloc_size) - if len(seq.logical_blocks) + num_required_blocks > self.block_manager.num_gpu_blocks: - # Reach max gpu cache size. - logger.warning(f'session[{seq.session_id}] ' - f'sequence[{seq.seq_id}] ' - 'reach max gpu size.') - self._set_message_status(seq, MessageStatus.ABORTED) - self.block_manager.free(seq) - seq.set_step(0) - continue + assert seq.num_blocks + num_required_blocks <= self.block_manager.num_gpu_blocks, ( + 'Sequence requires more blocks than total gpu blocks.') if not __evict_for_seq(seq, num_required_blocks): - self._set_message_status(seq, MessageStatus.WAITING) + seq.state.evict() continue self.block_manager.allocate(seq, prealloc_size) self.block_trie.allocate(seq) - return self.running, swap_in_map, swap_out_map, copy_map + return self.ready[:self.scheduler_config.max_batches], swap_in_map, swap_out_map, copy_map def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" @@ -304,37 +254,16 @@ def schedule(self, is_prefill: bool, prealloc_size: int = 0): return SchedulerOutput(running=running, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map) - def _set_session_status(self, session_id: int, status: MessageStatus): - """Setup the status of session. + def stop_session(self, session_id: int): + """Stop session. Args: session_id (int): The session id. - status (MessageStatus): New status. """ assert session_id in self.sessions session = self.sessions[session_id] - session.status = status for seq in session.sequences.values(): - seq.status = status - - def stop_session(self, session_id: int): - """Stop session. - - Args: - session_id (int): The session id. - """ - self._set_session_status(session_id, MessageStatus.STOPPED) - - def _remove_sequence(self, seq: SchedulerSequence): - """Remove sequence(unsafe) - - Args: - seq (SchedulerSequence): sequence to remove - """ - self.block_manager.free(seq) - self.state_manager.free(seq) - seq.set_step(0) - seq.session.remove_sequence(seq) + seq.state.stop() def end_session(self, session_id: int): """End session. @@ -345,25 +274,20 @@ def end_session(self, session_id: int): session = self.sessions[session_id] seqs = list(session.sequences.values()) for seq in seqs: - self._remove_sequence(seq) + seq.state.stop() + session.remove_sequence(seq) self.sessions.pop(session_id) def has_unfinished(self): """Check if there are any unfinished message.""" - return self.has_running() or self.has_waiting() or self.has_migration_done() + return self.has_ready() or self.has_waiting() or self.has_migration_done() - def has_running(self): - return self.num_running() > 0 + def has_ready(self): + return self.num_ready() > 0 def has_waiting(self): return self.num_waiting() > 0 - def has_to_be_migrated(self): - return self.num_to_be_migrated() > 0 - - def has_migration_running(self): - return self.num_running() > 0 - def has_migration_waiting(self): return self.num_migration_waiting() > 0 @@ -374,71 +298,58 @@ def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] - def num_running(self): + def num_ready(self): """Num running.""" - return self.seq_manager.num_sequences(MessageStatus.RUNNING) + return self.seq_manager.num_sequences(MessageStatus.READY) def num_waiting(self): """Num waiting.""" return self.seq_manager.num_sequences(MessageStatus.WAITING) - def num_to_be_migrated(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.TO_BE_MIGRATED) - - def num_migration_locked(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.MIGRATION_LOCKED) - - def num_migration_running(self): - """Num migration running.""" - return self.seq_manager.num_sequences(MessageStatus.RUNNING_MIGRATION) - def num_migration_done(self): """Num migration done.""" return self.seq_manager.num_sequences(MessageStatus.MIGRATION_DONE) def num_migration_waiting(self): """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.WAITING_MIGRATION) + return self.seq_manager.num_sequences(MessageStatus.MIGRATION_WAITING) - def num_locked(self): + def num_running(self): """Num locked.""" - return self.seq_manager.num_sequences(MessageStatus.LOCKED) + return self.seq_manager.num_sequences(MessageStatus.RUNNING) - def lock_running(self, running: SeqList): + def active_seqs(self, running: SeqList): """Lock running sequence.""" for seq in running: - if seq.status == MessageStatus.RUNNING: - self._set_message_status(seq, MessageStatus.LOCKED) + if seq.status == MessageStatus.READY: + seq.state.activate() - def unlock_running(self, locked: SeqList): - for seq in locked: - if seq.status == MessageStatus.LOCKED: - self._set_message_status(seq, MessageStatus.RUNNING) + def deactive_seqs(self, running: SeqList): + for seq in running: + if seq.status == MessageStatus.RUNNING: + seq.state.deactivate() - def lock_running_migration(self, running: SeqList): + def activate_migration_seqs(self, running: SeqList): """Lock running sequence.""" for seq in running: - if seq.status == MessageStatus.RUNNING_MIGRATION: - self._set_message_status(seq, MessageStatus.MIGRATION_LOCKED) + if seq.status == MessageStatus.MIGRATION_READY: + seq.state.activate() - def unlock_running_migration(self, locked: SeqList): + def deactivate_migration_seqs(self, running: SeqList): """Unlock running migration.""" - for seq in locked: - if seq.status == MessageStatus.MIGRATION_LOCKED: - self._set_message_status(seq, MessageStatus.MIGRATION_DONE) + for seq in running: + if seq.status == MessageStatus.MIGRATION_RUNNING: + seq.state.deactivate() def collect_migration_done(self): - migration_done = self.migration_done - for seq in migration_done: - self._set_message_status(seq, MessageStatus.RUNNING) + for seq in self.migration_done: + seq.state.activate() @property def schedule_metrics(self): return ScheduleMetrics( - active_seqs=self.num_locked(), - waiting_seqs=self.num_waiting() + self.num_running(), + active_seqs=self.num_running(), + waiting_seqs=self.num_waiting() + self.num_ready(), total_blocks=self.block_manager.num_gpu_blocks, free_blocks=self.block_manager.get_num_free_gpu_blocks(), ) diff --git a/lmdeploy/pytorch/paging/seq_states/__init__.py b/lmdeploy/pytorch/paging/seq_states/__init__.py new file mode 100644 index 0000000000..bba2109f8e --- /dev/null +++ b/lmdeploy/pytorch/paging/seq_states/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .states import StateBase, build_seq_state # noqa: F401 diff --git a/lmdeploy/pytorch/paging/seq_states/states.py b/lmdeploy/pytorch/paging/seq_states/states.py new file mode 100644 index 0000000000..4ab7b1f154 --- /dev/null +++ b/lmdeploy/pytorch/paging/seq_states/states.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING + +from lmdeploy.pytorch.messages import MessageStatus, SchedulerSequence + +if TYPE_CHECKING: + from lmdeploy.pytorch.paging import Scheduler + + +def _free_seq(seq: SchedulerSequence, scheduler: 'Scheduler'): + """Free the sequence.""" + if seq.num_blocks > 0: + scheduler.block_manager.free(seq) + if seq.logical_state >= 0: + scheduler.state_manager.free(seq) + seq.set_step(0) + + +class StateBase: + status = None + _registry = dict() + + def __init_subclass__(cls, **kargs) -> None: + super().__init_subclass__(**kargs) + if cls.status: + cls._registry[cls.status] = cls + + @classmethod + def build(cls, scheduler: 'Scheduler', seq: 'SchedulerSequence', status: MessageStatus) -> 'StateBase': + """Build sequence state.""" + if status not in cls._registry: + raise NotImplementedError(f'Unsupported status {status} for building seq state.') + return cls._registry[status](seq, scheduler) + + def __init__(self, seq: SchedulerSequence, scheduler: 'Scheduler'): + self.seq = seq + self.scheduler = scheduler + + def to_state(self, new_state): + """Transition to a new state.""" + self.scheduler.seq_manager.update_sequence_status(self.seq, new_state.status) + self.seq.set_state(new_state(self.seq, self.scheduler)) + + def evict(self): + """Evict the state.""" + raise NotImplementedError(f'evict not implemented for state {self.status}') + + def activate(self): + """Activate the state.""" + raise NotImplementedError(f'activate not implemented for state {self.status}') + + def deactivate(self): + """Deactivate the state.""" + raise NotImplementedError(f'deactivate not implemented for state {self.status}') + + def finish(self): + """Finish the state.""" + raise NotImplementedError(f'finish not implemented for state {self.status}') + + def stop(self): + """Stop the state.""" + self.to_state(StoppedState) + + def free(self): + """Free the state.""" + _free_seq(self.seq, self.scheduler) + + +class WaitingState(StateBase): + """State for waiting sequences.""" + status = MessageStatus.WAITING + + def activate(self): + """From WAITING to READY.""" + num_req_blocks = self.scheduler.block_manager.num_required_blocks(self.seq) + assert self.seq.num_blocks >= num_req_blocks + if self.scheduler.is_ssm: + assert self.seq.logical_state >= 0 + self.to_state(ReadyState) + + def evict(self): + self.to_state(WaitingState) + + +class ReadyState(StateBase): + """State for ready sequences.""" + status = MessageStatus.READY + + def activate(self): + """From READY to RUNNING.""" + self.to_state(RunningState) + + def evict(self): + self.to_state(WaitingState) + + +class StoppedState(StateBase): + """State for stopped sequences.""" + status = MessageStatus.STOPPED + + def activate(self): + """From STOPPED to WAITING.""" + assert self.seq.num_token_ids > 0 + self.to_state(WaitingState) + + +class RunningState(StateBase): + """State for running sequences.""" + status = MessageStatus.RUNNING + + def deactivate(self): + self.to_state(ReadyState) + + def finish(self): + if self.seq.preserve_cache: + self.to_state(ToBeMigratedState) + else: + self.to_state(StoppedState) + + +class ToBeMigratedState(StateBase): + """State for to be migrated sequences.""" + status = MessageStatus.TO_BE_MIGRATED + + +class MigrationWaitingState(StateBase): + """State for migration waiting sequences.""" + status = MessageStatus.MIGRATION_WAITING + + def activate(self): + self.to_state(MigrationReadyState) + + def evict(self): + self.to_state(MigrationWaitingState) + + +class MigrationReadyState(StateBase): + """State for migration ready sequences.""" + status = MessageStatus.MIGRATION_READY + + def activate(self): + self.to_state(MigrationRunningState) + + def evict(self): + self.to_state(MigrationWaitingState) + + +class MigrationDoneState(StateBase): + """State for migration done sequences.""" + status = MessageStatus.MIGRATION_DONE + + def finish(self): + self.to_state(ReadyState) + + +class MigrationRunningState(StateBase): + """State for migration running sequences.""" + status = MessageStatus.MIGRATION_RUNNING + + def finish(self): + self.to_state(MigrationDoneState) + + +def build_seq_state(scheduler: 'Scheduler', seq: 'SchedulerSequence', status: MessageStatus) -> StateBase: + """Build sequence state.""" + return StateBase.build(scheduler, seq, status) diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index 197217c8bb..de7b68e2de 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -125,10 +125,10 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_d update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL for token, msg, stop, model_meta, routed_experts in zip(next_token_ids, running, stopped, model_metas, all_routed_experts): - if msg.status != MessageStatus.LOCKED: + if msg.status != MessageStatus.RUNNING: continue # fill token msg.update_token_ids(token, model_meta=model_meta, mode=update_mode, routed_experts=routed_experts) if stop: - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED + msg.state.finish() diff --git a/lmdeploy/pytorch/strategies/ar_spec/sequence.py b/lmdeploy/pytorch/strategies/ar_spec/sequence.py index ba4236e988..2e272e1473 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/sequence.py +++ b/lmdeploy/pytorch/strategies/ar_spec/sequence.py @@ -179,11 +179,11 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_d msg = running[idx] stop = stopped[idx] model_meta = model_metas[idx] - if msg.status != MessageStatus.LOCKED: + if msg.status != MessageStatus.RUNNING: continue cur_draft_tokens = draft_token_ids[idx] # fill token msg.update_token_ids(token, draft_token_ids=cur_draft_tokens, model_meta=model_meta, mode=update_mode) if stop: msg.set_stop_pos(stop_pos[idx]) - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED + msg.state.finish() diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py index ab004a2b63..05c962632c 100644 --- a/lmdeploy/pytorch/strategies/dllm/sequence.py +++ b/lmdeploy/pytorch/strategies/dllm/sequence.py @@ -238,11 +238,11 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_d stop = stopped[idx] model_meta = model_metas[idx] mask = dllm_mask[idx] - if msg.status != MessageStatus.LOCKED: + if msg.status != MessageStatus.RUNNING: continue # fill token msg.update_token_ids(token, dllm_mask=mask, model_meta=model_meta, mode=update_mode) if stop: msg.set_stop_pos(stop_pos[idx]) - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED + msg.state.finish() diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py index f74b6548cf..b08116d7f4 100644 --- a/tests/pytorch/paging/test_block_manager.py +++ b/tests/pytorch/paging/test_block_manager.py @@ -2,9 +2,10 @@ import pytest import torch -from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta -from lmdeploy.pytorch.paging.block_manager import DefaultBlockManager, WindowBlockManager +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig +from lmdeploy.pytorch.messages import SequenceMeta from lmdeploy.pytorch.paging.block_manager.base_block_manager import LogicalAllocator +from lmdeploy.pytorch.paging.scheduler import Scheduler # yapf: enable @@ -86,18 +87,39 @@ def num_gpu_blocks(self): yield 4 @pytest.fixture - def block_mgr(self, num_cpu_blocks, num_gpu_blocks): - yield DefaultBlockManager(num_cpu_blocks, num_gpu_blocks) + def max_batch_size(self): + yield 4 + + @pytest.fixture + def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size): + yield CacheConfig(max_batches=max_batch_size, + block_size=block_size, + num_cpu_blocks=num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks) + + @pytest.fixture + def scheduler_config(self, max_batch_size): + yield SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') @pytest.fixture - def seq_manager(self, block_size): + def seq_meta(self, block_size): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy strategy = ARSequenceStrategy() - seq_meta = SequenceMeta(block_size, strategy=strategy) - yield SequenceManager(seq_meta) + yield SequenceMeta(block_size, strategy=strategy) + + @pytest.fixture + def scheduler(self, cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + @pytest.fixture + def block_mgr(self, scheduler): + yield scheduler.block_manager - def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): - sess = SchedulerSession(0, seq_manager) + def test_alloc(self, scheduler, block_mgr, num_gpu_blocks): + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size # test alloc @@ -121,9 +143,9 @@ def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): msg = sess.add_sequence(token_ids) assert not block_mgr.can_allocate(msg) - def test_num_required_blocks(self, block_mgr, seq_manager, num_gpu_blocks): + def test_num_required_blocks(self, scheduler, block_mgr): from lmdeploy.pytorch.messages import InputEmbeddings - sess = SchedulerSession(0, seq_manager) + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size token_ids = torch.tensor([1]) @@ -142,8 +164,8 @@ def test_num_required_blocks(self, block_mgr, seq_manager, num_gpu_blocks): num_required = block_mgr.num_required_blocks(msg) assert num_required == 3 - def test_append_slot(self, block_mgr, seq_manager, num_gpu_blocks): - sess = SchedulerSession(0, seq_manager) + def test_append_slot(self, scheduler, block_mgr, num_gpu_blocks): + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size # test append @@ -168,8 +190,8 @@ def test_append_slot(self, block_mgr, seq_manager, num_gpu_blocks): assert len(block_table) == 2 assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2 - def test_swap(self, block_mgr, seq_manager, num_gpu_blocks): - sess = SchedulerSession(0, seq_manager) + def test_swap(self, scheduler, block_mgr, num_gpu_blocks): + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size token_ids = torch.tensor([1] * (block_size + 1)) @@ -227,18 +249,40 @@ def num_gpu_blocks(self): yield 4 @pytest.fixture - def seq_manager(self, block_size): + def max_batch_size(self): + yield 4 + + @pytest.fixture + def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size, window_size): + yield CacheConfig(max_batches=max_batch_size, + block_size=block_size, + num_cpu_blocks=num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + window_size=window_size) + + @pytest.fixture + def scheduler_config(self, max_batch_size): + yield SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + + @pytest.fixture + def seq_meta(self, block_size): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy strategy = ARSequenceStrategy() - seq_meta = SequenceMeta(block_size, strategy=strategy) - yield SequenceManager(seq_meta) + yield SequenceMeta(block_size, strategy=strategy) + + @pytest.fixture + def scheduler(self, cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) @pytest.fixture - def block_mgr(self, num_cpu_blocks, num_gpu_blocks, window_size): - yield WindowBlockManager(num_cpu_blocks, num_gpu_blocks, window_size) + def block_mgr(self, scheduler): + yield scheduler.block_manager - def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): - sess = SchedulerSession(0, seq_manager) + def test_alloc(self, scheduler, block_mgr, num_gpu_blocks): + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size # test alloc @@ -262,8 +306,8 @@ def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): msg = sess.add_sequence(token_ids) assert not block_mgr.can_allocate(msg) - def test_win_alloc(self, block_mgr, seq_manager, num_gpu_blocks, window_size): - sess = SchedulerSession(0, seq_manager) + def test_win_alloc(self, scheduler, block_mgr, num_gpu_blocks, window_size): + sess = scheduler.add_session(0) # 2 win block token_ids = torch.tensor([1] * window_size) diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py index 7d20c96dab..5736e4d006 100644 --- a/tests/pytorch/paging/test_block_trie.py +++ b/tests/pytorch/paging/test_block_trie.py @@ -1,10 +1,9 @@ import numpy as np import pytest -from lmdeploy.pytorch.config import CacheConfig -from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta -from lmdeploy.pytorch.paging.block_manager import build_block_manager -from lmdeploy.pytorch.paging.block_trie import BlockTrie +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig +from lmdeploy.pytorch.messages import SequenceMeta +from lmdeploy.pytorch.paging import Scheduler class TestBlockTire: @@ -22,31 +21,45 @@ def num_gpu_blocks(self): yield 16 @pytest.fixture - def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): - yield CacheConfig(max_batches=256, + def max_batch_size(self): + yield 4 + + @pytest.fixture + def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size): + yield CacheConfig(max_batches=max_batch_size, block_size=block_size, num_cpu_blocks=num_cpu_blocks, num_gpu_blocks=num_gpu_blocks, enable_prefix_caching=True) @pytest.fixture - def block_mgr(self, cache_config): - yield build_block_manager(cache_config) + def scheduler_config(self, max_batch_size): + yield SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') @pytest.fixture - def block_trie(self, cache_config, block_mgr): - yield BlockTrie(cache_config, block_mgr) - - @pytest.fixture - def seq_manager(self, block_size): + def seq_meta(self, block_size): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy strategy = ARSequenceStrategy() - seq_meta = SequenceMeta(block_size, strategy=strategy) - yield SequenceManager(seq_meta) + yield SequenceMeta(block_size, strategy=strategy) + + @pytest.fixture + def scheduler(self, cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + @pytest.fixture + def block_mgr(self, scheduler): + yield scheduler.block_manager + + @pytest.fixture + def block_trie(self, scheduler): + yield scheduler.block_trie - def test_allocate(self, block_trie, block_mgr, seq_manager): + def test_allocate(self, block_trie, block_mgr, scheduler): allocator = block_trie.allocator - sess = SchedulerSession(0, seq_manager) + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size token_ids = ([1] * block_size + [2] * block_size) token_ids += [3] * (block_size // 2) @@ -83,9 +96,9 @@ def test_allocate(self, block_trie, block_mgr, seq_manager): assert node in block_trie.leaves assert len(block_trie.leaves) == 1 - def test_match(self, block_trie, block_mgr, seq_manager): + def test_match(self, block_trie, block_mgr, scheduler): allocator = block_trie.allocator - sess = SchedulerSession(0, seq_manager) + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size # initialize cache @@ -121,9 +134,9 @@ def test_match(self, block_trie, block_mgr, seq_manager): ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) assert np.array_equal(ref_cnt, [4, 3]) - def test_evict(self, block_trie, seq_manager, num_gpu_blocks): + def test_evict(self, block_trie, scheduler, num_gpu_blocks): block_mgr = block_trie.block_manager - sess = SchedulerSession(0, seq_manager) + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size token_ids = ([1] * block_size * (num_gpu_blocks - 1)) token_ids += [2] * (block_size // 2) diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index a0acf5f054..3d07c2c019 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -21,15 +21,22 @@ def num_gpu_blocks(self): yield 4 @pytest.fixture - def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): - yield CacheConfig(max_batches=256, + def max_batch_size(self): + yield 4 + + @pytest.fixture + def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size): + yield CacheConfig(max_batches=max_batch_size, block_size=block_size, num_cpu_blocks=num_cpu_blocks, num_gpu_blocks=num_gpu_blocks) @pytest.fixture - def scheduler_config(self): - yield SchedulerConfig(max_batches=4, max_session_len=128, max_request_output_len=64, eviction_type='recompute') + def scheduler_config(self, max_batch_size): + yield SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') @pytest.fixture def seq_meta(self, block_size): @@ -51,7 +58,6 @@ def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): num_blocks = 2 token_ids = torch.tensor([0] * block_size * num_blocks) seq = session.add_sequence(token_ids) - scheduler.add_sequence(seq) assert seq.status == MessageStatus.WAITING assert seq in scheduler.waiting @@ -59,7 +65,7 @@ def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): output = scheduler.schedule(is_prefill=True) block_tables = scheduler.get_block_tables(output.running) - assert seq.status == MessageStatus.RUNNING + assert seq.status == MessageStatus.READY assert seq in output.running assert len(block_tables) == 1 assert len(block_tables[0]) == num_blocks @@ -73,38 +79,34 @@ def test_update(self, scheduler, block_size, num_gpu_blocks): session1 = scheduler.add_session(session_id1) token_ids1 = torch.tensor([0] * block_size * 1) seq1 = session1.add_sequence(token_ids1) - scheduler.add_sequence(seq1) session_id2 = 1 session2 = scheduler.add_session(session_id2) token_ids2 = torch.tensor([0] * block_size * 2) seq2 = session2.add_sequence(token_ids2) - scheduler.add_sequence(seq2) token_ids3 = torch.tensor([0] * block_size * 3) seq3 = session2.add_sequence(token_ids3) - scheduler.add_sequence(seq3) scheduler.schedule(is_prefill=True) - assert seq1.status == MessageStatus.RUNNING - assert seq2.status == MessageStatus.RUNNING + assert seq1.status == MessageStatus.READY + assert seq2.status == MessageStatus.READY assert seq3.status == MessageStatus.WAITING # stop seq - seq1.status = MessageStatus.STOPPED - assert len(scheduler.running) == 1 + seq1.state.stop() + assert len(scheduler.ready) == 1 assert seq1 in scheduler.hanging # end seq - seq1.status = MessageStatus.ENDED - scheduler._remove_sequence(seq1) + seq1.session.remove_sequence(seq1) assert session_id1 in scheduler.sessions - assert seq1 not in scheduler.running + assert seq1 not in scheduler.ready assert seq1 not in scheduler.hanging assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 2 # stop session scheduler.stop_session(session_id2) - assert len(scheduler.running) == 0 + assert len(scheduler.ready) == 0 assert len(scheduler.waiting) == 0 assert len(scheduler.hanging) == 2 @@ -122,25 +124,22 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): # test: add 3 seq token_ids1 = torch.tensor([0] * block_size * 1) seq1 = session.add_sequence(token_ids1) - scheduler.add_sequence(seq1) token_ids2 = torch.tensor([0] * block_size * 2) seq2 = session.add_sequence(token_ids2) - scheduler.add_sequence(seq2) token_ids3 = torch.tensor([0] * block_size * 3) seq3 = session.add_sequence(token_ids3) - scheduler.add_sequence(seq3) scheduler.schedule(is_prefill=True) # seq1: 1 running gpu # seq2: 2 running gpu # seq3: 3 waiting empty - assert seq1.status == MessageStatus.RUNNING - assert seq2.status == MessageStatus.RUNNING + assert seq1.status == MessageStatus.READY + assert seq2.status == MessageStatus.READY assert seq3.status == MessageStatus.WAITING assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 3 # test: waiting alloc - seq2.status = MessageStatus.STOPPED - assert len(scheduler.running) == 1 + seq2.state.stop() + assert len(scheduler.ready) == 1 assert len(scheduler.waiting) == 1 assert len(scheduler.hanging) == 1 @@ -148,17 +147,16 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): # seq1: 1 running gpu # seq2: 2 hanging cpu # seq3: 3 running gpu - assert seq1.status == MessageStatus.RUNNING + assert seq1.status == MessageStatus.READY assert seq2.status == MessageStatus.STOPPED - assert seq3.status == MessageStatus.RUNNING + assert seq3.status == MessageStatus.READY assert block_manager.get_num_free_gpu_blocks() == 0 # test: waiting append token - seq2.status = MessageStatus.WAITING - seq3.status = MessageStatus.ENDED - scheduler._remove_sequence(seq3) + seq2.state.activate() + seq3.session.remove_sequence(seq3) seq2.update_token_ids(torch.tensor([1] * block_size)) - assert len(scheduler.running) == 1 + assert len(scheduler.ready) == 1 assert len(scheduler.waiting) == 1 assert len(scheduler.hanging) == 0 @@ -166,18 +164,18 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): # seq1: 1 running gpu # seq2: 3 running gpu # seq3: 3 nan - assert seq1.status == MessageStatus.RUNNING - assert seq2.status == MessageStatus.RUNNING + assert seq1.status == MessageStatus.READY + assert seq2.status == MessageStatus.READY assert block_manager.get_num_free_gpu_blocks() == 0 # test running append seq1.update_token_ids(torch.tensor([1] * block_size)) seq2.update_token_ids(torch.tensor([1] * block_size)) - assert len(scheduler.running) == 2 + assert len(scheduler.ready) == 2 scheduler.schedule(is_prefill=False) # seq1: 1 waiting cpu # seq2: 4 running gpu # seq3: 3 nan assert seq1.status == MessageStatus.WAITING - assert seq2.status == MessageStatus.RUNNING + assert seq2.status == MessageStatus.READY assert block_manager.get_num_free_gpu_blocks() == 0 From d3259fb5f2610e266079b176b9485333f43e5e22 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 27 Nov 2025 21:19:02 +0800 Subject: [PATCH 2/6] fix pd, better property --- lmdeploy/pytorch/paging/scheduler.py | 106 ++++++++----------- lmdeploy/pytorch/paging/seq_states/states.py | 9 ++ 2 files changed, 53 insertions(+), 62 deletions(-) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 1359c88ae5..21917881ab 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -61,41 +61,54 @@ def __init__( seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size) self.seq_manager = SequenceManager(seq_meta) - @property - def waiting(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING) - return list(seq_map.values()) + @staticmethod + def create_status_list_property(status: MessageStatus): + """Create status list property.""" - @property - def ready(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.READY) - return list(seq_map.values()) + def _get_status_list(self): + seq_map = self.seq_manager.get_sequences(status) + return list(seq_map.values()) - @property - def hanging(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.STOPPED) - return list(seq_map.values()) + return property(_get_status_list) - @property - def running(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING) - return list(seq_map.values()) + @staticmethod + def create_num_status_method(status: MessageStatus): + """Create num status method.""" - @property - def migration_waiting(self): - """Get migration sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_WAITING) - return list(seq_map.values()) + def _num_status(self): + return self.seq_manager.num_sequences(status) - @property - def migration_done(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_DONE) - return list(seq_map.values()) + return _num_status + + @staticmethod + def create_has_status_method(status: MessageStatus): + """Create has status method.""" + + def _has_status(self): + return self.seq_manager.num_sequences(status) > 0 + + return _has_status + + # status list properties + waiting = create_status_list_property(MessageStatus.WAITING) + ready = create_status_list_property(MessageStatus.READY) + hanging = create_status_list_property(MessageStatus.STOPPED) + running = create_status_list_property(MessageStatus.RUNNING) + migration_waiting = create_status_list_property(MessageStatus.MIGRATION_WAITING) + migration_done = create_status_list_property(MessageStatus.MIGRATION_DONE) + + # num status methods + num_waiting = create_num_status_method(MessageStatus.WAITING) + num_ready = create_num_status_method(MessageStatus.READY) + num_running = create_num_status_method(MessageStatus.RUNNING) + num_migration_waiting = create_num_status_method(MessageStatus.MIGRATION_WAITING) + num_migration_done = create_num_status_method(MessageStatus.MIGRATION_DONE) + + # has status methods + has_waiting = create_has_status_method(MessageStatus.WAITING) + has_ready = create_has_status_method(MessageStatus.READY) + has_migration_waiting = create_has_status_method(MessageStatus.MIGRATION_WAITING) + has_migration_done = create_has_status_method(MessageStatus.MIGRATION_DONE) def add_session(self, session_id: int): """Add new session. @@ -274,6 +287,7 @@ def end_session(self, session_id: int): session = self.sessions[session_id] seqs = list(session.sequences.values()) for seq in seqs: + # stop session so it won't get scheduled again seq.state.stop() session.remove_sequence(seq) self.sessions.pop(session_id) @@ -282,42 +296,10 @@ def has_unfinished(self): """Check if there are any unfinished message.""" return self.has_ready() or self.has_waiting() or self.has_migration_done() - def has_ready(self): - return self.num_ready() > 0 - - def has_waiting(self): - return self.num_waiting() > 0 - - def has_migration_waiting(self): - return self.num_migration_waiting() > 0 - - def has_migration_done(self): - return self.num_migration_done() > 0 - def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] - def num_ready(self): - """Num running.""" - return self.seq_manager.num_sequences(MessageStatus.READY) - - def num_waiting(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.WAITING) - - def num_migration_done(self): - """Num migration done.""" - return self.seq_manager.num_sequences(MessageStatus.MIGRATION_DONE) - - def num_migration_waiting(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.MIGRATION_WAITING) - - def num_running(self): - """Num locked.""" - return self.seq_manager.num_sequences(MessageStatus.RUNNING) - def active_seqs(self, running: SeqList): """Lock running sequence.""" for seq in running: diff --git a/lmdeploy/pytorch/paging/seq_states/states.py b/lmdeploy/pytorch/paging/seq_states/states.py index 4ab7b1f154..1f44f02111 100644 --- a/lmdeploy/pytorch/paging/seq_states/states.py +++ b/lmdeploy/pytorch/paging/seq_states/states.py @@ -122,6 +122,9 @@ class ToBeMigratedState(StateBase): """State for to be migrated sequences.""" status = MessageStatus.TO_BE_MIGRATED + def finish(self): + self.to_state(StoppedState) + class MigrationWaitingState(StateBase): """State for migration waiting sequences.""" @@ -149,6 +152,9 @@ class MigrationDoneState(StateBase): """State for migration done sequences.""" status = MessageStatus.MIGRATION_DONE + def activate(self): + self.to_state(ReadyState) + def finish(self): self.to_state(ReadyState) @@ -157,6 +163,9 @@ class MigrationRunningState(StateBase): """State for migration running sequences.""" status = MessageStatus.MIGRATION_RUNNING + def deactivate(self): + self.to_state(MigrationDoneState) + def finish(self): self.to_state(MigrationDoneState) From 59138541055f00844acdfadc90ca6a9e7b2ad5cb Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 27 Nov 2025 21:21:59 +0800 Subject: [PATCH 3/6] skip decoding warmup --- lmdeploy/pytorch/engine/model_agent.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index c6c9345539..6be141fd7f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -440,6 +440,9 @@ def warmup(self): # warmup decoding(with cuda graph) capture_batch_sizes = self.patched_model.get_capture_batch_sizes() capture_batch_sizes = sorted(capture_batch_sizes, reverse=True) + if self.cache_config.role == EngineRole.Prefill: + # do not warmup decoding for prefill engine + capture_batch_sizes = [] for num_tokens in capture_batch_sizes: inputs = self.inputs_strategy.make_dummy(num_tokens, is_decoding=True, From b7b9f4b56fe84192cb5a3e65f9375e7944b9ad9b Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 1 Dec 2025 19:58:48 +0800 Subject: [PATCH 4/6] rename --- lmdeploy/pytorch/engine/engine.py | 16 ++++--------- lmdeploy/pytorch/engine/request.py | 2 +- lmdeploy/pytorch/messages.py | 10 ++++---- lmdeploy/pytorch/paging/scheduler.py | 35 +++++++++++++--------------- 4 files changed, 27 insertions(+), 36 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 609f8fa274..2fe9017d88 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -656,11 +656,10 @@ def __update_max_new_tokens(msg): scheduler = self.scheduler for req in reqs: session_id = req.data['session_id'] - if scheduler is None: + sess = scheduler.sessions.get(session_id, None) + if sess is None: self._response(req.resp, ResponseType.SESSION_NOT_EXIST) continue - session_id = req.data['session_id'] - sess = scheduler.sessions[session_id] # TODO: support 1 session n sequence sampling_param = req.data['sampling_param'] if len(sess.sequences) == 0: @@ -675,7 +674,6 @@ def __update_max_new_tokens(msg): resp_cache=req.data.get('with_cache'), preserve_cache=req.data.get('preserve_cache')) msg = next(iter(sess.sequences.values())) - __update_max_new_tokens(msg) if migration_request: self.migration_event.set() else: @@ -688,8 +686,8 @@ def __update_max_new_tokens(msg): ) msg.sampling_param = sampling_param msg.state.activate() - __update_max_new_tokens(msg) + __update_max_new_tokens(msg) msg.resp = req.resp @property @@ -697,10 +695,6 @@ def model_config(self) -> ModelConfig: """Model config.""" return self.executor.model_config - @property - def gpu_count(self): - return self.dist_config.world_size - @property def torch_int_dtype(self): """Return int32 for cuda, int64 for others.""" @@ -1309,8 +1303,8 @@ async def async_loop(self): forward_event=forward_event, has_runable_event=has_runable_event, inputs_maker=inputs_maker) - except Exception as e: - logger.exception(f'exception happened: {type(e)} {e}') + except Exception: + logger.exception('Engine main loop failed.') finally: self._loop_finally() diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 466e102e22..268d9556dd 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -108,7 +108,7 @@ def _gather_request(self, req_types: List[RequestType], data: List[Any]): resps = [] for rtype, rdata in zip(req_types, data): event = asyncio.Event() - resp = Response(type=ResponseType.HANDLER_NOT_EXIST, + resp = Response(type=ResponseType.INTERNAL_ENGINE_ERROR, sender_id=self.sender_id, event=event, data=None, diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 1fd26ea2cd..699988f9b6 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum +from collections import defaultdict from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -179,9 +180,7 @@ class SequenceManager: def __init__(self, seq_meta: SequenceMeta) -> None: self._seq_map: SeqMap = dict() - self._status_seq_map: Dict[MessageStatus, SeqMap] = dict() - for status in MessageStatus: - self._status_seq_map[status] = dict() + self._status_seq_map: Dict[MessageStatus, SeqMap] = defaultdict(dict) self.seq_meta = seq_meta self._seq_count = 0 @@ -203,9 +202,10 @@ def num_sequences(self, status: MessageStatus): """Num sequences.""" return len(self.get_sequences(status)) - def add_sequence(self, seq: 'SchedulerSequence', status: MessageStatus): + def add_sequence(self, seq: 'SchedulerSequence'): """Add sequence.""" seq_id = seq.seq_id + status = seq.status status_map = self._status_seq_map[status] self._seq_map[seq_id] = seq status_map[seq_id] = seq @@ -288,7 +288,7 @@ def add_sequence(self, # update seq manager status = MessageStatus.WAITING if migration_request is None else MessageStatus.MIGRATION_WAITING seq.set_state(build_seq_state(self.scheduler, seq, status)) - self.seq_manager.add_sequence(seq, status) + self.seq_manager.add_sequence(seq) # metrics seq.record_event(EventType.QUEUED) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 637dfa5ab6..e60fdc358f 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -17,6 +17,7 @@ logger = get_logger('lmdeploy') +MapType = Dict[int, int] SeqList = List[SchedulerSequence] @@ -25,9 +26,9 @@ class SchedulerOutput: """Output of schedule.""" running: SeqList - swap_in_map: Dict[int, int] - swap_out_map: Dict[int, int] - copy_map: Dict[int, int] + swap_in_map: MapType + swap_out_map: MapType + copy_map: MapType class Scheduler: @@ -165,9 +166,9 @@ def _schedule_prefill(self, prealloc_size: int = 0): max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running() eviction_helper = self.eviction_helper - swap_out_map: Dict[int, int] = dict() - swap_in_map: Dict[int, int] = dict() - copy_map: Dict[int, int] = dict() + swap_out_map: MapType = dict() + swap_in_map: MapType = dict() + copy_map: MapType = dict() running: SeqList = [] token_count = 0 @@ -227,9 +228,9 @@ def _reorder_running(): assert len(running) != 0 eviction_helper = self.eviction_helper - swap_out_map: Dict[int, int] = dict() - swap_in_map: Dict[int, int] = dict() - copy_map: Dict[int, int] = dict() + swap_out_map: MapType = dict() + swap_in_map: MapType = dict() + copy_map: MapType = dict() def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): """Evict until can append.""" @@ -312,28 +313,24 @@ def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] - def active_seqs(self, running: SeqList): + def active_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.READY): """Lock running sequence.""" for seq in running: - if seq.status == MessageStatus.READY: + if seq.status == filter_status: seq.state.activate() - def deactive_seqs(self, running: SeqList): + def deactive_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.RUNNING): for seq in running: - if seq.status == MessageStatus.RUNNING: + if seq.status == filter_status: seq.state.deactivate() def activate_migration_seqs(self, running: SeqList): """Lock running sequence.""" - for seq in running: - if seq.status == MessageStatus.MIGRATION_READY: - seq.state.activate() + return self.active_seqs(running, filter_status=MessageStatus.MIGRATION_READY) def deactivate_migration_seqs(self, running: SeqList): """Unlock running migration.""" - for seq in running: - if seq.status == MessageStatus.MIGRATION_RUNNING: - seq.state.deactivate() + return self.deactive_seqs(running, filter_status=MessageStatus.MIGRATION_RUNNING) def collect_migration_done(self): for seq in self.migration_done: From 686452aaee8a98903511483fa1f078fbe6ff1578 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 1 Dec 2025 21:23:41 +0800 Subject: [PATCH 5/6] add more profile logs --- lmdeploy/pytorch/engine/engine.py | 3 ++- lmdeploy/pytorch/paging/scheduler.py | 4 ++++ lmdeploy/pytorch/strategies/ar/sampling.py | 2 ++ lmdeploy/pytorch/strategies/dllm/sampling.py | 3 +++ 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 2fe9017d88..9de78855ca 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -767,7 +767,7 @@ def __has_values(input_multimodals): return vision_embedding_inputs @torch.inference_mode() - @record_function('CreateModelInputs') + @record_function('create_model_inputs') def create_model_inputs(self, messages: SeqList, is_prefill: bool): """Create model inputs from messages. @@ -932,6 +932,7 @@ def _make_infer_outputs( outputs[session_id].logits = logits.split(seq_length)[idx] return outputs + @record_function('make_forward_inputs') def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False): """Make forward inputs.""" diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index e60fdc358f..bbf7ff903a 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from typing import Dict, List +from torch.profiler import record_function + from lmdeploy.messages import EventType, ScheduleMetrics from lmdeploy.utils import get_logger @@ -161,6 +163,7 @@ def _reorder_migrating(): return migration_ready + @record_function('schedule_prefill') def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" @@ -217,6 +220,7 @@ def _reorder_waiting(): return running, swap_in_map, swap_out_map, copy_map + @record_function('schedule_decoding') def _schedule_decoding(self, prealloc_size: int = 0): """Schedule decoding.""" diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index 3b97940e9b..b37ce85e8f 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -2,6 +2,7 @@ from typing import List import torch +from torch.profiler import record_function from lmdeploy.pytorch.engine.logits_process import SamplingInputs from lmdeploy.pytorch.messages import SchedulerSequence @@ -41,6 +42,7 @@ def __init__(self, pad_token_id: int) -> None: self.pad_token_id = pad_token_id self.session_to_cleanup = [] + @record_function('make_sampling_inputs') def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: """Create sampling inputs from the sequences.""" batch_size = len(seqs) diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index 45048e25a5..c181704b0a 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List +from torch.profiler import record_function + from lmdeploy.pytorch.engine.logits_process import SamplingInputs from lmdeploy.pytorch.messages import SchedulerSequence @@ -16,6 +18,7 @@ def __init__(self, pad_token_id: int, dllm_block_length: int) -> None: super().__init__(pad_token_id) self.dllm_block_length = dllm_block_length + @record_function('make_sampling_inputs') def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: """Create sampling inputs from the sequences.""" out = super().make_sampling_inputs(seqs) From 3385d04a39e193bcda4c213d77eafc387e223fe4 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 2 Dec 2025 14:09:34 +0800 Subject: [PATCH 6/6] add config builder --- lmdeploy/pytorch/engine/config_builder.py | 106 +++++++++++++++++++ lmdeploy/pytorch/engine/engine.py | 119 +++------------------- 2 files changed, 119 insertions(+), 106 deletions(-) create mode 100644 lmdeploy/pytorch/engine/config_builder.py diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py new file mode 100644 index 0000000000..d5c0fd7241 --- /dev/null +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os + +from lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig +from lmdeploy.pytorch.config import (BackendConfig, CacheConfig, DistConfig, MiscConfig, SchedulerConfig, + SpecDecodeConfig) +from lmdeploy.utils import get_logger, get_max_batch_size, get_model + + +class ConfigBuilder: + + @staticmethod + def update_engine_config(engine_config: PytorchEngineConfig): + """Update pytorch engine config.""" + logger = get_logger('lmdeploy') + + # make sure engine exits + if engine_config is None: + engine_config = PytorchEngineConfig() + else: + engine_config = copy.deepcopy(engine_config) + + if engine_config.max_batch_size is None: + engine_config.max_batch_size = get_max_batch_size(engine_config.device_type) + + if engine_config.dllm_block_length is not None: + max_prefill_token_num = engine_config.max_prefill_token_num + max_batch_size = engine_config.max_batch_size + if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num: + engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length + logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} ' + f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size ' + f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).') + + if engine_config.dp != 1: + if engine_config.tp == 1 and engine_config.ep == 1: + engine_config.dp = 1 + engine_config.dp_rank = 0 + + return engine_config + + @staticmethod + def build_scheduler_config(engine_config: PytorchEngineConfig): + """Build scheduler config.""" + scheduler_config = SchedulerConfig(max_batches=engine_config.max_batch_size, + max_session_len=engine_config.session_len, + prefill_interval=engine_config.prefill_interval) + return scheduler_config + + @staticmethod + def build_cache_config(engine_config: PytorchEngineConfig): + """Build cache config.""" + cache_config = CacheConfig(max_batches=engine_config.max_batch_size, + block_size=engine_config.block_size, + num_cpu_blocks=engine_config.num_cpu_blocks, + num_gpu_blocks=engine_config.num_gpu_blocks, + cache_max_entry_count=engine_config.cache_max_entry_count, + max_prefill_token_num=engine_config.max_prefill_token_num, + enable_prefix_caching=engine_config.enable_prefix_caching, + quant_policy=engine_config.quant_policy, + device_type=engine_config.device_type, + migration_backend=engine_config.migration_backend, + role=engine_config.role) + return cache_config + + @staticmethod + def build_backend_config(engine_config: PytorchEngineConfig): + """Build backend config.""" + backend_config = BackendConfig( + eager_mode=engine_config.eager_mode, + device_type=engine_config.device_type, + ) + return backend_config + + @staticmethod + def build_dist_config(engine_config: PytorchEngineConfig): + """Build dist config.""" + dist_config = DistConfig.from_engine_config(engine_config=engine_config) + return dist_config + + @staticmethod + def build_misc_config(engine_config: PytorchEngineConfig): + """Build misc config.""" + misc_config = MiscConfig.from_engine_config(engine_config) + return misc_config + + @staticmethod + def build_specdecode_config(target_model, speculative_config: SpeculativeConfig, engine_config: PytorchEngineConfig, + cache_config: CacheConfig): + """Build spec decode config.""" + specdecode_config = None + if speculative_config is not None: + draft_model = speculative_config.model + if draft_model and not os.path.exists(speculative_config.model): + draft_model = get_model(draft_model, engine_config.download_dir, engine_config.revision) + + specdecode_config = SpecDecodeConfig.from_config( + method=speculative_config.method, + num_speculative_tokens=speculative_config.num_speculative_tokens, + model=draft_model, + target_model=target_model, + target_cache_cfg=cache_config, + dtype=engine_config.dtype, + ) + return specdecode_config diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 9de78855ca..a28c72bf5d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import copy import gc import logging import os @@ -18,15 +17,16 @@ from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest) from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.utils import get_logger, get_max_batch_size, get_model +from lmdeploy.utils import get_logger, get_model from ..adapter.adapter import AdapterManager -from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig, SpecDecodeConfig +from ..config import CacheConfig, ModelConfig from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler from ..strategies import build_strategy_factory from .base import EngineBase +from .config_builder import ConfigBuilder from .engine_checker import EngineChecker from .executor import build_executor from .model_agent import BatchedOutputs @@ -75,99 +75,6 @@ def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): return torch.as_tensor(out, dtype=dtype) -def _update_engine_config(engine_config: PytorchEngineConfig): - """Update pytorch engine config.""" - # make sure engine exits - if engine_config is None: - engine_config = PytorchEngineConfig() - else: - engine_config = copy.deepcopy(engine_config) - - if engine_config.max_batch_size is None: - engine_config.max_batch_size = get_max_batch_size(engine_config.device_type) - - if engine_config.dllm_block_length is not None: - max_prefill_token_num = engine_config.max_prefill_token_num - max_batch_size = engine_config.max_batch_size - if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num: - engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length - logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} ' - f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size ' - f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).') - - if engine_config.dp != 1: - if engine_config.tp == 1 and engine_config.ep == 1: - engine_config.dp = 1 - engine_config.dp_rank = 0 - - return engine_config - - -def _build_scheduler_config(engine_config: PytorchEngineConfig): - """Build scheduler config.""" - scheduler_config = SchedulerConfig(max_batches=engine_config.max_batch_size, - max_session_len=engine_config.session_len, - prefill_interval=engine_config.prefill_interval) - return scheduler_config - - -def _build_cache_config(engine_config: PytorchEngineConfig): - """Build cache config.""" - cache_config = CacheConfig(max_batches=engine_config.max_batch_size, - block_size=engine_config.block_size, - num_cpu_blocks=engine_config.num_cpu_blocks, - num_gpu_blocks=engine_config.num_gpu_blocks, - cache_max_entry_count=engine_config.cache_max_entry_count, - max_prefill_token_num=engine_config.max_prefill_token_num, - enable_prefix_caching=engine_config.enable_prefix_caching, - quant_policy=engine_config.quant_policy, - device_type=engine_config.device_type, - migration_backend=engine_config.migration_backend, - role=engine_config.role) - return cache_config - - -def _build_backend_config(engine_config: PytorchEngineConfig): - """Build backend config.""" - backend_config = BackendConfig( - eager_mode=engine_config.eager_mode, - device_type=engine_config.device_type, - ) - return backend_config - - -def _build_dist_config(engine_config: PytorchEngineConfig): - """Build dist config.""" - dist_config = DistConfig.from_engine_config(engine_config=engine_config) - return dist_config - - -def _build_misc_config(engine_config: PytorchEngineConfig): - """Build misc config.""" - misc_config = MiscConfig.from_engine_config(engine_config) - return misc_config - - -def _build_specdecode_config(target_model, speculative_config: SpeculativeConfig, engine_config: PytorchEngineConfig, - cache_config: CacheConfig): - """Build spec decode config.""" - specdecode_config = None - if speculative_config is not None: - draft_model = speculative_config.model - if draft_model and not os.path.exists(speculative_config.model): - draft_model = get_model(draft_model, engine_config.download_dir, engine_config.revision) - - specdecode_config = SpecDecodeConfig.from_config( - method=speculative_config.method, - num_speculative_tokens=speculative_config.num_speculative_tokens, - model=draft_model, - target_model=target_model, - target_cache_cfg=cache_config, - dtype=engine_config.dtype, - ) - return specdecode_config - - def _build_seq_meta(cache_config: CacheConfig, strategy: Any): from lmdeploy.pytorch.messages import SequenceMeta @@ -202,11 +109,11 @@ def clear(self): class RunableEventBase: """Runable event base.""" - async def wait(self, idx: int): + async def wait(self): """Wait event.""" raise NotImplementedError('Not implemented.') - def set(self, idx: int = None): + def set(self): """Set event.""" raise NotImplementedError('Not implemented.') @@ -365,7 +272,7 @@ def __init__( speculative_config: SpeculativeConfig = None, ) -> None: # make sure engine config exist - engine_config = _update_engine_config(engine_config) + engine_config = ConfigBuilder.update_engine_config(engine_config) # frequently gc would cause latency spike # default threshold (700, 10, 10) @@ -393,14 +300,14 @@ def __init__( checker.handle() # build configs - scheduler_config = _build_scheduler_config(engine_config) - cache_config = _build_cache_config(engine_config) - backend_config = _build_backend_config(engine_config) - dist_config = _build_dist_config(engine_config) - misc_config = _build_misc_config(engine_config) - + scheduler_config = ConfigBuilder.build_scheduler_config(engine_config) + cache_config = ConfigBuilder.build_cache_config(engine_config) + backend_config = ConfigBuilder.build_backend_config(engine_config) + dist_config = ConfigBuilder.build_dist_config(engine_config) + misc_config = ConfigBuilder.build_misc_config(engine_config) # spec decode - self.specdecode_config = _build_specdecode_config(model_path, speculative_config, engine_config, cache_config) + self.specdecode_config = ConfigBuilder.build_specdecode_config(model_path, speculative_config, engine_config, + cache_config) # build model agent self.executor = build_executor(