From 4905a1bf99d25eafe8318354f4d2ed4210bf399c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 24 Jun 2025 23:46:10 +0200 Subject: [PATCH 01/15] Add rate limit handler --- src/neo4j_graphrag/llm/rate_limit.py | 299 +++++++++++++++++++++++++++ 1 file changed, 299 insertions(+) create mode 100644 src/neo4j_graphrag/llm/rate_limit.py diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py new file mode 100644 index 000000000..4c310cd24 --- /dev/null +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -0,0 +1,299 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +import logging +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, Optional, TypeVar + +from neo4j_graphrag.exceptions import RateLimitError + +try: + from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + before_sleep_log, + ) + + TENACITY_AVAILABLE = True +except ImportError: + TENACITY_AVAILABLE = False + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) +AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) + + +class RateLimitHandler(ABC): + """Abstract base class for rate limit handling strategies.""" + + @abstractmethod + def handle_sync(self, func: F) -> F: + """Apply rate limit handling to a synchronous function. + + Args: + func: The function to wrap with rate limit handling. + + Returns: + The wrapped function. + """ + pass + + @abstractmethod + def handle_async(self, func: AF) -> AF: + """Apply rate limit handling to an asynchronous function. + + Args: + func: The async function to wrap with rate limit handling. + + Returns: + The wrapped async function. + """ + pass + + +class NoOpRateLimitHandler(RateLimitHandler): + """A no-op rate limit handler that does not apply any rate limiting.""" + + def handle_sync(self, func: F) -> F: + """Return the function unchanged.""" + return func + + def handle_async(self, func: AF) -> AF: + """Return the async function unchanged.""" + return func + + +class RetryRateLimitHandler(RateLimitHandler): + """Rate limit handler using exponential backoff retry strategy. + + This handler uses tenacity for retry logic with exponential backoff. + Falls back to NoOpRateLimitHandler if tenacity is not available. + + Args: + max_attempts: Maximum number of retry attempts. Defaults to 3. + min_wait: Minimum wait time between retries in seconds. Defaults to 1. + max_wait: Maximum wait time between retries in seconds. Defaults to 60. + multiplier: Exponential backoff multiplier. Defaults to 2. + """ + + def __init__( + self, + max_attempts: int = 3, + min_wait: float = 1.0, + max_wait: float = 60.0, + multiplier: float = 2.0, + ): + if not TENACITY_AVAILABLE: + logger.warning( + "tenacity is not installed. Rate limit handling will be disabled. " + "Install it with: pip install tenacity" + ) + self._fallback_handler = NoOpRateLimitHandler() + self._use_fallback = True + else: + self._use_fallback = False + self.max_attempts = max_attempts + self.min_wait = min_wait + self.max_wait = max_wait + self.multiplier = multiplier + + def handle_sync(self, func: F) -> F: + """Apply retry logic to a synchronous function.""" + if self._use_fallback: + return self._fallback_handler.handle_sync(func) + + @retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=wait_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + return wrapper # type: ignore + + def handle_async(self, func: AF) -> AF: + """Apply retry logic to an asynchronous function.""" + if self._use_fallback: + return self._fallback_handler.handle_async(func) + + @retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=wait_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + return await func(*args, **kwargs) + + return wrapper # type: ignore + + +def is_rate_limit_error(exception: Exception) -> bool: + """Check if an exception is a rate limit error from any LLM provider. + + Args: + exception: The exception to check. + + Returns: + True if the exception indicates a rate limit error, False otherwise. + """ + # Already converted to RateLimitError + if isinstance(exception, RateLimitError): + return True + + error_type = type(exception).__name__.lower() + exception_str = str(exception).lower() + + # OpenAI - specific error type + if error_type == "ratelimiterror": + return True + + # Check for HTTP 429 status code (various providers) + if hasattr(exception, "status_code") and getattr(exception, "status_code") == 429: + return True + + if hasattr(exception, "response"): + response = getattr(exception, "response") + if hasattr(response, "status_code") and response.status_code == 429: + return True + + # Provider-specific error types with message checks + rate_limit_error_types = { + "apierror": "too many requests", # Anthropic, Cohere + "sdkerror": "too many requests", # MistralAI + "responseerror": "too many requests", # Ollama + "responsevalidationerror": "resource exhausted", # VertexAI (special case) + } + + if error_type in rate_limit_error_types: + required_message = rate_limit_error_types[error_type] + return required_message in exception_str + + return False + + +def convert_to_rate_limit_error(exception: Exception) -> RateLimitError: + """Convert a provider-specific rate limit exception to RateLimitError. + + Args: + exception: The original exception from the LLM provider. + + Returns: + A RateLimitError with the original exception message. + """ + return RateLimitError(f"Rate limit exceeded: {exception}") + + +def rate_limit_handler( + handler: Optional[RateLimitHandler] = None, +) -> Callable[[F], F]: + """Decorator to apply rate limit handling to synchronous methods. + + This decorator works with instance methods and uses the instance's rate limit handler + if available, falling back to the provided handler or default. + + Args: + handler: The rate limit handler to use. If None, uses the instance's handler or default. + + Returns: + A decorator function. + """ + + def decorator(func: F) -> F: + @functools.wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler if available, otherwise use provided handler or default + instance_handler = ( + getattr(self, "_rate_limit_handler", None) + if hasattr(self, "_rate_limit_handler") + else None + ) + active_handler = handler or instance_handler or DEFAULT_RATE_LIMIT_HANDLER + + def inner_func() -> Any: + try: + return func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return active_handler.handle_sync(inner_func)() + + return wrapper # type: ignore + + return decorator + + +def async_rate_limit_handler( + handler: Optional[RateLimitHandler] = None, +) -> Callable[[AF], AF]: + """Decorator to apply rate limit handling to asynchronous methods. + + This decorator works with instance methods and uses the instance's rate limit handler + if available, falling back to the provided handler or default. + + Args: + handler: The rate limit handler to use. If None, uses the instance's handler or default. + + Returns: + A decorator function. + """ + + def decorator(func: AF) -> AF: + @functools.wraps(func) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler if available, otherwise use provided handler or default + instance_handler = ( + getattr(self, "_rate_limit_handler", None) + if hasattr(self, "_rate_limit_handler") + else None + ) + active_handler = handler or instance_handler or DEFAULT_RATE_LIMIT_HANDLER + + async def inner_func() -> Any: + try: + return await func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return await active_handler.handle_async(inner_func)() + + return wrapper # type: ignore + + return decorator + + +# Default rate limit handler instance +DEFAULT_RATE_LIMIT_HANDLER = RetryRateLimitHandler() From 5a30b4fb5e7dfb90daec484a6681bf93483a316d Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 24 Jun 2025 23:50:30 +0200 Subject: [PATCH 02/15] Update LLM interfaces --- src/neo4j_graphrag/exceptions.py | 4 ++++ src/neo4j_graphrag/llm/__init__.py | 13 +++++++++++++ src/neo4j_graphrag/llm/anthropic_llm.py | 4 +++- src/neo4j_graphrag/llm/base.py | 18 ++++++++++++++++++ src/neo4j_graphrag/llm/cohere_llm.py | 4 +++- src/neo4j_graphrag/llm/mistralai_llm.py | 5 ++++- src/neo4j_graphrag/llm/ollama_llm.py | 4 +++- src/neo4j_graphrag/llm/openai_llm.py | 13 ++++++++++--- src/neo4j_graphrag/llm/vertexai_llm.py | 4 +++- 9 files changed, 61 insertions(+), 8 deletions(-) diff --git a/src/neo4j_graphrag/exceptions.py b/src/neo4j_graphrag/exceptions.py index 681b20eec..9faffff99 100644 --- a/src/neo4j_graphrag/exceptions.py +++ b/src/neo4j_graphrag/exceptions.py @@ -138,3 +138,7 @@ class InvalidHybridSearchRankerError(Neo4jGraphRagError): class SearchQueryParseError(Neo4jGraphRagError): """Exception raised when there is a query parse error in the text search string.""" + + +class RateLimitError(LLMGenerationError): + """Exception raised when API rate limit is exceeded.""" diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index a9ece5ccb..3c4f65d9a 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -18,6 +18,13 @@ from .mistralai_llm import MistralAILLM from .ollama_llm import OllamaLLM from .openai_llm import AzureOpenAILLM, OpenAILLM +from .rate_limit import ( + RateLimitHandler, + NoOpRateLimitHandler, + RetryRateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from .types import LLMResponse from .vertexai_llm import VertexAILLM @@ -31,4 +38,10 @@ "VertexAILLM", "AzureOpenAILLM", "MistralAILLM", + # Rate limiting components + "RateLimitHandler", + "NoOpRateLimitHandler", + "RetryRateLimitHandler", + "rate_limit_handler", + "async_rate_limit_handler", ] diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 881156e3f..cb96c24e3 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -19,6 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.rate_limit import RateLimitHandler from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -62,6 +63,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): try: @@ -71,7 +73,7 @@ def __init__( """Could not import Anthropic Python client. Please install it with `pip install "neo4j-graphrag[anthropic]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.anthropic = anthropic self.client = anthropic.Anthropic(**kwargs) self.async_client = anthropic.AsyncAnthropic(**kwargs) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 87d281794..7b3114c6c 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -21,9 +21,16 @@ from neo4j_graphrag.types import LLMMessage from .types import LLMResponse, ToolCallResponse +from .rate_limit import ( + rate_limit_handler, + async_rate_limit_handler, + DEFAULT_RATE_LIMIT_HANDLER, +) from neo4j_graphrag.tool import Tool +from .rate_limit import RateLimitHandler + class LLMInterface(ABC): """Interface for large language models. @@ -31,6 +38,7 @@ class LLMInterface(ABC): Args: model_name (str): The name of the language model. model_params (Optional[dict]): Additional parameters passed to the model when text is sent to it. Defaults to None. + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. """ @@ -38,11 +46,18 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): self.model_name = model_name self.model_params = model_params or {} + if rate_limit_handler is not None: + self._rate_limit_handler = rate_limit_handler + else: + self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER + + @rate_limit_handler() @abstractmethod def invoke( self, @@ -65,6 +80,7 @@ def invoke( LLMGenerationError: If anything goes wrong. """ + @async_rate_limit_handler() @abstractmethod async def ainvoke( self, @@ -87,6 +103,7 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ + @rate_limit_handler() def invoke_with_tools( self, input: str, @@ -114,6 +131,7 @@ def invoke_with_tools( """ raise NotImplementedError("This LLM provider does not support tool calling.") + @async_rate_limit_handler() async def ainvoke_with_tools( self, input: str, diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index ecddd53ea..0bf6a1f35 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -20,6 +20,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.rate_limit import RateLimitHandler from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -60,6 +61,7 @@ def __init__( self, model_name: str = "", model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ) -> None: try: @@ -69,7 +71,7 @@ def __init__( """Could not import cohere python client. Please install it with `pip install "neo4j-graphrag[cohere]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.cohere = cohere self.cohere_api_error = cohere.core.api_error.ApiError diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 9e44287bc..5e4869e05 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -21,6 +21,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.rate_limit import RateLimitHandler from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -44,6 +45,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): """ @@ -52,6 +54,7 @@ def __init__( model_name (str): model_params (str): Parameters like temperature and such that will be passed to the chat completions endpoint + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. kwargs: All other parameters will be passed to the Mistral client. """ @@ -60,7 +63,7 @@ def __init__( """Could not import Mistral Python client. Please install it with `pip install "neo4j-graphrag[mistralai]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) api_key = kwargs.pop("api_key", None) if api_key is None: api_key = os.getenv("MISTRAL_API_KEY", "") diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 5abb13d8f..80dd2c4e5 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -23,6 +23,7 @@ from neo4j_graphrag.types import LLMMessage from .base import LLMInterface +from .rate_limit import RateLimitHandler from .types import ( BaseMessage, LLMResponse, @@ -40,6 +41,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): try: @@ -49,7 +51,7 @@ def __init__( "Could not import ollama Python client. " "Please install it with `pip install ollama`." ) - super().__init__(model_name, model_params, **kwargs) + super().__init__(model_name, model_params, rate_limit_handler) self.ollama = ollama self.client = ollama.Client( **kwargs, diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 1e0228e45..3f77a87cf 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -39,6 +39,7 @@ from ..exceptions import LLMGenerationError from .base import LLMInterface +from .rate_limit import RateLimitHandler from .types import ( BaseMessage, LLMResponse, @@ -63,6 +64,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, ): """ Base class for OpenAI LLM. @@ -72,6 +74,7 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. """ try: import openai @@ -81,7 +84,7 @@ def __init__( Please install it with `pip install "neo4j-graphrag[openai]"`.""" ) self.openai = openai - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) def get_messages( self, @@ -347,6 +350,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): """OpenAI LLM @@ -356,9 +360,10 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.client = self.openai.OpenAI(**kwargs) self.async_client = self.openai.AsyncOpenAI(**kwargs) @@ -369,6 +374,7 @@ def __init__( model_name: str, model_params: Optional[dict[str, Any]] = None, system_instruction: Optional[str] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): """Azure OpenAI LLM. Use this class when using an OpenAI model @@ -377,8 +383,9 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.client = self.openai.AzureOpenAI(**kwargs) self.async_client = self.openai.AsyncAzureOpenAI(**kwargs) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 39d483915..ce695872e 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -19,6 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.rate_limit import RateLimitHandler from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -78,6 +79,7 @@ def __init__( model_name: str = "gemini-1.5-flash-001", model_params: Optional[dict[str, Any]] = None, system_instruction: Optional[str] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): if GenerativeModel is None or ResponseValidationError is None: @@ -85,7 +87,7 @@ def __init__( """Could not import Vertex AI Python client. Please install it with `pip install "neo4j-graphrag[google]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.model_name = model_name self.system_instruction = system_instruction self.options = kwargs From 7cab32019dc19eb1041cb234e148f8e667efc80f Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 25 Jun 2025 10:55:58 +0200 Subject: [PATCH 03/15] Update Changelog and docs --- CHANGELOG.md | 1 + docs/source/api.rst | 31 ++++++++++++++++ docs/source/user_guide_rag.rst | 65 ++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c5eb72654..49de506c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Support for Python 3.13 - Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default. - Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call). +- Added automatic rate limiting with retry logic and exponential backoff for all LLM providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely. ### Fixed diff --git a/docs/source/api.rst b/docs/source/api.rst index 55a5d1cc4..d8280b1cc 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -347,6 +347,28 @@ MistralAILLM :members: +Rate Limiting +============= + +RateLimitHandler +---------------- + +.. autoclass:: neo4j_graphrag.llm.rate_limit.RateLimitHandler + :members: + +RetryRateLimitHandler +--------------------- + +.. autoclass:: neo4j_graphrag.llm.rate_limit.RetryRateLimitHandler + :members: + +NoOpRateLimitHandler +-------------------- + +.. autoclass:: neo4j_graphrag.llm.rate_limit.NoOpRateLimitHandler + :members: + + PromptTemplate ============== @@ -473,6 +495,8 @@ Errors * :class:`neo4j_graphrag.exceptions.LLMGenerationError` + * :class:`neo4j_graphrag.exceptions.RateLimitError` + * :class:`neo4j_graphrag.exceptions.SchemaValidationError` * :class:`neo4j_graphrag.exceptions.PdfLoaderError` @@ -597,6 +621,13 @@ LLMGenerationError :show-inheritance: +RateLimitError +============== + +.. autoclass:: neo4j_graphrag.exceptions.RateLimitError + :show-inheritance: + + SchemaValidationError ===================== diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 1ad76ef91..49f76fc19 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -294,6 +294,71 @@ Here's an example using the Python Ollama client: See :ref:`llminterface`. +Rate Limit Handling +=================== + +All LLM implementations include automatic rate limiting that uses retry logic with exponential backoff by default. This feature helps handle API rate limits from LLM providers gracefully by automatically retrying failed requests with increasing wait times between attempts. + +Default Rate Limit Handler +-------------------------- + +Rate limiting is enabled by default for all LLM instances with the following configuration: + +- **Max attempts**: 3 +- **Min wait**: 1.0 seconds +- **Max wait**: 60.0 seconds +- **Multiplier**: 2.0 (exponential backoff) + +.. code:: python + + from neo4j_graphrag.llm import OpenAILLM + + # Rate limiting is automatically enabled + llm = OpenAILLM(model_name="gpt-4o") + + # The LLM will automatically retry on rate limit errors + response = llm.invoke("Hello, world!") + +Custom Rate Limiting +-------------------- + +You can customize the rate limiting behavior by creating your own rate limit handler: + +.. code:: python + + from neo4j_graphrag.llm import AnthropicLLM + from neo4j_graphrag.llm.rate_limit import RateLimitHandler + + class CustomRateLimitHandler(RateLimitHandler): + """Implement your custom rate limiting strategy.""" + # Implement required methods: handle_sync, handle_async + pass + + # Create custom rate limit handler and pass it to the LLM interface + custom_handler = CustomRateLimitHandler() + + llm = AnthropicLLM( + model_name="claude-3-sonnet-20240229", + rate_limit_handler=custom_handler, + ) + +Disabling Rate Limiting +----------------------- + +For high-throughput applications or when you handle rate limiting externally, you can disable it: + +.. code:: python + + from neo4j_graphrag.llm import CohereLLM, NoOpRateLimitHandler + + # Disable rate limiting completely + llm = CohereLLM( + model_name="command-r-plus", + rate_limit_handler=NoOpRateLimitHandler(), + ) + llm.invoke("Hello, world!") + + Configuring the Prompt ======================== From 0f4dc41fec2e1b18668a836a7a859bf11722e107 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 26 Jun 2025 11:03:08 +0200 Subject: [PATCH 04/15] Add unit tests for rate limit handler --- tests/unit/llm/test_rate_limit.py | 173 ++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 tests/unit/llm/test_rate_limit.py diff --git a/tests/unit/llm/test_rate_limit.py b/tests/unit/llm/test_rate_limit.py new file mode 100644 index 000000000..b677544f5 --- /dev/null +++ b/tests/unit/llm/test_rate_limit.py @@ -0,0 +1,173 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Awaitable + +import pytest +from unittest.mock import Mock +from tenacity import RetryError + +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + NoOpRateLimitHandler, + DEFAULT_RATE_LIMIT_HANDLER, +) +from neo4j_graphrag.exceptions import RateLimitError + + +def test_default_handler_retries_sync() -> None: + call_count = 0 + + def mock_func() -> None: + nonlocal call_count + call_count += 1 + raise RateLimitError("Rate limit exceeded") + + wrapped_func = DEFAULT_RATE_LIMIT_HANDLER.handle_sync(mock_func) + + with pytest.raises(RetryError): + wrapped_func() + + assert call_count == 3 + + +@pytest.mark.asyncio +async def test_default_handler_retries_async() -> None: + call_count = 0 + + async def mock_func() -> None: + nonlocal call_count + call_count += 1 + raise RateLimitError("Rate limit exceeded") + + wrapped_func = DEFAULT_RATE_LIMIT_HANDLER.handle_async(mock_func) + + with pytest.raises(RetryError): + await wrapped_func() + + assert call_count == 3 + + +def test_other_errors_pass_through_sync() -> None: + call_count = 0 + + def mock_func() -> None: + nonlocal call_count + call_count += 1 + raise ValueError("Some other error") + + wrapped_func = DEFAULT_RATE_LIMIT_HANDLER.handle_sync(mock_func) + + with pytest.raises(ValueError): + wrapped_func() + + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_other_errors_pass_through_async() -> None: + call_count = 0 + + async def mock_func() -> None: + nonlocal call_count + call_count += 1 + raise ValueError("Some other error") + + wrapped_func = DEFAULT_RATE_LIMIT_HANDLER.handle_async(mock_func) + + with pytest.raises(ValueError): + await wrapped_func() + + assert call_count == 1 + + +def test_noop_handler_sync() -> None: + def mock_func() -> str: + return "test result" + + handler = NoOpRateLimitHandler() + wrapped_func = handler.handle_sync(mock_func) + + assert wrapped_func() == "test result" + assert wrapped_func is mock_func + + +@pytest.mark.asyncio +async def test_noop_handler_async() -> None: + async def mock_func() -> str: + return "async test result" + + handler = NoOpRateLimitHandler() + wrapped_func = handler.handle_async(mock_func) + + assert await wrapped_func() == "async test result" + assert wrapped_func is mock_func + + +def test_custom_handler_sync_retry_override() -> None: + call_count = 0 + + def mock_func() -> str: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RateLimitError("Rate limit exceeded") + return "success after custom retry" + + # Custom handler with single retry + def custom_handle_sync(func: Callable[[], Any]) -> Callable[[], Any]: + def wrapper() -> Any: + try: + return func() + except RateLimitError: + return func() # Retry once + + return wrapper + + handler = Mock(spec=RateLimitHandler) + handler.handle_sync = custom_handle_sync + + result = handler.handle_sync(mock_func)() + assert result == "success after custom retry" + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_custom_handler_async_retry_override() -> None: + call_count = 0 + + async def mock_func() -> str: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RateLimitError("Rate limit exceeded") + return "success after custom retry" + + # Custom handler with single retry + def custom_handle_async(func: Callable[[], Awaitable[Any]]) -> Callable[[], Awaitable[Any]]: + async def wrapper() -> Any: + try: + return await func() + except RateLimitError: + return await func() # Retry once + + return wrapper + + handler = Mock(spec=RateLimitHandler) + handler.handle_async = custom_handle_async + + result = await handler.handle_async(mock_func)() + assert result == "success after custom retry" + assert call_count == 2 From 6eefd0f5c256b5b268aa7c9b67583edc398cc957 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 26 Jun 2025 14:36:42 +0200 Subject: [PATCH 05/15] Improve rate limit handler --- src/neo4j_graphrag/llm/rate_limit.py | 138 ++++++++++----------------- 1 file changed, 51 insertions(+), 87 deletions(-) diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py index 4c310cd24..cf457dc31 100644 --- a/src/neo4j_graphrag/llm/rate_limit.py +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -21,18 +21,13 @@ from neo4j_graphrag.exceptions import RateLimitError -try: - from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, - before_sleep_log, - ) - - TENACITY_AVAILABLE = True -except ImportError: - TENACITY_AVAILABLE = False +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + before_sleep_log, +) logger = logging.getLogger(__name__) @@ -100,24 +95,13 @@ def __init__( max_wait: float = 60.0, multiplier: float = 2.0, ): - if not TENACITY_AVAILABLE: - logger.warning( - "tenacity is not installed. Rate limit handling will be disabled. " - "Install it with: pip install tenacity" - ) - self._fallback_handler = NoOpRateLimitHandler() - self._use_fallback = True - else: - self._use_fallback = False - self.max_attempts = max_attempts - self.min_wait = min_wait - self.max_wait = max_wait - self.multiplier = multiplier + self.max_attempts = max_attempts + self.min_wait = min_wait + self.max_wait = max_wait + self.multiplier = multiplier def handle_sync(self, func: F) -> F: """Apply retry logic to a synchronous function.""" - if self._use_fallback: - return self._fallback_handler.handle_sync(func) @retry( retry=retry_if_exception_type(RateLimitError), @@ -137,8 +121,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: def handle_async(self, func: AF) -> AF: """Apply retry logic to an asynchronous function.""" - if self._use_fallback: - return self._fallback_handler.handle_async(func) @retry( retry=retry_if_exception_type(RateLimitError), @@ -213,86 +195,68 @@ def convert_to_rate_limit_error(exception: Exception) -> RateLimitError: return RateLimitError(f"Rate limit exceeded: {exception}") -def rate_limit_handler( - handler: Optional[RateLimitHandler] = None, -) -> Callable[[F], F]: +def rate_limit_handler(func: F) -> F: """Decorator to apply rate limit handling to synchronous methods. - This decorator works with instance methods and uses the instance's rate limit handler - if available, falling back to the provided handler or default. + This decorator works with instance methods and uses the instance's rate limit handler. Args: - handler: The rate limit handler to use. If None, uses the instance's handler or default. + func: The function to wrap with rate limit handling. Returns: - A decorator function. + The wrapped function. """ - def decorator(func: F) -> F: - @functools.wraps(func) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - # Use instance handler if available, otherwise use provided handler or default - instance_handler = ( - getattr(self, "_rate_limit_handler", None) - if hasattr(self, "_rate_limit_handler") - else None - ) - active_handler = handler or instance_handler or DEFAULT_RATE_LIMIT_HANDLER - - def inner_func() -> Any: - try: - return func(self, *args, **kwargs) - except Exception as e: - if is_rate_limit_error(e): - raise convert_to_rate_limit_error(e) - raise - - return active_handler.handle_sync(inner_func)() + @functools.wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler or default + active_handler = getattr( + self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER + ) - return wrapper # type: ignore + def inner_func() -> Any: + try: + return func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return active_handler.handle_sync(inner_func)() - return decorator + return wrapper # type: ignore -def async_rate_limit_handler( - handler: Optional[RateLimitHandler] = None, -) -> Callable[[AF], AF]: +def async_rate_limit_handler(func: AF) -> AF: """Decorator to apply rate limit handling to asynchronous methods. - This decorator works with instance methods and uses the instance's rate limit handler - if available, falling back to the provided handler or default. + This decorator works with instance methods and uses the instance's rate limit handler. Args: - handler: The rate limit handler to use. If None, uses the instance's handler or default. + func: The async function to wrap with rate limit handling. Returns: - A decorator function. + The wrapped async function. """ - def decorator(func: AF) -> AF: - @functools.wraps(func) - async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - # Use instance handler if available, otherwise use provided handler or default - instance_handler = ( - getattr(self, "_rate_limit_handler", None) - if hasattr(self, "_rate_limit_handler") - else None - ) - active_handler = handler or instance_handler or DEFAULT_RATE_LIMIT_HANDLER - - async def inner_func() -> Any: - try: - return await func(self, *args, **kwargs) - except Exception as e: - if is_rate_limit_error(e): - raise convert_to_rate_limit_error(e) - raise - - return await active_handler.handle_async(inner_func)() + @functools.wraps(func) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler or default + active_handler = getattr( + self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER + ) - return wrapper # type: ignore + async def inner_func() -> Any: + try: + return await func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return await active_handler.handle_async(inner_func)() - return decorator + return wrapper # type: ignore # Default rate limit handler instance From 465732cf7c1a96e44d5cf68597387591db16acd5 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 26 Jun 2025 14:38:26 +0200 Subject: [PATCH 06/15] Remove decorators from absract methods and add them to methods of LLM provider classes --- src/neo4j_graphrag/llm/anthropic_llm.py | 8 +++++++- src/neo4j_graphrag/llm/base.py | 6 ------ src/neo4j_graphrag/llm/cohere_llm.py | 8 +++++++- src/neo4j_graphrag/llm/mistralai_llm.py | 8 +++++++- src/neo4j_graphrag/llm/ollama_llm.py | 4 +++- src/neo4j_graphrag/llm/openai_llm.py | 6 +++++- src/neo4j_graphrag/llm/vertexai_llm.py | 8 +++++++- 7 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index cb96c24e3..6bafef85b 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -19,7 +19,11 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -95,6 +99,7 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore + @rate_limit_handler def invoke( self, input: str, @@ -131,6 +136,7 @@ def invoke( except self.anthropic.APIError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 7b3114c6c..cca710bc9 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -22,8 +22,6 @@ from .types import LLMResponse, ToolCallResponse from .rate_limit import ( - rate_limit_handler, - async_rate_limit_handler, DEFAULT_RATE_LIMIT_HANDLER, ) @@ -57,7 +55,6 @@ def __init__( else: self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER - @rate_limit_handler() @abstractmethod def invoke( self, @@ -80,7 +77,6 @@ def invoke( LLMGenerationError: If anything goes wrong. """ - @async_rate_limit_handler() @abstractmethod async def ainvoke( self, @@ -103,7 +99,6 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ - @rate_limit_handler() def invoke_with_tools( self, input: str, @@ -131,7 +126,6 @@ def invoke_with_tools( """ raise NotImplementedError("This LLM provider does not support tool calling.") - @async_rate_limit_handler() async def ainvoke_with_tools( self, input: str, diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 0bf6a1f35..7c3905500 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -20,7 +20,11 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -98,6 +102,7 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore + @rate_limit_handler def invoke( self, input: str, @@ -129,6 +134,7 @@ def invoke( content=res.message.content[0].text if res.message.content else "", ) + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 5e4869e05..ae2a6312f 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -21,7 +21,11 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -89,6 +93,7 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return cast(list[Messages], messages) + @rate_limit_handler def invoke( self, input: str, @@ -127,6 +132,7 @@ def invoke( except SDKError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 80dd2c4e5..6c4728888 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -23,7 +23,7 @@ from neo4j_graphrag.types import LLMMessage from .base import LLMInterface -from .rate_limit import RateLimitHandler +from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler from .types import ( BaseMessage, LLMResponse, @@ -80,6 +80,7 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore + @rate_limit_handler def invoke( self, input: str, @@ -110,6 +111,7 @@ def invoke( except self.ollama.ResponseError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 3f77a87cf..ed8af1958 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -39,7 +39,7 @@ from ..exceptions import LLMGenerationError from .base import LLMInterface -from .rate_limit import RateLimitHandler +from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler from .types import ( BaseMessage, LLMResponse, @@ -127,6 +127,7 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: except AttributeError: raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") + @rate_limit_handler def invoke( self, input: str, @@ -161,6 +162,7 @@ def invoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + @rate_limit_handler def invoke_with_tools( self, input: str, @@ -235,6 +237,7 @@ def invoke_with_tools( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke( self, input: str, @@ -269,6 +272,7 @@ async def ainvoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke_with_tools( self, input: str, diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index ce695872e..5b772c35b 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -19,7 +19,11 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -123,6 +127,7 @@ def get_messages( messages.append(Content(role="user", parts=[Part.from_text(input)])) return messages + @rate_limit_handler def invoke( self, input: str, @@ -152,6 +157,7 @@ def invoke( except ResponseValidationError as e: raise LLMGenerationError("Error calling VertexAILLM") from e + @async_rate_limit_handler async def ainvoke( self, input: str, From 073e31f079c9181377a536078503f13c422b7393 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 26 Jun 2025 14:38:59 +0200 Subject: [PATCH 07/15] Improve documentation --- docs/source/user_guide_rag.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 49f76fc19..937a7099c 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -319,6 +319,26 @@ Rate limiting is enabled by default for all LLM instances with the following con # The LLM will automatically retry on rate limit errors response = llm.invoke("Hello, world!") +.. note:: + + To change the default configuration of `RetryRateLimitHandler`: + + .. code:: python + + from neo4j_graphrag.llm import OpenAILLM + from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler + + # Customize rate limiting parameters + llm = OpenAILLM( + model_name="gpt-4o", + rate_limit_handler=RetryRateLimitHandler( + max_attempts=10, # Increase max retry attempts + min_wait=2.0, # Increase minimum wait time + max_wait=120.0, # Increase maximum wait time + multiplier=3.0 # More aggressive backoff + ) + ) + Custom Rate Limiting -------------------- From af3a5d2774e067d7cc6e0a3c36b8fe81f1c48a12 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 30 Jun 2025 17:51:16 +0200 Subject: [PATCH 08/15] Improve wait strategy for concurrent mode --- src/neo4j_graphrag/llm/rate_limit.py | 59 ++++++++++++++++------------ 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py index cf457dc31..af6dacfb9 100644 --- a/src/neo4j_graphrag/llm/rate_limit.py +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -17,7 +17,7 @@ import functools import logging from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, Optional, TypeVar +from typing import Any, Awaitable, Callable, TypeVar from neo4j_graphrag.exceptions import RateLimitError @@ -25,10 +25,12 @@ retry, stop_after_attempt, wait_exponential, + wait_random_exponential, retry_if_exception_type, before_sleep_log, ) + logger = logging.getLogger(__name__) F = TypeVar("F", bound=Callable[..., Any]) @@ -79,13 +81,13 @@ class RetryRateLimitHandler(RateLimitHandler): """Rate limit handler using exponential backoff retry strategy. This handler uses tenacity for retry logic with exponential backoff. - Falls back to NoOpRateLimitHandler if tenacity is not available. Args: max_attempts: Maximum number of retry attempts. Defaults to 3. min_wait: Minimum wait time between retries in seconds. Defaults to 1. max_wait: Maximum wait time between retries in seconds. Defaults to 60. multiplier: Exponential backoff multiplier. Defaults to 2. + jitter: Whether to add random jitter to retry delays to prevent thundering herd. Defaults to True. """ def __init__( @@ -94,49 +96,54 @@ def __init__( min_wait: float = 1.0, max_wait: float = 60.0, multiplier: float = 2.0, + jitter: bool = True, ): self.max_attempts = max_attempts self.min_wait = min_wait self.max_wait = max_wait self.multiplier = multiplier + self.jitter = jitter - def handle_sync(self, func: F) -> F: - """Apply retry logic to a synchronous function.""" + def _get_wait_strategy(self) -> Any: + """Get the appropriate wait strategy based on jitter setting. - @retry( - retry=retry_if_exception_type(RateLimitError), - stop=stop_after_attempt(self.max_attempts), - wait=wait_exponential( + Returns: + The configured wait strategy for tenacity retry. + """ + if self.jitter: + # Use built-in random exponential backoff with jitter + return wait_random_exponential( multiplier=self.multiplier, min=self.min_wait, max=self.max_wait, - ), + ) + else: + # Use standard exponential backoff without jitter + return wait_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ) + + def handle_sync(self, func: F) -> F: + """Apply retry logic to a synchronous function.""" + decorator = retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=self._get_wait_strategy(), before_sleep=before_sleep_log(logger, logging.WARNING), ) - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - return func(*args, **kwargs) - - return wrapper # type: ignore + return decorator(func) def handle_async(self, func: AF) -> AF: """Apply retry logic to an asynchronous function.""" - - @retry( + decorator = retry( retry=retry_if_exception_type(RateLimitError), stop=stop_after_attempt(self.max_attempts), - wait=wait_exponential( - multiplier=self.multiplier, - min=self.min_wait, - max=self.max_wait, - ), + wait=self._get_wait_strategy(), before_sleep=before_sleep_log(logger, logging.WARNING), ) - @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> Any: - return await func(*args, **kwargs) - - return wrapper # type: ignore + return decorator(func) def is_rate_limit_error(exception: Exception) -> bool: From ddc71fb61bd73eb663a823a9ca7ff578237d4881 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 3 Jul 2025 11:19:24 +0200 Subject: [PATCH 09/15] Fix tenacity dependency --- poetry.lock | 10 +++++----- pyproject.toml | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index cbd375ac0..5cb194fcb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -4028,8 +4028,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4942,8 +4942,8 @@ grpcio = ">=1.41.0" httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ {version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""}, - {version = ">=1.26", markers = "python_version == \"3.12\""}, {version = ">=1.21,<2.1.0", markers = "python_version < \"3.10\""}, + {version = ">=1.26", markers = "python_version == \"3.12\""}, {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, ] portalocker = ">=2.7.0,<3.0.0" @@ -6281,7 +6281,7 @@ widechars = ["wcwidth"] name = "tenacity" version = "9.1.2" description = "Retry code until it succeeds" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138"}, @@ -7370,4 +7370,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.14" -content-hash = "f53f3dfff909ce5fadc0f38896354f2952cc22098bd2dcd043a7de8e89026375" +content-hash = "83b68416feaf289d06e1af48ec8b7a3ac20ec0585be6d80f5bb0fb5b7deda025" diff --git a/pyproject.toml b/pyproject.toml index 320fc11e0..b44c2fa64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ scipy = [ { version = "^1.13.0", python = ">=3.9,<3.13" }, { version = "^1.15.0", python = ">=3.13,<3.14" } ] +tenacity = "^9.1.2" [tool.poetry.group.dev.dependencies] urllib3 = "<2" From 843ac1a5000eadf4f26badb1554a823cba8c5ca4 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 3 Jul 2025 11:25:42 +0200 Subject: [PATCH 10/15] Ruff --- tests/unit/llm/test_rate_limit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/llm/test_rate_limit.py b/tests/unit/llm/test_rate_limit.py index b677544f5..f1f4b133b 100644 --- a/tests/unit/llm/test_rate_limit.py +++ b/tests/unit/llm/test_rate_limit.py @@ -156,7 +156,9 @@ async def mock_func() -> str: return "success after custom retry" # Custom handler with single retry - def custom_handle_async(func: Callable[[], Awaitable[Any]]) -> Callable[[], Awaitable[Any]]: + def custom_handle_async( + func: Callable[[], Awaitable[Any]], + ) -> Callable[[], Awaitable[Any]]: async def wrapper() -> Any: try: return await func() From b7bf2223c969c401f19531664b43898adb594bcb Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 7 Jul 2025 09:54:01 +0200 Subject: [PATCH 11/15] Simplify is rate limit error --- src/neo4j_graphrag/llm/rate_limit.py | 40 +++++++++------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py index af6dacfb9..098597f78 100644 --- a/src/neo4j_graphrag/llm/rate_limit.py +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -155,37 +155,21 @@ def is_rate_limit_error(exception: Exception) -> bool: Returns: True if the exception indicates a rate limit error, False otherwise. """ - # Already converted to RateLimitError - if isinstance(exception, RateLimitError): - return True - error_type = type(exception).__name__.lower() exception_str = str(exception).lower() - # OpenAI - specific error type - if error_type == "ratelimiterror": - return True - - # Check for HTTP 429 status code (various providers) - if hasattr(exception, "status_code") and getattr(exception, "status_code") == 429: - return True - - if hasattr(exception, "response"): - response = getattr(exception, "response") - if hasattr(response, "status_code") and response.status_code == 429: - return True - - # Provider-specific error types with message checks - rate_limit_error_types = { - "apierror": "too many requests", # Anthropic, Cohere - "sdkerror": "too many requests", # MistralAI - "responseerror": "too many requests", # Ollama - "responsevalidationerror": "resource exhausted", # VertexAI (special case) - } - - if error_type in rate_limit_error_types: - required_message = rate_limit_error_types[error_type] - return required_message in exception_str + # For LLMGenerationError (which wraps all provider errors), check provider-specific patterns + if error_type == "llmgenerationerror": + # Check for various rate limit patterns from different providers + rate_limit_patterns = [ + "error code: 429", # Azure OpenAI + "too many requests", # Anthropic, Cohere, MistralAI + "resource exhausted", # VertexAI + "rate limit", # Generic rate limit messages + "429", # Generic rate limit messages + ] + + return any(pattern in exception_str for pattern in rate_limit_patterns) return False From 204a52e3499ad19d01b7849e0cc2cb327789b11a Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 7 Jul 2025 12:28:10 +0200 Subject: [PATCH 12/15] Update doc related to VertexAILLM --- docs/source/user_guide_rag.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 937a7099c..c1402db88 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -125,7 +125,7 @@ To use VertexAI, instantiate the `VertexAILLM` class: generation_config = GenerationConfig(temperature=0.0) llm = VertexAILLM( - model_name="gemini-1.5-flash-001", generation_config=generation_config + model_name="gemini-2.5-flash", generation_config=generation_config ) llm.invoke("say something") @@ -133,7 +133,7 @@ To use VertexAI, instantiate the `VertexAILLM` class: .. note:: In order to run this code, the `google-cloud-aiplatform` Python package needs to be installed: - `pip install "neo4j_grpahrag[vertexai]"` + `pip install "neo4j_graphrag[google]"` See :ref:`vertexaillm`. From dca21783d6ee76c8a55c05df3a2ea32707a96cde Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 9 Jul 2025 12:47:12 +0200 Subject: [PATCH 13/15] Update custom_llm.py --- examples/customize/llms/custom_llm.py | 38 +++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 9366243e6..d3cbb87fe 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -1,8 +1,13 @@ import random import string -from typing import Any, List, Optional, Union +from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union from neo4j_graphrag.llm import LLMInterface, LLMResponse +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage @@ -13,6 +18,8 @@ def __init__( ): super().__init__(model_name, **kwargs) + # Optional: Apply rate limit handling to synchronous invoke method + # @rate_limit_handler def invoke( self, input: str, @@ -24,6 +31,8 @@ def invoke( ) return LLMResponse(content=content) + # Optional: Apply rate limit handling to asynchronous ainvoke method + # @async_rate_limit_handler async def ainvoke( self, input: str, @@ -33,6 +42,31 @@ async def ainvoke( raise NotImplementedError() -llm = CustomLLM("") +llm = CustomLLM( + "" +) # if rate_limit_handler and async_rate_limit_handler decorators are used, the default rate limit handler will be applied automatically (retry with exponential backoff) res: LLMResponse = llm.invoke("text") print(res.content) + +# If rate_limit_handler and async_rate_limit_handler decorators are used and you want to use a custom rate limit handler +# Type variables for function signatures used in rate limit handlers +# F = TypeVar("F", bound=Callable[..., Any]) +# AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) + + +# class CustomRateLimitHandler(RateLimitHandler): +# def __init__(self): +# super().__init__() + +# def handle_sync(self, func: F) -> F: +# # error handling here +# return func + +# def handle_async(self, func: AF) -> AF: +# # error handling here +# return func + + +# llm = CustomLLM("", rate_limit_handler=CustomRateLimitHandler()) +# res: LLMResponse = llm.invoke("text") +# print(res.content) From bc69242bf8262b036f0b1d6a06b5171c3e4260cc Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 9 Jul 2025 14:38:32 +0200 Subject: [PATCH 14/15] Fix linter issues --- examples/customize/llms/custom_llm.py | 30 ++++++++++++++------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index d3cbb87fe..3bab6b0b8 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -50,23 +50,25 @@ async def ainvoke( # If rate_limit_handler and async_rate_limit_handler decorators are used and you want to use a custom rate limit handler # Type variables for function signatures used in rate limit handlers -# F = TypeVar("F", bound=Callable[..., Any]) -# AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) +F = TypeVar("F", bound=Callable[..., Any]) +AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) -# class CustomRateLimitHandler(RateLimitHandler): -# def __init__(self): -# super().__init__() +class CustomRateLimitHandler(RateLimitHandler): + def __init__(self) -> None: + super().__init__() -# def handle_sync(self, func: F) -> F: -# # error handling here -# return func + def handle_sync(self, func: F) -> F: + # error handling here + return func -# def handle_async(self, func: AF) -> AF: -# # error handling here -# return func + def handle_async(self, func: AF) -> AF: + # error handling here + return func -# llm = CustomLLM("", rate_limit_handler=CustomRateLimitHandler()) -# res: LLMResponse = llm.invoke("text") -# print(res.content) +llm_with_custom_rate_limit_handler = CustomLLM( + "", rate_limit_handler=CustomRateLimitHandler() +) +result: LLMResponse = llm_with_custom_rate_limit_handler.invoke("text") +print(result.content) From f20c514dc044a19d503eedec3365776f5d85ee53 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 9 Jul 2025 14:47:08 +0200 Subject: [PATCH 15/15] Fix more linter issues --- examples/customize/llms/custom_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 3bab6b0b8..0eecfd878 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -5,8 +5,8 @@ from neo4j_graphrag.llm import LLMInterface, LLMResponse from neo4j_graphrag.llm.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, + # rate_limit_handler, + # async_rate_limit_handler, ) from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage