-
Notifications
You must be signed in to change notification settings - Fork 2
feat: two-stage evaluation pipeline with prefilter (#14) #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
93d1d16
adaa5af
eada07e
987bb46
d4ec614
ca40c1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,32 +1,122 @@ | ||
| import importlib.resources | ||
| import json | ||
| import logging | ||
| import re | ||
| import subprocess | ||
| from abc import ABC, abstractmethod | ||
| from pathlib import Path | ||
| from typing import Optional | ||
|
|
||
| from context_scribe.models.interaction import Interaction | ||
| from context_scribe.models.evaluator_models import RuleOutput, INTERNAL_SIGNATURE | ||
| from context_scribe.models.evaluator_models import ( | ||
| RuleOutput, INTERNAL_SIGNATURE, PrefilterResult, PrefilterMetrics, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _parse_bool(value) -> Optional[bool]: | ||
| """Safely parse a boolean that may arrive as a string from LLM JSON. | ||
|
|
||
| Returns ``None`` for unrecognised or null values so the caller can | ||
| fall back to full evaluation (fail-open behaviour). | ||
| """ | ||
| if isinstance(value, bool): | ||
| return value | ||
| if isinstance(value, str): | ||
| normalised = value.strip().lower() | ||
| if normalised in ("true", "1", "yes"): | ||
| return True | ||
| if normalised in ("false", "0", "no"): | ||
| return False | ||
| return None # unrecognised string → pass through | ||
| return None # None / other types → pass through | ||
|
|
||
|
|
||
| def _load_package_template(filename: str) -> str: | ||
| """Load a template file from this package using importlib.resources.""" | ||
| return ( | ||
| importlib.resources.files("context_scribe.evaluator") | ||
| .joinpath(filename) | ||
| .read_text(encoding="utf-8") | ||
| ) | ||
|
|
||
|
|
||
| class BaseEvaluator(ABC): | ||
| def __init__(self): | ||
| # Load the prompt template | ||
| template_path = Path(__file__).parent / "prompt_template.md" | ||
| with open(template_path, "r", encoding="utf-8") as f: | ||
| self.prompt_template = f.read() | ||
| def __init__(self, skip_prefilter: bool = False): | ||
| self.skip_prefilter = skip_prefilter | ||
| self.metrics = PrefilterMetrics() | ||
| # Load the prompt templates via importlib.resources (works in | ||
| # packaged installs such as wheels / zip imports). | ||
| self.prompt_template = _load_package_template("prompt_template.md") | ||
| self._prefilter_template = _load_package_template("prefilter_template.md") | ||
|
Comment on lines
+35
to
+51
|
||
|
|
||
| @abstractmethod | ||
| def _execute_cli(self, prompt: str) -> str: | ||
| """Executes the specific CLI tool and returns the raw stdout. | ||
|
|
||
| Should raise subprocess.TimeoutExpired if the execution takes too long. | ||
| """ | ||
| pass | ||
|
|
||
| def _pre_evaluate(self, interaction: Interaction) -> Optional[PrefilterResult]: | ||
| """Stage 1: Lightweight check to filter non-rule interactions.""" | ||
| prompt = self._prefilter_template.format( | ||
| internal_signature=INTERNAL_SIGNATURE, | ||
| content=interaction.content, | ||
| ) | ||
| try: | ||
| output = self._execute_cli(prompt) | ||
|
|
||
| # Extract response text from JSON wrapper if present | ||
| response_text = output | ||
| try: | ||
| data = json.loads(output) | ||
| if isinstance(data, dict): | ||
| response_text = data.get("result", data.get("response", output)) | ||
| except json.JSONDecodeError: | ||
| pass | ||
|
|
||
| response_text = re.sub(r'```(?:json)?\s*', '', str(response_text)).strip() | ||
|
|
||
| # Parse the prefilter JSON response | ||
| json_match = re.search(r'\{[^}]*"contains_rule"[^}]*\}', response_text) | ||
| if json_match: | ||
| pf_data = json.loads(json_match.group(0)) | ||
| parsed = _parse_bool(pf_data.get("contains_rule", True)) | ||
| if parsed is None: | ||
| logger.warning( | ||
| "Unrecognised contains_rule value %r, passing through to full eval", | ||
| pf_data.get("contains_rule"), | ||
| ) | ||
| return None | ||
| return PrefilterResult( | ||
| contains_rule=parsed, | ||
| confidence=float(pf_data.get("confidence", 0.0)), | ||
| ) | ||
|
|
||
| logger.warning("Could not parse prefilter response, passing through to full eval") | ||
| return None | ||
|
|
||
| except subprocess.TimeoutExpired: | ||
| logger.warning("Prefilter timed out, passing through to full eval") | ||
| return None | ||
| except Exception as e: | ||
| logger.warning("Prefilter error: %s, passing through to full eval", e) | ||
| return None | ||
|
|
||
| def evaluate_interaction(self, interaction: Interaction, existing_global: str = "", existing_project: str = "") -> Optional[RuleOutput]: | ||
| # Stage 1: Pre-filter | ||
| if not self.skip_prefilter: | ||
| prefilter_result = self._pre_evaluate(interaction) | ||
| self.metrics.record_result(prefilter_result) | ||
| if prefilter_result and prefilter_result.should_skip_full_eval: | ||
| logger.info( | ||
| "Prefilter: skipping full eval for %s (confidence=%.2f)", | ||
| interaction.project_name, prefilter_result.confidence, | ||
| ) | ||
| return None | ||
|
|
||
| # Stage 2: Full extraction | ||
| prompt = self.prompt_template.format( | ||
| internal_signature=INTERNAL_SIGNATURE, | ||
| project_name=interaction.project_name, | ||
|
|
@@ -37,36 +127,29 @@ def evaluate_interaction(self, interaction: Interaction, existing_global: str = | |
|
|
||
| try: | ||
| output = self._execute_cli(prompt) | ||
|
|
||
| # Extract response text | ||
| response_text = output | ||
| try: | ||
| data = json.loads(output) | ||
| if isinstance(data, dict): | ||
| # Handle both gemini ("response") and claude ("result"/"response") formats | ||
| response_text = data.get("result", data.get("response", output)) | ||
| except json.JSONDecodeError: | ||
| pass | ||
|
|
||
| # Strip markdown code fences if present (Claude often wraps JSON in ```json ... ```) | ||
| # Strip markdown code fences if present | ||
| response_text = re.sub(r'```(?:json)?\s*', '', str(response_text)).strip() | ||
|
|
||
| # Robust JSON extraction: look for substrings that start with { and end with } | ||
| # and contain both "scope" and "rules" | ||
| # Robust JSON extraction | ||
| best_rule_data = None | ||
|
|
||
| # Find all { and } positions | ||
| start_indices = [i for i, char in enumerate(response_text) if char == '{'] | ||
| end_indices = [i for i, char in enumerate(response_text) if char == '}'] | ||
|
|
||
| # Try progressively smaller substrings starting from the first { and ending at the last } | ||
| # until we find a valid JSON object that has our keys. | ||
|
|
||
| for start in start_indices: | ||
| for end in reversed(end_indices): | ||
| if end > start: | ||
| try: | ||
| candidate = response_text[start:end+1] | ||
| # Quick check to avoid expensive json.loads on non-candidates | ||
| if '"scope"' in candidate and '"rules"' in candidate: | ||
| data = json.loads(candidate) | ||
| if isinstance(data, dict) and "scope" in data and "rules" in data: | ||
|
|
@@ -81,35 +164,34 @@ def evaluate_interaction(self, interaction: Interaction, existing_global: str = | |
| try: | ||
| rules_raw = best_rule_data["rules"] | ||
| desc = best_rule_data.get("description", "Updated rules") | ||
|
|
||
| if isinstance(rules_raw, list): | ||
| rules_content = "\n".join([str(r) for r in rules_raw]).strip() | ||
| else: | ||
| rules_content = str(rules_raw).strip() | ||
|
|
||
| if len(rules_content) > 0: | ||
| return RuleOutput( | ||
| content=rules_content, | ||
| scope=str(best_rule_data["scope"]).upper(), | ||
| content=rules_content, | ||
| scope=str(best_rule_data["scope"]).upper(), | ||
| description=str(desc) | ||
| ) | ||
| except Exception as e: | ||
| logger.debug(f"Failed to extract rule fields from JSON: {e}") | ||
|
|
||
| if "NO_RULE" in str(response_text): | ||
| return None | ||
| # Fallback for non-JSON responses (robustness) | ||
|
|
||
| # Fallback for non-JSON responses | ||
| text_upper = str(response_text).upper() | ||
| if "PROJECT" in text_upper or "GLOBAL" in text_upper: | ||
| scope = "PROJECT" if "PROJECT" in text_upper else "GLOBAL" | ||
| # Try to find some content if rules are just listed | ||
| content = str(response_text) | ||
| return RuleOutput(content=content, scope=scope, description="Extracted via fallback") | ||
|
|
||
| logger.error(f"Failed to parse rule extraction for {interaction.project_name}") | ||
| return None | ||
|
|
||
| except subprocess.TimeoutExpired: | ||
| logger.error(f"Evaluation timed out for {interaction.project_name}") | ||
| return None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| {internal_signature} | ||
| You are a lightweight classifier. Your ONLY job is to determine whether the following | ||
| user-agent interaction contains a NEW persistent preference, project constraint, or | ||
| behavioral rule that should be remembered long-term. | ||
|
|
||
| Examples of rule-bearing interactions: | ||
| - "Always use tabs instead of spaces" | ||
| - "For this project, use PostgreSQL not MySQL" | ||
| - "Never use semicolons in TypeScript" | ||
|
|
||
| Examples of NON-rule interactions: | ||
| - "Can you help me fix this bug?" | ||
| - "Explain how async/await works" | ||
| - "Generate a function that sorts a list" | ||
|
|
||
| INTERACTION: | ||
| ''' | ||
| {content} | ||
| ''' | ||
|
|
||
| Respond with ONLY a JSON object: | ||
| {{"contains_rule": true, "confidence": 0.95}} | ||
| or | ||
| {{"contains_rule": false, "confidence": 0.90}} |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -41,6 +41,9 @@ def __init__(self, tool: str, bank_path: str): | |||||||||||
| self.last_event_time = "N/A" | ||||||||||||
| self.update_count = 0 | ||||||||||||
| self.history = [] # List of (time, file_path, description) tuples | ||||||||||||
| self.prefilter_passed = 0 | ||||||||||||
| self.prefilter_skipped = 0 | ||||||||||||
| self.prefilter_errors = 0 | ||||||||||||
|
|
||||||||||||
| def add_history(self, file_path: str, description: str): | ||||||||||||
| self.update_count += 1 | ||||||||||||
|
|
@@ -100,9 +103,13 @@ def generate_layout(self) -> Layout: | |||||||||||
| # Footer | ||||||||||||
| stats = Table.grid(expand=True) | ||||||||||||
| stats.add_column(justify="left") | ||||||||||||
| stats.add_column(justify="center") | ||||||||||||
| stats.add_column(justify="right") | ||||||||||||
| total_processed = self.prefilter_passed + self.prefilter_skipped | ||||||||||||
| skip_rate = (self.prefilter_skipped / total_processed * 100) if total_processed > 0 else 0.0 | ||||||||||||
| stats.add_row( | ||||||||||||
| Text(f" System: Active", style="green"), | ||||||||||||
| Text(f"Prefilter: {self.prefilter_skipped} skipped / {total_processed} total ({skip_rate:.0f}%) | {self.prefilter_errors} errors", style="dim"), | ||||||||||||
| Text(f"Total Rules Extracted: {self.update_count} ", style="bold green") | ||||||||||||
|
Comment on lines
+108
to
113
|
||||||||||||
| ) | ||||||||||||
| layout["footer"].update(Panel(stats, border_style="dim")) | ||||||||||||
|
|
@@ -235,7 +242,7 @@ def _status(msg: str, db, live, debug: bool): | |||||||||||
| live.update(db.generate_layout()) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_name: str = "auto", tools: Optional[List[str]] = None) -> bool: | ||||||||||||
| async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_name: str = "auto", skip_prefilter: bool = False, tools: Optional[List[str]] = None) -> bool: | ||||||||||||
| # Build provider list: --tools takes precedence over --tool | ||||||||||||
| if tools is not None: | ||||||||||||
| if not tools: | ||||||||||||
|
|
@@ -249,7 +256,7 @@ async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_n | |||||||||||
|
|
||||||||||||
| if evaluator_name == "auto": | ||||||||||||
| evaluator_name = _detect_evaluator(tool_names[0]) | ||||||||||||
| evaluator = get_evaluator(evaluator_name) | ||||||||||||
| evaluator = get_evaluator(evaluator_name, skip_prefilter=skip_prefilter) | ||||||||||||
|
||||||||||||
| evaluator = get_evaluator(evaluator_name, skip_prefilter=skip_prefilter) | |
| if skip_prefilter: | |
| evaluator = get_evaluator(evaluator_name, skip_prefilter=True) | |
| else: | |
| evaluator = get_evaluator(evaluator_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved — all evaluators now accept **kwargs, so passing skip_prefilter=False no longer crashes. See ca40c1e.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_evaluator(..., **kwargs)now forwards keyword args to the evaluator class constructor, but at leastAnthropicEvaluatorstill defines__init__(model=...)and will raiseTypeError: got an unexpected keyword argument 'skip_prefilter'. Either update all registered evaluators to accept/forward**kwargs(and passskip_prefilterintoBaseEvaluator.__init__), or makeget_evaluatorfilter/only pass supported kwargs per evaluator.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved —
AnthropicEvaluator.__init__now accepts**kwargsand forwards toBaseEvaluator, matching the other evaluators. See ca40c1e.