Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 182 additions & 1 deletion openviking/storage/collection_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import asyncio
import hashlib
import json
import re
from functools import lru_cache
from typing import Any, Dict, Optional

from openviking.models.embedder.base import EmbedResult
Expand All @@ -25,6 +27,108 @@

logger = get_logger(__name__)

_TOKEN_LIMIT_RE = re.compile(
r"passed\s+(?P<input_tokens>\d+)\s+input tokens.*?maximum input length of\s+"
r"(?P<max_tokens>\d+)\s+tokens",
re.IGNORECASE | re.DOTALL,
)
_EMBEDDING_TRUNCATION_HEADROOM = 512


def _parse_input_token_limit_error(error: Exception) -> Optional[tuple[int, int]]:
"""Extract input-token and max-token values from provider errors."""
match = _TOKEN_LIMIT_RE.search(str(error))
if not match:
return None
return int(match.group("input_tokens")), int(match.group("max_tokens"))


@lru_cache(maxsize=16)
def _get_token_encoder(model_name: str):
"""Best-effort tokenizer lookup for provider-compatible embedding models."""
try:
import tiktoken
except ImportError:
return None

if model_name:
try:
return tiktoken.encoding_for_model(model_name)
except KeyError:
pass

try:
return tiktoken.get_encoding("cl100k_base")
except Exception:
return None


def _truncate_text_to_token_limit(
text: str,
model_name: str,
max_tokens: int,
*,
observed_input_tokens: Optional[int] = None,
) -> str:
"""Trim text to the requested token budget."""
if max_tokens <= 0 or not text:
return text

if observed_input_tokens and observed_input_tokens > max_tokens:
shrink_ratio = max_tokens / observed_input_tokens
target_chars = max(1, int(len(text) * shrink_ratio * 0.9))
if target_chars < len(text):
return text[:target_chars]

encoder = _get_token_encoder(model_name)
if encoder is not None:
token_ids = encoder.encode(text)
if len(token_ids) <= max_tokens:
if observed_input_tokens and observed_input_tokens > max_tokens:
estimated_tokens = max(1, len(token_ids))
shrink_ratio = max_tokens / estimated_tokens
target_chars = max(1, int(len(text) * shrink_ratio * 0.9))
return text[:target_chars] if target_chars < len(text) else text
return text
return encoder.decode(token_ids[:max_tokens])

estimated_tokens = max(1, len(text.encode("utf-8")) // 2)
if estimated_tokens <= max_tokens:
return text

shrink_ratio = max_tokens / estimated_tokens
target_chars = max(1, int(len(text) * shrink_ratio * 0.9))
return text[:target_chars]


def _resolve_embedder_dimension(
embedder: Any, configured_dimension: int, *, warn_prefix: str
) -> int:
"""Prefer the embedder-reported dimension over config defaults."""
if embedder and hasattr(embedder, "get_dimension"):
try:
actual_dimension = int(embedder.get_dimension())
if actual_dimension > 0:
if configured_dimension and configured_dimension != actual_dimension:
logger.warning(
"%s embedding dimension mismatch: config=%s, embedder=%s. "
"Using embedder dimension.",
warn_prefix,
configured_dimension,
actual_dimension,
)
return actual_dimension
except Exception as exc:
logger.warning(
"%s failed to resolve embedding dimension from embedder, "
"falling back to config=%s: %s",
warn_prefix,
configured_dimension,
exc,
)

return configured_dimension


class CollectionSchemas:
"""
Expand Down Expand Up @@ -114,6 +218,18 @@ async def init_context_collection(storage) -> bool:
config = get_openviking_config()
name = config.storage.vectordb.name
vector_dim = config.embedding.dimension
try:
embedder = config.embedding.get_embedder()
vector_dim = _resolve_embedder_dimension(
embedder, vector_dim, warn_prefix="init_context_collection"
)
except Exception as exc:
logger.warning(
"init_context_collection failed to initialize embedder for dimension "
"detection, using config dimension=%s: %s",
vector_dim,
exc,
)
schema = CollectionSchemas.context_collection(name, vector_dim)
return await storage.create_collection(name, schema)

Expand Down Expand Up @@ -148,6 +264,68 @@ def __init__(self, vikingdb: VikingVectorIndexBackend):
def _initialize_embedder(self, config: "OpenVikingConfig"):
"""Initialize the embedder instance from config."""
self._embedder = config.embedding.get_embedder()
self._vector_dim = _resolve_embedder_dimension(
self._embedder, self._vector_dim, warn_prefix="TextEmbeddingHandler"
)

def _embed_with_retry(
self,
text: str,
uri: str = "",
fallback_text: str = "",
) -> EmbedResult:
"""Retry with progressively smaller text when the provider rejects overlong input."""
current_text = text
model_name = getattr(self._embedder, "model_name", "")
last_error: Optional[Exception] = None

for attempt in range(5):
try:
return self._embedder.embed(current_text)
except Exception as exc:
last_error = exc
limit_info = _parse_input_token_limit_error(exc)
if not limit_info:
raise

input_tokens, max_tokens = limit_info
retry_budget = max(
1,
int((max_tokens - _EMBEDDING_TRUNCATION_HEADROOM) * (0.85**attempt)),
)
truncated_text = _truncate_text_to_token_limit(
current_text,
model_name,
retry_budget,
observed_input_tokens=input_tokens,
)
if len(truncated_text) >= len(current_text):
fallback_chars = max(1, int(len(current_text) * 0.5))
truncated_text = current_text[:fallback_chars]
if len(truncated_text) >= len(current_text):
raise exc

logger.warning(
"Embedding input too long for uri=%s model=%s (%s > %s tokens). "
"Attempt %s/5: truncating to ~%s tokens and retrying.",
uri or "<unknown>",
model_name or "<unknown>",
input_tokens,
max_tokens,
attempt + 1,
retry_budget,
)
current_text = truncated_text

if fallback_text and fallback_text.strip() and fallback_text.strip() != text.strip():
logger.warning(
"Embedding retries exhausted for uri=%s model=%s. Falling back to abstract/summary text.",
uri or "<unknown>",
model_name or "<unknown>",
)
return self._embedder.embed(fallback_text)

raise RuntimeError("Failed to embed text after token-limit retries") from last_error

@staticmethod
def _seed_uri_for_id(uri: str, level: Any) -> str:
Expand Down Expand Up @@ -198,7 +376,10 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str,
# embed() is a blocking HTTP call; offload to thread pool to avoid
# blocking the event loop and allow real concurrency.
result: EmbedResult = await asyncio.to_thread(
self._embedder.embed, embedding_msg.message
self._embed_with_retry,
embedding_msg.message,
inserted_data.get("uri", ""),
inserted_data.get("abstract", ""),
)

# Add dense vector
Expand Down
64 changes: 56 additions & 8 deletions openviking/storage/queuefs/semantic_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""SemanticProcessor: Processes messages from SemanticQueue, generates .abstract.md and .overview.md."""

import asyncio
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple

from openviking.parse.parsers.constants import (
Expand Down Expand Up @@ -31,6 +32,52 @@

logger = get_logger(__name__)

_SEMANTIC_CONTENT_TOKEN_BUDGET = 6000


@lru_cache(maxsize=16)
def _get_token_encoder(model_name: str):
"""Best-effort tokenizer lookup for LLM prompt budgeting."""
try:
import tiktoken
except ImportError:
return None

if model_name:
try:
return tiktoken.encoding_for_model(model_name)
except KeyError:
pass

try:
return tiktoken.get_encoding("cl100k_base")
except Exception:
return None


def _truncate_text_by_token_budget(text: str, model_name: str, max_tokens: int) -> str:
"""Trim file content before rendering it into a summary prompt."""
if max_tokens <= 0 or not text:
return text

encoder = _get_token_encoder(model_name)
suffix = "\n...(truncated)"
if encoder is not None:
token_ids = encoder.encode(text)
if len(token_ids) <= max_tokens:
return text
suffix_tokens = encoder.encode(suffix)
budget = max(1, max_tokens - len(suffix_tokens))
return encoder.decode(token_ids[:budget]) + suffix

estimated_tokens = max(1, len(text.encode("utf-8")) // 2)
if estimated_tokens <= max_tokens:
return text

shrink_ratio = max_tokens / estimated_tokens
target_chars = max(1, int(len(text) * shrink_ratio * 0.9))
return text[:target_chars] + suffix


class SemanticProcessor(DequeueHandlerBase):
"""
Expand Down Expand Up @@ -321,10 +368,6 @@ async def _generate_text_summary(

# Read file content (limit length)
content = await viking_fs.read_file(file_path, ctx=active_ctx)

# Limit content length (about 10000 tokens)
max_chars = 30000
content = await viking_fs.read_file(file_path, ctx=active_ctx)
if isinstance(content, bytes):
# Try to decode with error handling for text files
try:
Expand All @@ -333,10 +376,15 @@ async def _generate_text_summary(
logger.warning(f"Failed to decode file as UTF-8, skipping: {file_path}")
return {"name": file_name, "summary": ""}

# Limit content length (about 10000 tokens)
max_chars = 30000
if len(content) > max_chars:
content = content[:max_chars] + "\n...(truncated)"
model_name = getattr(vlm.get_vlm_instance(), "model_name", "") if vlm.is_available() else ""
truncated_content = _truncate_text_by_token_budget(
content, model_name, _SEMANTIC_CONTENT_TOKEN_BUDGET
)
if truncated_content != content:
logger.debug(
"Truncated file content for summary generation: %s", file_path
)
content = truncated_content

# Generate summary
if not vlm.is_available():
Expand Down
Loading