Skip to content

Commit b42d3a1

Browse files
authored
Refactor db_injector decorator (#16390)
1 parent f4a26c5 commit b42d3a1

File tree

8 files changed

+440
-145
lines changed

8 files changed

+440
-145
lines changed

src/prefect/server/database/dependencies.py

Lines changed: 248 additions & 52 deletions
Large diffs are not rendered by default.

src/prefect/server/events/ordering.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import pendulum
2020
import sqlalchemy as sa
2121
from cachetools import TTLCache
22-
from typing_extensions import Self
2322

2423
from prefect.logging import get_logger
2524
from prefect.server.database.dependencies import db_injector
@@ -75,7 +74,9 @@ async def record_event_as_seen(self, event: ReceivedEvent) -> None:
7574
self._seen_events[self.scope][event.id] = True
7675

7776
@db_injector
78-
async def record_follower(db: PrefectDBInterface, self: Self, event: ReceivedEvent):
77+
async def record_follower(
78+
self, db: PrefectDBInterface, event: ReceivedEvent
79+
) -> None:
7980
"""Remember that this event is waiting on another event to arrive"""
8081
assert event.follows
8182

@@ -92,8 +93,8 @@ async def record_follower(db: PrefectDBInterface, self: Self, event: ReceivedEve
9293

9394
@db_injector
9495
async def forget_follower(
95-
db: PrefectDBInterface, self: Self, follower: ReceivedEvent
96-
):
96+
self, db: PrefectDBInterface, follower: ReceivedEvent
97+
) -> None:
9798
"""Forget that this event is waiting on another event to arrive"""
9899
assert follower.follows
99100

@@ -107,7 +108,7 @@ async def forget_follower(
107108

108109
@db_injector
109110
async def get_followers(
110-
db: PrefectDBInterface, self: Self, leader: ReceivedEvent
111+
self, db: PrefectDBInterface, leader: ReceivedEvent
111112
) -> List[ReceivedEvent]:
112113
"""Returns events that were waiting on this leader event to arrive"""
113114
async with db.session_context() as session:
@@ -120,7 +121,7 @@ async def get_followers(
120121
return sorted(followers, key=lambda e: e.occurred)
121122

122123
@db_injector
123-
async def get_lost_followers(db: PrefectDBInterface, self) -> List[ReceivedEvent]:
124+
async def get_lost_followers(self, db: PrefectDBInterface) -> List[ReceivedEvent]:
124125
"""Returns events that were waiting on a leader event that never arrived"""
125126
earlier = pendulum.now("UTC") - PRECEDING_EVENT_LOOKBACK
126127

src/prefect/server/services/foreman.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import pendulum
99
import sqlalchemy as sa
10-
from typing_extensions import Self
1110

1211
from prefect.server import models
1312
from prefect.server.database.dependencies import db_injector
@@ -70,7 +69,7 @@ def __init__(
7069
)
7170

7271
@db_injector
73-
async def run_once(db: PrefectDBInterface, self: Self) -> None:
72+
async def run_once(self, db: PrefectDBInterface) -> None:
7473
"""
7574
Iterate over workers current marked as online. Mark workers as offline
7675
if they have an old last_heartbeat_time. Marks work pools as not ready
@@ -85,8 +84,7 @@ async def run_once(db: PrefectDBInterface, self: Self) -> None:
8584

8685
@db_injector
8786
async def _mark_online_workers_without_a_recent_heartbeat_as_offline(
88-
db: PrefectDBInterface,
89-
self: Self,
87+
self, db: PrefectDBInterface
9088
) -> None:
9189
"""
9290
Updates the status of workers that have an old last heartbeat time
@@ -147,7 +145,7 @@ async def _mark_online_workers_without_a_recent_heartbeat_as_offline(
147145
self.logger.info(f"Marked {result.rowcount} workers as offline.")
148146

149147
@db_injector
150-
async def _mark_work_pools_as_not_ready(db: PrefectDBInterface, self: Self):
148+
async def _mark_work_pools_as_not_ready(self, db: PrefectDBInterface):
151149
"""
152150
Marks a work pool as not ready.
153151
@@ -185,10 +183,7 @@ async def _mark_work_pools_as_not_ready(db: PrefectDBInterface, self: Self):
185183
self.logger.info(f"Marked work pool {work_pool.id} as NOT_READY.")
186184

187185
@db_injector
188-
async def _mark_deployments_as_not_ready(
189-
db: PrefectDBInterface,
190-
self: Self,
191-
):
186+
async def _mark_deployments_as_not_ready(self, db: PrefectDBInterface) -> None:
192187
"""
193188
Marks a deployment as NOT_READY and emits a deployment status event.
194189
Emits an event and updates any bookkeeping fields on the deployment.
@@ -231,10 +226,7 @@ async def _mark_deployments_as_not_ready(
231226
)
232227

233228
@db_injector
234-
async def _mark_work_queues_as_not_ready(
235-
db: PrefectDBInterface,
236-
self: Self,
237-
):
229+
async def _mark_work_queues_as_not_ready(self, db: PrefectDBInterface):
238230
"""
239231
Marks work queues as NOT_READY based on their last_polled field.
240232
Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
from .bases import (
2-
PrefectBaseModel,
2+
ActionBaseModel,
33
IDBaseModel,
44
ORMBaseModel,
5-
ActionBaseModel,
5+
PrefectBaseModel,
6+
PrefectDescriptorBase,
67
get_class_fields_only,
78
)
9+
10+
__all__ = [
11+
"ActionBaseModel",
12+
"IDBaseModel",
13+
"ORMBaseModel",
14+
"PrefectBaseModel",
15+
"PrefectDescriptorBase",
16+
"get_class_fields_only",
17+
]

src/prefect/server/utilities/schemas/bases.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,59 @@
11
import datetime
22
import os
3-
from typing import (
4-
TYPE_CHECKING,
5-
Any,
6-
ClassVar,
7-
Dict,
8-
Optional,
9-
Set,
10-
Type,
11-
TypeVar,
12-
)
3+
from abc import ABC, abstractmethod
4+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar
135
from uuid import UUID, uuid4
146

157
import pendulum
16-
from pydantic import (
17-
BaseModel,
18-
ConfigDict,
19-
Field,
20-
)
8+
from pydantic import BaseModel, ConfigDict, Field
219
from typing_extensions import Self
2210

2311
from prefect.types import DateTime
2412

2513
if TYPE_CHECKING:
2614
from pydantic.main import IncEx
15+
from rich.repr import RichReprResult
2716

2817
T = TypeVar("T")
2918
B = TypeVar("B", bound=BaseModel)
3019

3120

32-
def get_class_fields_only(model: Type[BaseModel]) -> set:
21+
def get_class_fields_only(model: type[BaseModel]) -> set[str]:
3322
"""
3423
Gets all the field names defined on the model class but not any parent classes.
3524
Any fields that are on the parent but redefined on the subclass are included.
3625
"""
37-
subclass_class_fields = set(model.__annotations__.keys())
38-
parent_class_fields = set()
26+
# the annotations keys fit all of these criteria without further processing
27+
return set(model.__annotations__)
3928

40-
for base in model.__class__.__bases__:
41-
if issubclass(base, BaseModel):
42-
parent_class_fields.update(base.__annotations__.keys())
4329

44-
return (subclass_class_fields - parent_class_fields) | (
45-
subclass_class_fields & parent_class_fields
46-
)
30+
class PrefectDescriptorBase(ABC):
31+
"""A base class for descriptor objects used with PrefectBaseModel
32+
33+
Pydantic needs to be told about any kind of non-standard descriptor
34+
objects used on a model, in order for these not to be treated as a field
35+
type instead.
36+
37+
This base class is registered as an ignored type with PrefectBaseModel
38+
and any classes that inherit from it will also be ignored. This allows
39+
such descriptors to be used as properties, methods or other bound
40+
descriptor use cases.
41+
42+
"""
43+
44+
@abstractmethod
45+
def __get__(
46+
self, __instance: Optional[Any], __owner: Optional[type[Any]] = None
47+
) -> Any:
48+
"""Base descriptor access.
49+
50+
The default implementation returns itself when the instance is None,
51+
and raises an attribute error when the instance is not not None.
52+
53+
"""
54+
if __instance is not None:
55+
raise AttributeError
56+
return self
4757

4858

4959
class PrefectBaseModel(BaseModel):
@@ -58,7 +68,7 @@ class PrefectBaseModel(BaseModel):
5868
subtle unintentional testing errors.
5969
"""
6070

61-
_reset_fields: ClassVar[Set[str]] = set()
71+
_reset_fields: ClassVar[set[str]] = set()
6272

6373
model_config = ConfigDict(
6474
ser_json_timedelta="float",
@@ -68,6 +78,7 @@ class PrefectBaseModel(BaseModel):
6878
and os.getenv("PREFECT_TESTING_TEST_MODE", "0").lower() not in ["true", "1"]
6979
else "forbid"
7080
),
81+
ignored_types=(PrefectDescriptorBase,),
7182
)
7283

7384
def __eq__(self, other: Any) -> bool:
@@ -84,22 +95,20 @@ def __eq__(self, other: Any) -> bool:
8495
else:
8596
return copy_dict == other
8697

87-
def __rich_repr__(self):
98+
def __rich_repr__(self) -> "RichReprResult":
8899
# Display all of the fields in the model if they differ from the default value
89100
for name, field in self.model_fields.items():
90101
value = getattr(self, name)
91102

92103
# Simplify the display of some common fields
93-
if field.annotation == UUID and value:
104+
if isinstance(value, UUID):
94105
value = str(value)
95-
elif (
96-
isinstance(field.annotation, datetime.datetime)
97-
and name == "timestamp"
98-
and value
99-
):
100-
value = pendulum.instance(value).isoformat()
101-
elif isinstance(field.annotation, datetime.datetime) and value:
102-
value = pendulum.instance(value).diff_for_humans()
106+
elif isinstance(value, datetime.datetime):
107+
value = (
108+
pendulum.instance(value).isoformat()
109+
if name == "timestamp"
110+
else pendulum.instance(value).diff_for_humans()
111+
)
103112

104113
yield name, value, field.get_default()
105114

@@ -126,7 +135,7 @@ def model_dump_for_orm(
126135
exclude_unset: bool = False,
127136
exclude_defaults: bool = False,
128137
exclude_none: bool = False,
129-
) -> Dict[str, Any]:
138+
) -> dict[str, Any]:
130139
"""
131140
Prefect extension to `BaseModel.model_dump`. Generate a Python dictionary
132141
representation of the model suitable for passing to SQLAlchemy model
@@ -179,7 +188,7 @@ class IDBaseModel(PrefectBaseModel):
179188
The ID is reset on copy() and not included in equality comparisons.
180189
"""
181190

182-
_reset_fields: ClassVar[Set[str]] = {"id"}
191+
_reset_fields: ClassVar[set[str]] = {"id"}
183192
id: UUID = Field(default_factory=uuid4)
184193

185194

@@ -192,7 +201,7 @@ class ORMBaseModel(IDBaseModel):
192201
equality comparisons.
193202
"""
194203

195-
_reset_fields: ClassVar[Set[str]] = {"id", "created", "updated"}
204+
_reset_fields: ClassVar[set[str]] = {"id", "created", "updated"}
196205

197206
model_config = ConfigDict(from_attributes=True)
198207

tests/server/database/test_dependencies.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
AsyncPostgresConfiguration,
1212
BaseDatabaseConfiguration,
1313
)
14-
from prefect.server.database.dependencies import inject_db
1514
from prefect.server.database.interface import PrefectDBInterface
1615
from prefect.server.database.orm_models import (
1716
AioSqliteORMConfiguration,
@@ -151,28 +150,6 @@ async def test_injecting_existing_orm_configs(ORMConfig):
151150
assert type(db.orm) == ORMConfig
152151

153152

154-
async def test_inject_db(db):
155-
"""
156-
Regression test for async-mangling behavior of inject_db() decorator.
157-
158-
Previously, when wrapping a coroutine function, the decorator returned
159-
that function's coroutine object, instead of the coroutine function.
160-
161-
This worked fine in most cases because both a coroutine function and a
162-
coroutine object can be awaited, but it broke our Pytest setup because
163-
we were auto-marking coroutine functions as async, and any async test
164-
wrapped by inject_db() was no longer a coroutine function, but instead
165-
a coroutine object, so we skipped marking it.
166-
"""
167-
168-
class Returner:
169-
@inject_db
170-
async def return_1(self, db):
171-
return 1
172-
173-
assert asyncio.iscoroutinefunction(Returner().return_1)
174-
175-
176153
async def test_inject_interface_class():
177154
class TestInterface(PrefectDBInterface):
178155
@property
@@ -182,3 +159,72 @@ def new_property(self):
182159
with dependencies.temporary_interface_class(TestInterface):
183160
db = dependencies.provide_database_interface()
184161
assert isinstance(db, TestInterface)
162+
163+
164+
class TestDBInject:
165+
@pytest.fixture(autouse=True)
166+
def _setup(self):
167+
self.db: PrefectDBInterface = dependencies.provide_database_interface()
168+
169+
def test_decorated_function(self):
170+
@dependencies.db_injector
171+
def function_with_injected_db(
172+
db: PrefectDBInterface, foo: int
173+
) -> PrefectDBInterface:
174+
"""The documentation is sublime"""
175+
return db
176+
177+
assert function_with_injected_db(42) is self.db
178+
179+
unwrapped = function_with_injected_db.__wrapped__
180+
assert function_with_injected_db.__doc__ == unwrapped.__doc__
181+
function_with_injected_db.__doc__ = "Something else"
182+
assert function_with_injected_db.__doc__ == "Something else"
183+
assert unwrapped.__doc__ == function_with_injected_db.__doc__
184+
del function_with_injected_db.__doc__
185+
assert function_with_injected_db.__doc__ is None
186+
assert unwrapped.__doc__ is function_with_injected_db.__doc__
187+
188+
class SomeClass:
189+
@dependencies.db_injector
190+
def method_with_injected_db(
191+
self, db: PrefectDBInterface, foo: int
192+
) -> PrefectDBInterface:
193+
"""The documentation is sublime"""
194+
return db
195+
196+
def test_decorated_method(self):
197+
instance = self.SomeClass()
198+
assert instance.method_with_injected_db(42) is self.db
199+
200+
def test_unbound_decorated_method(self):
201+
instance = self.SomeClass()
202+
# manually binding the unbound descriptor to an instance
203+
bound = self.SomeClass.method_with_injected_db.__get__(instance)
204+
assert bound(42) is self.db
205+
206+
def test_bound_method_attributes(self):
207+
instance = self.SomeClass()
208+
bound = instance.method_with_injected_db
209+
assert bound.__self__ is instance
210+
assert bound.__func__ is self.SomeClass.method_with_injected_db.__wrapped__
211+
212+
unwrapped = bound.__wrapped__
213+
assert bound.__doc__ == unwrapped.__doc__
214+
215+
before = bound.__doc__
216+
with pytest.raises(AttributeError, match="is not writable$"):
217+
bound.__doc__ = "Something else"
218+
with pytest.raises(AttributeError, match="is not writable$"):
219+
del bound.__doc__
220+
assert unwrapped.__doc__ == before
221+
222+
def test_decorated_coroutine_function(self):
223+
@dependencies.db_injector
224+
async def coroutine_with_injected_db(
225+
db: PrefectDBInterface, foo: int
226+
) -> PrefectDBInterface:
227+
return db
228+
229+
assert asyncio.iscoroutinefunction(coroutine_with_injected_db)
230+
assert asyncio.run(coroutine_with_injected_db(42)) is self.db

0 commit comments

Comments
 (0)