diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index 20d489b65..f9a16d858 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextvars import random import uuid from contextlib import contextmanager @@ -400,48 +401,57 @@ async def signal_external_workflow( def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: - trace = get_trace_provider().get_current_trace() - span: Optional[Span] = None - if trace: - span = custom_span( - name="temporal:startActivity", data={"activity": input.activity} - ) - span.start(mark_as_current=True) - - set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = self.next.start_activity(input) + ctx = contextvars.copy_context() + span = ctx.run( + self._create_span, + name="temporal:startActivity", + data={"activity": input.activity}, + input=input, + ) + handle = ctx.run(self.next.start_activity, input) if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore + handle.add_done_callback(lambda _: span.finish(), context=ctx) + return handle async def start_child_workflow( self, input: temporalio.worker.StartChildWorkflowInput ) -> temporalio.workflow.ChildWorkflowHandle: - trace = get_trace_provider().get_current_trace() - span: Optional[Span] = None - if trace: - span = custom_span( - name="temporal:startChildWorkflow", data={"workflow": input.workflow} - ) - span.start(mark_as_current=True) - set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = await self.next.start_child_workflow(input) + ctx = contextvars.copy_context() + span = ctx.run( + self._create_span, + name="temporal:startChildWorkflow", + data={"workflow": input.workflow}, + input=input, + ) + handle = await ctx.run(self.next.start_child_workflow, input) if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore + handle.add_done_callback(lambda _: span.finish(), context=ctx) return handle def start_local_activity( self, input: temporalio.worker.StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle: + ctx = contextvars.copy_context() + span = ctx.run( + self._create_span, + name="temporal:startLocalActivity", + data={"activity": input.activity}, + input=input, + ) + handle = ctx.run(self.next.start_local_activity, input) + if span: + handle.add_done_callback(lambda _: span.finish(), context=ctx) + return handle + + @staticmethod + def _create_span( + name: str, data: dict[str, Any], input: _InputWithHeaders + ) -> Optional[Span]: trace = get_trace_provider().get_current_trace() span: Optional[Span] = None if trace: - span = custom_span( - name="temporal:startLocalActivity", data={"activity": input.activity} - ) + span = custom_span(name=name, data=data) span.start(mark_as_current=True) set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = self.next.start_local_activity(input) - if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore - return handle + return span diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py index f0a98c2b6..ec95310cd 100644 --- a/tests/contrib/openai_agents/test_openai_tracing.py +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -2,9 +2,10 @@ from datetime import timedelta from typing import Any -from agents import Span, Trace, TracingProcessor +from agents import Span, Trace, TracingProcessor, trace from agents.tracing import get_trace_provider +from temporalio import workflow from temporalio.client import Client from temporalio.contrib.openai_agents.testing import ( AgentEnvironment, @@ -21,6 +22,10 @@ class MemoryTracingProcessor(TracingProcessor): trace_events: list[tuple[Trace, bool]] = [] span_events: list[tuple[Span, bool]] = [] + def __init__(self): + self.trace_events = [] + self.span_events = [] + def on_trace_start(self, trace: Trace) -> None: self.trace_events.append((trace, True)) @@ -40,6 +45,12 @@ def force_flush(self) -> None: pass +def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: + assert a[0].trace_id == b[0].trace_id + assert a[1] + assert not b[1] + + async def test_tracing(client: Client): async with AgentEnvironment(model=research_mock_model()) as env: client = env.applied_on_client(client) @@ -71,11 +82,6 @@ async def test_tracing(client: Client): assert processor.trace_events[0][1] assert not processor.trace_events[1][1] - def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: - assert a[0].trace_id == b[0].trace_id - assert a[1] - assert not b[1] - # Initial planner spans - There are only 3 because we don't make an actual model call paired_span(processor.span_events[0], processor.span_events[5]) assert ( @@ -142,3 +148,68 @@ def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: processor.span_events[-4][0].span_data.export().get("name") == "temporal:executeActivity" ) + + +@workflow.defn +class ChildWorkflow: + @workflow.run + async def run(self) -> str: + return "A" + + +@workflow.defn +class ParentWorkflow: + @workflow.run + async def run(self) -> str: + with trace("Parent trace"): + return await workflow.execute_child_workflow(ChildWorkflow.run) + + +async def test_tracing_child_workflow(client: Client): + async with AgentEnvironment(model=research_mock_model()) as env: + client = env.applied_on_client(client) + + provider = get_trace_provider() + + processor = MemoryTracingProcessor() + provider.set_processors([processor]) + + async with new_worker( + client, + ParentWorkflow, + ChildWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + ParentWorkflow.run, + id=f"openai-tracing-child-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=120), + ) + result = await workflow_handle.result() + + # There is one closed root trace + assert len(processor.trace_events) == 2 + assert ( + processor.trace_events[0][0].trace_id + == processor.trace_events[1][0].trace_id + ) + assert processor.trace_events[0][1] + assert not processor.trace_events[1][1] + + for span, _ in processor.span_events: + print( + f"Span: {span.span_id}, parent: {span.parent_id}, data: {span.span_data.export()}" + ) + + # Two spans - startChildWorkflow > executeWorkflow + paired_span(processor.span_events[0], processor.span_events[3]) + assert ( + processor.span_events[0][0].span_data.export().get("name") + == "temporal:startChildWorkflow" + ) + + paired_span(processor.span_events[1], processor.span_events[2]) + assert ( + processor.span_events[1][0].span_data.export().get("name") + == "temporal:executeWorkflow" + )