diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index b83158f2..97a51a1c 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -48,4 +48,4 @@ jobs: push: ${{ github.event_name != 'pull_request' }} # Only push on merge to main, not on PRs tags: | ${{ inputs.is_nightly == true && format('ghcr.io/{0}/semantic-router/extproc:nightly-{1}', github.repository_owner, steps.date.outputs.date_tag) || format('ghcr.io/{0}/semantic-router/extproc:{1}', github.repository_owner, github.sha) }} - ${{ inputs.is_nightly != true && format('ghcr.io/{0}/semantic-router/extproc:latest', github.repository_owner) || '' }} \ No newline at end of file + ${{ inputs.is_nightly != true && format('ghcr.io/{0}/semantic-router/extproc:latest', github.repository_owner) || '' }} diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 50591853..be401987 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -45,7 +45,7 @@ jobs: sudo apt-get install -y \ make \ build-essential \ - pkg-config + pkg-config npm install -g markdownlint-cli pip install --user yamllint codespell @@ -81,31 +81,13 @@ jobs: key: ${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} - name: Install pre-commit - run: pip install pre-commit + run: make precommit-install - name: Run Code Spell Check run: make codespell - - name: Run pre-commit on Go, Rust, JavaScript, Markdown, Yaml and Python files - run: | - # Find all Go, Rust, JavaScripts, Markdown and Python files (excluding vendored/generated code) - FILES=$(find . -type f \( -name "*.go" -o -name "*.rs" -o -name "*.py" -o -name "*.js" -o -name "*.md" -o -name "*.yaml" -o -name "*.yml" \) \ - ! -path "./target/*" \ - ! -path "./candle-binding/target/*" \ - ! -path "./.git/*" \ - ! -path "./node_modules/*" \ - ! -path "./vendor/*" \ - ! -path "./__pycache__/*" \ - ! -path "./site/*" \ - ! -name "*.pb.go" \ - | tr '\n' ' ') - - if [ -n "$FILES" ]; then - echo "Running pre-commit on files: $FILES" - pre-commit run --files $FILES - else - echo "No Go, Rust, JavaScript, Markdown, Yaml, or Python files found to check" - fi + - name: Run pre-commit check + run: make precommit-check - name: Show pre-commit results if: failure() diff --git a/.github/workflows/precommit-publish.yml b/.github/workflows/precommit-publish.yml new file mode 100644 index 00000000..985112b3 --- /dev/null +++ b/.github/workflows/precommit-publish.yml @@ -0,0 +1,40 @@ +name: Create and publish Precommit Image + +on: + push: + branches: [ "main" ] + pull_request: + paths: + - 'Dockerfile.precommit' + +jobs: + build_and_push: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Check out the repo + uses: actions/checkout@v4 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Generate date tag for nightly builds + id: date + if: inputs.is_nightly == true + run: echo "date_tag=$(date +'%Y%m%d')" >> $GITHUB_OUTPUT + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile.precommit + push: ${{ github.event_name != 'pull_request' }} # Only push on merge to main, not on PRs + tags: | + ${{ inputs.is_nightly != true && format('ghcr.io/{0}/semantic-router/precommit:latest', github.repository_owner) || '' }} diff --git a/Dockerfile.extproc b/Dockerfile.extproc index f06d05e7..1ba8b45e 100644 --- a/Dockerfile.extproc +++ b/Dockerfile.extproc @@ -15,7 +15,6 @@ COPY tools/make/ tools/make/ COPY Makefile ./ COPY candle-binding/Cargo.toml candle-binding/ COPY candle-binding/src/ candle-binding/src/ -Copy tools ./tools # Use Makefile to build the Rust library RUN make rust diff --git a/Dockerfile.precommit b/Dockerfile.precommit new file mode 100644 index 00000000..ec6ead26 --- /dev/null +++ b/Dockerfile.precommit @@ -0,0 +1,29 @@ +FROM golang:1.24 + +# Install Base env +RUN apt-get update && apt-get install -y \ + make \ + build-essential \ + pkg-config \ + python3 \ + python3-pip + +# Install Node.js and npm +RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \ + apt-get install -y nodejs + +# Install Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ + . $HOME/.cargo/env + +# Markdown +RUN npm install -g markdownlint-cli + +# Install pre-commit and tools +RUN pip install --break-system-packages pre-commit + +# Yamllint +RUN pip install --break-system-packages yamllint + +# CodeSpell +RUN pip install --break-system-packages codespell diff --git a/Makefile b/Makefile index 729ef207..257156d1 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,7 @@ _run: -f tools/make/linter.mk \ -f tools/make/milvus.mk \ -f tools/make/models.mk \ + -f tools/make/pre-commit.mk \ $(MAKECMDGOALS) .PHONY: _run diff --git a/bench/vllm_semantic_router_bench/router_reason_bench_multi_dataset.py b/bench/vllm_semantic_router_bench/router_reason_bench_multi_dataset.py index 270fe8ea..5a963ff0 100644 --- a/bench/vllm_semantic_router_bench/router_reason_bench_multi_dataset.py +++ b/bench/vllm_semantic_router_bench/router_reason_bench_multi_dataset.py @@ -178,35 +178,62 @@ def parse_args(): return parser.parse_args() -def get_dataset_optimal_tokens(dataset_info): +def get_dataset_optimal_tokens(dataset_info, model_name=None): """ - Determine optimal token limit based on dataset complexity and reasoning requirements. + Determine optimal token limit based on dataset complexity, reasoning requirements, and model capabilities. Token limits are optimized for structured response generation while maintaining - efficiency across different reasoning complexity levels. + efficiency across different reasoning complexity levels and model architectures. + + Args: + dataset_info: Dataset information object + model_name: Model identifier (e.g., "openai/gpt-oss-20b", "Qwen/Qwen3-30B-A3B") """ dataset_name = dataset_info.name.lower() difficulty = dataset_info.difficulty_level.lower() - # Optimized token limits per dataset (increased for reasoning mode support) - dataset_tokens = { - "gpqa": 1500, # Graduate-level scientific reasoning + # Determine model type and capabilities + model_multiplier = 1.0 + if model_name: + model_lower = model_name.lower() + if "qwen" in model_lower: + # Qwen models are more efficient and can handle longer contexts + model_multiplier = 1.5 + elif "deepseek" in model_lower: + # DeepSeek models (e.g., V3.1) are capable and can handle longer contexts + model_multiplier = 1.5 + elif "gpt-oss" in model_lower: + # GPT-OSS models use baseline token limits + model_multiplier = 1.0 + # Default to baseline for unknown models + + # Base token limits per dataset (optimized for gpt-oss20b baseline) + base_dataset_tokens = { + "gpqa": 3000, # Graduate-level scientific reasoning (increased for complex multi-step reasoning) "truthfulqa": 800, # Misconception analysis "hellaswag": 800, # Natural continuation reasoning "arc": 800, # Elementary/middle school science "commonsenseqa": 1000, # Common sense reasoning - "mmlu": 600 if difficulty == "undergraduate" else 800, # Academic knowledge + "mmlu": 3000, # Academic knowledge (increased for complex technical domains like engineering/chemistry) } - # Find matching dataset - for dataset_key, tokens in dataset_tokens.items(): + # Find matching dataset and apply model multiplier + base_tokens = None + for dataset_key, tokens in base_dataset_tokens.items(): if dataset_key in dataset_name: - return tokens + base_tokens = tokens + break + + # Fallback to difficulty-based tokens if dataset not found + if base_tokens is None: + difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150} + base_tokens = difficulty_tokens.get(difficulty, 200) - # Default based on difficulty level - difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150} + # Apply model-specific multiplier and round to nearest 50 + final_tokens = int(base_tokens * model_multiplier) + final_tokens = ((final_tokens + 25) // 50) * 50 # Round to nearest 50 - return difficulty_tokens.get(difficulty, 200) + return final_tokens def get_available_models(endpoint: str, api_key: str = "") -> List[str]: @@ -507,6 +534,20 @@ def evaluate_model_vllm_multimode( q.cot_content is not None and q.cot_content.strip() for q in questions[:10] ) + # Debug: Show CoT content status for first few questions + print(f" CoT Debug - Checking first 10 questions:") + for i, q in enumerate(questions[:10]): + cot_status = ( + "None" + if q.cot_content is None + else ( + f"'{q.cot_content[:50]}...'" + if len(q.cot_content) > 50 + else f"'{q.cot_content}'" + ) + ) + print(f" Q{i+1}: CoT = {cot_status}") + if has_cot_content: print(f" Dataset has CoT content - using 3 modes: NR, XC, NR_REASONING") else: @@ -827,20 +868,23 @@ def main(): print(f"Router models: {router_models}") print(f"vLLM models: {vllm_models}") - # Determine optimal token limit for this dataset - if args.max_tokens: - optimal_tokens = args.max_tokens - print(f"Using user-specified max_tokens: {optimal_tokens}") - else: - optimal_tokens = get_dataset_optimal_tokens(dataset_info) - print( - f"Using dataset-optimal max_tokens: {optimal_tokens} (for {dataset_info.name})" - ) + # Function to get optimal tokens for a specific model + # For fair comparison, use consistent token limits regardless of model name + def get_model_optimal_tokens(model_name): + if args.max_tokens: + return args.max_tokens + else: + # Use base dataset tokens without model-specific multipliers for fair comparison + return get_dataset_optimal_tokens(dataset_info, model_name=None) # Router evaluation (NR-only) if args.run_router and router_endpoint and router_models: for model in router_models: + model_tokens = get_model_optimal_tokens(model) print(f"\nEvaluating router model: {model}") + print( + f"Using max_tokens: {model_tokens} (dataset-optimized for fair comparison)" + ) rt_df = evaluate_model_router_transparent( questions=questions, dataset=dataset, @@ -848,7 +892,7 @@ def main(): endpoint=router_endpoint, api_key=router_api_key, concurrent_requests=args.concurrent_requests, - max_tokens=optimal_tokens, + max_tokens=model_tokens, temperature=args.temperature, ) analysis = analyze_results(rt_df) @@ -863,7 +907,11 @@ def main(): # Direct vLLM evaluation (NR/XC with reasoning ON/OFF) if args.run_vllm and vllm_endpoint and vllm_models: for model in vllm_models: + model_tokens = get_model_optimal_tokens(model) print(f"\nEvaluating vLLM model: {model}") + print( + f"Using max_tokens: {model_tokens} (dataset-optimized for fair comparison)" + ) vdf = evaluate_model_vllm_multimode( questions=questions, dataset=dataset, @@ -871,7 +919,7 @@ def main(): endpoint=vllm_endpoint, api_key=vllm_api_key, concurrent_requests=args.concurrent_requests, - max_tokens=optimal_tokens, + max_tokens=model_tokens, temperature=args.temperature, exec_modes=args.vllm_exec_modes, ) diff --git a/config/config.yaml b/config/config.yaml index faabb985..2b44e57d 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -8,6 +8,7 @@ semantic_cache: similarity_threshold: 0.8 max_entries: 1000 # Only applies to memory backend ttl_seconds: 3600 + eviction_policy: "fifo" # "fifo", "lru", "lfu", currently only supports memory backend # For production environments, use Milvus for scalable caching: # backend_type: "milvus" @@ -46,14 +47,14 @@ vllm_endpoints: - "phi4" - "gemma3:27b" weight: 1 # Load balancing weight - health_check_path: "/health" # Optional health check endpoint + health_check_path: "/api/version" # Optional health check endpoint - name: "endpoint2" address: "127.0.0.1" port: 11434 models: - "mistral-small3.1" weight: 1 - health_check_path: "/health" + health_check_path: "/api/version" - name: "endpoint3" address: "127.0.0.1" port: 11434 diff --git a/deploy/kubernetes/crds/vllm.ai_semanticroutes.yaml b/deploy/kubernetes/crds/vllm.ai_semanticroutes.yaml new file mode 100644 index 00000000..c943e699 --- /dev/null +++ b/deploy/kubernetes/crds/vllm.ai_semanticroutes.yaml @@ -0,0 +1,293 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.19.0 + name: semanticroutes.vllm.ai +spec: + group: vllm.ai + names: + kind: SemanticRoute + listKind: SemanticRouteList + plural: semanticroutes + shortNames: + - sr + singular: semanticroute + scope: Namespaced + versions: + - additionalPrinterColumns: + - description: Number of routing rules + jsonPath: .spec.rules + name: Rules + type: integer + - jsonPath: .metadata.creationTimestamp + name: Age + type: date + name: v1alpha1 + schema: + openAPIV3Schema: + description: SemanticRoute defines a semantic routing rule for LLM requests + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: SemanticRouteSpec defines the desired state of SemanticRoute + properties: + rules: + description: Rules defines the routing rules to be applied + items: + description: RouteRule defines a single routing rule + properties: + defaultModel: + description: DefaultModel defines the fallback model if no modelRefs + are available + properties: + address: + description: Address defines the endpoint address + maxLength: 255 + minLength: 1 + type: string + modelName: + description: ModelName defines the name of the model + maxLength: 100 + minLength: 1 + type: string + port: + description: Port defines the endpoint port + format: int32 + maximum: 65535 + minimum: 1 + type: integer + priority: + description: Priority defines the priority of this model + reference (higher values = higher priority) + format: int32 + maximum: 1000 + minimum: 0 + type: integer + weight: + default: 100 + description: Weight defines the traffic weight for this + model (0-100) + format: int32 + maximum: 100 + minimum: 0 + type: integer + required: + - address + - modelName + - port + type: object + filters: + description: Filters defines the optional filters to be applied + to requests matching this rule + items: + description: Filter defines a filter to be applied to requests + properties: + config: + description: Config defines the filter-specific configuration + type: object + x-kubernetes-preserve-unknown-fields: true + enabled: + default: true + description: Enabled defines whether this filter is enabled + type: boolean + type: + allOf: + - enum: + - PIIDetection + - PromptGuard + - SemanticCache + - ReasoningControl + - ToolSelection + - enum: + - PIIDetection + - PromptGuard + - SemanticCache + - ReasoningControl + description: Type defines the filter type + type: string + required: + - type + type: object + maxItems: 20 + type: array + intents: + description: Intents defines the intent categories that this + rule should match + items: + description: Intent defines an intent category for routing + properties: + category: + description: Category defines the intent category name + (e.g., "math", "computer science", "creative") + maxLength: 100 + minLength: 1 + pattern: ^[a-zA-Z0-9\s\-_]+$ + type: string + description: + description: Description provides an optional description + of this intent category + maxLength: 500 + type: string + threshold: + default: 0.7 + description: Threshold defines the confidence threshold + for this intent (0.0-1.0) + maximum: 1 + minimum: 0 + type: number + required: + - category + type: object + maxItems: 50 + minItems: 1 + type: array + modelRefs: + description: ModelRefs defines the target models for this routing + rule + items: + description: ModelRef defines a reference to a model endpoint + properties: + address: + description: Address defines the endpoint address + maxLength: 255 + minLength: 1 + type: string + modelName: + description: ModelName defines the name of the model + maxLength: 100 + minLength: 1 + type: string + port: + description: Port defines the endpoint port + format: int32 + maximum: 65535 + minimum: 1 + type: integer + priority: + description: Priority defines the priority of this model + reference (higher values = higher priority) + format: int32 + maximum: 1000 + minimum: 0 + type: integer + weight: + default: 100 + description: Weight defines the traffic weight for this + model (0-100) + format: int32 + maximum: 100 + minimum: 0 + type: integer + required: + - address + - modelName + - port + type: object + maxItems: 10 + minItems: 1 + type: array + required: + - intents + - modelRefs + type: object + maxItems: 100 + minItems: 1 + type: array + required: + - rules + type: object + status: + description: SemanticRouteStatus defines the observed state of SemanticRoute + properties: + activeRules: + description: ActiveRules indicates the number of currently active + routing rules + format: int32 + type: integer + conditions: + description: Conditions represent the latest available observations + of the SemanticRoute's current state + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration reflects the generation of the most + recently observed SemanticRoute + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/deploy/llm-router-dashboard.json b/deploy/llm-router-dashboard.json index a23a21e6..44bfb9a4 100644 --- a/deploy/llm-router-dashboard.json +++ b/deploy/llm-router-dashboard.json @@ -94,7 +94,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "sum by(category) (llm_category_classifications_total)", + "expr": "sum by(category) (llm_category_classifications_count)", "fullMetaSearch": false, "includeNullMetadata": true, "instant": false, @@ -440,4 +440,4 @@ "uid": "llm-router-metrics", "version": 12, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/README.md b/docker/README.md index cc868f9b..22b6ec95 100644 --- a/docker/README.md +++ b/docker/README.md @@ -12,7 +12,7 @@ This Docker Compose configuration allows you to quickly run Semantic Router + En 1. **Clone the repository and navigate to the project directory** ```bash - git clone + git clone https://github.com/vllm-project/semantic-router.git cd semantic_router ``` diff --git a/e2e-tests/00-client-request-test.py b/e2e-tests/00-client-request-test.py index bd33b788..7073b788 100644 --- a/e2e-tests/00-client-request-test.py +++ b/e2e-tests/00-client-request-test.py @@ -4,6 +4,8 @@ This test validates that the Envoy proxy is running and accepting requests, and that basic request formatting works correctly. + +Signed-off-by: Yossi Ovadia """ import json @@ -14,14 +16,13 @@ import requests -# Add parent directory to path to allow importing common test utilities -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from tests.test_base import SemanticRouterTestBase +# Import test base from same directory +from test_base import SemanticRouterTestBase # Constants ENVOY_URL = "http://localhost:8801" OPENAI_ENDPOINT = "/v1/chat/completions" -DEFAULT_MODEL = "qwen2.5:32b" # Changed to match other tests +DEFAULT_MODEL = "gemma3:27b" # Use configured model MAX_RETRIES = 3 RETRY_DELAY = 2 @@ -177,14 +178,14 @@ def test_malformed_request(self): try: response = self._make_request( - payload, timeout=10 + payload, timeout=30 ) # Reduced timeout for error cases if not response: response = requests.post( f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers={"Content-Type": "application/json"}, json=payload, - timeout=10, + timeout=30, ) passed = 400 <= response.status_code < 500 diff --git a/e2e-tests/01-envoy-extproc-test.py b/e2e-tests/01-envoy-extproc-test.py index 34e6f472..0f910ac0 100644 --- a/e2e-tests/01-envoy-extproc-test.py +++ b/e2e-tests/01-envoy-extproc-test.py @@ -5,23 +5,25 @@ This test verifies that Envoy is correctly forwarding requests to the ExtProc, and that the ExtProc is responding with appropriate routing decisions. These tests use custom headers to trace request processing. + +Signed-off-by: Yossi Ovadia """ import json import os import sys +import unittest import uuid import requests -# Add parent directory to path to allow importing common test utilities -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from tests.test_base import SemanticRouterTestBase +# Import test base from same directory +from test_base import SemanticRouterTestBase # Constants ENVOY_URL = "http://localhost:8801" OPENAI_ENDPOINT = "/v1/chat/completions" -DEFAULT_MODEL = "qwen2.5:32b" # Changed from gemma3:27b to match make test-prompt +DEFAULT_MODEL = "gemma3:27b" # Use configured model class EnvoyExtProcTest(SemanticRouterTestBase): diff --git a/e2e-tests/02-router-classification-test.py b/e2e-tests/02-router-classification-test.py index 040a522c..eac8eca1 100644 --- a/e2e-tests/02-router-classification-test.py +++ b/e2e-tests/02-router-classification-test.py @@ -4,25 +4,27 @@ This test validates the router's ability to classify different types of queries and select the appropriate model based on the content. + +Signed-off-by: Yossi Ovadia """ import json import os import sys import time +import unittest from collections import defaultdict import requests -# Add parent directory to path to allow importing common test utilities -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from tests.test_base import SemanticRouterTestBase +# Import test base from same directory +from test_base import SemanticRouterTestBase # Constants ENVOY_URL = "http://localhost:8801" OPENAI_ENDPOINT = "/v1/chat/completions" ROUTER_METRICS_URL = "http://localhost:9190/metrics" -DEFAULT_MODEL = "qwen2.5:32b" # Changed from gemma3:27b to match make test-prompt +DEFAULT_MODEL = "gemma3:27b" # Use configured model # Category test cases - each designed to trigger a specific classifier category CATEGORY_TEST_CASES = [ @@ -38,6 +40,20 @@ }, ] # Reduced to just 2 test cases to avoid timeouts +# Auto routing test cases - should trigger different model selection +AUTO_ROUTING_TEST_CASES = [ + { + "name": "Math Problem (should route to phi4)", + "expected_model": "phi4", + "content": "Calculate the derivative of f(x) = x^3 + 2x^2 - 5x + 7", + }, + { + "name": "Creative Writing (should route to another model)", + "expected_model_not": "phi4", # Should NOT be phi4 since phi4 is optimized for math + "content": "Write a poem about the ocean at sunset", + }, +] + class RouterClassificationTest(SemanticRouterTestBase): """Test the router's classification functionality.""" @@ -68,7 +84,7 @@ def setUp(self): f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers={"Content-Type": "application/json"}, json=payload, - timeout=60, + timeout=(10, 60), # (connect timeout, read timeout) ) if response.status_code >= 500: @@ -129,7 +145,7 @@ def test_classification_consistency(self): f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers={"Content-Type": "application/json"}, json=payload, - timeout=10, + timeout=(10, 60), # (connect timeout, read timeout) ) passed = response.status_code < 400 @@ -185,7 +201,7 @@ def test_category_classification(self): f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers={"Content-Type": "application/json"}, json=payload, - timeout=60, + timeout=(10, 60), # (connect timeout, read timeout) ) passed = response.status_code < 400 @@ -267,6 +283,105 @@ def test_classifier_metrics(self): self.assertGreaterEqual(metrics_found, 0, "No classification metrics found") + def test_auto_routing_intelligence(self): + """Test that auto model selection actually routes different queries to different specialized models.""" + self.print_test_header( + "Auto Routing Intelligence Test", + "Verifies that model='auto' actually routes different query types to different specialized models", + ) + + results = {} + + for test_case in AUTO_ROUTING_TEST_CASES: + self.print_subtest_header(test_case["name"]) + + payload = { + "model": "auto", # This should trigger intelligent routing + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": test_case["content"]}, + ], + "temperature": 0.7, + } + + self.print_request_info( + payload=payload, + expectations=f"Expect: Auto routing to select appropriate specialized model", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=30, + ) + + passed = response.status_code == 200 + + try: + response_json = response.json() + selected_model = response_json.get("model", "unknown") + except: + selected_model = "unknown" + + results[test_case["name"]] = selected_model + + # Check if routing met expectations + if "expected_model" in test_case: + routing_correct = selected_model == test_case["expected_model"] + routing_message = f"Expected {test_case['expected_model']}, got {selected_model}" + elif "expected_model_not" in test_case: + routing_correct = selected_model != test_case["expected_model_not"] + routing_message = f"Should NOT be {test_case['expected_model_not']}, got {selected_model}" + else: + routing_correct = True + routing_message = f"Got {selected_model}" + + self.print_response_info( + response, + { + "Query Type": test_case["name"], + "Selected Model": selected_model, + "Routing Expectation": routing_message, + "Routing Correct": routing_correct, + }, + ) + + self.print_test_result( + passed=passed and routing_correct, + message=( + f"Auto routing working: {selected_model} for {test_case['name']}" + if passed and routing_correct + else f"Auto routing failed: {routing_message}" + ), + ) + + self.assertEqual(response.status_code, 200, f"Auto routing request failed with status {response.status_code}") + + # Check routing intelligence + if "expected_model" in test_case: + self.assertEqual( + selected_model, + test_case["expected_model"], + f"Auto routing failed: expected {test_case['expected_model']}, got {selected_model}" + ) + elif "expected_model_not" in test_case: + self.assertNotEqual( + selected_model, + test_case["expected_model_not"], + f"Auto routing failed: got {selected_model}, should not be {test_case['expected_model_not']}" + ) + + # Print summary of routing decisions + print(f"\nAuto Routing Summary:") + for test_name, model in results.items(): + print(f" {test_name}: {model}") + + # Ensure we got different models for different query types (intelligence test) + unique_models = set(results.values()) + if len(unique_models) == 1: + self.fail(f"Auto routing not working - all queries routed to same model: {list(unique_models)[0]}") + if __name__ == "__main__": unittest.main() diff --git a/e2e-tests/04-cache-test.py b/e2e-tests/04-cache-test.py index ce76e377..9291d004 100644 --- a/e2e-tests/04-cache-test.py +++ b/e2e-tests/04-cache-test.py @@ -4,19 +4,21 @@ This test validates the semantic cache functionality by sending similar queries and checking if cache hits occur as expected. + +Signed-off-by: Yossi Ovadia """ import json import os import sys import time +import unittest import uuid import requests -# Add parent directory to path to allow importing common test utilities -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from tests.test_base import SemanticRouterTestBase +# Import test base from same directory +from test_base import SemanticRouterTestBase # Constants ENVOY_URL = "http://localhost:8801" @@ -136,7 +138,7 @@ def test_cache_hit_with_identical_query(self): # First request self.print_subtest_header("First Request (Expected Cache Miss)") response1 = requests.post( - f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers=headers, json=payload, timeout=10 + f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers=headers, json=payload, timeout=30 ) response1_json = response1.json() @@ -157,7 +159,7 @@ def test_cache_hit_with_identical_query(self): # Second identical request self.print_subtest_header("Second Request (Expected Cache Hit)") response2 = requests.post( - f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers=headers, json=payload, timeout=10 + f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers=headers, json=payload, timeout=30 ) response2_json = response2.json() @@ -237,7 +239,7 @@ def test_cache_hit_with_similar_query(self): headers = {"Content-Type": "application/json", "X-Session-ID": session_id} response1 = requests.post( - f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers=headers, json=payload1, timeout=10 + f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers=headers, json=payload1, timeout=30 ) response1_json = response1.json() @@ -273,7 +275,7 @@ def test_cache_hit_with_similar_query(self): ) response2 = requests.post( - f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers=headers, json=payload2, timeout=10 + f"{ENVOY_URL}{OPENAI_ENDPOINT}", headers=headers, json=payload2, timeout=30 ) response2_json = response2.json() diff --git a/e2e-tests/05-pii-policy-test.py b/e2e-tests/05-pii-policy-test.py new file mode 100644 index 00000000..7a215739 --- /dev/null +++ b/e2e-tests/05-pii-policy-test.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +""" +05-pii-policy-test.py - PII detection and policy tests + +This test validates the router's ability to detect PII in requests and enforce +model-specific PII policies correctly. + +Signed-off-by: Yossi Ovadia +""" + +import json +import os +import sys +import time +import unittest +import uuid + +import requests + +# Import test base from same directory +from test_base import SemanticRouterTestBase + +# Constants +ENVOY_URL = "http://localhost:8801" +OPENAI_ENDPOINT = "/v1/chat/completions" +ROUTER_METRICS_URL = "http://localhost:9190/metrics" +DEFAULT_MODEL = "gemma3:27b" # Use configured model + +# PII test cases - based on gemma3:27b policy: allows ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] +PII_TEST_CASES = [ + { + "name": "Allowed PII - Email", + "expected_result": "allowed", + "content": "Please send the report to john.doe@company.com for review.", + "expected_pii_types": ["EMAIL_ADDRESS"], + }, + { + "name": "Allowed PII - Person Name", + "expected_result": "allowed", + "content": "John Smith submitted the application yesterday.", + "expected_pii_types": ["PERSON"], + }, + { + "name": "Allowed PII - Location (GPE)", + "expected_result": "allowed", + "content": "The conference will be held in New York City.", + "expected_pii_types": ["GPE"], + }, + { + "name": "Allowed PII - Phone Number", + "expected_result": "allowed", + "content": "You can reach me at (555) 123-4567 for more information.", + "expected_pii_types": ["PHONE_NUMBER"], + }, + { + "name": "Mixed Allowed PII", + "expected_result": "allowed", + "content": "Contact John Smith in New York at john@company.com or (555) 123-4567.", + "expected_pii_types": ["PERSON", "GPE", "EMAIL_ADDRESS", "PHONE_NUMBER"], + }, +] + +# Cases that might contain PII types not allowed by gemma3:27b policy +POTENTIAL_BLOCKED_PII_CASES = [ + { + "name": "Credit Card Number", + "content": "My credit card number is 4532-1234-5678-9012.", + "potential_pii_types": ["CREDIT_CARD"], + }, + { + "name": "Social Security Number", + "content": "My SSN is 123-45-6789 for verification.", + "potential_pii_types": ["SSN"], + }, + { + "name": "Medical Information", + "content": "Patient ID 12345 has diabetes and takes insulin.", + "potential_pii_types": ["MEDICAL_LICENSE"], + }, +] + +# No PII cases - should always be allowed +NO_PII_CASES = [ + { + "name": "General Question", + "expected_result": "allowed", + "content": "What is the weather like today?", + }, + { + "name": "Math Problem", + "expected_result": "allowed", + "content": "What is 2 + 2?", + }, + { + "name": "Technical Question", + "expected_result": "allowed", + "content": "How does machine learning work?", + }, +] + + +class PIIDetectionPolicyTest(SemanticRouterTestBase): + """Test the router's PII detection and policy enforcement functionality.""" + + def setUp(self): + """Check if the services are running before running tests.""" + self.print_test_header( + "Setup Check", + "Verifying that required services (Envoy and Router) are running and PII detection is enabled", + ) + + # Check Envoy + try: + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "assistant", "content": "You are a helpful assistant."}, + {"role": "user", "content": "test"}, + ], + } + + self.print_request_info( + payload=payload, + expectations="Expect: Service health check to succeed with 2xx status code", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=(10, 60), # (connect timeout, read timeout) + ) + + if response.status_code >= 500: + self.skipTest( + f"Envoy server returned server error: {response.status_code}" + ) + + self.print_response_info(response) + + except requests.exceptions.ConnectionError: + self.skipTest("Cannot connect to Envoy server. Is it running?") + + # Check router metrics endpoint + try: + response = requests.get(ROUTER_METRICS_URL, timeout=2) + if response.status_code != 200: + self.skipTest( + "Router metrics server is not responding. Is the router running?" + ) + + self.print_response_info(response, {"Service": "Router Metrics"}) + + except requests.exceptions.ConnectionError: + self.skipTest( + "Cannot connect to router metrics server. Is the router running?" + ) + + # Check if PII detection metrics exist + response = requests.get(ROUTER_METRICS_URL) + metrics_text = response.text + if "pii" not in metrics_text.lower(): + self.skipTest("PII metrics not found. PII detection may be disabled.") + + def test_no_pii_requests_allowed(self): + """Test that requests with no PII are always allowed.""" + self.print_test_header( + "No PII Requests Test", + "Verifies that requests without PII are processed normally", + ) + + for test_case in NO_PII_CASES: + with self.subTest(test_case["name"]): + self.print_subtest_header(test_case["name"]) + + session_id = str(uuid.uuid4()) + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": test_case["content"]}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + } + + self.print_request_info( + payload=payload, + expectations="Expect: Request with no PII to be processed normally", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=(10, 60), # (connect timeout, read timeout) + ) + + # No PII requests should be processed successfully + passed = response.status_code == 200 + + try: + response_json = response.json() + model = response_json.get("model", "unknown") + except: + model = "N/A" + + self.print_response_info( + response, + { + "Content": test_case["content"][:50] + "...", + "Selected Model": model, + "Session ID": session_id, + "PII Status": "Expected: No PII", + }, + ) + + self.print_test_result( + passed=passed, + message=( + f"No PII request processed normally (status: {response.status_code})" + if passed + else f"No PII request blocked unexpectedly (status: {response.status_code})" + ), + ) + + self.assertEqual( + response.status_code, + 200, + f"No PII request '{test_case['name']}' failed with status {response.status_code}. Expected: 200 (service must be working)", + ) + + def test_allowed_pii_requests(self): + """Test that requests with allowed PII types are processed.""" + self.print_test_header( + "Allowed PII Requests Test", + "Verifies that requests with allowed PII types (EMAIL_ADDRESS, PERSON, GPE, PHONE_NUMBER) are processed", + ) + + for test_case in PII_TEST_CASES: + with self.subTest(test_case["name"]): + self.print_subtest_header(test_case["name"]) + + session_id = str(uuid.uuid4()) + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": test_case["content"]}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + } + + self.print_request_info( + payload=payload, + expectations=f"Expect: Request with allowed PII types {test_case['expected_pii_types']} to be processed", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=(10, 60), # (connect timeout, read timeout) + ) + + # Allowed PII requests should be processed successfully + passed = response.status_code == 200 + + try: + response_json = response.json() + model = response_json.get("model", "unknown") + except: + model = "N/A" + + self.print_response_info( + response, + { + "Content": test_case["content"][:50] + "...", + "Selected Model": model, + "Session ID": session_id, + "Expected PII Types": test_case["expected_pii_types"], + "PII Status": "Expected: Allowed", + }, + ) + + self.print_test_result( + passed=passed, + message=( + f"Allowed PII request processed normally (status: {response.status_code})" + if passed + else f"Allowed PII request blocked unexpectedly (status: {response.status_code})" + ), + ) + + self.assertEqual( + response.status_code, + 200, + f"Allowed PII request '{test_case['name']}' failed with status {response.status_code}. Expected: 200 (service must be working)", + ) + + def test_pii_policy_consistency(self): + """Test that PII policy decisions are consistent for the same content.""" + self.print_test_header( + "PII Policy Consistency Test", + "Verifies that the same content consistently triggers the same PII policy decision", + ) + + test_content = "Please contact John Smith at john@company.com for assistance." + + results = [] + for i in range(3): + self.print_subtest_header(f"Consistency Test {i+1}") + + session_id = str(uuid.uuid4()) + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": test_content}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + } + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=(10, 60), # (connect timeout, read timeout) + ) + + # Record the response status for consistency checking + results.append(response.status_code) + + self.print_response_info( + response, + { + "Attempt": i + 1, + "Status Code": response.status_code, + "Content": test_content[:50] + "...", + }, + ) + + time.sleep(1) # Small delay between requests + + # Check consistency - all results should be the same + is_consistent = len(set(results)) == 1 + self.print_test_result( + passed=is_consistent, + message=f"PII policy consistency: {is_consistent}. Results: {results}", + ) + + self.assertEqual( + len(set(results)), 1, f"Inconsistent PII policy results: {results}" + ) + + def test_pii_detection_metrics(self): + """Test that PII detection metrics are being recorded properly.""" + self.print_test_header( + "PII Detection Metrics Test", + "Verifies that PII detection metrics are being properly recorded and exposed", + ) + + # Get baseline metrics + response = requests.get(ROUTER_METRICS_URL) + metrics_text = response.text + + # Look for specific PII metrics + pii_metrics = [ + "llm_classifier_latency_seconds_count", # Classification timing + "llm_request_errors_total", # Blocked requests with reason="pii_block" + "llm_model_requests_total", # Total requests + ] + + metrics_found = {} + for metric in pii_metrics: + for line in metrics_text.split("\n"): + if metric in line and not line.startswith("#"): + # For classifier metrics, ensure it's specifically for pii + if "classifier" in metric and "pii" not in line: + continue + # For error metrics, ensure it's specifically pii_block + if "errors" in metric and "pii_block" not in line: + continue + # Extract metric value + try: + parts = line.strip().split() + if len(parts) >= 2: + metrics_found[metric] = float(parts[-1]) + break + except ValueError: + continue + + self.print_response_info( + response, + {"Metrics Found": len(metrics_found), "Total Expected": len(pii_metrics)}, + ) + + # Print detailed metrics information + for metric, value in metrics_found.items(): + print(f"\nMetric: {metric}") + print(f" Value: {value}") + + # Print any metrics that contain "pii" even if not in our expected list + print(f"\nAll PII-related metrics found:") + for line in metrics_text.split("\n"): + if "pii" in line.lower() and not line.startswith("#") and line.strip(): + print(f" {line.strip()}") + + passed = len(metrics_found) > 0 + self.print_test_result( + passed=passed, + message=f"Found {len(metrics_found)} out of {len(pii_metrics)} expected PII metrics", + ) + + self.assertGreater(len(metrics_found), 0, "No PII metrics found") + + def test_model_pii_policy_configuration(self): + """Test that different models have different PII policies configured.""" + self.print_test_header( + "Model PII Policy Configuration Test", + "Verifies that the router correctly applies different PII policies for different models", + ) + + # Test with gemma3:27b (has restrictive PII policy) + test_models = [DEFAULT_MODEL] + test_content = "Contact John at john@company.com" + + for model in test_models: + self.print_subtest_header(f"Testing Model: {model}") + + session_id = str(uuid.uuid4()) + payload = { + "model": model, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": test_content}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + } + + self.print_request_info( + payload=payload, + expectations=f"Expect: Model {model} to apply its specific PII policy", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=(10, 60), # (connect timeout, read timeout) + ) + + try: + response_json = response.json() + selected_model = response_json.get("model", "unknown") + except: + selected_model = "N/A" + + self.print_response_info( + response, + { + "Requested Model": model, + "Selected Model": selected_model, + "Status Code": response.status_code, + "Content": test_content, + }, + ) + + # The request should be processed successfully + passed = response.status_code == 200 + self.print_test_result( + passed=passed, + message=f"Model {model} PII policy applied correctly", + ) + + self.assertEqual( + response.status_code, + 200, + f"Model {model} PII policy failed with status {response.status_code}. Expected: 200 (service must be working)", + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/e2e-tests/06-tools-test.py b/e2e-tests/06-tools-test.py new file mode 100644 index 00000000..8c1cb46b --- /dev/null +++ b/e2e-tests/06-tools-test.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 +""" +06-tools-test.py - Tools selection tests + +This test validates the router's ability to automatically select appropriate tools +based on request content using semantic similarity matching. + +Signed-off-by: Yossi Ovadia +""" + +import json +import os +import sys +import time +import unittest +import uuid + +import requests + +# Import test base from same directory +from test_base import SemanticRouterTestBase + +# Constants +ENVOY_URL = "http://localhost:8801" +OPENAI_ENDPOINT = "/v1/chat/completions" +ROUTER_METRICS_URL = "http://localhost:9190/metrics" +DEFAULT_MODEL = "gemma3:27b" # Use configured model + +# Tool test cases - based on the 5 tools configured in config/tools_db.json +TOOL_TEST_CASES = [ + { + "name": "Weather Query", + "expected_tools": ["get_weather"], + "content": "What's the weather forecast for tomorrow in San Francisco?", + "description": "Should match get_weather tool", + }, + { + "name": "Search Request", + "expected_tools": ["search_web"], + "content": "Can you find information about the latest AI research papers?", + "description": "Should match search_web tool", + }, + { + "name": "Mathematical Calculation", + "expected_tools": ["calculate"], + "content": "What is 25% of 847 plus the square root of 169?", + "description": "Should match calculate tool", + }, + { + "name": "Email Request", + "expected_tools": ["send_email"], + "content": "Please send an email to the team about the meeting update.", + "description": "Should match send_email tool", + }, + { + "name": "Calendar/Scheduling", + "expected_tools": ["create_calendar_event"], + "content": "Schedule a meeting with the development team for next Friday at 2 PM.", + "description": "Should match create_calendar_event tool", + }, + { + "name": "Multi-tool Request", + "expected_tools": ["get_weather", "create_calendar_event"], + "content": "Check the weather for the conference location and schedule a meeting to discuss it.", + "description": "Should match multiple tools", + }, +] + +# Cases that should not match any tools strongly +NO_TOOL_CASES = [ + { + "name": "General Conversation", + "content": "Hello, how are you doing today?", + "description": "Generic greeting, should not match specific tools", + }, + { + "name": "Simple Question", + "content": "What is artificial intelligence?", + "description": "General knowledge question, no specific tool needed", + }, + { + "name": "Creative Writing", + "content": "Write a short poem about the ocean.", + "description": "Creative task, no specific tool needed", + }, +] + + +class ToolsSelectionTest(SemanticRouterTestBase): + """Test the router's automatic tool selection functionality.""" + + def setUp(self): + """Check if the services are running before running tests.""" + self.print_test_header( + "Setup Check", + "Verifying that required services (Envoy and Router) are running and tools selection is enabled", + ) + + # Check Envoy + try: + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "assistant", "content": "You are a helpful assistant."}, + {"role": "user", "content": "test"}, + ], + } + + self.print_request_info( + payload=payload, + expectations="Expect: Service health check to succeed with 2xx status code", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=60, + ) + + if response.status_code >= 500: + self.skipTest( + f"Envoy server returned server error: {response.status_code}" + ) + + self.print_response_info(response) + + except requests.exceptions.ConnectionError: + self.skipTest("Cannot connect to Envoy server. Is it running?") + + # Check router metrics endpoint + try: + response = requests.get(ROUTER_METRICS_URL, timeout=2) + if response.status_code != 200: + self.skipTest( + "Router metrics server is not responding. Is the router running?" + ) + + self.print_response_info(response, {"Service": "Router Metrics"}) + + except requests.exceptions.ConnectionError: + self.skipTest( + "Cannot connect to router metrics server. Is the router running?" + ) + + # Check if tools metrics exist + response = requests.get(ROUTER_METRICS_URL) + metrics_text = response.text + if "tool" not in metrics_text.lower(): + self.skipTest("Tools metrics not found. Tools selection may be disabled.") + + def test_specific_tool_selection(self): + """Test that specific requests match their expected tools.""" + self.print_test_header( + "Specific Tool Selection Test", + "Verifies that requests for specific functionality match the appropriate tools", + ) + + for test_case in TOOL_TEST_CASES: + with self.subTest(test_case["name"]): + self.print_subtest_header(test_case["name"]) + + session_id = str(uuid.uuid4()) + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "system", "content": "You are a helpful assistant with access to various tools."}, + {"role": "user", "content": test_case["content"]}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + "X-Tools-Enabled": "true", # Explicitly request tool selection + } + + self.print_request_info( + payload=payload, + expectations=f"Expect: Request to match tools {test_case['expected_tools']}", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=30, + ) + + # Tool selection should work regardless of vLLM backend availability + # Tool selection should work successfully + passed = response.status_code == 200 + + try: + response_json = response.json() + model = response_json.get("model", "unknown") + tools = response_json.get("tools", []) + except: + model = "N/A" + tools = [] + + self.print_response_info( + response, + { + "Content": test_case["content"][:50] + "...", + "Selected Model": model, + "Expected Tools": test_case["expected_tools"], + "Selected Tools": tools if tools else "N/A", + "Session ID": session_id, + "Description": test_case["description"], + }, + ) + + self.print_test_result( + passed=passed, + message=( + f"Tool selection request processed (status: {response.status_code})" + if passed + else f"Tool selection request failed (status: {response.status_code})" + ), + ) + + self.assertIn( + response.status_code, + [200, 503], + f"Tool selection request '{test_case['name']}' failed. Status: {response.status_code}", + ) + + def test_no_tool_requests(self): + """Test that generic requests don't unnecessarily trigger tool selection.""" + self.print_test_header( + "No Tool Requests Test", + "Verifies that generic requests don't inappropriately match specific tools", + ) + + for test_case in NO_TOOL_CASES: + with self.subTest(test_case["name"]): + self.print_subtest_header(test_case["name"]) + + session_id = str(uuid.uuid4()) + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": test_case["content"]}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + } + + self.print_request_info( + payload=payload, + expectations="Expect: Generic request to be processed without specific tool selection", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=30, + ) + + # Tool selection should work successfully + passed = response.status_code == 200 + + try: + response_json = response.json() + model = response_json.get("model", "unknown") + tools = response_json.get("tools", []) + except: + model = "N/A" + tools = [] + + self.print_response_info( + response, + { + "Content": test_case["content"][:50] + "...", + "Selected Model": model, + "Selected Tools": tools if tools else "None", + "Session ID": session_id, + "Description": test_case["description"], + }, + ) + + self.print_test_result( + passed=passed, + message=( + f"Generic request processed normally (status: {response.status_code})" + if passed + else f"Generic request failed (status: {response.status_code})" + ), + ) + + self.assertIn( + response.status_code, + [200, 503], + f"Generic request '{test_case['name']}' failed. Status: {response.status_code}", + ) + + def test_tools_configuration_validation(self): + """Test that the tools configuration is properly loaded and accessible.""" + self.print_test_header( + "Tools Configuration Validation Test", + "Verifies that the tools database is properly loaded with expected tools", + ) + + # Make a request that should trigger tool processing + session_id = str(uuid.uuid4()) + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "system", "content": "You are a helpful assistant with access to tools."}, + {"role": "user", "content": "I need help with weather, calculations, and scheduling."}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + "X-Tools-Debug": "true", # Request debug information + } + + self.print_request_info( + payload=payload, + expectations="Expect: Request to be processed and tools configuration to be validated", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=30, + ) + + passed = response.status_code in [200, 503] + + try: + response_json = response.json() + model = response_json.get("model", "unknown") + except: + model = "N/A" + + self.print_response_info( + response, + { + "Selected Model": model, + "Status Code": response.status_code, + "Session ID": session_id, + "Tools Config": "Expected: 5 tools loaded", + }, + ) + + self.print_test_result( + passed=passed, + message=f"Tools configuration validation completed (status: {response.status_code})", + ) + + self.assertIn( + response.status_code, + [200, 503], + f"Tools configuration validation failed. Status: {response.status_code}", + ) + + def test_tools_metrics(self): + """Test that tools selection metrics are being recorded properly.""" + self.print_test_header( + "Tools Selection Metrics Test", + "Verifies that tools selection metrics are being properly recorded and exposed", + ) + + # Get baseline metrics + response = requests.get(ROUTER_METRICS_URL) + metrics_text = response.text + + # Look for specific tools metrics + tools_metrics = [ + "llm_router_tools_selected_total", + "llm_router_tools_selection_duration_seconds", + "llm_router_tools_database_size", + "llm_router_requests_total", + ] + + metrics_found = {} + for metric in tools_metrics: + for line in metrics_text.split("\n"): + if metric in line and not line.startswith("#"): + # Extract metric value + try: + parts = line.strip().split() + if len(parts) >= 2: + metrics_found[metric] = float(parts[-1]) + break + except ValueError: + continue + + self.print_response_info( + response, + {"Metrics Found": len(metrics_found), "Total Expected": len(tools_metrics)}, + ) + + # Print detailed metrics information + for metric, value in metrics_found.items(): + print(f"\nMetric: {metric}") + print(f" Value: {value}") + + # Print any metrics that contain "tool" even if not in our expected list + print(f"\nAll tools-related metrics found:") + for line in metrics_text.split("\n"): + if "tool" in line.lower() and not line.startswith("#") and line.strip(): + print(f" {line.strip()}") + + passed = len(metrics_found) > 0 + self.print_test_result( + passed=passed, + message=f"Found {len(metrics_found)} out of {len(tools_metrics)} expected tools metrics", + ) + + self.assertGreater(len(metrics_found), 0, "No tools metrics found") + + def test_tools_selection_consistency(self): + """Test that tool selection is consistent for the same content.""" + self.print_test_header( + "Tools Selection Consistency Test", + "Verifies that the same content consistently triggers the same tool selection", + ) + + test_content = "What's the weather like today in New York?" + + results = [] + for i in range(3): + self.print_subtest_header(f"Consistency Test {i+1}") + + session_id = str(uuid.uuid4()) + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "system", "content": "You are a helpful assistant with access to tools."}, + {"role": "user", "content": test_content}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + "X-Tools-Enabled": "true", + } + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=30, + ) + + # Record the response status for consistency checking + results.append(response.status_code) + + try: + response_json = response.json() + tools = response_json.get("tools", []) + except: + tools = [] + + self.print_response_info( + response, + { + "Attempt": i + 1, + "Status Code": response.status_code, + "Content": test_content[:50] + "...", + "Selected Tools": tools if tools else "None", + }, + ) + + time.sleep(1) # Small delay between requests + + # Check consistency - all results should be the same + is_consistent = len(set(results)) == 1 + self.print_test_result( + passed=is_consistent, + message=f"Tools selection consistency: {is_consistent}. Results: {results}", + ) + + self.assertEqual( + len(set(results)), 1, f"Inconsistent tools selection results: {results}" + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/e2e-tests/07-model-selection-test.py b/e2e-tests/07-model-selection-test.py new file mode 100644 index 00000000..801daea8 --- /dev/null +++ b/e2e-tests/07-model-selection-test.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python3 +""" +07-model-selection-test.py - Model selection and scoring tests + +This test validates the router's ability to select appropriate models based on +content categories and configured scoring rules. + +Signed-off-by: Yossi Ovadia +""" + +import json +import os +import sys +import time +import unittest +import uuid + +import requests + +# Import test base from same directory +from test_base import SemanticRouterTestBase + +# Constants +ENVOY_URL = "http://localhost:8801" +OPENAI_ENDPOINT = "/v1/chat/completions" +ROUTER_METRICS_URL = "http://localhost:9190/metrics" +DEFAULT_MODEL = "gemma3:27b" # Use configured model + +# Model selection test cases - based on category scores in config.yaml +MODEL_SELECTION_TEST_CASES = [ + { + "name": "Math Query - Should Prefer phi4", + "category": "math", + "content": "Solve this differential equation: dy/dx + 2y = x^2", + "preferred_models": ["phi4"], # phi4 has score 1.0 for math + "reasoning_enabled": True, + }, + { + "name": "Business Query - Should Prefer phi4", + "category": "business", + "content": "What are the key strategies for improving quarterly revenue in a SaaS company?", + "preferred_models": ["phi4"], # phi4 has score 0.8 for business + "reasoning_enabled": False, + }, + { + "name": "Law Query - Should Prefer gemma3:27b", + "category": "law", + "content": "What are the legal implications of GDPR compliance for international data transfers?", + "preferred_models": ["gemma3:27b"], # gemma3:27b has score 0.8 for law + "reasoning_enabled": False, + }, + { + "name": "Chemistry Query - Should Prefer mistral-small3.1", + "category": "chemistry", + "content": "Explain the mechanism of nucleophilic substitution in organic chemistry.", + "preferred_models": ["mistral-small3.1"], # mistral has score 0.8 for chemistry + "reasoning_enabled": True, + }, + { + "name": "History Query - Should Prefer mistral-small3.1", + "category": "history", + "content": "Analyze the causes and consequences of the French Revolution.", + "preferred_models": ["mistral-small3.1"], # mistral has score 0.8 for history + "reasoning_enabled": False, + }, + { + "name": "General Query - Should Use Default", + "category": "other", + "content": "What is artificial intelligence and how does it work?", + "preferred_models": ["gemma3:27b"], # gemma3:27b has score 0.8 for other + "reasoning_enabled": False, + }, +] + +# Reasoning test cases - categories that should enable reasoning +REASONING_TEST_CASES = [ + { + "name": "Physics Problem - Reasoning Enabled", + "category": "physics", + "content": "Calculate the trajectory of a projectile launched at 45 degrees with initial velocity 20 m/s.", + "reasoning_enabled": True, + }, + { + "name": "Computer Science Problem - Reasoning Enabled", + "category": "computer science", + "content": "Design an algorithm to find the shortest path in a weighted directed graph.", + "reasoning_enabled": True, + }, + { + "name": "Engineering Problem - Reasoning Enabled", + "category": "engineering", + "content": "Design a bridge to span 100 meters with maximum load capacity of 50 tons.", + "reasoning_enabled": True, + }, + { + "name": "Biology Analysis - Reasoning Enabled", + "category": "biology", + "content": "Explain the process of protein synthesis from DNA to functional protein.", + "reasoning_enabled": True, + }, +] + +# Non-reasoning categories +NON_REASONING_CASES = [ + { + "name": "Psychology Discussion - No Reasoning", + "category": "psychology", + "content": "Discuss the psychological effects of social media on teenagers.", + "reasoning_enabled": False, + }, + { + "name": "Philosophy Discussion - No Reasoning", + "category": "philosophy", + "content": "What is the meaning of consciousness in modern philosophy?", + "reasoning_enabled": False, + }, +] + + +class ModelSelectionTest(SemanticRouterTestBase): + """Test the router's model selection and scoring functionality.""" + + def setUp(self): + """Check if the services are running before running tests.""" + self.print_test_header( + "Setup Check", + "Verifying that required services (Envoy and Router) are running and model selection is enabled", + ) + + # Check Envoy + try: + payload = { + "model": DEFAULT_MODEL, + "messages": [ + {"role": "assistant", "content": "You are a helpful assistant."}, + {"role": "user", "content": "test"}, + ], + } + + self.print_request_info( + payload=payload, + expectations="Expect: Service health check to succeed with 2xx status code", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=60, + ) + + if response.status_code >= 500: + self.skipTest( + f"Envoy server returned server error: {response.status_code}" + ) + + self.print_response_info(response) + + except requests.exceptions.ConnectionError: + self.skipTest("Cannot connect to Envoy server. Is it running?") + + # Check router metrics endpoint + try: + response = requests.get(ROUTER_METRICS_URL, timeout=2) + if response.status_code != 200: + self.skipTest( + "Router metrics server is not responding. Is the router running?" + ) + + self.print_response_info(response, {"Service": "Router Metrics"}) + + except requests.exceptions.ConnectionError: + self.skipTest( + "Cannot connect to router metrics server. Is the router running?" + ) + + # Check if model selection metrics exist + response = requests.get(ROUTER_METRICS_URL) + metrics_text = response.text + if "model_selection" not in metrics_text.lower(): + self.skipTest("Model selection metrics not found. Model selection may be disabled.") + + def test_category_based_model_selection(self): + """Test that models are selected based on category scores.""" + self.print_test_header( + "Category-Based Model Selection Test", + "Verifies that the router selects models based on category scoring rules", + ) + + for test_case in MODEL_SELECTION_TEST_CASES: + with self.subTest(test_case["name"]): + self.print_subtest_header(test_case["name"]) + + session_id = str(uuid.uuid4()) + payload = { + "model": "auto", # Request automatic model selection + "messages": [ + {"role": "system", "content": f"You are an expert in {test_case['category']}."}, + {"role": "user", "content": test_case["content"]}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + "X-Category-Hint": test_case["category"], # Provide category hint + } + + self.print_request_info( + payload=payload, + expectations=f"Expect: Auto-selection to prefer {test_case['preferred_models']} for {test_case['category']}", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=30, + ) + + # Model selection should work successfully + passed = response.status_code == 200 + + try: + response_json = response.json() + selected_model = response_json.get("model", "unknown") + reasoning_enabled = response_json.get("reasoning_enabled", False) + except: + selected_model = "N/A" + reasoning_enabled = False + + self.print_response_info( + response, + { + "Category": test_case["category"], + "Content": test_case["content"][:50] + "...", + "Selected Model": selected_model, + "Preferred Models": test_case["preferred_models"], + "Reasoning Enabled": reasoning_enabled, + "Expected Reasoning": test_case["reasoning_enabled"], + "Session ID": session_id, + }, + ) + + self.print_test_result( + passed=passed, + message=( + f"Model selection processed (status: {response.status_code}, model: {selected_model})" + if passed + else f"Model selection failed (status: {response.status_code})" + ), + ) + + self.assertEqual( + response.status_code, + 200, + f"Model selection request '{test_case['name']}' failed with status {response.status_code}. Expected: 200 (service must be working)", + ) + + def test_reasoning_mode_selection(self): + """Test that reasoning mode is enabled for appropriate categories.""" + self.print_test_header( + "Reasoning Mode Selection Test", + "Verifies that reasoning mode is enabled for categories that require structured thinking", + ) + + all_reasoning_cases = REASONING_TEST_CASES + NON_REASONING_CASES + + for test_case in all_reasoning_cases: + with self.subTest(test_case["name"]): + self.print_subtest_header(test_case["name"]) + + session_id = str(uuid.uuid4()) + payload = { + "model": "auto", + "messages": [ + {"role": "system", "content": f"You are an expert in {test_case['category']}."}, + {"role": "user", "content": test_case["content"]}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + "X-Category-Hint": test_case["category"], + } + + self.print_request_info( + payload=payload, + expectations=f"Expect: Reasoning mode {'enabled' if test_case['reasoning_enabled'] else 'disabled'} for {test_case['category']}", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=30, + ) + + # Reasoning mode should work successfully + passed = response.status_code == 200 + + try: + response_json = response.json() + selected_model = response_json.get("model", "unknown") + reasoning_enabled = response_json.get("reasoning_enabled", False) + except: + selected_model = "N/A" + reasoning_enabled = False + + self.print_response_info( + response, + { + "Category": test_case["category"], + "Content": test_case["content"][:50] + "...", + "Selected Model": selected_model, + "Reasoning Enabled": reasoning_enabled, + "Expected Reasoning": test_case["reasoning_enabled"], + "Session ID": session_id, + }, + ) + + self.print_test_result( + passed=passed, + message=( + f"Reasoning mode selection processed correctly" + if passed + else f"Reasoning mode selection failed" + ), + ) + + self.assertEqual( + response.status_code, + 200, + f"Reasoning mode test '{test_case['name']}' failed with status {response.status_code}. Expected: 200 (service must be working)", + ) + + def test_model_fallback_behavior(self): + """Test model fallback when preferred models are unavailable.""" + self.print_test_header( + "Model Fallback Behavior Test", + "Verifies that the router falls back to available models when preferred models are unavailable", + ) + + # Test with a specific model that might not be available + session_id = str(uuid.uuid4()) + payload = { + "model": "non-existent-model", # Request a non-existent model + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is machine learning?"}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + } + + self.print_request_info( + payload=payload, + expectations="Expect: Router to fallback to default model when requested model is unavailable", + ) + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=30, + ) + + # Fallback should work - 400 is acceptable for invalid model request + passed = response.status_code in [200, 400] # 400 is acceptable for invalid model + + try: + response_json = response.json() + selected_model = response_json.get("model", "unknown") + except: + selected_model = "N/A" + + self.print_response_info( + response, + { + "Requested Model": "non-existent-model", + "Selected Model": selected_model, + "Status Code": response.status_code, + "Session ID": session_id, + }, + ) + + self.print_test_result( + passed=passed, + message=f"Model fallback behavior tested (status: {response.status_code}, model: {selected_model})", + ) + + self.assertIn( + response.status_code, + [200, 400], + f"Model fallback test failed with status {response.status_code}. Expected: 200 (fallback) or 400 (invalid model)", + ) + + def test_model_selection_metrics(self): + """Test that model selection metrics are being recorded properly.""" + self.print_test_header( + "Model Selection Metrics Test", + "Verifies that model selection metrics are being properly recorded and exposed", + ) + + # Get baseline metrics + response = requests.get(ROUTER_METRICS_URL) + metrics_text = response.text + + # Look for specific model selection metrics + model_metrics = [ + "llm_router_model_selection_count", + "llm_router_model_selection_duration_seconds", + "llm_router_category_classification_total", + "llm_router_reasoning_enabled_total", + "llm_router_requests_total", + ] + + metrics_found = {} + for metric in model_metrics: + for line in metrics_text.split("\n"): + if metric in line and not line.startswith("#"): + # Extract metric value + try: + parts = line.strip().split() + if len(parts) >= 2: + metrics_found[metric] = float(parts[-1]) + break + except ValueError: + continue + + self.print_response_info( + response, + {"Metrics Found": len(metrics_found), "Total Expected": len(model_metrics)}, + ) + + # Print detailed metrics information + for metric, value in metrics_found.items(): + print(f"\nMetric: {metric}") + print(f" Value: {value}") + + # Print any metrics that contain "model" or "category" even if not in our expected list + print(f"\nAll model/category-related metrics found:") + for line in metrics_text.split("\n"): + if (("model" in line.lower() or "category" in line.lower() or "reasoning" in line.lower()) + and not line.startswith("#") and line.strip()): + print(f" {line.strip()}") + + passed = len(metrics_found) > 0 + self.print_test_result( + passed=passed, + message=f"Found {len(metrics_found)} out of {len(model_metrics)} expected model selection metrics", + ) + + self.assertGreater(len(metrics_found), 0, "No model selection metrics found") + + def test_model_selection_consistency(self): + """Test that model selection is consistent for the same content and category.""" + self.print_test_header( + "Model Selection Consistency Test", + "Verifies that the same content consistently triggers the same model selection", + ) + + test_content = "Solve this quadratic equation: x^2 + 5x + 6 = 0" + category = "math" + + results = [] + for i in range(3): + self.print_subtest_header(f"Consistency Test {i+1}") + + session_id = str(uuid.uuid4()) + payload = { + "model": "auto", + "messages": [ + {"role": "system", "content": f"You are an expert in {category}."}, + {"role": "user", "content": test_content}, + ], + "temperature": 0.7, + } + + headers = { + "Content-Type": "application/json", + "X-Session-ID": session_id, + "X-Category-Hint": category, + } + + response = requests.post( + f"{ENVOY_URL}{OPENAI_ENDPOINT}", + headers=headers, + json=payload, + timeout=30, + ) + + # Record the response status and model for consistency checking + try: + response_json = response.json() + selected_model = response_json.get("model", "unknown") + results.append((response.status_code, selected_model)) + except: + results.append((response.status_code, "N/A")) + + self.print_response_info( + response, + { + "Attempt": i + 1, + "Status Code": response.status_code, + "Selected Model": results[-1][1], + "Category": category, + "Content": test_content[:50] + "...", + }, + ) + + time.sleep(1) # Small delay between requests + + # Check consistency - all results should be the same + status_codes = [r[0] for r in results] + models = [r[1] for r in results] + + status_consistent = len(set(status_codes)) == 1 + model_consistent = len(set(models)) == 1 + + is_consistent = status_consistent and model_consistent + + self.print_test_result( + passed=is_consistent, + message=f"Model selection consistency: {is_consistent}. Status codes: {status_codes}, Models: {models}", + ) + + self.assertEqual( + len(set(status_codes)), 1, f"Inconsistent status codes: {status_codes}" + ) + self.assertEqual( + len(set(models)), 1, f"Inconsistent model selection: {models}" + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/e2e-tests/TEST_STATUS_REPORT.md b/e2e-tests/TEST_STATUS_REPORT.md new file mode 100644 index 00000000..23ec414e --- /dev/null +++ b/e2e-tests/TEST_STATUS_REPORT.md @@ -0,0 +1,36 @@ +# E2E Test Status Report +*Generated: 2024-09-22* + + + +## ✅ **PASSING TESTS** + +**00-client-request-test.py** - Basic client connectivity and request/response validation + +**01-envoy-extproc-test.py** - Envoy ExtProc integration and request/response modification + +**02-router-classification-test.py** - Semantic routing intelligence and model selection based on query type + +**04-cache-test.py** - Semantic caching functionality (skipped - cache disabled as expected) + +**05-pii-policy-test.py** - PII detection and policy enforcement for allowed/blocked data types + +**06-tools-test.py** - Automatic tool selection based on semantic similarity matching + +**07-model-selection-test.py** - Category-based model selection and fallback behavior + +**test_base.py** - Base test utilities and helper functions + +--- + +## 📋 **TEST COVERAGE** + +This test suite validates the core functionality of the vLLM Semantic Router system: + +- **Client Integration**: Basic request/response handling through Envoy proxy +- **ExtProc Interface**: Envoy external processing integration +- **Semantic Routing**: Intelligent model selection based on content classification +- **Caching**: Semantic caching system (currently disabled) +- **PII Detection**: Privacy protection and data filtering +- **Tool Selection**: Automatic tool matching based on request content +- **Model Selection**: Category-based routing and fallback mechanisms \ No newline at end of file diff --git a/examples/semanticroute/README.md b/examples/semanticroute/README.md new file mode 100644 index 00000000..d232b3c1 --- /dev/null +++ b/examples/semanticroute/README.md @@ -0,0 +1,179 @@ +# SemanticRoute Examples + +This directory contains various examples of SemanticRoute configurations demonstrating different routing scenarios and capabilities. + +## Examples Overview + +### 1. Simple Intent Routing (`simple-intent-routing.yaml`) + +A basic example showing intent-based routing for math and computer science queries. + +**Features:** + +- Simple intent matching with categories +- Single model reference with fallback +- Minimal configuration + +**Use Case:** Basic routing based on query categories without complex filtering. + +### 2. Complex Filter Chain (`complex-filter-chain.yaml`) + +Demonstrates a comprehensive filter chain with multiple security and performance filters. + +**Features:** + +- PII detection with custom allowed types +- Prompt guard with custom security rules +- Semantic caching for performance +- Reasoning control configuration + +**Use Case:** Production environments requiring security, privacy, and performance optimizations. + +### 3. Multiple Routes (`multiple-routes.yaml`) + +Shows how to define multiple routing rules within a single SemanticRoute resource. + +**Features:** + +- Separate rules for technical vs. creative queries +- Different reasoning configurations per rule +- Rule-specific caching strategies + +**Use Case:** Applications serving diverse query types with different processing requirements. + +### 4. Weighted Routing (`weighted-routing.yaml`) + +Demonstrates traffic distribution across multiple model endpoints using weights and priorities. + +**Features:** + +- Traffic splitting (80/20) between models +- Priority-based failover +- Load balancing configuration + +**Use Case:** A/B testing, gradual rollouts, or load distribution across model endpoints. + +### 5. Tool Selection Example (`tool-selection-example.yaml`) + +Demonstrates automatic tool selection based on semantic similarity to user queries. + +**Features:** + +- Automatic tool selection with configurable similarity threshold +- Tool filtering by categories and tags +- Fallback behavior configuration +- Integration with semantic caching and reasoning control + +**Use Case:** Applications requiring dynamic tool selection based on user intent and query content. + +### 6. Comprehensive Example (`comprehensive-example.yaml`) + +A production-ready configuration showcasing all SemanticRoute features. + +**Features:** + +- Multiple rules with different configurations +- Advanced filtering with custom rules +- External cache backend (Redis) +- High-availability model setup +- Comprehensive security policies + +**Use Case:** Enterprise production deployments requiring full feature utilization. + +## Deployment Instructions + +### Prerequisites + +1. Kubernetes cluster with SemanticRoute CRD installed: + + ```bash + kubectl apply -f ../../deploy/kubernetes/crds/vllm.ai_semanticroutes.yaml + ``` + +2. Ensure your model endpoints are accessible from the cluster. + +### Deploy Examples + +1. **Deploy a single example:** + + ```bash + kubectl apply -f simple-intent-routing.yaml + ``` + +2. **Deploy all examples:** + + ```bash + kubectl apply -f . + ``` + +3. **Verify deployment:** + + ```bash + kubectl get semanticroutes + kubectl describe semanticroute reasoning-route + ``` + +## Configuration Reference + +### Intent Configuration + +```yaml +intents: +- category: "math" # Required: Intent category name + description: "Mathematics queries" # Optional: Human-readable description + threshold: 0.7 # Optional: Confidence threshold (0.0-1.0) +``` + +### Model Reference Configuration + +```yaml +modelRefs: +- modelName: "gpt-oss" # Required: Model identifier + address: "127.0.0.1" # Required: Endpoint address + port: 8080 # Required: Endpoint port + weight: 80 # Optional: Traffic weight (0-100) + priority: 100 # Optional: Priority for failover +``` + +### Filter Configuration + +Each filter type has specific configuration options: + +- **PIIDetection**: Controls PII detection and handling +- **PromptGuard**: Provides security and jailbreak protection +- **SemanticCache**: Enables response caching for performance +- **ReasoningControl**: Manages reasoning mode behavior +- **ToolSelection**: Enables automatic tool selection based on semantic similarity + +## Best Practices + +1. **Start Simple**: Begin with basic intent routing and add filters as needed. + +2. **Test Thoroughly**: Validate routing behavior with representative queries. + +3. **Monitor Performance**: Use appropriate cache settings and monitor hit rates. + +4. **Security First**: Enable PII detection and prompt guard in production. + +5. **Gradual Rollout**: Use weighted routing for safe model deployments. + +## Troubleshooting + +### Common Issues + +1. **Route Not Matching**: Check intent categories and thresholds. +2. **Model Unreachable**: Verify endpoint addresses and network connectivity. +3. **Filter Errors**: Validate filter configurations against the schema. + +### Debugging Commands + +```bash +# Check SemanticRoute status +kubectl get sr -o wide + +# View detailed configuration +kubectl describe semanticroute + +# Check logs (if controller is deployed) +kubectl logs -l app=semantic-router-controller +``` diff --git a/examples/semanticroute/complex-filter-chain.yaml b/examples/semanticroute/complex-filter-chain.yaml new file mode 100644 index 00000000..6eabacfc --- /dev/null +++ b/examples/semanticroute/complex-filter-chain.yaml @@ -0,0 +1,59 @@ +apiVersion: vllm.ai/v1alpha1 +kind: SemanticRoute +metadata: + name: complex-route + namespace: default + labels: + app: semantic-router + scenario: complex-filter-chain +spec: + rules: + - intents: + - category: "computer science" + description: "Programming, algorithms, data structures" + threshold: 0.7 + - category: "math" + description: "Mathematics, calculus, algebra" + threshold: 0.7 + modelRefs: + - modelName: gpt-oss + address: 127.0.0.1 + port: 8080 + weight: 100 + filters: + - type: PIIDetection + enabled: true + config: + allowByDefault: false + pii_types_allowed: ["EMAIL_ADDRESS", "PERSON"] + threshold: 0.7 + action: "block" + - type: PromptGuard + enabled: true + config: + threshold: 0.7 + action: "block" + customRules: + - name: "sensitive-data-rule" + pattern: "(?i)(password|secret|token|key)" + action: "block" + description: "Block requests containing sensitive data keywords" + - type: SemanticCache + enabled: true + config: + similarityThreshold: 0.8 + maxEntries: 1000 + ttlSeconds: 3600 + backend: "memory" + - type: ReasoningControl + enabled: true + config: + reasonFamily: "gpt-oss" + enableReasoning: true + reasoningEffort: "medium" + maxReasoningSteps: 10 + reasoningTimeout: 30 + defaultModel: + modelName: deepseek-v31 + address: 127.0.0.1 + port: 8088 diff --git a/examples/semanticroute/comprehensive-example.yaml b/examples/semanticroute/comprehensive-example.yaml new file mode 100644 index 00000000..fd5db4d9 --- /dev/null +++ b/examples/semanticroute/comprehensive-example.yaml @@ -0,0 +1,109 @@ +apiVersion: vllm.ai/v1alpha1 +kind: SemanticRoute +metadata: + name: comprehensive-example + namespace: default + labels: + app: semantic-router + scenario: comprehensive + environment: production +spec: + rules: + # Rule 1: High-performance reasoning route for technical queries + - intents: + - category: "computer science" + description: "Programming, algorithms, software engineering" + threshold: 0.75 + - category: "math" + description: "Advanced mathematics, calculus, statistics" + threshold: 0.75 + modelRefs: + - modelName: gpt-oss-premium + address: 127.0.0.1 + port: 8080 + weight: 70 + priority: 100 + - modelName: claude-reasoning + address: 127.0.0.1 + port: 8082 + weight: 30 + priority: 95 + filters: + - type: PIIDetection + enabled: true + config: + allowByDefault: false + pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE"] + threshold: 0.8 + action: "block" + - type: PromptGuard + enabled: true + config: + threshold: 0.75 + action: "block" + customRules: + - name: "code-injection-rule" + pattern: "(?i)(eval|exec|system|shell|cmd)" + action: "warn" + description: "Detect potential code injection attempts" + - type: SemanticCache + enabled: true + config: + similarityThreshold: 0.85 + maxEntries: 2000 + ttlSeconds: 7200 + backend: "redis" + backendConfig: + host: "redis.cache.svc.cluster.local" + port: "6379" + - type: ReasoningControl + enabled: true + config: + reasonFamily: "gpt-oss" + enableReasoning: true + reasoningEffort: "high" + maxReasoningSteps: 20 + reasoningTimeout: 60 + defaultModel: + modelName: deepseek-v31 + address: 127.0.0.1 + port: 8088 + + # Rule 2: Creative and general purpose route + - intents: + - category: "creative" + description: "Creative writing, storytelling, art generation" + threshold: 0.6 + - category: "other" + description: "General purpose conversations" + threshold: 0.5 + modelRefs: + - modelName: creative-model + address: 127.0.0.1 + port: 8081 + weight: 100 + filters: + - type: PIIDetection + enabled: true + config: + allowByDefault: true + pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] + threshold: 0.7 + action: "mask" + - type: ReasoningControl + enabled: true + config: + reasonFamily: "gpt-oss" + enableReasoning: false + reasoningEffort: "low" + - type: SemanticCache + enabled: true + config: + similarityThreshold: 0.75 + maxEntries: 1000 + ttlSeconds: 3600 + backend: "memory" + defaultModel: + modelName: general-model + address: 127.0.0.1 + port: 8089 diff --git a/examples/semanticroute/multiple-routes.yaml b/examples/semanticroute/multiple-routes.yaml new file mode 100644 index 00000000..011c8901 --- /dev/null +++ b/examples/semanticroute/multiple-routes.yaml @@ -0,0 +1,72 @@ +apiVersion: vllm.ai/v1alpha1 +kind: SemanticRoute +metadata: + name: multiple-routes + namespace: default + labels: + app: semantic-router + scenario: multiple-routes +spec: + rules: + # Rule 1: Reasoning-enabled route for technical queries + - intents: + - category: "computer science" + description: "Programming, algorithms, data structures" + threshold: 0.7 + - category: "math" + description: "Mathematics, calculus, algebra" + threshold: 0.7 + modelRefs: + - modelName: gpt-oss + address: 127.0.0.1 + port: 8080 + weight: 100 + filters: + - type: ReasoningControl + enabled: true + config: + reasonFamily: "gpt-oss" + enableReasoning: true + reasoningEffort: "high" + maxReasoningSteps: 15 + - type: SemanticCache + enabled: true + config: + similarityThreshold: 0.85 + maxEntries: 500 + ttlSeconds: 7200 + defaultModel: + modelName: deepseek-v31 + address: 127.0.0.1 + port: 8088 + + # Rule 2: Lightweight route for creative and general queries + - intents: + - category: "creative" + description: "Creative writing, storytelling, art" + threshold: 0.6 + - category: "other" + description: "General purpose queries" + threshold: 0.5 + modelRefs: + - modelName: lightweight-model + address: 127.0.0.1 + port: 8081 + weight: 100 + filters: + - type: ReasoningControl + enabled: true + config: + reasonFamily: "gpt-oss" + enableReasoning: false + reasoningEffort: "low" + - type: PIIDetection + enabled: true + config: + allowByDefault: true + threshold: 0.8 + action: "mask" + defaultModel: + modelName: general-model + address: 127.0.0.1 + port: 8089 diff --git a/examples/semanticroute/simple-intent-routing.yaml b/examples/semanticroute/simple-intent-routing.yaml new file mode 100644 index 00000000..99abd5b0 --- /dev/null +++ b/examples/semanticroute/simple-intent-routing.yaml @@ -0,0 +1,26 @@ +apiVersion: vllm.ai/v1alpha1 +kind: SemanticRoute +metadata: + name: reasoning-route + namespace: default + labels: + app: semantic-router + scenario: simple-intent +spec: + rules: + - intents: + - category: "computer science" + description: "Programming, algorithms, data structures, software engineering" + threshold: 0.7 + - category: "math" + description: "Mathematics, calculus, algebra, statistics" + threshold: 0.7 + modelRefs: + - modelName: gpt-oss + address: 127.0.0.1 + port: 8080 + weight: 100 + defaultModel: + modelName: deepseek-v31 + address: 127.0.0.1 + port: 8088 diff --git a/examples/semanticroute/tool-selection-example.yaml b/examples/semanticroute/tool-selection-example.yaml new file mode 100644 index 00000000..7cd4c219 --- /dev/null +++ b/examples/semanticroute/tool-selection-example.yaml @@ -0,0 +1,50 @@ +apiVersion: vllm.ai/v1alpha1 +kind: SemanticRoute +metadata: + name: tool-selection-example + namespace: default + labels: + app: semantic-router + scenario: tool-selection +spec: + rules: + - intents: + - category: "computer science" + description: "Programming, algorithms, data structures" + threshold: 0.7 + - category: "math" + description: "Mathematics, calculus, algebra" + threshold: 0.7 + modelRefs: + - modelName: gpt-oss + address: 127.0.0.1 + port: 8080 + weight: 100 + filters: + - type: ToolSelection + enabled: true + config: + topK: 3 + similarityThreshold: 0.8 + toolsDBPath: "config/tools_db.json" + fallbackToEmpty: true + categories: ["weather", "calculation", "search"] + tags: ["utility", "api", "function"] + - type: SemanticCache + enabled: true + config: + similarityThreshold: 0.85 + maxEntries: 1000 + ttlSeconds: 3600 + backend: "memory" + - type: ReasoningControl + enabled: true + config: + reasonFamily: "gpt-oss" + enableReasoning: true + reasoningEffort: "medium" + maxReasoningSteps: 10 + defaultModel: + modelName: deepseek-v31 + address: 127.0.0.1 + port: 8088 diff --git a/examples/semanticroute/weighted-routing.yaml b/examples/semanticroute/weighted-routing.yaml new file mode 100644 index 00000000..19f381cb --- /dev/null +++ b/examples/semanticroute/weighted-routing.yaml @@ -0,0 +1,49 @@ +apiVersion: vllm.ai/v1alpha1 +kind: SemanticRoute +metadata: + name: weighted-routing + namespace: default + labels: + app: semantic-router + scenario: weighted-routing +spec: + rules: + - intents: + - category: "computer science" + description: "Programming, algorithms, data structures" + threshold: 0.7 + - category: "math" + description: "Mathematics, calculus, algebra" + threshold: 0.7 + modelRefs: + # Primary model gets 80% of traffic + - modelName: gpt-oss + address: 127.0.0.1 + port: 8080 + weight: 80 + priority: 100 + # Secondary model gets 20% of traffic + - modelName: qwen3 + address: 127.0.0.1 + port: 8089 + weight: 20 + priority: 90 + filters: + - type: ReasoningControl + enabled: true + config: + reasonFamily: "gpt-oss" + enableReasoning: true + reasoningEffort: "medium" + maxReasoningSteps: 10 + - type: SemanticCache + enabled: true + config: + similarityThreshold: 0.8 + maxEntries: 1000 + ttlSeconds: 3600 + backend: "memory" + defaultModel: + modelName: deepseek-v31 + address: 127.0.0.1 + port: 8088 diff --git a/src/semantic-router/go.mod b/src/semantic-router/go.mod index c467d425..e3406d7b 100644 --- a/src/semantic-router/go.mod +++ b/src/semantic-router/go.mod @@ -24,6 +24,7 @@ require ( go.uber.org/zap v1.27.0 google.golang.org/grpc v1.71.1 gopkg.in/yaml.v3 v3.0.1 + k8s.io/apimachinery v0.31.4 ) require ( @@ -34,17 +35,22 @@ require ( github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect github.com/cockroachdb/redact v1.1.3 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect + github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/getsentry/sentry-go v0.12.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect @@ -55,8 +61,9 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/x448/float16 v0.8.4 // indirect go.uber.org/automaxprocs v1.6.0 // indirect - go.uber.org/multierr v1.10.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/net v0.41.0 // indirect golang.org/x/sync v0.15.0 // indirect golang.org/x/sys v0.33.0 // indirect @@ -64,4 +71,10 @@ require ( golang.org/x/tools v0.33.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect google.golang.org/protobuf v1.36.6 // indirect + gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + k8s.io/klog/v2 v2.130.1 // indirect + k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 // indirect + sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect + sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect ) diff --git a/src/semantic-router/go.sum b/src/semantic-router/go.sum index acb526d9..42ee628e 100644 --- a/src/semantic-router/go.sum +++ b/src/semantic-router/go.sum @@ -32,8 +32,9 @@ github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3Ee github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -53,6 +54,8 @@ github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= github.com/getsentry/sentry-go v0.12.0 h1:era7g0re5iY13bHSdN/xMkyV+5zZppjRVQhZrXCaEIk= github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TBdRL1M71JZW2c= @@ -107,10 +110,13 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -133,6 +139,8 @@ github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0Gqw github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8= @@ -180,9 +188,12 @@ github.com/milvus-io/milvus-sdk-go/v2 v2.4.2/go.mod h1:ulO1YUXKH0PGg50q27grw048G github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= @@ -209,8 +220,9 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= @@ -240,6 +252,8 @@ github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkU github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -272,6 +286,8 @@ github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBn github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= @@ -301,8 +317,8 @@ go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwE go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= -go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= @@ -448,11 +464,16 @@ gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.51.1/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= @@ -460,3 +481,15 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +k8s.io/apimachinery v0.31.4 h1:8xjE2C4CzhYVm9DGf60yohpNUh5AEBnPxCryPBECmlM= +k8s.io/apimachinery v0.31.4/go.mod h1:rsPdaZJfTfLsNJSQzNHQvYoTmxhoOEofxtOsF3rtsMo= +k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= +k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= +k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 h1:pUdcCO1Lk/tbT5ztQWOBi5HBgbBP1J8+AsQnQCKsi8A= +k8s.io/utils v0.0.0-20240711033017-18e509b52bc8/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= +sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= +sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= +sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08= +sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= +sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= diff --git a/src/semantic-router/hack/boilerplate.go.txt b/src/semantic-router/hack/boilerplate.go.txt new file mode 100644 index 00000000..8f48a295 --- /dev/null +++ b/src/semantic-router/hack/boilerplate.go.txt @@ -0,0 +1,15 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ diff --git a/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/doc.go b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/doc.go new file mode 100644 index 00000000..fa6e3ba6 --- /dev/null +++ b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/doc.go @@ -0,0 +1,4 @@ +// Package v1alpha1 contains API Schema definitions for the v1alpha1 API group +// +kubebuilder:object:generate=true +// +groupName=vllm.ai +package v1alpha1 diff --git a/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/filter_helpers.go b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/filter_helpers.go new file mode 100644 index 00000000..fef12657 --- /dev/null +++ b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/filter_helpers.go @@ -0,0 +1,253 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + "encoding/json" + "fmt" + + "k8s.io/apimachinery/pkg/runtime" +) + +// FilterConfigHelper provides helper methods for working with filter configurations +type FilterConfigHelper struct{} + +// NewFilterConfigHelper creates a new FilterConfigHelper +func NewFilterConfigHelper() *FilterConfigHelper { + return &FilterConfigHelper{} +} + +// MarshalFilterConfig marshals a filter configuration to RawExtension +func (h *FilterConfigHelper) MarshalFilterConfig(config interface{}) (*runtime.RawExtension, error) { + if config == nil { + return nil, nil + } + + data, err := json.Marshal(config) + if err != nil { + return nil, fmt.Errorf("failed to marshal filter config: %w", err) + } + + return &runtime.RawExtension{Raw: data}, nil +} + +// UnmarshalPIIDetectionConfig unmarshals a PIIDetectionConfig from RawExtension +func (h *FilterConfigHelper) UnmarshalPIIDetectionConfig(raw *runtime.RawExtension) (*PIIDetectionConfig, error) { + if raw == nil || len(raw.Raw) == 0 { + return &PIIDetectionConfig{}, nil + } + + var config PIIDetectionConfig + if err := json.Unmarshal(raw.Raw, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal PIIDetectionConfig: %w", err) + } + + return &config, nil +} + +// UnmarshalPromptGuardConfig unmarshals a PromptGuardConfig from RawExtension +func (h *FilterConfigHelper) UnmarshalPromptGuardConfig(raw *runtime.RawExtension) (*PromptGuardConfig, error) { + if raw == nil || len(raw.Raw) == 0 { + return &PromptGuardConfig{}, nil + } + + var config PromptGuardConfig + if err := json.Unmarshal(raw.Raw, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal PromptGuardConfig: %w", err) + } + + return &config, nil +} + +// UnmarshalSemanticCacheConfig unmarshals a SemanticCacheConfig from RawExtension +func (h *FilterConfigHelper) UnmarshalSemanticCacheConfig(raw *runtime.RawExtension) (*SemanticCacheConfig, error) { + if raw == nil || len(raw.Raw) == 0 { + return &SemanticCacheConfig{}, nil + } + + var config SemanticCacheConfig + if err := json.Unmarshal(raw.Raw, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal SemanticCacheConfig: %w", err) + } + + return &config, nil +} + +// UnmarshalReasoningControlConfig unmarshals a ReasoningControlConfig from RawExtension +func (h *FilterConfigHelper) UnmarshalReasoningControlConfig(raw *runtime.RawExtension) (*ReasoningControlConfig, error) { + if raw == nil || len(raw.Raw) == 0 { + return &ReasoningControlConfig{}, nil + } + + var config ReasoningControlConfig + if err := json.Unmarshal(raw.Raw, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal ReasoningControlConfig: %w", err) + } + + return &config, nil +} + +// MarshalToolSelectionConfig marshals a ToolSelectionConfig to RawExtension +func (h *FilterConfigHelper) MarshalToolSelectionConfig(config *ToolSelectionConfig) (*runtime.RawExtension, error) { + if config == nil { + return &runtime.RawExtension{}, nil + } + + data, err := json.Marshal(config) + if err != nil { + return nil, fmt.Errorf("failed to marshal ToolSelectionConfig: %w", err) + } + + return &runtime.RawExtension{Raw: data}, nil +} + +// UnmarshalToolSelectionConfig unmarshals a ToolSelectionConfig from RawExtension +func (h *FilterConfigHelper) UnmarshalToolSelectionConfig(raw *runtime.RawExtension) (*ToolSelectionConfig, error) { + if raw == nil || len(raw.Raw) == 0 { + return &ToolSelectionConfig{}, nil + } + + var config ToolSelectionConfig + if err := json.Unmarshal(raw.Raw, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal ToolSelectionConfig: %w", err) + } + + return &config, nil +} + +// UnmarshalFilterConfig unmarshals a filter configuration based on the filter type +func (h *FilterConfigHelper) UnmarshalFilterConfig(filterType FilterType, raw *runtime.RawExtension) (interface{}, error) { + switch filterType { + case FilterTypePIIDetection: + return h.UnmarshalPIIDetectionConfig(raw) + case FilterTypePromptGuard: + return h.UnmarshalPromptGuardConfig(raw) + case FilterTypeSemanticCache: + return h.UnmarshalSemanticCacheConfig(raw) + case FilterTypeReasoningControl: + return h.UnmarshalReasoningControlConfig(raw) + case FilterTypeToolSelection: + return h.UnmarshalToolSelectionConfig(raw) + default: + return nil, fmt.Errorf("unsupported filter type: %s", filterType) + } +} + +// ValidateFilterConfig validates a filter configuration +func (h *FilterConfigHelper) ValidateFilterConfig(filter *Filter) error { + if filter == nil { + return fmt.Errorf("filter cannot be nil") + } + + // Validate filter type + switch filter.Type { + case FilterTypePIIDetection, FilterTypePromptGuard, FilterTypeSemanticCache, FilterTypeReasoningControl, FilterTypeToolSelection: + // Valid filter types + default: + return fmt.Errorf("invalid filter type: %s", filter.Type) + } + + // If config is provided, try to unmarshal it to validate structure + if filter.Config != nil { + _, err := h.UnmarshalFilterConfig(filter.Type, filter.Config) + if err != nil { + return fmt.Errorf("invalid filter config for type %s: %w", filter.Type, err) + } + } + + return nil +} + +// CreatePIIDetectionFilter creates a PIIDetection filter with the given configuration +func CreatePIIDetectionFilter(config *PIIDetectionConfig) (*Filter, error) { + helper := NewFilterConfigHelper() + rawConfig, err := helper.MarshalFilterConfig(config) + if err != nil { + return nil, err + } + + enabled := true + return &Filter{ + Type: FilterTypePIIDetection, + Config: rawConfig, + Enabled: &enabled, + }, nil +} + +// CreatePromptGuardFilter creates a PromptGuard filter with the given configuration +func CreatePromptGuardFilter(config *PromptGuardConfig) (*Filter, error) { + helper := NewFilterConfigHelper() + rawConfig, err := helper.MarshalFilterConfig(config) + if err != nil { + return nil, err + } + + enabled := true + return &Filter{ + Type: FilterTypePromptGuard, + Config: rawConfig, + Enabled: &enabled, + }, nil +} + +// CreateSemanticCacheFilter creates a SemanticCache filter with the given configuration +func CreateSemanticCacheFilter(config *SemanticCacheConfig) (*Filter, error) { + helper := NewFilterConfigHelper() + rawConfig, err := helper.MarshalFilterConfig(config) + if err != nil { + return nil, err + } + + enabled := true + return &Filter{ + Type: FilterTypeSemanticCache, + Config: rawConfig, + Enabled: &enabled, + }, nil +} + +// CreateReasoningControlFilter creates a ReasoningControl filter with the given configuration +func CreateReasoningControlFilter(config *ReasoningControlConfig) (*Filter, error) { + helper := NewFilterConfigHelper() + rawConfig, err := helper.MarshalFilterConfig(config) + if err != nil { + return nil, err + } + + enabled := true + return &Filter{ + Type: FilterTypeReasoningControl, + Config: rawConfig, + Enabled: &enabled, + }, nil +} + +// CreateToolSelectionFilter creates a ToolSelection filter with the given configuration +func CreateToolSelectionFilter(config *ToolSelectionConfig) (*Filter, error) { + helper := NewFilterConfigHelper() + rawConfig, err := helper.MarshalFilterConfig(config) + if err != nil { + return nil, err + } + + enabled := true + return &Filter{ + Type: FilterTypeToolSelection, + Config: rawConfig, + Enabled: &enabled, + }, nil +} diff --git a/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/filter_types.go b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/filter_types.go new file mode 100644 index 00000000..7394954a --- /dev/null +++ b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/filter_types.go @@ -0,0 +1,220 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +// PIIDetectionConfig defines the configuration for PII detection filter +type PIIDetectionConfig struct { + // AllowByDefault defines whether PII is allowed by default + // +optional + // +kubebuilder:default=false + AllowByDefault *bool `json:"allowByDefault,omitempty"` + + // PIITypesAllowed defines the list of PII types that are allowed + // +optional + // +kubebuilder:validation:MaxItems=50 + PIITypesAllowed []string `json:"pii_types_allowed,omitempty"` + + // Threshold defines the confidence threshold for PII detection (0.0-1.0) + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=1 + // +kubebuilder:default=0.7 + Threshold *float64 `json:"threshold,omitempty"` + + // Action defines what to do when PII is detected + // +optional + // +kubebuilder:validation:Enum=block;mask;allow + // +kubebuilder:default=block + Action *string `json:"action,omitempty"` +} + +// PromptGuardConfig defines the configuration for prompt guard filter +type PromptGuardConfig struct { + // Threshold defines the confidence threshold for jailbreak detection (0.0-1.0) + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=1 + // +kubebuilder:default=0.7 + Threshold *float64 `json:"threshold,omitempty"` + + // Action defines what to do when a jailbreak attempt is detected + // +optional + // +kubebuilder:validation:Enum=block;warn;allow + // +kubebuilder:default=block + Action *string `json:"action,omitempty"` + + // CustomRules defines additional custom security rules + // +optional + // +kubebuilder:validation:MaxItems=100 + CustomRules []SecurityRule `json:"customRules,omitempty"` +} + +// SecurityRule defines a custom security rule +type SecurityRule struct { + // Name defines the name of the security rule + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + // +kubebuilder:validation:MaxLength=100 + Name string `json:"name"` + + // Pattern defines the regex pattern to match + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + // +kubebuilder:validation:MaxLength=1000 + Pattern string `json:"pattern"` + + // Action defines the action to take when this rule matches + // +kubebuilder:validation:Required + // +kubebuilder:validation:Enum=block;warn;allow + Action string `json:"action"` + + // Description provides an optional description of this rule + // +optional + // +kubebuilder:validation:MaxLength=500 + Description string `json:"description,omitempty"` +} + +// SemanticCacheConfig defines the configuration for semantic cache filter +type SemanticCacheConfig struct { + // SimilarityThreshold defines the similarity threshold for cache hits (0.0-1.0) + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=1 + // +kubebuilder:default=0.8 + SimilarityThreshold *float64 `json:"similarityThreshold,omitempty"` + + // MaxEntries defines the maximum number of cache entries + // +optional + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Maximum=1000000 + // +kubebuilder:default=1000 + MaxEntries *int32 `json:"maxEntries,omitempty"` + + // TTLSeconds defines the time-to-live for cache entries in seconds + // +optional + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Maximum=86400 + // +kubebuilder:default=3600 + TTLSeconds *int32 `json:"ttlSeconds,omitempty"` + + // Backend defines the cache backend type + // +optional + // +kubebuilder:validation:Enum=memory;redis;milvus + // +kubebuilder:default=memory + Backend *string `json:"backend,omitempty"` + + // BackendConfig defines backend-specific configuration + // +optional + BackendConfig map[string]string `json:"backendConfig,omitempty"` +} + +// ReasoningControlConfig defines the configuration for reasoning control filter +type ReasoningControlConfig struct { + // ReasonFamily defines the reasoning family to use + // +optional + // +kubebuilder:validation:Enum=gpt-oss;deepseek;qwen3;claude + ReasonFamily *string `json:"reasonFamily,omitempty"` + + // EnableReasoning defines whether reasoning mode is enabled + // +optional + // +kubebuilder:default=true + EnableReasoning *bool `json:"enableReasoning,omitempty"` + + // ReasoningEffort defines the reasoning effort level + // +optional + // +kubebuilder:validation:Enum=low;medium;high + // +kubebuilder:default=medium + ReasoningEffort *string `json:"reasoningEffort,omitempty"` + + // MaxReasoningSteps defines the maximum number of reasoning steps + // +optional + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Maximum=100 + // +kubebuilder:default=10 + MaxReasoningSteps *int32 `json:"maxReasoningSteps,omitempty"` + + // ReasoningTimeout defines the timeout for reasoning in seconds + // +optional + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Maximum=300 + // +kubebuilder:default=30 + ReasoningTimeout *int32 `json:"reasoningTimeout,omitempty"` +} + +// ToolSelectionConfig defines the configuration for automatic tool selection filter +type ToolSelectionConfig struct { + // TopK defines the number of top tools to select based on similarity + // +optional + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Maximum=20 + // +kubebuilder:default=3 + TopK *int32 `json:"topK,omitempty"` + + // SimilarityThreshold defines the similarity threshold for tool selection (0.0-1.0) + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=1 + // +kubebuilder:default=0.2 + SimilarityThreshold *float64 `json:"similarityThreshold,omitempty"` + + // ToolsDBPath defines the path to the tools database file + // +optional + // +kubebuilder:default="config/tools_db.json" + ToolsDBPath *string `json:"toolsDBPath,omitempty"` + + // FallbackToEmpty defines whether to return empty tools on failure + // +optional + // +kubebuilder:default=true + FallbackToEmpty *bool `json:"fallbackToEmpty,omitempty"` + + // Categories defines the tool categories to include in selection + // +optional + // +kubebuilder:validation:MaxItems=20 + Categories []string `json:"categories,omitempty"` + + // Tags defines the tool tags to include in selection + // +optional + // +kubebuilder:validation:MaxItems=50 + Tags []string `json:"tags,omitempty"` +} + +// FilterCondition defines a condition for applying filters +type FilterCondition struct { + // Type defines the condition type + // +kubebuilder:validation:Required + // +kubebuilder:validation:Enum=Always;Never;OnMatch;OnNoMatch + Type FilterConditionType `json:"type"` + + // Value defines the condition value (used with OnMatch/OnNoMatch) + // +optional + Value string `json:"value,omitempty"` +} + +// FilterConditionType defines the supported filter condition types +// +kubebuilder:validation:Enum=Always;Never;OnMatch;OnNoMatch +type FilterConditionType string + +const ( + // FilterConditionAlways means the filter is always applied + FilterConditionAlways FilterConditionType = "Always" + // FilterConditionNever means the filter is never applied + FilterConditionNever FilterConditionType = "Never" + // FilterConditionOnMatch means the filter is applied when a condition matches + FilterConditionOnMatch FilterConditionType = "OnMatch" + // FilterConditionOnNoMatch means the filter is applied when a condition doesn't match + FilterConditionOnNoMatch FilterConditionType = "OnNoMatch" +) diff --git a/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/register.go b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/register.go new file mode 100644 index 00000000..368bfb9a --- /dev/null +++ b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/register.go @@ -0,0 +1,47 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +// GroupVersion is group version used to register these objects +var GroupVersion = schema.GroupVersion{Group: "vllm.ai", Version: "v1alpha1"} + +// SchemeBuilder is used to add go types to the GroupVersionKind scheme +var ( + SchemeBuilder = runtime.NewSchemeBuilder(addKnownTypes) + AddToScheme = SchemeBuilder.AddToScheme +) + +// Resource takes an unqualified resource and returns a Group qualified GroupResource +func Resource(resource string) schema.GroupResource { + return GroupVersion.WithResource(resource).GroupResource() +} + +// addKnownTypes adds the set of types defined in this package to the supplied scheme. +func addKnownTypes(scheme *runtime.Scheme) error { + scheme.AddKnownTypes(GroupVersion, + &SemanticRoute{}, + &SemanticRouteList{}, + ) + metav1.AddToGroupVersion(scheme, GroupVersion) + return nil +} diff --git a/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/types.go b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/types.go new file mode 100644 index 00000000..ba4ba0b5 --- /dev/null +++ b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/types.go @@ -0,0 +1,179 @@ +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +// SemanticRoute defines a semantic routing rule for LLM requests +// +kubebuilder:object:root=true +// +kubebuilder:subresource:status +// +kubebuilder:resource:scope=Namespaced,shortName=sr +// +kubebuilder:printcolumn:name="Rules",type="integer",JSONPath=".spec.rules",description="Number of routing rules" +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" +type SemanticRoute struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec SemanticRouteSpec `json:"spec,omitempty"` + Status SemanticRouteStatus `json:"status,omitempty"` +} + +// SemanticRouteSpec defines the desired state of SemanticRoute +type SemanticRouteSpec struct { + // Rules defines the routing rules to be applied + // +kubebuilder:validation:MinItems=1 + // +kubebuilder:validation:MaxItems=100 + Rules []RouteRule `json:"rules"` +} + +// SemanticRouteStatus defines the observed state of SemanticRoute +type SemanticRouteStatus struct { + // Conditions represent the latest available observations of the SemanticRoute's current state + // +optional + Conditions []metav1.Condition `json:"conditions,omitempty"` + + // ObservedGeneration reflects the generation of the most recently observed SemanticRoute + // +optional + ObservedGeneration int64 `json:"observedGeneration,omitempty"` + + // ActiveRules indicates the number of currently active routing rules + // +optional + ActiveRules int32 `json:"activeRules,omitempty"` +} + +// RouteRule defines a single routing rule +type RouteRule struct { + // Intents defines the intent categories that this rule should match + // +kubebuilder:validation:MinItems=1 + // +kubebuilder:validation:MaxItems=50 + Intents []Intent `json:"intents"` + + // ModelRefs defines the target models for this routing rule + // +kubebuilder:validation:MinItems=1 + // +kubebuilder:validation:MaxItems=10 + ModelRefs []ModelRef `json:"modelRefs"` + + // Filters defines the optional filters to be applied to requests matching this rule + // +optional + // +kubebuilder:validation:MaxItems=20 + Filters []Filter `json:"filters,omitempty"` + + // DefaultModel defines the fallback model if no modelRefs are available + // +optional + DefaultModel *ModelRef `json:"defaultModel,omitempty"` +} + +// Intent defines an intent category for routing +type Intent struct { + // Category defines the intent category name (e.g., "math", "computer science", "creative") + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + // +kubebuilder:validation:MaxLength=100 + // +kubebuilder:validation:Pattern=^[a-zA-Z0-9\s\-_]+$ + Category string `json:"category"` + + // Description provides an optional description of this intent category + // +optional + // +kubebuilder:validation:MaxLength=500 + Description string `json:"description,omitempty"` + + // Threshold defines the confidence threshold for this intent (0.0-1.0) + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=1 + // +kubebuilder:default=0.7 + Threshold *float64 `json:"threshold,omitempty"` +} + +// ModelRef defines a reference to a model endpoint +type ModelRef struct { + // ModelName defines the name of the model + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + // +kubebuilder:validation:MaxLength=100 + ModelName string `json:"modelName"` + + // Address defines the endpoint address + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + // +kubebuilder:validation:MaxLength=255 + Address string `json:"address"` + + // Port defines the endpoint port + // +kubebuilder:validation:Required + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Maximum=65535 + Port int32 `json:"port"` + + // Weight defines the traffic weight for this model (0-100) + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=100 + // +kubebuilder:default=100 + Weight *int32 `json:"weight,omitempty"` + + // Priority defines the priority of this model reference (higher values = higher priority) + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=1000 + Priority *int32 `json:"priority,omitempty"` +} + +// Filter defines a filter to be applied to requests +type Filter struct { + // Type defines the filter type + // +kubebuilder:validation:Required + // +kubebuilder:validation:Enum=PIIDetection;PromptGuard;SemanticCache;ReasoningControl + Type FilterType `json:"type"` + + // Config defines the filter-specific configuration + // +optional + Config *runtime.RawExtension `json:"config,omitempty"` + + // Enabled defines whether this filter is enabled + // +optional + // +kubebuilder:default=true + Enabled *bool `json:"enabled,omitempty"` +} + +// FilterType defines the supported filter types +// +kubebuilder:validation:Enum=PIIDetection;PromptGuard;SemanticCache;ReasoningControl;ToolSelection +type FilterType string + +const ( + // FilterTypePIIDetection enables PII detection and filtering + FilterTypePIIDetection FilterType = "PIIDetection" + // FilterTypePromptGuard enables prompt security and jailbreak detection + FilterTypePromptGuard FilterType = "PromptGuard" + // FilterTypeSemanticCache enables semantic caching for performance optimization + FilterTypeSemanticCache FilterType = "SemanticCache" + // FilterTypeReasoningControl enables reasoning mode control + FilterTypeReasoningControl FilterType = "ReasoningControl" + // FilterTypeToolSelection enables automatic tool selection based on semantic similarity + FilterTypeToolSelection FilterType = "ToolSelection" +) + +// SemanticRouteList contains a list of SemanticRoute +// +kubebuilder:object:root=true +type SemanticRouteList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []SemanticRoute `json:"items"` +} diff --git a/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/zz_generated.deepcopy.go b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/zz_generated.deepcopy.go new file mode 100644 index 00000000..d4aab7df --- /dev/null +++ b/src/semantic-router/pkg/apis/vllm.ai/v1alpha1/zz_generated.deepcopy.go @@ -0,0 +1,477 @@ +//go:build !ignore_autogenerated + +/* +Copyright 2025 vLLM Semantic Router. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by controller-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Filter) DeepCopyInto(out *Filter) { + *out = *in + if in.Config != nil { + in, out := &in.Config, &out.Config + *out = new(runtime.RawExtension) + (*in).DeepCopyInto(*out) + } + if in.Enabled != nil { + in, out := &in.Enabled, &out.Enabled + *out = new(bool) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Filter. +func (in *Filter) DeepCopy() *Filter { + if in == nil { + return nil + } + out := new(Filter) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FilterCondition) DeepCopyInto(out *FilterCondition) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FilterCondition. +func (in *FilterCondition) DeepCopy() *FilterCondition { + if in == nil { + return nil + } + out := new(FilterCondition) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FilterConfigHelper) DeepCopyInto(out *FilterConfigHelper) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FilterConfigHelper. +func (in *FilterConfigHelper) DeepCopy() *FilterConfigHelper { + if in == nil { + return nil + } + out := new(FilterConfigHelper) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Intent) DeepCopyInto(out *Intent) { + *out = *in + if in.Threshold != nil { + in, out := &in.Threshold, &out.Threshold + *out = new(float64) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Intent. +func (in *Intent) DeepCopy() *Intent { + if in == nil { + return nil + } + out := new(Intent) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ModelRef) DeepCopyInto(out *ModelRef) { + *out = *in + if in.Weight != nil { + in, out := &in.Weight, &out.Weight + *out = new(int32) + **out = **in + } + if in.Priority != nil { + in, out := &in.Priority, &out.Priority + *out = new(int32) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelRef. +func (in *ModelRef) DeepCopy() *ModelRef { + if in == nil { + return nil + } + out := new(ModelRef) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PIIDetectionConfig) DeepCopyInto(out *PIIDetectionConfig) { + *out = *in + if in.AllowByDefault != nil { + in, out := &in.AllowByDefault, &out.AllowByDefault + *out = new(bool) + **out = **in + } + if in.PIITypesAllowed != nil { + in, out := &in.PIITypesAllowed, &out.PIITypesAllowed + *out = make([]string, len(*in)) + copy(*out, *in) + } + if in.Threshold != nil { + in, out := &in.Threshold, &out.Threshold + *out = new(float64) + **out = **in + } + if in.Action != nil { + in, out := &in.Action, &out.Action + *out = new(string) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PIIDetectionConfig. +func (in *PIIDetectionConfig) DeepCopy() *PIIDetectionConfig { + if in == nil { + return nil + } + out := new(PIIDetectionConfig) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PromptGuardConfig) DeepCopyInto(out *PromptGuardConfig) { + *out = *in + if in.Threshold != nil { + in, out := &in.Threshold, &out.Threshold + *out = new(float64) + **out = **in + } + if in.Action != nil { + in, out := &in.Action, &out.Action + *out = new(string) + **out = **in + } + if in.CustomRules != nil { + in, out := &in.CustomRules, &out.CustomRules + *out = make([]SecurityRule, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PromptGuardConfig. +func (in *PromptGuardConfig) DeepCopy() *PromptGuardConfig { + if in == nil { + return nil + } + out := new(PromptGuardConfig) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ReasoningControlConfig) DeepCopyInto(out *ReasoningControlConfig) { + *out = *in + if in.ReasonFamily != nil { + in, out := &in.ReasonFamily, &out.ReasonFamily + *out = new(string) + **out = **in + } + if in.EnableReasoning != nil { + in, out := &in.EnableReasoning, &out.EnableReasoning + *out = new(bool) + **out = **in + } + if in.ReasoningEffort != nil { + in, out := &in.ReasoningEffort, &out.ReasoningEffort + *out = new(string) + **out = **in + } + if in.MaxReasoningSteps != nil { + in, out := &in.MaxReasoningSteps, &out.MaxReasoningSteps + *out = new(int32) + **out = **in + } + if in.ReasoningTimeout != nil { + in, out := &in.ReasoningTimeout, &out.ReasoningTimeout + *out = new(int32) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ReasoningControlConfig. +func (in *ReasoningControlConfig) DeepCopy() *ReasoningControlConfig { + if in == nil { + return nil + } + out := new(ReasoningControlConfig) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RouteRule) DeepCopyInto(out *RouteRule) { + *out = *in + if in.Intents != nil { + in, out := &in.Intents, &out.Intents + *out = make([]Intent, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.ModelRefs != nil { + in, out := &in.ModelRefs, &out.ModelRefs + *out = make([]ModelRef, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Filters != nil { + in, out := &in.Filters, &out.Filters + *out = make([]Filter, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.DefaultModel != nil { + in, out := &in.DefaultModel, &out.DefaultModel + *out = new(ModelRef) + (*in).DeepCopyInto(*out) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RouteRule. +func (in *RouteRule) DeepCopy() *RouteRule { + if in == nil { + return nil + } + out := new(RouteRule) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SecurityRule) DeepCopyInto(out *SecurityRule) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SecurityRule. +func (in *SecurityRule) DeepCopy() *SecurityRule { + if in == nil { + return nil + } + out := new(SecurityRule) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SemanticCacheConfig) DeepCopyInto(out *SemanticCacheConfig) { + *out = *in + if in.SimilarityThreshold != nil { + in, out := &in.SimilarityThreshold, &out.SimilarityThreshold + *out = new(float64) + **out = **in + } + if in.MaxEntries != nil { + in, out := &in.MaxEntries, &out.MaxEntries + *out = new(int32) + **out = **in + } + if in.TTLSeconds != nil { + in, out := &in.TTLSeconds, &out.TTLSeconds + *out = new(int32) + **out = **in + } + if in.Backend != nil { + in, out := &in.Backend, &out.Backend + *out = new(string) + **out = **in + } + if in.BackendConfig != nil { + in, out := &in.BackendConfig, &out.BackendConfig + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SemanticCacheConfig. +func (in *SemanticCacheConfig) DeepCopy() *SemanticCacheConfig { + if in == nil { + return nil + } + out := new(SemanticCacheConfig) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SemanticRoute) DeepCopyInto(out *SemanticRoute) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + in.Spec.DeepCopyInto(&out.Spec) + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SemanticRoute. +func (in *SemanticRoute) DeepCopy() *SemanticRoute { + if in == nil { + return nil + } + out := new(SemanticRoute) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *SemanticRoute) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SemanticRouteList) DeepCopyInto(out *SemanticRouteList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]SemanticRoute, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SemanticRouteList. +func (in *SemanticRouteList) DeepCopy() *SemanticRouteList { + if in == nil { + return nil + } + out := new(SemanticRouteList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *SemanticRouteList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SemanticRouteSpec) DeepCopyInto(out *SemanticRouteSpec) { + *out = *in + if in.Rules != nil { + in, out := &in.Rules, &out.Rules + *out = make([]RouteRule, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SemanticRouteSpec. +func (in *SemanticRouteSpec) DeepCopy() *SemanticRouteSpec { + if in == nil { + return nil + } + out := new(SemanticRouteSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SemanticRouteStatus) DeepCopyInto(out *SemanticRouteStatus) { + *out = *in + if in.Conditions != nil { + in, out := &in.Conditions, &out.Conditions + *out = make([]v1.Condition, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SemanticRouteStatus. +func (in *SemanticRouteStatus) DeepCopy() *SemanticRouteStatus { + if in == nil { + return nil + } + out := new(SemanticRouteStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ToolSelectionConfig) DeepCopyInto(out *ToolSelectionConfig) { + *out = *in + if in.TopK != nil { + in, out := &in.TopK, &out.TopK + *out = new(int32) + **out = **in + } + if in.SimilarityThreshold != nil { + in, out := &in.SimilarityThreshold, &out.SimilarityThreshold + *out = new(float64) + **out = **in + } + if in.ToolsDBPath != nil { + in, out := &in.ToolsDBPath, &out.ToolsDBPath + *out = new(string) + **out = **in + } + if in.FallbackToEmpty != nil { + in, out := &in.FallbackToEmpty, &out.FallbackToEmpty + *out = new(bool) + **out = **in + } + if in.Categories != nil { + in, out := &in.Categories, &out.Categories + *out = make([]string, len(*in)) + copy(*out, *in) + } + if in.Tags != nil { + in, out := &in.Tags, &out.Tags + *out = make([]string, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ToolSelectionConfig. +func (in *ToolSelectionConfig) DeepCopy() *ToolSelectionConfig { + if in == nil { + return nil + } + out := new(ToolSelectionConfig) + in.DeepCopyInto(out) + return out +} diff --git a/src/semantic-router/pkg/cache/cache_factory.go b/src/semantic-router/pkg/cache/cache_factory.go index 396d1eb6..c72f7a8b 100644 --- a/src/semantic-router/pkg/cache/cache_factory.go +++ b/src/semantic-router/pkg/cache/cache_factory.go @@ -27,6 +27,7 @@ func NewCacheBackend(config CacheConfig) (CacheBackend, error) { SimilarityThreshold: config.SimilarityThreshold, MaxEntries: config.MaxEntries, TTLSeconds: config.TTLSeconds, + EvictionPolicy: config.EvictionPolicy, } return NewInMemoryCache(options), nil diff --git a/src/semantic-router/pkg/cache/cache_interface.go b/src/semantic-router/pkg/cache/cache_interface.go index ceb18c26..f35e165c 100644 --- a/src/semantic-router/pkg/cache/cache_interface.go +++ b/src/semantic-router/pkg/cache/cache_interface.go @@ -10,7 +10,9 @@ type CacheEntry struct { Model string Query string Embedding []float32 - Timestamp time.Time + Timestamp time.Time // Creation time (when the entry was added or completed with a response) + LastAccessAt time.Time // Last access time + HitCount int64 // Access count } // CacheBackend defines the interface for semantic cache implementations @@ -58,6 +60,20 @@ const ( MilvusCacheType CacheBackendType = "milvus" ) +// EvictionPolicyType defines the available eviction policies +type EvictionPolicyType string + +const ( + // FIFOEvictionPolicyType specifies the FIFO eviction policy + FIFOEvictionPolicyType EvictionPolicyType = "fifo" + + // LRUEvictionPolicyType specifies the LRU eviction policy + LRUEvictionPolicyType EvictionPolicyType = "lru" + + // LFUEvictionPolicyType specifies the LFU eviction policy + LFUEvictionPolicyType EvictionPolicyType = "lfu" +) + // CacheConfig contains configuration settings shared across all cache backends type CacheConfig struct { // BackendType specifies which cache implementation to use @@ -75,6 +91,9 @@ type CacheConfig struct { // TTLSeconds sets cache entry expiration time (0 disables expiration) TTLSeconds int `yaml:"ttl_seconds,omitempty"` + // EvictionPolicy defines the eviction policy for in-memory cache ("fifo", "lru", "lfu") + EvictionPolicy EvictionPolicyType `yaml:"eviction_policy,omitempty"` + // BackendConfigPath points to backend-specific configuration files BackendConfigPath string `yaml:"backend_config_path,omitempty"` } diff --git a/src/semantic-router/pkg/cache/eviction_policy.go b/src/semantic-router/pkg/cache/eviction_policy.go new file mode 100644 index 00000000..921f17f6 --- /dev/null +++ b/src/semantic-router/pkg/cache/eviction_policy.go @@ -0,0 +1,58 @@ +package cache + +type EvictionPolicy interface { + SelectVictim(entries []CacheEntry) int +} + +type FIFOPolicy struct{} + +func (p *FIFOPolicy) SelectVictim(entries []CacheEntry) int { + if len(entries) == 0 { + return -1 + } + + oldestIdx := 0 + for i := 1; i < len(entries); i++ { + if entries[i].Timestamp.Before(entries[oldestIdx].Timestamp) { + oldestIdx = i + } + } + return oldestIdx +} + +type LRUPolicy struct{} + +func (p *LRUPolicy) SelectVictim(entries []CacheEntry) int { + if len(entries) == 0 { + return -1 + } + + oldestIdx := 0 + for i := 1; i < len(entries); i++ { + if entries[i].LastAccessAt.Before(entries[oldestIdx].LastAccessAt) { + oldestIdx = i + } + } + return oldestIdx +} + +type LFUPolicy struct{} + +func (p *LFUPolicy) SelectVictim(entries []CacheEntry) int { + if len(entries) == 0 { + return -1 + } + + victimIdx := 0 + for i := 1; i < len(entries); i++ { + if entries[i].HitCount < entries[victimIdx].HitCount { + victimIdx = i + } else if entries[i].HitCount == entries[victimIdx].HitCount { + // Use LRU as tiebreaker to avoid random selection + if entries[i].LastAccessAt.Before(entries[victimIdx].LastAccessAt) { + victimIdx = i + } + } + } + return victimIdx +} diff --git a/src/semantic-router/pkg/cache/eviction_policy_test.go b/src/semantic-router/pkg/cache/eviction_policy_test.go new file mode 100644 index 00000000..91d5504a --- /dev/null +++ b/src/semantic-router/pkg/cache/eviction_policy_test.go @@ -0,0 +1,89 @@ +package cache + +import ( + "testing" + "time" +) + +func TestFIFOPolicy(t *testing.T) { + policy := &FIFOPolicy{} + + // Test empty entries + if victim := policy.SelectVictim([]CacheEntry{}); victim != -1 { + t.Errorf("Expected -1 for empty entries, got %d", victim) + } + + // Test with entries + now := time.Now() + entries := []CacheEntry{ + {Query: "query1", Timestamp: now.Add(-3 * time.Second)}, + {Query: "query2", Timestamp: now.Add(-1 * time.Second)}, + {Query: "query3", Timestamp: now.Add(-2 * time.Second)}, + } + + victim := policy.SelectVictim(entries) + if victim != 0 { + t.Errorf("Expected victim index 0 (oldest), got %d", victim) + } +} + +func TestLRUPolicy(t *testing.T) { + policy := &LRUPolicy{} + + // Test empty entries + if victim := policy.SelectVictim([]CacheEntry{}); victim != -1 { + t.Errorf("Expected -1 for empty entries, got %d", victim) + } + + // Test with entries + now := time.Now() + entries := []CacheEntry{ + {Query: "query1", LastAccessAt: now.Add(-3 * time.Second)}, + {Query: "query2", LastAccessAt: now.Add(-1 * time.Second)}, + {Query: "query3", LastAccessAt: now.Add(-2 * time.Second)}, + } + + victim := policy.SelectVictim(entries) + if victim != 0 { + t.Errorf("Expected victim index 0 (least recently used), got %d", victim) + } +} + +func TestLFUPolicy(t *testing.T) { + policy := &LFUPolicy{} + + // Test empty entries + if victim := policy.SelectVictim([]CacheEntry{}); victim != -1 { + t.Errorf("Expected -1 for empty entries, got %d", victim) + } + + // Test with entries + now := time.Now() + entries := []CacheEntry{ + {Query: "query1", HitCount: 5, LastAccessAt: now.Add(-2 * time.Second)}, + {Query: "query2", HitCount: 1, LastAccessAt: now.Add(-3 * time.Second)}, + {Query: "query3", HitCount: 3, LastAccessAt: now.Add(-1 * time.Second)}, + } + + victim := policy.SelectVictim(entries) + if victim != 1 { + t.Errorf("Expected victim index 1 (least frequently used), got %d", victim) + } +} + +func TestLFUPolicyTiebreaker(t *testing.T) { + policy := &LFUPolicy{} + + // Test tiebreaker: same frequency, choose least recently used + now := time.Now() + entries := []CacheEntry{ + {Query: "query1", HitCount: 2, LastAccessAt: now.Add(-1 * time.Second)}, + {Query: "query2", HitCount: 2, LastAccessAt: now.Add(-3 * time.Second)}, + {Query: "query3", HitCount: 5, LastAccessAt: now.Add(-2 * time.Second)}, + } + + victim := policy.SelectVictim(entries) + if victim != 1 { + t.Errorf("Expected victim index 1 (LRU tiebreaker), got %d", victim) + } +} diff --git a/src/semantic-router/pkg/cache/inmemory_cache.go b/src/semantic-router/pkg/cache/inmemory_cache.go index 61fb8773..07595a52 100644 --- a/src/semantic-router/pkg/cache/inmemory_cache.go +++ b/src/semantic-router/pkg/cache/inmemory_cache.go @@ -26,6 +26,7 @@ type InMemoryCache struct { hitCount int64 missCount int64 lastCleanupTime *time.Time + evictionPolicy EvictionPolicy } // InMemoryCacheOptions contains configuration parameters for the in-memory cache @@ -34,18 +35,31 @@ type InMemoryCacheOptions struct { MaxEntries int TTLSeconds int Enabled bool + EvictionPolicy EvictionPolicyType } // NewInMemoryCache initializes a new in-memory semantic cache instance func NewInMemoryCache(options InMemoryCacheOptions) *InMemoryCache { - observability.Debugf("Initializing in-memory cache: enabled=%t, maxEntries=%d, ttlSeconds=%d, threshold=%.3f", - options.Enabled, options.MaxEntries, options.TTLSeconds, options.SimilarityThreshold) + observability.Debugf("Initializing in-memory cache: enabled=%t, maxEntries=%d, ttlSeconds=%d, threshold=%.3f, eviction_policy=%s", + options.Enabled, options.MaxEntries, options.TTLSeconds, options.SimilarityThreshold, options.EvictionPolicy) + + var evictionPolicy EvictionPolicy + switch options.EvictionPolicy { + case LRUEvictionPolicyType: + evictionPolicy = &LRUPolicy{} + case LFUEvictionPolicyType: + evictionPolicy = &LFUPolicy{} + default: // FIFOEvictionPolicyType + evictionPolicy = &FIFOPolicy{} + } + return &InMemoryCache{ entries: []CacheEntry{}, similarityThreshold: options.SimilarityThreshold, maxEntries: options.MaxEntries, ttlSeconds: options.TTLSeconds, enabled: options.Enabled, + evictionPolicy: evictionPolicy, } } @@ -75,38 +89,28 @@ func (c *InMemoryCache) AddPendingRequest(requestID string, model string, query // Remove expired entries to maintain cache hygiene c.cleanupExpiredEntries() + // Check if eviction is needed before adding the new entry + if c.maxEntries > 0 && len(c.entries) >= c.maxEntries { + c.evictOne() + } + // Create cache entry for the pending request + now := time.Now() entry := CacheEntry{ - RequestID: requestID, - RequestBody: requestBody, - Model: model, - Query: query, - Embedding: embedding, - Timestamp: time.Now(), + RequestID: requestID, + RequestBody: requestBody, + Model: model, + Query: query, + Embedding: embedding, + Timestamp: now, + LastAccessAt: now, + HitCount: 0, } c.entries = append(c.entries, entry) observability.Debugf("InMemoryCache.AddPendingRequest: added pending entry (total entries: %d, embedding_dim: %d)", len(c.entries), len(embedding)) - // Apply entry limit to prevent unbounded memory growth - if c.maxEntries > 0 && len(c.entries) > c.maxEntries { - // Sort entries by timestamp to identify oldest - sort.Slice(c.entries, func(i, j int) bool { - return c.entries[i].Timestamp.Before(c.entries[j].Timestamp) - }) - // Keep only the most recent entries - removedCount := len(c.entries) - c.maxEntries - c.entries = c.entries[len(c.entries)-c.maxEntries:] - observability.Debugf("InMemoryCache: size limit exceeded, removed %d oldest entries (limit: %d)", - removedCount, c.maxEntries) - observability.LogEvent("cache_trimmed", map[string]interface{}{ - "backend": "memory", - "removed_count": removedCount, - "max_entries": c.maxEntries, - }) - } - // Record metrics metrics.RecordCacheOperation("memory", "add_pending", "success", time.Since(start).Seconds()) metrics.UpdateCacheEntries("memory", len(c.entries)) @@ -134,6 +138,7 @@ func (c *InMemoryCache) UpdateWithResponse(requestID string, responseBody []byte // Complete the cache entry with the response c.entries[i].ResponseBody = responseBody c.entries[i].Timestamp = time.Now() + c.entries[i].LastAccessAt = time.Now() observability.Debugf("InMemoryCache.UpdateWithResponse: updated entry with response (response_size: %d bytes)", len(responseBody)) @@ -163,6 +168,18 @@ func (c *InMemoryCache) AddEntry(requestID string, model string, query string, r return fmt.Errorf("failed to generate embedding: %w", err) } + c.mu.Lock() + defer c.mu.Unlock() + + // Clean up expired entries before adding new one + c.cleanupExpiredEntries() + + // Check if eviction is needed before adding the new entry + if c.maxEntries > 0 && len(c.entries) >= c.maxEntries { + c.evictOne() + } + + now := time.Now() entry := CacheEntry{ RequestID: requestID, RequestBody: requestBody, @@ -170,15 +187,11 @@ func (c *InMemoryCache) AddEntry(requestID string, model string, query string, r Model: model, Query: query, Embedding: embedding, - Timestamp: time.Now(), + Timestamp: now, + LastAccessAt: now, + HitCount: 0, } - c.mu.Lock() - defer c.mu.Unlock() - - // Clean up expired entries before adding new one - c.cleanupExpiredEntries() - c.entries = append(c.entries, entry) observability.Debugf("InMemoryCache.AddEntry: added complete entry (total entries: %d, request_size: %d, response_size: %d)", len(c.entries), len(requestBody), len(responseBody)) @@ -188,16 +201,6 @@ func (c *InMemoryCache) AddEntry(requestID string, model string, query string, r "model": model, }) - // Apply entry limit if configured - if c.maxEntries > 0 && len(c.entries) > c.maxEntries { - // Sort by timestamp to identify oldest entries - sort.Slice(c.entries, func(i, j int) bool { - return c.entries[i].Timestamp.Before(c.entries[j].Timestamp) - }) - // Keep only the most recent entries - c.entries = c.entries[len(c.entries)-c.maxEntries:] - } - // Record success metrics metrics.RecordCacheOperation("memory", "add_entry", "success", time.Since(start).Seconds()) metrics.UpdateCacheEntries("memory", len(c.entries)) @@ -228,19 +231,19 @@ func (c *InMemoryCache) FindSimilar(model string, query string) ([]byte, bool, e } c.mu.RLock() - defer c.mu.RUnlock() // Check for expired entries during search c.cleanupExpiredEntriesReadOnly() type SimilarityResult struct { + EntryIndex int Entry CacheEntry Similarity float32 } // Compare with completed entries for the same model results := make([]SimilarityResult, 0, len(c.entries)) - for _, entry := range c.entries { + for entryIndex, entry := range c.entries { if entry.ResponseBody == nil { continue // Skip incomplete entries } @@ -257,15 +260,19 @@ func (c *InMemoryCache) FindSimilar(model string, query string) ([]byte, bool, e } results = append(results, SimilarityResult{ + EntryIndex: entryIndex, Entry: entry, Similarity: dotProduct, }) } + // unlock the read lock since we need the write lock to update the access info + c.mu.RUnlock() + // Handle case where no suitable entries exist if len(results) == 0 { atomic.AddInt64(&c.missCount, 1) - observability.Debugf("InMemoryCache.FindSimilar: no entries found with responses (total entries: %d)", len(c.entries)) + observability.Debugf("InMemoryCache.FindSimilar: no entries found with responses") metrics.RecordCacheOperation("memory", "find_similar", "miss", time.Since(start).Seconds()) metrics.RecordCacheMiss() return nil, false, nil @@ -279,6 +286,11 @@ func (c *InMemoryCache) FindSimilar(model string, query string) ([]byte, bool, e // Check if the best match meets the similarity threshold if results[0].Similarity >= c.similarityThreshold { atomic.AddInt64(&c.hitCount, 1) + + c.mu.Lock() + c.updateAccessInfo(results[0].EntryIndex, results[0].Entry) + c.mu.Unlock() + observability.Debugf("InMemoryCache.FindSimilar: CACHE HIT - similarity=%.4f >= threshold=%.4f, response_size=%d bytes", results[0].Similarity, c.similarityThreshold, len(results[0].Entry.ResponseBody)) observability.LogEvent("cache_hit", map[string]interface{}{ @@ -356,8 +368,8 @@ func (c *InMemoryCache) cleanupExpiredEntries() { validEntries := make([]CacheEntry, 0, len(c.entries)) for _, entry := range c.entries { - // Retain entries that are still within their TTL - if now.Sub(entry.Timestamp).Seconds() < float64(c.ttlSeconds) { + // Retain entries that are still within their TTL based on last access + if now.Sub(entry.LastAccessAt).Seconds() < float64(c.ttlSeconds) { validEntries = append(validEntries, entry) } } @@ -389,7 +401,7 @@ func (c *InMemoryCache) cleanupExpiredEntriesReadOnly() { expiredCount := 0 for _, entry := range c.entries { - if now.Sub(entry.Timestamp).Seconds() >= float64(c.ttlSeconds) { + if now.Sub(entry.LastAccessAt).Seconds() >= float64(c.ttlSeconds) { expiredCount++ } } @@ -404,3 +416,45 @@ func (c *InMemoryCache) cleanupExpiredEntriesReadOnly() { }) } } + +// updateAccessInfo updates the access information for the given entry index +func (c *InMemoryCache) updateAccessInfo(entryIndex int, target CacheEntry) { + // fast path + if entryIndex < len(c.entries) && c.entries[entryIndex].RequestID == target.RequestID { + c.entries[entryIndex].LastAccessAt = time.Now() + c.entries[entryIndex].HitCount++ + return + } + + // fallback to linear search + for i := range c.entries { + if c.entries[i].RequestID == target.RequestID { + c.entries[i].LastAccessAt = time.Now() + c.entries[i].HitCount++ + break + } + } +} + +// evictOne removes one entry based on the configured eviction policy +func (c *InMemoryCache) evictOne() { + if len(c.entries) == 0 { + return + } + + victimIdx := c.evictionPolicy.SelectVictim(c.entries) + if victimIdx < 0 || victimIdx >= len(c.entries) { + return + } + + evictedRequestID := c.entries[victimIdx].RequestID + + c.entries[victimIdx] = c.entries[len(c.entries)-1] + c.entries = c.entries[:len(c.entries)-1] + + observability.LogEvent("cache_evicted", map[string]any{ + "backend": "memory", + "request_id": evictedRequestID, + "max_entries": c.maxEntries, + }) +} diff --git a/src/semantic-router/pkg/cache/inmemory_cache_integration_test.go b/src/semantic-router/pkg/cache/inmemory_cache_integration_test.go new file mode 100644 index 00000000..caffe6b9 --- /dev/null +++ b/src/semantic-router/pkg/cache/inmemory_cache_integration_test.go @@ -0,0 +1,173 @@ +package cache + +import ( + "fmt" + "testing" + + candle_binding "github.com/vllm-project/semantic-router/candle-binding" +) + +// TestInMemoryCacheIntegration tests the in-memory cache integration +func TestInMemoryCacheIntegration(t *testing.T) { + if err := candle_binding.InitModel("sentence-transformers/all-MiniLM-L6-v2", true); err != nil { + t.Skipf("Failed to initialize BERT model: %v", err) + } + + cache := NewInMemoryCache(InMemoryCacheOptions{ + Enabled: true, + MaxEntries: 2, + SimilarityThreshold: 0.9, + EvictionPolicy: "lfu", + TTLSeconds: 0, + }) + + t.Run("InMemoryCacheIntegration", func(t *testing.T) { + // Step 1: Add first entry + err := cache.AddEntry("req1", "test-model", "Hello world", + []byte("request1"), []byte("response1")) + if err != nil { + t.Fatalf("Failed to add first entry: %v", err) + } + + // Step 2: Add second entry (cache at capacity) + err = cache.AddEntry("req2", "test-model", "Good morning", + []byte("request2"), []byte("response2")) + if err != nil { + t.Fatalf("Failed to add second entry: %v", err) + } + + // Verify + if len(cache.entries) != 2 { + t.Errorf("Expected 2 entries, got %d", len(cache.entries)) + } + if cache.entries[1].RequestID != "req2" { + t.Errorf("Expected req2 to be the second entry, got %s", cache.entries[1].RequestID) + } + + // Step 3: Access first entry multiple times to increase its frequency + for range 2 { + responseBody, found, err := cache.FindSimilar("test-model", "Hello world") + if err != nil { + t.Logf("FindSimilar failed (expected due to high threshold): %v", err) + } + if !found { + t.Errorf("Expected to find similar entry for first query") + } + if string(responseBody) != "response1" { + t.Errorf("Expected response1, got %s", string(responseBody)) + } + } + + // Step 4: Access second entry once + responseBody, found, err := cache.FindSimilar("test-model", "Good morning") + if err != nil { + t.Logf("FindSimilar failed (expected due to high threshold): %v", err) + } + if !found { + t.Errorf("Expected to find similar entry for second query") + } + if string(responseBody) != "response2" { + t.Errorf("Expected response2, got %s", string(responseBody)) + } + + // Step 5: Add third entry - should trigger LFU eviction + err = cache.AddEntry("req3", "test-model", "Bye", + []byte("request3"), []byte("response3")) + if err != nil { + t.Fatalf("Failed to add third entry: %v", err) + } + + // Verify + if len(cache.entries) != 2 { + t.Errorf("Expected 2 entries after eviction, got %d", len(cache.entries)) + } + if cache.entries[0].RequestID != "req1" { + t.Errorf("Expected req1 to be the first entry, got %s", cache.entries[0].RequestID) + } + if cache.entries[1].RequestID != "req3" { + t.Errorf("Expected req3 to be the second entry, got %s", cache.entries[1].RequestID) + } + if cache.entries[0].HitCount != 2 { + t.Errorf("Expected HitCount to be 2, got %d", cache.entries[0].HitCount) + } + if cache.entries[1].HitCount != 0 { + t.Errorf("Expected HitCount to be 0, got %d", cache.entries[1].HitCount) + } + }) +} + +// TestInMemoryCachePendingRequestWorkflow tests the in-memory cache pending request workflow +func TestInMemoryCachePendingRequestWorkflow(t *testing.T) { + if err := candle_binding.InitModel("sentence-transformers/all-MiniLM-L6-v2", true); err != nil { + t.Skipf("Failed to initialize BERT model: %v", err) + } + + cache := NewInMemoryCache(InMemoryCacheOptions{ + Enabled: true, + MaxEntries: 2, + EvictionPolicy: "lru", + }) + + t.Run("PendingRequestFlow", func(t *testing.T) { + // Step 1: Add pending request + err := cache.AddPendingRequest("req1", "test-model", "test query", []byte("request")) + if err != nil { + t.Fatalf("Failed to add pending request: %v", err) + } + + // Verify + if len(cache.entries) != 1 { + t.Errorf("Expected 1 entry after AddPendingRequest, got %d", len(cache.entries)) + } + + if string(cache.entries[0].ResponseBody) != "" { + t.Error("Expected ResponseBody to be empty for pending request") + } + + // Step 2: Update with response + err = cache.UpdateWithResponse("req1", []byte("response1")) + if err != nil { + t.Fatalf("Failed to update with response: %v", err) + } + + // Step 3: Try to find similar + response, found, err := cache.FindSimilar("test-model", "test query") + if err != nil { + t.Logf("FindSimilar error (may be due to embedding): %v", err) + } + + if !found { + t.Errorf("Expected to find completed entry after UpdateWithResponse") + } + if string(response) != "response1" { + t.Errorf("Expected response1, got %s", string(response)) + } + }) +} + +// TestEvictionPolicySelection tests that the correct policy is selected +func TestEvictionPolicySelection(t *testing.T) { + testCases := []struct { + policy string + expected string + }{ + {"lru", "*cache.LRUPolicy"}, + {"lfu", "*cache.LFUPolicy"}, + {"fifo", "*cache.FIFOPolicy"}, + {"", "*cache.FIFOPolicy"}, // Default + {"invalid", "*cache.FIFOPolicy"}, // Default fallback + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("Policy_%s", tc.policy), func(t *testing.T) { + cache := NewInMemoryCache(InMemoryCacheOptions{ + EvictionPolicy: EvictionPolicyType(tc.policy), + }) + + policyType := fmt.Sprintf("%T", cache.evictionPolicy) + if policyType != tc.expected { + t.Errorf("Expected policy type %s, got %s", tc.expected, policyType) + } + }) + } +} diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index 43d929e5..9a3bfb70 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -66,6 +66,9 @@ type RouterConfig struct { // Time-to-live for cache entries in seconds (0 means no expiration) TTLSeconds int `yaml:"ttl_seconds,omitempty"` + // Eviction policy for in-memory cache ("fifo", "lru", "lfu") + EvictionPolicy string `yaml:"eviction_policy,omitempty"` + // Path to backend-specific configuration file BackendConfigPath string `yaml:"backend_config_path,omitempty"` } `yaml:"semantic_cache"` diff --git a/src/semantic-router/pkg/consts/consts.go b/src/semantic-router/pkg/consts/consts.go new file mode 100644 index 00000000..cf1486bf --- /dev/null +++ b/src/semantic-router/pkg/consts/consts.go @@ -0,0 +1,5 @@ +package consts + +// UnknownLabel is a canonical fallback label value used across the codebase +// when a more specific value (e.g., model, category, reason) is not available. +const UnknownLabel = "unknown" diff --git a/src/semantic-router/pkg/extproc/metrics_integration_test.go b/src/semantic-router/pkg/extproc/metrics_integration_test.go index 397318a8..addf21c2 100644 --- a/src/semantic-router/pkg/extproc/metrics_integration_test.go +++ b/src/semantic-router/pkg/extproc/metrics_integration_test.go @@ -81,7 +81,10 @@ var _ = Describe("Metrics recording", func() { StartTime: time.Now().Add(-1 * time.Second), } - before := getHistogramSampleCount("llm_model_tpot_seconds", ctx.RequestModel) + beforeTPOT := getHistogramSampleCount("llm_model_tpot_seconds", ctx.RequestModel) + + beforePrompt := getHistogramSampleCount("llm_prompt_tokens_per_request", ctx.RequestModel) + beforeCompletion := getHistogramSampleCount("llm_completion_tokens_per_request", ctx.RequestModel) openAIResponse := map[string]interface{}{ "id": "chatcmpl-xyz", @@ -111,7 +114,13 @@ var _ = Describe("Metrics recording", func() { Expect(err).NotTo(HaveOccurred()) Expect(response.GetResponseBody()).NotTo(BeNil()) - after := getHistogramSampleCount("llm_model_tpot_seconds", ctx.RequestModel) - Expect(after).To(BeNumerically(">", before)) + afterTPOT := getHistogramSampleCount("llm_model_tpot_seconds", ctx.RequestModel) + Expect(afterTPOT).To(BeNumerically(">", beforeTPOT)) + + // New per-request token histograms should also be recorded + afterPrompt := getHistogramSampleCount("llm_prompt_tokens_per_request", ctx.RequestModel) + afterCompletion := getHistogramSampleCount("llm_completion_tokens_per_request", ctx.RequestModel) + Expect(afterPrompt).To(BeNumerically(">", beforePrompt)) + Expect(afterCompletion).To(BeNumerically(">", beforeCompletion)) }) }) diff --git a/src/semantic-router/pkg/extproc/reason_mode_selector.go b/src/semantic-router/pkg/extproc/reason_mode_selector.go index 7380f54b..58f880a3 100644 --- a/src/semantic-router/pkg/extproc/reason_mode_selector.go +++ b/src/semantic-router/pkg/extproc/reason_mode_selector.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/consts" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/entropy" @@ -132,7 +133,7 @@ func (r *OpenAIRouter) setReasoningModeToRequestBody(requestBody []byte, enabled } // Determine model for kwargs and logging - model := "unknown" + model := consts.UnknownLabel if modelValue, ok := requestMap["model"]; ok { if modelStr, ok := modelValue.(string); ok { model = modelStr @@ -191,7 +192,7 @@ func (r *OpenAIRouter) setReasoningModeToRequestBody(requestBody []byte, enabled // Record metrics for template usage and effort when enabled if enabled { familyConfig := r.getModelReasoningFamily(model) - modelFamily := "unknown" + modelFamily := consts.UnknownLabel templateParam := "reasoning_effort" // default fallback if familyConfig != nil { diff --git a/src/semantic-router/pkg/extproc/request_handler.go b/src/semantic-router/pkg/extproc/request_handler.go index a1ced99d..867333de 100644 --- a/src/semantic-router/pkg/extproc/request_handler.go +++ b/src/semantic-router/pkg/extproc/request_handler.go @@ -263,8 +263,7 @@ func (r *OpenAIRouter) handleCaching(ctx *RequestContext) (*ext_proc.ProcessingR if err != nil { observability.Errorf("Error searching cache: %v", err) } else if found { - // Record and log cache hit - metrics.RecordCacheHit() + // Log cache hit observability.LogEvent("cache_hit", map[string]interface{}{ "request_id": ctx.RequestID, "model": requestModel, diff --git a/src/semantic-router/pkg/extproc/router.go b/src/semantic-router/pkg/extproc/router.go index e333cd48..90eed7c5 100644 --- a/src/semantic-router/pkg/extproc/router.go +++ b/src/semantic-router/pkg/extproc/router.go @@ -84,6 +84,7 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) { SimilarityThreshold: cfg.GetCacheSimilarityThreshold(), MaxEntries: cfg.SemanticCache.MaxEntries, TTLSeconds: cfg.SemanticCache.TTLSeconds, + EvictionPolicy: cache.EvictionPolicyType(cfg.SemanticCache.EvictionPolicy), BackendConfigPath: cfg.SemanticCache.BackendConfigPath, } diff --git a/src/semantic-router/pkg/extproc/test_utils_test.go b/src/semantic-router/pkg/extproc/test_utils_test.go index 924acfe3..d8d76d55 100644 --- a/src/semantic-router/pkg/extproc/test_utils_test.go +++ b/src/semantic-router/pkg/extproc/test_utils_test.go @@ -135,12 +135,14 @@ func CreateTestConfig() *config.RouterConfig { SimilarityThreshold *float32 `yaml:"similarity_threshold,omitempty"` MaxEntries int `yaml:"max_entries,omitempty"` TTLSeconds int `yaml:"ttl_seconds,omitempty"` + EvictionPolicy string `yaml:"eviction_policy,omitempty"` BackendConfigPath string `yaml:"backend_config_path,omitempty"` }{ BackendType: "memory", Enabled: false, // Disable for most tests SimilarityThreshold: &[]float32{0.9}[0], MaxEntries: 100, + EvictionPolicy: "lru", TTLSeconds: 3600, }, PromptGuard: config.PromptGuardConfig{ @@ -214,6 +216,7 @@ func CreateTestRouter(cfg *config.RouterConfig) (*extproc.OpenAIRouter, error) { SimilarityThreshold: cfg.GetCacheSimilarityThreshold(), MaxEntries: cfg.SemanticCache.MaxEntries, TTLSeconds: cfg.SemanticCache.TTLSeconds, + EvictionPolicy: cache.EvictionPolicyType(cfg.SemanticCache.EvictionPolicy), } semanticCache, err := cache.NewCacheBackend(cacheConfig) if err != nil { diff --git a/src/semantic-router/pkg/metrics/metrics.go b/src/semantic-router/pkg/metrics/metrics.go index f5e2db21..50fdd637 100644 --- a/src/semantic-router/pkg/metrics/metrics.go +++ b/src/semantic-router/pkg/metrics/metrics.go @@ -9,6 +9,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/consts" ) // Minimal fallback bucket configurations - used only when configuration is completely missing @@ -147,6 +148,26 @@ var ( []string{"model"}, ) + // PromptTokensPerRequest tracks the distribution of prompt tokens per request by model + PromptTokensPerRequest = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "llm_prompt_tokens_per_request", + Help: "Distribution of prompt tokens per request by model", + Buckets: []float64{0, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384}, + }, + []string{"model"}, + ) + + // CompletionTokensPerRequest tracks the distribution of completion tokens per request by model + CompletionTokensPerRequest = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "llm_completion_tokens_per_request", + Help: "Distribution of completion tokens per request by model", + Buckets: []float64{0, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384}, + }, + []string{"model"}, + ) + // ModelRoutingModifications tracks when a model is changed from one to another ModelRoutingModifications = promauto.NewCounterVec( prometheus.CounterOpts{ @@ -258,11 +279,12 @@ var ( []string{"backend"}, ) - // CategoryClassifications tracks the number of times each category is classified - CategoryClassifications = promauto.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "llm_category_classifications_total", - Help: "The total number of times each category is classified", + // CategoryClassificationsCount is an alias with a name preferred by the issue request. + // It mirrors CategoryClassifications and is incremented alongside it for compatibility. + CategoryClassificationsCount = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "llm_category_classifications_count", + Help: "The total number of times each category is classified (alias metric)", }, []string{"category"}, ) @@ -363,7 +385,7 @@ var ( // RecordModelRequest increments the counter for requests to a specific model func RecordModelRequest(model string) { if model == "" { - model = "unknown" + model = consts.UnknownLabel } ModelRequests.WithLabelValues(model).Inc() } @@ -371,10 +393,10 @@ func RecordModelRequest(model string) { // RecordRequestError increments request error counters labeled by model and normalized reason func RecordRequestError(model, reason string) { if model == "" { - model = "unknown" + model = consts.UnknownLabel } if reason == "" { - reason = "unknown" + reason = consts.UnknownLabel } // Normalize a few common variants to canonical reasons switch reason { @@ -414,10 +436,10 @@ func RecordModelCost(model string, currency string, amount float64) { // RecordRoutingReasonCode increments the counter for a routing decision reason code and model func RecordRoutingReasonCode(reasonCode, model string) { if reasonCode == "" { - reasonCode = "unknown" + reasonCode = consts.UnknownLabel } if model == "" { - model = "unknown" + model = consts.UnknownLabel } RoutingReasonCodes.WithLabelValues(reasonCode, model).Inc() } @@ -429,6 +451,13 @@ func RecordModelTokensDetailed(model string, promptTokens, completionTokens floa ModelTokens.WithLabelValues(model).Add(totalTokens) ModelPromptTokens.WithLabelValues(model).Add(promptTokens) ModelCompletionTokens.WithLabelValues(model).Add(completionTokens) + + // Also record per-request histograms for visibility into distribution + if model == "" { + model = consts.UnknownLabel + } + PromptTokensPerRequest.WithLabelValues(model).Observe(promptTokens) + CompletionTokensPerRequest.WithLabelValues(model).Observe(completionTokens) } // RecordModelCompletionLatency records the latency of a model completion @@ -442,7 +471,7 @@ func RecordModelTTFT(model string, seconds float64) { return } if model == "" { - model = "unknown" + model = consts.UnknownLabel } ModelTTFT.WithLabelValues(model).Observe(seconds) } @@ -453,7 +482,7 @@ func RecordModelTPOT(model string, secondsPerToken float64) { return } if model == "" { - model = "unknown" + model = consts.UnknownLabel } ModelTPOT.WithLabelValues(model).Observe(secondsPerToken) } @@ -484,9 +513,12 @@ func UpdateCacheEntries(backend string, count int) { CacheEntriesTotal.WithLabelValues(backend).Set(float64(count)) } -// RecordCategoryClassification increments the gauge for a specific category classification +// RecordCategoryClassification increments the counter for a specific category classification func RecordCategoryClassification(category string) { - CategoryClassifications.WithLabelValues(category).Inc() + if category == "" { + category = consts.UnknownLabel + } + CategoryClassificationsCount.WithLabelValues(category).Inc() } // RecordPIIViolation records a PII policy violation for a specific model and PII data type @@ -544,7 +576,7 @@ func GetBatchSizeRange(size int) string { } // Fallback for unexpected cases - return "unknown" + return consts.UnknownLabel } // GetBatchSizeRangeFromBuckets generates range labels based on size buckets @@ -725,7 +757,7 @@ func RecordReasoningDecision(category, model string, enabled bool, effort string // RecordReasoningTemplateUsage records usage of a model-family-specific template parameter func RecordReasoningTemplateUsage(family, param string) { if family == "" { - family = "unknown" + family = consts.UnknownLabel } if param == "" { param = "none" @@ -736,7 +768,7 @@ func RecordReasoningTemplateUsage(family, param string) { // RecordReasoningEffortUsage records the effort usage by model family func RecordReasoningEffortUsage(family, effort string) { if family == "" { - family = "unknown" + family = consts.UnknownLabel } if effort == "" { effort = "unspecified" @@ -747,7 +779,7 @@ func RecordReasoningEffortUsage(family, effort string) { // RecordEntropyClassificationDecision records an entropy-based classification decision func RecordEntropyClassificationDecision(uncertaintyLevel string, reasoningEnabled bool, decisionReason string, topCategory string) { if uncertaintyLevel == "" { - uncertaintyLevel = "unknown" + uncertaintyLevel = consts.UnknownLabel } if decisionReason == "" { decisionReason = "unspecified" @@ -767,7 +799,7 @@ func RecordEntropyClassificationDecision(uncertaintyLevel string, reasoningEnabl // RecordEntropyValue records the entropy value for a classification func RecordEntropyValue(category string, classificationType string, entropyValue float64) { if category == "" { - category = "unknown" + category = consts.UnknownLabel } if classificationType == "" { classificationType = "standard" @@ -779,7 +811,7 @@ func RecordEntropyValue(category string, classificationType string, entropyValue // RecordClassificationConfidence records the confidence score from classification func RecordClassificationConfidence(category string, classificationMethod string, confidence float64) { if category == "" { - category = "unknown" + category = consts.UnknownLabel } if classificationMethod == "" { classificationMethod = "traditional" @@ -796,10 +828,10 @@ func RecordEntropyClassificationLatency(seconds float64) { // RecordProbabilityDistributionQuality records quality checks for probability distributions func RecordProbabilityDistributionQuality(qualityCheck string, status string) { if qualityCheck == "" { - qualityCheck = "unknown" + qualityCheck = consts.UnknownLabel } if status == "" { - status = "unknown" + status = consts.UnknownLabel } ProbabilityDistributionQuality.WithLabelValues(qualityCheck, status).Inc() @@ -808,7 +840,7 @@ func RecordProbabilityDistributionQuality(qualityCheck string, status string) { // RecordEntropyFallback records when entropy-based routing falls back to traditional methods func RecordEntropyFallback(fallbackReason string, fallbackStrategy string) { if fallbackReason == "" { - fallbackReason = "unknown" + fallbackReason = consts.UnknownLabel } if fallbackStrategy == "" { fallbackStrategy = "unspecified" diff --git a/tools/make/golang.mk b/tools/make/golang.mk index 134c67e3..06441e20 100644 --- a/tools/make/golang.mk +++ b/tools/make/golang.mk @@ -26,3 +26,19 @@ check-go-mod-tidy: fi @echo "✅ src/semantic-router go mod tidy check passed" @echo "✅ All go mod tidy checks passed" + +# Controller-gen targets +install-controller-gen: + @echo "Installing controller-gen..." + @cd src/semantic-router && go install sigs.k8s.io/controller-tools/cmd/controller-gen@latest + +generate-crd: install-controller-gen + @echo "Generating CRD manifests..." + @cd src/semantic-router && controller-gen crd:crdVersions=v1,allowDangerousTypes=true paths=./pkg/apis/vllm.ai/v1alpha1 output:crd:artifacts:config=../../deploy/kubernetes/crds + +generate-deepcopy: install-controller-gen + @echo "Generating deepcopy methods..." + @cd src/semantic-router && controller-gen object:headerFile=./hack/boilerplate.go.txt paths=./pkg/apis/vllm.ai/v1alpha1 + +generate-api: generate-deepcopy generate-crd + @echo "Generated all API artifacts" \ No newline at end of file diff --git a/tools/make/linter.mk b/tools/make/linter.mk index 9c8114b5..940b8b35 100644 --- a/tools/make/linter.mk +++ b/tools/make/linter.mk @@ -2,11 +2,11 @@ # = Everything For Project Linter, markdown, yaml, code spell etc. = # =============================== linter.mk ========================== -docs-lint: +docs-lint: docs-install @$(LOG_TARGET) cd website && npm run lint -docs-lint-fix: +docs-lint-fix: docs-install @$(LOG_TARGET) cd website && npm run lint:fix diff --git a/tools/make/pre-commit.mk b/tools/make/pre-commit.mk new file mode 100644 index 00000000..97735adf --- /dev/null +++ b/tools/make/pre-commit.mk @@ -0,0 +1,24 @@ +precommit-install: + pip install pre-commit + +precommit-check: + @FILES=$$(find . -type f \( -name "*.go" -o -name "*.rs" -o -name "*.py" -o -name "*.js" -o -name "*.md" -o -name "*.yaml" -o -name "*.yml" \) \ + ! -path "./target/*" \ + ! -path "./candle-binding/target/*" \ + ! -path "./.git/*" \ + ! -path "./node_modules/*" \ + ! -path "./vendor/*" \ + ! -path "./__pycache__/*" \ + ! -path "./site/*" \ + ! -name "*.pb.go" \ + | tr '\n' ' '); \ + if [ -n "$$FILES" ]; then \ + echo "Running pre-commit on files: $$FILES"; \ + pre-commit run --files $$FILES; \ + else \ + echo "No Go, Rust, JavaScript, Markdown, Yaml, or Python files found to check"; \ + fi + +precommit-local: + docker pull ghcr.io/vllm/semantic-router/precommit:latest + docker run --rm -v $$(pwd):/data ghcr.io/vllm-project/semantic-router/precommit:latest pre-commit run --all-files diff --git a/website/docs/getting-started/configuration.md b/website/docs/getting-started/configuration.md index ebeb8633..de224553 100644 --- a/website/docs/getting-started/configuration.md +++ b/website/docs/getting-started/configuration.md @@ -17,10 +17,12 @@ bert_model: # Semantic caching semantic_cache: + backend_type: "memory" # Options: "memory" or "milvus" enabled: false similarity_threshold: 0.8 max_entries: 1000 ttl_seconds: 3600 + eviction_policy: "fifo" # Options: "fifo", "lru", "lfu" # Tool auto-selection tools: @@ -365,10 +367,12 @@ Configure additional features: ```yaml # Semantic Caching semantic_cache: - enabled: true # Enable semantic caching - similarity_threshold: 0.8 # Cache hit threshold + enabled: true # Enable semantic caching + backend_type: "memory" # Options: "memory" or "milvus" + similarity_threshold: 0.8 # Cache hit threshold max_entries: 1000 # Maximum cache entries ttl_seconds: 3600 # Cache expiration time + eviction_policy: "fifo" # Options: "fifo", "lru", "lfu" # Tool Auto-Selection tools: @@ -539,9 +543,11 @@ model_config: # Enable caching semantic_cache: enabled: true + backend_type: "memory" similarity_threshold: 0.85 # Higher = more cache hits max_entries: 5000 - ttl_seconds: 7200 # 2 hour cache + ttl_seconds: 7200 # 2 hour cache + eviction_policy: "fifo" # Options: "fifo", "lru", "lfu" # Enable tool selection tools: @@ -656,9 +662,11 @@ For high-traffic scenarios: # Enable caching semantic_cache: enabled: true + backend_type: "memory" similarity_threshold: 0.85 # Higher = more cache hits max_entries: 10000 ttl_seconds: 3600 + eviction_policy: "lru" # Optimize classification classifier: diff --git a/website/docs/getting-started/installation.md b/website/docs/getting-started/installation.md index 2af4cced..0970c832 100644 --- a/website/docs/getting-started/installation.md +++ b/website/docs/getting-started/installation.md @@ -108,9 +108,6 @@ vllm_endpoints: model_config: "your-model-name": - param_count: 671000000000 # 671B parameters for DeepSeek-V3.1 - batch_size: 512.0 # vLLM default batch size - context_size: 65536.0 # DeepSeek-V3.1 context length pii_policy: allow_by_default: false # Deny all PII by default pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] # Only allow these specific PII types diff --git a/website/docs/training/model-performance-eval.md b/website/docs/training/model-performance-eval.md new file mode 100644 index 00000000..87cd0a05 --- /dev/null +++ b/website/docs/training/model-performance-eval.md @@ -0,0 +1,323 @@ +# Model Performance Evaluation +## Why evaluate? +Evaluation makes routing data-driven. By measuring per-category accuracy on MMLU-Pro (and doing a quick sanity check with ARC), you can: + +- Select the right model for each category and rank them into categories.model_scores +- Pick a sensible default_model based on overall performance +- Decide when CoT prompting is worth the latency/cost tradeoff +- Catch regressions when models, prompts, or parameters change +- Keep changes reproducible and auditable for CI and releases + +In short, evaluation converts anecdotes into measurable signals that improve quality, cost efficiency, and reliability of the router. + +--- + +This guide documents the automated workflow to evaluate models (MMLU-Pro and ARC Challenge) via a vLLM-compatible OpenAI endpoint, generate a performance-based routing config, and update `categories.model_scores` in config. + +see code in [/src/training/model_eval](https://github.com/vllm-project/semantic-router/tree/main/src/training/model_eval) + +### What you'll run end-to-end +#### 1) Evaluate models + +- per-category accuracies +- ARC Challenge: overall accuracy + +#### 2) Visualize results + +- bar/heatmap plot of per-category accuracies + +![Bar](/img/bar.png) +![Heatmap](/img/heatmap.png) + +#### 3) Generate an updated config.yaml + +- Rank models per category into categories.model_scores +- Set default_model to the best average performer +- Keep or apply category-level reasioning settings + +## 1.Prerequisites + +- A running vLLM-compatible OpenAI endpoint serving your models + - Endpoint URL like http://localhost:8000/v1 + - Optional API key if your endpoint requires one + + ```bash + # Terminal 1 + vllm serve microsoft/phi-4 --port 11434 --served_model_name phi4 + + # Terminal 2 + vllm serve Qwen/Qwen3-0.6B --port 11435 --served_model_name qwen3-0.6B + ``` + +- Python packages for evaluation scripts: + - From the repo root: matplotlib in [requirements.txt](https://github.com/vllm-project/semantic-router/blob/main/requirements.txt) + - From `/src/training/model_eval`: [requirements.txt](https://github.com/vllm-project/semantic-router/blob/main/src/training/model_eval/requirements.txt) + + ```bash + # We will work at this dir in this guide + cd /src/training/model_eval + pip install -r requirements.txt + ``` + +**Optional tip:** + +- Ensure your `config/config.yaml` includes your deployed model names under `vllm_endpoints[].models` and any pricing/policy under `model_config` if you plan to use the generated config directly. + +## 2.Evaluate on MMLU-Pro +see script in [mmul_pro_vllm_eval.py](https://github.com/vllm-project/semantic-router/blob/main/src/training/model_eval/mmlu_pro_vllm_eval.py) + +### Example usage patterns + +```bash +# Evaluate a few models, few samples per category, direct prompting +python mmlu_pro_vllm_eval.py \ + --endpoint http://localhost:11434/v1 \ + --models phi4 \ + --samples-per-category 10 + +python mmlu_pro_vllm_eval.py \ + --endpoint http://localhost:11435/v1 \ + --models qwen3-0.6B \ + --samples-per-category 10 + +# Evaluate with CoT (results saved under *_cot) +python mmlu_pro_vllm_eval.py \ + --endpoint http://localhost:11435/v1 \ + --models qwen3-0.6B \ + --samples-per-category 10 + --use-cot + +# If you have set up Semantic Router properly, you can run in one go +python mmlu_pro_vllm_eval.py \ + --endpoint http://localhost:8801/v1 \ + --models qwen3-0.6B, phi4 \ + --samples-per-category + # --use-cot # Uncomment this line if use CoT +``` + +### Key flags + +- **--endpoint**: vLLM OpenAI URL (default http://localhost:8000/v1) +- **--models**: space-separated list OR a single comma-separated string; if omitted, the script queries /models from the endpoint +- **--categories**: restrict evaluation to specific categories; if omitted, uses all categories in the dataset +- **--samples-per-category**: limit questions per category (useful for quick runs) +- **--use-cot**: enables Chain-of-Thought prompting variant; results are saved in a separate subfolder suffix (_cot vs _direct) +- **--concurrent-requests**: concurrency for throughput +- **--output-dir**: where results are saved (default results) +- **--max-tokens**, **--temperature**, **--seed**: generation and reproducibility knobs + +### What it outputs per model + +- **results/Model_Name_(direct|cot)/** + - **detailed_results.csv**: one row per question with is_correct and category + - **analysis.json**: overall_accuracy, category_accuracy map, avg_response_time, counts + - **summary.json**: condensed metrics +- **mmlu_pro_vllm_eval.txt**: prompts and answers log (debug/inspection) + +**Note** + +- **Model naming**: slashes are replaced with underscores for folder names; e.g., gemma3:27b -> gemma3:27b_direct directory. +- Category accuracy is computed on successful queries only; failed requests are excluded. + +## 3.Evaluate on ARC Challenge (optional, overall sanity check) +see script in [arc_challenge_vllm_eval.py](https://github.com/vllm-project/semantic-router/blob/main/src/training/model_eval/arc_challenge_vllm_eval.py) + +### Example usage patterns + +``` bash +python arc_challenge_vllm_eval.py \ + --endpoint http://localhost:8801/v1\ + --models qwen3-0.6B,phi4 + --output-dir arc_results +``` + +### Key flags + +- **--samples**: total questions to sample (default 20); ARC is not categorized in our script +- Other flags mirror the **MMLU-Pro** script + +### What it outputs per model + +- **results/Model_Name_(direct|cot)/** + - **detailed_results.csv**: one row per question with is_correct and category + - **analysis.json**: overall_accuracy, avg_response_time + - **summary.json**: condensed metrics +- **arc_challenge_vllm_eval.txt**: prompts and answers log (debug/inspection) + +**Note** + +- ARC results do not feed `categories[].model_scores` directly, but they can help spot regressions. + +## 4.Visualize per-category performance +see script in [plot_category_accuracies.py](https://github.com/vllm-project/semantic-router/blob/main/src/training/model_eval/plot_category_accuracies.py) + +### Example usage patterns: + +```bash +# Use results/ to generate bar plot +python plot_category_accuracies.py \ + --results-dir results \ + --plot-type bar \ + --output-file results/bar.png + +# Use results/ to generate heatmap plot +python plot_category_accuracies.py \ + --results-dir results \ + --plot-type heatmap \ + --output-file results/heatmap.png + +# Use sample-data to generate example plot +python src/training/model_eval/plot_category_accuracies.py \ + --sample-data \ + --plot-type heatmap \ + --output-file results/category_accuracies.png +``` + +### Key flags + +- **--results-dir**: where analysis.json files are +- **--plot-type**: bar or heatmap +- **--output-file**: output image path (default model_eval/category_accuracies.png) +- **--sample-data**: if no results exist, generates fake data to preview the plot + +### What it does + +- Finds all `results/**/analysis.json`, aggregates analysis["category_accuracy"] per model +- Adds an Overall column representing the average across categories +- Produces a figure to quickly compare model/category performance + +**Note** + +- It merges `direct` and `cot` as distinct model variants by appending `:direct` or `:cot` to the label; the legend hides `:direct` for brevity. + +## 5.Generate performance-based routing config +see script in [result_to_config.py](https://github.com/vllm-project/semantic-router/blob/main/src/training/model_eval/result_to_config.py) + +### Example usage patterns + +```bash +# Use results/ to generate a new config file (not overridden) +python src/training/model_eval/result_to_config.py \ + --results-dir results \ + --output-file config/config.eval.yaml + +# Modify similarity-thredshold +python src/training/model_eval/result_to_config.py \ + --results-dir results \ + --output-file config/config.eval.yaml \ + --similarity-threshold 0.85 + +# Generate from specific folder +python src/training/model_eval/result_to_config.py \ + --results-dir results/mmlu_run_2025_09_10 \ + --output-file config/config.eval.yaml +``` + +### Key flags + +- **--results-dir**: points to the folder where analysis.json files live +- **--output-file**: target config path (default config/config.yaml) +- **--similarity-threshold**: semantic cache threshold to set in the generated config + +### What it does + +- Reads all `analysis.json` files, extracting analysis["category_accuracy"] +- Constructs a new config: + - **categories**: For each category present in results, ranks models by accuracy: + - **category.model_scores** = `[{ model: "Model_Name", score: 0.87 }, ...]`, highest first + - **default_model**: the best average performer across categories + - **category reasoning settings**: auto-filled from a built-in mapping (you can adjust after generation) + - math, physics, chemistry, CS, engineering -> high reasoning + - others default -> low/medium + - Leaves out any special “auto” placeholder models if present + +### Schema alignment + +- **categories[].name**: the MMLU-Pro category string +- **categories[].model_scores**: descending ranking by accuracy for that category +- **default_model**: a top performer across categories (approach suffix removed, e.g., gemma3:27b from gemma3:27b:direct) +- Keeps other config sections (semantic_cache, tools, classifier, prompt_guard) with reasonable defaults; you can edit them post-generation if your environment differs + +**Note** + +- This script only work with results from **MMLU_Pro** Evaluation. +- Existing config.yaml can be overwritten. Consider writing to a temp file first and diffing: + - `--output-file config/config.eval.yaml` +- If your production config.yaml carries **environment-specific settings (endpoints, pricing, policies)**, port the evaluated `categories[].model_scores` and `default_model` back into your canonical config. + +### Example config.eval.yaml +see more about config at [configuration](https://vllm-semantic-router.com/docs/getting-started/configuration) + +```yaml +bert_model: + model_id: sentence-transformers/all-MiniLM-L12-v2 + threshold: 0.6 + use_cpu: true +semantic_cache: + enabled: true + similarity_threshold: 0.85 + max_entries: 1000 + ttl_seconds: 3600 +tools: + enabled: true + top_k: 3 + similarity_threshold: 0.2 + tools_db_path: config/tools_db.json + fallback_to_empty: true +prompt_guard: + enabled: true + use_modernbert: true + model_id: models/jailbreak_classifier_modernbert-base_model + threshold: 0.7 + use_cpu: true + jailbreak_mapping_path: models/jailbreak_classifier_modernbert-base_model/jailbreak_type_mapping.json + +# Lack of endpoint config and model_config right here, modify here as needed + +classifier: + category_model: + model_id: models/category_classifier_modernbert-base_model + use_modernbert: true + threshold: 0.6 + use_cpu: true + category_mapping_path: models/category_classifier_modernbert-base_model/category_mapping.json + pii_model: + model_id: models/pii_classifier_modernbert-base_presidio_token_model + use_modernbert: true + threshold: 0.7 + use_cpu: true + pii_mapping_path: models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json +categories: +- name: business + use_reasoning: false + reasoning_description: Business content is typically conversational + reasoning_effort: low + model_scores: + - model: phi4 + score: 0.2 + - model: qwen3-0.6B + score: 0.0 +- name: law + use_reasoning: false + reasoning_description: Legal content is typically explanatory + reasoning_effort: medium + model_scores: + - model: phi4 + score: 0.8 + - model: qwen3-0.6B + score: 0.2 + +# Ignore some categories here + +- name: engineering + use_reasoning: true + reasoning_description: Engineering problems require systematic problem-solving + reasoning_effort: high + model_scores: + - model: phi4 + score: 0.6 + - model: qwen3-0.6B + score: 0.2 +default_reasoning_effort: medium +default_model: phi4 +``` diff --git a/website/docs/training/training-overview.md b/website/docs/training/training-overview.md index f4e0c476..1dee0d93 100644 --- a/website/docs/training/training-overview.md +++ b/website/docs/training/training-overview.md @@ -998,3 +998,7 @@ lora_training_infrastructure: lora_training: "$5-20 per model (reduced compute)" savings: "80-90% cost reduction" ``` + +## Next + +- See: [Model Performance Evaluation](/docs/training/model-performance-eval) diff --git a/website/sidebars.js b/website/sidebars.js index b0397e97..edf8296d 100644 --- a/website/sidebars.js +++ b/website/sidebars.js @@ -47,6 +47,7 @@ const sidebars = { label: 'Model Training', items: [ 'training/training-overview', + 'training/model-performance-eval', ], }, { diff --git a/website/src/pages/community/community-page.module.css b/website/src/pages/community/community-page.module.css index 9bf221b1..f9443f89 100644 --- a/website/src/pages/community/community-page.module.css +++ b/website/src/pages/community/community-page.module.css @@ -130,6 +130,16 @@ margin-top: 0.25rem; } +.stepNumberTips { + display: flex; + flex-direction: column; + margin: 0 0 0 1rem; +} + +.stepNumberTips p { + line-height: 1.5; +} + .step h4 { margin: 0 0 0.5rem 0; color: var(--ifm-color-primary); diff --git a/website/src/pages/community/contributing.js b/website/src/pages/community/contributing.js index 02358d47..d7d9f3a4 100644 --- a/website/src/pages/community/contributing.js +++ b/website/src/pages/community/contributing.js @@ -103,6 +103,16 @@ export default function Contributing() {

Test

Run tests and ensure your changes don't break existing functionality.

+
+

1. Run precommit hooks, ensure compliance with the project submission guidelines;

+

+ 2. You can refer to + {' '} + Install the local + {' '} + to start semantic-router locally. +

+
@@ -117,6 +127,64 @@ export default function Contributing() { +
+

⚙️ Precommit hooks

+

The Semantic-router project provides a precommit hook to standardize the entire project, including Go, Python, Rust, Markdown, and spelling error checking.

+

Although these measures may increase the difficulty of contributions, they are necessary. We are currently building a portable Docker precommit environment to reduce the difficulty of contributions, allowing you to focus on functional pull requests.

+ +
+

Manual

+ +

Some Tips:

+
+

1. If the precommit check fails, don't worry. You can also get more information by executing "make help".

+

2. For the pip installation tool, we recommend that you use venv for installation.

+

3. You can also directly submit the PR and let GitHub CI test it for you, but this will take a lot of time!

+
+ +
+
+ 1 +
+

Install precommit

+

Run pip install --user precommit

+
+
+
+ 2 +
+

Install check tools

+
+

Markdown: npm install -g markdownlint-cli

+

Yaml: pip install --user yamllint

+

CodeSpell: pip install --user codespell

+

JavaScript: cd website && npm lint

+
+
+
+
+ 3 +
+

Install precommit to git

+

Run pre-commit install, then pre-commit installed at .git/hooks/pre-commit

+
+
+
+ 4 +
+

Run

+

Run make precommit-check to check.

+
+
+ +
+ +

Docker

+

Coming soon!

+
+
+
+

🏷️ Working Group Areas

diff --git a/website/static/img/bar.png b/website/static/img/bar.png new file mode 100644 index 00000000..d369681d Binary files /dev/null and b/website/static/img/bar.png differ diff --git a/website/static/img/heatmap.png b/website/static/img/heatmap.png new file mode 100644 index 00000000..06b4ea26 Binary files /dev/null and b/website/static/img/heatmap.png differ