diff --git a/examples/transformer/client.py b/examples/transformer/client.py new file mode 100644 index 00000000..88299834 --- /dev/null +++ b/examples/transformer/client.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +import numpy as np +import sys +import argparse +from typing import List, Tuple +import tritonclient.http as httpclient +from tritonclient.utils import InferenceServerException + + +class SimpleTokenizer: + """ + A simple character-level tokenizer for demo purposes. + """ + + def __init__(self, vocab_size=10000): + self.vocab_size = vocab_size + self.pad_token_id = 0 + self.unk_token_id = 1 + + def encode(self, text: str, max_length: int = 128) -> Tuple[List[int], List[int]]: + """ + Encode text to token IDs and create attention mask. + + Args: + text: Input text string + max_length: Maximum sequence length + + Returns: + Tuple of (input_ids, attention_mask) + """ + # Simple character-level encoding that maps each character to an ID based on its ASCII value + input_ids = [min(ord(c), self.vocab_size - 1) for c in text.lower()] + + # Truncate if too long + if len(input_ids) > max_length: + input_ids = input_ids[:max_length] + + # Create attention mask (1 for real tokens, 0 for padding) + attention_mask = [1] * len(input_ids) + + # Pad to max_length + padding_length = max_length - len(input_ids) + input_ids.extend([self.pad_token_id] * padding_length) + attention_mask.extend([0] * padding_length) + + return input_ids, attention_mask + + +class SentimentClient: + """ + Client for the Transformer Sentiment Classifier on Triton Inference Server. + """ + + def __init__(self, url: str = "localhost:8000", model_name: str = "transformer"): + """ + Initialize the client. + + Args: + url: Triton server URL (e.g., "localhost:8000") + model_name: Name of the model + """ + self.url = url + self.model_name = model_name + self.client = httpclient.InferenceServerClient(url=url, verbose=False) + self.tokenizer = SimpleTokenizer() + self.max_seq_length = 128 + self.class_names = ["Negative", "Neutral", "Positive"] + + def check_server_ready(self) -> bool: + """Check if the Triton server is ready.""" + try: + if self.client.is_server_ready(): + print(f"Server at {self.url} is ready") + return True + else: + print(f"Server at {self.url} is not ready") + return False + except InferenceServerException as e: + print(f"Failed to connect to server at {self.url}") + print(f" Error: {e}") + return False + + def check_model_ready(self) -> bool: + """Check if the model is ready.""" + try: + if self.client.is_model_ready(self.model_name): + print(f"Model '{self.model_name}' is ready") + return True + else: + print(f"Model '{self.model_name}' is not ready") + return False + except InferenceServerException as e: + print(f"Failed to check model status") + print(f" Error: {e}") + return False + + def predict(self, text: str) -> Tuple[np.ndarray, int, str]: + """ + Run inference on a single text input. + + Args: + text: Input text string + + Returns: + Tuple of (probabilities, predicted_class, class_name) + """ + # Tokenize input + input_ids, attention_mask = self.tokenizer.encode(text, self.max_seq_length) + + # Convert to numpy arrays with batch dimension + input_ids_np = np.array([input_ids], dtype=np.int64) + attention_mask_np = np.array([attention_mask], dtype=np.int64) + + # Create input objects + inputs = [ + httpclient.InferInput("INPUT_IDS", input_ids_np.shape, "INT64"), + httpclient.InferInput("ATTENTION_MASK", attention_mask_np.shape, "INT64") + ] + + # Set data + inputs[0].set_data_from_numpy(input_ids_np) + inputs[1].set_data_from_numpy(attention_mask_np) + + # Create output object + outputs = [httpclient.InferRequestedOutput("OUTPUT")] + + # Send inference request + try: + response = self.client.infer( + model_name=self.model_name, + inputs=inputs, + outputs=outputs + ) + + # Get output + output = response.as_numpy("OUTPUT")[0] # Remove batch dimension + predicted_class = int(np.argmax(output)) + class_name = self.class_names[predicted_class] + + return output, predicted_class, class_name + + except InferenceServerException as e: + print(f"Inference failed: {e}") + raise + + def predict_batch(self, texts: List[str]) -> List[Tuple[np.ndarray, int, str]]: + """ + Run inference on a batch of text inputs. + + Args: + texts: List of input text strings + + Returns: + List of tuples (probabilities, predicted_class, class_name) for each input + """ + # Tokenize all inputs + input_ids_batch = [] + attention_mask_batch = [] + + for text in texts: + input_ids, attention_mask = self.tokenizer.encode(text, self.max_seq_length) + input_ids_batch.append(input_ids) + attention_mask_batch.append(attention_mask) + + # Convert to numpy arrays + input_ids_np = np.array(input_ids_batch, dtype=np.int64) + attention_mask_np = np.array(attention_mask_batch, dtype=np.int64) + + # Create input objects + inputs = [ + httpclient.InferInput("INPUT_IDS", input_ids_np.shape, "INT64"), + httpclient.InferInput("ATTENTION_MASK", attention_mask_np.shape, "INT64") + ] + + # Set data + inputs[0].set_data_from_numpy(input_ids_np) + inputs[1].set_data_from_numpy(attention_mask_np) + + # Create output object + outputs = [httpclient.InferRequestedOutput("OUTPUT")] + + # Send inference request + try: + response = self.client.infer( + model_name=self.model_name, + inputs=inputs, + outputs=outputs + ) + + # Get outputs + outputs_np = response.as_numpy("OUTPUT") + + results = [] + for output in outputs_np: + predicted_class = int(np.argmax(output)) + class_name = self.class_names[predicted_class] + results.append((output, predicted_class, class_name)) + + return results + + except InferenceServerException as e: + print(f"Batch inference failed: {e}") + raise + diff --git a/examples/transformer/config.pbtxt b/examples/transformer/config.pbtxt new file mode 100644 index 00000000..0c178fa2 --- /dev/null +++ b/examples/transformer/config.pbtxt @@ -0,0 +1,47 @@ +name: "transformer" +backend: "python" +max_batch_size: 8 # maximum batch size that model supports for the types of batching on Triton + +# Input tensor specifications +input [ + { + name: "INPUT_IDS" + data_type: TYPE_INT64 + dims: [ 128 ] # max_seq_length + }, + { + name: "ATTENTION_MASK" + data_type: TYPE_INT64 + dims: [ 128 ] # max_seq_length + } +] + +# Output tensor specifications +output [ + { + name: "OUTPUT" + data_type: TYPE_FP32 + dims: [ 3 ] # num_classes (Negative, Neutral, Positive) + } +] + +# Instance group configuration +# For GPUs: Use KIND_GPU +# For CPU-only: Use KIND_CPU +instance_group [ + { + count: 1 + kind: KIND_GPU + gpus: [ 0 ] + } +] + +# Dynamic batching configuration for better throughput +dynamic_batching { + preferred_batch_size: [ 4, 8 ] + max_queue_delay_microseconds: 100 +} + +# Model version policy - serve the latest version +version_policy: { latest: { num_versions: 1 } } + diff --git a/examples/transformer/model.py b/examples/transformer/model.py new file mode 100644 index 00000000..67413dfb --- /dev/null +++ b/examples/transformer/model.py @@ -0,0 +1,214 @@ +import json +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +import triton_python_backend_utils as pb_utils +import random + + +class SentimentClassifier(nn.Module): + """ + A transformer-based sentiment classifier model. + This model takes tokenized text sequences as input and outputs sentiment scores. + """ + def __init__(self, vocab_size=10000, embed_dim=256, num_heads=8, + num_layers=4, max_seq_length=128, num_classes=3): + super(SentimentClassifier, self).__init__() + """ + Initialize the sentiment classifier model. + + Args: + vocab_size: Size of the vocabulary + embed_dim: Embedding dimension + num_heads: Number of attention heads + num_layers: Number of transformer layers + max_seq_length: Maximum sequence length + num_classes: Number of sentiment classes + """ + + self.vocab_size = vocab_size + self.embed_dim = embed_dim + self.max_seq_length = max_seq_length + + # Embedding layers + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.position_embedding = nn.Embedding(max_seq_length, embed_dim) + + # Transformer encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=embed_dim * 4, + dropout=0.1, + activation='relu', + batch_first=True + ) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Classification head + self.fc1 = nn.Linear(embed_dim, embed_dim // 2) + self.dropout = nn.Dropout(0.1) + self.fc2 = nn.Linear(embed_dim // 2, num_classes) + + def forward(self, input_ids, attention_mask=None): + """ + Args: + input_ids: Token IDs [batch_size, seq_length] + attention_mask: Attention mask [batch_size, seq_length] + Returns: + logits: Classification logits [batch_size, num_classes] + """ + batch_size, seq_length = input_ids.shape + + # Create position IDs + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + # Embeddings + token_embeds = self.token_embedding(input_ids) + position_embeds = self.position_embedding(position_ids) + embeddings = token_embeds + position_embeds + + # Create attention mask for transformer (inverted: 1 -> can attend, 0 -> cannot attend) + if attention_mask is not None: + # Convert to boolean mask (True = masked position) + src_key_padding_mask = (attention_mask == 0) + else: + src_key_padding_mask = None + + # Transformer encoding + encoded = self.transformer_encoder( + embeddings, + src_key_padding_mask=src_key_padding_mask + ) + + # Pool: take the mean of all non-padded tokens + if attention_mask is not None: + mask_expanded = attention_mask.unsqueeze(-1).expand(encoded.size()).float() + sum_embeddings = torch.sum(encoded * mask_expanded, dim=1) + sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9) + pooled = sum_embeddings / sum_mask + else: + pooled = encoded.mean(dim=1) + + # Classification + x = self.fc1(pooled) + x = F.relu(x) + x = self.dropout(x) + logits = self.fc2(x) + + return logits + + +class TritonPythonModel: + """ + Triton Python backend model wrapper for the transformer sentiment classifier. + """ + def initialize(self, args): + """ + Initialize the model. This function is called once when the model is loaded. + + Args: + args: Dictionary containing initialization parameters + """ + self.model_config = model_config = json.loads(args['model_config']) + + # Get output configuration + output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") + self.output_dtype = pb_utils.triton_string_to_numpy(output_config['data_type']) + + # Model parameters + self.vocab_size = 10000 + self.embed_dim = 256 + self.num_heads = 8 + self.num_layers = 4 + self.max_seq_length = 128 + self.num_classes = 3 # Negative, Neutral, Positive + + # Device configuration + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Set all random seeds for reproducibility BEFORE model initialization + seed = 42 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # For multi-GPU + # Make CUDA operations deterministic + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # Initialize the model + self.model = SentimentClassifier( + vocab_size=self.vocab_size, + embed_dim=self.embed_dim, + num_heads=self.num_heads, + num_layers=self.num_layers, + max_seq_length=self.max_seq_length, + num_classes=self.num_classes + ).to(self.device) + + # Set model to evaluation mode to disable dropout and ensure consistent outputs + self.model.eval() + + print(f"Model initialized on device: {self.device}") + print(f"Model parameters: vocab_size={self.vocab_size}, embed_dim={self.embed_dim}, " + f"num_heads={self.num_heads}, num_layers={self.num_layers}, " + f"max_seq_length={self.max_seq_length}, num_classes={self.num_classes}") + + def execute(self, requests): + """ + Execute inference on a batch of requests. + + Args: + requests: List of pb_utils.InferenceRequest objects + + Returns: + List of pb_utils.InferenceResponse objects + """ + responses = [] + + for request in requests: + # Get input tensors + input_ids_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT_IDS") + attention_mask_tensor = pb_utils.get_input_tensor_by_name(request, "ATTENTION_MASK") + + # Convert to numpy + input_ids_np = input_ids_tensor.as_numpy() + attention_mask_np = attention_mask_tensor.as_numpy() + + # Convert to torch tensors + input_ids = torch.from_numpy(input_ids_np).long().to(self.device) + attention_mask = torch.from_numpy(attention_mask_np).long().to(self.device) + + # Run inference + with torch.no_grad(): + logits = self.model(input_ids, attention_mask) + + # Apply softmax to get probabilities + probabilities = F.softmax(logits, dim=-1) + + # Convert output to numpy + output_np = probabilities.cpu().numpy().astype(self.output_dtype) + + # Create output tensor + output_tensor = pb_utils.Tensor("OUTPUT", output_np) + + # Create inference response + inference_response = pb_utils.InferenceResponse(output_tensors=[output_tensor]) + responses.append(inference_response) + + return responses + + def finalize(self): + """ + Clean up resources when the model is unloaded. + """ + print("Cleaning up model resources...") + del self.model + if torch.cuda.is_available(): + torch.cuda.empty_cache() +