|
3 | 3 | Intended for internal use by the Prefect REST API.
|
4 | 4 | """
|
5 | 5 |
|
6 |
| -from typing import TYPE_CHECKING, Optional, Sequence |
| 6 | +from typing import TYPE_CHECKING, Optional, Sequence, TypeVar, Union |
7 | 7 | from uuid import UUID
|
8 | 8 |
|
9 | 9 | import sqlalchemy as sa
|
10 | 10 | from sqlalchemy import delete, select
|
11 | 11 | from sqlalchemy.ext.asyncio import AsyncSession
|
| 12 | +from sqlalchemy.sql import Select |
12 | 13 |
|
13 | 14 | import prefect.server.schemas as schemas
|
14 | 15 | from prefect.server.database.dependencies import db_injector
|
|
18 | 19 | from prefect.server.database.orm_models import ORMFlow
|
19 | 20 |
|
20 | 21 |
|
| 22 | +T = TypeVar("T", bound=tuple) |
| 23 | + |
| 24 | + |
21 | 25 | @db_injector
|
22 | 26 | async def create_flow(
|
23 | 27 | db: PrefectDBInterface, session: AsyncSession, flow: schemas.core.Flow
|
@@ -53,7 +57,7 @@ async def create_flow(
|
53 | 57 | .execution_options(populate_existing=True)
|
54 | 58 | )
|
55 | 59 | result = await session.execute(query)
|
56 |
| - model = result.scalar() |
| 60 | + model = result.scalar_one() |
57 | 61 | return model
|
58 | 62 |
|
59 | 63 |
|
@@ -125,13 +129,13 @@ async def read_flow_by_name(
|
125 | 129 | @db_injector
|
126 | 130 | async def _apply_flow_filters(
|
127 | 131 | db: PrefectDBInterface,
|
128 |
| - query, |
129 |
| - flow_filter: schemas.filters.FlowFilter = None, |
130 |
| - flow_run_filter: schemas.filters.FlowRunFilter = None, |
131 |
| - task_run_filter: schemas.filters.TaskRunFilter = None, |
132 |
| - deployment_filter: schemas.filters.DeploymentFilter = None, |
133 |
| - work_pool_filter: schemas.filters.WorkPoolFilter = None, |
134 |
| -): |
| 132 | + query: Select[T], |
| 133 | + flow_filter: Union[schemas.filters.FlowFilter, None] = None, |
| 134 | + flow_run_filter: Union[schemas.filters.FlowRunFilter, None] = None, |
| 135 | + task_run_filter: Union[schemas.filters.TaskRunFilter, None] = None, |
| 136 | + deployment_filter: Union[schemas.filters.DeploymentFilter, None] = None, |
| 137 | + work_pool_filter: Union[schemas.filters.WorkPoolFilter, None] = None, |
| 138 | +) -> Select[T]: |
135 | 139 | """
|
136 | 140 | Applies filters to a flow query as a combination of EXISTS subqueries.
|
137 | 141 | """
|
@@ -181,14 +185,14 @@ async def _apply_flow_filters(
|
181 | 185 | async def read_flows(
|
182 | 186 | db: PrefectDBInterface,
|
183 | 187 | session: AsyncSession,
|
184 |
| - flow_filter: schemas.filters.FlowFilter = None, |
185 |
| - flow_run_filter: schemas.filters.FlowRunFilter = None, |
186 |
| - task_run_filter: schemas.filters.TaskRunFilter = None, |
187 |
| - deployment_filter: schemas.filters.DeploymentFilter = None, |
188 |
| - work_pool_filter: schemas.filters.WorkPoolFilter = None, |
| 188 | + flow_filter: Union[schemas.filters.FlowFilter, None] = None, |
| 189 | + flow_run_filter: Union[schemas.filters.FlowRunFilter, None] = None, |
| 190 | + task_run_filter: Union[schemas.filters.TaskRunFilter, None] = None, |
| 191 | + deployment_filter: Union[schemas.filters.DeploymentFilter, None] = None, |
| 192 | + work_pool_filter: Union[schemas.filters.WorkPoolFilter, None] = None, |
189 | 193 | sort: schemas.sorting.FlowSort = schemas.sorting.FlowSort.NAME_ASC,
|
190 |
| - offset: int = None, |
191 |
| - limit: int = None, |
| 194 | + offset: Union[int, None] = None, |
| 195 | + limit: Union[int, None] = None, |
192 | 196 | ) -> Sequence["ORMFlow"]:
|
193 | 197 | """
|
194 | 198 | Read multiple flows.
|
@@ -232,11 +236,11 @@ async def read_flows(
|
232 | 236 | async def count_flows(
|
233 | 237 | db: PrefectDBInterface,
|
234 | 238 | session: AsyncSession,
|
235 |
| - flow_filter: schemas.filters.FlowFilter = None, |
236 |
| - flow_run_filter: schemas.filters.FlowRunFilter = None, |
237 |
| - task_run_filter: schemas.filters.TaskRunFilter = None, |
238 |
| - deployment_filter: schemas.filters.DeploymentFilter = None, |
239 |
| - work_pool_filter: schemas.filters.WorkPoolFilter = None, |
| 239 | + flow_filter: Union[schemas.filters.FlowFilter, None] = None, |
| 240 | + flow_run_filter: Union[schemas.filters.FlowRunFilter, None] = None, |
| 241 | + task_run_filter: Union[schemas.filters.TaskRunFilter, None] = None, |
| 242 | + deployment_filter: Union[schemas.filters.DeploymentFilter, None] = None, |
| 243 | + work_pool_filter: Union[schemas.filters.WorkPoolFilter, None] = None, |
240 | 244 | ) -> int:
|
241 | 245 | """
|
242 | 246 | Count flows.
|
@@ -265,7 +269,7 @@ async def count_flows(
|
265 | 269 | )
|
266 | 270 |
|
267 | 271 | result = await session.execute(query)
|
268 |
| - return result.scalar() |
| 272 | + return result.scalar_one() |
269 | 273 |
|
270 | 274 |
|
271 | 275 | @db_injector
|
|
0 commit comments