77from agents .tracing import (
88 get_trace_provider ,
99)
10- from agents .tracing .provider import DefaultTraceProvider
10+ from agents .tracing .provider import (
11+ DefaultTraceProvider ,
12+ SynchronousMultiTracingProcessor ,
13+ )
1114from agents .tracing .spans import Span
1215
1316from temporalio import workflow
@@ -72,22 +75,35 @@ def activity_span(
7275 )
7376
7477
75- class _TemporalTracingProcessor (TracingProcessor ):
76- def __init__ (self , impl : TracingProcessor ):
78+ class _TemporalTracingProcessor (SynchronousMultiTracingProcessor ):
79+ def __init__ (
80+ self , impl : SynchronousMultiTracingProcessor , auto_close_in_workflows : bool
81+ ):
7782 super ().__init__ ()
7883 self ._impl = impl
84+ self ._auto_close_in_workflows = auto_close_in_workflows
85+
86+ def add_tracing_processor (self , tracing_processor : TracingProcessor ):
87+ self ._impl .add_tracing_processor (tracing_processor )
88+
89+ def set_processors (self , processors : list [TracingProcessor ]):
90+ self ._impl .set_processors (processors )
7991
8092 def on_trace_start (self , trace : Trace ) -> None :
8193 if workflow .in_workflow () and workflow .unsafe .is_replaying ():
8294 # In replay mode, don't report
8395 return
8496
8597 self ._impl .on_trace_start (trace )
98+ if self ._auto_close_in_workflows and workflow .in_workflow ():
99+ self ._impl .on_trace_end (trace )
86100
87101 def on_trace_end (self , trace : Trace ) -> None :
88102 if workflow .in_workflow () and workflow .unsafe .is_replaying ():
89103 # In replay mode, don't report
90104 return
105+ if self ._auto_close_in_workflows and workflow .in_workflow ():
106+ return
91107
92108 self ._impl .on_trace_end (trace )
93109
@@ -97,11 +113,16 @@ def on_span_start(self, span: Span[Any]) -> None:
97113 return
98114
99115 self ._impl .on_span_start (span )
116+ if self ._auto_close_in_workflows and workflow .in_workflow ():
117+ self ._impl .on_span_end (span )
100118
101119 def on_span_end (self , span : Span [Any ]) -> None :
102120 if workflow .in_workflow () and workflow .unsafe .is_replaying ():
103121 # In replay mode, don't report
104122 return
123+ if self ._auto_close_in_workflows and workflow .in_workflow ():
124+ return
125+
105126 self ._impl .on_span_end (span )
106127
107128 def shutdown (self ) -> None :
@@ -114,12 +135,13 @@ def force_flush(self) -> None:
114135class TemporalTraceProvider (DefaultTraceProvider ):
115136 """A trace provider that integrates with Temporal workflows."""
116137
117- def __init__ (self ):
138+ def __init__ (self , auto_close_in_workflows : bool = False ):
118139 """Initialize the TemporalTraceProvider."""
119140 super ().__init__ ()
120141 self ._original_provider = cast (DefaultTraceProvider , get_trace_provider ())
121- self ._multi_processor = _TemporalTracingProcessor ( # type: ignore[assignment]
122- self ._original_provider ._multi_processor
142+ self ._multi_processor = _TemporalTracingProcessor (
143+ self ._original_provider ._multi_processor ,
144+ auto_close_in_workflows ,
123145 )
124146
125147 def time_iso (self ) -> str :
0 commit comments