From 3ab3937506eba396f304b46ed07a335c1d279de7 Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Tue, 15 Jul 2025 22:49:34 +0530 Subject: [PATCH 1/2] feat: add max-num-batched-tokens configuration and implement request handling constraints --- README.md | 1 + manifests/basic-config.yaml | 1 + manifests/config.yaml | 1 + pkg/llm-d-inference-sim/config.go | 5 + pkg/llm-d-inference-sim/config_test.go | 36 +++++ pkg/llm-d-inference-sim/request.go | 1 + pkg/llm-d-inference-sim/simulator.go | 153 +++++++++++++++++- pkg/llm-d-inference-sim/simulator_test.go | 187 ++++++++++++++++++++++ 8 files changed, 380 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 93388ac7..3e6aa3b0 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ For more details see the = int64(s.config.MaxNumSeqs) { + return false + } + + // If max-num-batched-tokens is not configured (0), only check max-num-seqs + if s.config.MaxNumBatchedTokens <= 0 { + return true + } + + // Calculate tokens needed for this request + requestTokens := s.calculateProcessingTokens(req) + currentTokens := atomic.LoadInt64(&s.processingTokensCount) + + // Check max-num-batched-tokens constraint + return currentTokens+int64(requestTokens) <= int64(s.config.MaxNumBatchedTokens) +} + +// addRunningRequest adds a request to the running requests tracking +func (s *VllmSimulator) addRunningRequest(reqID string, req completionRequest) { + processingTokens := s.calculateProcessingTokens(req) + + runningReq := runningRequest{ + promptTokens: req.getNumberOfPromptTokens(), + maxTokens: processingTokens, + totalTokens: processingTokens, + } + + s.runningRequestsMap.Store(reqID, runningReq) + atomic.AddInt64(&s.processingTokensCount, int64(processingTokens)) + atomic.AddInt64(&s.nRunningReqs, 1) +} + +// removeRunningRequest removes a request from the running requests tracking +func (s *VllmSimulator) removeRunningRequest(reqID string) { + if value, ok := s.runningRequestsMap.LoadAndDelete(reqID); ok { + runningReq := value.(runningRequest) + atomic.AddInt64(&s.processingTokensCount, -int64(runningReq.totalTokens)) + atomic.AddInt64(&s.nRunningReqs, -1) + } +} + // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { vllmReq, err := s.readRequest(ctx, isChatCompletion) @@ -400,6 +484,16 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple return } + // Validate max-num-batched-tokens constraint - reject requests that would never be accepted + if s.config.MaxNumBatchedTokens > 0 { + requestTokens := s.calculateProcessingTokens(vllmReq) + if requestTokens > s.config.MaxNumBatchedTokens { + s.sendCompletionError(ctx, fmt.Sprintf("Request requires %d tokens, but max-num-batched-tokens is set to %d. This request would never be accepted. Please reduce max_tokens or increase max-num-batched-tokens", + requestTokens, s.config.MaxNumBatchedTokens), "BadRequestError", fasthttp.StatusBadRequest) + return + } + } + var wg sync.WaitGroup wg.Add(1) reqCtx := &completionReqCtx{ @@ -414,15 +508,60 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple wg.Wait() } +func (s *VllmSimulator) queueManager(ctx context.Context) { + // Use a slice to maintain the queue of waiting requests + var waitingQueue []*completionReqCtx + ticker := time.NewTicker(10 * time.Millisecond) // Check every 10ms if we can process waiting requests + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + s.logger.Info("queueManager stopped") + return + case reqCtx := <-s.reqChan: + // Add new request to the waiting queue + waitingQueue = append(waitingQueue, reqCtx) + case <-ticker.C: + // Periodically check if we can process waiting requests + if len(waitingQueue) == 0 { + continue + } + + // Try to process requests from the front of the queue + var newQueue []*completionReqCtx + for _, reqCtx := range waitingQueue { + if s.canAcceptRequest(reqCtx.completionReq) { + // Generate a unique ID for this request + reqID := uuid.New().String() + + // Add to running requests tracking + s.addRunningRequest(reqID, reqCtx.completionReq) + + // Add the request ID to the context so workers can use it + reqCtx.requestID = reqID + + // Send to processing channel + s.processingChan <- reqCtx + } else { + // Can't process yet, keep in queue + newQueue = append(newQueue, reqCtx) + } + } + waitingQueue = newQueue + } + } +} + func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { for { select { case <-ctx.Done(): s.logger.Info("reqProcessingWorker stopped:", "worker id", id) return - case reqCtx, ok := <-s.reqChan: + case reqCtx, ok := <-s.processingChan: if !ok { - s.logger.Info("reqProcessingWorker worker exiting: reqChan closed") + s.logger.Info("reqProcessingWorker worker exiting: processingChan closed") return } atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan))) @@ -449,7 +588,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // TODO - check if this request went to the waiting queue - add it to waiting map s.reportLoras() } - atomic.AddInt64(&(s.nRunningReqs), 1) + + // Note: we don't increment nRunningReqs here because it's already done in addRunningRequest s.reportRunningRequests() var responseTokens []string @@ -514,6 +654,10 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { req.doRemotePrefill()) } } + + // Clean up the running request tracking + s.removeRunningRequest(reqCtx.requestID) + reqCtx.wg.Done() } } @@ -521,8 +665,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // decrease model usage reference number func (s *VllmSimulator) responseSentCallback(model string) { - - atomic.AddInt64(&(s.nRunningReqs), -1) + // Note: nRunningReqs is now decremented in removeRunningRequest s.reportRunningRequests() // Only LoRA models require reference-count handling. diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 22a507ae..573688f2 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -65,6 +65,9 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http return nil, err } + // run queue manager that handles request constraints + go s.queueManager(ctx) + // run request processing workers for i := 1; i <= s.config.MaxNumSeqs; i++ { go s.reqProcessingWorker(ctx, i) @@ -489,4 +492,188 @@ var _ = Describe("Simulator", func() { Expect(string(body)).To(ContainSubstring("BadRequestError")) }) }) + + Context("max-num-batched-tokens functionality", func() { + var simulator *VllmSimulator + + BeforeEach(func() { + var err error + simulator, err = New(klog.Background()) + Expect(err).NotTo(HaveOccurred()) + + // Setup basic configuration + simulator.config = newConfig() + simulator.config.Model = "test-model" + simulator.config.MaxModelLen = 1024 + simulator.config.MaxNumSeqs = 5 + simulator.config.MaxNumBatchedTokens = 2048 + }) + + Describe("calculateProcessingTokens", func() { + It("should calculate tokens with explicit max_tokens", func() { + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello world"}}, + }, + MaxTokens: int64Ptr(100), + } + + // Mock the token counting (in real implementation, this would tokenize the message) + // For test purposes, assume "Hello world" = 2 tokens + tokens := simulator.calculateProcessingTokens(req) + + // Should be prompt tokens (2) + max tokens (100) = 102 + // Note: In real implementation, this depends on the actual tokenization + Expect(tokens).To(BeNumerically(">=", 100)) + }) + + It("should calculate tokens without max_tokens using max-model-len", func() { + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello world"}}, + }, + } + + tokens := simulator.calculateProcessingTokens(req) + + // Should be prompt tokens + (max-model-len - prompt tokens) + // which equals max-model-len = 1024 + Expect(tokens).To(Equal(1024)) + }) + }) + + Describe("canAcceptRequest", func() { + It("should accept request when within both constraints", func() { + simulator.config.MaxNumSeqs = 2 + simulator.config.MaxNumBatchedTokens = 2048 + + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello"}}, + }, + MaxTokens: int64Ptr(100), + } + + canAccept := simulator.canAcceptRequest(req) + Expect(canAccept).To(BeTrue()) + }) + + It("should reject request when max-num-seqs is exceeded", func() { + simulator.config.MaxNumSeqs = 1 + simulator.config.MaxNumBatchedTokens = 2048 + + // Simulate one request already running + simulator.nRunningReqs = 1 + + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello"}}, + }, + MaxTokens: int64Ptr(100), + } + + canAccept := simulator.canAcceptRequest(req) + Expect(canAccept).To(BeFalse()) + }) + + It("should reject request when max-num-batched-tokens would be exceeded", func() { + simulator.config.MaxNumSeqs = 5 + simulator.config.MaxNumBatchedTokens = 500 + + // Simulate tokens already being used + simulator.processingTokensCount = 400 + + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello"}}, + }, + MaxTokens: int64Ptr(200), // This would exceed the limit (400 + 200+ > 500) + } + + canAccept := simulator.canAcceptRequest(req) + Expect(canAccept).To(BeFalse()) + }) + + It("should ignore batched tokens constraint when MaxNumBatchedTokens is 0", func() { + simulator.config.MaxNumSeqs = 5 + simulator.config.MaxNumBatchedTokens = 0 // Disabled + + // Simulate a lot of tokens being used + simulator.processingTokensCount = 10000 + + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello"}}, + }, + MaxTokens: int64Ptr(200), + } + + canAccept := simulator.canAcceptRequest(req) + Expect(canAccept).To(BeTrue()) // Should only check max-num-seqs + }) + }) + + It("Should start with max-num-batched-tokens parameter", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-num-batched-tokens", "1024"} + client, err := startServerWithArgs(ctx, modeRandom, args) + Expect(err).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + }) + + It("Should reject requests that exceed max-num-batched-tokens immediately", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-num-batched-tokens", "10"} + client, err := startServerWithArgs(ctx, modeRandom, args) + Expect(err).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + + // Create a request that requires more than 10 tokens (4 prompt + 20 max_tokens = 24 tokens) + reqBody := `{ + "messages": [ + {"role": "user", "content": "Hello world test prompt"} + ], + "model": "my_model", + "max_tokens": 20 + }` + + resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + + Expect(resp.StatusCode).To(Equal(400)) + Expect(string(body)).To(ContainSubstring("Request requires")) + Expect(string(body)).To(ContainSubstring("max-num-batched-tokens is set to 10")) + Expect(string(body)).To(ContainSubstring("would never be accepted")) + }) + }) }) + +// Helper function to create int64 pointer +func int64Ptr(i int64) *int64 { + return &i +} From 9d2b6b6c6f8d57cc91b83b494736da54dce3d634 Mon Sep 17 00:00:00 2001 From: mohitpalsingh Date: Wed, 16 Jul 2025 14:32:13 +0530 Subject: [PATCH 2/2] feat: add MaxNumBatchedTokens to configuration and refactor request handling --- pkg/llm-d-inference-sim/config_test.go | 7 +--- pkg/llm-d-inference-sim/request.go | 2 +- pkg/llm-d-inference-sim/simulator.go | 57 +++++++------------------- 3 files changed, 16 insertions(+), 50 deletions(-) diff --git a/pkg/llm-d-inference-sim/config_test.go b/pkg/llm-d-inference-sim/config_test.go index 6e5d5442..60711027 100644 --- a/pkg/llm-d-inference-sim/config_test.go +++ b/pkg/llm-d-inference-sim/config_test.go @@ -53,6 +53,7 @@ func createDefaultConfig(model string) *configuration { c.MaxNumSeqs = 5 c.MaxLoras = 2 c.MaxCPULoras = 5 + c.MaxNumBatchedTokens = 2048 c.TimeToFirstToken = 2000 c.InterTokenLatency = 1000 c.KVCacheTransferLatency = 100 @@ -88,7 +89,6 @@ var _ = Describe("Simulator configuration", func() { c = createDefaultConfig(qwenModelName) c.Port = 8001 c.ServedModelNames = []string{"model1", "model2"} - c.MaxNumBatchedTokens = 2048 c.LoraModules = []loraModule{{Name: "lora1", Path: "/path/to/lora1"}, {Name: "lora2", Path: "/path/to/lora2"}} test = testCase{ name: "config file", @@ -106,7 +106,6 @@ var _ = Describe("Simulator configuration", func() { c.Port = 8002 c.ServedModelNames = []string{"alias1", "alias2"} c.Seed = 100 - c.MaxNumBatchedTokens = 2048 c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}, {Name: "lora4", Path: "/path/to/lora4"}} c.LoraModulesString = []string{ "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", @@ -125,7 +124,6 @@ var _ = Describe("Simulator configuration", func() { // Config from config.yaml file plus command line args with different format c = createDefaultConfig(model) c.Port = 8002 - c.MaxNumBatchedTokens = 2048 c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}} c.LoraModulesString = []string{ "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", @@ -143,7 +141,6 @@ var _ = Describe("Simulator configuration", func() { // Config from config.yaml file plus command line args with empty string c = createDefaultConfig(model) c.Port = 8002 - c.MaxNumBatchedTokens = 2048 c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}} c.LoraModulesString = []string{ "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", @@ -162,7 +159,6 @@ var _ = Describe("Simulator configuration", func() { c = createDefaultConfig(qwenModelName) c.Port = 8001 c.ServedModelNames = []string{"model1", "model2"} - c.MaxNumBatchedTokens = 2048 c.LoraModulesString = []string{} test = testCase{ name: "config file with command line args with empty string for loras", @@ -175,7 +171,6 @@ var _ = Describe("Simulator configuration", func() { c = createDefaultConfig(qwenModelName) c.Port = 8001 c.ServedModelNames = []string{"model1", "model2"} - c.MaxNumBatchedTokens = 2048 c.LoraModulesString = []string{} test = testCase{ name: "config file with command line args with empty parameter for loras", diff --git a/pkg/llm-d-inference-sim/request.go b/pkg/llm-d-inference-sim/request.go index 36e6f431..e9f30d19 100644 --- a/pkg/llm-d-inference-sim/request.go +++ b/pkg/llm-d-inference-sim/request.go @@ -105,7 +105,7 @@ type completionReqCtx struct { httpReqCtx *fasthttp.RequestCtx isChatCompletion bool wg *sync.WaitGroup - requestID string + processingTokens int } // chatCompletionRequest defines structure of /chat/completion request diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 017dd038..33c23052 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -60,13 +60,6 @@ const ( toolChoiceRequired = "required" ) -// runningRequest tracks token usage for a currently running request -type runningRequest struct { - promptTokens int - maxTokens int - totalTokens int -} - // VllmSimulator simulates vLLM server supporting OpenAI API type VllmSimulator struct { // logger is used for information and errors logging @@ -83,8 +76,6 @@ type VllmSimulator struct { nRunningReqs int64 // nWaitingReqs is the number of inference requests that are waiting to be processed nWaitingReqs int64 - // runningRequestsMap tracks token usage for currently running requests - runningRequestsMap sync.Map // processingTokensCount tracks the total number of tokens being processed by running requests processingTokensCount int64 // loraInfo is prometheus gauge @@ -394,23 +385,18 @@ func (s *VllmSimulator) isLora(model string) bool { } // calculateProcessingTokens calculates the total number of processing tokens for a request -// Returns prompt tokens + max output tokens +// Returns prompt tokens + max output tokens, or MaxModelLen if max_tokens is not specified func (s *VllmSimulator) calculateProcessingTokens(req completionRequest) int { promptTokens := req.getNumberOfPromptTokens() maxCompletionTokens := req.getMaxCompletionTokens() - // If max_tokens is not specified, calculate it as max-model-len - prompt-len - outputTokens := 0 - if maxCompletionTokens != nil { - outputTokens = int(*maxCompletionTokens) - } else { - outputTokens = s.config.MaxModelLen - promptTokens - if outputTokens < 0 { - outputTokens = 0 - } + // If max_tokens is not specified, return the maximum possible tokens (MaxModelLen) + if maxCompletionTokens == nil { + return s.config.MaxModelLen } - return promptTokens + outputTokens + // If max_tokens is specified, return prompt tokens + specified max completion tokens + return promptTokens + int(*maxCompletionTokens) } // canAcceptRequest checks if a new request can be accepted based on max-num-seqs and max-num-batched-tokens constraints @@ -436,27 +422,18 @@ func (s *VllmSimulator) canAcceptRequest(req completionRequest) bool { } // addRunningRequest adds a request to the running requests tracking -func (s *VllmSimulator) addRunningRequest(reqID string, req completionRequest) { - processingTokens := s.calculateProcessingTokens(req) +func (s *VllmSimulator) addRunningRequest(reqCtx *completionReqCtx) { + processingTokens := s.calculateProcessingTokens(reqCtx.completionReq) + reqCtx.processingTokens = processingTokens - runningReq := runningRequest{ - promptTokens: req.getNumberOfPromptTokens(), - maxTokens: processingTokens, - totalTokens: processingTokens, - } - - s.runningRequestsMap.Store(reqID, runningReq) atomic.AddInt64(&s.processingTokensCount, int64(processingTokens)) atomic.AddInt64(&s.nRunningReqs, 1) } // removeRunningRequest removes a request from the running requests tracking -func (s *VllmSimulator) removeRunningRequest(reqID string) { - if value, ok := s.runningRequestsMap.LoadAndDelete(reqID); ok { - runningReq := value.(runningRequest) - atomic.AddInt64(&s.processingTokensCount, -int64(runningReq.totalTokens)) - atomic.AddInt64(&s.nRunningReqs, -1) - } +func (s *VllmSimulator) removeRunningRequest(reqCtx *completionReqCtx) { + atomic.AddInt64(&s.processingTokensCount, -int64(reqCtx.processingTokens)) + atomic.AddInt64(&s.nRunningReqs, -1) } // handleCompletions general completion requests handler, support both text and chat completion APIs @@ -532,14 +509,8 @@ func (s *VllmSimulator) queueManager(ctx context.Context) { var newQueue []*completionReqCtx for _, reqCtx := range waitingQueue { if s.canAcceptRequest(reqCtx.completionReq) { - // Generate a unique ID for this request - reqID := uuid.New().String() - // Add to running requests tracking - s.addRunningRequest(reqID, reqCtx.completionReq) - - // Add the request ID to the context so workers can use it - reqCtx.requestID = reqID + s.addRunningRequest(reqCtx) // Send to processing channel s.processingChan <- reqCtx @@ -656,7 +627,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { } // Clean up the running request tracking - s.removeRunningRequest(reqCtx.requestID) + s.removeRunningRequest(reqCtx) reqCtx.wg.Done() }