Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ class ScheduleMetrics:
active_blocks: int = 0
cached_blocks: int = 0
free_blocks: int = 0
prefix_cache_hit_rate: float = 0


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def log(self):
f'Unfinished: {scheduler_stats.num_total_reqs-scheduler_stats.num_finished_reqs} reqs, '
f'Running: {scheduler_stats.num_running_reqs} reqs, '
f'Waiting: {scheduler_stats.num_waiting_reqs} reqs, '
f'GPU KV cache usage: {scheduler_stats.gpu_cache_usage * 100 :.1f}%')
f'GPU KV cache usage: {scheduler_stats.gpu_cache_usage * 100 :.1f}%, '
f'Prefix cache hit rate: {scheduler_stats.prefix_cache_hit_rate * 100 :.1f}%')

print(log_msg, flush=True)
if spec_log_msg:
print(spec_log_msg, flush=True)
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class SchedulerStats:
num_running_reqs: currently executing requests.
num_waiting_reqs: Requests queued waiting for execution.
gpu_cache_usage: Fraction of GPU KV blocks utilized (0.0 to 1.0).
prefix_cache_hit_rate: Prefix caching hit rate.
"""

num_total_reqs: int = 0
num_finished_reqs: int = 0
num_running_reqs: int = 0
num_waiting_reqs: int = 0
gpu_cache_usage: float = 0.0
prefix_cache_hit_rate: float = 0.0

def __repr__(self):
"""Return a human-readable string representation."""
Expand All @@ -36,12 +38,14 @@ def __repr__(self):
f' num_running_reqs={self.num_running_reqs},\n'
f' num_waiting_reqs={self.num_waiting_reqs},\n'
f' gpu_cache_usage={self.gpu_cache_usage:.6f},\n'
f' prefix_cache_hit_rate={self.prefix_cache_hit_rate:.6f},\n'
')')

def update_from_schedule_metrics(self, scheduled_metrics: ScheduleMetrics):
self.num_running_reqs = scheduled_metrics.active_seqs
self.num_waiting_reqs = scheduled_metrics.waiting_seqs
self.gpu_cache_usage = 1.0 - (scheduled_metrics.free_blocks / scheduled_metrics.total_blocks)
self.prefix_cache_hit_rate = scheduled_metrics.prefix_cache_hit_rate


class RequestState:
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,6 @@ def update_token_ids(self,
"""Update token ids, old token ids will be added to history."""
raise NotImplementedError('NotImplemented')

def set_step(self, step: int):
def set_step(self, step: int, routed_experts: np.ndarray = None):
"""Set step."""
raise NotImplementedError('NotImplemented')
54 changes: 49 additions & 5 deletions lmdeploy/pytorch/paging/block_trie.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import heapq
from dataclasses import dataclass
from typing import Dict, Set

import numpy as np
Expand All @@ -10,16 +11,36 @@
from .block_manager import BaseBlockManager


@dataclass
class PrefixCacheStats:
"""Prefix caching stats."""
num_query_tokens: int = 0
num_hit_tokens: int = 0

def reset(self):
self.num_query_tokens = 0
self.num_hit_tokens = 0

def hit_rate(self):
return 0.0 if self.num_query_tokens <= 0 else float(self.num_hit_tokens) / self.num_query_tokens


class Node:
"""Node of block trie."""

def __init__(self, hash_key: int, block: int, tokens: np.ndarray, num_matched: int = 0):
def __init__(self,
hash_key: int,
block: int,
tokens: np.ndarray,
num_matched: int = 0,
routed_experts: np.ndarray = None):
self.hash_key = hash_key
self.block = block
self.tokens = tokens
self.num_matched = num_matched
self.children: Dict[int, 'Node'] = dict()
self._parent: 'Node' = None
self.routed_experts = routed_experts

@property
def parent(self):
Expand Down Expand Up @@ -54,6 +75,11 @@ def __init__(self, cache_config: CacheConfig, block_manager: BaseBlockManager):
# caches with different adapter should not be shared.
self._roots: Dict[str, Node] = dict()
self.leaves: Set[Node] = set()
self.stats = PrefixCacheStats()

def hit_rate(self):
"""Get hit rate."""
return self.stats.hit_rate()

def get_root(self, adapter_name: str):
"""Get root by adapter name."""
Expand All @@ -73,14 +99,19 @@ def match(self, seq: SchedulerSequence):
curr: Node = getattr(logical_blocks, 'last_shared_node', None)
if curr is None:
curr = self.get_root(seq.adapter_name)
init_num_matched = curr.num_matched
num_matched = curr.num_matched

def __match_success(node: Node):
nonlocal curr, num_matched
nonlocal curr, num_matched, matched_routed_experts
matched_blocks.append(node.block)
if seq.return_routed_experts and node.routed_experts is not None:
matched_routed_experts.append(node.routed_experts)
curr = node
num_matched += block_size

matched_routed_experts = []

while num_matched + block_size < seq.num_valid_ids:
curr_tokens = seq.history_cache[num_matched:num_matched + block_size]

Expand All @@ -99,7 +130,15 @@ def __match_success(node: Node):
self.allocator.update_access_time(matched_blocks)
self.allocator.add_ref_count(matched_blocks, 1)
seq.logical_blocks.append(matched_blocks)
seq.set_step(num_matched)
if len(matched_routed_experts) > 0:
matched_routed_experts = np.concatenate(matched_routed_experts, axis=0)
else:
matched_routed_experts = None
seq.set_step(num_matched, routed_experts=matched_routed_experts)

# record prefix hit
self.stats.num_query_tokens += seq.num_all_ids - init_num_matched
self.stats.num_hit_tokens += num_matched - init_num_matched

seq.logical_blocks.last_shared_node = curr

Expand Down Expand Up @@ -129,7 +168,8 @@ def allocate(self, seq: SchedulerSequence):
free_blocks = []
while num_matched + block_size <= num_valid_ids:
curr_tokens = seq.history_cache[num_matched:num_matched + block_size]

routed_experts = seq.all_routed_experts.get_real()[num_matched:num_matched +
block_size] if seq.return_routed_experts else None
block = logical_blocks[block_id]

hash_key = hash(('random', tuple(curr_tokens)))
Expand All @@ -142,7 +182,11 @@ def allocate(self, seq: SchedulerSequence):
free_blocks.append(block)
logical_blocks[block_id] = node.block
else:
node = Node(hash_key=hash_key, block=block, tokens=curr_tokens, num_matched=num_matched + block_size)
node = Node(hash_key=hash_key,
block=block,
tokens=curr_tokens,
num_matched=num_matched + block_size,
routed_experts=routed_experts)
node.parent = parent
blocks.append(node.block)
num_matched += block_size
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/paging/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,5 @@ def schedule_metrics(self):
waiting_seqs=self.num_waiting() + self.num_running(),
total_blocks=self.block_manager.num_gpu_blocks,
free_blocks=self.block_manager.get_num_free_gpu_blocks(),
prefix_cache_hit_rate=self.block_trie.hit_rate(),
)
7 changes: 5 additions & 2 deletions lmdeploy/pytorch/strategies/ar/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def update_token_ids(self,
if model_meta is not None:
self.model_meta = model_meta

def set_step(self, step: int):
def set_step(self, step: int, routed_experts: np.ndarray = None):
"""Set step."""
num_all_ids = self.num_all_ids
# update step for vlm
Expand All @@ -79,7 +79,10 @@ def set_step(self, step: int):
self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids)

if self.return_routed_experts:
self.all_routed_experts.resize(step)
if routed_experts is not None:
self.all_routed_experts.append(routed_experts)
else:
self.all_routed_experts.resize(step)


class ARSequenceStrategy(SequenceStrategy):
Expand Down
44 changes: 43 additions & 1 deletion tests/pytorch/paging/test_block_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from lmdeploy.pytorch.config import CacheConfig
from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta
from lmdeploy.pytorch.messages import SamplingParam, SchedulerSession, SequenceManager, SequenceMeta
from lmdeploy.pytorch.paging.block_manager import build_block_manager
from lmdeploy.pytorch.paging.block_trie import BlockTrie

Expand Down Expand Up @@ -37,13 +37,55 @@ def block_mgr(self, cache_config):
def block_trie(self, cache_config, block_mgr):
yield BlockTrie(cache_config, block_mgr)

@pytest.fixture
def num_moe_layers(self):
yield 4

@pytest.fixture
def experts_topk(self):
yield 4

@pytest.fixture
def seq_manager(self, block_size):
from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
strategy = ARSequenceStrategy()
seq_meta = SequenceMeta(block_size, strategy=strategy)
yield SequenceManager(seq_meta)

def test_with_routed_experts(self, block_trie, block_mgr, seq_manager, num_moe_layers, experts_topk):

def _get_routed_experts(size, value):
return np.full((size, num_moe_layers, experts_topk), value, dtype=np.int32)

sess = SchedulerSession(0, seq_manager)
block_size = sess.seq_meta.block_size
token_ids = ([1] * block_size + [2] * block_size)
all_routed_experts = [_get_routed_experts(block_size, 1), _get_routed_experts(block_size, 2)]
token_ids += [3] * (block_size // 2)
all_routed_experts += [_get_routed_experts(block_size // 2, 3)]
seq = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True))
all_routed_experts += [_get_routed_experts(block_size - 1, 4)]
routed_experts = np.concatenate(all_routed_experts, axis=0)
seq.update_token_ids([4] * block_size, routed_experts=routed_experts)

# test allocate
block_mgr.allocate(seq)
block_trie.allocate(seq)
node = getattr(seq.logical_blocks, 'last_shared_node', None)
assert node is not None
assert node.routed_experts is not None
target_routed_experts = np.concatenate(
[_get_routed_experts(block_size // 2, 3),
_get_routed_experts(block_size // 2, 4)], axis=0)
assert np.array_equal(node.routed_experts, target_routed_experts)

# test match
seq_query = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True))
block_trie.match(seq_query)
assert seq_query.all_routed_experts is not None
assert len(seq_query.all_routed_experts) == block_size * 2
assert np.array_equal(seq_query.all_routed_experts.get_real(), np.concatenate(all_routed_experts[:2], axis=0))

def test_allocate(self, block_trie, block_mgr, seq_manager):
allocator = block_trie.allocator
sess = SchedulerSession(0, seq_manager)
Expand Down
Loading