1
1
import asyncio
2
+ from datetime import timedelta , datetime
2
3
from unittest .mock import Mock , AsyncMock , PropertyMock
3
4
4
5
import pytest
6
+ from google .protobuf .timestamp_pb2 import Timestamp
7
+ from google .protobuf .duration import from_timedelta
5
8
6
- from cadence import Client
9
+ from cadence import activity , Client
7
10
from cadence ._internal .activity import ActivityExecutor
8
- from cadence .api .v1 .common_pb2 import WorkflowExecution , ActivityType , Payload , Failure
11
+ from cadence .activity import ActivityInfo
12
+ from cadence .api .v1 .common_pb2 import WorkflowExecution , ActivityType , Payload , Failure , WorkflowType
9
13
from cadence .api .v1 .service_worker_pb2 import RespondActivityTaskCompletedResponse , PollForActivityTaskResponse , \
10
14
RespondActivityTaskCompletedRequest , RespondActivityTaskFailedResponse , RespondActivityTaskFailedRequest
11
15
from cadence .data_converter import DefaultDataConverter
@@ -19,7 +23,6 @@ def client() -> Client:
19
23
return client
20
24
21
25
22
- @pytest .mark .asyncio
23
26
async def test_activity_async_success (client ):
24
27
worker_stub = client .worker_stub
25
28
worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
@@ -37,7 +40,6 @@ async def activity_fn():
37
40
identity = 'identity' ,
38
41
))
39
42
40
- @pytest .mark .asyncio
41
43
async def test_activity_async_failure (client ):
42
44
worker_stub = client .worker_stub
43
45
worker_stub .RespondActivityTaskFailed = AsyncMock (return_value = RespondActivityTaskFailedResponse ())
@@ -64,7 +66,6 @@ async def activity_fn():
64
66
identity = 'identity' ,
65
67
)
66
68
67
- @pytest .mark .asyncio
68
69
async def test_activity_args (client ):
69
70
worker_stub = client .worker_stub
70
71
worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
@@ -82,8 +83,6 @@ async def activity_fn(first: str, second: str):
82
83
identity = 'identity' ,
83
84
))
84
85
85
-
86
- @pytest .mark .asyncio
87
86
async def test_activity_sync_success (client ):
88
87
worker_stub = client .worker_stub
89
88
worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
@@ -105,7 +104,6 @@ def activity_fn():
105
104
identity = 'identity' ,
106
105
))
107
106
108
- @pytest .mark .asyncio
109
107
async def test_activity_sync_failure (client ):
110
108
worker_stub = client .worker_stub
111
109
worker_stub .RespondActivityTaskFailed = AsyncMock (return_value = RespondActivityTaskFailedResponse ())
@@ -132,7 +130,6 @@ def activity_fn():
132
130
identity = 'identity' ,
133
131
)
134
132
135
- @pytest .mark .asyncio
136
133
async def test_activity_unknown (client ):
137
134
worker_stub = client .worker_stub
138
135
worker_stub .RespondActivityTaskFailed = AsyncMock (return_value = RespondActivityTaskFailedResponse ())
@@ -148,7 +145,7 @@ def registry(name: str):
148
145
149
146
call = worker_stub .RespondActivityTaskFailed .call_args [0 ][0 ]
150
147
151
- assert 'unknown activity : any' in call .failure .details .decode ()
148
+ assert 'Activity type not found : any' in call .failure .details .decode ()
152
149
call .failure .details = bytes ()
153
150
assert call == RespondActivityTaskFailedRequest (
154
151
task_token = b'task_token' ,
@@ -158,15 +155,85 @@ def registry(name: str):
158
155
identity = 'identity' ,
159
156
)
160
157
158
+ async def test_activity_context (client ):
159
+ worker_stub = client .worker_stub
160
+ worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
161
+
162
+ async def activity_fn ():
163
+ assert fake_info ("activity_type" ) == activity .info ()
164
+ assert activity .in_activity ()
165
+ assert activity .client () is not None
166
+ return "success"
167
+
168
+ executor = ActivityExecutor (client , 'task_list' , 'identity' , 1 , lambda name : activity_fn )
169
+
170
+ await executor .execute (fake_task ("activity_type" , "" ))
171
+
172
+ worker_stub .RespondActivityTaskCompleted .assert_called_once_with (RespondActivityTaskCompletedRequest (
173
+ task_token = b'task_token' ,
174
+ result = Payload (data = '"success"' .encode ()),
175
+ identity = 'identity' ,
176
+ ))
177
+
178
+ async def test_activity_context_sync (client ):
179
+ worker_stub = client .worker_stub
180
+ worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
181
+
182
+ def activity_fn ():
183
+ assert fake_info ("activity_type" ) == activity .info ()
184
+ assert activity .in_activity ()
185
+ with pytest .raises (RuntimeError ):
186
+ activity .client ()
187
+ return "success"
188
+
189
+ executor = ActivityExecutor (client , 'task_list' , 'identity' , 1 , lambda name : activity_fn )
190
+
191
+ await executor .execute (fake_task ("activity_type" , "" ))
192
+
193
+ worker_stub .RespondActivityTaskCompleted .assert_called_once_with (RespondActivityTaskCompletedRequest (
194
+ task_token = b'task_token' ,
195
+ result = Payload (data = '"success"' .encode ()),
196
+ identity = 'identity' ,
197
+ ))
198
+
199
+
200
+ def fake_info (activity_type : str ) -> ActivityInfo :
201
+ return ActivityInfo (
202
+ task_token = b'task_token' ,
203
+ workflow_domain = "workflow_domain" ,
204
+ workflow_id = "workflow_id" ,
205
+ workflow_run_id = "run_id" ,
206
+ activity_id = "activity_id" ,
207
+ activity_type = activity_type ,
208
+ attempt = 1 ,
209
+ workflow_type = "workflow_type" ,
210
+ task_list = "task_list" ,
211
+ heartbeat_timeout = timedelta (seconds = 1 ),
212
+ scheduled_timestamp = datetime (2020 , 1 , 2 ,3 ),
213
+ started_timestamp = datetime (2020 , 1 , 2 ,4 ),
214
+ start_to_close_timeout = timedelta (seconds = 2 ),
215
+ )
216
+
161
217
def fake_task (activity_type : str , input_json : str ) -> PollForActivityTaskResponse :
162
218
return PollForActivityTaskResponse (
163
219
task_token = b'task_token' ,
220
+ workflow_domain = "workflow_domain" ,
221
+ workflow_type = WorkflowType (name = "workflow_type" ),
164
222
workflow_execution = WorkflowExecution (
165
223
workflow_id = "workflow_id" ,
166
224
run_id = "run_id" ,
167
225
),
168
226
activity_id = "activity_id" ,
169
227
activity_type = ActivityType (name = activity_type ),
170
228
input = Payload (data = input_json .encode ()),
171
- attempt = 0 ,
172
- )
229
+ attempt = 1 ,
230
+ heartbeat_timeout = from_timedelta (timedelta (seconds = 1 )),
231
+ scheduled_time = from_datetime (datetime (2020 , 1 , 2 , 3 )),
232
+ started_time = from_datetime (datetime (2020 , 1 , 2 , 4 )),
233
+ start_to_close_timeout = from_timedelta (timedelta (seconds = 2 )),
234
+ )
235
+
236
+ def from_datetime (time : datetime ) -> Timestamp :
237
+ t = Timestamp ()
238
+ t .FromDatetime (time )
239
+ return t
0 commit comments