Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 37 additions & 22 deletions cadence/_internal/activity/_activity_executor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import inspect
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from traceback import format_exception
from typing import Any, Callable
from google.protobuf.duration import to_timedelta
from google.protobuf.timestamp import to_datetime

from cadence._internal.type_utils import get_fn_parameters
from cadence._internal.activity._context import _Context, _SyncContext
from cadence.activity import ActivityInfo
from cadence.api.v1.common_pb2 import Failure
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \
RespondActivityTaskCompletedRequest
Expand All @@ -19,35 +21,31 @@ def __init__(self, client: Client, task_list: str, identity: str, max_workers: i
self._data_converter = client.data_converter
self._registry = registry
self._identity = identity
self._task_list = task_list
self._thread_pool = ThreadPoolExecutor(max_workers=max_workers,
thread_name_prefix=f'{task_list}-activity-')

async def execute(self, task: PollForActivityTaskResponse):
activity_type = task.activity_type.name
try:
activity_fn = self._registry(activity_type)
except KeyError as e:
_logger.error("Activity type not found.", extra={'activity_type': activity_type})
await self._report_failure(task, e)
return

await self._execute_fn(activity_fn, task)

async def _execute_fn(self, activity_fn: Callable[[Any], Any], task: PollForActivityTaskResponse):
try:
type_hints = get_fn_parameters(activity_fn)
params = await self._client.data_converter.from_data(task.input, type_hints)
if inspect.iscoroutinefunction(activity_fn):
result = await activity_fn(*params)
else:
result = await self._invoke_sync_activity(activity_fn, params)
context = self._create_context(task)
result = await context.execute(task.input)
await self._report_success(task, result)
except Exception as e:
await self._report_failure(task, e)

async def _invoke_sync_activity(self, activity_fn: Callable[[Any], Any], params: list[Any]) -> Any:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._thread_pool, activity_fn, *params)
def _create_context(self, task: PollForActivityTaskResponse) -> _Context:
activity_type = task.activity_type.name
try:
activity_fn = self._registry(activity_type)
except KeyError:
raise KeyError(f"Activity type not found: {activity_type}") from None

info = self._create_info(task)

if inspect.iscoroutinefunction(activity_fn):
return _Context(self._client, info, activity_fn)
else:
return _SyncContext(self._client, info, activity_fn, self._thread_pool)

async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception):
try:
Expand All @@ -71,6 +69,23 @@ async def _report_success(self, task: PollForActivityTaskResponse, result: Any):
except Exception:
_logger.exception('Exception reporting activity complete')

def _create_info(self, task: PollForActivityTaskResponse) -> ActivityInfo:
return ActivityInfo(
task_token=task.task_token,
workflow_type=task.workflow_type.name,
workflow_domain=task.workflow_domain,
workflow_id=task.workflow_execution.workflow_id,
workflow_run_id=task.workflow_execution.run_id,
activity_id=task.activity_id,
activity_type=task.activity_type.name,
task_list=self._task_list,
heartbeat_timeout=to_timedelta(task.heartbeat_timeout),
scheduled_timestamp=to_datetime(task.scheduled_time),
started_timestamp=to_datetime(task.started_time),
start_to_close_timeout=to_timedelta(task.start_to_close_timeout),
attempt=task.attempt,
)


def _to_failure(exception: Exception) -> Failure:
stacktrace = "".join(format_exception(exception))
Expand Down
48 changes: 48 additions & 0 deletions cadence/_internal/activity/_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import asyncio
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Callable, Any

from cadence import Client
from cadence._internal.type_utils import get_fn_parameters
from cadence.activity import ActivityInfo, ActivityContext
from cadence.api.v1.common_pb2 import Payload


class _Context(ActivityContext):
def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any]):
self._client = client
self._info = info
self._activity_fn = activity_fn

async def execute(self, payload: Payload) -> Any:
params = await self._to_params(payload)
with self._activate():
return await self._activity_fn(*params)

async def _to_params(self, payload: Payload) -> list[Any]:
type_hints = get_fn_parameters(self._activity_fn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: store in self._activity_fn_args_type_hints to avoid unnecessary evaluations.

return await self._client.data_converter.from_data(payload, type_hints)

def client(self) -> Client:
return self._client

def info(self) -> ActivityInfo:
return self._info

class _SyncContext(_Context):
def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any], executor: ThreadPoolExecutor):
super().__init__(client, info, activity_fn)
self._executor = executor

async def execute(self, payload: Payload) -> Any:
params = await self._to_params(payload)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._executor, self._run, params)

def _run(self, args: list[Any]) -> Any:
with self._activate():
return self._activity_fn(*args)

def client(self) -> Client:
raise RuntimeError("client is only supported in async activities")

61 changes: 61 additions & 0 deletions cadence/activity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import timedelta, datetime
from typing import Iterator

from cadence import Client


@dataclass(frozen=True)
class ActivityInfo:
task_token: bytes
workflow_type: str
workflow_domain: str
workflow_id: str
workflow_run_id: str
activity_id: str
activity_type: str
task_list: str
heartbeat_timeout: timedelta
scheduled_timestamp: datetime
started_timestamp: datetime
start_to_close_timeout: timedelta
attempt: int

def client() -> Client:
return ActivityContext.get().client()

def in_activity() -> bool:
return ActivityContext.is_set()

def info() -> ActivityInfo:
return ActivityContext.get().info()



class ActivityContext(ABC):
_var: ContextVar['ActivityContext'] = ContextVar("activity")

@abstractmethod
def info(self) -> ActivityInfo:
...

@abstractmethod
def client(self) -> Client:
...

@contextmanager
def _activate(self) -> Iterator[None]:
token = ActivityContext._var.set(self)
yield None
ActivityContext._var.reset(token)

@staticmethod
def is_set() -> bool:
return ActivityContext._var.get(None) is not None

@staticmethod
def get() -> 'ActivityContext':
return ActivityContext._var.get()
91 changes: 79 additions & 12 deletions tests/cadence/_internal/activity/test_activity_executor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
from datetime import timedelta, datetime
from unittest.mock import Mock, AsyncMock, PropertyMock

import pytest
from google.protobuf.timestamp_pb2 import Timestamp
from google.protobuf.duration import from_timedelta

from cadence import Client
from cadence import activity, Client
from cadence._internal.activity import ActivityExecutor
from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure
from cadence.activity import ActivityInfo
from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure, WorkflowType
from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \
RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest
from cadence.data_converter import DefaultDataConverter
Expand All @@ -19,7 +23,6 @@ def client() -> Client:
return client


@pytest.mark.asyncio
async def test_activity_async_success(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
Expand All @@ -37,7 +40,6 @@ async def activity_fn():
identity='identity',
))

@pytest.mark.asyncio
async def test_activity_async_failure(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
Expand All @@ -64,7 +66,6 @@ async def activity_fn():
identity='identity',
)

@pytest.mark.asyncio
async def test_activity_args(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
Expand All @@ -82,8 +83,6 @@ async def activity_fn(first: str, second: str):
identity='identity',
))


@pytest.mark.asyncio
async def test_activity_sync_success(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
Expand All @@ -105,7 +104,6 @@ def activity_fn():
identity='identity',
))

@pytest.mark.asyncio
async def test_activity_sync_failure(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
Expand All @@ -132,7 +130,6 @@ def activity_fn():
identity='identity',
)

@pytest.mark.asyncio
async def test_activity_unknown(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
Expand All @@ -148,7 +145,7 @@ def registry(name: str):

call = worker_stub.RespondActivityTaskFailed.call_args[0][0]

assert 'unknown activity: any' in call.failure.details.decode()
assert 'Activity type not found: any' in call.failure.details.decode()
call.failure.details = bytes()
assert call == RespondActivityTaskFailedRequest(
task_token=b'task_token',
Expand All @@ -158,15 +155,85 @@ def registry(name: str):
identity='identity',
)

async def test_activity_context(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())

async def activity_fn():
assert fake_info("activity_type") == activity.info()
assert activity.in_activity()
assert activity.client() is not None
return "success"

executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn)

await executor.execute(fake_task("activity_type", ""))

worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest(
task_token=b'task_token',
result=Payload(data='"success"'.encode()),
identity='identity',
))

async def test_activity_context_sync(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())

def activity_fn():
assert fake_info("activity_type") == activity.info()
assert activity.in_activity()
with pytest.raises(RuntimeError):
activity.client()
return "success"

executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn)

await executor.execute(fake_task("activity_type", ""))

worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest(
task_token=b'task_token',
result=Payload(data='"success"'.encode()),
identity='identity',
))


def fake_info(activity_type: str) -> ActivityInfo:
return ActivityInfo(
task_token=b'task_token',
workflow_domain="workflow_domain",
workflow_id="workflow_id",
workflow_run_id="run_id",
activity_id="activity_id",
activity_type=activity_type,
attempt=1,
workflow_type="workflow_type",
task_list="task_list",
heartbeat_timeout=timedelta(seconds=1),
scheduled_timestamp=datetime(2020, 1, 2 ,3),
started_timestamp=datetime(2020, 1, 2 ,4),
start_to_close_timeout=timedelta(seconds=2),
)

def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse:
return PollForActivityTaskResponse(
task_token=b'task_token',
workflow_domain="workflow_domain",
workflow_type=WorkflowType(name="workflow_type"),
workflow_execution=WorkflowExecution(
workflow_id="workflow_id",
run_id="run_id",
),
activity_id="activity_id",
activity_type=ActivityType(name=activity_type),
input=Payload(data=input_json.encode()),
attempt=0,
)
attempt=1,
heartbeat_timeout=from_timedelta(timedelta(seconds=1)),
scheduled_time=from_datetime(datetime(2020, 1, 2, 3)),
started_time=from_datetime(datetime(2020, 1, 2, 4)),
start_to_close_timeout=from_timedelta(timedelta(seconds=2)),
)

def from_datetime(time: datetime) -> Timestamp:
t = Timestamp()
t.FromDatetime(time)
return t