Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 98 additions & 50 deletions src/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,17 @@ def _build_prompt(
"""
Build classification prompt from the template.

This method is shared between all classifier implementations
and uses the prompt_template attribute that should be set
during initialization.
This default implementation uses the prompt_template attribute
with placeholders for char_list, player_list, prev_text,
current_text, and next_text.

Subclasses can override this method to customize prompt building.

Args:
prev_text: Previous segment text for context
current_text: Current segment text to classify
next_text: Next segment text for context
character_names: List of character names in the campaign
prev_text: Text from previous segment
current_text: Text from current segment to classify
next_text: Text from next segment
character_names: List of character names
player_names: List of player names

Returns:
Expand All @@ -111,16 +113,19 @@ def _parse_response(
"""
Parse LLM response into ClassificationResult.

This method handles the standard response format used by all
classifier implementations:
- Classificatie: IC/OOC/MIXED
- Reden: reasoning text
- Vertrouwen: confidence score (0.0-1.0)
- Personage: character name (or N/A)
This default implementation parses responses in the format:
- Classificatie: IC|OOC|MIXED
- Reden: <reasoning text>
- Vertrouwen: <confidence value 0.0-1.0>
- Personage: <character name or N/A>

The field names are language-specific (Dutch by default).
Subclasses can override this method to support different formats
or languages.

Args:
response: Raw text response from the LLM
index: Segment index for the classification result
response: Raw response text from LLM
index: Segment index for logging purposes

Returns:
ClassificationResult with parsed values
Expand All @@ -130,41 +135,53 @@ def _parse_response(
reasoning = "Could not parse response"
character = None

lines = response.strip().split('\n')
for line in lines:
parts = line.strip().split(":", 1)
if len(parts) != 2:
continue

key, value = parts[0].strip(), parts[1].strip()

if key == "Classificatie":
try:
classification = Classification(value.upper())
except ValueError:
if hasattr(self, 'logger'):
self.logger.warning(
"Invalid classification '%s' for segment %s, defaulting to IC",
value,
index
)
classification = Classification.IN_CHARACTER
elif key == "Reden":
reasoning = value
elif key == "Vertrouwen":
try:
confidence = float(value)
confidence = ConfidenceDefaults.clamp(confidence)
except ValueError:
if hasattr(self, 'logger'):
self.logger.warning(
"Invalid confidence value '%s' for segment %s, using default",
value,
index
)
elif key == "Personage":
if value.upper() != "N/A":
character = value
# Use regex to extract fields more robustly
# This handles multi-line values and out-of-order fields
import re

# Extract classification
class_match = re.search(r'Classificatie:\s*(\w+)', response, re.IGNORECASE)
if class_match:
class_text = class_match.group(1).strip().upper()
try:
classification = Classification(class_text)
except ValueError:
self.logger.warning(
"Invalid classification '%s' for segment %s, defaulting to IC",
class_text,
index
)
classification = Classification.IN_CHARACTER

# Extract reasoning - capture everything after "Reden:" until next field or end
reden_match = re.search(
r'Reden:\s*(.+?)(?=(?:Vertrouwen:|Personage:|$))',
response,
re.DOTALL | re.IGNORECASE
)
if reden_match:
reasoning = reden_match.group(1).strip()

# Extract confidence
conf_match = re.search(r'Vertrouwen:\s*([\d.]+)', response, re.IGNORECASE)
if conf_match:
try:
conf_text = conf_match.group(1).strip()
confidence = float(conf_text)
confidence = ConfidenceDefaults.clamp(confidence)
except ValueError:
self.logger.warning(
"Invalid confidence value '%s' for segment %s, using default.",
conf_text,
index
)

# Extract character name
char_match = re.search(r'Personage:\s*(.+?)(?:\n|$)', response, re.IGNORECASE)
if char_match:
char_text = char_match.group(1).strip()
if char_text.upper() != "N/A":
character = char_text

return ClassificationResult(
segment_index=index,
Expand Down Expand Up @@ -526,6 +543,37 @@ def __init__(self, api_key: str = None, model: str = "llama-3.3-70b-versatile"):
burst_size=Config.GROQ_RATE_LIMIT_BURST,
)

def preflight_check(self):
"""Check that Groq API is accessible and configured."""
issues = []

if not self.api_key:
issues.append(
PreflightIssue(
component="classifier",
message="Groq API key not configured. Set GROQ_API_KEY in .env",
severity="error",
)
)
return issues

# Test API connectivity with a simple request
try:
test_response = self.client.chat.completions.create(
messages=[{"role": "user", "content": "test"}],
model=self.model,
)
except Exception as exc:
issues.append(
PreflightIssue(
component="classifier",
message=f"Groq API test failed: {exc}",
severity="error",
)
)

return issues

def classify_segments(
self,
segments: List[Dict],
Expand Down