Skip to content

Commit 59601fa

Browse files
[async] Applied #2596 to async code - part 3 - basic tests working
1 parent aa56ecb commit 59601fa

File tree

4 files changed

+93
-94
lines changed

4 files changed

+93
-94
lines changed

src/snowflake/connector/aio/_session_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ async def use_session(
467467
finally:
468468
await session.close()
469469
else:
470-
with self._yield_session_from_pool(url) as session_from_pool:
470+
for session_from_pool in self._yield_session_from_pool(url):
471471
yield session_from_pool
472472

473473
async def request(

src/snowflake/connector/url_util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from logging import getLogger
66
from typing import Iterable
77

8-
import vendored.requests as requests
9-
108
from .constants import _TOP_LEVEL_DOMAIN_REGEX
9+
from .vendored import requests
1110

1211
logger = getLogger(__name__)
1312

test/unit/aio/test_connection_async_unit.py

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from cryptography.hazmat.primitives.asymmetric import rsa
2727

2828
import snowflake.connector.aio
29-
from snowflake.connector.aio import connect as async_connect
3029
from snowflake.connector.aio._network import SnowflakeRestful
3130
from snowflake.connector.aio.auth import (
3231
AuthByDefault,
@@ -842,93 +841,3 @@ async def test_invalid_authenticator():
842841
)
843842
await conn.connect()
844843
assert "Unknown authenticator: INVALID" in str(excinfo.value)
845-
846-
847-
@pytest.mark.skipolddriver
848-
@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"])
849-
async def test_large_query_through_proxy_async(
850-
wiremock_generic_mappings_dir,
851-
wiremock_target_proxy_pair,
852-
wiremock_mapping_dir,
853-
proxy_env_vars,
854-
proxy_method,
855-
):
856-
target_wm, proxy_wm = wiremock_target_proxy_pair
857-
858-
password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json"
859-
multi_chunk_request_mapping = (
860-
wiremock_mapping_dir / "queries/select_large_request_successful.json"
861-
)
862-
disconnect_mapping = (
863-
wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
864-
)
865-
telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json"
866-
chunk_1_mapping = wiremock_mapping_dir / "queries/chunk_1.json"
867-
chunk_2_mapping = wiremock_mapping_dir / "queries/chunk_2.json"
868-
869-
expected_headers = {"Via": {"contains": "wiremock"}}
870-
871-
target_wm.import_mapping(password_mapping, expected_headers=expected_headers)
872-
target_wm.add_mapping_with_default_placeholders(
873-
multi_chunk_request_mapping, expected_headers
874-
)
875-
target_wm.add_mapping(disconnect_mapping, expected_headers=expected_headers)
876-
target_wm.add_mapping(telemetry_mapping, expected_headers=expected_headers)
877-
target_wm.add_mapping_with_default_placeholders(chunk_1_mapping, expected_headers)
878-
target_wm.add_mapping_with_default_placeholders(chunk_2_mapping, expected_headers)
879-
880-
set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars
881-
connect_kwargs = {
882-
"user": "testUser",
883-
"password": "testPassword",
884-
"account": "testAccount",
885-
"host": target_wm.wiremock_host,
886-
"port": target_wm.wiremock_http_port,
887-
"protocol": "http",
888-
"warehouse": "TEST_WH",
889-
}
890-
891-
if proxy_method == "explicit_args":
892-
connect_kwargs.update(
893-
{
894-
"proxy_host": proxy_wm.wiremock_host,
895-
"proxy_port": str(proxy_wm.wiremock_http_port),
896-
"proxy_user": "proxyUser",
897-
"proxy_password": "proxyPass",
898-
}
899-
)
900-
clear_proxy_env_vars()
901-
else:
902-
proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}"
903-
set_proxy_env_vars(proxy_url)
904-
905-
row_count = 50_000
906-
conn = await async_connect(**connect_kwargs)
907-
try:
908-
cur = conn.cursor()
909-
await cur.execute(
910-
f"select seq4() as n from table(generator(rowcount => {row_count}));"
911-
)
912-
assert len(cur._result_set.batches) > 1
913-
_ = [r async for r in cur]
914-
finally:
915-
await conn.close()
916-
917-
async with aiohttp.ClientSession() as session:
918-
async with session.get(
919-
f"{proxy_wm.http_host_with_port}/__admin/requests"
920-
) as resp:
921-
proxy_reqs = await resp.json()
922-
assert any(
923-
"/queries/v1/query-request" in r["request"]["url"]
924-
for r in proxy_reqs["requests"]
925-
)
926-
927-
async with session.get(
928-
f"{target_wm.http_host_with_port}/__admin/requests"
929-
) as resp:
930-
target_reqs = await resp.json()
931-
assert any(
932-
"/queries/v1/query-request" in r["request"]["url"]
933-
for r in target_reqs["requests"]
934-
)

test/unit/aio/test_proxies_async.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from snowflake.connector.aio import connect
7+
from snowflake.connector.aio import connect as async_connect
78

89
pytestmark = pytest.mark.asyncio
910

@@ -86,3 +87,93 @@ async def test_basic_query_through_proxy_async(
8687
"/queries/v1/query-request" in r["request"]["url"]
8788
for r in target_reqs["requests"]
8889
)
90+
91+
92+
@pytest.mark.skipolddriver
93+
@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"])
94+
async def test_large_query_through_proxy_async(
95+
wiremock_generic_mappings_dir,
96+
wiremock_target_proxy_pair,
97+
wiremock_mapping_dir,
98+
proxy_env_vars,
99+
proxy_method,
100+
):
101+
target_wm, proxy_wm = wiremock_target_proxy_pair
102+
103+
password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json"
104+
multi_chunk_request_mapping = (
105+
wiremock_mapping_dir / "queries/select_large_request_successful.json"
106+
)
107+
disconnect_mapping = (
108+
wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
109+
)
110+
telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json"
111+
chunk_1_mapping = wiremock_mapping_dir / "queries/chunk_1.json"
112+
chunk_2_mapping = wiremock_mapping_dir / "queries/chunk_2.json"
113+
114+
expected_headers = {"Via": {"contains": "wiremock"}}
115+
116+
target_wm.import_mapping(password_mapping, expected_headers=expected_headers)
117+
target_wm.add_mapping_with_default_placeholders(
118+
multi_chunk_request_mapping, expected_headers
119+
)
120+
target_wm.add_mapping(disconnect_mapping, expected_headers=expected_headers)
121+
target_wm.add_mapping(telemetry_mapping, expected_headers=expected_headers)
122+
target_wm.add_mapping_with_default_placeholders(chunk_1_mapping, expected_headers)
123+
target_wm.add_mapping_with_default_placeholders(chunk_2_mapping, expected_headers)
124+
125+
set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars
126+
connect_kwargs = {
127+
"user": "testUser",
128+
"password": "testPassword",
129+
"account": "testAccount",
130+
"host": target_wm.wiremock_host,
131+
"port": target_wm.wiremock_http_port,
132+
"protocol": "http",
133+
"warehouse": "TEST_WH",
134+
}
135+
136+
if proxy_method == "explicit_args":
137+
connect_kwargs.update(
138+
{
139+
"proxy_host": proxy_wm.wiremock_host,
140+
"proxy_port": str(proxy_wm.wiremock_http_port),
141+
"proxy_user": "proxyUser",
142+
"proxy_password": "proxyPass",
143+
}
144+
)
145+
clear_proxy_env_vars()
146+
else:
147+
proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}"
148+
set_proxy_env_vars(proxy_url)
149+
150+
row_count = 50_000
151+
conn = await async_connect(**connect_kwargs)
152+
try:
153+
cur = conn.cursor()
154+
await cur.execute(
155+
f"select seq4() as n from table(generator(rowcount => {row_count}));"
156+
)
157+
assert len(cur._result_set.batches) > 1
158+
_ = [r async for r in cur]
159+
finally:
160+
await conn.close()
161+
162+
async with aiohttp.ClientSession() as session:
163+
async with session.get(
164+
f"{proxy_wm.http_host_with_port}/__admin/requests"
165+
) as resp:
166+
proxy_reqs = await resp.json()
167+
assert any(
168+
"/queries/v1/query-request" in r["request"]["url"]
169+
for r in proxy_reqs["requests"]
170+
)
171+
172+
async with session.get(
173+
f"{target_wm.http_host_with_port}/__admin/requests"
174+
) as resp:
175+
target_reqs = await resp.json()
176+
assert any(
177+
"/queries/v1/query-request" in r["request"]["url"]
178+
for r in target_reqs["requests"]
179+
)

0 commit comments

Comments
 (0)