diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b62b8b868c..eb5689f01e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,11 @@ repos: src/snowflake/connector/nanoarrow_cpp/ArrowIterator/flatcc/.*\.h| )$ - id: check-yaml - exclude: .github/repo_meta.yaml + exclude: > + (?x)^( + .github/repo_meta.yaml| + ci/anaconda/recipe/meta.yaml| + )$ - id: debug-statements - id: check-ast - repo: https://github.com/asottile/yesqa diff --git a/ci/anaconda/bld.bat b/ci/anaconda/bld.bat deleted file mode 100644 index 5a5aeeb48b..0000000000 --- a/ci/anaconda/bld.bat +++ /dev/null @@ -1 +0,0 @@ -$PYTHON setup.py install diff --git a/ci/anaconda/build.sh b/ci/anaconda/build.sh deleted file mode 100644 index a6609066d9..0000000000 --- a/ci/anaconda/build.sh +++ /dev/null @@ -1 +0,0 @@ -$PYTHON setup.py install --single-version-externally-managed --record=record.txt diff --git a/ci/anaconda/conda_build.sh b/ci/anaconda/conda_build.sh new file mode 100755 index 0000000000..275c30c6fe --- /dev/null +++ b/ci/anaconda/conda_build.sh @@ -0,0 +1,9 @@ +# Check https://snow-external.slack.com/archives/C02D68R4D0D/p1678899446863299 for context about using --numpy +conda install conda-build +conda install conda-verify +conda install diffutils +conda build ci/anaconda/recipe/ --python 3.9 +conda build ci/anaconda/recipe/ --python 3.10 +conda build ci/anaconda/recipe/ --python 3.11 +conda build ci/anaconda/recipe/ --python 3.12 +conda build ci/anaconda/recipe/ --python 3.13 diff --git a/ci/anaconda/meta.yaml b/ci/anaconda/meta.yaml deleted file mode 100644 index 09f7648495..0000000000 --- a/ci/anaconda/meta.yaml +++ /dev/null @@ -1,29 +0,0 @@ -package: - name: snowflake_connector_python - version: "1.2.3" - -source: - path: /tmp/anaconda_workspace/src - -requirements: - build: - - python - - setuptools - - run: - - python - - boto3 ==1.3.1 - - botocore ==1.4.26 - - future - - six - - pytz - - pycrypto ==2.6.1 - - pyopenssl ==0.15.1 - - cryptography ==1.2.3 - - cffi ==1.6.0 - -about: - home: https://www.snowflake.com/ - license: Apache 2.0 - license_file: /tmp/anaconda_workspace/src/LICENSE.txt - summary: Snowflake Connector for Python diff --git a/ci/anaconda/package_builder.sh b/ci/anaconda/package_builder.sh new file mode 100755 index 0000000000..dab6a63507 --- /dev/null +++ b/ci/anaconda/package_builder.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +export SNOWFLAKE_CONNECTOR_PYTHON_DIR=/repo/snowflake-connector-python +export CONDA_BLD_PATH=/repo/conda-bld + +mkdir -p $CONDA_BLD_PATH +cd "$SNOWFLAKE_CONNECTOR_PYTHON_DIR" +# Build with .tar.bz2 (pkg_format = 1) and .conda (pkg_format = 2). +conda config --set conda_build.pkg_format 1 +bash ./ci/anaconda/conda_build.sh +conda config --set conda_build.pkg_format 2 +bash ./ci/anaconda/conda_build.sh +conda config --remove-key conda_build.pkg_format +conda build purge +cd $CONDA_BLD_PATH +conda index . +chmod -R o+w,g+w $CONDA_BLD_PATH diff --git a/ci/anaconda/recipe/meta.yaml b/ci/anaconda/recipe/meta.yaml new file mode 100644 index 0000000000..43c22d8470 --- /dev/null +++ b/ci/anaconda/recipe/meta.yaml @@ -0,0 +1,81 @@ +{% set name = "snowflake-connector-python" %} +{% set version = os.environ.get('SNOWFLAKE_CONNECTOR_PYTHON_VERSION', 0) %} +{% set build_number = os.environ.get('PUBLIC_CONNECTOR_BUILD_NUMBER', 0) %} + +package: + name: {{ name|lower }} + version: {{ version }} + +source: + path: ../../.. + +build: + number: {{ build_number }} + string: "py{{ py }}_{{ build_number }}" + script: + - export SF_NO_COPY_ARROW_LIB=1 # [unix] + - export SF_ARROW_LIBDIR="${PREFIX}/lib" # [unix] + - export ENABLE_EXT_MODULES=true # [unix] + - {{ PYTHON }} -m pip install . --no-use-pep517 --no-deps -vvv + entry_points: + - snowflake-dump-ocsp-response = snowflake.connector.tool.dump_ocsp_response:main + - snowflake-dump-ocsp-response-cache = snowflake.connector.tool.dump_ocsp_response_cache:main + - snowflake-dump-certs = snowflake.connector.tool.dump_certs:main + +requirements: + build: + - {{ compiler("c") }} + - {{ compiler("cxx") }} + - libgcc-ng + - libstdcxx-ng + - patch # [not win] + host: + - setuptools >=40.6.0 + - wheel + - cython + - python {{ python }} + run: + {% if py == 39 %} + - python >=3.9,<3.10.0a0 + {% elif py == 310 %} + - python >=3.10,<3.11.0a0 + {% elif py == 311 %} + - python >=3.11,<3.12.0a0 + {% elif py == 312 %} + - python >=3.12,<3.13.0a0 + {% elif py == 313 %} + - python >=3.13,<3.14.0a0 + {% else %} + - python + {% endif %} + - asn1crypto >0.24.0,<2.0.0 + - cryptography >=44.0.1 + - pyOpenSSL >=24.0.0,<26.0.0 + - pyjwt >=2.10.1,<3.0.0 + - pytz + - requests >=2.32.4,<3.0.0 + - packaging + - charset-normalizer >=2,<4 + - idna >=3.7,<4 + - urllib3 >=1.26.5,<2.0.0 # [py<310] + - certifi >=2024.7.4 + - typing_extensions >=4.3,<5 + - filelock >=3.5,<4 + - sortedcontainers >=2.4.0 + - platformdirs >=2.6.0,<5.0.0 + - tomlkit + - boto3 >=1.24 + - botocore >=1.24 +test: + requires: + - pip + imports: + - snowflake + - snowflake.connector + - snowflake.connector.nanoarrow_arrow_iterator # [unix] + commands: + - pip check + +about: + home: https://github.com/snowflakedb/snowflake-connector-python + summary: Snowflake Connector for Python diff --git a/ci/anaconda/run.sh b/ci/anaconda/run.sh new file mode 100755 index 0000000000..f13fac9d6f --- /dev/null +++ b/ci/anaconda/run.sh @@ -0,0 +1,106 @@ + +# Before manual running, do something similar to the following in command line. +# export WORKSPACE=/home/jdoe/my_workspace; +# export PUBLIC_CONNECTOR_BUILD_NUMBER=321; +# export aarch64_base_image=; +# export x86_base_image=; + +# Here miniconda-install.sh is just a installer that I downloaded from Anaconda official site, +# https://repo.anaconda.com/miniconda/ + + +if [[ -z $WORKSPACE ]]; then + # Development on dev machine + WORKSPACE=$HOME +fi + +# ===== Build docker image ===== +cd $WORKSPACE + +# Validate dependency sync before building +python3 -m venv tmp_validate_env +source tmp_validate_env/bin/activate +pip install pyyaml +python3 $WORKSPACE/snowflake-connector-python/ci/anaconda/validate_deps_sync.py +if [[ $? -ne 0 ]]; then + echo "[FAILURE] setup.cfg and meta.yaml dependencies are not in sync" + deactivate + rm -rf tmp_validate_env + exit 1 +fi +deactivate +rm -rf tmp_validate_env + +docker build \ + --build-arg ARCH=$(uname -m) \ + --build-arg AARCH64_BASE_IMAGE="${aarch64_base_image}" \ + --build-arg X86_BASE_IMAGE="${x86_base_image}" \ + -t snowflake_connector_python_image \ + -f - . <<'DOCKERFILE' +# Use different base images based on target platform + +ARG ARCH +ARG AARCH64_BASE_IMAGE=artifactory.int.snowflakecomputing.com/development-docker-virtual/arm64v8/centos:8 +ARG X86_BASE_IMAGE=artifactory.int.snowflakecomputing.com/development-docker-virtual/centos:8 + +FROM ${AARCH64_BASE_IMAGE} AS base-aarch64 + +FROM ${X86_BASE_IMAGE} AS base-x86_64 + + + +# Select the appropriate base image based on target architecture + +FROM base-${ARCH} AS base + +COPY miniconda-install.sh . + + + +RUN chmod 0755 miniconda-install.sh + + + +RUN mkdir -p /etc/miniconda && bash miniconda-install.sh -b -u -p /etc/miniconda/ + + + +RUN ln -s /etc/miniconda/bin/conda /usr/bin/conda && rm miniconda-install.sh +DOCKERFILE + +# Go back to the original directory +cd $WORKSPACE + + +# Check to make sure repos exist to build conda packages +if [[ -d $WORKSPACE/snowflake-connector-python ]]; then + echo "Check snowflake-connector-python repo exists - PASSED" +else + echo "[FAILURE] Please clone snowflake-connector-python repo at $WORKSPACE/snowflake-connector-python" +fi + +# Extract connector version if not provided +if [[ -z "$SNOWFLAKE_CONNECTOR_PYTHON_VERSION" ]]; then + VERSION_FILE="$WORKSPACE/snowflake-connector-python/src/snowflake/connector/version.py" + if [[ -f "$VERSION_FILE" ]]; then + SNOWFLAKE_CONNECTOR_PYTHON_VERSION=$( \ + grep -Eo 'VERSION\s*=\s*\([^)]*\)' "$VERSION_FILE" \ + | grep -Eo '[0-9]+' \ + | paste -sd '.' - \ + ) + export SNOWFLAKE_CONNECTOR_PYTHON_VERSION + fi +fi + +# Run packager in docker image +docker run \ + -v $WORKSPACE/snowflake-connector-python/:/repo/snowflake-connector-python \ + -v $WORKSPACE/conda-bld:/repo/conda-bld \ + -e SNOWFLAKE_CONNECTOR_PYTHON_VERSION=${SNOWFLAKE_CONNECTOR_PYTHON_VERSION} \ + -e PUBLIC_CONNECTOR_BUILD_NUMBER=${PUBLIC_CONNECTOR_BUILD_NUMBER} \ + snowflake_connector_python_image \ + /repo/snowflake-connector-python/ci/anaconda/package_builder.sh + +# Cleanup image for disk space +docker container prune -f +docker rmi snowflake_connector_python_image diff --git a/ci/anaconda/validate_deps_sync.py b/ci/anaconda/validate_deps_sync.py new file mode 100644 index 0000000000..d94c832c3c --- /dev/null +++ b/ci/anaconda/validate_deps_sync.py @@ -0,0 +1,243 @@ +"""Validate conda run requirements match setup.cfg install_requires. + +Note that it assumes default requirements to be install_requires + boto. If that assumption is no longer true, +this script needs to be updated accordingly. + +Exit behavior: +- If there is no diff: exit 0 with no output. +- If there is a diff: print the diff and exit 1. +""" + +import configparser +import re +import sys +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +import yaml + + +def repo_root() -> Path: + """Return repository root based on this file location.""" + return Path(__file__).resolve().parents[2] + + +def normalize_name(name: str) -> str: + """Normalize a dependency name to a canonical form. + + Replaces underscores with hyphens and lowercases the name. + + Args: + name: Raw package name. + + Returns: + Normalized package name. + """ + return name.strip().lower().replace("_", "-") + + +def split_requirement(req: str) -> Tuple[str, str]: + """Split a requirement into name and version specifier. + + Drops PEP 508 markers (after ';') and conda selectors (after '#'). + + Args: + req: A single requirement line. + + Returns: + Tuple of (normalized_name, normalized_spec). Spec contains no spaces. + """ + # Drop markers and selectors + req = req.split(";", 1)[0] + req = req.split("#", 1)[0] + req = req.strip() + if not req: + return "", "" + + # Find first comparator + m = re.search(r"(<=|>=|==|!=|~=|<|>|=)", req) + if m: + name = req[: m.start()].strip() + spec = req[m.start() :].strip() + else: + # No version specified + parts = req.split() + name = parts[0] if parts else "" + spec = "" + + # Normalize + spec = re.sub(r"\s*,\s*", ",", spec) + spec = re.sub(r"\s+", "", spec) + return normalize_name(name), spec + + +def get_setup_install_requires(cfg_path: Path) -> List[str]: + """Extract normalized install_requires entries from setup.cfg. + + Args: + cfg_path: Path to setup.cfg. + + Returns: + List of strings in the form " " where spec may be empty. + """ + parser = configparser.ConfigParser(interpolation=None) + parser.read(cfg_path, encoding="utf-8") + if not parser.has_section("options"): + raise RuntimeError(f"Missing [options] section in {cfg_path}") + if not parser.has_option("options", "install_requires"): + raise RuntimeError(f"Missing install_requires under [options] in {cfg_path}") + raw_value = parser.get("options", "install_requires") + deps: List[str] = [] + for line in raw_value.splitlines(): + item = line.strip() + if not item or item.startswith("#"): + continue + name, spec = split_requirement(item) + if name and name != "python": + deps.append(f"{name} {spec}".strip()) + return deps + + +def get_meta_run_requirements(meta_path: Path) -> List[str]: + """Extract normalized run requirements from meta.yaml. + + Args: + meta_path: Path to meta.yaml. + + Returns: + List of strings in the form " " where spec may be empty. + """ + text = meta_path.read_text(encoding="utf-8") + cleaned_lines: List[str] = [] + for line in text.splitlines(): + if "{%" in line or "%}" in line: + continue + if "{{" in line and "}}" in line: + continue + cleaned_lines.append(line) + cleaned = "\n".join(cleaned_lines) + + try: + data = yaml.safe_load(cleaned) or {} + except Exception as exc: + raise RuntimeError(f"Failed to parse YAML for {meta_path}") from exc + + reqs = data.get("requirements", {}) or {} + run_items = reqs.get("run", []) or [] + + deps: List[str] = [] + for idx, it in enumerate(run_items): + if not isinstance(it, str): + raise TypeError( + f"requirements.run entry at index {idx} in {meta_path} " + f"must be a string; got {type(it).__name__}: {it!r}" + ) + name, spec = split_requirement(it) + if name and name != "python": + deps.append(f"{name} {spec}".strip()) + return deps + + +def get_setup_extra_requires(cfg_path: Path, extra: str) -> List[str]: + """Extract normalized requirements for a given extra from setup.cfg. + + Args: + cfg_path: Path to setup.cfg. + extra: The extras_require key to extract (e.g., "boto"). + + Returns: + List of strings in the form " ". + """ + parser = configparser.ConfigParser(interpolation=None) + parser.read(cfg_path, encoding="utf-8") + section = "options.extras_require" + if not parser.has_section(section): + raise RuntimeError(f"Missing [options.extras_require] section in {cfg_path}") + key = extra.strip().lower() + if not parser.has_option(section, key): + raise RuntimeError( + f"Missing extra '{key}' under [options.extras_require] in {cfg_path}" + ) + raw_value = parser.get(section, key) + deps: List[str] = [] + for line in raw_value.splitlines(): + item = line.strip() + if not item or item.startswith("#"): + continue + name, spec = split_requirement(item) + if name and name != "python": + deps.append(f"{name} {spec}".strip()) + return deps + + +def compare_deps(setup_deps: Iterable[str], meta_deps: Iterable[str]) -> str: + """Compare two dependency lists and return a human-readable diff. + + Args: + setup_deps: Normalized dependencies from setup.cfg. + meta_deps: Normalized dependencies from meta.yaml. + + Returns: + Empty string if equal, otherwise a multi-line diff description. + """ + + def to_map(items: Iterable[str]) -> Dict[str, str]: + mapping: Dict[str, str] = {} + for it in items: + parts = it.split(" ", 1) + name = parts[0] + spec = parts[1] if len(parts) > 1 else "" + mapping[name] = spec + return mapping + + s_map = to_map(setup_deps) + m_map = to_map(meta_deps) + + s_names = set(s_map) + m_names = set(m_map) + missing = sorted(s_names - m_names) + extra = sorted(m_names - s_names) + + mismatches: List[Tuple[str, str, str]] = [] + for name in sorted(s_names & m_names): + if s_map.get(name, "") != m_map.get(name, ""): + mismatches.append((name, s_map.get(name, ""), m_map.get(name, ""))) + + if not (missing or extra or mismatches): + return "" + + lines: List[str] = [] + if missing: + lines.append("Missing in meta.yaml run:") + for n in missing: + lines.append(f" - {n} ({s_map[n] or 'no spec'})") + if extra: + lines.append("Extra in meta.yaml run:") + for n in extra: + lines.append(f" - {n} ({m_map[n] or 'no spec'})") + if mismatches: + lines.append("Version spec mismatches:") + for n, s, m in mismatches: + lines.append(f" - {n}: setup.cfg='{s}' vs meta.yaml='{m}'") + return "\n".join(lines) + + +def main() -> int: + root = repo_root() + setup_cfg_path = root / "setup.cfg" + setup_deps = get_setup_install_requires(setup_cfg_path) + boto_deps = get_setup_extra_requires(setup_cfg_path, "boto") + # Make sure to update ci/anaconda/recipe/meta.yaml accordingly when there is dependency set update. + expected_deps = setup_deps + boto_deps + meta_deps = get_meta_run_requirements( + root / "ci" / "anaconda" / "recipe" / "meta.yaml" + ) + diff = compare_deps(expected_deps, meta_deps) + if not diff: + return 0 + print(diff) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/setup.cfg b/setup.cfg index 25c3a8dc91..87fde2c44e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ project_urls = python_requires = >=3.9 packages = find_namespace: install_requires = + # Make sure to update ci/anaconda/recipe/meta.yaml accordingly when there is dependency set update. # [boto] extension is added by default unless SNOWFLAKE_NO_BOTO variable is set # check setup.py asn1crypto>0.24.0,<2.0.0 diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index 145e19761b..8bfeca6f2e 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -375,6 +375,7 @@ async def __open_connection(self): host=self.host, port=self.port ), redirect_uri=self._oauth_redirect_uri, + uri=self._oauth_socket_uri, scope=self._oauth_scope, pkce_enabled=not self._oauth_disable_pkce, token_cache=( @@ -925,9 +926,8 @@ async def close(self, retry: bool = True) -> None: # close telemetry first, since it needs rest to send remaining data logger.debug("closed") - await self._telemetry.close( - send_on_close=bool(retry and self.telemetry_enabled) - ) + if self.telemetry_enabled: + await self._telemetry.close(retry=retry) if ( await self._all_async_queries_finished() and not self._server_session_keep_alive @@ -1054,6 +1054,15 @@ async def connect(self, **kwargs) -> None: else: self.__config(**self._conn_parameters) + no_proxy_csv_str = ( + ",".join(str(x) for x in self.no_proxy) + if ( + self.no_proxy is not None + and isinstance(self.no_proxy, Iterable) + and not isinstance(self.no_proxy, (str, bytes)) + ) + else self.no_proxy + ) self._http_config: AioHttpConfig = AioHttpConfig( connector_factory=SnowflakeSSLConnectorFactory(), use_pooling=not self.disable_request_pooling, @@ -1063,15 +1072,7 @@ async def connect(self, **kwargs) -> None: proxy_password=self.proxy_password, snowflake_ocsp_mode=self._ocsp_mode(), trust_env=True, # Required for proxy support via environment variables - no_proxy=( - ",".join(str(x) for x in self.no_proxy) - if ( - self.no_proxy is not None - and isinstance(self.no_proxy, Iterable) - and not isinstance(self.no_proxy, (str, bytes)) - ) - else self.no_proxy - ), + no_proxy=no_proxy_csv_str, ) self._session_manager = SessionManagerFactory.get_manager(self._http_config) diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index 919c8a6956..5eb2df7018 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -38,7 +38,7 @@ from ..session_manager import BaseHttpConfig from ..session_manager import SessionManager as SessionManagerSync from ..session_manager import SessionPool as SessionPoolSync -from ..session_manager import _ConfigDirectAccessMixin +from ..session_manager import _BaseConfigDirectAccessMixin logger = logging.getLogger(__name__) @@ -339,7 +339,7 @@ async def delete( ) -class _AsyncHttpConfigDirectAccessMixin(_ConfigDirectAccessMixin, abc.ABC): +class _AsyncHttpConfigDirectAccessMixin(_BaseConfigDirectAccessMixin, abc.ABC): @property @abc.abstractmethod def config(self) -> AioHttpConfig: ... @@ -543,17 +543,24 @@ def make_session(self, *, url: str | None = None) -> aiohttp.ClientSession: session_manager=self.clone(), snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode, ) - # We use requests.utils here (in asynch code) to keep the behaviour uniform for synch and asynch code. If we wanted each version to depict its http library's behaviour, we could use here: aiohttp.helpers.proxy_bypass(url, proxies={...}) here - proxy = ( - None - if should_bypass_proxies(url, no_proxy=self.config.no_proxy) - else self.proxy_url - ) + + proxy_from_conn_params: str | None = None + if not aiohttp.helpers.proxies_from_env(): + # TODO: This is only needed because we want to keep compatibility with the synch driver version. + # Otherwise, we could remove that condition and always pass proxy from conn params to the Session constructor. + # But in such case precedence will be reverted and it will overwrite the env vars settings. + + # We use requests.utils here (in asynch code) to keep the behaviour uniform for synch and asynch code. If we wanted each version to depict its http library's behaviour, we could use here: aiohttp.helpers.proxy_bypass(url, proxies={...}) here + proxy_from_conn_params = ( + None + if should_bypass_proxies(url, no_proxy=self.config.no_proxy) + else self.proxy_url + ) # Construct session with base proxy set, request() may override per-URL when bypassing return self.SessionWithProxy( connector=connector, trust_env=self._cfg.trust_env, - proxy=proxy, + proxy=proxy_from_conn_params, ) diff --git a/src/snowflake/connector/aio/_telemetry.py b/src/snowflake/connector/aio/_telemetry.py index b9b46f2301..29fa7a93dc 100644 --- a/src/snowflake/connector/aio/_telemetry.py +++ b/src/snowflake/connector/aio/_telemetry.py @@ -38,7 +38,7 @@ async def add_log_to_batch(self, telemetry_data: TelemetryData) -> None: if len(self._log_batch) >= self._flush_size: await self.send_batch() - async def send_batch(self) -> None: + async def send_batch(self, retry: bool = False) -> None: if self.is_closed: raise Exception("Attempted to send batch when TelemetryClient is closed") elif not self._enabled: @@ -69,6 +69,7 @@ async def send_batch(self) -> None: method="post", client=None, timeout=5, + _no_retry=not retry, ) if not ret["success"]: logger.info( @@ -89,9 +90,8 @@ async def try_add_log_to_batch(self, telemetry_data: TelemetryData) -> None: except Exception: logger.warning("Failed to add log to telemetry.", exc_info=True) - async def close(self, send_on_close: bool = True) -> None: + async def close(self, retry: bool = False) -> None: if not self.is_closed: logger.debug("Closing telemetry client.") - if send_on_close: - await self.send_batch() + await self.send_batch(retry=retry) self._rest = None diff --git a/src/snowflake/connector/aio/auth/_oauth_code.py b/src/snowflake/connector/aio/auth/_oauth_code.py index ce3b7bacbf..1cfafe1fe7 100644 --- a/src/snowflake/connector/aio/auth/_oauth_code.py +++ b/src/snowflake/connector/aio/auth/_oauth_code.py @@ -36,6 +36,7 @@ def __init__( external_browser_timeout: int | None = None, enable_single_use_refresh_tokens: bool = False, connection: SnowflakeConnection | None = None, + uri: str | None = None, **kwargs, ) -> None: """Initializes an instance with OAuth authorization code parameters.""" @@ -58,6 +59,7 @@ def __init__( external_browser_timeout=external_browser_timeout, enable_single_use_refresh_tokens=enable_single_use_refresh_tokens, connection=connection, + uri=uri, **kwargs, ) diff --git a/src/snowflake/connector/auth/_auth.py b/src/snowflake/connector/auth/_auth.py index 51b26c5c8c..d774c03eb0 100644 --- a/src/snowflake/connector/auth/_auth.py +++ b/src/snowflake/connector/auth/_auth.py @@ -506,6 +506,16 @@ def read_temporary_credentials( user: str, session_parameters: dict[str, Any], ) -> None: + """Attempt to load cached credentials to skip interactive authentication. + + SSO (ID_TOKEN): If present, avoids opening browser for external authentication. + Controlled by client_store_temporary_credential parameter. + + MFA (MFA_TOKEN): If present, skips MFA prompt on next connection. + Controlled by client_request_mfa_token parameter. + + If cached tokens are expired/invalid, they're deleted and normal auth proceeds. + """ if session_parameters.get(PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, False): self._rest.id_token = self._read_temporary_credential( host, @@ -541,6 +551,13 @@ def write_temporary_credentials( session_parameters: dict[str, Any], response: dict[str, Any], ) -> None: + """Cache credentials received from successful authentication for future use. + + Tokens are only cached if: + 1. Server returned the token in response (server-side caching must be enabled) + 2. Client has caching enabled via session parameters + 3. User consented to caching (consent_cache_id_token for ID tokens) + """ if ( self._rest._connection.auth_class.consent_cache_id_token and session_parameters.get( diff --git a/src/snowflake/connector/auth/_http_server.py b/src/snowflake/connector/auth/_http_server.py index a11662f25b..8b7162bfbc 100644 --- a/src/snowflake/connector/auth/_http_server.py +++ b/src/snowflake/connector/auth/_http_server.py @@ -70,8 +70,10 @@ def __init__( self, uri: str, buf_size: int = 16384, + redirect_uri: str | None = None, ) -> None: parsed_uri = urllib.parse.urlparse(uri) + parsed_redirect = urllib.parse.urlparse(redirect_uri) if redirect_uri else None self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.buf_size = buf_size if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true": @@ -82,30 +84,34 @@ def __init__( else: self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - port = parsed_uri.port or 0 + if parsed_redirect and self._is_local_uri(parsed_redirect): + server_port = parsed_redirect.port or 0 + else: + server_port = parsed_uri.port or 0 + for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1): try: self._socket.bind( ( parsed_uri.hostname, - port, + server_port, ) ) break except socket.gaierror as ex: logger.error( - f"Failed to bind authorization callback server to port {port}: {ex}" + f"Failed to bind authorization callback server to port {server_port}: {ex}" ) raise except OSError as ex: if attempt == self.DEFAULT_MAX_ATTEMPTS: logger.error( - f"Failed to bind authorization callback server to port {port}: {ex}" + f"Failed to bind authorization callback server to port {server_port}: {ex}" ) raise logger.warning( f"Attempt {attempt}/{self.DEFAULT_MAX_ATTEMPTS}. " - f"Failed to bind authorization callback server to port {port}: {ex}" + f"Failed to bind authorization callback server to port {server_port}: {ex}" ) time.sleep(self.PORT_BIND_TIMEOUT / self.PORT_BIND_MAX_ATTEMPTS) try: @@ -114,16 +120,47 @@ def __init__( logger.error(f"Failed to start listening for auth callback: {ex}") self.close() raise - port = self._socket.getsockname()[1] + + server_port = self._socket.getsockname()[1] self._uri = urllib.parse.ParseResult( scheme=parsed_uri.scheme, - netloc=parsed_uri.hostname + ":" + str(port), + netloc=parsed_uri.hostname + ":" + str(server_port), path=parsed_uri.path, params=parsed_uri.params, query=parsed_uri.query, fragment=parsed_uri.fragment, ) + if parsed_redirect: + if ( + self._is_local_uri(parsed_redirect) + and server_port != parsed_redirect.port + ): + logger.debug( + f"Updating redirect port {parsed_redirect.port} to match the server port {server_port}." + ) + self._redirect_uri = urllib.parse.ParseResult( + scheme=parsed_redirect.scheme, + netloc=parsed_redirect.hostname + ":" + str(server_port), + path=parsed_redirect.path, + params=parsed_redirect.params, + query=parsed_redirect.query, + fragment=parsed_redirect.fragment, + ) + else: + self._redirect_uri = parsed_redirect + else: + # For backwards compatibility + self._redirect_uri = self._uri + + @staticmethod + def _is_local_uri(uri): + return uri.hostname in ("localhost", "127.0.0.1") + + @property + def redirect_uri(self) -> str | None: + return self._redirect_uri.geturl() + @property def url(self) -> str: return self._uri.geturl() diff --git a/src/snowflake/connector/auth/_oauth_base.py b/src/snowflake/connector/auth/_oauth_base.py index 2ff1241638..dc7f853e02 100644 --- a/src/snowflake/connector/auth/_oauth_base.py +++ b/src/snowflake/connector/auth/_oauth_base.py @@ -35,6 +35,14 @@ class _OAuthTokensMixin: + """Manages OAuth token caching to avoid repeated browser authentication flows. + + Access tokens: Short-lived (typically 10 minutes), cached to avoid immediate re-auth. + Refresh tokens: Long-lived (hours/days), used to obtain new access tokens silently. + + Tokens are cached per (user, IDP host) to support multiple OAuth providers/accounts. + """ + def __init__( self, token_cache: TokenCache | None, @@ -77,12 +85,18 @@ def _pop_cached_token(self, key: TokenKey | None) -> str | None: return self._token_cache.retrieve(key) def _pop_cached_access_token(self) -> bool: - """Retrieves OAuth access token from the token cache if enabled""" + """Retrieves OAuth access token from the token cache if enabled, available and still valid. + + Returns True if cached token found, allowing authentication to skip OAuth flow. + """ self._access_token = self._pop_cached_token(self._get_access_token_cache_key()) return self._access_token is not None def _pop_cached_refresh_token(self) -> bool: - """Retrieves OAuth refresh token from the token cache if enabled""" + """Retrieves OAuth refresh token from the token cache (if enabled) to silently obtain new access token. + + Returns True if refresh token found, enabling automatic token renewal without user interaction. + """ if self._refresh_token_enabled: self._refresh_token = self._pop_cached_token( self._get_refresh_token_cache_key() diff --git a/src/snowflake/connector/auth/oauth_code.py b/src/snowflake/connector/auth/oauth_code.py index a5aaf31fb9..4596000579 100644 --- a/src/snowflake/connector/auth/oauth_code.py +++ b/src/snowflake/connector/auth/oauth_code.py @@ -65,6 +65,7 @@ def __init__( external_browser_timeout: int | None = None, enable_single_use_refresh_tokens: bool = False, connection: SnowflakeConnection | None = None, + uri: str | None = None, **kwargs, ) -> None: authentication_url, redirect_uri = self._validate_oauth_code_uris( @@ -92,6 +93,7 @@ def __init__( self._origin: str | None = None self._authentication_url = authentication_url self._redirect_uri = redirect_uri + self._uri = uri self._state = secrets.token_urlsafe(43) logger.debug("chose oauth state: %s", "".join("*" for _ in self._state)) self._protocol = "http" @@ -117,7 +119,10 @@ def _request_tokens( ) -> (str | None, str | None): """Web Browser based Authentication.""" logger.debug("authenticating with OAuth authorization code flow") - with AuthHttpServer(self._redirect_uri) as callback_server: + with AuthHttpServer( + redirect_uri=self._redirect_uri, + uri=self._uri or self._redirect_uri, # for backward compatibility + ) as callback_server: code = self._do_authorization_request(callback_server, conn) return self._do_token_request(code, callback_server, conn) @@ -260,7 +265,7 @@ def _do_authorization_request( connection: SnowflakeConnection, ) -> str | None: authorization_request = self._construct_authorization_request( - callback_server.url + callback_server.redirect_uri ) logger.debug("step 1: going to open authorization URL") print( @@ -315,7 +320,7 @@ def _do_token_request( fields = { "grant_type": "authorization_code", "code": code, - "redirect_uri": callback_server.url, + "redirect_uri": callback_server.redirect_uri, } if self._enable_single_use_refresh_tokens: fields["enable_single_use_refresh_tokens"] = "true" diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 102a6cf3aa..8c6b32dcd0 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -279,8 +279,13 @@ def _get_private_bytes_from_file( "support_negative_year": (True, bool), # snowflake "log_max_query_length": (LOG_MAX_QUERY_LENGTH, int), # snowflake "disable_request_pooling": (False, bool), # snowflake - # enable temporary credential file for Linux, default false. Mac/Win will overlook this + # Cache SSO ID tokens to avoid repeated browser popups. Must be enabled on the server-side. + # Storage: keyring (macOS/Windows), file (Linux). Auto-enabled on macOS/Windows. + # Sets session PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL as well "client_store_temporary_credential": (False, bool), + # Cache MFA tokens to skip MFA prompts on reconnect. Must be enabled on the server-side. + # Storage: keyring (macOS/Windows), file (Linux). Auto-enabled on macOS/Windows. + # In driver, we extract this from session using PARAMETER_CLIENT_REQUEST_MFA_TOKEN. "client_request_mfa_token": (False, bool), "use_openssl_only": ( True, @@ -383,6 +388,11 @@ def _get_private_bytes_from_file( # SNOW-1825621: OAUTH implementation ), "oauth_redirect_uri": ("http://127.0.0.1", str), + "oauth_socket_uri": ( + "http://127.0.0.1", + str, + # SNOW-2194055: Separate server and redirect URIs in AuthHttpServer + ), "oauth_scope": ( "", str, @@ -961,6 +971,15 @@ def connect(self, **kwargs) -> None: if len(kwargs) > 0: self.__config(**kwargs) + no_proxy_csv_str = ( + ",".join(str(x) for x in self.no_proxy) + if ( + self.no_proxy is not None + and isinstance(self.no_proxy, Iterable) + and not isinstance(self.no_proxy, (str, bytes)) + ) + else self.no_proxy + ) self._http_config = HttpConfig( adapter_factory=ProxySupportAdapterFactory(), use_pooling=(not self.disable_request_pooling), @@ -968,15 +987,7 @@ def connect(self, **kwargs) -> None: proxy_port=self.proxy_port, proxy_user=self.proxy_user, proxy_password=self.proxy_password, - no_proxy=( - ",".join(str(x) for x in self.no_proxy) - if ( - self.no_proxy is not None - and isinstance(self.no_proxy, Iterable) - and not isinstance(self.no_proxy, (str, bytes)) - ) - else self.no_proxy - ), + no_proxy=no_proxy_csv_str, ) self._session_manager = SessionManagerFactory.get_manager(self._http_config) @@ -1036,7 +1047,8 @@ def close(self, retry: bool = True) -> None: # close telemetry first, since it needs rest to send remaining data logger.debug("closed") - self._telemetry.close(send_on_close=bool(retry and self.telemetry_enabled)) + if self.telemetry_enabled: + self._telemetry.close(retry=retry) if ( self._all_async_queries_finished() and not self._server_session_keep_alive @@ -1269,9 +1281,11 @@ def __open_connection(self): backoff_generator=self._backoff_generator, ) elif self._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR: + # Enable SSO credential caching self._session_parameters[ PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL ] = (self._client_store_temporary_credential if IS_LINUX else True) + # Try to load cached ID token to avoid browser popup auth.read_temporary_credentials( self.host, self.user, @@ -1335,6 +1349,7 @@ def __open_connection(self): host=self.host, port=self.port ), redirect_uri=self._oauth_redirect_uri, + uri=self._oauth_socket_uri, scope=self._oauth_scope, pkce_enabled=not self._oauth_disable_pkce, token_cache=( @@ -1362,9 +1377,11 @@ def __open_connection(self): connection=self, ) elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR: + # Enable MFA token caching self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = ( self._client_request_mfa_token if IS_LINUX else True ) + # Try to load cached MFA token to skip MFA prompt if self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN]: auth.read_temporary_credentials( self.host, diff --git a/src/snowflake/connector/constants.py b/src/snowflake/connector/constants.py index 47f07b9eb9..7429862d59 100644 --- a/src/snowflake/connector/constants.py +++ b/src/snowflake/connector/constants.py @@ -442,7 +442,6 @@ class IterUnit(Enum): ENV_VAR_PARTNER = "SF_PARTNER" ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE" - _DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"} _CONNECTIVITY_ERR_MSG = ( diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py index edd843628f..a478b01cbd 100644 --- a/src/snowflake/connector/session_manager.py +++ b/src/snowflake/connector/session_manager.py @@ -219,7 +219,7 @@ def close(self) -> None: self._idle_sessions.clear() -class _ConfigDirectAccessMixin(abc.ABC): +class _BaseConfigDirectAccessMixin(abc.ABC): @property @abc.abstractmethod def config(self) -> HttpConfig: ... @@ -236,14 +236,6 @@ def use_pooling(self) -> bool: def use_pooling(self, value: bool) -> None: self.config = self.config.copy_with(use_pooling=value) - @property - def adapter_factory(self) -> Callable[..., HTTPAdapter]: - return self.config.adapter_factory - - @adapter_factory.setter - def adapter_factory(self, value: Callable[..., HTTPAdapter]) -> None: - self.config = self.config.copy_with(adapter_factory=value) - @property def max_retries(self) -> Retry | int: return self.config.max_retries @@ -253,6 +245,16 @@ def max_retries(self, value: Retry | int) -> None: self.config = self.config.copy_with(max_retries=value) +class _HttpConfigDirectAccessMixin(_BaseConfigDirectAccessMixin, abc.ABC): + @property + def adapter_factory(self) -> Callable[..., HTTPAdapter]: + return self.config.adapter_factory + + @adapter_factory.setter + def adapter_factory(self, value: Callable[..., HTTPAdapter]) -> None: + self.config = self.config.copy_with(adapter_factory=value) + + class _RequestVerbsUsingSessionMixin(abc.ABC): """ Mixin that provides HTTP methods (get, post, put, etc.) mirroring requests.Session, maintaining their default argument behavior (e.g., HEAD uses allow_redirects=False). @@ -363,7 +365,7 @@ def delete( return session.delete(url, headers=headers, timeout=timeout, **kwargs) -class SessionManager(_RequestVerbsUsingSessionMixin, _ConfigDirectAccessMixin): +class SessionManager(_RequestVerbsUsingSessionMixin, _HttpConfigDirectAccessMixin): """ Central HTTP session manager that handles all external requests from the Snowflake driver. @@ -557,7 +559,7 @@ def clone( Optional kwargs (e.g. *use_pooling* / *adapter_factory* / max_retries etc.) - overrides to create a modified copy of the HttpConfig before instantiation. """ - return SessionManager.from_config(self._cfg, **http_config_overrides) + return self.from_config(self._cfg, **http_config_overrides) def __getstate__(self): state = self.__dict__.copy() @@ -621,12 +623,6 @@ def make_session(self, *, url: str | None = None) -> Session: session.proxies = proxies return session - def clone( - self, - **http_config_overrides, - ) -> SessionManager: - return ProxySessionManager.from_config(self._cfg, **http_config_overrides) - class SessionManagerFactory: @staticmethod diff --git a/src/snowflake/connector/telemetry.py b/src/snowflake/connector/telemetry.py index a22cbdfbb6..37edd3fd41 100644 --- a/src/snowflake/connector/telemetry.py +++ b/src/snowflake/connector/telemetry.py @@ -155,7 +155,7 @@ def try_add_log_to_batch(self, telemetry_data: TelemetryData) -> None: except Exception: logger.warning("Failed to add log to telemetry.", exc_info=True) - def send_batch(self) -> None: + def send_batch(self, retry: bool = False) -> None: if self.is_closed: raise Exception("Attempted to send batch when TelemetryClient is closed") elif not self._enabled: @@ -186,6 +186,7 @@ def send_batch(self) -> None: method="post", client=None, timeout=5, + _no_retry=not retry, ) if not ret["success"]: logger.info( @@ -204,11 +205,10 @@ def send_batch(self) -> None: def is_closed(self) -> bool: return self._rest is None - def close(self, send_on_close: bool = True) -> None: + def close(self, retry: bool = False) -> None: if not self.is_closed: logger.debug("Closing telemetry client.") - if send_on_close: - self.send_batch() + self.send_batch(retry=retry) self._rest = None def disable(self) -> None: diff --git a/src/snowflake/connector/token_cache.py b/src/snowflake/connector/token_cache.py index b197fc51e0..5e71ebb386 100644 --- a/src/snowflake/connector/token_cache.py +++ b/src/snowflake/connector/token_cache.py @@ -22,6 +22,14 @@ class TokenType(Enum): + """Types of credentials that can be cached to avoid repeated authentication. + + - ID_TOKEN: SSO identity token from external browser/Okta authentication + - MFA_TOKEN: Multi-factor authentication token to skip MFA prompts + - OAUTH_ACCESS_TOKEN: Short-lived OAuth access token + - OAUTH_REFRESH_TOKEN: Long-lived OAuth token to obtain new access tokens + """ + ID_TOKEN = "ID_TOKEN" MFA_TOKEN = "MFA_TOKEN" OAUTH_ACCESS_TOKEN = "OAUTH_ACCESS_TOKEN" @@ -57,6 +65,16 @@ def _warn(warning: str) -> None: class TokenCache(ABC): + """Secure storage for authentication credentials to avoid repeated login prompts. + + Platform-specific implementations: + - macOS/Windows: Uses OS keyring (Keychain/Credential Manager) via 'keyring' library + - Linux: Uses encrypted JSON file in ~/.cache/snowflake/ with 0o600 permissions + - Fallback: NoopTokenCache (no caching) if secure storage unavailable + + Tokens are keyed by (host, user, token_type) to support multiple accounts. + """ + @staticmethod def make(skip_file_permissions_check: bool = False) -> TokenCache: if IS_MACOS or IS_WINDOWS: @@ -127,6 +145,17 @@ class _CacheFileWriteError(_FileTokenCacheError): class FileTokenCache(TokenCache): + """Linux implementation: stores tokens in JSON file with strict security. + + Cache location (in priority order): + 1. $SF_TEMPORARY_CREDENTIAL_CACHE_DIR/credential_cache_v1.json + 2. $XDG_CACHE_HOME/snowflake/credential_cache_v1.json + 3. $HOME/.cache/snowflake/credential_cache_v1.json + + Security: File must have 0o600 permissions and be owned by current user. + Uses file locks to prevent concurrent access corruption. + """ + @staticmethod def make(skip_file_permissions_check: bool = False) -> FileTokenCache | None: cache_dir = FileTokenCache.find_cache_dir(skip_file_permissions_check) @@ -364,6 +393,14 @@ def _ensure_permissions(self, fd: int, permissions: int) -> None: class KeyringTokenCache(TokenCache): + """macOS/Windows implementation: uses OS-native secure credential storage. + + - macOS: Stores tokens in Keychain + - Windows: Stores tokens in Windows Credential Manager + + Tokens are stored with service="{HOST}:{USER}:{TOKEN_TYPE}" and username="{USER}". + """ + def __init__(self) -> None: self.logger = logging.getLogger(__name__) diff --git a/test/data/wiremock/mappings/generic/telemetry.json b/test/data/wiremock/mappings/generic/telemetry.json index 9b734a0cf2..86ed49a5fb 100644 --- a/test/data/wiremock/mappings/generic/telemetry.json +++ b/test/data/wiremock/mappings/generic/telemetry.json @@ -10,9 +10,9 @@ "data": { "code": null, "data": "Log Received", - "message": null, - "success": true - } + "message": null + }, + "success": true } } } diff --git a/test/test_utils/cross_module_fixtures/wiremock_fixtures.py b/test/test_utils/cross_module_fixtures/wiremock_fixtures.py index d6330850e2..9421bf9280 100644 --- a/test/test_utils/cross_module_fixtures/wiremock_fixtures.py +++ b/test/test_utils/cross_module_fixtures/wiremock_fixtures.py @@ -12,6 +12,7 @@ WiremockClient, get_clients_for_proxy_and_target, get_clients_for_proxy_target_and_storage, + get_clients_for_two_proxies_and_target, ) @@ -102,3 +103,23 @@ def wiremock_backend_storage_proxy(wiremock_generic_mappings_dir): proxy_mapping_template=wiremock_proxy_mapping_path ) as triple: yield triple + + +@pytest.fixture +def wiremock_two_proxies_backend(wiremock_generic_mappings_dir): + """Starts backend (DB) and two proxy Wiremocks. + + Returns a tuple ``(backend_wm, proxy1_wm, proxy2_wm)`` to make roles explicit. + - proxy1_wm: Configured to forward to backend + - proxy2_wm: Configured to forward to backend + + Use when you need to test proxy selection logic with simple setup, + such as connection parameters taking precedence over environment variables. + """ + wiremock_proxy_mapping_path = ( + wiremock_generic_mappings_dir / "proxy_forward_all.json" + ) + with get_clients_for_two_proxies_and_target( + proxy_mapping_template=wiremock_proxy_mapping_path + ) as triple: + yield triple diff --git a/test/test_utils/wiremock/wiremock_utils.py b/test/test_utils/wiremock/wiremock_utils.py index 03fbe0adca..7a7b28bdce 100644 --- a/test/test_utils/wiremock/wiremock_utils.py +++ b/test/test_utils/wiremock/wiremock_utils.py @@ -388,3 +388,55 @@ def get_clients_for_proxy_target_and_storage( forbidden = [target_wm.wiremock_http_port, proxy_wm.wiremock_http_port] with WiremockClient(forbidden_ports=forbidden) as storage_wm: yield target_wm, storage_wm, proxy_wm + + +@contextmanager +def get_clients_for_two_proxies_and_target( + proxy_mapping_template: Union[str, dict, pathlib.Path, None] = None, + additional_proxy_placeholders: Optional[dict[str, object]] = None, + additional_proxy_args: Optional[Iterable[str]] = None, +): + """Context manager that starts three Wiremock instances – one *target* (DB) and two *proxies*. + + Both proxies are configured to forward all traffic to *target* using the same + mapping mechanism. This allows the test to verify which proxy was actually used + by checking the request history. + + Yields a tuple ``(target_wm, proxy1_wm, proxy2_wm)`` where: + - target_wm: The backend/DB Wiremock + - proxy1_wm: First proxy configured to forward to target + - proxy2_wm: Second proxy configured to forward to target + + All processes are shut down automatically on context exit. + + Note: + Use this helper for tests that need to verify proxy selection logic, + such as connection parameters taking precedence over environment variables. + """ + # Reuse existing helper to set up target+proxy1 + with get_clients_for_proxy_and_target( + proxy_mapping_template=proxy_mapping_template, + additional_proxy_placeholders=additional_proxy_placeholders, + additional_proxy_args=additional_proxy_args, + ) as (target_wm, proxy1_wm): + # Start second proxy and configure it to forward to target as well + forbidden = [target_wm.wiremock_http_port, proxy1_wm.wiremock_http_port] + with WiremockClient(forbidden_ports=forbidden) as proxy2_wm: + # Configure proxy2 to forward to target with the same mapping + if proxy_mapping_template is None: + proxy_mapping_template = ( + pathlib.Path(__file__).parent.parent.parent.parent + / "test" + / "data" + / "wiremock" + / "mappings" + / "generic" + / "proxy_forward_all.json" + ) + placeholders: dict[str, object] = { + "{{TARGET_HTTP_HOST_WITH_PORT}}": target_wm.http_host_with_port + } + if additional_proxy_placeholders: + placeholders.update(additional_proxy_placeholders) + proxy2_wm.add_mapping(proxy_mapping_template, placeholders=placeholders) + yield target_wm, proxy1_wm, proxy2_wm diff --git a/test/unit/aio/test_auth_oauth_auth_code_async.py b/test/unit/aio/test_auth_oauth_auth_code_async.py index 7152e4de8d..2a149f92af 100644 --- a/test/unit/aio/test_auth_oauth_auth_code_async.py +++ b/test/unit/aio/test_auth_oauth_auth_code_async.py @@ -5,7 +5,7 @@ import unittest.mock as mock from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body -from unittest.mock import patch +from unittest.mock import PropertyMock, patch import pytest @@ -303,6 +303,73 @@ def test_mro(): ) +@pytest.mark.parametrize("redirect_uri", ["https://redirect/uri"]) +@pytest.mark.parametrize("rtr_enabled", [True, False]) +async def test_auth_oauth_auth_code_uses_redirect_uri( + redirect_uri, rtr_enabled: bool, omit_oauth_urls_check +): + """Test that the redirect URI is used correctly in the OAuth authorization code flow.""" + auth = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + redirect_uri, + "scope", + "host", + pkce_enabled=False, + enable_single_use_refresh_tokens=rtr_enabled, + uri="http://localhost:0", + ) + + def fake_get_request_token_response(_, fields: dict[str, str]): + if rtr_enabled: + assert fields.get("enable_single_use_refresh_tokens") == "true" + else: + assert "enable_single_use_refresh_tokens" not in fields + return ("access_token", "refresh_token") + + with patch( + "snowflake.connector.aio.auth.AuthByOauthCode._construct_authorization_request", + return_value="authorization_request", + ) as mock_construct_authorization_request: + with patch( + "snowflake.connector.aio.auth.AuthByOauthCode._receive_authorization_callback", + return_value=("code", auth._state), + ): + with patch( + "snowflake.connector.aio.auth.AuthByOauthCode._ask_authorization_callback_from_user", + return_value=("code", auth._state), + ): + with patch( + "snowflake.connector.aio.auth.AuthByOauthCode._get_request_token_response", + side_effect=fake_get_request_token_response, + ) as mock_get_request_token_response: + with patch( + "snowflake.connector.auth._http_server.AuthHttpServer.redirect_uri", + return_value=redirect_uri, + new_callable=PropertyMock, + ): + await auth.prepare( + conn=None, + authenticator=OAUTH_AUTHORIZATION_CODE, + service_name=None, + account="acc", + user="user", + ) + mock_construct_authorization_request.assert_called_once_with( + redirect_uri + ) + assert mock_get_request_token_response.call_count == 1 + assert ( + mock_get_request_token_response.call_args[0][1][ + "redirect_uri" + ] + == redirect_uri + ) + + @pytest.mark.skipolddriver async def test_oauth_authorization_code_allows_empty_user( monkeypatch, omit_oauth_urls_check @@ -349,3 +416,50 @@ def mock_request_tokens(self, **kwargs): assert isinstance(conn.auth_class, AuthByOauthCode) await conn.close() + + +@pytest.mark.parametrize( + "uri,redirect_uri", + [ + ("https://example.com/server", "http://localhost:8080"), + ("http://localhost:8080", "https://example.com/redirect"), + ("http://127.0.0.1:9090", "https://server.com/oauth/callback"), + (None, "https://redirect.example.com"), + ], +) +@mock.patch( + "snowflake.connector.aio.auth._oauth_code.AuthByOauthCode._do_authorization_request" +) +@mock.patch( + "snowflake.connector.aio.auth._oauth_code.AuthByOauthCode._do_token_request" +) +async def test_auth_oauth_auth_code_passes_uri_to_http_server( + _, __, uri, redirect_uri, omit_oauth_urls_check +): + """Test that uri and redirect_uri parameters are passed correctly to AuthHttpServer.""" + auth = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "https://auth_url", + "tokenRequestUrl", + redirect_uri, + "scope", + "host", + uri=uri, + ) + + with patch( + "snowflake.connector.auth.oauth_code.AuthHttpServer", + # return_value=None, + ) as mock_http_server_init: + auth._request_tokens( + conn=mock.MagicMock(), + authenticator="authenticator", + service_name="service_name", + account="account", + user="user", + ) + mock_http_server_init.assert_called_once_with( + uri=uri or redirect_uri, redirect_uri=redirect_uri + ) diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py index 284997ca34..d3bfe489b8 100644 --- a/test/unit/aio/test_connection_async_unit.py +++ b/test/unit/aio/test_connection_async_unit.py @@ -844,7 +844,7 @@ async def test_invalid_authenticator(): @pytest.mark.skipolddriver -async def test_connect_metadata_preservation(): +def test_connect_metadata_preservation(): """Test that the async connect function preserves metadata from SnowflakeConnection.__init__. This test verifies that various inspection methods return consistent metadata, @@ -852,7 +852,9 @@ async def test_connect_metadata_preservation(): """ import inspect - from snowflake.connector.aio import SnowflakeConnection, connect + # Use already imported snowflake.connector.aio + connect = snowflake.connector.aio.connect + SnowflakeConnection = snowflake.connector.aio.SnowflakeConnection # Test 1: Check __name__ is correct assert ( @@ -910,7 +912,7 @@ async def test_connect_metadata_preservation(): connect_doc == source_doc ), "inspect.getdoc(connect) should match inspect.getdoc(SnowflakeConnection.__init__)" - # Test 8: Check that connect is callable and returns expected type + # Test 8: Check that connect is callable assert callable(connect), "connect should be callable" # Test 9: Check type() and __class__ values (important for user introspection) @@ -931,3 +933,4 @@ async def test_connect_metadata_preservation(): assert ( len(params) > 0 ), "connect should have parameters from SnowflakeConnection.__init__" + # Should have parameters like account, user, password, etc. diff --git a/test/unit/aio/test_proxies_async.py b/test/unit/aio/test_proxies_async.py index e7cc6b67b2..0bf88704c4 100644 --- a/test/unit/aio/test_proxies_async.py +++ b/test/unit/aio/test_proxies_async.py @@ -4,10 +4,13 @@ from collections import deque from test.unit.test_proxies import ( DbRequestFlags, + ProxyPrecedenceFlags, RequestFlags, _apply_no_proxy, _base_connect_kwargs, _configure_proxy, + _set_mappings_for_common_backend, + _set_mappings_for_query_and_chunks, _setup_backend_storage_mappings, ) @@ -234,6 +237,44 @@ async def _collect_db_request_flags_only(proxy_wm, target_wm) -> DbRequestFlags: return DbRequestFlags(proxy_saw_db=proxy_saw_db, target_saw_db=target_saw_db) +async def _collect_proxy_precedence_flags( + proxy1_wm, proxy2_wm, target_wm +) -> ProxyPrecedenceFlags: + """Async version of proxy precedence flags collection using aiohttp.""" + async with aiohttp.ClientSession() as session: + async with session.get( + f"{proxy1_wm.http_host_with_port}/__admin/requests" + ) as resp: + proxy1_reqs = await resp.json() + async with session.get( + f"{proxy2_wm.http_host_with_port}/__admin/requests" + ) as resp: + proxy2_reqs = await resp.json() + async with session.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ) as resp: + target_reqs = await resp.json() + + proxy1_saw_request = any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy1_reqs["requests"] + ) + proxy2_saw_request = any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy2_reqs["requests"] + ) + backend_saw_request = any( + "/queries/v1/query-request" in r["request"]["url"] + for r in target_reqs["requests"] + ) + + return ProxyPrecedenceFlags( + proxy1_saw_request=proxy1_saw_request, + proxy2_saw_request=proxy2_saw_request, + backend_saw_request=backend_saw_request, + ) + + @pytest.mark.skipolddriver @pytest.mark.parametrize("no_proxy_source", ["param", "env"]) async def test_no_proxy_bypass_storage( @@ -577,3 +618,72 @@ async def test_no_proxy_bypass_backend_and_storage_param_only( assert flags.proxy_saw_db is False assert flags.storage_saw_storage assert flags.proxy_saw_storage is False + + +@pytest.mark.skipolddriver +async def test_proxy_env_vars_take_precedence_over_connection_params( + wiremock_two_proxies_backend, + wiremock_mapping_dir, + wiremock_generic_mappings_dir, + proxy_env_vars, + monkeypatch, + host_port_pooling, + fix_aiohttp_proxy_bypass, +): + """Verify that proxy_host/proxy_port connection parameters take precedence over env vars. + + Setup: + - Set HTTP_PROXY env var to point to proxy_from_env_vars + - Set proxy_host param to point to proxy_from_conn_params + + Expected outcome: + - proxy_from_conn_params should see the request (params take precedence) + - proxy_from_env_vars should NOT see the request + - backend should see the request + """ + target_wm, proxy_from_conn_params, proxy_from_env_vars = ( + wiremock_two_proxies_backend + ) + + # Setup backend mappings for large query with multiple chunks + _set_mappings_for_common_backend(target_wm, wiremock_generic_mappings_dir) + _set_mappings_for_query_and_chunks( + target_wm, + wiremock_mapping_dir, + ) + + # Set HTTP_PROXY env var AFTER Wiremock is running using monkeypatch + # This prevents Wiremock from inheriting it and forwarding through proxy2 + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + clear_proxy_env_vars() # Clear any existing ones first + + env_proxy_url = f"http://{proxy_from_env_vars.wiremock_host}:{proxy_from_env_vars.wiremock_http_port}" + + # Set connection params to point to proxy1 (should take precedence) + connect_kwargs = _base_connect_kwargs(target_wm) + connect_kwargs.update( + { + "proxy_host": proxy_from_conn_params.wiremock_host, + "proxy_port": str(proxy_from_conn_params.wiremock_http_port), + } + ) + + with monkeypatch.context() as m_context: + m_context.setenv("HTTP_PROXY", env_proxy_url) + m_context.setenv("HTTPS_PROXY", env_proxy_url) + + # Execute query - now async + await _execute_large_query(connect_kwargs, row_count=50_000) + + # Verify proxy selection using named tuple flags - now async + flags = await _collect_proxy_precedence_flags( + proxy_from_conn_params, proxy_from_env_vars, target_wm + ) + assert not ( + flags.proxy1_saw_request + ), "proxy_from_conn_params (connection param proxy) should NOT have seen the query request" + assert flags.proxy2_saw_request, ( + "proxy_from_env_vars (env var proxy) should have seen the request " + "since connection params take precedence" + ) + assert flags.backend_saw_request, "backend should have seen the query request" diff --git a/test/unit/aio/test_telemetry_async.py b/test/unit/aio/test_telemetry_async.py index 3dbe1197b0..7ae3bbb824 100644 --- a/test/unit/aio/test_telemetry_async.py +++ b/test/unit/aio/test_telemetry_async.py @@ -53,7 +53,7 @@ def test_telemetry_data_to_dict(): def get_client_and_mock(): - rest_call = Mock() + rest_call = AsyncMock() rest_call.return_value = {"success": True} rest = Mock() rest.attach_mock(rest_call, "request") @@ -315,3 +315,27 @@ def get_mocked_telemetry_connection(telemetry_enabled: bool = True) -> AsyncMock mock_connection._telemetry = mock_telemetry return mock_connection + + +async def test_telemetry_send_batch_with_retry_flag(): + """Tests that send_batch respects the retry parameter.""" + client, rest_call = get_client_and_mock() + + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + + # Test with retry=True + await client.send_batch(retry=True) + + assert rest_call.call_count == 1 + # Verify _no_retry parameter is False when retry=True + call_kwargs = rest_call.call_args[1] + assert call_kwargs["_no_retry"] is False + + # Add another log and test with retry=False (default) + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 3000)) + await client.send_batch(retry=False) + + assert rest_call.call_count == 2 + # Verify _no_retry parameter is True when retry=False + call_kwargs = rest_call.call_args[1] + assert call_kwargs["_no_retry"] is True diff --git a/test/unit/test_auth_callback_server.py b/test/unit/test_auth_callback_server.py index bf03a8d5f6..5a33ec0a80 100644 --- a/test/unit/test_auth_callback_server.py +++ b/test/unit/test_auth_callback_server.py @@ -24,7 +24,9 @@ def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> No monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) test_response: requests.Response | None = None - with AuthHttpServer("http://127.0.0.1/test_request") as callback_server: + with AuthHttpServer( + "http://127.0.0.1/test_request", + ) as callback_server: def request_callback(): nonlocal test_response @@ -57,7 +59,155 @@ def request_callback(): def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) - with AuthHttpServer("http://127.0.0.1/test_request") as callback_server: + with AuthHttpServer( + "http://127.0.0.1/test_request", + ) as callback_server: block, client_socket = callback_server.receive_block(timeout=timeout) assert block is None assert client_socket is None + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_updates_localhost_redirect_uri_port_to_match_socket_port( + monkeypatch, + socket_host, + socket_port, + redirect_host, + redirect_port, + dontwait, + reuse_port, +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", + redirect_uri=f"http://{redirect_host}{redirect_port}/test_request", + ) as callback_server: + assert callback_server._redirect_uri.port == callback_server.port + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + 54321, + 54320, + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_uses_redirect_uri_port_when_specified( + monkeypatch, + socket_host, + socket_port, + redirect_host, + redirect_port, + dontwait, + reuse_port, +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", + redirect_uri=f"http://{redirect_host}:{redirect_port}/test_request", + ) as callback_server: + assert callback_server.port == redirect_port + assert callback_server._redirect_uri.port == redirect_port + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_does_not_updates_nonlocalhost_redirect_uri_port_to_match_socket_port( + monkeypatch, socket_host, socket_port, redirect_port, dontwait, reuse_port +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + redirect_uri = f"http://not_localhost{redirect_port}/test_request" + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", redirect_uri=redirect_uri + ) as callback_server: + assert callback_server.redirect_uri == redirect_uri diff --git a/test/unit/test_auth_oauth_auth_code.py b/test/unit/test_auth_oauth_auth_code.py index 4342521afa..c2455bac6c 100644 --- a/test/unit/test_auth_oauth_auth_code.py +++ b/test/unit/test_auth_oauth_auth_code.py @@ -5,7 +5,7 @@ import unittest.mock as mock from test.helpers import apply_auth_class_update_body, create_mock_auth_body -from unittest.mock import patch +from unittest.mock import PropertyMock, patch import pytest @@ -287,6 +287,73 @@ def mock_request_tokens(self, **kwargs): conn.close() +@pytest.mark.parametrize("redirect_uri", ["https://redirect/uri"]) +@pytest.mark.parametrize("rtr_enabled", [True, False]) +def test_auth_oauth_auth_code_uses_redirect_uri( + redirect_uri, rtr_enabled: bool, omit_oauth_urls_check +): + """Test that the redirect URI is used correctly in the OAuth authorization code flow.""" + auth = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + redirect_uri, + "scope", + "host", + pkce_enabled=False, + enable_single_use_refresh_tokens=rtr_enabled, + uri="http://localhost:0", + ) + + def fake_get_request_token_response(_, fields: dict[str, str]): + if rtr_enabled: + assert fields.get("enable_single_use_refresh_tokens") == "true" + else: + assert "enable_single_use_refresh_tokens" not in fields + return ("access_token", "refresh_token") + + with patch( + "snowflake.connector.auth.AuthByOauthCode._construct_authorization_request", + return_value="authorization_request", + ) as mock_construct_authorization_request: + with patch( + "snowflake.connector.auth.AuthByOauthCode._receive_authorization_callback", + return_value=("code", auth._state), + ): + with patch( + "snowflake.connector.auth.AuthByOauthCode._ask_authorization_callback_from_user", + return_value=("code", auth._state), + ): + with patch( + "snowflake.connector.auth.AuthByOauthCode._get_request_token_response", + side_effect=fake_get_request_token_response, + ) as mock_get_request_token_response: + with patch( + "snowflake.connector.auth._http_server.AuthHttpServer.redirect_uri", + return_value=redirect_uri, + new_callable=PropertyMock, + ): + auth.prepare( + conn=None, + authenticator=OAUTH_AUTHORIZATION_CODE, + service_name=None, + account="acc", + user="user", + ) + mock_construct_authorization_request.assert_called_once_with( + redirect_uri + ) + assert mock_get_request_token_response.call_count == 1 + assert ( + mock_get_request_token_response.call_args[0][1][ + "redirect_uri" + ] + == redirect_uri + ) + + @pytest.mark.skipolddriver def test_oauth_authorization_code_allows_empty_user(monkeypatch, omit_oauth_urls_check): """Test that OAUTH_AUTHORIZATION_CODE authenticator allows connection without user parameter.""" @@ -328,3 +395,48 @@ def mock_request_tokens(self, **kwargs): assert isinstance(conn.auth_class, AuthByOauthCode) conn.close() + + +@pytest.mark.parametrize( + "uri,redirect_uri", + [ + ("https://example.com/server", "http://localhost:8080"), + ("http://localhost:8080", "https://example.com/redirect"), + ("http://127.0.0.1:9090", "https://server.com/oauth/callback"), + (None, "https://redirect.example.com"), + ], +) +@mock.patch( + "snowflake.connector.auth.oauth_code.AuthByOauthCode._do_authorization_request" +) +@mock.patch("snowflake.connector.auth.oauth_code.AuthByOauthCode._do_token_request") +def test_auth_oauth_auth_code_passes_uri_to_http_server( + _, __, uri, redirect_uri, omit_oauth_urls_check +): + """Test that uri and redirect_uri parameters are passed correctly to AuthHttpServer.""" + auth = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "https://auth_url", + "tokenRequestUrl", + redirect_uri, + "scope", + "host", + uri=uri, + ) + + with patch( + "snowflake.connector.auth.oauth_code.AuthHttpServer", + # return_value=None, + ) as mock_http_server_init: + auth._request_tokens( + conn=mock.MagicMock(), + authenticator="authenticator", + service_name="service_name", + account="account", + user="user", + ) + mock_http_server_init.assert_called_once_with( + uri=uri or redirect_uri, redirect_uri=redirect_uri + ) diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index e9d5da4e40..ce19e09181 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -17,6 +17,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa import snowflake.connector +from snowflake.connector import SnowflakeConnection, connect from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.errors import ( Error, @@ -868,3 +869,89 @@ def test_reraise_error_in_file_transfer_work_function_config( expected_value = bool(reraise_enabled) actual_value = conn._reraise_error_in_file_transfer_work_function assert actual_value == expected_value + + +@pytest.mark.skipolddriver +def test_connect_metadata_preservation(): + """Test that the sync connect function preserves metadata from SnowflakeConnection.__init__. + + This test verifies that various inspection methods return consistent metadata, + ensuring IDE support, type checking, and documentation generation work correctly. + """ + import inspect + + # Test 1: Check __name__ is correct + assert ( + connect.__name__ == "__init__" + ), f"connect.__name__ should be 'connect', but got '{connect.__name__}'" + + # Test 2: Check __wrapped__ points to SnowflakeConnection.__init__ + assert hasattr(connect, "__wrapped__"), "connect should have __wrapped__ attribute" + assert ( + connect.__wrapped__ is SnowflakeConnection.__init__ + ), "connect.__wrapped__ should reference SnowflakeConnection.__init__" + + # Test 3: Check __module__ is preserved + assert hasattr(connect, "__module__"), "connect should have __module__ attribute" + assert connect.__module__ == SnowflakeConnection.__init__.__module__, ( + f"connect.__module__ should match SnowflakeConnection.__init__.__module__, " + f"but got '{connect.__module__}' vs '{SnowflakeConnection.__init__.__module__}'" + ) + + # Test 4: Check __doc__ is preserved + assert hasattr(connect, "__doc__"), "connect should have __doc__ attribute" + assert ( + connect.__doc__ == SnowflakeConnection.__init__.__doc__ + ), "connect.__doc__ should match SnowflakeConnection.__init__.__doc__" + + # Test 5: Check __annotations__ are preserved (or at least available) + assert hasattr( + connect, "__annotations__" + ), "connect should have __annotations__ attribute" + src_annotations = getattr(SnowflakeConnection.__init__, "__annotations__", {}) + connect_annotations = getattr(connect, "__annotations__", {}) + assert connect_annotations == src_annotations, ( + f"connect.__annotations__ should match SnowflakeConnection.__init__.__annotations__, " + f"but got {connect_annotations} vs {src_annotations}" + ) + + # Test 6: Check inspect.signature works correctly + try: + connect_sig = inspect.signature(connect) + source_sig = inspect.signature(SnowflakeConnection.__init__) + assert str(connect_sig) == str(source_sig), ( + f"inspect.signature(connect) should match inspect.signature(SnowflakeConnection.__init__), " + f"but got '{connect_sig}' vs '{source_sig}'" + ) + except Exception as e: + pytest.fail(f"inspect.signature(connect) failed: {e}") + + # Test 7: Check inspect.getdoc works correctly + connect_doc = inspect.getdoc(connect) + source_doc = inspect.getdoc(SnowflakeConnection.__init__) + assert ( + connect_doc == source_doc + ), "inspect.getdoc(connect) should match inspect.getdoc(SnowflakeConnection.__init__)" + + # Test 8: Check that connect is callable + assert callable(connect), "connect should be callable" + + # Test 9: Check type() and __class__ values (important for user introspection) + assert ( + type(connect).__name__ == "function" + ), f"type(connect).__name__ should be 'function', but got '{type(connect).__name__}'" + assert ( + connect.__class__.__name__ == "function" + ), f"connect.__class__.__name__ should be 'function', but got '{connect.__class__.__name__}'" + assert inspect.isfunction( + connect + ), "connect should be recognized as a function by inspect.isfunction()" + + # Test 10: Verify the function has proper introspection capabilities + # IDEs and type checkers should be able to resolve parameters + sig = inspect.signature(connect) + params = list(sig.parameters.keys()) + assert ( + len(params) > 0 + ), "connect should have parameters from SnowflakeConnection.__init__" + # Should have parameters like account, user, password, etc. diff --git a/test/unit/test_proxies.py b/test/unit/test_proxies.py index 3f2f83e4f7..86d8199cb1 100644 --- a/test/unit/test_proxies.py +++ b/test/unit/test_proxies.py @@ -229,30 +229,14 @@ def _setup_backend_storage_mappings( wiremock_mapping_dir, wiremock_generic_mappings_dir, ): - password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json" - multi_chunk_request_mapping = ( - wiremock_mapping_dir / "queries/select_large_request_successful.json" - ) - chunk_1_mapping = wiremock_mapping_dir / "queries/chunk_1.json" - chunk_2_mapping = wiremock_mapping_dir / "queries/chunk_2.json" - disconnect_mapping = ( - wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" - ) - telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json" - - target_wm.import_mapping_with_default_placeholders(password_mapping) - target_wm.add_mapping(disconnect_mapping) - target_wm.add_mapping(telemetry_mapping) - target_wm.add_mapping( - multi_chunk_request_mapping, - placeholders={ - "{{STORAGE_WIREMOCK_HTTP_HOST_WITH_PORT}}": storage_wm.http_host_with_port - }, + """Setup backend, storage, and proxy mappings for large queries.""" + _set_mappings_for_common_backend(target_wm, wiremock_generic_mappings_dir) + _set_mappings_for_query_and_chunks( + target_wm, + wiremock_mapping_dir, + storage_or_target_wm=storage_wm, ) - storage_wm.add_mapping_with_default_placeholders(chunk_1_mapping) - storage_wm.add_mapping_with_default_placeholders(chunk_2_mapping) - proxy_wm.add_mapping( { "request": {"method": "ANY", "urlPathPattern": "/amazonaws/.*"}, @@ -313,13 +297,61 @@ def _apply_no_proxy(no_proxy_source, no_proxy_value, connect_kwargs): ) +def _set_mappings_for_common_backend(target_wm, wiremock_generic_mappings_dir): + """Set common backend mappings: auth, disconnect, and telemetry.""" + password_mapping = ( + wiremock_generic_mappings_dir.parent / "auth/password/successful_flow.json" + ) + disconnect_mapping = ( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json" + + target_wm.import_mapping_with_default_placeholders(password_mapping) + target_wm.add_mapping(disconnect_mapping) + target_wm.add_mapping(telemetry_mapping) + + +def _set_mappings_for_query_and_chunks( + target_wm, + wiremock_mapping_dir, + storage_or_target_wm=None, +): + """Set multi-chunk query mapping and chunk mappings. + + Args: + target_wm: The target/backend Wiremock client + wiremock_mapping_dir: Path to wiremock mappings directory + storage_or_target_wm: Optional storage Wiremock client. If not provided, chunks are added to target_wm. + """ + if storage_or_target_wm is None: + storage_or_target_wm = target_wm + + multi_chunk_request_mapping = ( + wiremock_mapping_dir / "queries/select_large_request_successful.json" + ) + chunk_1_mapping = wiremock_mapping_dir / "queries/chunk_1.json" + chunk_2_mapping = wiremock_mapping_dir / "queries/chunk_2.json" + + target_wm.add_mapping( + multi_chunk_request_mapping, + placeholders={ + "{{STORAGE_WIREMOCK_HTTP_HOST_WITH_PORT}}": storage_or_target_wm.http_host_with_port + }, + ) + + storage_or_target_wm.add_mapping_with_default_placeholders(chunk_1_mapping) + storage_or_target_wm.add_mapping_with_default_placeholders(chunk_2_mapping) + + def _execute_large_query(connect_kwargs, row_count: int): with snowflake.connector.connect(**connect_kwargs) as conn: cursors = conn.execute_string( f"select seq4() as n from table(generator(rowcount => {row_count}));" ) assert len(cursors[0]._result_set.batches) > 1 - assert list(cursors[0]) + rs = list(cursors[0]) + assert rs class RequestFlags(NamedTuple): @@ -383,6 +415,46 @@ def _collect_db_request_flags_only(proxy_wm, target_wm) -> DbRequestFlags: return DbRequestFlags(proxy_saw_db=proxy_saw_db, target_saw_db=target_saw_db) +class ProxyPrecedenceFlags(NamedTuple): + proxy1_saw_request: bool + proxy2_saw_request: bool + backend_saw_request: bool + + +def _collect_proxy_precedence_flags( + proxy1_wm, proxy2_wm, target_wm +) -> ProxyPrecedenceFlags: + """Collect flags for proxy precedence tests to see which proxy was used.""" + proxy1_reqs = requests.get( + f"{proxy1_wm.http_host_with_port}/__admin/requests" + ).json() + proxy2_reqs = requests.get( + f"{proxy2_wm.http_host_with_port}/__admin/requests" + ).json() + target_reqs = requests.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ).json() + + proxy1_saw_request = any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy1_reqs["requests"] + ) + proxy2_saw_request = any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy2_reqs["requests"] + ) + backend_saw_request = any( + "/queries/v1/query-request" in r["request"]["url"] + for r in target_reqs["requests"] + ) + + return ProxyPrecedenceFlags( + proxy1_saw_request=proxy1_saw_request, + proxy2_saw_request=proxy2_saw_request, + backend_saw_request=backend_saw_request, + ) + + @pytest.mark.skipolddriver @pytest.mark.parametrize("no_proxy_source", ["param", "env"]) def test_no_proxy_bypass_storage( @@ -672,3 +744,70 @@ def test_no_proxy_bypass_backend_and_storage_param_only( assert flags.proxy_saw_db is False assert flags.storage_saw_storage assert flags.proxy_saw_storage is False + + +@pytest.mark.skipolddriver +def test_proxy_env_vars_take_precedence_over_connection_params( + wiremock_two_proxies_backend, + wiremock_mapping_dir, + wiremock_generic_mappings_dir, + proxy_env_vars, + monkeypatch, +): + """Verify that proxy_host/proxy_port connection parameters take precedence over env vars. + + Setup: + - Set HTTP_PROXY env var to point to proxy_from_env_vars + - Set proxy_host param to point to proxy_from_conn_params + + Expected outcome: + - proxy_from_conn_params should see the request (params take precedence) + - proxy_from_env_vars should NOT see the request + - backend should see the request + """ + target_wm, proxy_from_conn_params, proxy_from_env_vars = ( + wiremock_two_proxies_backend + ) + + # Setup backend mappings for large query with multiple chunks + _set_mappings_for_common_backend(target_wm, wiremock_generic_mappings_dir) + _set_mappings_for_query_and_chunks( + target_wm, + wiremock_mapping_dir, + ) + + # Set HTTP_PROXY env var AFTER Wiremock is running using monkeypatch + # This prevents Wiremock from inheriting it and forwarding through proxy2 + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + clear_proxy_env_vars() # Clear any existing ones first + + env_proxy_url = f"http://{proxy_from_env_vars.wiremock_host}:{proxy_from_env_vars.wiremock_http_port}" + + # Set connection params to point to proxy1 (should take precedence) + connect_kwargs = _base_connect_kwargs(target_wm) + connect_kwargs.update( + { + "proxy_host": proxy_from_conn_params.wiremock_host, + "proxy_port": str(proxy_from_conn_params.wiremock_http_port), + } + ) + + with monkeypatch.context() as m_context: + m_context.setenv("HTTP_PROXY", env_proxy_url) + m_context.setenv("HTTPS_PROXY", env_proxy_url) + + # Execute query + _execute_large_query(connect_kwargs, row_count=50_000) + + # Verify proxy selection using named tuple flags + flags = _collect_proxy_precedence_flags( + proxy_from_conn_params, proxy_from_env_vars, target_wm + ) + assert not ( + flags.proxy1_saw_request + ), "proxy_from_conn_params (connection param proxy) should NOT have seen the query request" + assert flags.proxy2_saw_request, ( + "proxy_from_env_vars (env var proxy) should have seen the request " + "since connection params take precedence" + ) + assert flags.backend_saw_request, "backend should have seen the query request" diff --git a/test/unit/test_telemetry.py b/test/unit/test_telemetry.py index 336a9d9c6e..d24834534e 100644 --- a/test/unit/test_telemetry.py +++ b/test/unit/test_telemetry.py @@ -157,6 +157,30 @@ def test_telemetry_send_batch_disabled(): assert rest_call.call_count == 0 +def test_telemetry_send_batch_with_retry_flag(): + """Tests that send_batch respects the retry parameter.""" + client, rest_call = get_client_and_mock() + + client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + + # Test with retry=True + client.send_batch(retry=True) + + assert rest_call.call_count == 1 + # Verify _no_retry parameter is False when retry=True + call_kwargs = rest_call.call_args[1] + assert call_kwargs["_no_retry"] is False + + # Add another log and test with retry=False (default) + client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 3000)) + client.send_batch(retry=False) + + assert rest_call.call_count == 2 + # Verify _no_retry parameter is True when retry=False + call_kwargs = rest_call.call_args[1] + assert call_kwargs["_no_retry"] is True + + def test_generate_telemetry_data_dict_with_basic_info(): assert snowflake.connector.telemetry.generate_telemetry_data_dict() == { snowflake.connector.telemetry.TelemetryField.KEY_DRIVER_TYPE.value: CLIENT_NAME,