Skip to content

Commit aa56ecb

Browse files
[async] Applied #2596 to async code - part 2
1 parent b4f30bc commit aa56ecb

File tree

4 files changed

+24
-15
lines changed

4 files changed

+24
-15
lines changed

src/snowflake/connector/aio/_session_manager.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
resolve_cafile,
2222
)
2323
from ._crl import CRLValidator
24+
from ..url_util import should_bypass_proxies
2425
from ._ocsp_asn1crypto import SnowflakeOCSPAsn1Crypto
2526

2627
if TYPE_CHECKING:
@@ -439,7 +440,7 @@ def from_config(cls, cfg: AioHttpConfig, **overrides: Any) -> SessionManager:
439440
cfg = cfg.copy_with(**overrides)
440441
return cls(config=cfg)
441442

442-
def make_session(self) -> aiohttp.ClientSession:
443+
def make_session(self, *, url: str | None = None) -> aiohttp.ClientSession:
443444
"""Create a new aiohttp.ClientSession with configured connector."""
444445
connector = self._cfg.get_connector(
445446
session_manager=self.clone(),
@@ -456,24 +457,18 @@ async def use_session(
456457
self, url: str | bytes, use_pooling: bool | None = None
457458
) -> AsyncGenerator[aiohttp.ClientSession]:
458459
"""
459-
Async version of use_session yielding aiohttp.ClientSession.
460460
'url' is an obligatory parameter due to the need for correct proxy handling (i.e. bypassing caused by no_proxy settings).
461461
"""
462462
use_pooling = use_pooling if use_pooling is not None else self.use_pooling
463463
if not use_pooling:
464-
session = self.make_session()
464+
session = self.make_session(url=url)
465465
try:
466466
yield session
467467
finally:
468468
await session.close()
469469
else:
470-
hostname = urlparse(url).hostname if url else None
471-
pool = self._sessions_map[hostname]
472-
session = pool.get_session()
473-
try:
474-
yield session
475-
finally:
476-
pool.return_session(session)
470+
with self._yield_session_from_pool(url) as session_from_pool:
471+
yield session_from_pool
477472

478473
async def request(
479474
self,
@@ -585,16 +580,22 @@ def request(
585580
)
586581
return super().request(method, url, **kwargs)
587582

588-
def make_session(self) -> aiohttp.ClientSession:
583+
def make_session(self, *, url: str | None = None) -> aiohttp.ClientSession:
589584
connector = self._cfg.get_connector(
590585
session_manager=self.clone(),
591586
snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode,
592587
)
588+
# We use requests.utils here (in asynch code) to keep the behaviour uniform for synch and asynch code. If we wanted each version to depict its http library's behaviour, we could use here: aiohttp.helpers.proxy_bypass(url, proxies={...}) here
589+
proxy = (
590+
None
591+
if should_bypass_proxies(url, no_proxy=self.config.no_proxy)
592+
else self.proxy_url
593+
)
593594
# Construct session with base proxy set, request() may override per-URL when bypassing
594595
return self.SessionWithProxy(
595596
connector=connector,
596597
trust_env=self._cfg.trust_env,
597-
proxy=self.proxy_url,
598+
proxy=proxy,
598599
)
599600

600601

src/snowflake/connector/session_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from .compat import urlparse
1313
from .proxy import get_proxy_url
14+
from .url_util import should_bypass_proxies
1415
from .vendored import requests
1516
from .vendored.requests import Response, Session
1617
from .vendored.requests.adapters import BaseAdapter, HTTPAdapter
@@ -610,7 +611,7 @@ def make_session(self, *, url: str | None = None) -> Session:
610611
{
611612
"no_proxy": self._cfg.no_proxy,
612613
}
613-
if requests.utils.should_bypass_proxies(url, no_proxy=self.config.no_proxy)
614+
if should_bypass_proxies(url, no_proxy=self.config.no_proxy)
614615
else {
615616
"http": self.proxy_url,
616617
"https": self.proxy_url,

src/snowflake/connector/url_util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import re
44
import urllib.parse
55
from logging import getLogger
6+
from typing import Iterable
7+
8+
import vendored.requests as requests
69

710
from .constants import _TOP_LEVEL_DOMAIN_REGEX
811

@@ -47,3 +50,7 @@ def extract_top_level_domain_from_hostname(hostname: str | None = None) -> str:
4750
# RFC1034 for TLD spec, and https://data.iana.org/TLD/tlds-alpha-by-domain.txt for full TLD list
4851
match = re.search(_TOP_LEVEL_DOMAIN_REGEX, hostname)
4952
return (match.group(0)[1:] if match else "com").lower()
53+
54+
55+
def should_bypass_proxies(url: str | bytes, no_proxy: Iterable[str] | None) -> bool:
56+
return requests.utils.should_bypass_proxies(url, no_proxy)

test/unit/aio/mock_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ async def forbidden_connect(*args, **kwargs):
3434
raise NotImplementedError("Unit test tried to make real network connection")
3535

3636
class MockSessionManager(SessionManager):
37-
def make_session(self):
38-
session = super().make_session()
37+
def make_session(self, *, url: str | None = None):
38+
session = super().make_session(url)
3939
if not allow_send:
4040
# Block at connector._connect level (like sync blocks session.send)
4141
# This allows patches on session.request to work

0 commit comments

Comments
 (0)