Skip to content

Commit 474a830

Browse files
committed
stop assigning task run early
1 parent 0c783f4 commit 474a830

File tree

2 files changed

+30
-49
lines changed

2 files changed

+30
-49
lines changed

src/prefect/tasks.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ async def create_local_run(
888888
from prefect.utilities._engine import dynamic_key_for_task_run
889889
from prefect.utilities.engine import (
890890
collect_task_run_inputs_sync,
891+
record_task_assets,
891892
)
892893

893894
if flow_run_context is None:
@@ -928,7 +929,7 @@ async def create_local_run(
928929

929930
store = await ResultStore(
930931
result_storage=await get_or_create_default_task_scheduling_storage()
931-
).update_for_task(task)
932+
).update_for_task(self)
932933
context = serialize_context()
933934
data: dict[str, Any] = {"context": context}
934935
if parameters:
@@ -995,55 +996,10 @@ async def create_local_run(
995996
)
996997

997998
# Record task assets after creating the task run
998-
self.task_run = task_run
999-
self._record_task_assets()
999+
record_task_assets(self, task_run)
10001000

10011001
return task_run
10021002

1003-
def _record_task_assets(self) -> None:
1004-
"""Record direct assets and conditionally propagate upstream assets based on task type."""
1005-
ctx = FlowRunContext.get()
1006-
if not ctx or not hasattr(self, "task_run") or not self.task_run:
1007-
return
1008-
1009-
direct_assets = []
1010-
1011-
# TODO don't do hasattr
1012-
if hasattr(self, "asset_deps") and self.asset_deps:
1013-
from prefect.assets import Asset
1014-
1015-
for asset in self.asset_deps:
1016-
asset_obj = asset if isinstance(asset, Asset) else Asset(key=asset)
1017-
direct_assets.append(asset_obj)
1018-
1019-
if hasattr(self, "assets"):
1020-
direct_assets.extend(self.assets)
1021-
assets_for_downstream = self.assets[:]
1022-
else:
1023-
upstream_assets = self._get_upstream_assets_from_inputs()
1024-
assets_for_downstream = direct_assets + list(upstream_assets)
1025-
1026-
ctx.task_run_assets[self.task_run.id] = assets_for_downstream
1027-
1028-
def _get_upstream_assets_from_inputs(self) -> set[Any]:
1029-
"""Extract upstream assets from task inputs"""
1030-
if (
1031-
not hasattr(self, "task_run")
1032-
or not self.task_run
1033-
or not self.task_run.task_inputs
1034-
):
1035-
return set()
1036-
1037-
upstream_assets = set()
1038-
for input_list in self.task_run.task_inputs.values():
1039-
# TODO make sure we're only checking TaskRunResult
1040-
# TODO I think I can just get rid of this whole method...
1041-
for task_input in input_list:
1042-
if hasattr(task_input, "assets") and task_input.assets:
1043-
upstream_assets.update(task_input.assets)
1044-
1045-
return upstream_assets
1046-
10471003
@overload
10481004
def __call__(
10491005
self: "Task[P, NoReturn]",

src/prefect/utilities/engine.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,14 +852,14 @@ def resolve_inputs_sync(
852852

853853

854854
def get_upstream_assets_from_task_inputs(task_run: TaskRun) -> set[Any]:
855-
"""Extract upstream assets from task inputs (already forwarded)."""
855+
"""Extract upstream assets from task inputs"""
856856
if not task_run or not task_run.task_inputs:
857857
return set()
858858

859859
upstream_assets = set()
860860
for input_list in task_run.task_inputs.values():
861861
for task_input in input_list:
862-
if hasattr(task_input, "assets") and task_input.assets:
862+
if isinstance(task_input, TaskRunResult) and task_input.assets:
863863
upstream_assets.update(task_input.assets)
864864

865865
return upstream_assets
@@ -912,3 +912,28 @@ def emit_asset_events(task: Any, task_run: TaskRun, succeeded: bool) -> None:
912912
resource=asset_as_resource(asset),
913913
related=all_related,
914914
)
915+
916+
917+
def record_task_assets(task: Any, task_run: TaskRun) -> None:
918+
"""Record direct assets and conditionally propagate upstream assets based on task type."""
919+
ctx = FlowRunContext.get()
920+
if not ctx or not task_run:
921+
return
922+
923+
direct_assets = []
924+
925+
if hasattr(task, "asset_deps") and task.asset_deps:
926+
from prefect.assets import Asset
927+
928+
for asset in task.asset_deps:
929+
asset_obj = asset if isinstance(asset, Asset) else Asset(key=asset)
930+
direct_assets.append(asset_obj)
931+
932+
if hasattr(task, "assets"):
933+
direct_assets.extend(task.assets[:])
934+
assets_for_downstream = task.assets[:]
935+
else:
936+
upstream_assets = get_upstream_assets_from_task_inputs(task_run)
937+
assets_for_downstream = direct_assets + list(upstream_assets)
938+
939+
ctx.task_run_assets[task_run.id] = assets_for_downstream

0 commit comments

Comments
 (0)