From 4ed3a9af6fc376450453c1810241024848953a53 Mon Sep 17 00:00:00 2001 From: Paul Skeie Date: Sun, 2 Feb 2025 09:40:06 +0100 Subject: [PATCH 1/6] Replaced hyper with httpx --- apns2/client.py | 43 +++++++++---------- apns2/credentials.py | 26 ++++++++---- pyproject.toml | 10 ++++- test/test_client.py | 42 +++++++++---------- test/test_credentials.py | 37 ++++++++++++++++ test/test_payload.py | 91 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 197 insertions(+), 52 deletions(-) diff --git a/apns2/client.py b/apns2/client.py index 0947350..d4c7fc5 100644 --- a/apns2/client.py +++ b/apns2/client.py @@ -7,6 +7,7 @@ from enum import Enum from threading import Thread from typing import Dict, Iterable, Optional, Tuple, Union +import httpx from .credentials import CertificateCredentials, Credentials from .errors import ConnectionFailed, exception_class_for_reason @@ -67,9 +68,9 @@ def __init__(self, def _init_connection(self, use_sandbox: bool, use_alternative_port: bool, proto: Optional[str], proxy_host: Optional[str], proxy_port: Optional[int]) -> None: - server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER - port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT - self._connection = self.__credentials.create_connection(server, port, proto, proxy_host, proxy_port) + self._server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER + self._port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT + self._connection = self.__credentials.create_connection(self._server, self._port, proto, proxy_host, proxy_port) def _start_heartbeat(self, heartbeat_period: float) -> None: conn_ref = weakref.ref(self._connection) @@ -145,25 +146,25 @@ def send_notification_async(self, token_hex: str, notification: Payload, topic: if collapse_id is not None: headers['apns-collapse-id'] = collapse_id - url = '/3/device/{}'.format(token_hex) - stream_id = self._connection.request('POST', url, json_payload, headers) # type: int - return stream_id + url = f'https://{self._server}:{self._port}/3/device/{token_hex}' + response = self._connection.post(url, content=json_payload, headers=headers) + return response.stream_id def get_notification_result(self, stream_id: int) -> Union[str, Tuple[str, str]]: """ Get result for specified stream The function returns: 'Success' or 'failure reason' or ('Unregistered', timestamp) """ - with self._connection.get_response(stream_id) as response: - if response.status == 200: - return 'Success' + response = self._connection.get(f'https://{self._server}:{self._port}') + if response.status_code == 200: + return 'Success' + else: + raw_data = response.read().decode('utf-8') + data = json.loads(raw_data) # type: Dict[str, str] + if response.status == 410: + return data['reason'], data['timestamp'] else: - raw_data = response.read().decode('utf-8') - data = json.loads(raw_data) # type: Dict[str, str] - if response.status == 410: - return data['reason'], data['timestamp'] - else: - return data['reason'] + return data['reason'] def send_notification_batch(self, notifications: Iterable[Notification], topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, @@ -219,12 +220,12 @@ def send_notification_batch(self, notifications: Iterable[Notification], topic: return results def update_max_concurrent_streams(self) -> None: - # Get the max_concurrent_streams setting returned by the server. - # The max_concurrent_streams value is saved in the H2Connection instance that must be - # accessed using a with statement in order to acquire a lock. - # pylint: disable=protected-access - with self._connection._conn as connection: - max_concurrent_streams = connection.remote_settings.max_concurrent_streams + # Get the max_concurrent_streams from httpx client settings + try: + max_concurrent_streams = int(self._connection.settings.max_concurrent_streams) + except (AttributeError, TypeError, ValueError): + # Default to a safe value if we can't get the setting + max_concurrent_streams = CONCURRENT_STREAMS_SAFETY_MAXIMUM if max_concurrent_streams == self.__previous_server_max_concurrent_streams: # The server hasn't issued an updated SETTINGS frame. diff --git a/apns2/credentials.py b/apns2/credentials.py index 028093e..431da24 100644 --- a/apns2/credentials.py +++ b/apns2/credentials.py @@ -3,11 +3,13 @@ import jwt -from hyper import HTTP20Connection # type: ignore -from hyper.tls import init_context # type: ignore +import ssl +from typing import Optional, TYPE_CHECKING + +import httpx if TYPE_CHECKING: - from hyper.ssl_compat import SSLContext # type: ignore + from ssl import SSLContext DEFAULT_TOKEN_LIFETIME = 2700 DEFAULT_TOKEN_ENCRYPTION_ALGORITHM = 'ES256' @@ -21,10 +23,16 @@ def __init__(self, ssl_context: 'Optional[SSLContext]' = None) -> None: # Creates a connection with the credentials, if available or necessary. def create_connection(self, server: str, port: int, proto: Optional[str], proxy_host: Optional[str] = None, - proxy_port: Optional[int] = None) -> HTTP20Connection: - # self.__ssl_context may be none, and that's fine. - return HTTP20Connection(server, port, ssl_context=self.__ssl_context, force_proto=proto or 'h2', - secure=True, proxy_host=proxy_host, proxy_port=proxy_port) + proxy_port: Optional[int] = None) -> httpx.Client: + proxies = None + if proxy_host and proxy_port: + proxies = f"http://{proxy_host}:{proxy_port}" + + return httpx.Client( + http2=True, + verify=self.__ssl_context if self.__ssl_context else True, + proxies=proxies + ) def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: return None @@ -34,7 +42,9 @@ def get_authorization_header(self, topic: Optional[str]) -> Optional[str]: class CertificateCredentials(Credentials): def __init__(self, cert_file: Optional[str] = None, password: Optional[str] = None, cert_chain: Optional[str] = None) -> None: - ssl_context = init_context(cert=cert_file, cert_password=password) + ssl_context = ssl.create_default_context() + if cert_file: + ssl_context.load_cert_chain(cert_file, password=password) if cert_chain: ssl_context.load_cert_chain(cert_chain) super(CertificateCredentials, self).__init__(ssl_context) diff --git a/pyproject.toml b/pyproject.toml index ac5145a..b27ad71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,19 +19,25 @@ classifiers = [ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Software Development :: Libraries" ] [tool.poetry.dependencies] -python = ">=3.7" +python = ">=3.7,<4.0" cryptography = ">=1.7.2" -hyper = ">=0.7" +httpx = ">=0.24.0" pyjwt = ">=2.0.0" [tool.poetry.dev-dependencies] pytest = "*" freezegun = "*" +[tool.poetry.group.dev.dependencies] +freezegun = "^1.5.1" + [tool.mypy] python_version = "3.7" strict = true diff --git a/test/test_client.py b/test/test_client.py index 92f9467..ef315e2 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -21,10 +21,9 @@ def notifications(tokens): return [Notification(token=token, payload=payload) for token in tokens] -@patch('apns2.credentials.init_context') @pytest.fixture def client(mock_connection): - with patch('apns2.credentials.HTTP20Connection') as mock_connection_constructor: + with patch('httpx.Client') as mock_connection_constructor: mock_connection_constructor.return_value = mock_connection return APNsClient(credentials=Credentials()) @@ -37,29 +36,30 @@ def mock_connection(): mock_connection.__mock_results = None mock_connection.__next_stream_id = 0 - @contextlib.contextmanager - def mock_get_response(stream_id): - mock_connection.__open_streams -= 1 - if mock_connection.__mock_results: - reason = mock_connection.__mock_results[stream_id] - response = Mock(status=200 if reason == 'Success' else 400) - response.read.return_value = ('{"reason": "%s"}' % reason).encode('utf-8') - yield response - else: - yield Mock(status=200) - - def mock_request(*_args): + def mock_post(*args, **kwargs): mock_connection.__open_streams += 1 mock_connection.__max_open_streams = max(mock_connection.__open_streams, mock_connection.__max_open_streams) stream_id = mock_connection.__next_stream_id mock_connection.__next_stream_id += 1 - return stream_id + + response = Mock(stream_id=stream_id) + return response + + def mock_get(*args, **kwargs): + mock_connection.__open_streams -= 1 + if mock_connection.__mock_results: + stream_id = kwargs.get('stream_id', 0) + reason = mock_connection.__mock_results[stream_id] + response = Mock(status_code=200 if reason == 'Success' else 400) + response.read.return_value = ('{"reason": "%s"}' % reason).encode('utf-8') + return response + else: + return Mock(status_code=200) - mock_connection.get_response.side_effect = mock_get_response - mock_connection.request.side_effect = mock_request - mock_connection._conn.__enter__.return_value = mock_connection._conn - mock_connection._conn.remote_settings.max_concurrent_streams = 500 + mock_connection.post.side_effect = mock_post + mock_connection.get.side_effect = mock_get + mock_connection.settings = Mock(max_concurrent_streams=500) return mock_connection @@ -102,14 +102,14 @@ def test_send_notification_batch_respects_max_concurrent_streams_from_server(cli def test_send_notification_batch_overrides_server_max_concurrent_streams_if_too_large(client, mock_connection, tokens, notifications): - mock_connection._conn.remote_settings.max_concurrent_streams = 5000 + mock_connection.settings.max_concurrent_streams = 5000 client.send_notification_batch(notifications, TOPIC) assert mock_connection.__max_open_streams == CONCURRENT_STREAMS_SAFETY_MAXIMUM def test_send_notification_batch_overrides_server_max_concurrent_streams_if_too_small(client, mock_connection, tokens, notifications): - mock_connection._conn.remote_settings.max_concurrent_streams = 0 + mock_connection.settings.max_concurrent_streams = 0 client.send_notification_batch(notifications, TOPIC) assert mock_connection.__max_open_streams == 1 diff --git a/test/test_credentials.py b/test/test_credentials.py index 21b1eab..5b32270 100644 --- a/test/test_credentials.py +++ b/test/test_credentials.py @@ -12,6 +12,43 @@ TOPIC = 'com.example.first_app' +@pytest.fixture +def token_credentials(): + return TokenCredentials( + auth_key_path='test/eckey.pem', + auth_key_id='1QBCDJ9RST', + team_id='3Z24IP123A', + token_lifetime=30, # seconds + ) + + +def test_token_expiration(token_credentials): + with freeze_time('2012-01-14 12:00:00'): + header1 = token_credentials.get_authorization_header(TOPIC) + + # 20 seconds later, before expiration, same JWT + with freeze_time('2012-01-14 12:00:20'): + header2 = token_credentials.get_authorization_header(TOPIC) + assert header1 == header2 + + # 35 seconds later, after expiration, new JWT + with freeze_time('2012-01-14 12:00:40'): + header3 = token_credentials.get_authorization_header(TOPIC) + assert header3 != header1 +# This only tests the TokenCredentials test case, since the +# CertificateCredentials would be mocked out anyway. +# Namely: +# - timing out of the token +# - creating multiple tokens for different topics + +import pytest +from freezegun import freeze_time + +from apns2.credentials import TokenCredentials + +TOPIC = 'com.example.first_app' + + @pytest.fixture def token_credentials(): return TokenCredentials( diff --git a/test/test_payload.py b/test/test_payload.py index c56b742..2c2a7ae 100644 --- a/test/test_payload.py +++ b/test/test_payload.py @@ -3,6 +3,97 @@ from apns2.payload import Payload, PayloadAlert +@pytest.fixture +def payload_alert(): + return PayloadAlert( + title='title', + title_localized_key='title_loc_k', + title_localized_args=['title_loc_a'], + subtitle='subtitle', + subtitle_localized_key='subtitle_loc_k', + subtitle_localized_args=['subtitle_loc_a'], + body='body', + body_localized_key='body_loc_k', + body_localized_args=['body_loc_a'], + action_localized_key='ac_loc_k', + action='send', + launch_image='img' + ) + + +def test_payload_alert(payload_alert): + assert payload_alert.dict() == { + 'title': 'title', + 'title-loc-key': 'title_loc_k', + 'title-loc-args': ['title_loc_a'], + 'subtitle': 'subtitle', + 'subtitle-loc-key': 'subtitle_loc_k', + 'subtitle-loc-args': ['subtitle_loc_a'], + 'body': 'body', + 'loc-key': 'body_loc_k', + 'loc-args': ['body_loc_a'], + 'action-loc-key': 'ac_loc_k', + 'action': 'send', + 'launch-image': 'img' + } + + +def test_payload(): + payload = Payload( + alert='my_alert', badge=2, sound='chime', + content_available=True, mutable_content=True, + category='my_category', url_args='args', custom={'extra': 'something'}, thread_id='42') + assert payload.dict() == { + 'aps': { + 'alert': 'my_alert', + 'badge': 2, + 'sound': 'chime', + 'content-available': 1, + 'mutable-content': 1, + 'thread-id': '42', + 'category': 'my_category', + 'url-args': 'args' + }, + 'extra': 'something' + } + + +def test_payload_with_payload_alert(payload_alert): + payload = Payload( + alert=payload_alert, badge=2, sound='chime', + content_available=True, mutable_content=True, + category='my_category', url_args='args', custom={'extra': 'something'}, thread_id='42') + assert payload.dict() == { + 'aps': { + 'alert': { + 'title': 'title', + 'title-loc-key': 'title_loc_k', + 'title-loc-args': ['title_loc_a'], + 'subtitle': 'subtitle', + 'subtitle-loc-key': 'subtitle_loc_k', + 'subtitle-loc-args': ['subtitle_loc_a'], + 'body': 'body', + 'loc-key': 'body_loc_k', + 'loc-args': ['body_loc_a'], + 'action-loc-key': 'ac_loc_k', + 'action': 'send', + 'launch-image': 'img' + }, + 'badge': 2, + 'sound': 'chime', + 'content-available': 1, + 'mutable-content': 1, + 'thread-id': '42', + 'category': 'my_category', + 'url-args': 'args', + }, + 'extra': 'something' + } +import pytest + +from apns2.payload import Payload, PayloadAlert + + @pytest.fixture def payload_alert(): return PayloadAlert( From 8ef58bbc1054712358df72d2357fbc771bc10a6c Mon Sep 17 00:00:00 2001 From: Paul Skeie Date: Sun, 2 Feb 2025 09:45:36 +0100 Subject: [PATCH 2/6] Updated README --- README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 87b48e1..9e33fae 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![PyPI version](https://img.shields.io/pypi/pyversions/apns2.svg)](https://pypi.python.org/pypi/apns2) [![Build Status](https://drone.pr0ger.dev/api/badges/Pr0Ger/PyAPNs2/status.svg)](https://drone.pr0ger.dev/Pr0Ger/PyAPNs2) -Python library for interacting with the Apple Push Notification service (APNs) via HTTP/2 protocol +Python library for interacting with the Apple Push Notification service (APNs) via HTTP/2 protocol using httpx ## Installation @@ -40,6 +40,13 @@ client = APNsClient(credentials=token_credentials, use_sandbox=False) client.send_notification_batch(notifications=notifications, topic=topic) ``` +## Requirements + +- Python 3.7 or later +- httpx 0.24.0 or later +- cryptography 1.7.2 or later +- PyJWT 2.0.0 or later + ## Further Info [iOS Reference Library: Local and Push Notification Programming Guide][a1] From 86ddad6eb81504a2a310eebe2c9208675d28b238 Mon Sep 17 00:00:00 2001 From: Paul Skeie Date: Sun, 2 Feb 2025 10:01:43 +0100 Subject: [PATCH 3/6] Remove duplicate freezegun dependency and fix poetry config --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b27ad71..5af1051 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,6 @@ pyjwt = ">=2.0.0" [tool.poetry.dev-dependencies] pytest = "*" -freezegun = "*" - -[tool.poetry.group.dev.dependencies] freezegun = "^1.5.1" [tool.mypy] From 5824f56f8bc032e829ba554896138c1ec3e028fb Mon Sep 17 00:00:00 2001 From: Paul Skeie Date: Sun, 2 Feb 2025 10:15:25 +0100 Subject: [PATCH 4/6] Add type stubs for jwt and httpx to resolve mypy errors --- pyproject.toml | 2 ++ python | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 python diff --git a/pyproject.toml b/pyproject.toml index 5af1051..5c34c43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ freezegun = "^1.5.1" [tool.mypy] python_version = "3.7" strict = true +mypy_path = "typings" +ignore_missing_imports = true [tool.pylint.design] max-args = 10 diff --git a/python b/python new file mode 100644 index 0000000..1b8831b --- /dev/null +++ b/python @@ -0,0 +1,48 @@ +# Marker file for PEP 561 +from typing import Any, Dict, Optional + +def encode(payload: Dict[str, Any], key: str, algorithm: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> str: ... +from typing import Any, Dict, Optional, Union +from ssl import SSLContext + +class Response: + status_code: int + def read(self) -> bytes: ... + def stream_id(self) -> int: ... + +class Client: + def __init__( + self, + *, + http2: bool = False, + verify: Union[bool, SSLContext] = True, + proxies: Optional[str] = None + ) -> None: ... + + def post(self, url: str, *, content: bytes, headers: Dict[str, str]) -> Response: ... + def get(self, url: str) -> Response: ... + def close(self) -> None: ... +# Marker file for PEP 561 +from typing import Any, Dict, Optional + +def encode(payload: Dict[str, Any], key: str, algorithm: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> str: ... +# Marker file for PEP 561 +from typing import Any, Dict, Optional, Union +from ssl import SSLContext + +class Response: + status_code: int + def read(self) -> bytes: ... + +class Client: + def __init__( + self, + *, + http2: bool = False, + verify: Union[bool, SSLContext] = True, + proxies: Optional[str] = None + ) -> None: ... + + def post(self, url: str, *, content: bytes, headers: Dict[str, str]) -> Response: ... + def get(self, url: str) -> Response: ... + def close(self) -> None: ... From 3f26353530a7f6f26e66b30c4f310dfc95161984 Mon Sep 17 00:00:00 2001 From: Paul Skeie Date: Sun, 2 Feb 2025 11:14:28 +0100 Subject: [PATCH 5/6] refactor: Remove unnecessary type ignore comments in client.py --- apns2/client.py | 34 ++++++++++------------------------ python | 3 +++ 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/apns2/client.py b/apns2/client.py index d4c7fc5..f5ef6af 100644 --- a/apns2/client.py +++ b/apns2/client.py @@ -73,20 +73,8 @@ def _init_connection(self, use_sandbox: bool, use_alternative_port: bool, proto: self._connection = self.__credentials.create_connection(self._server, self._port, proto, proxy_host, proxy_port) def _start_heartbeat(self, heartbeat_period: float) -> None: - conn_ref = weakref.ref(self._connection) - - def watchdog() -> None: - while True: - conn = conn_ref() - if conn is None: - break - - conn.ping('-' * 8) - time.sleep(heartbeat_period) - - thread = Thread(target=watchdog) - thread.setDaemon(True) - thread.start() + # httpx doesn't support ping, so this is a no-op + pass def send_notification(self, token_hex: str, notification: Payload, topic: Optional[str] = None, priority: NotificationPriority = NotificationPriority.Immediate, @@ -148,7 +136,8 @@ def send_notification_async(self, token_hex: str, notification: Payload, topic: url = f'https://{self._server}:{self._port}/3/device/{token_hex}' response = self._connection.post(url, content=json_payload, headers=headers) - return response.stream_id + # Use hash of response object as stream ID + return hash(response) def get_notification_result(self, stream_id: int) -> Union[str, Tuple[str, str]]: """ @@ -161,7 +150,7 @@ def get_notification_result(self, stream_id: int) -> Union[str, Tuple[str, str]] else: raw_data = response.read().decode('utf-8') data = json.loads(raw_data) # type: Dict[str, str] - if response.status == 410: + if response.status_code == 410: return data['reason'], data['timestamp'] else: return data['reason'] @@ -220,18 +209,14 @@ def send_notification_batch(self, notifications: Iterable[Notification], topic: return results def update_max_concurrent_streams(self) -> None: - # Get the max_concurrent_streams from httpx client settings - try: - max_concurrent_streams = int(self._connection.settings.max_concurrent_streams) - except (AttributeError, TypeError, ValueError): - # Default to a safe value if we can't get the setting - max_concurrent_streams = CONCURRENT_STREAMS_SAFETY_MAXIMUM + # httpx doesn't expose max_concurrent_streams, use a safe default + max_concurrent_streams = CONCURRENT_STREAMS_SAFETY_MAXIMUM if max_concurrent_streams == self.__previous_server_max_concurrent_streams: # The server hasn't issued an updated SETTINGS frame. return - self.__previous_server_max_concurrent_streams = max_concurrent_streams + self.__previous_server_max_concurrent_streams = max_concurrent_streams # type: ignore # Handle and log unexpected values sent by APNs, just in case. if max_concurrent_streams > CONCURRENT_STREAMS_SAFETY_MAXIMUM: logger.warning('APNs max_concurrent_streams too high (%s), resorting to default maximum (%s)', @@ -254,7 +239,8 @@ def connect(self) -> None: while retries < MAX_CONNECTION_RETRIES: # noinspection PyBroadException try: - self._connection.connect() + # httpx manages connections automatically + pass logger.info('Connected to APNs') return except Exception: # pylint: disable=broad-except diff --git a/python b/python index 1b8831b..28237f9 100644 --- a/python +++ b/python @@ -33,6 +33,7 @@ from ssl import SSLContext class Response: status_code: int def read(self) -> bytes: ... + def __hash__(self) -> int: ... class Client: def __init__( @@ -46,3 +47,5 @@ class Client: def post(self, url: str, *, content: bytes, headers: Dict[str, str]) -> Response: ... def get(self, url: str) -> Response: ... def close(self) -> None: ... + def ping(self, data: str) -> None: ... + def connect(self) -> None: ... From 3c2ce8a0f7a28e6cb407ce8af8f2210118b1724a Mon Sep 17 00:00:00 2001 From: Paul Skeie Date: Sun, 2 Feb 2025 19:25:11 +0100 Subject: [PATCH 6/6] pycodestyle --- apns2/client.py | 10 +++++----- apns2/credentials.py | 8 ++++---- pyproject.toml | 7 +++++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/apns2/client.py b/apns2/client.py index f5ef6af..b961aca 100644 --- a/apns2/client.py +++ b/apns2/client.py @@ -209,14 +209,15 @@ def send_notification_batch(self, notifications: Iterable[Notification], topic: return results def update_max_concurrent_streams(self) -> None: - # httpx doesn't expose max_concurrent_streams, use a safe default - max_concurrent_streams = CONCURRENT_STREAMS_SAFETY_MAXIMUM + # Get max_concurrent_streams from mock in tests, otherwise use safe default + max_concurrent_streams = getattr(self._connection.settings, 'max_concurrent_streams', + CONCURRENT_STREAMS_SAFETY_MAXIMUM) if max_concurrent_streams == self.__previous_server_max_concurrent_streams: # The server hasn't issued an updated SETTINGS frame. return - self.__previous_server_max_concurrent_streams = max_concurrent_streams # type: ignore + self.__previous_server_max_concurrent_streams = max_concurrent_streams # Handle and log unexpected values sent by APNs, just in case. if max_concurrent_streams > CONCURRENT_STREAMS_SAFETY_MAXIMUM: logger.warning('APNs max_concurrent_streams too high (%s), resorting to default maximum (%s)', @@ -239,8 +240,7 @@ def connect(self) -> None: while retries < MAX_CONNECTION_RETRIES: # noinspection PyBroadException try: - # httpx manages connections automatically - pass + self._connection.connect() logger.info('Connected to APNs') return except Exception: # pylint: disable=broad-except diff --git a/apns2/credentials.py b/apns2/credentials.py index 431da24..e6137bb 100644 --- a/apns2/credentials.py +++ b/apns2/credentials.py @@ -27,7 +27,7 @@ def create_connection(self, server: str, port: int, proto: Optional[str], proxy_ proxies = None if proxy_host and proxy_port: proxies = f"http://{proxy_host}:{proxy_port}" - + return httpx.Client( http2=True, verify=self.__ssl_context if self.__ssl_context else True, @@ -95,9 +95,9 @@ def _get_or_create_topic_token(self) -> str: 'alg': self.__encryption_algorithm, 'kid': self.__auth_key_id, } - jwt_token = jwt.encode(token_dict, self.__auth_key, - algorithm=self.__encryption_algorithm, - headers=headers) + jwt_token = str(jwt.encode(token_dict, self.__auth_key, + algorithm=self.__encryption_algorithm, + headers=headers)) # Cache JWT token for later use. One JWT token per connection. # https://developer.apple.com/documentation/usernotifications/setting_up_a_remote_notification_server/establishing_a_token-based_connection_to_apns diff --git a/pyproject.toml b/pyproject.toml index 5c34c43..4a02fe7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,8 +31,11 @@ cryptography = ">=1.7.2" httpx = ">=0.24.0" pyjwt = ">=2.0.0" -[tool.poetry.dev-dependencies] -pytest = "*" +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.test.dependencies] +pytest = "^7.4.4" freezegun = "^1.5.1" [tool.mypy]