diff --git a/libs/core/langchain_core/tracers/root_listeners.py b/libs/core/langchain_core/tracers/root_listeners.py index 923cd1c16f691..24602806e7a99 100644 --- a/libs/core/langchain_core/tracers/root_listeners.py +++ b/libs/core/langchain_core/tracers/root_listeners.py @@ -1,7 +1,8 @@ """Tracers that call listeners.""" -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Callable, Optional, Union +import contextvars from langchain_core.runnables.config import ( RunnableConfig, @@ -11,28 +12,41 @@ from langchain_core.tracers.base import AsyncBaseTracer, BaseTracer from langchain_core.tracers.schemas import Run +# Re-entrancy guard to prevent listeners from triggering nested listener calls +_IN_LISTENER: contextvars.ContextVar[bool] = contextvars.ContextVar( + "root_listeners_in_listener", default=False +) + if TYPE_CHECKING: from uuid import UUID -Listener = Callable[[Run], None] | Callable[[Run, RunnableConfig], None] -AsyncListener = ( - Callable[[Run], Awaitable[None]] | Callable[[Run, RunnableConfig], Awaitable[None]] -) +Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] +AsyncListener = Union[ + Callable[[Run], Awaitable[None]], Callable[[Run, RunnableConfig], Awaitable[None]] +] class RootListenersTracer(BaseTracer): - """Tracer that calls listeners on run start, end, and error.""" + """Tracer that calls listeners on run start, end, and error. + + Parameters: + log_missing_parent: Whether to log a warning if the parent is missing. + Default is False. + config: The runnable config. + on_start: The listener to call on run start. + on_end: The listener to call on run end. + on_error: The listener to call on run error. + """ log_missing_parent = False - """Whether to log a warning if the parent is missing.""" def __init__( self, *, config: RunnableConfig, - on_start: Listener | None, - on_end: Listener | None, - on_error: Listener | None, + on_start: Optional[Listener], + on_end: Optional[Listener], + on_error: Optional[Listener], ) -> None: """Initialize the tracer. @@ -48,7 +62,7 @@ def __init__( self._arg_on_start = on_start self._arg_on_end = on_end self._arg_on_error = on_error - self.root_id: UUID | None = None + self.root_id: Optional[UUID] = None def _persist_run(self, run: Run) -> None: # This is a legacy method only called once for an entire run tree @@ -61,33 +75,55 @@ def _on_run_create(self, run: Run) -> None: self.root_id = run.id - if self._arg_on_start is not None: - call_func_with_variable_args(self._arg_on_start, run, self.config) + if self._arg_on_start is not None and not _IN_LISTENER.get(): + token = _IN_LISTENER.set(True) + try: + call_func_with_variable_args(self._arg_on_start, run, self.config) + finally: + _IN_LISTENER.reset(token) def _on_run_update(self, run: Run) -> None: if run.id != self.root_id: return + if _IN_LISTENER.get(): + return if run.error is None: if self._arg_on_end is not None: - call_func_with_variable_args(self._arg_on_end, run, self.config) + token = _IN_LISTENER.set(True) + try: + call_func_with_variable_args(self._arg_on_end, run, self.config) + finally: + _IN_LISTENER.reset(token) elif self._arg_on_error is not None: - call_func_with_variable_args(self._arg_on_error, run, self.config) + token = _IN_LISTENER.set(True) + try: + call_func_with_variable_args(self._arg_on_error, run, self.config) + finally: + _IN_LISTENER.reset(token) class AsyncRootListenersTracer(AsyncBaseTracer): - """Async Tracer that calls listeners on run start, end, and error.""" + """Async Tracer that calls listeners on run start, end, and error. + + Parameters: + log_missing_parent: Whether to log a warning if the parent is missing. + Default is False. + config: The runnable config. + on_start: The listener to call on run start. + on_end: The listener to call on run end. + on_error: The listener to call on run error. + """ log_missing_parent = False - """Whether to log a warning if the parent is missing.""" def __init__( self, *, config: RunnableConfig, - on_start: AsyncListener | None, - on_end: AsyncListener | None, - on_error: AsyncListener | None, + on_start: Optional[AsyncListener], + on_end: Optional[AsyncListener], + on_error: Optional[AsyncListener], ) -> None: """Initialize the tracer. @@ -103,7 +139,7 @@ def __init__( self._arg_on_start = on_start self._arg_on_end = on_end self._arg_on_error = on_error - self.root_id: UUID | None = None + self.root_id: Optional[UUID] = None async def _persist_run(self, run: Run) -> None: # This is a legacy method only called once for an entire run tree @@ -116,15 +152,29 @@ async def _on_run_create(self, run: Run) -> None: self.root_id = run.id - if self._arg_on_start is not None: - await acall_func_with_variable_args(self._arg_on_start, run, self.config) + if self._arg_on_start is not None and not _IN_LISTENER.get(): + token = _IN_LISTENER.set(True) + try: + await acall_func_with_variable_args(self._arg_on_start, run, self.config) + finally: + _IN_LISTENER.reset(token) async def _on_run_update(self, run: Run) -> None: if run.id != self.root_id: return + if _IN_LISTENER.get(): + return if run.error is None: if self._arg_on_end is not None: - await acall_func_with_variable_args(self._arg_on_end, run, self.config) + token = _IN_LISTENER.set(True) + try: + await acall_func_with_variable_args(self._arg_on_end, run, self.config) + finally: + _IN_LISTENER.reset(token) elif self._arg_on_error is not None: - await acall_func_with_variable_args(self._arg_on_error, run, self.config) + token = _IN_LISTENER.set(True) + try: + await acall_func_with_variable_args(self._arg_on_error, run, self.config) + finally: + _IN_LISTENER.reset(token)