diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index 1a2e9f9..20d2a3e 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -76,8 +76,7 @@ def __init__( @property async def base_token(self) -> bytes: - if self._base_token is None: - await self._connected.wait() + await self._connected.wait() assert self._base_token is not None return self._base_token @@ -134,25 +133,21 @@ async def reconnect(self): self._tg.start_soon(self._receive_task) ack = await self._receive( - filters.chain( - filters.cmd(protocol.ConnectAck), - lambda c: ( - c - if ( - c.token == self._base_token - if self._base_token is not None - else True - ) - else None - ), - ) + filters.cmd(protocol.ConnectAck), ) logging.debug(f"Connected with ack: {ack}") assert ack.status == 0 - if self._base_token is None: + if ack.token is None: + # Server accepted the cached token without returning a new one + logging.debug(f"Base token accepted by server: {self._base_token.hex()}") + elif self._base_token is None: + self._base_token = ack.token + logging.debug(f"Base token assigned: {ack.token.hex()}") + elif ack.token != self._base_token: + logging.warning(f"Base token changed: {self._base_token.hex()} -> {ack.token.hex()}") self._base_token = ack.token else: - assert ack.token == self._base_token + logging.debug(f"Base token confirmed: {ack.token.hex()}") if not self._connected.is_set(): self._connected.set()