@@ -66,7 +66,11 @@ def auto_complete_config(auto_complete_model_config):
6666 "optional" : True ,
6767 },
6868 ]
69- outputs = [{"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]}]
69+ outputs = [
70+ {"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]},
71+ {"name" : "input_tokens" , "data_type" : "TYPE_INT32" , "dims" : [- 1 ]},
72+ {"name" : "output_tokens" , "data_type" : "TYPE_INT32" , "dims" : [- 1 ]},
73+ ]
7074
7175 # Store the model configuration as a dictionary.
7276 config = auto_complete_model_config .as_dict ()
@@ -151,6 +155,15 @@ def initialize(self, args):
151155 )
152156 self .output_dtype = pb_utils .triton_string_to_numpy (output_config ["data_type" ])
153157
158+ output_tokens_config = pb_utils .get_output_config_by_name (
159+ self .model_config , "output_tokens"
160+ )
161+ self .output_tokens_dtype = pb_utils .triton_string_to_numpy (output_tokens_config ["data_type" ])
162+ input_tokens_config = pb_utils .get_output_config_by_name (
163+ self .model_config , "input_tokens"
164+ )
165+ self .input_tokens_dtype = pb_utils .triton_string_to_numpy (input_tokens_config ["data_type" ])
166+
154167 # Counter to keep track of ongoing request counts
155168 self .ongoing_request_count = 0
156169
@@ -246,10 +259,17 @@ def create_response(self, vllm_output, prepend_input):
246259 text_outputs = [
247260 (prompt + output .text ).encode ("utf-8" ) for output in vllm_output .outputs
248261 ]
262+ output_tokens = sum ([len (output .token_ids ) for output in vllm_output .outputs ])
249263 triton_output_tensor = pb_utils .Tensor (
250- "text_output" , np .asarray (text_outputs , dtype = self .output_dtype )
264+ "text_output" , np .asarray (text_outputs , dtype = self .output_dtype ),
251265 )
252- return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor ])
266+ triton_tokens_tensor = pb_utils .Tensor (
267+ "output_tokens" , np .asarray (output_tokens , dtype = self .output_tokens_dtype ),
268+ )
269+ triton_input_tokens_tensor = pb_utils .Tensor (
270+ "input_tokens" , np .asarray (len (vllm_output .prompt_token_ids ), dtype = self .input_tokens_dtype ),
271+ )
272+ return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor , triton_tokens_tensor , triton_input_tokens_tensor ])
253273
254274 def create_stream_response (self , vllm_output , previous_outputs_lengths ):
255275 """
0 commit comments