From 50ca77f45e30e83f296d7ad755457c751d39e894 Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Tue, 17 Mar 2026 07:15:54 -0700 Subject: [PATCH 1/3] feat(parse): implement audio resource parser with Whisper transcription Replace the audio parser stub with a working implementation that: - Extracts metadata (duration, sample rate, channels, bitrate) via mutagen - Transcribes speech via Whisper API with timestamped segments - Builds structured ResourceNode tree with L0/L1/L2 content tiers - Falls back to metadata-only output when Whisper is unavailable - Adds mutagen as optional dependency under [audio] extra - Adds audio_summary prompt template for semantic indexing - Includes unit tests with mocked Whisper API and mutagen --- openviking/parse/parsers/media/audio.py | 480 ++++++++++++++---- .../templates/parsing/audio_summary.yaml | 44 ++ pyproject.toml | 3 + tests/unit/parse/__init__.py | 0 tests/unit/parse/test_audio_parser.py | 288 +++++++++++ 5 files changed, 708 insertions(+), 107 deletions(-) create mode 100644 openviking/prompts/templates/parsing/audio_summary.yaml create mode 100644 tests/unit/parse/__init__.py create mode 100644 tests/unit/parse/test_audio_parser.py diff --git a/openviking/parse/parsers/media/audio.py b/openviking/parse/parsers/media/audio.py index 5c57dfc2..fbe944c9 100644 --- a/openviking/parse/parsers/media/audio.py +++ b/openviking/parse/parsers/media/audio.py @@ -1,41 +1,117 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 """ -Audio parser - Future implementation. - -Planned Features: -1. Speech-to-text transcription using ASR models -2. Audio metadata extraction (duration, sample rate, channels) -3. Speaker diarization (identify different speakers) -4. Timestamp alignment for transcribed text -5. Generate structured ResourceNode with transcript - -Example workflow: - 1. Load audio file - 2. Extract metadata (duration, format, sample rate) - 3. Transcribe speech to text using Whisper or similar - 4. (Optional) Perform speaker diarization - 5. Create ResourceNode with: - - type: NodeType.ROOT - - children: sections for each speaker/timestamp - - meta: audio metadata and timestamps - 6. Return ParseResult - -Supported formats: MP3, WAV, OGG, FLAC, AAC, M4A +Audio parser with metadata extraction and Whisper transcription. + +Features: +1. Speech-to-text transcription using Whisper API +2. Audio metadata extraction (duration, sample rate, channels) via mutagen +3. Timestamp alignment for transcribed text +4. Generate structured ResourceNode with transcript segments + +Supported formats: MP3, WAV, OGG, FLAC, AAC, M4A, OPUS """ +import io +import time from pathlib import Path -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from openviking.parse.base import NodeType, ParseResult, ResourceNode from openviking.parse.parsers.base_parser import BaseParser from openviking.parse.parsers.media.constants import AUDIO_EXTENSIONS from openviking_cli.utils.config.parser_config import AudioConfig +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + +# Magic bytes for audio format validation +AUDIO_MAGIC_BYTES: Dict[str, List[bytes]] = { + ".mp3": [b"ID3", b"\xff\xfb", b"\xff\xf3", b"\xff\xf2"], + ".wav": [b"RIFF"], + ".ogg": [b"OggS"], + ".flac": [b"fLaC"], + ".aac": [b"\xff\xf1", b"\xff\xf9"], + ".m4a": [b"\x00\x00\x00", b"ftypM4A", b"ftypisom"], + ".opus": [b"OggS"], +} + + +def _try_import_mutagen(): + """Lazily import mutagen, returning None if not installed.""" + try: + import mutagen + + return mutagen + except ImportError: + return None + + +def _format_timestamp(seconds: float) -> str: + """Format seconds as MM:SS or H:MM:SS.""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + if hours > 0: + return f"{hours}:{minutes:02d}:{secs:02d}" + return f"{minutes}:{secs:02d}" + + +def _extract_metadata_mutagen(file_path: Path) -> Dict[str, Any]: + """ + Extract audio metadata using mutagen. + + Args: + file_path: Path to audio file + + Returns: + Dictionary with duration, sample_rate, channels, bitrate, format + """ + mutagen = _try_import_mutagen() + if mutagen is None: + logger.warning( + "[AudioParser] mutagen not installed, skipping metadata extraction. " + "Install with: pip install mutagen" + ) + return {} + + try: + audio = mutagen.File(str(file_path)) + if audio is None: + logger.warning(f"[AudioParser] mutagen could not identify file: {file_path}") + return {} + + meta: Dict[str, Any] = {} + + # Duration + if hasattr(audio.info, "length"): + meta["duration"] = round(audio.info.length, 2) + + # Sample rate + if hasattr(audio.info, "sample_rate"): + meta["sample_rate"] = audio.info.sample_rate + + # Channels + if hasattr(audio.info, "channels"): + meta["channels"] = audio.info.channels + + # Bitrate (bits per second) + if hasattr(audio.info, "bitrate"): + meta["bitrate"] = audio.info.bitrate + + return meta + + except Exception as e: + logger.warning(f"[AudioParser] mutagen metadata extraction failed: {e}") + return {} class AudioParser(BaseParser): """ Audio parser for audio files. + + Extracts metadata via mutagen and transcribes speech via Whisper API. + Falls back to metadata-only output when transcription is unavailable. """ def __init__(self, config: Optional[AudioConfig] = None, **kwargs): @@ -53,23 +129,28 @@ def supported_extensions(self) -> List[str]: """Return supported audio file extensions.""" return AUDIO_EXTENSIONS - async def parse(self, source: Union[str, Path], instruction: str = "", **kwargs) -> ParseResult: + async def parse( + self, source: Union[str, Path], instruction: str = "", **kwargs + ) -> ParseResult: """ - Parse audio file - only copy original file and extract basic metadata, no content understanding. + Parse audio file - extract metadata, transcribe via Whisper, build ResourceNode tree. Args: source: Audio file path + instruction: Processing instruction **kwargs: Additional parsing parameters Returns: - ParseResult with audio content + ParseResult with audio content tree Raises: FileNotFoundError: If source file does not exist - IOError: If audio processing fails + ValueError: If file signature does not match expected format """ from openviking.storage.viking_fs import get_viking_fs + start_time = time.monotonic() + # Convert to Path object file_path = Path(source) if isinstance(source, str) else source if not file_path.exists(): @@ -78,160 +159,339 @@ async def parse(self, source: Union[str, Path], instruction: str = "", **kwargs) viking_fs = get_viking_fs() temp_uri = viking_fs.create_temp_uri() - # Phase 1: Generate temporary files + # Read audio bytes audio_bytes = file_path.read_bytes() ext = file_path.suffix + # Validate magic bytes + self._validate_audio_bytes(audio_bytes, ext, file_path) + from openviking_cli.utils.uri import VikingURI # Sanitize original filename (replace spaces with underscores) original_filename = file_path.name.replace(" ", "_") - # Root directory name: filename stem + _ + extension (without dot) stem = file_path.stem.replace(" ", "_") ext_no_dot = ext[1:] if ext else "" root_dir_name = VikingURI.sanitize_segment(f"{stem}_{ext_no_dot}") root_dir_uri = f"{temp_uri}/{root_dir_name}" await viking_fs.mkdir(root_dir_uri, exist_ok=True) - # 1.1 Save original audio with original filename (sanitized) + # Save original audio await viking_fs.write_file_bytes(f"{root_dir_uri}/{original_filename}", audio_bytes) - # 1.2 Validate audio file using magic bytes - # Define magic bytes for supported audio formats - audio_magic_bytes = { - ".mp3": [b"ID3", b"\xff\xfb", b"\xff\xf3", b"\xff\xf2"], - ".wav": [b"RIFF"], - ".ogg": [b"OggS"], - ".flac": [b"fLaC"], - ".aac": [b"\xff\xf1", b"\xff\xf9"], - ".m4a": [b"\x00\x00\x00", b"ftypM4A", b"ftypisom"], - ".opus": [b"OggS"], - } - - # Check magic bytes - valid = False - ext_lower = ext.lower() - magic_list = audio_magic_bytes.get(ext_lower, []) - for magic in magic_list: - if len(audio_bytes) >= len(magic) and audio_bytes.startswith(magic): - valid = True - break - - if not valid: - raise ValueError( - f"Invalid audio file: {file_path}. File signature does not match expected format {ext_lower}" + # Extract metadata via mutagen + mutagen_meta = _extract_metadata_mutagen(file_path) + duration = mutagen_meta.get("duration", 0) + sample_rate = mutagen_meta.get("sample_rate", 0) + channels = mutagen_meta.get("channels", 0) + bitrate = mutagen_meta.get("bitrate", 0) + format_str = ext_no_dot.lower() + + # Attempt transcription + transcript_segments: List[Dict[str, Any]] = [] + full_transcript = "" + warnings: List[str] = [] + + if self.config.enable_transcription: + try: + transcript_segments = await self._asr_transcribe_with_timestamps( + audio_bytes, self.config.transcription_model, ext + ) + if transcript_segments: + full_transcript = "\n".join( + seg["text"] for seg in transcript_segments + ) + else: + # Try plain transcription + full_transcript = await self._asr_transcribe( + audio_bytes, self.config.transcription_model, ext + ) + except Exception as e: + logger.warning(f"[AudioParser] Transcription failed: {e}") + warnings.append(f"Transcription unavailable: {e}") + + has_transcript = bool(full_transcript.strip()) + + # Save transcript file if available + if has_transcript: + transcript_md = self._build_transcript_markdown( + transcript_segments, full_transcript, file_path.stem ) + await viking_fs.write_file(f"{root_dir_uri}/transcript.md", transcript_md) + + # Build segment child nodes + children = [] + if transcript_segments: + for i, seg in enumerate(transcript_segments): + seg_start = seg.get("start", 0) + seg_end = seg.get("end", 0) + seg_text = seg.get("text", "").strip() + if not seg_text: + continue + + child = ResourceNode( + type=NodeType.SECTION, + title=f"segment_{i + 1:03d} ({_format_timestamp(seg_start)}-{_format_timestamp(seg_end)})", + level=1, + detail_file=None, + content_path=None, + children=[], + content_type="text", + meta={ + "start": seg_start, + "end": seg_end, + "text": seg_text, + }, + ) + children.append(child) + + # Build root node meta + root_meta: Dict[str, Any] = { + "duration": duration, + "sample_rate": sample_rate, + "channels": channels, + "bitrate": bitrate, + "format": format_str, + "content_type": "audio", + "source_title": file_path.stem, + "semantic_name": file_path.stem, + "original_filename": original_filename, + "has_transcript": has_transcript, + "segment_count": len(children), + } - # Extract audio metadata (placeholder) - duration = 0 - sample_rate = 0 - channels = 0 - format_str = ext[1:].upper() - - # Create ResourceNode - metadata only, no content understanding yet + # Create root ResourceNode root_node = ResourceNode( type=NodeType.ROOT, title=file_path.stem, level=0, detail_file=None, content_path=None, - children=[], - meta={ - "duration": duration, - "sample_rate": sample_rate, - "channels": channels, - "format": format_str.lower(), - "content_type": "audio", - "source_title": file_path.stem, - "semantic_name": file_path.stem, - "original_filename": original_filename, - }, + children=children, + content_type="audio", + meta=root_meta, + ) + + # Generate semantic info (L0 abstract, L1 overview) + description = full_transcript if has_transcript else f"Audio file: {file_path.name}" + await self._generate_semantic_info( + root_node, description, viking_fs, has_transcript ) - # Phase 3: Build directory structure (handled by TreeBuilder) + if not has_transcript: + warnings.append( + "No transcript available. Metadata-only output. " + "Configure Whisper API or install openai-whisper for transcription." + ) + + parse_time = time.monotonic() - start_time + return ParseResult( root=root_node, source_path=str(file_path), temp_dir_path=temp_uri, source_format="audio", parser_name="AudioParser", - meta={"content_type": "audio", "format": format_str.lower()}, + parse_time=parse_time, + meta={"content_type": "audio", "format": format_str}, + warnings=warnings, + ) + + def _validate_audio_bytes( + self, audio_bytes: bytes, ext: str, file_path: Path + ) -> None: + """Validate audio file using magic bytes.""" + ext_lower = ext.lower() + magic_list = AUDIO_MAGIC_BYTES.get(ext_lower, []) + for magic in magic_list: + if len(audio_bytes) >= len(magic) and audio_bytes.startswith(magic): + return + # If no magic bytes defined for this extension, skip validation + if not magic_list: + return + raise ValueError( + f"Invalid audio file: {file_path}. " + f"File signature does not match expected format {ext_lower}" ) - async def _asr_transcribe(self, audio_bytes: bytes, model: Optional[str]) -> str: + async def _asr_transcribe( + self, audio_bytes: bytes, model: Optional[str], ext: str = ".mp3" + ) -> str: """ - Generate audio transcription using ASR. + Transcribe audio using Whisper API via OpenAI client. Args: audio_bytes: Audio binary data - model: ASR model name + model: Whisper model name + ext: File extension for mime type hint Returns: - Audio transcription in markdown format - - TODO: Integrate with actual ASR API (Whisper, etc.) + Transcription text """ - # Fallback implementation - returns basic placeholder - return "Audio transcription (ASR integration pending)\n\nThis is an audio. ASR transcription feature has not yet integrated external API." + try: + from openviking_cli.utils.config import get_openviking_config + + config = get_openviking_config() + import openai + + client = openai.AsyncOpenAI( + api_key=config.llm.api_key if hasattr(config, "llm") else None, + ) + + audio_file = io.BytesIO(audio_bytes) + audio_file.name = f"audio{ext}" + + response = await client.audio.transcriptions.create( + model=model or "whisper-1", + file=audio_file, + language=self.config.language, + ) + + return response.text + + except Exception as e: + logger.warning(f"[AudioParser._asr_transcribe] Whisper API call failed: {e}") + return "" async def _asr_transcribe_with_timestamps( - self, audio_bytes: bytes, model: Optional[str] - ) -> Optional[str]: + self, audio_bytes: bytes, model: Optional[str], ext: str = ".mp3" + ) -> List[Dict[str, Any]]: """ - Extract transcription with timestamps from audio using ASR. + Transcribe audio with timestamps using Whisper API verbose_json format. Args: audio_bytes: Audio binary data - model: ASR model name + model: Whisper model name + ext: File extension Returns: - Transcript with timestamps in markdown format, or None if not available + List of segment dicts with keys: start, end, text + """ + try: + from openviking_cli.utils.config import get_openviking_config + + config = get_openviking_config() + import openai - TODO: Integrate with ASR API + client = openai.AsyncOpenAI( + api_key=config.llm.api_key if hasattr(config, "llm") else None, + ) + + audio_file = io.BytesIO(audio_bytes) + audio_file.name = f"audio{ext}" + + response = await client.audio.transcriptions.create( + model=model or "whisper-1", + file=audio_file, + response_format="verbose_json", + timestamp_granularities=["segment"], + language=self.config.language, + ) + + segments = [] + if hasattr(response, "segments") and response.segments: + for seg in response.segments: + segments.append({ + "start": seg.get("start", 0) if isinstance(seg, dict) else getattr(seg, "start", 0), + "end": seg.get("end", 0) if isinstance(seg, dict) else getattr(seg, "end", 0), + "text": seg.get("text", "") if isinstance(seg, dict) else getattr(seg, "text", ""), + }) + + return segments + + except Exception as e: + logger.warning( + f"[AudioParser._asr_transcribe_with_timestamps] Whisper API call failed: {e}" + ) + return [] + + def _build_transcript_markdown( + self, + segments: List[Dict[str, Any]], + full_transcript: str, + title: str, + ) -> str: """ - # Not implemented - return None - return None + Build a markdown transcript file from segments or plain text. + + Args: + segments: Timestamped transcript segments + full_transcript: Full transcript text (used if no segments) + title: Audio file title + + Returns: + Markdown-formatted transcript + """ + parts = [f"# Transcript: {title}\n"] + + if segments: + for seg in segments: + start = _format_timestamp(seg.get("start", 0)) + end = _format_timestamp(seg.get("end", 0)) + text = seg.get("text", "").strip() + if text: + parts.append(f"**[{start} - {end}]** {text}\n") + elif full_transcript.strip(): + parts.append(full_transcript.strip()) + parts.append("") + + return "\n".join(parts) async def _generate_semantic_info( - self, node: ResourceNode, description: str, viking_fs, has_transcript: bool - ): + self, + node: ResourceNode, + description: str, + viking_fs: Any, + has_transcript: bool, + ) -> None: """ - Phase 2: Generate abstract and overview. + Generate L0 abstract and L1 overview for the audio resource. Args: node: ResourceNode to update - description: Audio description + description: Audio transcript or description text viking_fs: VikingFS instance - has_transcript: Whether transcript file exists + has_transcript: Whether transcript is available """ - # Generate abstract (short summary, < 100 tokens) - abstract = description[:200] if len(description) > 200 else description - - # Generate overview (content summary + file list + usage instructions) + # L0 abstract: short summary (< 256 chars) + if has_transcript and len(description) > 50: + first_sentence_end = description.find(".", 20) + if 20 < first_sentence_end < 256: + abstract = description[: first_sentence_end + 1] + else: + abstract = description[:253] + "..." if len(description) > 256 else description + else: + abstract = description[:253] + "..." if len(description) > 256 else description + + # L1 overview overview_parts = [ "## Content Summary\n", - description, + abstract, "\n\n## Available Files\n", - f"- {node.meta['original_filename']}: Original audio file ({node.meta['duration']}s, {node.meta['sample_rate']}Hz, {node.meta['channels']}ch, {node.meta['format'].upper()} format)\n", + ( + f"- {node.meta['original_filename']}: Original audio file " + f"({node.meta['duration']}s, {node.meta['sample_rate']}Hz, " + f"{node.meta['channels']}ch, {node.meta['format'].upper()} format)\n" + ), ] if has_transcript: - overview_parts.append("- transcript.md: Transcript with timestamps from the audio\n") + overview_parts.append( + "- transcript.md: Timestamped transcript from the audio\n" + ) overview_parts.append("\n## Usage\n") overview_parts.append("### Play Audio\n") overview_parts.append("```python\n") overview_parts.append("audio_bytes = await audio_resource.play()\n") overview_parts.append("# Returns: Audio file binary data\n") - overview_parts.append("# Purpose: Play or save the audio\n") overview_parts.append("```\n\n") if has_transcript: - overview_parts.append("### Get Timestamps Transcript\n") + overview_parts.append("### Get Timestamped Transcript\n") overview_parts.append("```python\n") overview_parts.append("timestamps = await audio_resource.timestamps()\n") overview_parts.append("# Returns: FileContent object or None\n") - overview_parts.append("# Purpose: Extract timestamped transcript from the audio\n") overview_parts.append("```\n\n") overview_parts.append("### Get Audio Metadata\n") @@ -245,17 +505,22 @@ async def _generate_semantic_info( overview_parts.append( f"channels = audio_resource.get_channels() # {node.meta['channels']}\n" ) - overview_parts.append(f'format = audio_resource.get_format() # "{node.meta["format"]}"\n') + overview_parts.append( + f'format = audio_resource.get_format() # "{node.meta["format"]}"\n' + ) overview_parts.append("```\n") overview = "".join(overview_parts) - # Store in node meta node.meta["abstract"] = abstract node.meta["overview"] = overview async def parse_content( - self, content: str, source_path: Optional[str] = None, instruction: str = "", **kwargs + self, + content: str, + source_path: Optional[str] = None, + instruction: str = "", + **kwargs, ) -> ParseResult: """ Parse audio from content string - Not yet implemented. @@ -263,6 +528,7 @@ async def parse_content( Args: content: Audio content (base64 or binary string) source_path: Optional source path for metadata + instruction: Processing instruction **kwargs: Additional parsing parameters Returns: diff --git a/openviking/prompts/templates/parsing/audio_summary.yaml b/openviking/prompts/templates/parsing/audio_summary.yaml new file mode 100644 index 00000000..06009d47 --- /dev/null +++ b/openviking/prompts/templates/parsing/audio_summary.yaml @@ -0,0 +1,44 @@ +metadata: + id: "parsing.audio_summary" + name: "Audio Summary" + description: "Generate concise audio summary from transcript for semantic parsing" + version: "1.0.0" + language: "en" + category: "parsing" + +variables: + - name: "transcript" + type: "string" + description: "Full audio transcript text" + required: true + max_length: 30000 + - name: "duration" + type: "string" + description: "Audio duration in seconds" + default: "unknown" + required: false + - name: "format" + type: "string" + description: "Audio file format" + default: "unknown" + required: false + +template: | + Please analyze this audio transcript and generate a concise summary for semantic indexing. + + Audio duration: {{ duration }}s + Audio format: {{ format }} + + Transcript: + {{ transcript }} + + Generate a comprehensive summary that includes: + 1. Main topic or subject of the audio + 2. Key points discussed + 3. Any notable speakers or perspectives + 4. Important conclusions or takeaways + + Keep the summary clear and factual, suitable for semantic search and understanding. + +llm_config: + temperature: 0.0 diff --git a/pyproject.toml b/pyproject.toml index c26aa7cf..241997d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,9 @@ dependencies = [ pyagfs = { path = "third_party/agfs/agfs-sdk/python" } [project.optional-dependencies] +audio = [ + "mutagen>=1.47.0", +] test = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", diff --git a/tests/unit/parse/__init__.py b/tests/unit/parse/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/parse/test_audio_parser.py b/tests/unit/parse/test_audio_parser.py new file mode 100644 index 00000000..2b34e8e4 --- /dev/null +++ b/tests/unit/parse/test_audio_parser.py @@ -0,0 +1,288 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for AudioParser with mocked Whisper API and mutagen.""" + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.parse.base import NodeType +from openviking.parse.parsers.media.audio import ( + AUDIO_MAGIC_BYTES, + AudioParser, + _extract_metadata_mutagen, + _format_timestamp, +) +from openviking_cli.utils.config.parser_config import AudioConfig + + +class TestFormatTimestamp: + def test_seconds_only(self): + assert _format_timestamp(45) == "0:45" + + def test_minutes_and_seconds(self): + assert _format_timestamp(125) == "2:05" + + def test_hours(self): + assert _format_timestamp(3661) == "1:01:01" + + def test_zero(self): + assert _format_timestamp(0) == "0:00" + + +class TestExtractMetadataMutagen: + @patch("openviking.parse.parsers.media.audio._try_import_mutagen") + def test_mutagen_not_installed(self, mock_import): + mock_import.return_value = None + result = _extract_metadata_mutagen(Path("/fake/audio.mp3")) + assert result == {} + + @patch("openviking.parse.parsers.media.audio._try_import_mutagen") + def test_mutagen_returns_metadata(self, mock_import): + mock_mutagen = MagicMock() + mock_audio = MagicMock() + mock_audio.info.length = 120.5 + mock_audio.info.sample_rate = 44100 + mock_audio.info.channels = 2 + mock_audio.info.bitrate = 320000 + mock_mutagen.File.return_value = mock_audio + mock_import.return_value = mock_mutagen + + result = _extract_metadata_mutagen(Path("/fake/audio.mp3")) + assert result["duration"] == 120.5 + assert result["sample_rate"] == 44100 + assert result["channels"] == 2 + assert result["bitrate"] == 320000 + + @patch("openviking.parse.parsers.media.audio._try_import_mutagen") + def test_mutagen_file_returns_none(self, mock_import): + mock_mutagen = MagicMock() + mock_mutagen.File.return_value = None + mock_import.return_value = mock_mutagen + + result = _extract_metadata_mutagen(Path("/fake/audio.mp3")) + assert result == {} + + @patch("openviking.parse.parsers.media.audio._try_import_mutagen") + def test_mutagen_raises_exception(self, mock_import): + mock_mutagen = MagicMock() + mock_mutagen.File.side_effect = Exception("corrupt file") + mock_import.return_value = mock_mutagen + + result = _extract_metadata_mutagen(Path("/fake/audio.mp3")) + assert result == {} + + +class TestAudioParserInit: + def test_default_config(self): + parser = AudioParser() + assert parser.config.enable_transcription is True + assert parser.config.transcription_model == "whisper-large-v3" + + def test_custom_config(self): + config = AudioConfig(enable_transcription=False, language="en") + parser = AudioParser(config=config) + assert parser.config.enable_transcription is False + assert parser.config.language == "en" + + def test_supported_extensions(self): + parser = AudioParser() + exts = parser.supported_extensions + assert ".mp3" in exts + assert ".wav" in exts + assert ".ogg" in exts + assert ".flac" in exts + assert ".aac" in exts + assert ".m4a" in exts + + def test_can_parse(self): + parser = AudioParser() + assert parser.can_parse("test.mp3") is True + assert parser.can_parse("test.wav") is True + assert parser.can_parse("test.txt") is False + assert parser.can_parse("test.pdf") is False + + +class TestAudioParserValidation: + def test_validate_mp3_id3(self): + parser = AudioParser() + audio_bytes = b"ID3" + b"\x00" * 100 + parser._validate_audio_bytes(audio_bytes, ".mp3", Path("test.mp3")) + + def test_validate_wav_riff(self): + parser = AudioParser() + audio_bytes = b"RIFF" + b"\x00" * 100 + parser._validate_audio_bytes(audio_bytes, ".wav", Path("test.wav")) + + def test_validate_flac(self): + parser = AudioParser() + audio_bytes = b"fLaC" + b"\x00" * 100 + parser._validate_audio_bytes(audio_bytes, ".flac", Path("test.flac")) + + def test_validate_ogg(self): + parser = AudioParser() + audio_bytes = b"OggS" + b"\x00" * 100 + parser._validate_audio_bytes(audio_bytes, ".ogg", Path("test.ogg")) + + def test_invalid_mp3_raises(self): + parser = AudioParser() + audio_bytes = b"NOT_MP3" + b"\x00" * 100 + with pytest.raises(ValueError, match="Invalid audio file"): + parser._validate_audio_bytes(audio_bytes, ".mp3", Path("test.mp3")) + + def test_unknown_extension_skips_validation(self): + parser = AudioParser() + audio_bytes = b"anything" + parser._validate_audio_bytes(audio_bytes, ".xyz", Path("test.xyz")) + + +class TestAudioParserParse: + @pytest.mark.asyncio + async def test_file_not_found(self): + parser = AudioParser() + with pytest.raises(FileNotFoundError, match="Audio file not found"): + await parser.parse("/nonexistent/audio.mp3") + + @pytest.mark.asyncio + @patch("openviking.parse.parsers.media.audio._extract_metadata_mutagen") + async def test_parse_metadata_only(self, mock_metadata): + """Test parsing with transcription disabled - metadata only.""" + mock_metadata.return_value = { + "duration": 60.0, + "sample_rate": 44100, + "channels": 2, + "bitrate": 128000, + } + + config = AudioConfig(enable_transcription=False) + parser = AudioParser(config=config) + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: + f.write(b"ID3" + b"\x00" * 200) + tmp_path = f.name + + try: + mock_viking_fs = MagicMock() + mock_viking_fs.create_temp_uri.return_value = "viking://temp/test123" + mock_viking_fs.mkdir = AsyncMock() + mock_viking_fs.write_file_bytes = AsyncMock() + mock_viking_fs.write_file = AsyncMock() + + with patch( + "openviking.parse.parsers.media.audio.get_viking_fs", + return_value=mock_viking_fs, + ): + result = await parser.parse(tmp_path) + + assert result.parser_name == "AudioParser" + assert result.source_format == "audio" + assert result.root.type == NodeType.ROOT + assert result.root.meta["duration"] == 60.0 + assert result.root.meta["sample_rate"] == 44100 + assert result.root.meta["channels"] == 2 + assert result.root.meta["has_transcript"] is False + assert len(result.warnings) > 0 + finally: + Path(tmp_path).unlink(missing_ok=True) + + @pytest.mark.asyncio + @patch("openviking.parse.parsers.media.audio._extract_metadata_mutagen") + async def test_parse_with_transcript_segments(self, mock_metadata): + """Test parsing with mocked Whisper returning timestamped segments.""" + mock_metadata.return_value = { + "duration": 30.0, + "sample_rate": 16000, + "channels": 1, + "bitrate": 64000, + } + + config = AudioConfig(enable_transcription=True) + parser = AudioParser(config=config) + + segments = [ + {"start": 0.0, "end": 10.0, "text": "Hello world."}, + {"start": 10.0, "end": 20.0, "text": "This is a test."}, + {"start": 20.0, "end": 30.0, "text": "Goodbye."}, + ] + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: + f.write(b"ID3" + b"\x00" * 200) + tmp_path = f.name + + try: + mock_viking_fs = MagicMock() + mock_viking_fs.create_temp_uri.return_value = "viking://temp/test456" + mock_viking_fs.mkdir = AsyncMock() + mock_viking_fs.write_file_bytes = AsyncMock() + mock_viking_fs.write_file = AsyncMock() + + with ( + patch( + "openviking.parse.parsers.media.audio.get_viking_fs", + return_value=mock_viking_fs, + ), + patch.object( + parser, + "_asr_transcribe_with_timestamps", + new_callable=AsyncMock, + return_value=segments, + ), + ): + result = await parser.parse(tmp_path) + + assert result.root.meta["has_transcript"] is True + assert result.root.meta["segment_count"] == 3 + assert len(result.root.children) == 3 + assert result.root.children[0].type == NodeType.SECTION + assert "0:00" in result.root.children[0].title + assert result.root.children[0].meta["text"] == "Hello world." + assert len(result.warnings) == 0 + + mock_viking_fs.write_file.assert_called_once() + call_args = mock_viking_fs.write_file.call_args + assert "transcript.md" in call_args[0][0] + finally: + Path(tmp_path).unlink(missing_ok=True) + + +class TestAudioParserTranscript: + def test_build_transcript_markdown_with_segments(self): + parser = AudioParser() + segments = [ + {"start": 0.0, "end": 15.0, "text": "First segment."}, + {"start": 15.0, "end": 30.0, "text": "Second segment."}, + ] + md = parser._build_transcript_markdown(segments, "", "test_audio") + assert "# Transcript: test_audio" in md + assert "**[0:00 - 0:15]** First segment." in md + assert "**[0:15 - 0:30]** Second segment." in md + + def test_build_transcript_markdown_plain(self): + parser = AudioParser() + md = parser._build_transcript_markdown( + [], "This is the full transcript text.", "test_audio" + ) + assert "# Transcript: test_audio" in md + assert "This is the full transcript text." in md + + +class TestAudioParserParseContent: + @pytest.mark.asyncio + async def test_parse_content_not_implemented(self): + parser = AudioParser() + with pytest.raises(NotImplementedError): + await parser.parse_content("base64data") + + +class TestAudioMagicBytes: + def test_magic_bytes_defined(self): + """Verify magic bytes are defined for all supported formats.""" + assert ".mp3" in AUDIO_MAGIC_BYTES + assert ".wav" in AUDIO_MAGIC_BYTES + assert ".ogg" in AUDIO_MAGIC_BYTES + assert ".flac" in AUDIO_MAGIC_BYTES + assert ".aac" in AUDIO_MAGIC_BYTES + assert ".m4a" in AUDIO_MAGIC_BYTES + assert ".opus" in AUDIO_MAGIC_BYTES From 100083b7bf0be209a84c70072275b009f3b41a38 Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Tue, 17 Mar 2026 08:11:26 -0700 Subject: [PATCH 2/3] style: format with ruff --- bot/vikingbot/config/loader.py | 5 ++- bot/vikingbot/config/schema.py | 13 ++++-- bot/vikingbot/openviking_mount/ov_server.py | 15 +++---- openviking/models/embedder/__init__.py | 2 +- openviking/parse/parsers/excel.py | 13 ++++-- openviking/parse/parsers/legacy_doc.py | 21 ++++++---- openviking/parse/parsers/media/audio.py | 42 +++++++++---------- openviking/parse/registry.py | 6 +-- openviking/service/core.py | 4 +- openviking/session/compressor.py | 4 +- .../vectordb/service/server_fastapi.py | 29 +++++++------ tests/misc/test_embedding_input_type.py | 4 +- .../session/test_memory_extractor_language.py | 7 +++- tests/unit/test_ollama_embedding_factory.py | 3 +- tests/unit/test_openai_embedder.py | 2 - tests/unit/test_openai_embedder_chunking.py | 4 +- 16 files changed, 97 insertions(+), 77 deletions(-) diff --git a/bot/vikingbot/config/loader.py b/bot/vikingbot/config/loader.py index e57c3879..130e151b 100644 --- a/bot/vikingbot/config/loader.py +++ b/bot/vikingbot/config/loader.py @@ -4,11 +4,14 @@ import os from pathlib import Path from typing import Any + from loguru import logger + from vikingbot.config.schema import Config CONFIG_PATH = None + def get_config_path() -> Path: """Get the path to ov.conf config file. @@ -217,4 +220,4 @@ def camel_to_snake(name: str) -> str: def snake_to_camel(name: str) -> str: """Convert snake_case to camelCase.""" components = name.split("_") - return components[0] + "".join(x.title() for x in components[1:]) \ No newline at end of file + return components[0] + "".join(x.title() for x in components[1:]) diff --git a/bot/vikingbot/config/schema.py b/bot/vikingbot/config/schema.py index 90a2b10c..3dc1a8d5 100644 --- a/bot/vikingbot/config/schema.py +++ b/bot/vikingbot/config/schema.py @@ -40,8 +40,10 @@ class SandboxMode(str, Enum): SHARED = "shared" PER_CHANNEL = "per-channel" + class AgentMemoryMode(str, Enum): """Agent memory mode enumeration.""" + PER_SESSION = "per-session" SHARED = "shared" PER_CHANNEL = "per-channel" @@ -109,7 +111,10 @@ class FeishuChannelConfig(BaseChannelConfig): encrypt_key: str = "" verification_token: str = "" allow_from: list[str] = Field(default_factory=list) ## 允许更新Agent对话的Feishu用户ID列表 - thread_require_mention: bool = Field(default=True, description="话题群模式下是否需要@才响应:默认True=所有消息必须@才响应;False=新话题首条消息无需@,后续回复必须@") + thread_require_mention: bool = Field( + default=True, + description="话题群模式下是否需要@才响应:默认True=所有消息必须@才响应;False=新话题首条消息无需@,后续回复必须@", + ) def channel_id(self) -> str: # Use app_id directly as the ID @@ -396,7 +401,9 @@ class ProviderConfig(BaseModel): api_key: str = "" api_base: Optional[str] = None - extra_headers: Optional[dict[str, str]] = Field(default_factory=dict) # Custom headers (e.g. APP-Code for AiHubMix) + extra_headers: Optional[dict[str, str]] = Field( + default_factory=dict + ) # Custom headers (e.g. APP-Code for AiHubMix) class ProvidersConfig(BaseModel): @@ -734,4 +741,4 @@ def from_safe_name(safe_name: str): file_name_split = safe_name.split("__") return SessionKey( type=file_name_split[0], channel_id=file_name_split[1], chat_id=file_name_split[2] - ) \ No newline at end of file + ) diff --git a/bot/vikingbot/openviking_mount/ov_server.py b/bot/vikingbot/openviking_mount/ov_server.py index d3f6383c..514beb47 100644 --- a/bot/vikingbot/openviking_mount/ov_server.py +++ b/bot/vikingbot/openviking_mount/ov_server.py @@ -1,9 +1,9 @@ import asyncio import hashlib -from typing import List, Dict, Any, Optional +import time +from typing import Any, Dict, List, Optional from loguru import logger -import time import openviking as ov from vikingbot.config.loader import load_config @@ -99,9 +99,7 @@ async def find(self, query: str, target_uri: Optional[str] = None): return await self.client.find(query, target_uri=target_uri) return await self.client.find(query) - async def add_resource( - self, local_path: str, desc: str - ) -> Optional[Dict[str, Any]]: + async def add_resource(self, local_path: str, desc: str) -> Optional[Dict[str, Any]]: """添加资源到 Viking""" result = await self.client.add_resource(path=local_path, reason=desc) return result @@ -327,7 +325,9 @@ async def search_memory( async def grep(self, uri: str, pattern: str, case_insensitive: bool = False) -> Dict[str, Any]: """通过模式(正则表达式)搜索内容""" - return await self.client.grep(uri, pattern, case_insensitive=case_insensitive, node_limit=10) + return await self.client.grep( + uri, pattern, case_insensitive=case_insensitive, node_limit=10 + ) async def glob(self, pattern: str, uri: Optional[str] = None) -> Dict[str, Any]: """通过 glob 模式匹配文件""" @@ -337,7 +337,8 @@ async def commit(self, session_id: str, messages: list[dict[str, Any]], user_id: """提交会话""" import re import uuid - from openviking.message.part import Part, TextPart, ToolPart + + from openviking.message.part import TextPart, ToolPart user_exists = await self._check_user_exists(user_id) if not user_exists: diff --git a/openviking/models/embedder/__init__.py b/openviking/models/embedder/__init__.py index b418b809..02ccdfc0 100644 --- a/openviking/models/embedder/__init__.py +++ b/openviking/models/embedder/__init__.py @@ -25,7 +25,6 @@ ) from openviking.models.embedder.jina_embedders import JinaDenseEmbedder from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder -from openviking.models.embedder.voyage_embedders import VoyageDenseEmbedder from openviking.models.embedder.vikingdb_embedders import ( VikingDBDenseEmbedder, VikingDBHybridEmbedder, @@ -36,6 +35,7 @@ VolcengineHybridEmbedder, VolcengineSparseEmbedder, ) +from openviking.models.embedder.voyage_embedders import VoyageDenseEmbedder __all__ = [ # Base classes diff --git a/openviking/parse/parsers/excel.py b/openviking/parse/parsers/excel.py index 2da44ce3..a904f786 100644 --- a/openviking/parse/parsers/excel.py +++ b/openviking/parse/parsers/excel.py @@ -142,7 +142,9 @@ def _format_xls_cell(cell, wb, xlrd) -> str: dt = xlrd.xldate_as_tuple(cell.value, wb.datemode) # Include time component if non-zero if dt[3] or dt[4] or dt[5]: - return f"{dt[0]:04d}-{dt[1]:02d}-{dt[2]:02d} {dt[3]:02d}:{dt[4]:02d}:{dt[5]:02d}" + return ( + f"{dt[0]:04d}-{dt[1]:02d}-{dt[2]:02d} {dt[3]:02d}:{dt[4]:02d}:{dt[5]:02d}" + ) return f"{dt[0]:04d}-{dt[1]:02d}-{dt[2]:02d}" except Exception: return str(cell.value) @@ -151,8 +153,13 @@ def _format_xls_cell(cell, wb, xlrd) -> str: if cell.ctype == xlrd.XL_CELL_ERROR: # xlrd error code map error_map = { - 0x00: "#NULL!", 0x07: "#DIV/0!", 0x0F: "#VALUE!", - 0x17: "#REF!", 0x1D: "#NAME?", 0x24: "#NUM!", 0x2A: "#N/A", + 0x00: "#NULL!", + 0x07: "#DIV/0!", + 0x0F: "#VALUE!", + 0x17: "#REF!", + 0x1D: "#NAME?", + 0x24: "#NUM!", + 0x2A: "#N/A", } return error_map.get(cell.value, f"#ERR({cell.value})") if cell.ctype == xlrd.XL_CELL_NUMBER: diff --git a/openviking/parse/parsers/legacy_doc.py b/openviking/parse/parsers/legacy_doc.py index 95b770ae..025216f9 100644 --- a/openviking/parse/parsers/legacy_doc.py +++ b/openviking/parse/parsers/legacy_doc.py @@ -19,7 +19,7 @@ logger = get_logger(__name__) - # Max stream size to read (50MB) — prevents DoS from crafted files +# Max stream size to read (50MB) — prevents DoS from crafted files _MAX_STREAM_SIZE = 50 * 1024 * 1024 # Max character count sanity cap for ccpText _MAX_CCP_TEXT = 10_000_000 @@ -154,9 +154,7 @@ def _extract_from_ole(self, ole) -> str: if fc_clx <= 0 or lcb_clx <= 0 or fc_clx + lcb_clx > len(table_data): return self._simple_text_extract(word_doc, ccp_text) - return self._extract_via_clx( - word_doc, table_data, fc_clx, lcb_clx, ccp_text - ) + return self._extract_via_clx(word_doc, table_data, fc_clx, lcb_clx, ccp_text) def _simple_text_extract(self, word_doc: bytes, ccp_text: int) -> str: """ @@ -177,7 +175,10 @@ def _simple_text_extract(self, word_doc: bytes, ccp_text: int) -> str: raw = word_doc[text_start:end] text = raw.decode("utf-16-le", errors="replace") # Sanity: if mostly printable, it's likely correct - if sum(1 for c in text[:200] if c.isprintable() or c in "\n\r\t") > len(text[:200]) * 0.5: + if ( + sum(1 for c in text[:200] if c.isprintable() or c in "\n\r\t") + > len(text[:200]) * 0.5 + ): return self._clean_word_text(text) # Fall back to CP1252 single-byte @@ -277,7 +278,9 @@ def _extract_via_clx( raw = word_doc[byte_offset:byte_end] text_parts.append(self._decode_cp1252(raw)) else: - logger.warning(f"Piece {i} extends beyond stream ({byte_end} > {len(word_doc)})") + logger.warning( + f"Piece {i} extends beyond stream ({byte_end} > {len(word_doc)})" + ) else: # UTF-16LE byte_offset = fc_real @@ -286,7 +289,9 @@ def _extract_via_clx( raw = word_doc[byte_offset:byte_end] text_parts.append(raw.decode("utf-16-le", errors="replace")) else: - logger.warning(f"Piece {i} extends beyond stream ({byte_end} > {len(word_doc)})") + logger.warning( + f"Piece {i} extends beyond stream ({byte_end} > {len(word_doc)})" + ) chars_extracted += piece_char_count @@ -305,7 +310,7 @@ def _clean_word_text(text: str) -> str: """Normalize Word control characters to readable equivalents.""" text = text.replace("\r\n", "\n").replace("\r", "\n") # \x07 = cell/row end, \x0B = soft line break, \x0C = section break - text = text.replace("\x07", "\t").replace("\x0B", "\n").replace("\x0C", "\n\n") + text = text.replace("\x07", "\t").replace("\x0b", "\n").replace("\x0c", "\n\n") return text def _fallback_extract(self, path: Path) -> str: diff --git a/openviking/parse/parsers/media/audio.py b/openviking/parse/parsers/media/audio.py index fbe944c9..d702efd9 100644 --- a/openviking/parse/parsers/media/audio.py +++ b/openviking/parse/parsers/media/audio.py @@ -129,9 +129,7 @@ def supported_extensions(self) -> List[str]: """Return supported audio file extensions.""" return AUDIO_EXTENSIONS - async def parse( - self, source: Union[str, Path], instruction: str = "", **kwargs - ) -> ParseResult: + async def parse(self, source: Union[str, Path], instruction: str = "", **kwargs) -> ParseResult: """ Parse audio file - extract metadata, transcribe via Whisper, build ResourceNode tree. @@ -198,9 +196,7 @@ async def parse( audio_bytes, self.config.transcription_model, ext ) if transcript_segments: - full_transcript = "\n".join( - seg["text"] for seg in transcript_segments - ) + full_transcript = "\n".join(seg["text"] for seg in transcript_segments) else: # Try plain transcription full_transcript = await self._asr_transcribe( @@ -274,9 +270,7 @@ async def parse( # Generate semantic info (L0 abstract, L1 overview) description = full_transcript if has_transcript else f"Audio file: {file_path.name}" - await self._generate_semantic_info( - root_node, description, viking_fs, has_transcript - ) + await self._generate_semantic_info(root_node, description, viking_fs, has_transcript) if not has_transcript: warnings.append( @@ -297,9 +291,7 @@ async def parse( warnings=warnings, ) - def _validate_audio_bytes( - self, audio_bytes: bytes, ext: str, file_path: Path - ) -> None: + def _validate_audio_bytes(self, audio_bytes: bytes, ext: str, file_path: Path) -> None: """Validate audio file using magic bytes.""" ext_lower = ext.lower() magic_list = AUDIO_MAGIC_BYTES.get(ext_lower, []) @@ -391,11 +383,19 @@ async def _asr_transcribe_with_timestamps( segments = [] if hasattr(response, "segments") and response.segments: for seg in response.segments: - segments.append({ - "start": seg.get("start", 0) if isinstance(seg, dict) else getattr(seg, "start", 0), - "end": seg.get("end", 0) if isinstance(seg, dict) else getattr(seg, "end", 0), - "text": seg.get("text", "") if isinstance(seg, dict) else getattr(seg, "text", ""), - }) + segments.append( + { + "start": seg.get("start", 0) + if isinstance(seg, dict) + else getattr(seg, "start", 0), + "end": seg.get("end", 0) + if isinstance(seg, dict) + else getattr(seg, "end", 0), + "text": seg.get("text", "") + if isinstance(seg, dict) + else getattr(seg, "text", ""), + } + ) return segments @@ -476,9 +476,7 @@ async def _generate_semantic_info( ] if has_transcript: - overview_parts.append( - "- transcript.md: Timestamped transcript from the audio\n" - ) + overview_parts.append("- transcript.md: Timestamped transcript from the audio\n") overview_parts.append("\n## Usage\n") overview_parts.append("### Play Audio\n") @@ -505,9 +503,7 @@ async def _generate_semantic_info( overview_parts.append( f"channels = audio_resource.get_channels() # {node.meta['channels']}\n" ) - overview_parts.append( - f'format = audio_resource.get_format() # "{node.meta["format"]}"\n' - ) + overview_parts.append(f'format = audio_resource.get_format() # "{node.meta["format"]}"\n') overview_parts.append("```\n") overview = "".join(overview_parts) diff --git a/openviking/parse/registry.py b/openviking/parse/registry.py index 4c6b3e4a..40d351fb 100644 --- a/openviking/parse/registry.py +++ b/openviking/parse/registry.py @@ -19,14 +19,14 @@ # Import will be handled dynamically to avoid dependency issues from openviking.parse.parsers.html import HTMLParser + +# Import markitdown-inspired parsers +from openviking.parse.parsers.legacy_doc import LegacyDocParser from openviking.parse.parsers.markdown import MarkdownParser from openviking.parse.parsers.media import AudioParser, ImageParser, VideoParser from openviking.parse.parsers.pdf import PDFParser from openviking.parse.parsers.powerpoint import PowerPointParser from openviking.parse.parsers.text import TextParser - -# Import markitdown-inspired parsers -from openviking.parse.parsers.legacy_doc import LegacyDocParser from openviking.parse.parsers.word import WordParser from openviking.parse.parsers.zip_parser import ZipParser diff --git a/openviking/service/core.py b/openviking/service/core.py index b07bdb84..3a7d1f22 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -132,7 +132,9 @@ def _init_storage( logger.warning("AGFS client not initialized, skipping queue manager") # Initialize VikingDBManager with QueueManager - self._vikingdb_manager = VikingDBManager(vectordb_config=config.vectordb, queue_manager=self._queue_manager) + self._vikingdb_manager = VikingDBManager( + vectordb_config=config.vectordb, queue_manager=self._queue_manager + ) # Configure queues if QueueManager is available if self._queue_manager: diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 7fd066c7..09c924d4 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -379,9 +379,7 @@ async def extract_long_term_memories( merged_text = ( f"{action.memory.abstract} {candidate.content}" ) - merged_embed = self.deduplicator.embedder.embed( - merged_text - ) + merged_embed = self.deduplicator.embedder.embed(merged_text) batch_memories.append( (merged_embed.dense_vector, action.memory) ) diff --git a/openviking/storage/vectordb/service/server_fastapi.py b/openviking/storage/vectordb/service/server_fastapi.py index 34574e60..2b6edb04 100644 --- a/openviking/storage/vectordb/service/server_fastapi.py +++ b/openviking/storage/vectordb/service/server_fastapi.py @@ -11,7 +11,7 @@ import random import time from contextlib import asynccontextmanager -from typing import Dict, Any +from typing import Any, Dict import uvicorn from fastapi import FastAPI, Request @@ -28,19 +28,19 @@ @asynccontextmanager async def lifespan(app: FastAPI): """Handle application startup and shutdown events. - + Manages resource initialization and cleanup, ensuring graceful shutdown by waiting for all active requests to complete. - + Args: app: The FastAPI application instance """ # Startup logger.info("============ VikingDB Server Starting =============") random.seed(time.time_ns()) - + yield - + # Shutdown logger.info("Waiting for active requests to complete...") while _active_requests > 0: @@ -61,31 +61,30 @@ async def lifespan(app: FastAPI): @app.exception_handler(VikingDBException) async def vikingdb_exception_handler(request: Request, exc: VikingDBException) -> JSONResponse: """Handle VikingDB-specific exceptions. - + Args: request: The incoming HTTP request exc: The VikingDBException that was raised - + Returns: JSONResponse with error details """ return JSONResponse( - status_code=200, - content=error_response(exc.message, exc.code.value, request=request) + status_code=200, content=error_response(exc.message, exc.code.value, request=request) ) @app.middleware("http") async def request_tracking_middleware(request: Request, call_next): """Middleware to track request processing time and active request count. - + Increments active request counter, measures processing time, and adds processing time header to response. - + Args: request: The incoming HTTP request call_next: The next middleware/handler in the chain - + Returns: Response with added X-Process-Time header """ @@ -118,7 +117,7 @@ async def request_tracking_middleware(request: Request, call_next): @app.get("/") async def root() -> Dict[str, str]: """Root endpoint providing basic server information. - + Returns: Dict containing server name and version """ @@ -128,7 +127,7 @@ async def root() -> Dict[str, str]: @app.get("/health") async def health() -> Dict[str, Any]: """Health check endpoint for monitoring server status. - + Returns: Dict containing health status and current active request count """ @@ -140,4 +139,4 @@ async def health() -> Dict[str, Any]: logger.info("Starting VikingDB server on 0.0.0.0:5000") uvicorn.run(app, host="0.0.0.0", port=5000, log_level="info") except Exception as e: - logger.error(f"Failed to start VikingDB server: {e}") \ No newline at end of file + logger.error(f"Failed to start VikingDB server: {e}") diff --git a/tests/misc/test_embedding_input_type.py b/tests/misc/test_embedding_input_type.py index 4ee8f660..34e4897f 100644 --- a/tests/misc/test_embedding_input_type.py +++ b/tests/misc/test_embedding_input_type.py @@ -9,8 +9,6 @@ from unittest.mock import MagicMock, patch -import pytest - from openviking_cli.utils.config.embedding_config import EmbeddingConfig, EmbeddingModelConfig @@ -86,7 +84,7 @@ def test_legacy_input_type_lowercase_normalization(self): ) assert config.input_type == "search_query" - def test_query_document_param_lowercase_normalization(self): + def test_query_document_param_lowercase_normalization_jina(self): """Query/document task values should be normalized to lowercase.""" config = EmbeddingModelConfig( model="jina-embeddings-v5-text-small", diff --git a/tests/session/test_memory_extractor_language.py b/tests/session/test_memory_extractor_language.py index d1d98021..22edc865 100644 --- a/tests/session/test_memory_extractor_language.py +++ b/tests/session/test_memory_extractor_language.py @@ -104,7 +104,12 @@ def test_detect_output_language_japanese_with_single_cyrillic(): def test_detect_output_language_russian_with_threshold(): """Russian text with sufficient Cyrillic chars should be detected as Russian.""" - messages = [_msg("user", "\u042d\u0442\u043e \u0440\u0443\u0441\u0441\u043a\u0438\u0439 \u0442\u0435\u043a\u0441\u0442")] + messages = [ + _msg( + "user", + "\u042d\u0442\u043e \u0440\u0443\u0441\u0441\u043a\u0438\u0439 \u0442\u0435\u043a\u0441\u0442", + ) + ] language = MemoryExtractor._detect_output_language(messages, fallback_language="en") assert language == "ru" diff --git a/tests/unit/test_ollama_embedding_factory.py b/tests/unit/test_ollama_embedding_factory.py index d2b3caf1..dd2db492 100644 --- a/tests/unit/test_ollama_embedding_factory.py +++ b/tests/unit/test_ollama_embedding_factory.py @@ -9,7 +9,6 @@ with the openai factory and the placeholder used inside OpenAIDenseEmbedder. """ -import pytest from unittest.mock import MagicMock, patch from openviking_cli.utils.config.embedding_config import EmbeddingConfig, EmbeddingModelConfig @@ -27,7 +26,7 @@ def _make_mock_openai_class(): def _make_ollama_cfg(**kwargs) -> EmbeddingModelConfig: - defaults = dict(provider="ollama", model="nomic-embed-text", dimension=768) + defaults = {"provider": "ollama", "model": "nomic-embed-text", "dimension": 768} defaults.update(kwargs) return EmbeddingModelConfig(**defaults) diff --git a/tests/unit/test_openai_embedder.py b/tests/unit/test_openai_embedder.py index 8e9b72a8..bd5493ca 100644 --- a/tests/unit/test_openai_embedder.py +++ b/tests/unit/test_openai_embedder.py @@ -4,8 +4,6 @@ from unittest.mock import MagicMock, patch -import pytest - from openviking.models.embedder import OpenAIDenseEmbedder diff --git a/tests/unit/test_openai_embedder_chunking.py b/tests/unit/test_openai_embedder_chunking.py index 514aa4df..05a85d36 100644 --- a/tests/unit/test_openai_embedder_chunking.py +++ b/tests/unit/test_openai_embedder_chunking.py @@ -409,7 +409,9 @@ def test_low_custom_max_tokens_triggers_chunking(self, mock_openai_class): # Need text long enough to produce multiple chunks. # Fallback estimation: len(text)//3. With max_tokens=5, need >5 tokens. # Fixed-length split has min chunk_size=100, so text must be >100 chars to split. - text = "Hello world test. " * 30 # 540 chars -> 180 estimated tokens, well over max_tokens=5 + text = ( + "Hello world test. " * 30 + ) # 540 chars -> 180 estimated tokens, well over max_tokens=5 mock_client.embeddings.create.reset_mock() result = embedder.embed(text) From a6a57ce549ff887ada3dbef4287aca09eec15542 Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Wed, 18 Mar 2026 06:53:48 -0700 Subject: [PATCH 3/3] Address review feedback: fix Whisper API integration and reduce duplication - Change default transcription_model from "whisper-large-v3" (HuggingFace name) to "whisper-1" (OpenAI API compatible) - Add base_url support via ProviderConfig.api_base so custom Whisper deployments (Azure, local server) work correctly - Extract _get_whisper_client() and _prepare_audio_file() helpers to eliminate duplicate boilerplate between _asr_transcribe and _asr_transcribe_with_timestamps - Remove unused audio_summary.yaml template (will integrate with render_prompt in a follow-up if LLM-powered summarization is desired) Co-Authored-By: Claude Opus 4.6 (1M context) --- openviking/parse/parsers/media/audio.py | 75 ++++++++----------- .../templates/parsing/audio_summary.yaml | 44 ----------- openviking_cli/utils/config/parser_config.py | 5 +- 3 files changed, 34 insertions(+), 90 deletions(-) delete mode 100644 openviking/prompts/templates/parsing/audio_summary.yaml diff --git a/openviking/parse/parsers/media/audio.py b/openviking/parse/parsers/media/audio.py index d702efd9..3b77446e 100644 --- a/openviking/parse/parsers/media/audio.py +++ b/openviking/parse/parsers/media/audio.py @@ -306,32 +306,38 @@ def _validate_audio_bytes(self, audio_bytes: bytes, ext: str, file_path: Path) - f"File signature does not match expected format {ext_lower}" ) + def _get_whisper_client(self): + """Create an AsyncOpenAI client using ProviderConfig settings. + + Reads api_key and api_base from the project config so custom + Whisper deployments (Azure, local server) work correctly. + """ + import openai + from openviking_cli.utils.config import get_openviking_config + + config = get_openviking_config() + kwargs: Dict[str, Any] = {} + if hasattr(config, "llm"): + if config.llm.api_key: + kwargs["api_key"] = config.llm.api_key + if hasattr(config.llm, "api_base") and config.llm.api_base: + kwargs["base_url"] = config.llm.api_base + return openai.AsyncOpenAI(**kwargs) + + @staticmethod + def _prepare_audio_file(audio_bytes: bytes, ext: str) -> io.BytesIO: + """Wrap raw audio bytes in a named BytesIO for the Whisper API.""" + audio_file = io.BytesIO(audio_bytes) + audio_file.name = f"audio{ext}" + return audio_file + async def _asr_transcribe( self, audio_bytes: bytes, model: Optional[str], ext: str = ".mp3" ) -> str: - """ - Transcribe audio using Whisper API via OpenAI client. - - Args: - audio_bytes: Audio binary data - model: Whisper model name - ext: File extension for mime type hint - - Returns: - Transcription text - """ + """Transcribe audio using Whisper API via OpenAI client.""" try: - from openviking_cli.utils.config import get_openviking_config - - config = get_openviking_config() - import openai - - client = openai.AsyncOpenAI( - api_key=config.llm.api_key if hasattr(config, "llm") else None, - ) - - audio_file = io.BytesIO(audio_bytes) - audio_file.name = f"audio{ext}" + client = self._get_whisper_client() + audio_file = self._prepare_audio_file(audio_bytes, ext) response = await client.audio.transcriptions.create( model=model or "whisper-1", @@ -348,29 +354,10 @@ async def _asr_transcribe( async def _asr_transcribe_with_timestamps( self, audio_bytes: bytes, model: Optional[str], ext: str = ".mp3" ) -> List[Dict[str, Any]]: - """ - Transcribe audio with timestamps using Whisper API verbose_json format. - - Args: - audio_bytes: Audio binary data - model: Whisper model name - ext: File extension - - Returns: - List of segment dicts with keys: start, end, text - """ + """Transcribe audio with timestamps using Whisper API verbose_json format.""" try: - from openviking_cli.utils.config import get_openviking_config - - config = get_openviking_config() - import openai - - client = openai.AsyncOpenAI( - api_key=config.llm.api_key if hasattr(config, "llm") else None, - ) - - audio_file = io.BytesIO(audio_bytes) - audio_file.name = f"audio{ext}" + client = self._get_whisper_client() + audio_file = self._prepare_audio_file(audio_bytes, ext) response = await client.audio.transcriptions.create( model=model or "whisper-1", diff --git a/openviking/prompts/templates/parsing/audio_summary.yaml b/openviking/prompts/templates/parsing/audio_summary.yaml deleted file mode 100644 index 06009d47..00000000 --- a/openviking/prompts/templates/parsing/audio_summary.yaml +++ /dev/null @@ -1,44 +0,0 @@ -metadata: - id: "parsing.audio_summary" - name: "Audio Summary" - description: "Generate concise audio summary from transcript for semantic parsing" - version: "1.0.0" - language: "en" - category: "parsing" - -variables: - - name: "transcript" - type: "string" - description: "Full audio transcript text" - required: true - max_length: 30000 - - name: "duration" - type: "string" - description: "Audio duration in seconds" - default: "unknown" - required: false - - name: "format" - type: "string" - description: "Audio file format" - default: "unknown" - required: false - -template: | - Please analyze this audio transcript and generate a concise summary for semantic indexing. - - Audio duration: {{ duration }}s - Audio format: {{ format }} - - Transcript: - {{ transcript }} - - Generate a comprehensive summary that includes: - 1. Main topic or subject of the audio - 2. Key points discussed - 3. Any notable speakers or perspectives - 4. Important conclusions or takeaways - - Keep the summary clear and factual, suitable for semantic search and understanding. - -llm_config: - temperature: 0.0 diff --git a/openviking_cli/utils/config/parser_config.py b/openviking_cli/utils/config/parser_config.py index c8ff46aa..2c5f199f 100644 --- a/openviking_cli/utils/config/parser_config.py +++ b/openviking_cli/utils/config/parser_config.py @@ -309,13 +309,14 @@ class AudioConfig(ParserConfig): Attributes: enable_transcription: Whether to transcribe speech to text - transcription_model: Model to use (e.g., "whisper-large-v3") + transcription_model: Whisper model name. Use "whisper-1" for + OpenAI API, or a custom model name for self-hosted endpoints. language: Audio language (None for auto-detection) extract_metadata: Whether to extract audio metadata """ enable_transcription: bool = True - transcription_model: str = "whisper-large-v3" + transcription_model: str = "whisper-1" language: Optional[str] = None extract_metadata: bool = True