Skip to content

Commit 6894615

Browse files
author
Andrey Zelenchuk
committed
Fix losing initial payload because of the race.
1 parent 3c5db89 commit 6894615

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

channels_graphql_ws/graphql_ws_consumer.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,6 @@ async def _register_subscription(
700700
# `_sids_by_group` without any locks.
701701
self._assert_thread()
702702

703-
# The subject we will trigger on the `broadcast` message.
704-
trigger = rx.subjects.Subject()
705-
706703
# The subscription notification queue.
707704
notification_queue = asyncio.Queue(
708705
maxsize=self.subscription_notification_queue_limit
@@ -714,56 +711,41 @@ async def _register_subscription(
714711

715712
# Start an endless task which listens the `notification_queue`
716713
# and invokes subscription "resolver" on new notifications.
717-
async def notifier():
714+
async def notifier(observer: rx.Observer):
718715
"""Watch the notification queue and notify clients."""
719716

720717
# Assert we run in a proper thread.
721718
self._assert_thread()
722-
723-
# Dirty hack to partially workaround the race between:
724-
# 1) call to `result.subscribe` in `_on_gql_start`; and
725-
# 2) call to `trigger.on_next` below in this function.
726-
# The first call must be earlier. Otherwise, first one or more notifications
727-
# may be lost.
728-
await asyncio.sleep(1)
729-
730719
while True:
731720
serialized_payload = await notification_queue.get()
732721

733722
# Run a subscription's `publish` method (invoked by the
734-
# `trigger.on_next` function) within the threadpool used
723+
# `observer.on_next` function) within the threadpool used
735724
# for processing other GraphQL resolver functions.
736725
# NOTE: it is important to run the deserialization
737726
# in the worker thread as well.
738727
def workload():
739728
try:
740729
payload = Serializer.deserialize(serialized_payload)
741730
except Exception as ex: # pylint: disable=broad-except
742-
trigger.on_error(f"Cannot deserialize payload. {ex}")
731+
observer.on_error(f"Cannot deserialize payload. {ex}")
743732
else:
744-
trigger.on_next(payload)
733+
observer.on_next(payload)
745734

746735
await self._run_in_worker(workload)
747736

748737
# Message processed. This allows `Queue.join` to work.
749738
notification_queue.task_done()
750739

751-
# Enqueue the `publish` method execution. But do not notify
752-
# clients when `publish` returns `SKIP`.
753-
stream = trigger.map(publish_callback).filter( # pylint: disable=no-member
754-
lambda publish_returned: publish_returned is not self.SKIP
755-
)
756-
740+
def push_payloads(observer: rx.Observer):
757741
# Start listening for broadcasts (subscribe to the Channels
758742
# groups), spawn the notification processing task and put
759743
# subscription information into the registry.
760744
# NOTE: Update of `_sids_by_group` & `_subscriptions` must be
761745
# atomic i.e. without `awaits` in between.
762-
waitlist = []
763746
for group in groups:
764747
self._sids_by_group.setdefault(group, []).append(operation_id)
765-
waitlist.append(self._channel_layer.group_add(group, self.channel_name))
766-
notifier_task = self._spawn_background_task(notifier())
748+
notifier_task = self._spawn_background_task(notifier(observer))
767749
self._subscriptions[operation_id] = self._SubInf(
768750
groups=groups,
769751
sid=operation_id,
@@ -772,9 +754,20 @@ def workload():
772754
notifier_task=notifier_task,
773755
)
774756

775-
await asyncio.wait(waitlist)
757+
await asyncio.wait(
758+
[
759+
self._channel_layer.group_add(group, self.channel_name)
760+
for group in groups
761+
]
762+
)
776763

777-
return stream
764+
# Enqueue the `publish` method execution. But do not notify
765+
# clients when `publish` returns `SKIP`.
766+
return (
767+
rx.Observable.create(push_payloads) # pylint: disable=no-member
768+
.map(publish_callback)
769+
.filter(lambda publish_returned: publish_returned is not self.SKIP)
770+
)
778771

779772
async def _on_gql_stop(self, operation_id):
780773
"""Process the STOP message.

0 commit comments

Comments
 (0)