|
18 | 18 | import warnings
|
19 | 19 | from collections import defaultdict
|
20 | 20 | from timeit import default_timer
|
21 |
| -from typing import Any, Optional |
| 21 | +from typing import Any, Optional, AsyncGenerator |
| 22 | +import asyncio |
22 | 23 |
|
23 | 24 | from neo4j_graphrag.utils.logging import prettify
|
24 | 25 |
|
|
47 | 48 | )
|
48 | 49 | from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult
|
49 | 50 | from neo4j_graphrag.experimental.pipeline.types.context import RunContext
|
50 |
| -from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol |
| 51 | +from neo4j_graphrag.experimental.pipeline.notification import ( |
| 52 | + EventCallbackProtocol, |
| 53 | + Event, |
| 54 | + PipelineEvent, |
| 55 | + EventType, |
| 56 | +) |
51 | 57 |
|
52 | 58 |
|
53 | 59 | logger = logging.getLogger(__name__)
|
@@ -117,7 +123,7 @@ def __init__(
|
117 | 123 | ) -> None:
|
118 | 124 | super().__init__()
|
119 | 125 | self.store = store or InMemoryStore()
|
120 |
| - self.callback = callback |
| 126 | + self.callbacks = [callback] if callback else [] |
121 | 127 | self.final_results = InMemoryStore()
|
122 | 128 | self.is_validated = False
|
123 | 129 | self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
|
@@ -412,6 +418,76 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
|
412 | 418 | async def get_final_results(self, run_id: str) -> dict[str, Any]:
|
413 | 419 | return await self.final_results.get(run_id) # type: ignore[no-any-return]
|
414 | 420 |
|
| 421 | + async def stream( |
| 422 | + self, data: dict[str, Any], raise_exception: bool = True |
| 423 | + ) -> AsyncGenerator[Event, None]: |
| 424 | + """Run the pipeline and stream events for task progress. |
| 425 | +
|
| 426 | + Args: |
| 427 | + data (dict): Input data for the pipeline components |
| 428 | + raise_exception (bool): set to False to prevent this task from propagating |
| 429 | + Pipeline exceptions. |
| 430 | +
|
| 431 | + Yields: |
| 432 | + Event: Pipeline and task events including start, progress, and completion |
| 433 | + """ |
| 434 | + # Create queue for events |
| 435 | + event_queue: asyncio.Queue[Event] = asyncio.Queue() |
| 436 | + run_id = None |
| 437 | + |
| 438 | + async def event_stream(event: Event) -> None: |
| 439 | + # Put event in queue for streaming |
| 440 | + await event_queue.put(event) |
| 441 | + |
| 442 | + # Add event streaming callback |
| 443 | + self.callbacks.append(event_stream) |
| 444 | + |
| 445 | + event_queue_getter_task = None |
| 446 | + try: |
| 447 | + # Start pipeline execution in background task |
| 448 | + run_task = asyncio.create_task(self.run(data)) |
| 449 | + |
| 450 | + # loop until the run task is done, and we do not have |
| 451 | + # any more pending tasks in queue |
| 452 | + is_run_task_running = True |
| 453 | + is_queue_empty = False |
| 454 | + while is_run_task_running or not is_queue_empty: |
| 455 | + # Wait for next event or pipeline completion |
| 456 | + event_queue_getter_task = asyncio.create_task(event_queue.get()) |
| 457 | + done, pending = await asyncio.wait( |
| 458 | + [run_task, event_queue_getter_task], |
| 459 | + return_when=asyncio.FIRST_COMPLETED, |
| 460 | + ) |
| 461 | + |
| 462 | + is_run_task_running = run_task not in done |
| 463 | + is_queue_empty = event_queue.empty() |
| 464 | + |
| 465 | + for event_future in done: |
| 466 | + if event_future == run_task: |
| 467 | + continue |
| 468 | + # we are sure to get an Event here, since this is the only |
| 469 | + # thing we put in the queue, but mypy still complains |
| 470 | + event = event_future.result() |
| 471 | + run_id = getattr(event, "run_id", None) |
| 472 | + yield event # type: ignore |
| 473 | + |
| 474 | + if exc := run_task.exception(): |
| 475 | + yield PipelineEvent( |
| 476 | + event_type=EventType.PIPELINE_FAILED, |
| 477 | + # run_id is null if pipeline fails before even starting |
| 478 | + # ie during pipeline validation |
| 479 | + run_id=run_id or "", |
| 480 | + message=str(exc), |
| 481 | + ) |
| 482 | + if raise_exception: |
| 483 | + raise exc |
| 484 | + |
| 485 | + finally: |
| 486 | + # Restore original callback |
| 487 | + self.callbacks.remove(event_stream) |
| 488 | + if event_queue_getter_task and not event_queue_getter_task.done(): |
| 489 | + event_queue_getter_task.cancel() |
| 490 | + |
415 | 491 | async def run(self, data: dict[str, Any]) -> PipelineResult:
|
416 | 492 | logger.debug("PIPELINE START")
|
417 | 493 | start_time = default_timer()
|
|
0 commit comments