From c39497b7a08ba2bd268fa45a1a653a5c3f9289f4 Mon Sep 17 00:00:00 2001 From: Grain Team Date: Thu, 22 Jan 2026 10:55:44 -0800 Subject: [PATCH] Internal PiperOrigin-RevId: 859684764 --- .../_src/python/dataset/transformations/process_prefetch.py | 6 ++++-- .../python/dataset/transformations/process_prefetch_test.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 4b271ff1e..e254bd658 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -449,6 +449,8 @@ def _stop_prefetch(self): return self._prefetch_should_stop.set() + _clear_queue_and_maybe_unlink_shm(self._buffer) + self._clear_set_state_queue() # Not joining here will cause the children to be zombie after they finish. # Need to join or call active_children. @@ -458,9 +460,9 @@ def _stop_prefetch(self): # kill the child processes. if self._prefetch_process.is_alive(): self._prefetch_process.kill() + else: + _clear_queue_and_maybe_unlink_shm(self._buffer) self._prefetch_process = None - _clear_queue_and_maybe_unlink_shm(self._buffer) - self._clear_set_state_queue() self._set_state_count = 0 def get_state(self) -> StateT: diff --git a/grain/_src/python/dataset/transformations/process_prefetch_test.py b/grain/_src/python/dataset/transformations/process_prefetch_test.py index 390a4be72..6d4fe45d8 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch_test.py +++ b/grain/_src/python/dataset/transformations/process_prefetch_test.py @@ -17,6 +17,7 @@ import os import sys import time +import types from typing import TypeVar from unittest import mock