Skip to content

Commit b4f30bc

Browse files
[async] Applied #2596 to async code - part 1
1 parent e06dc52 commit b4f30bc

File tree

5 files changed

+28
-9
lines changed

5 files changed

+28
-9
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,15 @@ async def connect(self, **kwargs) -> None:
10471047
proxy_password=self.proxy_password,
10481048
snowflake_ocsp_mode=self._ocsp_mode(),
10491049
trust_env=True, # Required for proxy support via environment variables
1050+
no_proxy=(
1051+
",".join(str(x) for x in self.no_proxy)
1052+
if (
1053+
self.no_proxy is not None
1054+
and isinstance(self.no_proxy, Iterable)
1055+
and not isinstance(self.no_proxy, (str, bytes))
1056+
)
1057+
else self.no_proxy
1058+
),
10501059
)
10511060
self._session_manager = SessionManagerFactory.get_manager(self._http_config)
10521061

src/snowflake/connector/aio/_ocsp_snowflake.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def _download_ocsp_response_cache(
9898
if sf_cache_server_url is not None:
9999
url = sf_cache_server_url
100100

101-
async with session_manager.use_session() as session:
101+
async with session_manager.use_session(url) as session:
102102
max_retry = SnowflakeOCSP.OCSP_CACHE_SERVER_MAX_RETRY if do_retry else 1
103103
sleep_time = 1
104104
backoff = exponential_backoff()()
@@ -544,7 +544,7 @@ async def _fetch_ocsp_response(
544544
if not self.is_enabled_fail_open():
545545
sf_max_retry = SnowflakeOCSP.CA_OCSP_RESPONDER_MAX_RETRY_FC
546546

547-
async with session_manager.use_session() as session:
547+
async with session_manager.use_session(target_url) as session:
548548
max_retry = sf_max_retry if do_retry else 1
549549
sleep_time = 1
550550
backoff = exponential_backoff()()

src/snowflake/connector/aio/_result_batch.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,18 @@ async def download_chunk(http_session):
247247
and connection.rest.session_manager is not None
248248
):
249249
# If connection was explicitly passed and not closed yet - we can reuse SessionManager with session pooling
250-
async with connection.rest.use_session() as session:
250+
async with connection.rest.use_session(
251+
request_data["url"]
252+
) as session:
251253
logger.debug(
252254
f"downloading result batch id: {self.id} with existing session {session}"
253255
)
254256
response, content, encoding = await download_chunk(session)
255257
elif self._session_manager is not None:
256258
# If connection is not accessible or was already closed, but cursors are now used to fetch the data - we will only reuse the http setup (through cloned SessionManager without session pooling)
257-
async with self._session_manager.use_session() as session:
259+
async with self._session_manager.use_session(
260+
request_data["url"]
261+
) as session:
258262
response, content, encoding = await download_chunk(session)
259263
else:
260264
# If there was no session manager cloned, then we are using a default Session Manager setup, since it is very unlikely to enter this part outside of testing
@@ -264,7 +268,9 @@ async def download_chunk(http_session):
264268
local_session_manager = SessionManagerFactory.get_manager(
265269
use_pooling=False
266270
)
267-
async with local_session_manager.use_session() as session:
271+
async with local_session_manager.use_session(
272+
request_data["url"]
273+
) as session:
268274
response, content, encoding = await download_chunk(session)
269275

270276
if response.status == OK:

src/snowflake/connector/aio/_session_manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class _RequestVerbsUsingSessionMixin(abc.ABC):
257257

258258
@abc.abstractmethod
259259
async def use_session(
260-
self, url: str, use_pooling: bool
260+
self, url: str | bytes, use_pooling: bool
261261
) -> AsyncGenerator[aiohttp.ClientSession]: ...
262262

263263
async def get(
@@ -453,9 +453,12 @@ def make_session(self) -> aiohttp.ClientSession:
453453

454454
@contextlib.asynccontextmanager
455455
async def use_session(
456-
self, url: str | bytes | None = None, use_pooling: bool | None = None
456+
self, url: str | bytes, use_pooling: bool | None = None
457457
) -> AsyncGenerator[aiohttp.ClientSession]:
458-
"""Async version of use_session yielding aiohttp.ClientSession."""
458+
"""
459+
Async version of use_session yielding aiohttp.ClientSession.
460+
'url' is an obligatory parameter due to the need for correct proxy handling (i.e. bypassing caused by no_proxy settings).
461+
"""
459462
use_pooling = use_pooling if use_pooling is not None else self.use_pooling
460463
if not use_pooling:
461464
session = self.make_session()

src/snowflake/connector/aio/_wif_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
1616
from ..errors import MissingDependencyError, ProgrammingError
17+
from ..session_manager import SessionManagerFactory
1718
from ..wif_util import (
1819
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE,
1920
SNOWFLAKE_AUDIENCE,
@@ -307,7 +308,7 @@ async def create_attestation(
307308
session_manager = (
308309
session_manager.clone()
309310
if session_manager
310-
else SessionManager(use_pooling=True, max_retries=0)
311+
else SessionManagerFactory.get_manager(use_pooling=True, max_retries=0)
311312
)
312313

313314
if provider == AttestationProvider.AWS:

0 commit comments

Comments
 (0)