@@ -41,6 +41,7 @@ def setUp(self):
4141 self .triton_client = grpcclient .InferenceServerClient (url = "localhost:8001" )
4242 self .vllm_model_name = "vllm_opt"
4343 self .python_model_name = "add_sub"
44+ self .local_vllm_model_name = "vllm_local"
4445
4546 def test_vllm_triton_backend (self ):
4647 # Load both vllm and add_sub models
@@ -60,9 +61,21 @@ def test_vllm_triton_backend(self):
6061 self .assertFalse (self .triton_client .is_model_ready (self .python_model_name ))
6162
6263 # Test vllm model and unload vllm model
63- self ._test_vllm_model (send_parameters_as_tensor = True )
64- self ._test_vllm_model (send_parameters_as_tensor = False )
64+ self ._test_vllm_model (self . vllm_model_name , send_parameters_as_tensor = True )
65+ self ._test_vllm_model (self . vllm_model_name , send_parameters_as_tensor = False )
6566 self .triton_client .unload_model (self .vllm_model_name )
67+
68+ def test_local_vllm_model (self ):
69+ # Load local vllm model
70+ self .triton_client .load_model (self .local_vllm_model_name )
71+ self .assertTrue (self .triton_client .is_model_ready (self .local_vllm_model_name ))
72+
73+ # Test local vllm model
74+ self ._test_vllm_model (self .local_vllm_model_name , send_parameters_as_tensor = True )
75+ self ._test_vllm_model (self .local_vllm_model_name , send_parameters_as_tensor = False )
76+
77+ # Unload local vllm model
78+ self .triton_client .unload_model (self .local_vllm_model_name )
6679
6780 def test_model_with_invalid_attributes (self ):
6881 model_name = "vllm_invalid_1"
@@ -74,7 +87,7 @@ def test_vllm_invalid_model_name(self):
7487 with self .assertRaises (InferenceServerException ):
7588 self .triton_client .load_model (model_name )
7689
77- def _test_vllm_model (self , send_parameters_as_tensor ):
90+ def _test_vllm_model (self , model_name , send_parameters_as_tensor ):
7891 user_data = UserData ()
7992 stream = False
8093 prompts = [
@@ -92,11 +105,11 @@ def _test_vllm_model(self, send_parameters_as_tensor):
92105 i ,
93106 stream ,
94107 sampling_parameters ,
95- self . vllm_model_name ,
108+ model_name ,
96109 send_parameters_as_tensor ,
97110 )
98111 self .triton_client .async_stream_infer (
99- model_name = self . vllm_model_name ,
112+ model_name = model_name ,
100113 request_id = request_data ["request_id" ],
101114 inputs = request_data ["inputs" ],
102115 outputs = request_data ["outputs" ],
0 commit comments