Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions config/examples/generic_categories.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Example: Using generic categories with MMLU-Pro mapping
# This file demonstrates how to declare free-style categories and map them to
# MMLU-Pro categories expected by the classifier model.

bert_model:
model_id: sentence-transformers/all-MiniLM-L12-v2
threshold: 0.6
use_cpu: true

classifier:
category_model:
model_id: "models/category_classifier_modernbert-base_model"
use_modernbert: true
threshold: 0.6
use_cpu: true
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"

# Define your generic categories and map them to MMLU-Pro categories.
# The classifier will translate predicted MMLU categories into these generic names.
categories:
- name: tech
mmlu_categories: ["computer science", "engineering"]
model_scores:
- model: phi4
score: 0.9
- model: mistral-small3.1
score: 0.7
- name: finance
mmlu_categories: ["economics"]
model_scores:
- model: gemma3:27b
score: 0.8
- name: politics
# If omitted, identity mapping applies when this name matches MMLU
model_scores:
- model: gemma3:27b
score: 0.6

# A default model is recommended for fallback
default_model: mistral-small3.1
4 changes: 4 additions & 0 deletions src/semantic-router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ type Category struct {
ReasoningDescription string `yaml:"reasoning_description,omitempty"`
ReasoningEffort string `yaml:"reasoning_effort,omitempty"` // Configurable reasoning effort level (low, medium, high)
ModelScores []ModelScore `yaml:"model_scores"`
// MMLUCategories optionally maps this generic category to one or more MMLU-Pro categories
// used by the classifier model. When provided, classifier outputs will be translated
// from these MMLU categories to this generic category name.
MMLUCategories []string `yaml:"mmlu_categories,omitempty"`
}

// Legacy types - can be removed once migration is complete
Expand Down
50 changes: 50 additions & 0 deletions src/semantic-router/pkg/config/mmlu_categories_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package config_test

import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

"gopkg.in/yaml.v3"

"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
)

var _ = Describe("MMLU categories in config YAML", func() {
It("should unmarshal mmlu_categories into Category struct", func() {
yamlContent := `
categories:
- name: "tech"
mmlu_categories: ["computer science", "engineering"]
model_scores:
- model: "phi4"
score: 0.9
use_reasoning: false
- name: "finance"
mmlu_categories: ["economics"]
model_scores:
- model: "gemma3:27b"
score: 0.8
use_reasoning: true
- name: "politics"
model_scores:
- model: "gemma3:27b"
score: 0.6
use_reasoning: false
`

var cfg config.RouterConfig
Expect(yaml.Unmarshal([]byte(yamlContent), &cfg)).To(Succeed())

Expect(cfg.Categories).To(HaveLen(3))

Expect(cfg.Categories[0].Name).To(Equal("tech"))
Expect(cfg.Categories[0].MMLUCategories).To(ConsistOf("computer science", "engineering"))
Expect(cfg.Categories[0].ModelScores).ToNot(BeEmpty())

Expect(cfg.Categories[1].Name).To(Equal("finance"))
Expect(cfg.Categories[1].MMLUCategories).To(ConsistOf("economics"))

Expect(cfg.Categories[2].Name).To(Equal("politics"))
Expect(cfg.Categories[2].MMLUCategories).To(BeEmpty())
})
})
82 changes: 70 additions & 12 deletions src/semantic-router/pkg/utils/classification/classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ type Classifier struct {
CategoryMapping *CategoryMapping
PIIMapping *PIIMapping
JailbreakMapping *JailbreakMapping

// Category name mapping layer to support generic categories in config
// Maps MMLU-Pro category names -> generic category names (as defined in config.Categories)
MMLUToGeneric map[string]string
// Maps generic category names -> MMLU-Pro category names
GenericToMMLU map[string][]string
}

type option func(*Classifier)
Expand Down Expand Up @@ -272,6 +278,9 @@ func newClassifierWithOptions(cfg *config.RouterConfig, options ...option) (*Cla
option(classifier)
}

// Build category name mappings to support generic categories in config
classifier.buildCategoryNameMappings()

return initModels(classifier)
}

Expand Down Expand Up @@ -331,18 +340,21 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
return "", float64(result.Confidence), nil
}

// Convert class index to category name
// Convert class index to category name (MMLU-Pro)
categoryName, ok := c.CategoryMapping.GetCategoryFromIndex(result.Class)
if !ok {
observability.Warnf("Class index %d not found in category mapping", result.Class)
return "", float64(result.Confidence), nil
}

// Record the category classification metric
metrics.RecordCategoryClassification(categoryName)
// Translate to generic category if mapping is configured
genericCategory := c.translateMMLUToGeneric(categoryName)

observability.Infof("Classified as category: %s", categoryName)
return categoryName, float64(result.Confidence), nil
// Record the category classification metric using generic name when available
metrics.RecordCategoryClassification(genericCategory)

observability.Infof("Classified as category: %s (mmlu=%s)", genericCategory, categoryName)
return genericCategory, float64(result.Confidence), nil
}

// IsJailbreakEnabled checks if jailbreak detection is enabled and properly configured
Expand Down Expand Up @@ -485,11 +497,11 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
observability.Infof("Classification result: class=%d, confidence=%.4f, entropy_available=%t",
result.Class, result.Confidence, len(result.Probabilities) > 0)

// Get category names for all classes
// Get category names for all classes and translate to generic names when configured
categoryNames := make([]string, len(result.Probabilities))
for i := range result.Probabilities {
if name, ok := c.CategoryMapping.GetCategoryFromIndex(i); ok {
categoryNames[i] = name
categoryNames[i] = c.translateMMLUToGeneric(name)
} else {
categoryNames[i] = fmt.Sprintf("unknown_%d", i)
}
Expand Down Expand Up @@ -580,20 +592,21 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
return "", float64(result.Confidence), reasoningDecision, nil
}

// Convert class index to category name
// Convert class index to category name and translate to generic
categoryName, ok := c.CategoryMapping.GetCategoryFromIndex(result.Class)
if !ok {
observability.Warnf("Class index %d not found in category mapping", result.Class)
return "", float64(result.Confidence), reasoningDecision, nil
}
genericCategory := c.translateMMLUToGeneric(categoryName)

// Record the category classification metric
metrics.RecordCategoryClassification(categoryName)
metrics.RecordCategoryClassification(genericCategory)

observability.Infof("Classified as category: %s, reasoning_decision: use=%t, confidence=%.3f, reason=%s",
categoryName, reasoningDecision.UseReasoning, reasoningDecision.Confidence, reasoningDecision.DecisionReason)
observability.Infof("Classified as category: %s (mmlu=%s), reasoning_decision: use=%t, confidence=%.3f, reason=%s",
genericCategory, categoryName, reasoningDecision.UseReasoning, reasoningDecision.Confidence, reasoningDecision.DecisionReason)

return categoryName, float64(result.Confidence), reasoningDecision, nil
return genericCategory, float64(result.Confidence), reasoningDecision, nil
}

// ClassifyPII performs PII token classification on the given text and returns detected PII types
Expand Down Expand Up @@ -772,6 +785,51 @@ func (c *Classifier) findCategory(categoryName string) *config.Category {
return nil
}

// buildCategoryNameMappings builds translation maps between MMLU-Pro and generic categories
func (c *Classifier) buildCategoryNameMappings() {
c.MMLUToGeneric = make(map[string]string)
c.GenericToMMLU = make(map[string][]string)

// Build set of known MMLU-Pro categories from the model mapping (if available)
knownMMLU := make(map[string]bool)
if c.CategoryMapping != nil {
for _, label := range c.CategoryMapping.IdxToCategory {
knownMMLU[strings.ToLower(label)] = true
}
}

for _, cat := range c.Config.Categories {
if len(cat.MMLUCategories) > 0 {
for _, mmlu := range cat.MMLUCategories {
key := strings.ToLower(mmlu)
c.MMLUToGeneric[key] = cat.Name
c.GenericToMMLU[cat.Name] = append(c.GenericToMMLU[cat.Name], mmlu)
}
} else {
// Fallback: identity mapping when the generic name matches an MMLU category
nameLower := strings.ToLower(cat.Name)
if knownMMLU[nameLower] {
c.MMLUToGeneric[nameLower] = cat.Name
c.GenericToMMLU[cat.Name] = append(c.GenericToMMLU[cat.Name], cat.Name)
}
}
}
}

// translateMMLUToGeneric translates an MMLU-Pro category to a generic category if mapping exists
func (c *Classifier) translateMMLUToGeneric(mmluCategory string) string {
if mmluCategory == "" {
return ""
}
if c.MMLUToGeneric == nil {
return mmluCategory
}
if generic, ok := c.MMLUToGeneric[strings.ToLower(mmluCategory)]; ok {
return generic
}
return mmluCategory
}

// selectBestModelInternal performs the core model selection logic
//
// modelFilter is optional - if provided, only models passing the filter will be considered
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package classification

import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

candle_binding "github.com/vllm-project/semantic-router/candle-binding"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
)

var _ = Describe("generic category mapping (MMLU-Pro -> generic)", func() {
var (
classifier *Classifier
mockCategoryInitializer *MockCategoryInitializer
mockCategoryModel *MockCategoryInference
)

BeforeEach(func() {
mockCategoryInitializer = &MockCategoryInitializer{InitError: nil}
mockCategoryModel = &MockCategoryInference{}

cfg := &config.RouterConfig{}
cfg.Classifier.CategoryModel.ModelID = "model-id"
cfg.Classifier.CategoryModel.CategoryMappingPath = "category-mapping-path"
cfg.Classifier.CategoryModel.Threshold = 0.5

// Define generic categories with MMLU-Pro mappings
cfg.Categories = []config.Category{
{
Name: "tech",
MMLUCategories: []string{"computer science", "engineering"},
ModelScores: []config.ModelScore{{Model: "phi4", Score: 0.9, UseReasoning: config.BoolPtr(false)}},
ReasoningEffort: "low",
},
{
Name: "finance",
MMLUCategories: []string{"economics"},
ModelScores: []config.ModelScore{{Model: "gemma3:27b", Score: 0.8, UseReasoning: config.BoolPtr(true)}},
},
{
Name: "politics",
// No explicit mmlu_categories -> identity fallback when label exists in mapping
ModelScores: []config.ModelScore{{Model: "gemma3:27b", Score: 0.6, UseReasoning: config.BoolPtr(false)}},
},
}

// Category mapping represents labels coming from the MMLU-Pro model
categoryMapping := &CategoryMapping{
CategoryToIdx: map[string]int{
"computer science": 0,
"economics": 1,
"politics": 2,
},
IdxToCategory: map[string]string{
"0": "Computer Science", // different case to assert case-insensitive mapping
"1": "economics",
"2": "politics",
},
}

var err error
classifier, err = newClassifierWithOptions(
cfg,
withCategory(categoryMapping, mockCategoryInitializer, mockCategoryModel),
)
Expect(err).ToNot(HaveOccurred())
})

It("builds expected MMLU<->generic maps", func() {
Expect(classifier.MMLUToGeneric).To(HaveKeyWithValue("computer science", "tech"))
Expect(classifier.MMLUToGeneric).To(HaveKeyWithValue("engineering", "tech"))
Expect(classifier.MMLUToGeneric).To(HaveKeyWithValue("economics", "finance"))
// identity fallback for a generic name that exists as an MMLU label
Expect(classifier.MMLUToGeneric).To(HaveKeyWithValue("politics", "politics"))

Expect(classifier.GenericToMMLU).To(HaveKey("tech"))
Expect(classifier.GenericToMMLU["tech"]).To(ConsistOf("computer science", "engineering"))
Expect(classifier.GenericToMMLU).To(HaveKeyWithValue("finance", ConsistOf("economics")))
Expect(classifier.GenericToMMLU).To(HaveKeyWithValue("politics", ConsistOf("politics")))
})

It("translates ClassifyCategory result to generic category", func() {
// Model returns class index 0 -> "Computer Science" (MMLU) which maps to generic "tech"
mockCategoryModel.classifyResult = candle_binding.ClassResult{Class: 0, Confidence: 0.92}

category, score, err := classifier.ClassifyCategory("This text is about GPUs and compilers")
Expect(err).ToNot(HaveOccurred())
Expect(category).To(Equal("tech"))
Expect(score).To(BeNumerically("~", 0.92, 0.001))
})

It("translates names in entropy flow and returns generic top category", func() {
// Probabilities favor index 0 -> generic should be "tech"
mockCategoryModel.classifyWithProbsResult = candle_binding.ClassResultWithProbs{
Class: 0,
Confidence: 0.88,
Probabilities: []float32{0.7, 0.2, 0.1},
NumClasses: 3,
}

category, confidence, decision, err := classifier.ClassifyCategoryWithEntropy("Economic policies in computer science education")
Expect(err).ToNot(HaveOccurred())
Expect(category).To(Equal("tech"))
Expect(confidence).To(BeNumerically("~", 0.88, 0.001))
Expect(decision.TopCategories).ToNot(BeEmpty())
Expect(decision.TopCategories[0].Category).To(Equal("tech"))
})

It("falls back to identity when no mapping exists for an MMLU label", func() {
// index 2 -> "politics" (no explicit mapping provided, but present in MMLU set)
mockCategoryModel.classifyResult = candle_binding.ClassResult{Class: 2, Confidence: 0.91}

category, score, err := classifier.ClassifyCategory("This is a political debate")
Expect(err).ToNot(HaveOccurred())
Expect(category).To(Equal("politics"))
Expect(score).To(BeNumerically("~", 0.91, 0.001))
})
})
Loading
Loading