diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b0e399..fac70e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -166,7 +166,6 @@ set(PT_LIBS if (${TRITON_PYTORCH_NVSHMEM}) set(PT_LIBS ${PT_LIBS} - "libtorch_nvshmem.so" ) endif() # TRITON_PYTORCH_NVSHMEM diff --git a/model_repository/dict_model/1/model.pt b/model_repository/dict_model/1/model.pt new file mode 100644 index 0000000..e01de04 Binary files /dev/null and b/model_repository/dict_model/1/model.pt differ diff --git a/model_repository/dict_model/config.pbtxt b/model_repository/dict_model/config.pbtxt new file mode 100644 index 0000000..c39e708 --- /dev/null +++ b/model_repository/dict_model/config.pbtxt @@ -0,0 +1,24 @@ +name: "dict_model" +platform: "pytorch_libtorch" +max_batch_size: 8 + +input [ + { + name: "INPUT__0" + data_type: TYPE_FP32 + dims: [ 10 ] + } +] + +output [ + { + name: "logits" + data_type: TYPE_FP32 + dims: [ 20 ] + }, + { + name: "embeddings" + data_type: TYPE_FP32 + dims: [ 5 ] + } +] diff --git a/src/model_instance_state.cc b/src/model_instance_state.cc index d634f3b..f738980 100644 --- a/src/model_instance_state.cc +++ b/src/model_instance_state.cc @@ -52,6 +52,7 @@ ModelInstanceState::ModelInstanceState( ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) : BackendModelInstance(model_state, triton_model_instance), model_state_(model_state), device_(torch::kCPU), is_dict_input_(false), + dict_output_validated_(false), device_cnt_(0) { if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { @@ -149,6 +150,47 @@ ModelInstanceState::ModelInstanceState( THROW_IF_BACKEND_INSTANCE_ERROR(ValidateOutputs()); } +TRITONSERVER_Error* +ModelInstanceState::ValidateAndCacheDictOutput( + const c10::Dict& dict_output) +{ + if (dict_output_validated_.load(std::memory_order_acquire)) { + return nullptr; + } + std::lock_guard lock(dict_validation_mutex_); + if (dict_output_validated_.load(std::memory_order_acquire)) { + return nullptr; + } + if (dict_output.size() == 0) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, "Empty dict"); + } + std::vector temp_keys; + std::unordered_map temp_index; + size_t idx = 0; + for (auto it = dict_output.begin(); it != dict_output.end(); ++it) { + std::string key = it->key().toStringRef(); + if (!it->value().isTensor()) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, "Not tensor"); + } + temp_keys.push_back(key); + temp_index[key] = idx++; + } + std::vector missing; + for (auto& output : model_state_->ModelOutputs()) { + if (temp_index.find(output.first) == temp_index.end()) { + missing.push_back(output.first); + } + } + if (!missing.empty()) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, "Missing keys"); + } + output_dict_keys_ = std::move(temp_keys); + output_dict_key_to_index_ = std::move(temp_index); + dict_output_validated_.store(true, std::memory_order_release); + return nullptr; +} + + ModelInstanceState::~ModelInstanceState() { torch_model_.reset(); @@ -345,6 +387,18 @@ ModelInstanceState::Execute( list_output.elementType()->str() + "]"); } output_tensors->push_back(model_outputs_); + } else if (model_outputs_.isGenericDict()) { + auto dict_output = model_outputs_.toGenericDict(); + if (!dict_output_validated_.load(std::memory_order_acquire)) { + TRITONSERVER_Error* err = ValidateAndCacheDictOutput(dict_output); + if (err != nullptr) { + SendErrorForResponses(responses, request_count, err); + return; + } + } + for (const auto& key : output_dict_keys_) { + output_tensors->push_back(dict_output.at(key)); + } } else { throw std::invalid_argument( "output must be of type Tensor, List[str] or Tuple containing one of " @@ -872,7 +926,14 @@ ModelInstanceState::ReadOutputTensors( // The serialized string buffer must be valid until output copies are done std::vector> string_buffer; for (auto& output : model_state_->ModelOutputs()) { + // Use dict key mapping if available int op_index = output_index_map_[output.first]; + if (dict_output_validated_.load(std::memory_order_acquire)) { + auto it = output_dict_key_to_index_.find(output.first); + if (it != output_dict_key_to_index_.end()) { + op_index = it->second; + } + } auto name = output.first; auto output_tensor_pair = output.second; diff --git a/src/model_instance_state.hh b/src/model_instance_state.hh index b495510..092fa46 100644 --- a/src/model_instance_state.hh +++ b/src/model_instance_state.hh @@ -26,6 +26,9 @@ #pragma once +#include +#include + #include #include @@ -73,6 +76,13 @@ class ModelInstanceState : public BackendModelInstance { // Map from configuration name for an output to the index of // that output in the model. std::unordered_map output_index_map_; + + // If the output is a dictionary of tensors. + std::atomic dict_output_validated_; + std::mutex dict_validation_mutex_; + std::vector output_dict_keys_; + std::unordered_map output_dict_key_to_index_; + std::unordered_map output_dtype_map_; // If the input to the tensor is a dictionary of tensors. @@ -92,6 +102,9 @@ class ModelInstanceState : public BackendModelInstance { int device_cnt_; public: + TRITONSERVER_Error* ValidateAndCacheDictOutput( + const c10::Dict& dict_output); + virtual ~ModelInstanceState(); // Clear CUDA cache diff --git a/test_client.py b/test_client.py new file mode 100644 index 0000000..f11cb15 --- /dev/null +++ b/test_client.py @@ -0,0 +1,25 @@ +# test_client.py +import tritonclient.http as httpclient +import numpy as np + +# Create client +client = httpclient.InferenceServerClient(url="localhost:8000") + +# Prepare input +input_data = np.random.randn(5, 10).astype(np.float32) +inputs = [httpclient.InferInput("INPUT__0", input_data.shape, "FP32")] +inputs[0].set_data_from_numpy(input_data) + +# Request outputs by dict key names +outputs = [ + httpclient.InferRequestedOutput("logits"), + httpclient.InferRequestedOutput("embeddings") +] + +# Infer +results = client.infer("dict_model", inputs, outputs=outputs) + +# Check output names +print("Output names:", results.get_response()) +print("Logits shape:", results.as_numpy("logits").shape) +print("Embeddings shape:", results.as_numpy("embeddings").shape) diff --git a/test_model.py b/test_model.py new file mode 100644 index 0000000..e2f5b68 --- /dev/null +++ b/test_model.py @@ -0,0 +1,33 @@ +# test_model.py +import torch +import torch.nn as nn + +class DictOutputModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 50) + self.fc2 = nn.Linear(50, 20) + self.fc3 = nn.Linear(50, 5) + + def forward(self, x): + features = self.fc1(x) + logits = self.fc2(features) + embeddings = self.fc3(features) + + # Return dictionary + return { + "logits": logits, + "embeddings": embeddings + } + +# Create and save model +model = DictOutputModel() +model.eval() + +# Trace with example input +example_input = torch.randn(1, 10) +traced_model = torch.jit.trace(model, example_input, strict=False) + +# Save +torch.jit.save(traced_model, "model.pt") +print("Model saved!")