Skip to content

Commit c9866a0

Browse files
committed
Add ActivityContext
User code needs to be able to retrieve information about the running activity, and in the future will be able to interact with it (heartbeating). Create the initial structure for providing it.
1 parent a5a257c commit c9866a0

File tree

4 files changed

+225
-34
lines changed

4 files changed

+225
-34
lines changed

cadence/_internal/activity/_activity_executor.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import asyncio
21
import inspect
32
from concurrent.futures import ThreadPoolExecutor
43
from logging import getLogger
54
from traceback import format_exception
65
from typing import Any, Callable
6+
from google.protobuf.duration import to_timedelta
7+
from google.protobuf.timestamp import to_datetime
78

8-
from cadence._internal.type_utils import get_fn_parameters
9+
from cadence._internal.activity._context import _Context, _SyncContext
10+
from cadence.activity import ActivityInfo
911
from cadence.api.v1.common_pb2 import Failure
1012
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \
1113
RespondActivityTaskCompletedRequest
@@ -19,35 +21,31 @@ def __init__(self, client: Client, task_list: str, identity: str, max_workers: i
1921
self._data_converter = client.data_converter
2022
self._registry = registry
2123
self._identity = identity
24+
self._task_list = task_list
2225
self._thread_pool = ThreadPoolExecutor(max_workers=max_workers,
2326
thread_name_prefix=f'{task_list}-activity-')
2427

2528
async def execute(self, task: PollForActivityTaskResponse):
26-
activity_type = task.activity_type.name
27-
try:
28-
activity_fn = self._registry(activity_type)
29-
except KeyError as e:
30-
_logger.error("Activity type not found.", extra={'activity_type': activity_type})
31-
await self._report_failure(task, e)
32-
return
33-
34-
await self._execute_fn(activity_fn, task)
35-
36-
async def _execute_fn(self, activity_fn: Callable[[Any], Any], task: PollForActivityTaskResponse):
3729
try:
38-
type_hints = get_fn_parameters(activity_fn)
39-
params = await self._client.data_converter.from_data(task.input, type_hints)
40-
if inspect.iscoroutinefunction(activity_fn):
41-
result = await activity_fn(*params)
42-
else:
43-
result = await self._invoke_sync_activity(activity_fn, params)
30+
context = self._create_context(task)
31+
result = await context.execute(task.input)
4432
await self._report_success(task, result)
4533
except Exception as e:
4634
await self._report_failure(task, e)
4735

48-
async def _invoke_sync_activity(self, activity_fn: Callable[[Any], Any], params: list[Any]) -> Any:
49-
loop = asyncio.get_running_loop()
50-
return await loop.run_in_executor(self._thread_pool, activity_fn, *params)
36+
def _create_context(self, task: PollForActivityTaskResponse) -> _Context:
37+
activity_type = task.activity_type.name
38+
try:
39+
activity_fn = self._registry(activity_type)
40+
except KeyError:
41+
raise KeyError(f"Activity type not found: {activity_type}") from None
42+
43+
info = self._create_info(task)
44+
45+
if inspect.iscoroutinefunction(activity_fn):
46+
return _Context(self._client, info, activity_fn)
47+
else:
48+
return _SyncContext(self._client, info, activity_fn, self._thread_pool)
5149

5250
async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception):
5351
try:
@@ -71,6 +69,23 @@ async def _report_success(self, task: PollForActivityTaskResponse, result: Any):
7169
except Exception:
7270
_logger.exception('Exception reporting activity complete')
7371

72+
def _create_info(self, task: PollForActivityTaskResponse) -> ActivityInfo:
73+
return ActivityInfo(
74+
task_token=task.task_token,
75+
workflow_type=task.workflow_type.name,
76+
workflow_domain=task.workflow_domain,
77+
workflow_id=task.workflow_execution.workflow_id,
78+
workflow_run_id=task.workflow_execution.run_id,
79+
activity_id=task.activity_id,
80+
activity_type=task.activity_type.name,
81+
task_list=self._task_list,
82+
heartbeat_timeout=to_timedelta(task.heartbeat_timeout),
83+
scheduled_timestamp=to_datetime(task.scheduled_time),
84+
started_timestamp=to_datetime(task.started_time),
85+
start_to_close_timeout=to_timedelta(task.start_to_close_timeout),
86+
attempt=task.attempt,
87+
)
88+
7489

7590
def _to_failure(exception: Exception) -> Failure:
7691
stacktrace = "".join(format_exception(exception))
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import asyncio
2+
from concurrent.futures.thread import ThreadPoolExecutor
3+
from typing import Callable, Any
4+
5+
from cadence import Client
6+
from cadence._internal.type_utils import get_fn_parameters
7+
from cadence.activity import ActivityInfo, ActivityContext
8+
from cadence.api.v1.common_pb2 import Payload
9+
10+
11+
class _Context(ActivityContext):
12+
def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any]):
13+
self._client = client
14+
self._info = info
15+
self._activity_fn = activity_fn
16+
17+
async def execute(self, payload: Payload) -> Any:
18+
params = await self._to_params(payload)
19+
with self._activate():
20+
return await self._activity_fn(*params)
21+
22+
async def _to_params(self, payload: Payload) -> list[Any]:
23+
type_hints = get_fn_parameters(self._activity_fn)
24+
return await self._client.data_converter.from_data(payload, type_hints)
25+
26+
def client(self) -> Client:
27+
return self._client
28+
29+
def info(self) -> ActivityInfo:
30+
return self._info
31+
32+
class _SyncContext(_Context):
33+
def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any], executor: ThreadPoolExecutor):
34+
super().__init__(client, info, activity_fn)
35+
self._executor = executor
36+
37+
async def execute(self, payload: Payload) -> Any:
38+
params = await self._to_params(payload)
39+
loop = asyncio.get_running_loop()
40+
return await loop.run_in_executor(self._executor, self._run, params)
41+
42+
def _run(self, args: list[Any]) -> Any:
43+
with self._activate():
44+
return self._activity_fn(*args)
45+
46+
def client(self) -> Client:
47+
raise RuntimeError("client is only supported in async activities")
48+

cadence/activity.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from abc import ABC, abstractmethod
2+
from contextlib import contextmanager
3+
from contextvars import ContextVar
4+
from dataclasses import dataclass
5+
from datetime import timedelta, datetime
6+
from typing import Iterator
7+
8+
from cadence import Client
9+
10+
11+
@dataclass(frozen=True)
12+
class ActivityInfo:
13+
task_token: bytes
14+
workflow_type: str
15+
workflow_domain: str
16+
workflow_id: str
17+
workflow_run_id: str
18+
activity_id: str
19+
activity_type: str
20+
task_list: str
21+
heartbeat_timeout: timedelta
22+
scheduled_timestamp: datetime
23+
started_timestamp: datetime
24+
start_to_close_timeout: timedelta
25+
attempt: int
26+
27+
def client() -> Client:
28+
return ActivityContext.get().client()
29+
30+
def in_activity() -> bool:
31+
return ActivityContext.is_set()
32+
33+
def info() -> ActivityInfo:
34+
return ActivityContext.get().info()
35+
36+
37+
38+
class ActivityContext(ABC):
39+
_var: ContextVar['ActivityContext'] = ContextVar("activity")
40+
41+
@abstractmethod
42+
def info(self) -> ActivityInfo:
43+
...
44+
45+
@abstractmethod
46+
def client(self) -> Client:
47+
...
48+
49+
@contextmanager
50+
def _activate(self) -> Iterator[None]:
51+
token = ActivityContext._var.set(self)
52+
yield None
53+
ActivityContext._var.reset(token)
54+
55+
@staticmethod
56+
def is_set() -> bool:
57+
return ActivityContext._var.get(None) is not None
58+
59+
@staticmethod
60+
def get() -> 'ActivityContext':
61+
return ActivityContext._var.get()

tests/cadence/_internal/activity/test_activity_executor.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import asyncio
2+
from datetime import timedelta, datetime
23
from unittest.mock import Mock, AsyncMock, PropertyMock
34

45
import pytest
6+
from google.protobuf.timestamp_pb2 import Timestamp
7+
from google.protobuf.duration import from_timedelta
58

6-
from cadence import Client
9+
from cadence import activity, Client
710
from cadence._internal.activity import ActivityExecutor
8-
from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure
11+
from cadence.activity import ActivityInfo
12+
from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure, WorkflowType
913
from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \
1014
RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest
1115
from cadence.data_converter import DefaultDataConverter
@@ -19,7 +23,6 @@ def client() -> Client:
1923
return client
2024

2125

22-
@pytest.mark.asyncio
2326
async def test_activity_async_success(client):
2427
worker_stub = client.worker_stub
2528
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
@@ -37,7 +40,6 @@ async def activity_fn():
3740
identity='identity',
3841
))
3942

40-
@pytest.mark.asyncio
4143
async def test_activity_async_failure(client):
4244
worker_stub = client.worker_stub
4345
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
@@ -64,7 +66,6 @@ async def activity_fn():
6466
identity='identity',
6567
)
6668

67-
@pytest.mark.asyncio
6869
async def test_activity_args(client):
6970
worker_stub = client.worker_stub
7071
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
@@ -82,8 +83,6 @@ async def activity_fn(first: str, second: str):
8283
identity='identity',
8384
))
8485

85-
86-
@pytest.mark.asyncio
8786
async def test_activity_sync_success(client):
8887
worker_stub = client.worker_stub
8988
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
@@ -105,7 +104,6 @@ def activity_fn():
105104
identity='identity',
106105
))
107106

108-
@pytest.mark.asyncio
109107
async def test_activity_sync_failure(client):
110108
worker_stub = client.worker_stub
111109
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
@@ -132,7 +130,6 @@ def activity_fn():
132130
identity='identity',
133131
)
134132

135-
@pytest.mark.asyncio
136133
async def test_activity_unknown(client):
137134
worker_stub = client.worker_stub
138135
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
@@ -148,7 +145,7 @@ def registry(name: str):
148145

149146
call = worker_stub.RespondActivityTaskFailed.call_args[0][0]
150147

151-
assert 'unknown activity: any' in call.failure.details.decode()
148+
assert 'Activity type not found: any' in call.failure.details.decode()
152149
call.failure.details = bytes()
153150
assert call == RespondActivityTaskFailedRequest(
154151
task_token=b'task_token',
@@ -158,15 +155,85 @@ def registry(name: str):
158155
identity='identity',
159156
)
160157

158+
async def test_activity_context(client):
159+
worker_stub = client.worker_stub
160+
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
161+
162+
async def activity_fn():
163+
assert fake_info("activity_type") == activity.info()
164+
assert activity.in_activity()
165+
assert activity.client() is not None
166+
return "success"
167+
168+
executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn)
169+
170+
await executor.execute(fake_task("activity_type", ""))
171+
172+
worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest(
173+
task_token=b'task_token',
174+
result=Payload(data='"success"'.encode()),
175+
identity='identity',
176+
))
177+
178+
async def test_activity_context_sync(client):
179+
worker_stub = client.worker_stub
180+
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
181+
182+
def activity_fn():
183+
assert fake_info("activity_type") == activity.info()
184+
assert activity.in_activity()
185+
with pytest.raises(RuntimeError):
186+
activity.client()
187+
return "success"
188+
189+
executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn)
190+
191+
await executor.execute(fake_task("activity_type", ""))
192+
193+
worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest(
194+
task_token=b'task_token',
195+
result=Payload(data='"success"'.encode()),
196+
identity='identity',
197+
))
198+
199+
200+
def fake_info(activity_type: str) -> ActivityInfo:
201+
return ActivityInfo(
202+
task_token=b'task_token',
203+
workflow_domain="workflow_domain",
204+
workflow_id="workflow_id",
205+
workflow_run_id="run_id",
206+
activity_id="activity_id",
207+
activity_type=activity_type,
208+
attempt=1,
209+
workflow_type="workflow_type",
210+
task_list="task_list",
211+
heartbeat_timeout=timedelta(seconds=1),
212+
scheduled_timestamp=datetime(2020, 1, 2 ,3),
213+
started_timestamp=datetime(2020, 1, 2 ,4),
214+
start_to_close_timeout=timedelta(seconds=2),
215+
)
216+
161217
def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse:
162218
return PollForActivityTaskResponse(
163219
task_token=b'task_token',
220+
workflow_domain="workflow_domain",
221+
workflow_type=WorkflowType(name="workflow_type"),
164222
workflow_execution=WorkflowExecution(
165223
workflow_id="workflow_id",
166224
run_id="run_id",
167225
),
168226
activity_id="activity_id",
169227
activity_type=ActivityType(name=activity_type),
170228
input=Payload(data=input_json.encode()),
171-
attempt=0,
172-
)
229+
attempt=1,
230+
heartbeat_timeout=from_timedelta(timedelta(seconds=1)),
231+
scheduled_time=from_datetime(datetime(2020, 1, 2, 3)),
232+
started_time=from_datetime(datetime(2020, 1, 2, 4)),
233+
start_to_close_timeout=from_timedelta(timedelta(seconds=2)),
234+
)
235+
236+
def from_datetime(time: datetime) -> Timestamp:
237+
t = Timestamp()
238+
t.FromDatetime(time)
239+
return t

0 commit comments

Comments
 (0)