Skip to content

Commit bf57565

Browse files
authored
Fix GraphBuildingError when using g.stream multiple times (#3695)
1 parent f6d1152 commit bf57565

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

pydantic_graph/pydantic_graph/beta/graph_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@ def decorator(
284284
async def wrapper(ctx: StepContext[StateT, DepsT, InputT]):
285285
return call(ctx)
286286

287+
node_id = node_id or get_callable_name(call)
288+
287289
return self.step(call=wrapper, node_id=node_id, label=label)
288290

289291
@overload

tests/graph/beta/test_graph_builder.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,46 @@ async def orphan_step(ctx: StepContext[None, None, None]) -> int:
442442

443443
# Should not raise an error when validation is disabled
444444
g.build(validate_graph_structure=False)
445+
446+
447+
async def test_multiple_stream_decorators_without_node_id():
448+
"""Test that multiple @g.stream decorators without explicit node_id get unique IDs.
449+
450+
When using @g.stream without node_id, the node ID should be derived from the
451+
decorated function's name, not from the internal 'wrapper' function.
452+
"""
453+
from collections.abc import AsyncIterator
454+
455+
g = GraphBuilder(state_type=SimpleState, output_type=list[int])
456+
457+
@g.stream
458+
async def generate_stream(ctx: StepContext[SimpleState, None, None]) -> AsyncIterator[int]:
459+
"""Stream numbers from 1 to 3."""
460+
for i in range(1, 4):
461+
yield i
462+
463+
@g.stream
464+
async def square(ctx: StepContext[SimpleState, None, int]) -> AsyncIterator[int]:
465+
"""Square the input."""
466+
yield ctx.inputs * ctx.inputs
467+
468+
@g.step
469+
async def plus_one(ctx: StepContext[SimpleState, None, int]) -> int:
470+
return ctx.inputs + 1
471+
472+
collect = g.join(reduce_list_append, initial_factory=list[int])
473+
474+
# This should NOT raise GraphBuildingError about duplicate 'wrapper' node IDs
475+
g.add(
476+
g.edge_from(g.start_node).to(generate_stream),
477+
g.edge_from(generate_stream).map().to(square),
478+
g.edge_from(square).map().to(plus_one),
479+
g.edge_from(plus_one).to(collect),
480+
g.edge_from(collect).to(g.end_node),
481+
)
482+
483+
graph = g.build()
484+
485+
state = SimpleState()
486+
result = await graph.run(state=state)
487+
assert sorted(result) == [2, 5, 10]

0 commit comments

Comments
 (0)