Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions redis/auth/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,40 @@ def start(
except RuntimeError:
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=_start_event_loop_in_thread, args=(loop,), daemon=True
)

# Use threading.Event to signal when loop is ready
loop_ready = threading.Event()

def start_loop():
# This runs in the background thread. First, bind the event loop to
# this thread, then signal that the loop is ready so the calling
# thread can safely schedule work (via call_soon_threadsafe) before
# we block in run_forever().
asyncio.set_event_loop(loop)
loop_ready.set() # Signal that loop is ready for cross-thread use
loop.run_forever()

thread = threading.Thread(target=start_loop, daemon=True)
thread.start()

# Event to block for initial execution.
init_event = asyncio.Event()
self._init_timer = loop.call_later(
0, self._renew_token, skip_initial, init_event
)
# Wait for the loop to be ready before scheduling
loop_ready.wait()

# Use thread-safe Event for cross-thread synchronization
init_done = threading.Event()

def renew_with_callback():
try:
self._renew_token(skip_initial)
finally:
init_done.set()

# Schedule using call_soon_threadsafe for thread-safe scheduling
self._init_timer = loop.call_soon_threadsafe(renew_with_callback)
logger.info("Token manager started")

# Blocks in thread-safe manner.
asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
# Blocks using thread-safe Event
init_done.wait()
return self.stop

async def start_async(
Expand Down Expand Up @@ -247,9 +267,7 @@ def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
- (datetime.now(timezone.utc).timestamp() * 1000)
)

def _renew_token(
self, skip_initial: bool = False, init_event: asyncio.Event = None
):
def _renew_token(self, skip_initial: bool = False):
"""
Task to renew token from identity provider.
Schedules renewal tasks based on token TTL.
Expand Down Expand Up @@ -289,9 +307,6 @@ def _renew_token(
raise e

self._listener.on_error(e)
finally:
if init_event:
init_event.set()

async def _renew_token_async(
self, skip_initial: bool = False, init_event: asyncio.Event = None
Expand Down Expand Up @@ -356,15 +371,3 @@ def wrapped():
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)

return wrapped


def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
"""
Starts event loop in a thread.
Used to be able to schedule tasks using loop.call_later.

:param event_loop:
:return:
"""
asyncio.set_event_loop(event_loop)
event_loop.run_forever()
Loading