diff --git a/src/classifier.py b/src/classifier.py index 5ff42ce..638b082 100644 --- a/src/classifier.py +++ b/src/classifier.py @@ -78,15 +78,17 @@ def _build_prompt( """ Build classification prompt from the template. - This method is shared between all classifier implementations - and uses the prompt_template attribute that should be set - during initialization. + This default implementation uses the prompt_template attribute + with placeholders for char_list, player_list, prev_text, + current_text, and next_text. + + Subclasses can override this method to customize prompt building. Args: - prev_text: Previous segment text for context - current_text: Current segment text to classify - next_text: Next segment text for context - character_names: List of character names in the campaign + prev_text: Text from previous segment + current_text: Text from current segment to classify + next_text: Text from next segment + character_names: List of character names player_names: List of player names Returns: @@ -111,16 +113,19 @@ def _parse_response( """ Parse LLM response into ClassificationResult. - This method handles the standard response format used by all - classifier implementations: - - Classificatie: IC/OOC/MIXED - - Reden: reasoning text - - Vertrouwen: confidence score (0.0-1.0) - - Personage: character name (or N/A) + This default implementation parses responses in the format: + - Classificatie: IC|OOC|MIXED + - Reden: + - Vertrouwen: + - Personage: + + The field names are language-specific (Dutch by default). + Subclasses can override this method to support different formats + or languages. Args: - response: Raw text response from the LLM - index: Segment index for the classification result + response: Raw response text from LLM + index: Segment index for logging purposes Returns: ClassificationResult with parsed values @@ -130,41 +135,53 @@ def _parse_response( reasoning = "Could not parse response" character = None - lines = response.strip().split('\n') - for line in lines: - parts = line.strip().split(":", 1) - if len(parts) != 2: - continue - - key, value = parts[0].strip(), parts[1].strip() - - if key == "Classificatie": - try: - classification = Classification(value.upper()) - except ValueError: - if hasattr(self, 'logger'): - self.logger.warning( - "Invalid classification '%s' for segment %s, defaulting to IC", - value, - index - ) - classification = Classification.IN_CHARACTER - elif key == "Reden": - reasoning = value - elif key == "Vertrouwen": - try: - confidence = float(value) - confidence = ConfidenceDefaults.clamp(confidence) - except ValueError: - if hasattr(self, 'logger'): - self.logger.warning( - "Invalid confidence value '%s' for segment %s, using default", - value, - index - ) - elif key == "Personage": - if value.upper() != "N/A": - character = value + # Use regex to extract fields more robustly + # This handles multi-line values and out-of-order fields + import re + + # Extract classification + class_match = re.search(r'Classificatie:\s*(\w+)', response, re.IGNORECASE) + if class_match: + class_text = class_match.group(1).strip().upper() + try: + classification = Classification(class_text) + except ValueError: + self.logger.warning( + "Invalid classification '%s' for segment %s, defaulting to IC", + class_text, + index + ) + classification = Classification.IN_CHARACTER + + # Extract reasoning - capture everything after "Reden:" until next field or end + reden_match = re.search( + r'Reden:\s*(.+?)(?=(?:Vertrouwen:|Personage:|$))', + response, + re.DOTALL | re.IGNORECASE + ) + if reden_match: + reasoning = reden_match.group(1).strip() + + # Extract confidence + conf_match = re.search(r'Vertrouwen:\s*([\d.]+)', response, re.IGNORECASE) + if conf_match: + try: + conf_text = conf_match.group(1).strip() + confidence = float(conf_text) + confidence = ConfidenceDefaults.clamp(confidence) + except ValueError: + self.logger.warning( + "Invalid confidence value '%s' for segment %s, using default.", + conf_text, + index + ) + + # Extract character name + char_match = re.search(r'Personage:\s*(.+?)(?:\n|$)', response, re.IGNORECASE) + if char_match: + char_text = char_match.group(1).strip() + if char_text.upper() != "N/A": + character = char_text return ClassificationResult( segment_index=index, @@ -526,6 +543,37 @@ def __init__(self, api_key: str = None, model: str = "llama-3.3-70b-versatile"): burst_size=Config.GROQ_RATE_LIMIT_BURST, ) + def preflight_check(self): + """Check that Groq API is accessible and configured.""" + issues = [] + + if not self.api_key: + issues.append( + PreflightIssue( + component="classifier", + message="Groq API key not configured. Set GROQ_API_KEY in .env", + severity="error", + ) + ) + return issues + + # Test API connectivity with a simple request + try: + test_response = self.client.chat.completions.create( + messages=[{"role": "user", "content": "test"}], + model=self.model, + ) + except Exception as exc: + issues.append( + PreflightIssue( + component="classifier", + message=f"Groq API test failed: {exc}", + severity="error", + ) + ) + + return issues + def classify_segments( self, segments: List[Dict],