diff --git a/test/test_utils/wiremock/wiremock_utils.py b/test/test_utils/wiremock/wiremock_utils.py index 7a7b28bdc..0ec1359df 100644 --- a/test/test_utils/wiremock/wiremock_utils.py +++ b/test/test_utils/wiremock/wiremock_utils.py @@ -257,6 +257,29 @@ def add_mapping( if response.status_code != requests.codes.created: raise RuntimeError("Failed to add mapping") + def get_requests(self) -> dict: + """Get all requests seen by this wiremock instance. + + Returns: + dict: JSON response from wiremock's /__admin/requests endpoint + """ + return requests.get(f"{self.http_host_with_port}/__admin/requests").json() + + def saw_urls_matching(self, patterns: list[str]) -> bool: + """Check if this wiremock instance saw any requests matching the given URL patterns. + + Args: + patterns: List of string patterns to search for in request URLs + + Returns: + bool: True if any request URL contains any of the patterns + """ + reqs = self.get_requests() + return any( + any(pattern in r["request"]["url"] for pattern in patterns) + for r in reqs["requests"] + ) + def _find_free_port(self, forbidden_ports: Union[List[int], None] = None) -> int: max_retries = 1 if forbidden_ports is None else 3 if forbidden_ports is None: @@ -287,6 +310,66 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._stop_wiremock() +@contextmanager +def get_configured_proxy_client( + target_host_with_port: str, + proxy_mapping_template: Union[str, dict, pathlib.Path, None] = None, + additional_proxy_placeholders: Optional[dict[str, object]] = None, + forbidden_ports: Optional[List[int]] = None, + additional_proxy_args: Optional[Iterable[str]] = None, +): + """Context manager that starts and configures a proxy wiremock to forward to a target. + + Parameters + ---------- + target_host_with_port + The target URL (e.g., 'http://localhost:8080') that the proxy should forward to. + proxy_mapping_template + Mapping JSON (str / dict / pathlib.Path) to be used for configuring the proxy. + If *None*, the default forward_all.json template is used. + additional_proxy_placeholders + Optional placeholders to be replaced in the proxy mapping *in addition* to + ``{{TARGET_HTTP_HOST_WITH_PORT}}``. + forbidden_ports + List of ports that the proxy should avoid binding to. + additional_proxy_args + Extra command-line arguments passed to the proxy Wiremock instance. + + Yields + ------ + WiremockClient + A configured proxy wiremock instance. + """ + # Resolve default mapping template if none provided + if proxy_mapping_template is None: + proxy_mapping_template = ( + pathlib.Path(__file__).parent.parent.parent.parent + / "test" + / "data" + / "wiremock" + / "mappings" + / "generic" + / "proxy_forward_all.json" + ) + + # Start the *proxy* Wiremock + with WiremockClient( + forbidden_ports=forbidden_ports or [], + additional_wiremock_process_args=additional_proxy_args, + ) as proxy_wm: + # Prepare placeholders so that proxy forwards to the target + placeholders: dict[str, object] = { + "{{TARGET_HTTP_HOST_WITH_PORT}}": target_host_with_port + } + if additional_proxy_placeholders: + placeholders.update(additional_proxy_placeholders) + + # Configure proxy Wiremock to forward everything to target + proxy_wm.add_mapping(proxy_mapping_template, placeholders=placeholders) + + yield proxy_wm + + @contextmanager def get_clients_for_proxy_and_target( proxy_mapping_template: Union[str, dict, pathlib.Path, None] = None, @@ -313,36 +396,16 @@ def get_clients_for_proxy_and_target( Extra command-line arguments passed to the proxy Wiremock instance when it is launched. Useful for tweaking Wiremock behaviour in specific tests. """ - - # Resolve default mapping template if none provided - if proxy_mapping_template is None: - proxy_mapping_template = ( - pathlib.Path(__file__).parent.parent.parent.parent - / "test" - / "data" - / "wiremock" - / "mappings" - / "generic" - / "proxy_forward_all.json" - ) - # Start the *target* Wiremock first – this will emulate Snowflake / IdP backend with WiremockClient() as target_wm: - # Then start the *proxy* Wiremock and ensure it doesn't try to bind the same port - with WiremockClient( + # Start and configure proxy using extracted helper + with get_configured_proxy_client( + target_host_with_port=target_wm.http_host_with_port, + proxy_mapping_template=proxy_mapping_template, + additional_proxy_placeholders=additional_proxy_placeholders, forbidden_ports=[target_wm.wiremock_http_port], - additional_wiremock_process_args=additional_proxy_args, + additional_proxy_args=additional_proxy_args, ) as proxy_wm: - # Prepare placeholders so that proxy forwards to the *target* - placeholders: dict[str, object] = { - "{{TARGET_HTTP_HOST_WITH_PORT}}": target_wm.http_host_with_port - } - if additional_proxy_placeholders: - placeholders.update(additional_proxy_placeholders) - - # Configure proxy Wiremock to forward everything to target - proxy_wm.add_mapping(proxy_mapping_template, placeholders=placeholders) - # Yield control back to the caller with both Wiremocks ready yield target_wm, proxy_wm diff --git a/test/unit/test_proxies.py b/test/unit/test_proxies.py index 7bbcaf5d9..7124a250c 100644 --- a/test/unit/test_proxies.py +++ b/test/unit/test_proxies.py @@ -10,7 +10,6 @@ import pytest import snowflake.connector -import snowflake.connector.vendored.requests as requests from snowflake.connector.compat import urlparse as compat_urlparse from snowflake.connector.errors import OperationalError from snowflake.connector.session_manager import SessionManager @@ -152,20 +151,10 @@ def test_basic_query_through_proxy( cnx.close() # Ensure proxy saw query - proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json() - assert any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy_reqs["requests"] - ) + assert proxy_wm.saw_urls_matching(["/queries/v1/query-request"]) # Ensure backend saw query - target_reqs = requests.get( - f"{target_wm.http_host_with_port}/__admin/requests" - ).json() - assert any( - "/queries/v1/query-request" in r["request"]["url"] - for r in target_reqs["requests"] - ) + assert target_wm.saw_urls_matching(["/queries/v1/query-request"]) @pytest.mark.skipolddriver @@ -350,13 +339,15 @@ def _set_mappings_for_query_and_chunks( def _execute_large_query(connect_kwargs, row_count: int): + """Execute a large query using connection kwargs. + + Creates a connection, executes the large query, and validates it uses multiple batches. + """ with snowflake.connector.connect(**connect_kwargs) as conn: - cursors = conn.execute_string( - f"select seq4() as n from table(generator(rowcount => {row_count}));" - ) - assert len(cursors[0]._result_set.batches) > 1 - rs = list(cursors[0]) - assert rs + with conn.cursor() as cur: + _execute_large_query_on_cursor(cur, row_count) + # Verify that the query used multiple batches (remote storage) + assert len(cur._result_set.batches) > 1 @pytest.fixture @@ -387,30 +378,15 @@ class RequestFlags(NamedTuple): def _collect_request_flags(proxy_wm, target_wm, storage_wm) -> RequestFlags: - proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json() - target_reqs = requests.get( - f"{target_wm.http_host_with_port}/__admin/requests" - ).json() - storage_reqs = requests.get( - f"{storage_wm.http_host_with_port}/__admin/requests" - ).json() - - proxy_saw_db = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy_reqs["requests"] - ) - target_saw_db = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in target_reqs["requests"] - ) - proxy_saw_storage = any( - "/amazonaws/test/s3testaccount/stage/results/" in r["request"]["url"] - for r in proxy_reqs["requests"] - ) - storage_saw_storage = any( - "/amazonaws/test/s3testaccount/stage/results/" in r["request"]["url"] - for r in storage_reqs["requests"] + proxy_saw_db = proxy_wm.saw_urls_matching(["/queries/v1/query-request"]) + target_saw_db = target_wm.saw_urls_matching(["/queries/v1/query-request"]) + proxy_saw_storage = proxy_wm.saw_urls_matching( + ["/amazonaws/test/s3testaccount/stage/results/"] + ) + storage_saw_storage = storage_wm.saw_urls_matching( + ["/amazonaws/test/s3testaccount/stage/results/"] ) + return RequestFlags( proxy_saw_db=proxy_saw_db, target_saw_db=target_saw_db, @@ -425,21 +401,16 @@ class DbRequestFlags(NamedTuple): def _collect_db_request_flags_only(proxy_wm, target_wm) -> DbRequestFlags: - proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json() - target_reqs = requests.get( - f"{target_wm.http_host_with_port}/__admin/requests" - ).json() - proxy_saw_db = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy_reqs["requests"] - ) - target_saw_db = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in target_reqs["requests"] - ) + proxy_saw_db = proxy_wm.saw_urls_matching(["/queries/v1/query-request"]) + target_saw_db = target_wm.saw_urls_matching(["/queries/v1/query-request"]) return DbRequestFlags(proxy_saw_db=proxy_saw_db, target_saw_db=target_saw_db) +def _execute_large_query_on_cursor(cursor, row_count: int = 100000): + cursor.execute(f"SELECT seq4() as n FROM TABLE(GENERATOR(ROWCOUNT => {row_count}))") + return cursor.fetchall() + + class ProxyPrecedenceFlags(NamedTuple): proxy1_saw_request: bool proxy2_saw_request: bool @@ -449,29 +420,14 @@ class ProxyPrecedenceFlags(NamedTuple): def _collect_proxy_precedence_flags( proxy1_wm, proxy2_wm, target_wm ) -> ProxyPrecedenceFlags: - """Collect flags for proxy precedence tests to see which proxy was used.""" - proxy1_reqs = requests.get( - f"{proxy1_wm.http_host_with_port}/__admin/requests" - ).json() - proxy2_reqs = requests.get( - f"{proxy2_wm.http_host_with_port}/__admin/requests" - ).json() - target_reqs = requests.get( - f"{target_wm.http_host_with_port}/__admin/requests" - ).json() - - proxy1_saw_request = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy1_reqs["requests"] - ) - proxy2_saw_request = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy2_reqs["requests"] - ) - backend_saw_request = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in target_reqs["requests"] - ) + """Collect flags for proxy precedence tests. + + Checks which proxy (or target) saw query requests, useful for verifying + that connection parameters take precedence over environment variables. + """ + proxy1_saw_request = proxy1_wm.saw_urls_matching(["/queries/v1/query-request"]) + proxy2_saw_request = proxy2_wm.saw_urls_matching(["/queries/v1/query-request"]) + backend_saw_request = target_wm.saw_urls_matching(["/queries/v1/query-request"]) return ProxyPrecedenceFlags( proxy1_saw_request=proxy1_saw_request,