Skip to content

Commit 4c61674

Browse files
Get mypy passing for models/flows.py (#12919)
1 parent fd74517 commit 4c61674

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ plugins=
8686
pydantic.mypy
8787

8888
files =
89+
src/prefect/server/models/flows.py,
8990
src/prefect/concurrency/,
9091
src/prefect/events/,
9192
src/prefect/input/

src/prefect/server/models/flows.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
Intended for internal use by the Prefect REST API.
44
"""
55

6-
from typing import TYPE_CHECKING, Optional, Sequence
6+
from typing import TYPE_CHECKING, Optional, Sequence, TypeVar, Union
77
from uuid import UUID
88

99
import sqlalchemy as sa
1010
from sqlalchemy import delete, select
1111
from sqlalchemy.ext.asyncio import AsyncSession
12+
from sqlalchemy.sql import Select
1213

1314
import prefect.server.schemas as schemas
1415
from prefect.server.database.dependencies import db_injector
@@ -18,6 +19,9 @@
1819
from prefect.server.database.orm_models import ORMFlow
1920

2021

22+
T = TypeVar("T", bound=tuple)
23+
24+
2125
@db_injector
2226
async def create_flow(
2327
db: PrefectDBInterface, session: AsyncSession, flow: schemas.core.Flow
@@ -53,7 +57,7 @@ async def create_flow(
5357
.execution_options(populate_existing=True)
5458
)
5559
result = await session.execute(query)
56-
model = result.scalar()
60+
model = result.scalar_one()
5761
return model
5862

5963

@@ -125,13 +129,13 @@ async def read_flow_by_name(
125129
@db_injector
126130
async def _apply_flow_filters(
127131
db: PrefectDBInterface,
128-
query,
129-
flow_filter: schemas.filters.FlowFilter = None,
130-
flow_run_filter: schemas.filters.FlowRunFilter = None,
131-
task_run_filter: schemas.filters.TaskRunFilter = None,
132-
deployment_filter: schemas.filters.DeploymentFilter = None,
133-
work_pool_filter: schemas.filters.WorkPoolFilter = None,
134-
):
132+
query: Select[T],
133+
flow_filter: Union[schemas.filters.FlowFilter, None] = None,
134+
flow_run_filter: Union[schemas.filters.FlowRunFilter, None] = None,
135+
task_run_filter: Union[schemas.filters.TaskRunFilter, None] = None,
136+
deployment_filter: Union[schemas.filters.DeploymentFilter, None] = None,
137+
work_pool_filter: Union[schemas.filters.WorkPoolFilter, None] = None,
138+
) -> Select[T]:
135139
"""
136140
Applies filters to a flow query as a combination of EXISTS subqueries.
137141
"""
@@ -181,14 +185,14 @@ async def _apply_flow_filters(
181185
async def read_flows(
182186
db: PrefectDBInterface,
183187
session: AsyncSession,
184-
flow_filter: schemas.filters.FlowFilter = None,
185-
flow_run_filter: schemas.filters.FlowRunFilter = None,
186-
task_run_filter: schemas.filters.TaskRunFilter = None,
187-
deployment_filter: schemas.filters.DeploymentFilter = None,
188-
work_pool_filter: schemas.filters.WorkPoolFilter = None,
188+
flow_filter: Union[schemas.filters.FlowFilter, None] = None,
189+
flow_run_filter: Union[schemas.filters.FlowRunFilter, None] = None,
190+
task_run_filter: Union[schemas.filters.TaskRunFilter, None] = None,
191+
deployment_filter: Union[schemas.filters.DeploymentFilter, None] = None,
192+
work_pool_filter: Union[schemas.filters.WorkPoolFilter, None] = None,
189193
sort: schemas.sorting.FlowSort = schemas.sorting.FlowSort.NAME_ASC,
190-
offset: int = None,
191-
limit: int = None,
194+
offset: Union[int, None] = None,
195+
limit: Union[int, None] = None,
192196
) -> Sequence["ORMFlow"]:
193197
"""
194198
Read multiple flows.
@@ -232,11 +236,11 @@ async def read_flows(
232236
async def count_flows(
233237
db: PrefectDBInterface,
234238
session: AsyncSession,
235-
flow_filter: schemas.filters.FlowFilter = None,
236-
flow_run_filter: schemas.filters.FlowRunFilter = None,
237-
task_run_filter: schemas.filters.TaskRunFilter = None,
238-
deployment_filter: schemas.filters.DeploymentFilter = None,
239-
work_pool_filter: schemas.filters.WorkPoolFilter = None,
239+
flow_filter: Union[schemas.filters.FlowFilter, None] = None,
240+
flow_run_filter: Union[schemas.filters.FlowRunFilter, None] = None,
241+
task_run_filter: Union[schemas.filters.TaskRunFilter, None] = None,
242+
deployment_filter: Union[schemas.filters.DeploymentFilter, None] = None,
243+
work_pool_filter: Union[schemas.filters.WorkPoolFilter, None] = None,
240244
) -> int:
241245
"""
242246
Count flows.
@@ -265,7 +269,7 @@ async def count_flows(
265269
)
266270

267271
result = await session.execute(query)
268-
return result.scalar()
272+
return result.scalar_one()
269273

270274

271275
@db_injector

0 commit comments

Comments
 (0)