diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index 590ad685..78cabff5 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -16,13 +16,17 @@ from typing import Any, Dict, List, Optional from openviking.models.embedder.base import EmbedResult -from openviking.models.embedder.volcengine_embedders import is_429_error from openviking.server.identity import RequestContext, Role from openviking.storage.errors import CollectionNotFoundError from openviking.storage.queuefs.embedding_msg import EmbeddingMsg from openviking.storage.queuefs.named_queue import DequeueHandlerBase from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend from openviking.telemetry import bind_telemetry, resolve_telemetry +from openviking.utils.circuit_breaker import ( + CircuitBreaker, + CircuitBreakerOpen, + classify_api_error, +) from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger from openviking_cli.utils.config.open_viking_config import OpenVikingConfig @@ -162,6 +166,7 @@ def __init__(self, vikingdb: VikingVectorIndexBackend): self._collection_name = config.storage.vectordb.name self._vector_dim = config.embedding.dimension self._initialize_embedder(config) + self._circuit_breaker = CircuitBreaker() def _initialize_embedder(self, config: "OpenVikingConfig"): """Initialize the embedder instance from config.""" @@ -236,6 +241,24 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, self.report_success() return data + # Circuit breaker: if API is known-broken, re-enqueue and wait + try: + self._circuit_breaker.check() + except CircuitBreakerOpen: + logger.warning( + f"Circuit breaker is open, re-enqueueing embedding: {embedding_msg.id}" + ) + if self._vikingdb.has_queue_manager: + wait = self._circuit_breaker.retry_after + if wait > 0: + await asyncio.sleep(wait) + await self._vikingdb.enqueue_embedding_msg(embedding_msg) + self.report_success() + return None + # No queue manager — cannot re-enqueue, drop with error + self.report_error("Circuit breaker open and no queue manager", data) + return None + # Initialize embedder if not already initialized if not self._embedder: from openviking_cli.utils.config import get_openviking_config @@ -253,13 +276,23 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, ) except Exception as embed_err: error_msg = f"Failed to generate embedding: {embed_err}" - logger.error(error_msg) + error_class = classify_api_error(embed_err) + + if error_class == "permanent": + logger.critical(error_msg) + self._circuit_breaker.record_failure(embed_err) + self._merge_request_stats(embedding_msg.telemetry_id, error_count=1) + self.report_error(error_msg, data) + return None - if is_429_error(embed_err) and self._vikingdb.has_queue_manager: + # Transient or unknown — re-enqueue for retry + logger.warning(error_msg) + self._circuit_breaker.record_failure(embed_err) + if self._vikingdb.has_queue_manager: try: await self._vikingdb.enqueue_embedding_msg(embedding_msg) logger.info( - f"Re-enqueued embedding message after rate limit: {embedding_msg.id}" + f"Re-enqueued embedding message after transient error: {embedding_msg.id}" ) self.report_success() return None @@ -342,6 +375,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, self._merge_request_stats(embedding_msg.telemetry_id, processed=1) self.report_success() + self._circuit_breaker.record_success() return inserted_data except Exception as e: diff --git a/openviking/storage/queuefs/semantic_processor.py b/openviking/storage/queuefs/semantic_processor.py index 2b92b73b..1457e340 100644 --- a/openviking/storage/queuefs/semantic_processor.py +++ b/openviking/storage/queuefs/semantic_processor.py @@ -28,6 +28,11 @@ from openviking.storage.queuefs.semantic_msg import SemanticMsg from openviking.storage.viking_fs import get_viking_fs from openviking.telemetry import bind_telemetry, resolve_telemetry +from openviking.utils.circuit_breaker import ( + CircuitBreaker, + CircuitBreakerOpen, + classify_api_error, +) from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import VikingURI from openviking_cli.utils.config import get_openviking_config @@ -82,6 +87,7 @@ def __init__(self, max_concurrent_llm: int = 100): self._dag_executor: Optional[SemanticDagExecutor] = None self._current_ctx = RequestContext(user=UserIdentifier.the_default_user(), role=Role.ROOT) self._current_msg: Optional[SemanticMsg] = None + self._circuit_breaker = CircuitBreaker() @classmethod def _cache_dag_stats(cls, telemetry_id: str, uri: str, stats: DagStats) -> None: @@ -204,6 +210,29 @@ async def _check_file_content_changed( except Exception: return True + async def _reenqueue_semantic_msg(self, msg: SemanticMsg) -> None: + """Re-enqueue a semantic message for later processing. + + Throttles with a sleep when the circuit breaker is open to prevent + re-enqueue storms (messages cycling at 5/sec during OPEN window). + """ + import asyncio + + from openviking.storage.queuefs import get_queue_manager + + # Throttle to prevent re-enqueue storm during OPEN window + wait = self._circuit_breaker.retry_after + if wait > 0: + await asyncio.sleep(wait) + + queue_manager = get_queue_manager() + if queue_manager is not None: + semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC) + await semantic_queue.enqueue(msg) + logger.info(f"Re-enqueued semantic message: {msg.uri}") + else: + logger.warning(f"No queue manager available, cannot re-enqueue: {msg.uri}") + async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Process dequeued SemanticMsg, recursively process all subdirectories.""" msg: Optional[SemanticMsg] = None @@ -219,6 +248,16 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, assert data is not None msg = SemanticMsg.from_dict(data) + # Circuit breaker: if API is known-broken, re-enqueue and wait + try: + self._circuit_breaker.check() + except CircuitBreakerOpen: + logger.warning( + f"Circuit breaker is open, re-enqueueing semantic message: {msg.uri}" + ) + await self._reenqueue_semantic_msg(msg) + self.report_success() + return None collector = resolve_telemetry(msg.telemetry_id) telemetry_ctx = bind_telemetry(collector) if collector is not None else nullcontext() with telemetry_ctx: @@ -275,13 +314,38 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, self._merge_request_stats(msg.telemetry_id, processed=1) logger.info(f"Completed semantic generation for: {msg.uri}") self.report_success() + self._circuit_breaker.record_success() return None except Exception as e: - logger.error(f"Failed to process semantic message: {e}", exc_info=True) - if msg is not None: - self._merge_request_stats(msg.telemetry_id, error_count=1) - self.report_error(str(e), data) + error_class = classify_api_error(e) + if error_class == "permanent": + logger.critical( + f"Permanent API error processing semantic message, dropping: {e}", + exc_info=True, + ) + self._circuit_breaker.record_failure(e) + if msg is not None: + self._merge_request_stats(msg.telemetry_id, error_count=1) + self.report_error(str(e), data) + else: + # Transient or unknown — re-enqueue for retry + logger.warning( + f"Transient API error processing semantic message, re-enqueueing: {e}", + exc_info=True, + ) + self._circuit_breaker.record_failure(e) + if msg is not None: + try: + await self._reenqueue_semantic_msg(msg) + except Exception as requeue_err: + logger.error(f"Failed to re-enqueue semantic message: {requeue_err}") + self._merge_request_stats(msg.telemetry_id, error_count=1) + self.report_error(str(e), data) + return None + self.report_success() + else: + self.report_error(str(e), data) return None finally: # Safety net: release lifecycle lock if still held (e.g. on exception diff --git a/openviking/utils/circuit_breaker.py b/openviking/utils/circuit_breaker.py new file mode 100644 index 00000000..6f9fbdc2 --- /dev/null +++ b/openviking/utils/circuit_breaker.py @@ -0,0 +1,149 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Circuit breaker and error classification for API call protection.""" + +from __future__ import annotations + +import threading +import time + +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + +# --- Error classification --- + +_PERMANENT_PATTERNS = ("403", "401", "Forbidden", "Unauthorized", "AccountOverdue") +_TRANSIENT_PATTERNS = ( + "429", + "500", + "502", + "503", + "504", + "TooManyRequests", + "RateLimit", + "timeout", + "Timeout", + "ConnectionError", + "Connection refused", + "Connection reset", +) + + +def classify_api_error(error: Exception) -> str: + """Classify an API error as permanent, transient, or unknown. + + Checks both str(error) and str(error.__cause__) for known patterns. + + Returns: + "permanent" — 403/401, never retry. + "transient" — 429/5xx/timeout, safe to retry. + "unknown" — unrecognized, treated as transient by callers. + """ + texts = [str(error)] + if error.__cause__ is not None: + texts.append(str(error.__cause__)) + + for text in texts: + for pattern in _PERMANENT_PATTERNS: + if pattern in text: + return "permanent" + + for text in texts: + for pattern in _TRANSIENT_PATTERNS: + if pattern in text: + return "transient" + + return "unknown" + + +# --- Circuit breaker --- + +_STATE_CLOSED = "CLOSED" +_STATE_OPEN = "OPEN" +_STATE_HALF_OPEN = "HALF_OPEN" + + +class CircuitBreakerOpen(Exception): + """Raised when the circuit breaker is open and blocking requests.""" + + +class CircuitBreaker: + """Thread-safe circuit breaker for API call protection. + + Trips after ``failure_threshold`` consecutive failures (or immediately for + permanent errors like 403/401). After ``reset_timeout`` seconds, allows one + probe request (HALF_OPEN). If the probe succeeds, the breaker closes; if it + fails, the breaker reopens. + """ + + def __init__(self, failure_threshold: int = 5, reset_timeout: float = 300): + self._failure_threshold = failure_threshold + self._reset_timeout = reset_timeout + self._lock = threading.Lock() + self._state = _STATE_CLOSED + self._failure_count = 0 + self._last_failure_time: float = 0 + + def check(self) -> None: + """Allow the request through, or raise ``CircuitBreakerOpen``.""" + with self._lock: + if self._state == _STATE_CLOSED: + return + if self._state == _STATE_HALF_OPEN: + return # allow probe request + # OPEN — check if timeout elapsed + elapsed = time.monotonic() - self._last_failure_time + if elapsed >= self._reset_timeout: + self._state = _STATE_HALF_OPEN + logger.info("Circuit breaker transitioning OPEN -> HALF_OPEN (timeout elapsed)") + return + raise CircuitBreakerOpen( + f"Circuit breaker is OPEN, retry after {self._reset_timeout - elapsed:.0f}s" + ) + + @property + def retry_after(self) -> float: + """Seconds until the breaker may transition to HALF_OPEN, capped at 30s. + + Returns 0 if the breaker is CLOSED or HALF_OPEN. + """ + with self._lock: + if self._state != _STATE_OPEN: + return 0 + remaining = self._reset_timeout - (time.monotonic() - self._last_failure_time) + return min(max(remaining, 0), 30) + + def record_success(self) -> None: + """Record a successful API call. Resets failure count.""" + with self._lock: + if self._state == _STATE_HALF_OPEN: + logger.info("Circuit breaker transitioning HALF_OPEN -> CLOSED (probe succeeded)") + self._failure_count = 0 + self._state = _STATE_CLOSED + + def record_failure(self, error: Exception) -> None: + """Record a failed API call. May trip the breaker.""" + error_class = classify_api_error(error) + with self._lock: + self._failure_count += 1 + self._last_failure_time = time.monotonic() + + if self._state == _STATE_HALF_OPEN: + self._state = _STATE_OPEN + logger.info( + f"Circuit breaker transitioning HALF_OPEN -> OPEN (probe failed: {error})" + ) + return + + if error_class == "permanent": + self._state = _STATE_OPEN + logger.info(f"Circuit breaker tripped immediately on permanent error: {error}") + return + + if self._failure_count >= self._failure_threshold: + self._state = _STATE_OPEN + logger.info( + f"Circuit breaker tripped after {self._failure_count} consecutive " + f"failures: {error}" + ) diff --git a/tests/utils/test_circuit_breaker.py b/tests/utils/test_circuit_breaker.py new file mode 100644 index 00000000..0b487405 --- /dev/null +++ b/tests/utils/test_circuit_breaker.py @@ -0,0 +1,159 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +import threading +import time + +import pytest + + +def test_circuit_breaker_starts_closed(): + from openviking.utils.circuit_breaker import CircuitBreaker + + cb = CircuitBreaker(failure_threshold=3, reset_timeout=10) + cb.check() # should not raise + + +def test_circuit_breaker_opens_after_threshold(): + from openviking.utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen + + cb = CircuitBreaker(failure_threshold=3, reset_timeout=10) + for _ in range(3): + cb.record_failure(RuntimeError("500 Internal Server Error")) + with pytest.raises(CircuitBreakerOpen): + cb.check() + + +def test_circuit_breaker_resets_on_success(): + from openviking.utils.circuit_breaker import CircuitBreaker + + cb = CircuitBreaker(failure_threshold=3, reset_timeout=10) + cb.record_failure(RuntimeError("timeout")) + cb.record_failure(RuntimeError("timeout")) + cb.record_success() # resets count + cb.record_failure(RuntimeError("timeout")) + cb.record_failure(RuntimeError("timeout")) + cb.check() # should not raise — only 2 consecutive failures + + +def test_circuit_breaker_half_open_after_timeout(monkeypatch): + from openviking.utils.circuit_breaker import CircuitBreaker + + cb = CircuitBreaker(failure_threshold=1, reset_timeout=5) + cb.record_failure(RuntimeError("500")) + # Simulate time passing — capture original before patching to avoid recursion + future = time.monotonic() + 6 + monkeypatch.setattr(time, "monotonic", lambda: future) + cb.check() # should not raise — transitions to HALF_OPEN + + +def test_circuit_breaker_half_open_success_closes(): + from openviking.utils.circuit_breaker import CircuitBreaker + + cb = CircuitBreaker(failure_threshold=1, reset_timeout=0) + cb.record_failure(RuntimeError("500")) + # reset_timeout=0 means immediate HALF_OPEN + cb.check() # transitions to HALF_OPEN + cb.record_success() # transitions to CLOSED + cb.check() # should not raise + + +def test_circuit_breaker_half_open_failure_reopens(monkeypatch): + from openviking.utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen + + cb = CircuitBreaker(failure_threshold=1, reset_timeout=5) + cb.record_failure(RuntimeError("500")) + # Fast-forward past reset_timeout to reach HALF_OPEN + future = time.monotonic() + 6 + monkeypatch.setattr(time, "monotonic", lambda: future) + cb.check() # transitions to HALF_OPEN + cb.record_failure(RuntimeError("500 again")) + # Now the breaker is OPEN again, and last_failure_time is `future`, + # so elapsed is 0 which is < reset_timeout(5) — should raise. + with pytest.raises(CircuitBreakerOpen): + cb.check() + + +def test_permanent_error_trips_immediately(): + from openviking.utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen + + cb = CircuitBreaker(failure_threshold=10, reset_timeout=10) + cb.record_failure(RuntimeError("403 Forbidden AccountOverdueError")) + with pytest.raises(CircuitBreakerOpen): + cb.check() + + +def test_retry_after_returns_capped_value(): + from openviking.utils.circuit_breaker import CircuitBreaker + + cb = CircuitBreaker(failure_threshold=1, reset_timeout=300) + cb.record_failure(RuntimeError("500")) + # retry_after should be capped at 30 + assert 0 < cb.retry_after <= 30 + + +def test_retry_after_zero_when_closed(): + from openviking.utils.circuit_breaker import CircuitBreaker + + cb = CircuitBreaker() + assert cb.retry_after == 0 + + +def test_thread_safety(): + from openviking.utils.circuit_breaker import CircuitBreaker + + cb = CircuitBreaker(failure_threshold=100, reset_timeout=300) + errors = [] + + def record_failures(): + try: + for _ in range(50): + cb.record_failure(RuntimeError("500")) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=record_failures) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + assert cb._failure_count == 200 + + +def test_classify_permanent_errors(): + from openviking.utils.circuit_breaker import classify_api_error + + assert classify_api_error(RuntimeError("403 Forbidden")) == "permanent" + assert classify_api_error(RuntimeError("AccountOverdueError: 403")) == "permanent" + assert classify_api_error(RuntimeError("401 Unauthorized")) == "permanent" + assert classify_api_error(RuntimeError("Forbidden")) == "permanent" + + +def test_classify_transient_errors(): + from openviking.utils.circuit_breaker import classify_api_error + + assert classify_api_error(RuntimeError("429 TooManyRequests")) == "transient" + assert classify_api_error(RuntimeError("RateLimitError")) == "transient" + assert classify_api_error(RuntimeError("500 Internal Server Error")) == "transient" + assert classify_api_error(RuntimeError("502 Bad Gateway")) == "transient" + assert classify_api_error(RuntimeError("503 Service Unavailable")) == "transient" + assert classify_api_error(RuntimeError("Connection timeout")) == "transient" + assert classify_api_error(RuntimeError("ConnectionError: refused")) == "transient" + + +def test_classify_unknown_errors(): + from openviking.utils.circuit_breaker import classify_api_error + + assert classify_api_error(RuntimeError("something unexpected")) == "unknown" + assert classify_api_error(ValueError("bad value")) == "unknown" + + +def test_classify_chained_exception(): + from openviking.utils.circuit_breaker import classify_api_error + + cause = RuntimeError("403 Forbidden") + wrapper = RuntimeError("API call failed") + wrapper.__cause__ = cause + assert classify_api_error(wrapper) == "permanent"