Skip to content

Commit 0542b2d

Browse files
committed
add ut
1 parent ab7d5c4 commit 0542b2d

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

lmdeploy/pytorch/strategies/ar/sequence.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def set_step(self, step: int, routed_experts: np.ndarray = None):
8181
if self.return_routed_experts:
8282
if routed_experts is not None:
8383
self.all_routed_experts.append(routed_experts)
84-
assert routed_experts.shape[0] == len(self.all_routed_experts)
8584
else:
8685
self.all_routed_experts.resize(step)
8786

tests/pytorch/paging/test_block_trie.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
from lmdeploy.pytorch.config import CacheConfig
5-
from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta
5+
from lmdeploy.pytorch.messages import SamplingParam, SchedulerSession, SequenceManager, SequenceMeta
66
from lmdeploy.pytorch.paging.block_manager import build_block_manager
77
from lmdeploy.pytorch.paging.block_trie import BlockTrie
88

@@ -37,13 +37,55 @@ def block_mgr(self, cache_config):
3737
def block_trie(self, cache_config, block_mgr):
3838
yield BlockTrie(cache_config, block_mgr)
3939

40+
@pytest.fixture
41+
def num_moe_layers(self):
42+
yield 4
43+
44+
@pytest.fixture
45+
def experts_topk(self):
46+
yield 4
47+
4048
@pytest.fixture
4149
def seq_manager(self, block_size):
4250
from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
4351
strategy = ARSequenceStrategy()
4452
seq_meta = SequenceMeta(block_size, strategy=strategy)
4553
yield SequenceManager(seq_meta)
4654

55+
def test_with_routed_experts(self, block_trie, block_mgr, seq_manager, num_moe_layers, experts_topk):
56+
57+
def _get_routed_experts(size, value):
58+
return np.full((size, num_moe_layers, experts_topk), value, dtype=np.int32)
59+
60+
sess = SchedulerSession(0, seq_manager)
61+
block_size = sess.seq_meta.block_size
62+
token_ids = ([1] * block_size + [2] * block_size)
63+
all_routed_experts = [_get_routed_experts(block_size, 1), _get_routed_experts(block_size, 2)]
64+
token_ids += [3] * (block_size // 2)
65+
all_routed_experts += [_get_routed_experts(block_size // 2, 3)]
66+
seq = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True))
67+
all_routed_experts += [_get_routed_experts(block_size - 1, 4)]
68+
routed_experts = np.concatenate(all_routed_experts, axis=0)
69+
seq.update_token_ids([4] * block_size, routed_experts=routed_experts)
70+
71+
# test allocate
72+
block_mgr.allocate(seq)
73+
block_trie.allocate(seq)
74+
node = getattr(seq.logical_blocks, 'last_shared_node', None)
75+
assert node is not None
76+
assert node.routed_experts is not None
77+
target_routed_experts = np.concatenate(
78+
[_get_routed_experts(block_size // 2, 3),
79+
_get_routed_experts(block_size // 2, 4)], axis=0)
80+
assert np.array_equal(node.routed_experts, target_routed_experts)
81+
82+
# test match
83+
seq_query = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True))
84+
block_trie.match(seq_query)
85+
assert seq_query.all_routed_experts is not None
86+
assert len(seq_query.all_routed_experts) == block_size * 2
87+
assert np.array_equal(seq_query.all_routed_experts.get_real(), np.concatenate(all_routed_experts[:2], axis=0))
88+
4789
def test_allocate(self, block_trie, block_mgr, seq_manager):
4890
allocator = block_trie.allocator
4991
sess = SchedulerSession(0, seq_manager)

0 commit comments

Comments
 (0)