diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 4b271ff1e..e254bd658 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -449,6 +449,8 @@ def _stop_prefetch(self): return self._prefetch_should_stop.set() + _clear_queue_and_maybe_unlink_shm(self._buffer) + self._clear_set_state_queue() # Not joining here will cause the children to be zombie after they finish. # Need to join or call active_children. @@ -458,9 +460,9 @@ def _stop_prefetch(self): # kill the child processes. if self._prefetch_process.is_alive(): self._prefetch_process.kill() + else: + _clear_queue_and_maybe_unlink_shm(self._buffer) self._prefetch_process = None - _clear_queue_and_maybe_unlink_shm(self._buffer) - self._clear_set_state_queue() self._set_state_count = 0 def get_state(self) -> StateT: diff --git a/grain/_src/python/dataset/transformations/process_prefetch_test.py b/grain/_src/python/dataset/transformations/process_prefetch_test.py index 390a4be72..6d4fe45d8 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch_test.py +++ b/grain/_src/python/dataset/transformations/process_prefetch_test.py @@ -17,6 +17,7 @@ import os import sys import time +import types from typing import TypeVar from unittest import mock