diff --git a/docs/archive/OUTSTANDING_TASKS.md b/docs/archive/OUTSTANDING_TASKS.md index 8bf5506..38f937d 100644 --- a/docs/archive/OUTSTANDING_TASKS.md +++ b/docs/archive/OUTSTANDING_TASKS.md @@ -23,6 +23,9 @@ - Verify fix in `src/diarizer.py` - Add regression test in `tests/test_diarizer.py` +- [x] BUG-20251107-02: Add Progress Logging to Stage 6 IC/OOC Classification (Agent: Jules, Completed: 2025-11-23) → BUG_HUNT_TODO.md:544 +- [x] BUG-20251107-03: Implement Batched Classification for Stage 6 Performance Optimization (Agent: Jules, Completed: 2025-11-23) → BUG_HUNT_TODO.md:576 + --- diff --git a/src/classifier.py b/src/classifier.py index 5890c4e..c7bf6dd 100644 --- a/src/classifier.py +++ b/src/classifier.py @@ -14,6 +14,7 @@ from .retry import retry_with_backoff from .constants import Classification, ClassificationType, ConfidenceDefaults from .rate_limiter import RateLimiter +from .status_tracker import StatusTracker from .llm_factory import OllamaClientFactory, OllamaConfig, OllamaConnectionError try: # Optional dependency for cloud inference @@ -62,8 +63,7 @@ def to_dict(self) -> dict: data["speaker_role"] = self.speaker_role if self.character_confidence is not None: data["character_confidence"] = self.character_confidence - if self.unknown_speaker: - data["unknown_speaker"] = self.unknown_speaker + data["unknown_speaker"] = self.unknown_speaker if self.temporal_metadata: data["temporal_metadata"] = self.temporal_metadata if self.prompt_hash: @@ -321,6 +321,20 @@ def __init__( except FileNotFoundError: raise RuntimeError(f"Prompt file not found at: {prompt_path}") + # Load batch prompt template + batch_prompt_path = Config.PROJECT_ROOT / "src" / "prompts" / f"classifier_batch_prompt_{Config.WHISPER_LANGUAGE}.txt" + if not batch_prompt_path.exists(): + self.logger.warning(f"Batch prompt file for language '{Config.WHISPER_LANGUAGE}' not found. Falling back to English.") + batch_prompt_path = Config.PROJECT_ROOT / "src" / "prompts" / "classifier_batch_prompt_en.txt" + + try: + with open(batch_prompt_path, 'r', encoding='utf-8') as f: + self.batch_prompt_template = f.read() + except FileNotFoundError: + # Fallback to non-batch if file is missing, but log warning + self.logger.warning(f"Batch prompt file not found at: {batch_prompt_path}. Batching might fail.") + self.batch_prompt_template = "" + # Initialize Ollama client using factory factory = OllamaClientFactory(logger=self.logger) ollama_config = OllamaConfig( @@ -349,6 +363,8 @@ def __init__( self.max_context_segments = Config.CLASSIFIER_CONTEXT_MAX_SEGMENTS self.max_past_duration = Config.CLASSIFIER_CONTEXT_PAST_SECONDS self.max_future_duration = Config.CLASSIFIER_CONTEXT_FUTURE_SECONDS + self.use_batching = Config.CLASSIFICATION_USE_BATCHING + self.batch_size = Config.CLASSIFICATION_BATCH_SIZE def preflight_check(self): issues = [] @@ -379,13 +395,42 @@ def classify_segments( if not segments: return [] + # If batching is enabled and templates exist, use batched method + if self.use_batching and self.batch_prompt_template: + return self.classify_segments_batched( + segments, character_names, player_names, speaker_map, temporal_metadata + ) + active_speaker_map = speaker_map or self._build_fallback_speaker_map(segments) speaker_overview = self._format_speaker_overview(active_speaker_map) session_duration = self._get_session_duration(segments) past_classifications: List[Classification] = [] results: List[ClassificationResult] = [] + # Get session_id from first segment if available for status tracking + session_id = segments[0].get("session_id", "unknown") if segments else "unknown" + + total_segments = len(segments) + start_time_all = time.time() + + # Log start + self.logger.info(f"Starting sequential classification for {total_segments} segments") + StatusTracker.update_stage(session_id, 6, "running", f"Classifying {total_segments} segments...") + + PROGRESS_INTERVAL = 20 # Log every 20 segments + for i, segment in enumerate(segments): + # Progress Logging + if (i + 1) % PROGRESS_INTERVAL == 0: + elapsed = time.time() - start_time_all + avg_time = elapsed / (i + 1) + remaining = (total_segments - (i + 1)) * avg_time + percentage = ((i + 1) / total_segments) * 100 + + msg = f"Classified {i + 1}/{total_segments} ({percentage:.1f}%) - ETA: {remaining/60:.1f}m" + self.logger.info(msg) + StatusTracker.update_stage(session_id, 6, "running", msg) + context_segments = self._gather_context_segments(segments, i) speaker_info = self._resolve_speaker_info(segment.get("speaker"), active_speaker_map) metadata = ( @@ -419,6 +464,188 @@ def classify_segments( results.append(result) past_classifications.append(result.classification) + total_time = time.time() - start_time_all + avg_time = total_time / total_segments if total_segments > 0 else 0 + self.logger.info(f"Classification complete: {total_segments} segments in {total_time/60:.1f} minutes ({avg_time:.2f}s per segment)") + + return results + + def classify_segments_batched( + self, + segments: List[Dict], + character_names: List[str], + player_names: List[str], + speaker_map: Optional[Dict[str, Dict[str, Any]]] = None, + temporal_metadata: Optional[List[Dict[str, Any]]] = None + ) -> List[ClassificationResult]: + """Classify segments in batches for performance optimization.""" + active_speaker_map = speaker_map or self._build_fallback_speaker_map(segments) + speaker_overview = self._format_speaker_overview(active_speaker_map) + + results: List[ClassificationResult] = [None] * len(segments) + total_segments = len(segments) + session_id = segments[0].get("session_id", "unknown") if segments else "unknown" + + self.logger.info(f"Starting batched classification for {total_segments} segments (batch size: {self.batch_size})") + StatusTracker.update_stage(session_id, 6, "running", f"Batch classifying {total_segments} segments...") + + start_time_all = time.time() + + for i in range(0, total_segments, self.batch_size): + batch_segments = segments[i : min(i + self.batch_size, total_segments)] + batch_indices = list(range(i, i + len(batch_segments))) + + # Prepare batch text + batch_text_lines = [] + for idx, seg in zip(batch_indices, batch_segments): + info = self._resolve_speaker_info(seg.get("speaker"), active_speaker_map) + timestamp = self._format_timestamp(seg.get("start_time")) + batch_text_lines.append(f"Index {idx} [{timestamp}] {info.display_name()}: {seg.get('text', '').strip()}") + + batch_text = "\n".join(batch_text_lines) + + prompt = self.batch_prompt_template.format( + char_list=", ".join(character_names) if character_names else "Unknown", + player_list=", ".join(player_names) if player_names else "Unknown", + speaker_map=speaker_overview, + batch_text=batch_text + ) + + # Progress logging + if i > 0: + elapsed = time.time() - start_time_all + processed = i + avg_time_per_segment = elapsed / processed + remaining = (total_segments - processed) * avg_time_per_segment + percentage = (processed / total_segments) * 100 + msg = f"Classified {processed}/{total_segments} ({percentage:.1f}%) - ETA: {remaining/60:.1f}m" + self.logger.info(msg) + StatusTracker.update_stage(session_id, 6, "running", msg) + + try: + response_payload = self._generate_with_retry(prompt, i) + + if response_payload: + response_text = response_payload.get("response", "") + parsed_results = self._parse_batch_response(response_text, batch_indices) + + # Fill in results + for res in parsed_results: + idx = res.segment_index + if i <= idx < i + self.batch_size: + # Enrich result with metadata/context logic as in single mode + # Note: context-aware classification is reduced in batched mode, + # relying more on the LLM's ability to see local context in the batch + speaker_info = self._resolve_speaker_info(segments[idx].get("speaker"), active_speaker_map) + + res.model = response_payload.get("model") + self._apply_speaker_metadata(res, speaker_info) + self._infer_classification_type(res, speaker_info) + self._attach_prompt_metadata(res, prompt, response_text) + + results[idx] = res + + except Exception as e: + self.logger.error(f"Batch classification failed for indices {batch_indices}: {e}") + # Will be handled by fallback loop + + # Fill in any missing results (failed batches) + failed_indices = [idx for idx, res in enumerate(results) if res is None] + if failed_indices: + self.logger.warning(f"Falling back to sequential classification for {len(failed_indices)} failed segments") + + # Group contiguous indices into mini-batches for efficiency if we want, + # but for safety let's just process them sequentially using the existing helper methods. + for idx in failed_indices: + segment = segments[idx] + context_segments = self._gather_context_segments(segments, idx) + speaker_info = self._resolve_speaker_info(segment.get("speaker"), active_speaker_map) + + metadata = ( + temporal_metadata[idx] + if temporal_metadata and idx < len(temporal_metadata) + else self._build_temporal_metadata( + idx, + segment, + segments, + # Pass recent classifications if available, else empty list + [r.classification for r in results[:idx] if r], + self._get_session_duration(segments) + ) + ) + + prompt_text = self._build_prompt_with_context( + character_names=character_names, + player_names=player_names, + speaker_overview=speaker_overview, + metadata=metadata, + context_segments=context_segments, + speaker_info=speaker_info, + speaker_map=active_speaker_map, + ) + + # Use sequential classification helper + results[idx] = self._classify_with_context( + prompt_text, + index=idx, + speaker_info=speaker_info, + metadata=metadata, + ) + + total_time = time.time() - start_time_all + avg_time = total_time / total_segments if total_segments > 0 else 0 + self.logger.info(f"Batch classification complete: {total_segments} segments in {total_time/60:.1f} minutes ({avg_time:.2f}s per segment)") + + return results + + def _parse_batch_response(self, response_text: str, expected_indices: List[int]) -> List[ClassificationResult]: + """Parse JSON response from batch classification.""" + import json + + # Try to find JSON array in the text + json_match = re.search(r'\[.*\]', response_text, re.DOTALL) + if not json_match: + self.logger.warning("No JSON array found in batch response") + return [] + + json_str = json_match.group(0) + results = [] + + try: + data = json.loads(json_str) + for item in data: + index = item.get("index") + if index not in expected_indices: + continue + + # Parse fields + classification = Classification.IN_CHARACTER + try: + classification = Classification(item.get("classification", "IC").upper()) + except ValueError: + pass + + classification_type = ClassificationType.UNKNOWN + try: + classification_type = ClassificationType(item.get("type", "UNKNOWN").upper()) + except ValueError: + pass + + confidence = float(item.get("confidence", ConfidenceDefaults.DEFAULT)) + + results.append(ClassificationResult( + segment_index=index, + classification=classification, + classification_type=classification_type, + confidence=confidence, + reasoning=item.get("reason", ""), + character=item.get("character"), + speaker_name=item.get("speaker_name") + )) + + except json.JSONDecodeError as e: + self.logger.error(f"Failed to decode JSON from batch response: {e}") + return results def _classify_with_context( diff --git a/src/config.py b/src/config.py index c750c39..359246e 100644 --- a/src/config.py +++ b/src/config.py @@ -105,6 +105,8 @@ def get_env_as_bool(key: str, default: bool) -> bool: CLASSIFIER_CONTEXT_FUTURE_SECONDS: float = get_env_as_float("CLASSIFIER_CONTEXT_FUTURE_SECONDS", 30.0) CLASSIFIER_AUDIT_MODE: bool = get_env_as_bool("CLASSIFIER_AUDIT_MODE", False) CLASSIFIER_PROMPT_PREVIEW_CHARS: int = get_env_as_int("CLASSIFIER_PROMPT_PREVIEW_CHARS", 256) + CLASSIFICATION_USE_BATCHING: bool = get_env_as_bool("CLASSIFICATION_USE_BATCHING", True) + CLASSIFICATION_BATCH_SIZE: int = get_env_as_int("CLASSIFICATION_BATCH_SIZE", 10) # Paths PROJECT_ROOT: Path = Path(__file__).parent.parent diff --git a/src/prompts/classifier_batch_prompt_en.txt b/src/prompts/classifier_batch_prompt_en.txt new file mode 100644 index 0000000..c0ecce6 --- /dev/null +++ b/src/prompts/classifier_batch_prompt_en.txt @@ -0,0 +1,30 @@ +You are an expert in analyzing Dungeons & Dragons sessions. + +Task: Classify each of the following segments as IC/OOC/MIXED and provide a detail category. +Output MUST be a valid JSON array of objects. + +Detail categories: +- CHARACTER = Player speaking/acting as their character (includes "I attack", "I cast Fireball") +- DM_NARRATION = Dungeon Master narrating the world as narrator +- NPC_DIALOGUE = Dungeon Master speaking as an NPC +- OOC_OTHER = Rules talk, jokes, or real-life chatter + +Characters: {char_list} +Players: {player_list} + +Speaker Map: +{speaker_map} + +Segments to classify: +{batch_text} + +Return a JSON array where each object has these fields: +- "index": (integer) matching the input segment index +- "classification": "IC", "OOC", or "MIXED" +- "type": "CHARACTER", "DM_NARRATION", "NPC_DIALOGUE", or "OOC_OTHER" +- "reason": (string) brief explanation +- "confidence": (float) 0.0-1.0 +- "character": (string) name or null +- "speaker_name": (string) resolved speaker name or null + +JSON Output: \ No newline at end of file diff --git a/src/prompts/classifier_batch_prompt_nl.txt b/src/prompts/classifier_batch_prompt_nl.txt new file mode 100644 index 0000000..73d5748 --- /dev/null +++ b/src/prompts/classifier_batch_prompt_nl.txt @@ -0,0 +1,30 @@ +Je bent een expert in het analyseren van Nederlandstalige Dungeons & Dragons-sessies. + +Taak: Classificeer elk van de volgende segmenten als IC/OOC/MIXED en geef een detailcategorie. +De uitvoer MOET een geldige JSON-array van objecten zijn. + +Detailcategorieën: +- CHARACTER = Speler spreekt/handelt als het personage (inclusief "Ik val aan", "Ik cast Fireball") +- DM_NARRATION = Vertelling/beeldvorming door de DM als verteller +- NPC_DIALOGUE = DM spreekt als NPC +- OOC_OTHER = Alles buiten het verhaal (regels, grappen, real-life) + +Personages: {char_list} +Spelers: {player_list} + +Sprekerkaart: +{speaker_map} + +Te classificeren segmenten: +{batch_text} + +Retourneer een JSON-array waarbij elk object deze velden heeft: +- "index": (integer) komt overeen met de index van het invoersegment +- "classification": "IC", "OOC", of "MIXED" +- "type": "CHARACTER", "DM_NARRATION", "NPC_DIALOGUE", of "OOC_OTHER" +- "reason": (string) korte uitleg +- "confidence": (float) 0.0-1.0 +- "character": (string) naam of null +- "speaker_name": (string) opgeloste sprekernaam of null + +JSON Output: \ No newline at end of file diff --git a/tests/test_classifier.py b/tests/test_classifier.py index 7ce9e84..4b978dd 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -13,6 +13,14 @@ def patched_config(): MockConfig.GROQ_MAX_CALLS_PER_SECOND = 2 MockConfig.GROQ_RATE_LIMIT_PERIOD_SECONDS = 1.0 MockConfig.GROQ_RATE_LIMIT_BURST = 2 + MockConfig.CLASSIFICATION_BATCH_SIZE = 10 + MockConfig.CLASSIFICATION_USE_BATCHING = False + MockConfig.CLASSIFIER_CONTEXT_MAX_SEGMENTS = 5 + MockConfig.CLASSIFIER_CONTEXT_PAST_SECONDS = 30 + MockConfig.CLASSIFIER_CONTEXT_FUTURE_SECONDS = 30 + MockConfig.CLASSIFIER_PROMPT_PREVIEW_CHARS = 100 + MockConfig.CLASSIFIER_AUDIT_MODE = False + # Create a dummy prompt file path MockConfig.PROJECT_ROOT.return_value = MagicMock() type(MockConfig).PROJECT_ROOT = MagicMock() @@ -391,6 +399,8 @@ def test_to_dict(self): "confidence": 0.9, "reasoning": "Test reason", "character": "Aragorn", + "classification_type": "UNKNOWN", + "unknown_speaker": False } assert result.to_dict() == expected_dict