|
2 | 2 | import pytest |
3 | 3 |
|
4 | 4 | from lmdeploy.pytorch.config import CacheConfig |
5 | | -from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta |
| 5 | +from lmdeploy.pytorch.messages import SamplingParam, SchedulerSession, SequenceManager, SequenceMeta |
6 | 6 | from lmdeploy.pytorch.paging.block_manager import build_block_manager |
7 | 7 | from lmdeploy.pytorch.paging.block_trie import BlockTrie |
8 | 8 |
|
@@ -37,13 +37,55 @@ def block_mgr(self, cache_config): |
37 | 37 | def block_trie(self, cache_config, block_mgr): |
38 | 38 | yield BlockTrie(cache_config, block_mgr) |
39 | 39 |
|
| 40 | + @pytest.fixture |
| 41 | + def num_moe_layers(self): |
| 42 | + yield 4 |
| 43 | + |
| 44 | + @pytest.fixture |
| 45 | + def experts_topk(self): |
| 46 | + yield 4 |
| 47 | + |
40 | 48 | @pytest.fixture |
41 | 49 | def seq_manager(self, block_size): |
42 | 50 | from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy |
43 | 51 | strategy = ARSequenceStrategy() |
44 | 52 | seq_meta = SequenceMeta(block_size, strategy=strategy) |
45 | 53 | yield SequenceManager(seq_meta) |
46 | 54 |
|
| 55 | + def test_with_routed_experts(self, block_trie, block_mgr, seq_manager, num_moe_layers, experts_topk): |
| 56 | + |
| 57 | + def _get_routed_experts(size, value): |
| 58 | + return np.full((size, num_moe_layers, experts_topk), value, dtype=np.int32) |
| 59 | + |
| 60 | + sess = SchedulerSession(0, seq_manager) |
| 61 | + block_size = sess.seq_meta.block_size |
| 62 | + token_ids = ([1] * block_size + [2] * block_size) |
| 63 | + all_routed_experts = [_get_routed_experts(block_size, 1), _get_routed_experts(block_size, 2)] |
| 64 | + token_ids += [3] * (block_size // 2) |
| 65 | + all_routed_experts += [_get_routed_experts(block_size // 2, 3)] |
| 66 | + seq = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True)) |
| 67 | + all_routed_experts += [_get_routed_experts(block_size - 1, 4)] |
| 68 | + routed_experts = np.concatenate(all_routed_experts, axis=0) |
| 69 | + seq.update_token_ids([4] * block_size, routed_experts=routed_experts) |
| 70 | + |
| 71 | + # test allocate |
| 72 | + block_mgr.allocate(seq) |
| 73 | + block_trie.allocate(seq) |
| 74 | + node = getattr(seq.logical_blocks, 'last_shared_node', None) |
| 75 | + assert node is not None |
| 76 | + assert node.routed_experts is not None |
| 77 | + target_routed_experts = np.concatenate( |
| 78 | + [_get_routed_experts(block_size // 2, 3), |
| 79 | + _get_routed_experts(block_size // 2, 4)], axis=0) |
| 80 | + assert np.array_equal(node.routed_experts, target_routed_experts) |
| 81 | + |
| 82 | + # test match |
| 83 | + seq_query = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True)) |
| 84 | + block_trie.match(seq_query) |
| 85 | + assert seq_query.all_routed_experts is not None |
| 86 | + assert len(seq_query.all_routed_experts) == block_size * 2 |
| 87 | + assert np.array_equal(seq_query.all_routed_experts.get_real(), np.concatenate(all_routed_experts[:2], axis=0)) |
| 88 | + |
47 | 89 | def test_allocate(self, block_trie, block_mgr, seq_manager): |
48 | 90 | allocator = block_trie.allocator |
49 | 91 | sess = SchedulerSession(0, seq_manager) |
|
0 commit comments