Skip to content

Commit 4aeca3a

Browse files
authored
fix mllm device_map ut (#1000)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent d1bf7e8 commit 4aeca3a

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

test/test_cuda/test_multiple_card.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -362,24 +362,22 @@ def test_mllm_device_map(self):
362362
device_map = "0,1"
363363
ar = AutoRoundMLLM(model_name, device_map=device_map)
364364
self.assertEqual(ar.device, "cuda:0")
365-
self.assertEqual(ar.device_map, "auto")
366-
self.assertEqual(ar.device_list, [0, 1])
365+
self.assertEqual(ar.device_map, device_map)
367366

368367
device_map = 1
369-
ar = AutoRoundMLLM(ar.model, ar.tokenizer, ar.processor, device_map=device_map)
368+
ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map)
370369
self.assertEqual(ar.device, "cuda:1")
371-
self.assertEqual(ar.device_map, None)
372-
self.assertFalse(hasattr(ar, "device_list"))
370+
self.assertEqual(ar.device_map, device_map)
373371

374372
device_map = "auto"
375-
ar = AutoRoundMLLM(ar.model, ar.tokenizer, ar.processor, device_map=device_map)
373+
ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map)
376374
self.assertEqual(ar.device, "cuda")
377-
self.assertEqual(ar.device_map, "auto")
375+
self.assertEqual(ar.device_map, device_map)
378376

379377
device_map = {"model.language_model.layers": 0, "model.visual.blocks": 1}
380-
ar = AutoRoundMLLM(ar.model, ar.tokenizer, ar.processor, device_map=device_map)
381-
self.assertEqual(ar.model.model.language_model.layers.tuning_device, "cuda:0")
382-
self.assertEqual(ar.model.model.visual.blocks.tuning_device, "cuda:1")
378+
ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map)
379+
self.assertEqual(ar.model.model.language_model.layers[0].self_attn.q_proj.tuning_device, "cuda:0")
380+
self.assertEqual(ar.model.model.visual.blocks[0].mlp.fc1.tuning_device, "cuda:1")
383381

384382

385383
if __name__ == "__main__":

0 commit comments

Comments
 (0)