diff --git a/channels_graphql_ws/graphql_ws_consumer.py b/channels_graphql_ws/graphql_ws_consumer.py index 575ad58..37ac5b5 100644 --- a/channels_graphql_ws/graphql_ws_consumer.py +++ b/channels_graphql_ws/graphql_ws_consumer.py @@ -58,7 +58,7 @@ import promise import rx -from .scope_as_context import ScopeAsContext +from .operation_context import OperationContext from .serializer import Serializer # Module logger. @@ -547,7 +547,7 @@ async def _on_gql_start(self, operation_id, payload): # Create object-like context (like in `Query` or `Mutation`) # from the dict-like one provided by the Channels. - context = ScopeAsContext(self.scope) + context = OperationContext(self.scope) # Adding channel name to the context because it seems to be # useful for some use cases, take a loot at the issue from @@ -671,7 +671,12 @@ def register_middleware(next_middleware, root, info, *args, **kwds): await self._send_gql_complete(operation_id) async def _register_subscription( - self, operation_id, groups, publish_callback, unsubscribed_callback + self, + operation_id, + groups, + publish_callback, + unsubscribed_callback, + initial_payload, ): """Register a new subscription when client subscribes. @@ -701,61 +706,74 @@ async def _register_subscription( # `_sids_by_group` without any locks. self._assert_thread() - # The subject we will trigger on the `broadcast` message. - trigger = rx.subjects.Subject() - # The subscription notification queue. notification_queue = asyncio.Queue( maxsize=self.subscription_notification_queue_limit ) + # Enqueue the initial payload. + if initial_payload is not self.SKIP: + notification_queue.put_nowait(Serializer.serialize(initial_payload)) + # Start an endless task which listens the `notification_queue` # and invokes subscription "resolver" on new notifications. - async def notifier(): + async def notifier(observer: rx.Observer): """Watch the notification queue and notify clients.""" # Assert we run in a proper thread. self._assert_thread() while True: - payload = await notification_queue.get() + serialized_payload = await notification_queue.get() + # Run a subscription's `publish` method (invoked by the - # `trigger.on_next` function) within the threadpool used + # `observer.on_next` function) within the threadpool used # for processing other GraphQL resolver functions. - # NOTE: `lambda` is important to run the deserialization + # NOTE: it is important to run the deserialization # in the worker thread as well. - await self._run_in_worker( - lambda: trigger.on_next(Serializer.deserialize(payload)) - ) + def workload(): + try: + payload = Serializer.deserialize(serialized_payload) + except Exception as ex: # pylint: disable=broad-except + observer.on_error(f"Cannot deserialize payload. {ex}") + else: + observer.on_next(payload) + + await self._run_in_worker(workload) + # Message processed. This allows `Queue.join` to work. notification_queue.task_done() - # Enqueue the `publish` method execution. But do not notify - # clients when `publish` returns `SKIP`. - stream = trigger.map(publish_callback).filter( # pylint: disable=no-member - lambda publish_returned: publish_returned is not self.SKIP - ) + def push_payloads(observer: rx.Observer): + # Start listening for broadcasts (subscribe to the Channels + # groups), spawn the notification processing task and put + # subscription information into the registry. + # NOTE: Update of `_sids_by_group` & `_subscriptions` must be + # atomic i.e. without `awaits` in between. + for group in groups: + self._sids_by_group.setdefault(group, []).append(operation_id) + notifier_task = self._spawn_background_task(notifier(observer)) + self._subscriptions[operation_id] = self._SubInf( + groups=groups, + sid=operation_id, + unsubscribed_callback=unsubscribed_callback, + notification_queue=notification_queue, + notifier_task=notifier_task, + ) - # Start listening for broadcasts (subscribe to the Channels - # groups), spawn the notification processing task and put - # subscription information into the registry. - # NOTE: Update of `_sids_by_group` & `_subscriptions` must be - # atomic i.e. without `awaits` in between. - waitlist = [] - for group in groups: - self._sids_by_group.setdefault(group, []).append(operation_id) - waitlist.append(self._channel_layer.group_add(group, self.channel_name)) - notifier_task = self._spawn_background_task(notifier()) - self._subscriptions[operation_id] = self._SubInf( - groups=groups, - sid=operation_id, - unsubscribed_callback=unsubscribed_callback, - notification_queue=notification_queue, - notifier_task=notifier_task, + await asyncio.wait( + [ + self._channel_layer.group_add(group, self.channel_name) + for group in groups + ] ) - await asyncio.wait(waitlist) - - return stream + # Enqueue the `publish` method execution. But do not notify + # clients when `publish` returns `SKIP`. + return ( + rx.Observable.create(push_payloads) # pylint: disable=no-member + .map(publish_callback) + .filter(lambda publish_returned: publish_returned is not self.SKIP) + ) async def _on_gql_stop(self, operation_id): """Process the STOP message. diff --git a/channels_graphql_ws/operation_context.py b/channels_graphql_ws/operation_context.py new file mode 100644 index 0000000..f4e253a --- /dev/null +++ b/channels_graphql_ws/operation_context.py @@ -0,0 +1,33 @@ +"""Just `OperationContext` class.""" + +from channels_graphql_ws.scope_as_context import ScopeAsContext + + +class OperationContext(ScopeAsContext): + """ + The context intended to use in methods of Graphene classes as `info.context`. + + This class provides two public properties: + 1. `scope` - per-connection context. This is the `scope` of Django Channels. + 2. `operation_context` - per-operation context. Empty. Fill free to store your's + data here. + + For backward compatibility: + - Method `_asdict` returns the `scope`. + - Other attributes are routed to the `scope`. + """ + + def __init__(self, scope: dict): + """Nothing interesting here.""" + super().__init__(scope) + self._operation_context: dict = {} + + @property + def scope(self) -> dict: + """Return the scope.""" + return self._scope + + @property + def operation_context(self) -> dict: + """Return the per-operation context.""" + return self._operation_context diff --git a/channels_graphql_ws/scope_as_context.py b/channels_graphql_ws/scope_as_context.py index 202a1dc..036bc60 100644 --- a/channels_graphql_ws/scope_as_context.py +++ b/channels_graphql_ws/scope_as_context.py @@ -25,7 +25,7 @@ class ScopeAsContext: """Wrapper to make Channels `scope` appear as an `info.context`.""" - def __init__(self, scope): + def __init__(self, scope: dict): """Remember given `scope`.""" self._scope = scope diff --git a/channels_graphql_ws/subscription.py b/channels_graphql_ws/subscription.py index cc95756..71941d4 100644 --- a/channels_graphql_ws/subscription.py +++ b/channels_graphql_ws/subscription.py @@ -356,6 +356,7 @@ def __init_subclass_with_meta__( _meta.subscribe = get_function(subscribe) _meta.publish = get_function(publish) _meta.unsubscribed = get_function(unsubscribed) + _meta.initial_payload = options.get("initial_payload", cls.SKIP) super().__init_subclass_with_meta__(_meta=_meta, **options) @@ -422,7 +423,9 @@ def unsubscribed_callback(): # `subscribe`. return result - return register_subscription(groups, publish_callback, unsubscribed_callback) + return register_subscription( + groups, publish_callback, unsubscribed_callback, cls._meta.initial_payload + ) @classmethod def _group_name(cls, group=None): diff --git a/tests/test_concurrent.py b/tests/test_concurrent.py index e6a2c15..ee845f4 100644 --- a/tests/test_concurrent.py +++ b/tests/test_concurrent.py @@ -732,9 +732,9 @@ async def test_message_order_in_subscribe_unsubscribe_all_loop( 'complete' message. """ - NUMBER_OF_UNSUBSCRIBE_CALLS = 50 # pylint: disable=invalid-name + NUMBER_OF_UNSUBSCRIBE_CALLS = 100 # pylint: disable=invalid-name # Delay in seconds. - DELAY_BETWEEN_UNSUBSCRIBE_CALLS = 0.01 # pylint: disable=invalid-name + DELAY_BETWEEN_UNSUBSCRIBE_CALLS = 0.02 # pylint: disable=invalid-name # Gradually stop the test if time is up. TIME_BORDER = 20 # pylint: disable=invalid-name