Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 89 additions & 26 deletions test/test_utils/wiremock/wiremock_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,29 @@ def add_mapping(
if response.status_code != requests.codes.created:
raise RuntimeError("Failed to add mapping")

def get_requests(self) -> dict:
"""Get all requests seen by this wiremock instance.

Returns:
dict: JSON response from wiremock's /__admin/requests endpoint
"""
return requests.get(f"{self.http_host_with_port}/__admin/requests").json()

def saw_urls_matching(self, patterns: list[str]) -> bool:
"""Check if this wiremock instance saw any requests matching the given URL patterns.

Args:
patterns: List of string patterns to search for in request URLs

Returns:
bool: True if any request URL contains any of the patterns
"""
reqs = self.get_requests()
return any(
any(pattern in r["request"]["url"] for pattern in patterns)
for r in reqs["requests"]
)

def _find_free_port(self, forbidden_ports: Union[List[int], None] = None) -> int:
max_retries = 1 if forbidden_ports is None else 3
if forbidden_ports is None:
Expand Down Expand Up @@ -287,6 +310,66 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_wiremock()


@contextmanager
def get_configured_proxy_client(
target_host_with_port: str,
proxy_mapping_template: Union[str, dict, pathlib.Path, None] = None,
additional_proxy_placeholders: Optional[dict[str, object]] = None,
forbidden_ports: Optional[List[int]] = None,
additional_proxy_args: Optional[Iterable[str]] = None,
):
"""Context manager that starts and configures a proxy wiremock to forward to a target.

Parameters
----------
target_host_with_port
The target URL (e.g., 'http://localhost:8080') that the proxy should forward to.
proxy_mapping_template
Mapping JSON (str / dict / pathlib.Path) to be used for configuring the proxy.
If *None*, the default forward_all.json template is used.
additional_proxy_placeholders
Optional placeholders to be replaced in the proxy mapping *in addition* to
``{{TARGET_HTTP_HOST_WITH_PORT}}``.
forbidden_ports
List of ports that the proxy should avoid binding to.
additional_proxy_args
Extra command-line arguments passed to the proxy Wiremock instance.

Yields
------
WiremockClient
A configured proxy wiremock instance.
"""
# Resolve default mapping template if none provided
if proxy_mapping_template is None:
proxy_mapping_template = (
pathlib.Path(__file__).parent.parent.parent.parent
/ "test"
/ "data"
/ "wiremock"
/ "mappings"
/ "generic"
/ "proxy_forward_all.json"
)

# Start the *proxy* Wiremock
with WiremockClient(
forbidden_ports=forbidden_ports or [],
additional_wiremock_process_args=additional_proxy_args,
) as proxy_wm:
# Prepare placeholders so that proxy forwards to the target
placeholders: dict[str, object] = {
"{{TARGET_HTTP_HOST_WITH_PORT}}": target_host_with_port
}
if additional_proxy_placeholders:
placeholders.update(additional_proxy_placeholders)

# Configure proxy Wiremock to forward everything to target
proxy_wm.add_mapping(proxy_mapping_template, placeholders=placeholders)

yield proxy_wm


@contextmanager
def get_clients_for_proxy_and_target(
proxy_mapping_template: Union[str, dict, pathlib.Path, None] = None,
Expand All @@ -313,36 +396,16 @@ def get_clients_for_proxy_and_target(
Extra command-line arguments passed to the proxy Wiremock instance when it is
launched. Useful for tweaking Wiremock behaviour in specific tests.
"""

# Resolve default mapping template if none provided
if proxy_mapping_template is None:
proxy_mapping_template = (
pathlib.Path(__file__).parent.parent.parent.parent
/ "test"
/ "data"
/ "wiremock"
/ "mappings"
/ "generic"
/ "proxy_forward_all.json"
)

# Start the *target* Wiremock first – this will emulate Snowflake / IdP backend
with WiremockClient() as target_wm:
# Then start the *proxy* Wiremock and ensure it doesn't try to bind the same port
with WiremockClient(
# Start and configure proxy using extracted helper
with get_configured_proxy_client(
target_host_with_port=target_wm.http_host_with_port,
proxy_mapping_template=proxy_mapping_template,
additional_proxy_placeholders=additional_proxy_placeholders,
forbidden_ports=[target_wm.wiremock_http_port],
additional_wiremock_process_args=additional_proxy_args,
additional_proxy_args=additional_proxy_args,
) as proxy_wm:
# Prepare placeholders so that proxy forwards to the *target*
placeholders: dict[str, object] = {
"{{TARGET_HTTP_HOST_WITH_PORT}}": target_wm.http_host_with_port
}
if additional_proxy_placeholders:
placeholders.update(additional_proxy_placeholders)

# Configure proxy Wiremock to forward everything to target
proxy_wm.add_mapping(proxy_mapping_template, placeholders=placeholders)

# Yield control back to the caller with both Wiremocks ready
yield target_wm, proxy_wm

Expand Down
110 changes: 33 additions & 77 deletions test/unit/test_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pytest

import snowflake.connector
import snowflake.connector.vendored.requests as requests
from snowflake.connector.compat import urlparse as compat_urlparse
from snowflake.connector.errors import OperationalError
from snowflake.connector.session_manager import SessionManager
Expand Down Expand Up @@ -152,20 +151,10 @@ def test_basic_query_through_proxy(
cnx.close()

# Ensure proxy saw query
proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json()
assert any(
"/queries/v1/query-request" in r["request"]["url"]
for r in proxy_reqs["requests"]
)
assert proxy_wm.saw_urls_matching(["/queries/v1/query-request"])

# Ensure backend saw query
target_reqs = requests.get(
f"{target_wm.http_host_with_port}/__admin/requests"
).json()
assert any(
"/queries/v1/query-request" in r["request"]["url"]
for r in target_reqs["requests"]
)
assert target_wm.saw_urls_matching(["/queries/v1/query-request"])


@pytest.mark.skipolddriver
Expand Down Expand Up @@ -350,13 +339,15 @@ def _set_mappings_for_query_and_chunks(


def _execute_large_query(connect_kwargs, row_count: int):
"""Execute a large query using connection kwargs.

Creates a connection, executes the large query, and validates it uses multiple batches.
"""
with snowflake.connector.connect(**connect_kwargs) as conn:
cursors = conn.execute_string(
f"select seq4() as n from table(generator(rowcount => {row_count}));"
)
assert len(cursors[0]._result_set.batches) > 1
rs = list(cursors[0])
assert rs
with conn.cursor() as cur:
_execute_large_query_on_cursor(cur, row_count)
# Verify that the query used multiple batches (remote storage)
assert len(cur._result_set.batches) > 1


@pytest.fixture
Expand Down Expand Up @@ -387,30 +378,15 @@ class RequestFlags(NamedTuple):


def _collect_request_flags(proxy_wm, target_wm, storage_wm) -> RequestFlags:
proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json()
target_reqs = requests.get(
f"{target_wm.http_host_with_port}/__admin/requests"
).json()
storage_reqs = requests.get(
f"{storage_wm.http_host_with_port}/__admin/requests"
).json()

proxy_saw_db = any(
"/queries/v1/query-request" in r["request"]["url"]
for r in proxy_reqs["requests"]
)
target_saw_db = any(
"/queries/v1/query-request" in r["request"]["url"]
for r in target_reqs["requests"]
)
proxy_saw_storage = any(
"/amazonaws/test/s3testaccount/stage/results/" in r["request"]["url"]
for r in proxy_reqs["requests"]
)
storage_saw_storage = any(
"/amazonaws/test/s3testaccount/stage/results/" in r["request"]["url"]
for r in storage_reqs["requests"]
proxy_saw_db = proxy_wm.saw_urls_matching(["/queries/v1/query-request"])
target_saw_db = target_wm.saw_urls_matching(["/queries/v1/query-request"])
proxy_saw_storage = proxy_wm.saw_urls_matching(
["/amazonaws/test/s3testaccount/stage/results/"]
)
storage_saw_storage = storage_wm.saw_urls_matching(
["/amazonaws/test/s3testaccount/stage/results/"]
)

return RequestFlags(
proxy_saw_db=proxy_saw_db,
target_saw_db=target_saw_db,
Expand All @@ -425,21 +401,16 @@ class DbRequestFlags(NamedTuple):


def _collect_db_request_flags_only(proxy_wm, target_wm) -> DbRequestFlags:
proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json()
target_reqs = requests.get(
f"{target_wm.http_host_with_port}/__admin/requests"
).json()
proxy_saw_db = any(
"/queries/v1/query-request" in r["request"]["url"]
for r in proxy_reqs["requests"]
)
target_saw_db = any(
"/queries/v1/query-request" in r["request"]["url"]
for r in target_reqs["requests"]
)
proxy_saw_db = proxy_wm.saw_urls_matching(["/queries/v1/query-request"])
target_saw_db = target_wm.saw_urls_matching(["/queries/v1/query-request"])
return DbRequestFlags(proxy_saw_db=proxy_saw_db, target_saw_db=target_saw_db)


def _execute_large_query_on_cursor(cursor, row_count: int = 100000):
cursor.execute(f"SELECT seq4() as n FROM TABLE(GENERATOR(ROWCOUNT => {row_count}))")
return cursor.fetchall()


class ProxyPrecedenceFlags(NamedTuple):
proxy1_saw_request: bool
proxy2_saw_request: bool
Expand All @@ -449,29 +420,14 @@ class ProxyPrecedenceFlags(NamedTuple):
def _collect_proxy_precedence_flags(
proxy1_wm, proxy2_wm, target_wm
) -> ProxyPrecedenceFlags:
"""Collect flags for proxy precedence tests to see which proxy was used."""
proxy1_reqs = requests.get(
f"{proxy1_wm.http_host_with_port}/__admin/requests"
).json()
proxy2_reqs = requests.get(
f"{proxy2_wm.http_host_with_port}/__admin/requests"
).json()
target_reqs = requests.get(
f"{target_wm.http_host_with_port}/__admin/requests"
).json()

proxy1_saw_request = any(
"/queries/v1/query-request" in r["request"]["url"]
for r in proxy1_reqs["requests"]
)
proxy2_saw_request = any(
"/queries/v1/query-request" in r["request"]["url"]
for r in proxy2_reqs["requests"]
)
backend_saw_request = any(
"/queries/v1/query-request" in r["request"]["url"]
for r in target_reqs["requests"]
)
"""Collect flags for proxy precedence tests.

Checks which proxy (or target) saw query requests, useful for verifying
that connection parameters take precedence over environment variables.
"""
proxy1_saw_request = proxy1_wm.saw_urls_matching(["/queries/v1/query-request"])
proxy2_saw_request = proxy2_wm.saw_urls_matching(["/queries/v1/query-request"])
backend_saw_request = target_wm.saw_urls_matching(["/queries/v1/query-request"])

return ProxyPrecedenceFlags(
proxy1_saw_request=proxy1_saw_request,
Expand Down
Loading