Skip to content

Commit 674d759

Browse files
committed
Add implementation of AsyncSession.run_sync()
Currently, `sqlmodel.ext.asyncio.session.AsyncSession` doesn't implement `run_sync()`, which means that any call to `run_sync()` on a sqlmodel `AsyncSession` will be dispatched to the parent `sqlalchemy.ext.asyncio.AsyncSession`. The first argument to sqlalchemy's `AsyncSession.run_sync()` is a callable whose first argument is a `sqlalchemy.orm.Session` object. If we're using this in a repo that uses sqlmodel, we'll actually be passing a callable whose first argument is a `sqlmodel.orm.session.Session`. In practice this works fine - because `sqlmodel.orm.session.Session` is derived from `sqlalchemy.orm.Session`, the implementation of `sqlalchemy.ext.asyncio.AsyncSession.run_sync()` can use the sqlmodel `Session` object in place of the sqlalchemy `Session` object. However, static analysers will complain that the argument to `run_sync()` is of the wrong type. For example, here's a warning from pyright: ``` Pyright: Error: Argument of type "(session: Session, id: UUID) -> int" cannot be assigned to parameter "fn" of type "(Session, **_P@run_sync) -> _T@run_sync" in function "run_sync"   Type "(session: Session, id: UUID) -> int" is not assignable to type "(Session, id: UUID) -> int"     Parameter 1: type "Session" is incompatible with type "Session"       "sqlalchemy.orm.session.Session" is not assignable to "sqlmodel.orm.session.Session" [reportArgumentType] ``` This commit implements a `run_sync()` method on `sqlmodel.ext.asyncio.session.AsyncSession`, which casts the callable to the correct type before dispatching it to the base class. This satisfies the static type checks.
1 parent 6c0410e commit 674d759

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

sqlmodel/ext/asyncio/session.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from typing import (
22
Any,
3+
Callable,
4+
Concatenate,
35
Dict,
46
Mapping,
57
Optional,
8+
ParamSpec,
69
Sequence,
710
Type,
811
TypeVar,
@@ -17,6 +20,7 @@
1720
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
1821
from sqlalchemy.ext.asyncio.result import _ensure_sync_result
1922
from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS
23+
from sqlalchemy.orm import Session as _Session
2024
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
2125
from sqlalchemy.sql.base import Executable as _Executable
2226
from sqlalchemy.util.concurrency import greenlet_spawn
@@ -26,6 +30,7 @@
2630
from ...sql.base import Executable
2731
from ...sql.expression import Select, SelectOfScalar
2832

33+
_P = ParamSpec("_P")
2934
_TSelectParam = TypeVar("_TSelectParam", bound=Any)
3035

3136

@@ -148,3 +153,17 @@ async def execute( # type: ignore
148153
_parent_execute_state=_parent_execute_state,
149154
_add_event=_add_event,
150155
)
156+
157+
async def run_sync(
158+
self,
159+
fn: Callable[Concatenate[Session, _P], _TSelectParam],
160+
*arg: _P.args,
161+
**kw: _P.kwargs,
162+
) -> _TSelectParam:
163+
base_fn = cast(Callable[Concatenate[_Session, _P], _TSelectParam], fn)
164+
165+
return await super().run_sync(
166+
base_fn,
167+
*arg,
168+
**kw,
169+
)

0 commit comments

Comments
 (0)