From 6cb286af9fc95cdf1adf69e033381ed2399a176a Mon Sep 17 00:00:00 2001 From: gnzsnz <8376642+gnzsnz@users.noreply.github.com> Date: Sat, 27 Sep 2025 18:06:35 +0200 Subject: [PATCH] fix for issue #10 with additional test cases --- eventkit/ops/transform.py | 10 +++++----- tests/transform_test.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/eventkit/ops/transform.py b/eventkit/ops/transform.py index afcdd38..3976179 100644 --- a/eventkit/ops/transform.py +++ b/eventkit/ops/transform.py @@ -3,7 +3,7 @@ import time from collections import deque -from ..util import NO_VALUE +from ..util import NO_VALUE, get_event_loop from .combine import Chain, Concat, Merge, Switch from .op import Op @@ -226,7 +226,7 @@ def __init__(self, func, timeout=0, ordered=True, task_limit=None, source=None): def on_source(self, *args): obj = self._func(*args) - if asyncio.iscoroutine(obj): + if hasattr(obj, "__await__"): # function returns an awaitable if not self._task_limit or len(self._tasks) < self._task_limit: # schedule right away @@ -245,12 +245,12 @@ def on_source_done(self, source): self._source = None - def _create_task(self, coro): + def _create_task(self, awaitable): # schedule a task to be run if self._timeout: - coro = asyncio.wait_for(coro, self._timeout) + awaitable = asyncio.wait_for(awaitable, self._timeout) - task = asyncio.create_task(coro) + task = asyncio.ensure_future(awaitable, loop=get_event_loop()) task.add_done_callback(self._on_task_done) self._tasks.append(task) diff --git a/tests/transform_test.py b/tests/transform_test.py index 8b7ffae..b54357a 100644 --- a/tests/transform_test.py +++ b/tests/transform_test.py @@ -4,11 +4,18 @@ import numpy as np +import eventkit as ev from eventkit import Event +from eventkit.util import get_event_loop array = list(range(20)) +def run(*args, **kwargs): + loop = get_event_loop() + return loop.run_until_complete(*args, **kwargs) + + class TransformTest(unittest.TestCase): def test_constant(self): event = Event.sequence(array).constant(42) @@ -151,3 +158,31 @@ def test_switchmap(self): ] event = Event.range(3).switchmap(lambda v: Event.marble(marbles[v])) self.assertEqual(event.run(), ["A", "B", "1", "2", "K", "L", "M", "N"]) + + def test_map_with_future(self): + """Verify that Map correctly handles functions that return a Future.""" + # Create a future that we will complete manually + my_future = asyncio.Future() + + # The map function will just return our future + def map_func(x): + return my_future + + event = Event.sequence([1]).map(map_func) + result = [] + event.connect(result.append) + + # Give the event loop a chance to run the map + run(asyncio.sleep(0)) + + # The event should not have emitted yet, as the future is not done + self.assertEqual(result, []) + + # Now, complete the future + my_future.set_result(42) + + # Give the event loop a chance to process the completion + run(asyncio.sleep(0)) + + # The event should now have emitted the future's result + self.assertEqual(result, [42])