Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/archive/OUTSTANDING_TASKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
- Verify fix in `src/diarizer.py`
- Add regression test in `tests/test_diarizer.py`

- [x] BUG-20251107-02: Add Progress Logging to Stage 6 IC/OOC Classification (Agent: Jules, Completed: 2025-11-23) → BUG_HUNT_TODO.md:544
- [x] BUG-20251107-03: Implement Batched Classification for Stage 6 Performance Optimization (Agent: Jules, Completed: 2025-11-23) → BUG_HUNT_TODO.md:576

---


Expand Down
231 changes: 229 additions & 2 deletions src/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .retry import retry_with_backoff
from .constants import Classification, ClassificationType, ConfidenceDefaults
from .rate_limiter import RateLimiter
from .status_tracker import StatusTracker
from .llm_factory import OllamaClientFactory, OllamaConfig, OllamaConnectionError

try: # Optional dependency for cloud inference
Expand Down Expand Up @@ -62,8 +63,7 @@ def to_dict(self) -> dict:
data["speaker_role"] = self.speaker_role
if self.character_confidence is not None:
data["character_confidence"] = self.character_confidence
if self.unknown_speaker:
data["unknown_speaker"] = self.unknown_speaker
data["unknown_speaker"] = self.unknown_speaker
if self.temporal_metadata:
data["temporal_metadata"] = self.temporal_metadata
if self.prompt_hash:
Expand Down Expand Up @@ -321,6 +321,20 @@ def __init__(
except FileNotFoundError:
raise RuntimeError(f"Prompt file not found at: {prompt_path}")

# Load batch prompt template
batch_prompt_path = Config.PROJECT_ROOT / "src" / "prompts" / f"classifier_batch_prompt_{Config.WHISPER_LANGUAGE}.txt"
if not batch_prompt_path.exists():
self.logger.warning(f"Batch prompt file for language '{Config.WHISPER_LANGUAGE}' not found. Falling back to English.")
batch_prompt_path = Config.PROJECT_ROOT / "src" / "prompts" / "classifier_batch_prompt_en.txt"

try:
with open(batch_prompt_path, 'r', encoding='utf-8') as f:
self.batch_prompt_template = f.read()
except FileNotFoundError:
# Fallback to non-batch if file is missing, but log warning
self.logger.warning(f"Batch prompt file not found at: {batch_prompt_path}. Batching might fail.")
self.batch_prompt_template = ""

# Initialize Ollama client using factory
factory = OllamaClientFactory(logger=self.logger)
ollama_config = OllamaConfig(
Expand Down Expand Up @@ -349,6 +363,8 @@ def __init__(
self.max_context_segments = Config.CLASSIFIER_CONTEXT_MAX_SEGMENTS
self.max_past_duration = Config.CLASSIFIER_CONTEXT_PAST_SECONDS
self.max_future_duration = Config.CLASSIFIER_CONTEXT_FUTURE_SECONDS
self.use_batching = Config.CLASSIFICATION_USE_BATCHING
self.batch_size = Config.CLASSIFICATION_BATCH_SIZE

def preflight_check(self):
issues = []
Expand Down Expand Up @@ -379,13 +395,42 @@ def classify_segments(
if not segments:
return []

# If batching is enabled and templates exist, use batched method
if self.use_batching and self.batch_prompt_template:
return self.classify_segments_batched(
segments, character_names, player_names, speaker_map, temporal_metadata
)

active_speaker_map = speaker_map or self._build_fallback_speaker_map(segments)
speaker_overview = self._format_speaker_overview(active_speaker_map)
session_duration = self._get_session_duration(segments)
past_classifications: List[Classification] = []
results: List[ClassificationResult] = []

# Get session_id from first segment if available for status tracking
session_id = segments[0].get("session_id", "unknown") if segments else "unknown"

total_segments = len(segments)
start_time_all = time.time()

# Log start
self.logger.info(f"Starting sequential classification for {total_segments} segments")
StatusTracker.update_stage(session_id, 6, "running", f"Classifying {total_segments} segments...")

PROGRESS_INTERVAL = 20 # Log every 20 segments

for i, segment in enumerate(segments):
# Progress Logging
if (i + 1) % PROGRESS_INTERVAL == 0:
elapsed = time.time() - start_time_all
avg_time = elapsed / (i + 1)
remaining = (total_segments - (i + 1)) * avg_time
percentage = ((i + 1) / total_segments) * 100

msg = f"Classified {i + 1}/{total_segments} ({percentage:.1f}%) - ETA: {remaining/60:.1f}m"
self.logger.info(msg)
StatusTracker.update_stage(session_id, 6, "running", msg)

context_segments = self._gather_context_segments(segments, i)
speaker_info = self._resolve_speaker_info(segment.get("speaker"), active_speaker_map)
metadata = (
Expand Down Expand Up @@ -419,6 +464,188 @@ def classify_segments(
results.append(result)
past_classifications.append(result.classification)

total_time = time.time() - start_time_all
avg_time = total_time / total_segments if total_segments > 0 else 0
self.logger.info(f"Classification complete: {total_segments} segments in {total_time/60:.1f} minutes ({avg_time:.2f}s per segment)")

return results

def classify_segments_batched(
self,
segments: List[Dict],
character_names: List[str],
player_names: List[str],
speaker_map: Optional[Dict[str, Dict[str, Any]]] = None,
temporal_metadata: Optional[List[Dict[str, Any]]] = None
) -> List[ClassificationResult]:
"""Classify segments in batches for performance optimization."""
active_speaker_map = speaker_map or self._build_fallback_speaker_map(segments)
speaker_overview = self._format_speaker_overview(active_speaker_map)

results: List[ClassificationResult] = [None] * len(segments)
total_segments = len(segments)
session_id = segments[0].get("session_id", "unknown") if segments else "unknown"

self.logger.info(f"Starting batched classification for {total_segments} segments (batch size: {self.batch_size})")
StatusTracker.update_stage(session_id, 6, "running", f"Batch classifying {total_segments} segments...")

start_time_all = time.time()

for i in range(0, total_segments, self.batch_size):
batch_segments = segments[i : min(i + self.batch_size, total_segments)]
batch_indices = list(range(i, i + len(batch_segments)))

# Prepare batch text
batch_text_lines = []
for idx, seg in zip(batch_indices, batch_segments):
info = self._resolve_speaker_info(seg.get("speaker"), active_speaker_map)
timestamp = self._format_timestamp(seg.get("start_time"))
batch_text_lines.append(f"Index {idx} [{timestamp}] {info.display_name()}: {seg.get('text', '').strip()}")

batch_text = "\n".join(batch_text_lines)

prompt = self.batch_prompt_template.format(
char_list=", ".join(character_names) if character_names else "Unknown",
player_list=", ".join(player_names) if player_names else "Unknown",
speaker_map=speaker_overview,
batch_text=batch_text
)

# Progress logging
if i > 0:
elapsed = time.time() - start_time_all
processed = i
avg_time_per_segment = elapsed / processed
remaining = (total_segments - processed) * avg_time_per_segment
percentage = (processed / total_segments) * 100
msg = f"Classified {processed}/{total_segments} ({percentage:.1f}%) - ETA: {remaining/60:.1f}m"
self.logger.info(msg)
StatusTracker.update_stage(session_id, 6, "running", msg)

try:
response_payload = self._generate_with_retry(prompt, i)

if response_payload:
response_text = response_payload.get("response", "")
parsed_results = self._parse_batch_response(response_text, batch_indices)

# Fill in results
for res in parsed_results:
idx = res.segment_index
if i <= idx < i + self.batch_size:
# Enrich result with metadata/context logic as in single mode
# Note: context-aware classification is reduced in batched mode,
# relying more on the LLM's ability to see local context in the batch
speaker_info = self._resolve_speaker_info(segments[idx].get("speaker"), active_speaker_map)

res.model = response_payload.get("model")
self._apply_speaker_metadata(res, speaker_info)
self._infer_classification_type(res, speaker_info)
self._attach_prompt_metadata(res, prompt, response_text)

results[idx] = res

except Exception as e:
self.logger.error(f"Batch classification failed for indices {batch_indices}: {e}")
# Will be handled by fallback loop

# Fill in any missing results (failed batches)
failed_indices = [idx for idx, res in enumerate(results) if res is None]
if failed_indices:
self.logger.warning(f"Falling back to sequential classification for {len(failed_indices)} failed segments")

# Group contiguous indices into mini-batches for efficiency if we want,
# but for safety let's just process them sequentially using the existing helper methods.
for idx in failed_indices:
segment = segments[idx]
context_segments = self._gather_context_segments(segments, idx)
speaker_info = self._resolve_speaker_info(segment.get("speaker"), active_speaker_map)

metadata = (
temporal_metadata[idx]
if temporal_metadata and idx < len(temporal_metadata)
else self._build_temporal_metadata(
idx,
segment,
segments,
# Pass recent classifications if available, else empty list
[r.classification for r in results[:idx] if r],
self._get_session_duration(segments)
)
)

prompt_text = self._build_prompt_with_context(
character_names=character_names,
player_names=player_names,
speaker_overview=speaker_overview,
metadata=metadata,
context_segments=context_segments,
speaker_info=speaker_info,
speaker_map=active_speaker_map,
)

# Use sequential classification helper
results[idx] = self._classify_with_context(
prompt_text,
index=idx,
speaker_info=speaker_info,
metadata=metadata,
)

total_time = time.time() - start_time_all
avg_time = total_time / total_segments if total_segments > 0 else 0
self.logger.info(f"Batch classification complete: {total_segments} segments in {total_time/60:.1f} minutes ({avg_time:.2f}s per segment)")

return results

def _parse_batch_response(self, response_text: str, expected_indices: List[int]) -> List[ClassificationResult]:
"""Parse JSON response from batch classification."""
import json

# Try to find JSON array in the text
json_match = re.search(r'\[.*\]', response_text, re.DOTALL)
if not json_match:
self.logger.warning("No JSON array found in batch response")
return []

json_str = json_match.group(0)
results = []

try:
data = json.loads(json_str)
for item in data:
index = item.get("index")
if index not in expected_indices:
continue

# Parse fields
classification = Classification.IN_CHARACTER
try:
classification = Classification(item.get("classification", "IC").upper())
except ValueError:
pass

classification_type = ClassificationType.UNKNOWN
try:
classification_type = ClassificationType(item.get("type", "UNKNOWN").upper())
except ValueError:
pass

confidence = float(item.get("confidence", ConfidenceDefaults.DEFAULT))

results.append(ClassificationResult(
segment_index=index,
classification=classification,
classification_type=classification_type,
confidence=confidence,
reasoning=item.get("reason", ""),
character=item.get("character"),
speaker_name=item.get("speaker_name")
))

except json.JSONDecodeError as e:
self.logger.error(f"Failed to decode JSON from batch response: {e}")

return results

def _classify_with_context(
Expand Down
2 changes: 2 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def get_env_as_bool(key: str, default: bool) -> bool:
CLASSIFIER_CONTEXT_FUTURE_SECONDS: float = get_env_as_float("CLASSIFIER_CONTEXT_FUTURE_SECONDS", 30.0)
CLASSIFIER_AUDIT_MODE: bool = get_env_as_bool("CLASSIFIER_AUDIT_MODE", False)
CLASSIFIER_PROMPT_PREVIEW_CHARS: int = get_env_as_int("CLASSIFIER_PROMPT_PREVIEW_CHARS", 256)
CLASSIFICATION_USE_BATCHING: bool = get_env_as_bool("CLASSIFICATION_USE_BATCHING", True)
CLASSIFICATION_BATCH_SIZE: int = get_env_as_int("CLASSIFICATION_BATCH_SIZE", 10)

# Paths
PROJECT_ROOT: Path = Path(__file__).parent.parent
Expand Down
30 changes: 30 additions & 0 deletions src/prompts/classifier_batch_prompt_en.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
You are an expert in analyzing Dungeons & Dragons sessions.

Task: Classify each of the following segments as IC/OOC/MIXED and provide a detail category.
Output MUST be a valid JSON array of objects.

Detail categories:
- CHARACTER = Player speaking/acting as their character (includes "I attack", "I cast Fireball")
- DM_NARRATION = Dungeon Master narrating the world as narrator
- NPC_DIALOGUE = Dungeon Master speaking as an NPC
- OOC_OTHER = Rules talk, jokes, or real-life chatter

Characters: {char_list}
Players: {player_list}

Speaker Map:
{speaker_map}

Segments to classify:
{batch_text}

Return a JSON array where each object has these fields:
- "index": (integer) matching the input segment index
- "classification": "IC", "OOC", or "MIXED"
- "type": "CHARACTER", "DM_NARRATION", "NPC_DIALOGUE", or "OOC_OTHER"
- "reason": (string) brief explanation
- "confidence": (float) 0.0-1.0
- "character": (string) name or null
- "speaker_name": (string) resolved speaker name or null

JSON Output:
30 changes: 30 additions & 0 deletions src/prompts/classifier_batch_prompt_nl.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
Je bent een expert in het analyseren van Nederlandstalige Dungeons & Dragons-sessies.

Taak: Classificeer elk van de volgende segmenten als IC/OOC/MIXED en geef een detailcategorie.
De uitvoer MOET een geldige JSON-array van objecten zijn.

Detailcategorieën:
- CHARACTER = Speler spreekt/handelt als het personage (inclusief "Ik val aan", "Ik cast Fireball")
- DM_NARRATION = Vertelling/beeldvorming door de DM als verteller
- NPC_DIALOGUE = DM spreekt als NPC
- OOC_OTHER = Alles buiten het verhaal (regels, grappen, real-life)

Personages: {char_list}
Spelers: {player_list}

Sprekerkaart:
{speaker_map}

Te classificeren segmenten:
{batch_text}

Retourneer een JSON-array waarbij elk object deze velden heeft:
- "index": (integer) komt overeen met de index van het invoersegment
- "classification": "IC", "OOC", of "MIXED"
- "type": "CHARACTER", "DM_NARRATION", "NPC_DIALOGUE", of "OOC_OTHER"
- "reason": (string) korte uitleg
- "confidence": (float) 0.0-1.0
- "character": (string) naam of null
- "speaker_name": (string) opgeloste sprekernaam of null

JSON Output:
10 changes: 10 additions & 0 deletions tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ def patched_config():
MockConfig.GROQ_MAX_CALLS_PER_SECOND = 2
MockConfig.GROQ_RATE_LIMIT_PERIOD_SECONDS = 1.0
MockConfig.GROQ_RATE_LIMIT_BURST = 2
MockConfig.CLASSIFICATION_BATCH_SIZE = 10
MockConfig.CLASSIFICATION_USE_BATCHING = False
MockConfig.CLASSIFIER_CONTEXT_MAX_SEGMENTS = 5
MockConfig.CLASSIFIER_CONTEXT_PAST_SECONDS = 30
MockConfig.CLASSIFIER_CONTEXT_FUTURE_SECONDS = 30
MockConfig.CLASSIFIER_PROMPT_PREVIEW_CHARS = 100
MockConfig.CLASSIFIER_AUDIT_MODE = False

# Create a dummy prompt file path
MockConfig.PROJECT_ROOT.return_value = MagicMock()
type(MockConfig).PROJECT_ROOT = MagicMock()
Expand Down Expand Up @@ -391,6 +399,8 @@ def test_to_dict(self):
"confidence": 0.9,
"reasoning": "Test reason",
"character": "Aragorn",
"classification_type": "UNKNOWN",
"unknown_speaker": False
}
assert result.to_dict() == expected_dict

Expand Down