From 55ea89ff599dc5ebbbf0ab1ed292a0b5608ac65f Mon Sep 17 00:00:00 2001 From: Tanmay Mehta Date: Mon, 27 Apr 2026 16:57:40 +0000 Subject: [PATCH] initial changes --- CHANGELOG.md | 1 + docs/source/snowpark/secrets.rst | 1 + src/snowflake/snowpark/secrets.py | 39 +++++++++++++++++++++++++++++++ tests/integ/test_secrets.py | 3 +++ tests/unit/test_secrets.py | 12 ++++++++++ 5 files changed, 56 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 959915450f..6bceac7c61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### New Features - Added `artifact_repository` support to `udtf_configs` in `session.read.dbapi()`, enabling users to specify a custom artifact repository (e.g. PyPI) for packages used by the internal UDTF during distributed ingestion. +- Added `get_wif_token` to `snowflake.snowpark.secrets` for workload identity federation tokens on the Snowflake server (not available in SPCS file-based secret environments). #### Bug Fixes diff --git a/docs/source/snowpark/secrets.rst b/docs/source/snowpark/secrets.rst index 0c650171fe..2edcd41e19 100644 --- a/docs/source/snowpark/secrets.rst +++ b/docs/source/snowpark/secrets.rst @@ -22,3 +22,4 @@ Snowpark Secrets get_secret_type get_username_password get_cloud_provider_token + get_wif_token diff --git a/src/snowflake/snowpark/secrets.py b/src/snowflake/snowpark/secrets.py index cce9a15d7b..cae3521ba2 100644 --- a/src/snowflake/snowpark/secrets.py +++ b/src/snowflake/snowpark/secrets.py @@ -20,6 +20,7 @@ "get_secret_type", "get_username_password", "get_cloud_provider_token", + "get_wif_token", "UsernamePassword", "CloudProviderToken", ] @@ -61,6 +62,10 @@ def get_username_password(self, secret_name: str) -> UsernamePassword: def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken: pass + @abstractmethod + def get_wif_token(self, secret_name: str, audience: str) -> str: + pass + class _SnowflakeSecretsServer(_SnowflakeSecrets): """Secret instance for Snowflake server environment (using _snowflake module).""" @@ -89,6 +94,9 @@ def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken: secret_object.token, ) + def get_wif_token(self, secret_name: str, audience: str) -> str: + return self._snowflake.get_wif_token(secret_name, audience) + class _SnowflakeSecretsSPCS(_SnowflakeSecrets): """Secret instance for SPCS container environment (file-based secrets).""" @@ -173,6 +181,11 @@ def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken: "Cloud provider token secrets are not supported in SPCS container environments." ) + def get_wif_token(self, secret_name: str, audience: str) -> str: + raise NotImplementedError( + "WIF token secrets are not supported in SPCS container environments." + ) + def _is_spcs_environment() -> bool: return os.getenv(_SCLS_SPCS_SECRET_ENV_NAME, None) is not None @@ -259,3 +272,29 @@ def get_cloud_provider_token(secret_name: str) -> CloudProviderToken: NotImplementedError: If running outside Snowflake server environment. """ return _get_secrets_instance().get_cloud_provider_token(secret_name) + + +def get_wif_token(secret_name: str, audience: str) -> str: + """Get a workload identity federation (WIF) token from Snowflake. + + Note: + Requires a Snowflake environment with a WIF secret configured and an + external access integration that allows the UDF or stored procedure to + use that secret. The ``audience`` must match the token audience expected + by the external system (for example, an OAuth token endpoint URL). + + Args: + secret_name: The secret reference name bound to the WIF secret. + audience: The intended audience (``aud``) for the issued token. + + Returns: + The issued token as a string (typically a JWT). + + Raises: + NotImplementedError: If running outside the Snowflake server environment + (including SPCS file-based secret environments, where WIF tokens cannot + be minted). + ValueError: If the secret does not exist or is not authorized (when + applicable in supported environments). + """ + return _get_secrets_instance().get_wif_token(secret_name, audience) diff --git a/tests/integ/test_secrets.py b/tests/integ/test_secrets.py index f67b9b5701..1a7623bfe5 100644 --- a/tests/integ/test_secrets.py +++ b/tests/integ/test_secrets.py @@ -9,6 +9,7 @@ get_secret_type, get_cloud_provider_token, get_oauth_access_token, + get_wif_token, ) from snowflake.snowpark.types import BooleanType, StringType from tests.utils import IS_NOT_ON_GITHUB, RUNNING_ON_JENKINS, IS_IN_STORED_PROC, Utils @@ -169,3 +170,5 @@ def test_secrets_import_error(): get_cloud_provider_token("c1") with pytest.raises(NotImplementedError): get_oauth_access_token("o1") + with pytest.raises(NotImplementedError): + get_wif_token("w1", "https://audience") diff --git a/tests/unit/test_secrets.py b/tests/unit/test_secrets.py index 0964f212da..b5e393b860 100644 --- a/tests/unit/test_secrets.py +++ b/tests/unit/test_secrets.py @@ -12,6 +12,7 @@ get_secret_type, get_username_password, get_cloud_provider_token, + get_wif_token, UsernamePassword, CloudProviderToken, _SCLS_SPCS_SECRET_ENV_NAME, @@ -31,6 +32,7 @@ def _build_fake_snowflake_module() -> object: get_secret_type=lambda secret_name: "PASSWORD", get_username_password=lambda secret_name: fake_username_password, get_cloud_provider_token=lambda secret_name: fake_cloud_token, + get_wif_token=lambda secret_name, audience: f"wif:{secret_name}:{audience}", ) @@ -52,6 +54,11 @@ def test_secrets_mock_server_paths(): assert cloud.secret_access_key == "SECRET_TEST" assert cloud.token == "STS_TOKEN_TEST" + assert ( + get_wif_token("w1", "https://example.com/aud") + == "wif:w1:https://example.com/aud" + ) + @pytest.fixture def scls_spcs_mock_env(tmp_path): @@ -135,6 +142,9 @@ def test_secrets_mock_scls_spcs_error_cases(scls_spcs_mock_env): with pytest.raises(NotImplementedError): get_cloud_provider_token("any_secret") + with pytest.raises(NotImplementedError): + get_wif_token("any_secret", "https://audience") + with pytest.raises(ValueError, match="Unknown secret type"): get_secret_type("unknown_secret") @@ -159,6 +169,8 @@ def test_secrets_import_error_paths(): get_username_password("p1") with pytest.raises(NotImplementedError): get_cloud_provider_token("c1") + with pytest.raises(NotImplementedError): + get_wif_token("w1", "https://audience") finally: if original_env is not None: os.environ[_SCLS_SPCS_SECRET_ENV_NAME] = original_env