Skip to content

Commit 20bb869

Browse files
authored
Pin nexusrpc and revert recent nexus updates (#1006)
* Pin nexus * Revert "Install nexusrpc from GitHub (#966)" This reverts commit 808a5f4.
1 parent 3cd7189 commit 20bb869

File tree

11 files changed

+147
-51
lines changed

11 files changed

+147
-51
lines changed

.github/workflows/build-binaries.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,3 @@ jobs:
7474
with:
7575
name: packages-${{ matrix.package-suffix }}
7676
path: dist
77-
78-
- name: Deliberately fail to prevent releasing nexus-rpc w/ GitHub link in pyproject.toml
79-
run: |
80-
echo "This is a deliberate failure to prevent releasing nexus-rpc with a GitHub link in pyproject.toml"
81-
exit 1

pyproject.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ keywords = [
1111
"workflow",
1212
]
1313
dependencies = [
14-
"nexus-rpc>=1.1.0",
14+
"nexus-rpc==1.1.0",
1515
"protobuf>=3.20,<6",
1616
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
1717
"types-protobuf>=3.20",
@@ -231,6 +231,3 @@ exclude = [
231231
[tool.uv]
232232
# Prevent uv commands from building the package by default
233233
package = false
234-
235-
[tool.uv.sources]
236-
nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python.git", rev = "35f574c711193a6e2560d3e6665732a5bb7ae92c" }

temporalio/nexus/_decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ async def _start(
123123
return WorkflowRunOperationHandler(_start, input_type, output_type)
124124

125125
method_name = get_callable_name(start)
126-
nexusrpc.set_operation(
126+
nexusrpc.set_operation_definition(
127127
operation_handler_factory,
128128
nexusrpc.Operation(
129129
name=name or method_name,

temporalio/nexus/_util.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
TypeVar,
1414
)
1515

16+
import nexusrpc
1617
from nexusrpc import (
1718
InputT,
1819
OutputT,
@@ -78,10 +79,19 @@ def _get_start_method_input_and_output_type_annotations(
7879
try:
7980
type_annotations = typing.get_type_hints(start)
8081
except TypeError:
82+
warnings.warn(
83+
f"Expected decorated start method {start} to have type annotations"
84+
)
8185
return None, None
8286
output_type = type_annotations.pop("return", None)
8387

8488
if len(type_annotations) != 2:
89+
suffix = f": {type_annotations}" if type_annotations else ""
90+
warnings.warn(
91+
f"Expected decorated start method {start} to have exactly 2 "
92+
f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}"
93+
f"{suffix}."
94+
)
8595
input_type = None
8696
else:
8797
ctx_type, input_type = type_annotations.values()
@@ -108,6 +118,28 @@ def get_callable_name(fn: Callable[..., Any]) -> str:
108118
return method_name
109119

110120

121+
# TODO(nexus-preview) Copied from nexusrpc
122+
def get_operation_factory(
123+
obj: Any,
124+
) -> tuple[
125+
Optional[Callable[[Any], Any]],
126+
Optional[nexusrpc.Operation[Any, Any]],
127+
]:
128+
"""Return the :py:class:`Operation` for the object along with the factory function.
129+
130+
``obj`` should be a decorated operation start method.
131+
"""
132+
op_defn = nexusrpc.get_operation_definition(obj)
133+
if op_defn:
134+
factory = obj
135+
else:
136+
if factory := getattr(obj, "__nexus_operation_factory__", None):
137+
op_defn = nexusrpc.get_operation_definition(factory)
138+
if not isinstance(op_defn, nexusrpc.Operation):
139+
return None, None
140+
return factory, op_defn
141+
142+
111143
# TODO(nexus-preview) Copied from nexusrpc
112144
def set_operation_factory(
113145
obj: Any,

temporalio/worker/_interceptor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,14 +299,15 @@ class StartNexusOperationInput(Generic[InputT, OutputT]):
299299
input: InputT
300300
schedule_to_close_timeout: Optional[timedelta]
301301
headers: Optional[Mapping[str, str]]
302-
output_type: Optional[type[OutputT]] = None
302+
output_type: Optional[Type[OutputT]] = None
303303

304304
def __post_init__(self) -> None:
305305
"""Initialize operation-specific attributes after dataclass creation."""
306306
if isinstance(self.operation, nexusrpc.Operation):
307307
self.output_type = self.operation.output_type
308308
elif callable(self.operation):
309-
if op := nexusrpc.get_operation(self.operation):
309+
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
310+
if isinstance(op, nexusrpc.Operation):
310311
self.output_type = op.output_type
311312
else:
312313
raise ValueError(
@@ -325,7 +326,8 @@ def operation_name(self) -> str:
325326
elif isinstance(self.operation, str):
326327
return self.operation
327328
elif callable(self.operation):
328-
if op := nexusrpc.get_operation(self.operation):
329+
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
330+
if isinstance(op, nexusrpc.Operation):
329331
return op.name
330332
else:
331333
raise ValueError(

temporalio/workflow.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5145,7 +5145,7 @@ async def start_operation(
51455145
operation: nexusrpc.Operation[InputT, OutputT],
51465146
input: InputT,
51475147
*,
5148-
output_type: Optional[type[OutputT]] = None,
5148+
output_type: Optional[Type[OutputT]] = None,
51495149
schedule_to_close_timeout: Optional[timedelta] = None,
51505150
headers: Optional[Mapping[str, str]] = None,
51515151
) -> NexusOperationHandle[OutputT]: ...
@@ -5158,7 +5158,7 @@ async def start_operation(
51585158
operation: str,
51595159
input: Any,
51605160
*,
5161-
output_type: Optional[type[OutputT]] = None,
5161+
output_type: Optional[Type[OutputT]] = None,
51625162
schedule_to_close_timeout: Optional[timedelta] = None,
51635163
headers: Optional[Mapping[str, str]] = None,
51645164
) -> NexusOperationHandle[OutputT]: ...
@@ -5174,7 +5174,7 @@ async def start_operation(
51745174
],
51755175
input: InputT,
51765176
*,
5177-
output_type: Optional[type[OutputT]] = None,
5177+
output_type: Optional[Type[OutputT]] = None,
51785178
schedule_to_close_timeout: Optional[timedelta] = None,
51795179
headers: Optional[Mapping[str, str]] = None,
51805180
) -> NexusOperationHandle[OutputT]: ...
@@ -5190,7 +5190,7 @@ async def start_operation(
51905190
],
51915191
input: InputT,
51925192
*,
5193-
output_type: Optional[type[OutputT]] = None,
5193+
output_type: Optional[Type[OutputT]] = None,
51945194
schedule_to_close_timeout: Optional[timedelta] = None,
51955195
headers: Optional[Mapping[str, str]] = None,
51965196
) -> NexusOperationHandle[OutputT]: ...
@@ -5206,7 +5206,7 @@ async def start_operation(
52065206
],
52075207
input: InputT,
52085208
*,
5209-
output_type: Optional[type[OutputT]] = None,
5209+
output_type: Optional[Type[OutputT]] = None,
52105210
schedule_to_close_timeout: Optional[timedelta] = None,
52115211
headers: Optional[Mapping[str, str]] = None,
52125212
) -> NexusOperationHandle[OutputT]: ...
@@ -5217,7 +5217,7 @@ async def start_operation(
52175217
operation: Any,
52185218
input: Any,
52195219
*,
5220-
output_type: Optional[type[OutputT]] = None,
5220+
output_type: Optional[Type[OutputT]] = None,
52215221
schedule_to_close_timeout: Optional[timedelta] = None,
52225222
headers: Optional[Mapping[str, str]] = None,
52235223
) -> Any:
@@ -5246,7 +5246,7 @@ async def execute_operation(
52465246
operation: nexusrpc.Operation[InputT, OutputT],
52475247
input: InputT,
52485248
*,
5249-
output_type: Optional[type[OutputT]] = None,
5249+
output_type: Optional[Type[OutputT]] = None,
52505250
schedule_to_close_timeout: Optional[timedelta] = None,
52515251
headers: Optional[Mapping[str, str]] = None,
52525252
) -> OutputT: ...
@@ -5259,7 +5259,7 @@ async def execute_operation(
52595259
operation: str,
52605260
input: Any,
52615261
*,
5262-
output_type: Optional[type[OutputT]] = None,
5262+
output_type: Optional[Type[OutputT]] = None,
52635263
schedule_to_close_timeout: Optional[timedelta] = None,
52645264
headers: Optional[Mapping[str, str]] = None,
52655265
) -> OutputT: ...
@@ -5275,7 +5275,7 @@ async def execute_operation(
52755275
],
52765276
input: InputT,
52775277
*,
5278-
output_type: Optional[type[OutputT]] = None,
5278+
output_type: Optional[Type[OutputT]] = None,
52795279
schedule_to_close_timeout: Optional[timedelta] = None,
52805280
headers: Optional[Mapping[str, str]] = None,
52815281
) -> OutputT: ...
@@ -5294,7 +5294,7 @@ async def execute_operation(
52945294
],
52955295
input: InputT,
52965296
*,
5297-
output_type: Optional[type[OutputT]] = None,
5297+
output_type: Optional[Type[OutputT]] = None,
52985298
schedule_to_close_timeout: Optional[timedelta] = None,
52995299
headers: Optional[Mapping[str, str]] = None,
53005300
) -> OutputT: ...
@@ -5310,7 +5310,7 @@ async def execute_operation(
53105310
],
53115311
input: InputT,
53125312
*,
5313-
output_type: Optional[type[OutputT]] = None,
5313+
output_type: Optional[Type[OutputT]] = None,
53145314
schedule_to_close_timeout: Optional[timedelta] = None,
53155315
headers: Optional[Mapping[str, str]] = None,
53165316
) -> OutputT: ...
@@ -5321,7 +5321,7 @@ async def execute_operation(
53215321
operation: Any,
53225322
input: Any,
53235323
*,
5324-
output_type: Optional[type[OutputT]] = None,
5324+
output_type: Optional[Type[OutputT]] = None,
53255325
schedule_to_close_timeout: Optional[timedelta] = None,
53265326
headers: Optional[Mapping[str, str]] = None,
53275327
) -> Any:
@@ -5345,7 +5345,7 @@ def __init__(
53455345
self,
53465346
*,
53475347
endpoint: str,
5348-
service: Union[type[ServiceT], str],
5348+
service: Union[Type[ServiceT], str],
53495349
) -> None:
53505350
"""Create a Nexus client.
53515351
@@ -5372,7 +5372,7 @@ async def start_operation(
53725372
operation: Any,
53735373
input: Any,
53745374
*,
5375-
output_type: Optional[type] = None,
5375+
output_type: Optional[Type] = None,
53765376
schedule_to_close_timeout: Optional[timedelta] = None,
53775377
headers: Optional[Mapping[str, str]] = None,
53785378
) -> Any:
@@ -5393,7 +5393,7 @@ async def execute_operation(
53935393
operation: Any,
53945394
input: Any,
53955395
*,
5396-
output_type: Optional[type] = None,
5396+
output_type: Optional[Type] = None,
53975397
schedule_to_close_timeout: Optional[timedelta] = None,
53985398
headers: Optional[Mapping[str, str]] = None,
53995399
) -> Any:
@@ -5410,7 +5410,7 @@ async def execute_operation(
54105410
@overload
54115411
def create_nexus_client(
54125412
*,
5413-
service: type[ServiceT],
5413+
service: Type[ServiceT],
54145414
endpoint: str,
54155415
) -> NexusClient[ServiceT]: ...
54165416

@@ -5425,9 +5425,9 @@ def create_nexus_client(
54255425

54265426
def create_nexus_client(
54275427
*,
5428-
service: Union[type[ServiceT], str],
5428+
service: Union[Type[ServiceT], str],
54295429
endpoint: str,
5430-
) -> NexusClient[Any]:
5430+
) -> NexusClient[ServiceT]:
54315431
"""Create a Nexus client.
54325432
54335433
.. warning::

tests/nexus/test_dynamic_creation_of_user_handler_classes.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import httpx
44
import nexusrpc.handler
55
import pytest
6+
from nexusrpc.handler import sync_operation
67

78
from temporalio import nexus, workflow
89
from temporalio.client import Client
10+
from temporalio.nexus._util import get_operation_factory
911
from temporalio.testing import WorkflowEnvironment
1012
from temporalio.worker import Worker
1113
from tests.helpers.nexus import ServiceClient, create_nexus_endpoint
@@ -76,8 +78,8 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
7678
service_handler = nexusrpc.handler._core.ServiceHandler(
7779
service=nexusrpc.ServiceDefinition(
7880
name="MyService",
79-
operation_definitions={
80-
"increment": nexusrpc.OperationDefinition[int, int](
81+
operations={
82+
"increment": nexusrpc.Operation[int, int](
8183
name="increment",
8284
method_name="increment",
8385
input_type=int,
@@ -105,3 +107,70 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
105107
json=1,
106108
)
107109
assert response.status_code == 201
110+
111+
112+
def make_incrementer_user_service_definition_and_service_handler_classes(
113+
op_names: list[str],
114+
) -> tuple[type, type]:
115+
#
116+
# service contract
117+
#
118+
119+
ops = {name: nexusrpc.Operation[int, int] for name in op_names}
120+
service_cls: type = nexusrpc.service(type("ServiceContract", (), ops))
121+
122+
#
123+
# service handler
124+
#
125+
@sync_operation
126+
async def _increment_op(
127+
self,
128+
ctx: nexusrpc.handler.StartOperationContext,
129+
input: int,
130+
) -> int:
131+
return input + 1
132+
133+
op_handler_factories = {}
134+
for name in op_names:
135+
op_handler_factory, _ = get_operation_factory(_increment_op)
136+
assert op_handler_factory
137+
op_handler_factories[name] = op_handler_factory
138+
139+
handler_cls: type = nexusrpc.handler.service_handler(service=service_cls)(
140+
type("ServiceImpl", (), op_handler_factories)
141+
)
142+
143+
return service_cls, handler_cls
144+
145+
146+
@pytest.mark.skip(
147+
reason="Dynamic creation of service contract using type() is not supported"
148+
)
149+
async def test_dynamic_creation_of_user_handler_classes(
150+
client: Client, env: WorkflowEnvironment
151+
):
152+
task_queue = str(uuid.uuid4())
153+
154+
service_cls, handler_cls = (
155+
make_incrementer_user_service_definition_and_service_handler_classes(
156+
["increment"]
157+
)
158+
)
159+
160+
assert (service_defn := nexusrpc.get_service_definition(service_cls))
161+
service_name = service_defn.name
162+
163+
endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id
164+
async with Worker(
165+
client,
166+
task_queue=task_queue,
167+
nexus_service_handlers=[handler_cls()],
168+
):
169+
server_address = ServiceClient.default_server_address(env)
170+
async with httpx.AsyncClient() as http_client:
171+
response = await http_client.post(
172+
f"http://{server_address}/nexus/endpoints/{endpoint}/services/{service_name}/increment",
173+
json=1,
174+
)
175+
assert response.status_code == 200
176+
assert response.json() == 2

0 commit comments

Comments
 (0)