diff --git a/lazyflow/request/request.py b/lazyflow/request/request.py index 4df7a92c2..5602d36e5 100644 --- a/lazyflow/request/request.py +++ b/lazyflow/request/request.py @@ -33,8 +33,8 @@ import traceback import io from random import randrange -from typing import Callable from numpy import ma +from typing import Callable, Optional import logging @@ -107,6 +107,49 @@ def log_exception(logger, msg=None, exc_info=None, level=logging.ERROR): logger.log(level, msg) +class CancellationException(Exception): + """ + This is raised when the whole request has been cancelled. + If you catch this exception from within a request, clean up and return immediately. + If you have nothing to clean up, you are not required to handle this exception. + + Implementation details: + This exception is raised when the cancel flag is checked in the wait() function: + - immediately before the request is suspended OR + - immediately after the request is woken up from suspension + """ + + pass + + +class CancellationToken: + __slots__ = ("_cancelled",) + + def __init__(self): + self._cancelled = False + + @property + def cancelled(self): + return self._cancelled + + def raise_if_cancelled(self): + if self.cancelled: + raise CancellationException() + + def __repr__(self): + return f"CancellationToken(id={id(self)}, cancelled={self._cancelled})" + + +class CancellationTokenSource: + __slots__ = ("token",) + + def __init__(self): + self.token = CancellationToken() + + def cancel(self): + self.token._cancelled = True + + class Request(object): # One thread pool shared by all requests. @@ -144,20 +187,6 @@ def reset_thread_pool(cls, num_workers=min(multiprocessing.cpu_count(), 8)): cls.global_thread_pool.stop() cls.global_thread_pool = threadPool.ThreadPool(num_workers) - class CancellationException(Exception): - """ - This is raised when the whole request has been cancelled. - If you catch this exception from within a request, clean up and return immediately. - If you have nothing to clean up, you are not required to handle this exception. - - Implementation details: - This exception is raised when the cancel flag is checked in the wait() function: - - immediately before the request is suspended OR - - immediately after the request is woken up from suspension - """ - - pass - class InvalidRequestException(Exception): """ This is raised when calling wait on a request that has already been cancelled, @@ -197,7 +226,7 @@ class InternalError(Exception): _root_request_counter = itertools.count() - def __init__(self, fn, root_priority=[0]): + def __init__(self, fn, root_priority=[0], cancel_token=None): """ Constructor. Postconditions: The request has the same cancelled status as its parent (the request that is creating this one). @@ -208,6 +237,7 @@ def __init__(self, fn, root_priority=[0]): self._sig_cancelled = SimpleSignal() self._sig_finished = SimpleSignal() self._sig_execution_complete = SimpleSignal() + self.cancel_token = cancel_token # Workload self.fn = fn @@ -217,8 +247,6 @@ def __init__(self, fn, root_priority=[0]): # State self.started = False - self.cancelled = False - self.uncancellable = False self.finished = False self.execution_complete = False self.finished_event = threading.Event() @@ -249,11 +277,15 @@ def __init__(self, fn, root_priority=[0]): with current_request._lock: current_request.child_requests.add(self) # We must ensure that we get the same cancelled status as our parent. - self.cancelled = current_request.cancelled + self.cancel_token = current_request.cancel_token # We acquire the same priority as our parent, plus our own sub-priority current_request._max_child_priority += 1 self._priority = current_request._priority + root_priority + [current_request._max_child_priority] + @property + def uncancellable(self): + return self.cancel_token is None + def __lt__(self, other): """ Request comparison is by priority. @@ -291,6 +323,10 @@ def with_value(cls, value): """ return _ValueRequest(value) + @property + def cancelled(self): + return self.cancel_token and self.cancel_token.cancelled + def clean(self, _fullClean=True): """ Delete all state from the request, for cleanup purposes. @@ -358,7 +394,7 @@ def _execute(self): try: # Do the actual work self._result = self.fn() - except Request.CancellationException: + except CancellationException: # Don't propagate cancellations back to the worker thread, # even if the user didn't catch them. pass @@ -577,8 +613,6 @@ def _wait_within_foreign_thread(self, timeout): Here, we rely on an ordinary threading.Event primitive: ``self.finished_event`` """ # Don't allow this request to be cancelled, since a real thread is waiting for it. - self.uncancellable = True - with self._lock: direct_execute_needed = not self.started and (timeout is None) if direct_execute_needed: @@ -610,26 +644,30 @@ def _wait_within_foreign_thread(self, timeout): else: self.submit() - # This is a non-worker thread, so just block the old-fashioned way completed = self.finished_event.wait(timeout) - if not completed: - raise Request.TimeoutException() if self.cancelled: # It turns out this request was already cancelled. raise Request.InvalidRequestException() + if not completed: + raise Request.TimeoutException() + if self.exception is not None: exc_type, exc_value, exc_tb = self.exception_info raise_with_traceback(exc_value, exc_tb) + def raise_if_cancelled(self): + if self.cancel_token is not None: + self.cancel_token.raise_if_cancelled() + def _wait_within_request(self, current_request): """ This is the implementation of wait() when executed from another request. If we have to wait, suspend the current request instead of blocking the whole worker thread. """ # Before we suspend the current request, check to see if it's been cancelled since it last blocked - Request.raise_if_cancelled() + self.raise_if_cancelled() if current_request == self: # It's usually nonsense for a request to wait for itself, @@ -690,7 +728,7 @@ def _wait_within_request(self, current_request): # Now we're back (no longer suspended) # Was the current request cancelled while it was waiting for us? - Request.raise_if_cancelled() + self.raise_if_cancelled() # Are we back because we failed? if self.exception is not None: @@ -766,32 +804,6 @@ def notify_failed(self, fn): # Call immediately fn(self.exception, self.exception_info) - def cancel(self): - """ - Attempt to cancel this request and all requests that it spawned. - No request will be cancelled if other non-cancelled requests are waiting for its results. - """ - # We can only be cancelled if: - # (1) There are no foreign threads blocking for us (flagged via self.uncancellable) AND - # (2) our parent request (if any) is already cancelled AND - # (3) all requests that are pending for this one are already cancelled - with self._lock: - cancelled = not self.uncancellable - cancelled &= self.parent_request is None or self.parent_request.cancelled - for r in self.pending_requests: - cancelled &= r.cancelled - - self.cancelled = cancelled - if cancelled: - # Any children added after this point will receive our same cancelled status - child_requests = self.child_requests - self.child_requests = set() - - if self.cancelled: - # Cancel all requests that were spawned from this one. - for child in child_requests: - child.cancel() - @classmethod def _current_request(cls): """ @@ -814,14 +826,6 @@ def current_request_is_cancelled(cls): current_request = Request._current_request() return current_request and current_request.cancelled - @classmethod - def raise_if_cancelled(cls): - """ - If called from the context of a cancelled request, raise a CancellationException immediately. - """ - if Request.current_request_is_cancelled(): - raise Request.CancellationException() - ########################################## #### Backwards-compatible API support #### ########################################## @@ -990,7 +994,7 @@ def _acquire_from_within_request(self, current_request, blocking): # Try to get it immediately. got_it = self._modelLock.acquire(False) if not blocking: - Request.raise_if_cancelled() + current_request.raise_if_cancelled() return got_it if not got_it: # We have to wait. Add ourselves to the list of waiters. @@ -1003,7 +1007,7 @@ def _acquire_from_within_request(self, current_request, blocking): # Now we're back (no longer suspended) # Was the current request cancelled while it was waiting for the lock? - Request.raise_if_cancelled() + current_request.raise_if_cancelled() # Guaranteed to own _modelLock now (see release()). return True @@ -1153,7 +1157,7 @@ def _debug_mode_init(self): def __enter__(self): try: return self._ownership_lock.__enter__() - except Request.CancellationException: + except CancellationException: self._notify_nocheck() raise @@ -1236,10 +1240,11 @@ class RequestPoolError(Exception): pass - def __init__(self, max_active=None): + def __init__(self, max_active=None, cancel_token: Optional[CancellationToken] = None): """ max_active: The number of Requests to launch in parallel. """ + self._cancel_token = cancel_token self._started = False self._failed = False self._finished = True diff --git a/tests/testRequest.py b/tests/testRequest.py index da0f71971..b6454ebcf 100644 --- a/tests/testRequest.py +++ b/tests/testRequest.py @@ -22,7 +22,14 @@ # This information is also available on the ilastik web site at: # http://ilastik.org/license/ ############################################################################### -from lazyflow.request.request import Request, RequestLock, SimpleRequestCondition, RequestPool +from lazyflow.request.request import ( + Request, + RequestLock, + SimpleRequestCondition, + RequestPool, + CancellationTokenSource, + CancellationException, +) import os import time import random @@ -204,6 +211,7 @@ def workload(): got_cancel = [False] workcounter = [0] + cancel_source = CancellationTokenSource() def big_workload(): try: @@ -219,7 +227,7 @@ def big_workload(): ), "Shouldn't get to this line. This test is designed so that big_workload should be cancelled before it finishes all its work" for r in requests: assert not r.cancelled - except Request.CancellationException: + except CancellationException: got_cancel[0] = True except Exception as ex: import traceback @@ -232,14 +240,14 @@ def big_workload(): def handle_complete(result): completed[0] = True - req = Request(big_workload) + req = Request(big_workload, cancel_token=cancel_source.token) req.notify_finished(handle_complete) req.submit() while workcounter[0] == 0: time.sleep(0.001) - req.cancel() + cancel_source.cancel() time.sleep(1) assert req.cancelled diff --git a/tests/test_request/test_request.py b/tests/test_request/test_request.py index 5d0ec70c8..0d4f22f3a 100644 --- a/tests/test_request/test_request.py +++ b/tests/test_request/test_request.py @@ -7,7 +7,7 @@ import numpy as np from numpy.testing import assert_array_equal -from lazyflow.request.request import Request, SimpleSignal +from lazyflow.request.request import Request, SimpleSignal, CancellationTokenSource, CancellationException class TExc(Exception): @@ -88,10 +88,13 @@ def __call__(self): self.result = self.work_fn() return self.result except Exception as e: + print("EXC") self.exception = e raise finally: + print("DONE") self.done.set() + print("OUTDONE") class TestRequest: @@ -141,7 +144,7 @@ def test_signal_finished_called_on_completion(self, work_fn): req.notify_finished(recv) req.submit() - assert work.done.wait(timeout=1) + assert req.wait(timeout=1) recv.assert_called_once_with(42) @@ -150,7 +153,7 @@ def test_signal_finished_called_when_subscription_happened_after_completion(self recv = mock.Mock() req.submit() - assert work.done.wait(timeout=1) + assert req.wait(timeout=1) req.notify_finished(recv) @@ -164,6 +167,7 @@ def test_signal_finished_should_not_be_called_on_exception(self, broken_fn): req.submit() with pytest.raises(TExc): + print("HELLO") assert req.wait() == 42 recv.assert_not_called() @@ -198,7 +202,12 @@ def test_signal_failed_called_even_when_subscription_happened_after_completion(s @pytest.fixture -def work(): +def cancel_source(): + return CancellationTokenSource() + + +@pytest.fixture +def work(cancel_source): unpause = threading.Event() children = [] @@ -212,7 +221,7 @@ def work_fn(): work = Work(work_fn) work.unpause = unpause - work.request = Request(work) + work.request = Request(work, cancel_token=cancel_source.token) work.request.submit() work.children = children @@ -232,9 +241,9 @@ def test_requests_created_within_request_considired_child_requests(work): @pytest.fixture -def cancelled_work(work): +def cancelled_work(work, cancel_source): assert not work.request.cancelled - work.request.cancel() + cancel_source.cancel() assert work.request.cancelled work.unpause.set() @@ -252,7 +261,7 @@ def test_wait_for_cancelled_rq_raises_invalid_request_exception(cancelled_work): def test_cancel_raises_exception_on_yield_point(cancelled_work): work_rq = cancelled_work.request - assert isinstance(cancelled_work.exception, Request.CancellationException) + assert isinstance(cancelled_work.exception, CancellationException) def test_cancels_child_requests(cancelled_work):