diff --git a/pyrit/auth/__init__.py b/pyrit/auth/__init__.py index 4074809e5..3b17fcef6 100644 --- a/pyrit/auth/__init__.py +++ b/pyrit/auth/__init__.py @@ -7,8 +7,10 @@ from pyrit.auth.authenticator import Authenticator from pyrit.auth.azure_auth import ( + AsyncTokenProviderCredential, AzureAuth, TokenProviderCredential, + ensure_async_token_provider, get_azure_async_token_provider, get_azure_openai_auth, get_azure_token_provider, @@ -19,12 +21,14 @@ from pyrit.auth.manual_copilot_authenticator import ManualCopilotAuthenticator __all__ = [ + "AsyncTokenProviderCredential", "Authenticator", "AzureAuth", "AzureStorageAuth", "CopilotAuthenticator", "ManualCopilotAuthenticator", "TokenProviderCredential", + "ensure_async_token_provider", "get_azure_token_provider", "get_azure_async_token_provider", "get_default_azure_scope", diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 00e2f8d6f..4149749e4 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -3,6 +3,7 @@ from __future__ import annotations +import inspect import logging import time from typing import TYPE_CHECKING, Any, Union, cast @@ -22,7 +23,7 @@ ) if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Awaitable, Callable import azure.cognitiveservices.speech as speechsdk @@ -66,6 +67,103 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: return AccessToken(str(token), expires_on) +class AsyncTokenProviderCredential: + """ + Async wrapper to convert a token provider callable into an Azure AsyncTokenCredential. + + This class bridges the gap between token provider functions (sync or async) and Azure SDK + async clients that require an AsyncTokenCredential object (with async def get_token). + """ + + def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) -> None: + """ + Initialize AsyncTokenProviderCredential. + + Args: + token_provider: A callable that returns a token string (sync) or an awaitable that + returns a token string (async). Both are supported transparently. + """ + self._token_provider = token_provider + + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """ + Get an access token asynchronously. + + Args: + scopes: Token scopes (ignored as the scope is already configured in the token provider). + kwargs: Additional arguments (ignored). + + Returns: + AccessToken: The access token with expiration time. + """ + result = self._token_provider() + if inspect.isawaitable(result): + token = await result + else: + token = result + expires_on = int(time.time()) + 3600 + return AccessToken(str(token), expires_on) + + async def close(self) -> None: + """No-op close for protocol compliance. The callable provider does not hold resources.""" + + async def __aenter__(self) -> AsyncTokenProviderCredential: + """ + Enter the async context manager. + + Returns: + AsyncTokenProviderCredential: This credential instance. + """ + return self + + async def __aexit__(self, *args: Any) -> None: + """Exit the async context manager.""" + await self.close() + + +def ensure_async_token_provider( + api_key: str | Callable[[], str | Awaitable[str]] | None, +) -> str | Callable[[], Awaitable[str]] | None: + """ + Ensure the api_key is either a string or an async callable. + + If a synchronous callable token provider is provided, it's automatically wrapped + in an async function to make it compatible with async Azure SDK clients. + + Args: + api_key: Either a string API key or a callable that returns a token (sync or async). + + Returns: + Either a string API key or an async callable that returns a token. + """ + if api_key is None or isinstance(api_key, str) or not callable(api_key): + return api_key + + # Check if the callable is already async + if inspect.iscoroutinefunction(api_key): + return api_key + + # Wrap synchronous token provider in async function + logger.debug( + "Detected synchronous token provider." + " Automatically wrapping in async function for compatibility with async client." + ) + + async def async_token_provider() -> str: + """ + Async wrapper for synchronous token provider. + + Returns: + str: The token string from the synchronous provider. + """ + result = api_key() + if inspect.isawaitable(result): + return await result + return result + + return async_token_provider + + class AzureAuth(Authenticator): """ Azure CLI Authentication. diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 0128991e3..0549cd8f6 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import asyncio import json import logging import re @@ -23,7 +22,7 @@ AuthenticationError, ) -from pyrit.auth import get_azure_openai_auth +from pyrit.auth import ensure_async_token_provider, get_azure_openai_auth from pyrit.common import default_values from pyrit.exceptions.exception_classes import ( RateLimitException, @@ -41,46 +40,6 @@ logger = logging.getLogger(__name__) -def _ensure_async_token_provider( - api_key: Optional[str | Callable[[], str | Awaitable[str]]], -) -> Optional[str | Callable[[], Awaitable[str]]]: - """ - Ensure the api_key is either a string or an async callable. - - If a synchronous callable token provider is provided, it's automatically wrapped - in an async function to make it compatible with AsyncOpenAI. - - Args: - api_key: Either a string API key or a callable that returns a token (sync or async). - - Returns: - Either a string API key or an async callable that returns a token. - """ - if api_key is None or isinstance(api_key, str) or not callable(api_key): - return api_key - - # Check if the callable is already async - if asyncio.iscoroutinefunction(api_key): - return api_key - - # Wrap synchronous token provider in async function - logger.info( - "Detected synchronous token provider." - " Automatically wrapping in async function for compatibility with AsyncOpenAI." - ) - - async def async_token_provider() -> str: - """ - Async wrapper for synchronous token provider. - - Returns: - str: The token string from the synchronous provider. - """ - return api_key() # type: ignore[return-value] - - return async_token_provider - - class OpenAITarget(PromptTarget): """ Abstract base class for OpenAI-based prompt targets. @@ -198,7 +157,7 @@ def __init__( ) # Ensure api_key is async-compatible (wrap sync token providers if needed) - self._api_key = _ensure_async_token_provider(resolved_api_key) + self._api_key = ensure_async_token_provider(resolved_api_key) self._initialize_openai_client() diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 6b587d4f3..16aa3d75a 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -1,13 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import asyncio import base64 -import inspect -from collections.abc import Callable +import logging +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Optional -from azure.ai.contentsafety import ContentSafetyClient +from azure.ai.contentsafety.aio import ContentSafetyClient from azure.ai.contentsafety.models import ( AnalyzeImageOptions, AnalyzeImageResult, @@ -18,7 +17,7 @@ ) from azure.core.credentials import AzureKeyCredential -from pyrit.auth import TokenProviderCredential, get_azure_token_provider +from pyrit.auth import AsyncTokenProviderCredential, ensure_async_token_provider, get_azure_async_token_provider from pyrit.common import default_values from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( @@ -38,6 +37,8 @@ from pyrit.score.scorer_evaluation.scorer_evaluator import ScorerEvalDatasetFiles from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics +logger = logging.getLogger(__name__) + class AzureContentFilterScorer(FloatScaleScorer): """ @@ -94,7 +95,7 @@ def __init__( self, *, endpoint: Optional[str | None] = None, - api_key: Optional[str | Callable[[], str] | None] = None, + api_key: Optional[str | Callable[[], str | Awaitable[str]] | None] = None, harm_categories: Optional[list[TextCategory]] = None, validator: Optional[ScorerPromptValidator] = None, ) -> None: @@ -104,13 +105,12 @@ def __init__( Args: endpoint (Optional[str | None]): The endpoint URL for the Azure Content Safety service. Defaults to the `ENDPOINT_URI_ENVIRONMENT_VARIABLE` environment variable. - api_key (Optional[str | Callable[[], str] | None]): + api_key (Optional[str | Callable[[], str | Awaitable[str]] | None]): The API key for accessing the Azure Content Safety service, - or a synchronous callable that returns an access token. Async token providers - are not supported. If not provided (via parameter - or environment variable), Entra ID authentication is used automatically. - You can also explicitly pass a token provider from pyrit.auth - (e.g., get_azure_token_provider('https://cognitiveservices.azure.com/.default')). + or a callable that returns an access token. Both synchronous and asynchronous + token providers are supported. Sync providers are automatically wrapped for + async compatibility. If not provided (via parameter or environment variable), + Entra ID authentication is used automatically. Defaults to the `API_KEY_ENVIRONMENT_VARIABLE` environment variable. harm_categories (Optional[list[TextCategory]]): The harm categories you want to query for as defined in azure.ai.contentsafety.models.TextCategory. If not provided, defaults to all categories. @@ -129,36 +129,25 @@ def __init__( ) # API key: use passed value, env var, or fall back to Entra ID for Azure endpoints - resolved_api_key: str | Callable[[], str] + resolved_api_key: str | Callable[[], str | Awaitable[str]] if api_key is not None and callable(api_key): - if asyncio.iscoroutinefunction(api_key): - raise ValueError( - "Async token providers are not supported by AzureContentFilterScorer. " - "Use a synchronous token provider (e.g., get_azure_token_provider) instead." - ) - # Guard against sync callables that return coroutines/awaitables (e.g., lambda: async_fn()) - test_result = api_key() - if inspect.isawaitable(test_result): - if hasattr(test_result, "close"): - test_result.close() # prevent "coroutine was never awaited" warning - raise ValueError( - "The provided token provider returns a coroutine/awaitable, which is not supported " - "by AzureContentFilterScorer. Use a synchronous token provider instead." - ) resolved_api_key = api_key else: api_key_value = default_values.get_non_required_value( env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key ) - resolved_api_key = api_key_value or get_azure_token_provider("https://cognitiveservices.azure.com/.default") + resolved_api_key = api_key_value or get_azure_async_token_provider( + "https://cognitiveservices.azure.com/.default" + ) - self._api_key = resolved_api_key + # Ensure api_key is async-compatible (wrap sync token providers if needed) + self._api_key = ensure_async_token_provider(resolved_api_key) # Create ContentSafetyClient with appropriate credential if self._endpoint is not None: if callable(self._api_key): - # Token provider - create a TokenCredential wrapper - credential = TokenProviderCredential(self._api_key) + # Token provider - create an AsyncTokenCredential wrapper + credential = AsyncTokenProviderCredential(self._api_key) self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) else: # String API key @@ -291,7 +280,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op categories=self._category_values, output_type="EightSeverityLevels", ) - text_result = self._azure_cf_client.analyze_text(text_request_options) + text_result = await self._azure_cf_client.analyze_text(text_request_options) filter_results.append(text_result) elif message_piece.converted_value_data_type == "image_path": @@ -301,7 +290,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op image_request_options = AnalyzeImageOptions( image=image_data, categories=self._category_values, output_type="FourSeverityLevels" ) - image_result = self._azure_cf_client.analyze_image(image_request_options) + image_result = await self._azure_cf_client.analyze_image(image_request_options) filter_results.append(image_result) # Collect all scores from all chunks/images diff --git a/tests/integration/score/test_azure_content_filter_integration.py b/tests/integration/score/test_azure_content_filter_integration.py index 9f7fdef20..b0b0d1f9c 100644 --- a/tests/integration/score/test_azure_content_filter_integration.py +++ b/tests/integration/score/test_azure_content_filter_integration.py @@ -23,17 +23,12 @@ async def test_azure_content_filter_scorer_image_integration(memory) -> None: """ Integration test for Azure Content Filter Scorer with image input. - This test requires AZURE_CONTENT_SAFETY_API_KEY and AZURE_CONTENT_SAFETY_API_ENDPOINT - environment variables to be set. Uses a sample image from the assets folder. + This test requires AZURE_CONTENT_SAFETY_API_ENDPOINT to be set. + Authentication uses Entra ID by default (via `az login`). Alternatively, + set AZURE_CONTENT_SAFETY_API_KEY for API key auth. + Uses a sample image from the assets folder. """ with patch.object(CentralMemory, "get_memory_instance", return_value=memory): - # Verify required environment variables are set - api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") - endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT") - - if not api_key or not endpoint: - pytest.skip("Azure Content Safety credentials not configured") - scorer = AzureContentFilterScorer() image_path = HOME_PATH / "assets" / "architecture_components.png" @@ -62,13 +57,6 @@ async def test_azure_content_filter_scorer_long_text_chunking_integration(memory This verifies that the chunking and aggregation logic works correctly with the real API. """ with patch.object(CentralMemory, "get_memory_instance", return_value=memory): - # Verify required environment variables are set - api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") - endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT") - - if not api_key or not endpoint: - pytest.skip("Azure Content Safety credentials not configured") - scorer = AzureContentFilterScorer() # This should be greater than the rate limit @@ -91,3 +79,62 @@ async def test_azure_content_filter_scorer_long_text_chunking_integration(memory # Long benign text should still have low scores max_score = max(float(s.score_value) for s in scores) assert max_score > 0, "text should have > 0 score" + + +@pytest.mark.asyncio +async def test_azure_content_filter_scorer_image_with_api_key(memory) -> None: + """ + Integration test for Azure Content Filter Scorer image scoring with explicit API key auth. + + This test requires AZURE_CONTENT_SAFETY_API_KEY to be set. + """ + api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") + assert api_key, "AZURE_CONTENT_SAFETY_API_KEY must be set to test API key auth" + + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + scorer = AzureContentFilterScorer(api_key=api_key) + + image_path = HOME_PATH / "assets" / "architecture_components.png" + scores = await scorer.score_image_async(image_path=str(image_path)) + + assert len(scores) > 0, "Expected scores for harm categories" + + for score in scores: + assert score.score_type == "float_scale" + assert score.score_category is not None + assert 0.0 <= float(score.score_value) <= 1.0 + assert "azure_severity" in score.score_metadata + + max_score = max(float(s.score_value) for s in scores) + assert max_score < 0.5, "Architecture diagram should have low harm scores" + + +@pytest.mark.asyncio +async def test_azure_content_filter_scorer_text_with_api_key(memory) -> None: + """ + Integration test for Azure Content Filter Scorer text scoring with explicit API key auth. + + This test requires AZURE_CONTENT_SAFETY_API_KEY to be set. + """ + api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") + assert api_key, "AZURE_CONTENT_SAFETY_API_KEY must be set to test API key auth" + + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + scorer = AzureContentFilterScorer(api_key=api_key) + + base_text = "I hate you " * 10000 + + assert len(base_text) > AzureContentFilterScorer.MAX_TEXT_LENGTH + + scores = await scorer.score_text_async(text=base_text) + + assert len(scores) > 0, "Expected aggregated scores for harm categories" + + for score in scores: + assert score.score_type == "float_scale" + assert score.score_category is not None + assert 0.0 <= float(score.score_value) <= 1.0 + assert "azure_severity" in score.score_metadata + + max_score = max(float(s.score_value) for s in scores) + assert max_score > 0, "text should have > 0 score" diff --git a/tests/unit/score/test_azure_content_filter.py b/tests/unit/score/test_azure_content_filter.py index 27c0cc298..9cd391398 100644 --- a/tests/unit/score/test_azure_content_filter.py +++ b/tests/unit/score/test_azure_content_filter.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. +import inspect import os from unittest.mock import AsyncMock, MagicMock, patch @@ -55,7 +56,7 @@ async def test_score_async_unsupported_data_type_returns_empty_list( @pytest.mark.asyncio async def test_score_piece_async_text(patch_central_database, text_message_piece: MessagePiece): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client scores = await scorer._score_piece_async(text_message_piece) @@ -72,7 +73,7 @@ async def test_score_piece_async_text(patch_central_database, text_message_piece @pytest.mark.asyncio async def test_score_piece_async_image(patch_central_database, image_message_piece: MessagePiece): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_image.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client # Patch _get_base64_image_data to avoid actual file IO @@ -102,25 +103,62 @@ def test_explicit_category(): assert len(scorer._harm_categories) == 1 -def test_async_callable_api_key_raises(): +def test_async_callable_api_key_accepted(): async def async_provider(): return "token" - with pytest.raises(ValueError, match="Async token providers are not supported"): - AzureContentFilterScorer(api_key=async_provider, endpoint="bar") + scorer = AzureContentFilterScorer(api_key=async_provider, endpoint="bar") + # Async callable should be passed through as-is + assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) + + +@pytest.mark.asyncio +async def test_async_callable_api_key_returns_token(): + async def async_provider(): + return "token" + + scorer = AzureContentFilterScorer(api_key=async_provider, endpoint="bar") + result = await scorer._api_key() + assert result == "token" -def test_sync_callable_returning_coroutine_raises(): +def test_sync_callable_returning_coroutine_accepted(): async def async_fn(): return "token" - with pytest.raises(ValueError, match="returns a coroutine/awaitable"): - AzureContentFilterScorer(api_key=lambda: async_fn(), endpoint="bar") + sync_lambda = lambda: async_fn() # noqa: E731 + # Confirm the lambda itself is NOT a coroutine function (it's sync) + assert not inspect.iscoroutinefunction(sync_lambda) + + scorer = AzureContentFilterScorer(api_key=sync_lambda, endpoint="bar") + # After init, the sync callable should be wrapped in an async function + assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) + + +@pytest.mark.asyncio +async def test_sync_callable_returning_coroutine_returns_token(): + async def async_fn(): + return "token" + + sync_lambda = lambda: async_fn() # noqa: E731 + scorer = AzureContentFilterScorer(api_key=sync_lambda, endpoint="bar") + result = await scorer._api_key() + assert result == "token" def test_sync_callable_api_key_accepted(): scorer = AzureContentFilterScorer(api_key=lambda: "token", endpoint="bar") assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) + + +@pytest.mark.asyncio +async def test_sync_callable_api_key_returns_token(): + scorer = AzureContentFilterScorer(api_key=lambda: "token", endpoint="bar") + result = await scorer._api_key() + assert result == "token" @pytest.mark.asyncio @@ -129,7 +167,7 @@ async def test_azure_content_filter_scorer_adds_to_memory(): with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -143,7 +181,7 @@ async def test_azure_content_filter_scorer_adds_to_memory(): async def test_azure_content_filter_scorer_score(patch_central_database): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -181,7 +219,7 @@ async def test_azure_content_filter_scorer_chunks_long_text(patch_central_databa with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() # Mock returns for two chunks mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -205,7 +243,7 @@ async def test_azure_content_filter_scorer_accepts_short_text(patch_central_data with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client diff --git a/tests/unit/target/test_openai_target_auth.py b/tests/unit/target/test_openai_target_auth.py index 2045ae6e2..efe00a0ca 100644 --- a/tests/unit/target/test_openai_target_auth.py +++ b/tests/unit/target/test_openai_target_auth.py @@ -9,7 +9,8 @@ import pytest -from pyrit.prompt_target.openai.openai_target import OpenAITarget, _ensure_async_token_provider +from pyrit.auth import ensure_async_token_provider +from pyrit.prompt_target.openai.openai_target import OpenAITarget class _ConcreteOpenAITarget(OpenAITarget): @@ -126,30 +127,30 @@ def test_param_api_key_takes_precedence_over_env_var(self): class TestEnsureAsyncTokenProvider: - """Tests for the _ensure_async_token_provider helper function.""" + """Tests for the ensure_async_token_provider helper function.""" def test_none_returns_none(self): - assert _ensure_async_token_provider(None) is None + assert ensure_async_token_provider(None) is None def test_string_returns_string(self): - assert _ensure_async_token_provider("my-key") == "my-key" + assert ensure_async_token_provider("my-key") == "my-key" def test_async_callable_returned_as_is(self): async def provider() -> str: return "token" - result = _ensure_async_token_provider(provider) + result = ensure_async_token_provider(provider) assert result is provider def test_sync_callable_wrapped_to_async(self): def provider() -> str: return "sync-token" - result = _ensure_async_token_provider(provider) + result = ensure_async_token_provider(provider) assert asyncio.iscoroutinefunction(result) assert asyncio.run(result()) == "sync-token" def test_non_callable_non_string_returned_as_is(self): # Edge case: something that's not a string and not callable - result = _ensure_async_token_provider(42) # type: ignore[arg-type] + result = ensure_async_token_provider(42) # type: ignore[arg-type] assert result == 42 diff --git a/tests/unit/target/test_token_provider_wrapping.py b/tests/unit/target/test_token_provider_wrapping.py index 8ca874ce0..9f5657fed 100644 --- a/tests/unit/target/test_token_provider_wrapping.py +++ b/tests/unit/target/test_token_provider_wrapping.py @@ -6,7 +6,7 @@ import pytest -from pyrit.prompt_target.openai.openai_target import _ensure_async_token_provider +from pyrit.auth import ensure_async_token_provider class TestTokenProviderWrapping: @@ -15,13 +15,13 @@ class TestTokenProviderWrapping: def test_string_api_key_unchanged(self): """Test that string API keys are returned unchanged.""" api_key = "sk-test-key-12345" - result = _ensure_async_token_provider(api_key) + result = ensure_async_token_provider(api_key) assert result == api_key assert isinstance(result, str) def test_none_api_key_unchanged(self): """Test that None is returned unchanged.""" - result = _ensure_async_token_provider(None) + result = ensure_async_token_provider(None) assert result is None def test_async_token_provider_unchanged(self): @@ -30,7 +30,7 @@ def test_async_token_provider_unchanged(self): async def async_token_provider(): return "async-token" - result = _ensure_async_token_provider(async_token_provider) + result = ensure_async_token_provider(async_token_provider) assert result is async_token_provider assert asyncio.iscoroutinefunction(result) @@ -40,7 +40,7 @@ def test_sync_token_provider_wrapped(self): def sync_token_provider(): return "sync-token" - result = _ensure_async_token_provider(sync_token_provider) + result = ensure_async_token_provider(sync_token_provider) # Should return a different callable (the wrapper) assert result is not sync_token_provider @@ -54,7 +54,7 @@ async def test_wrapped_sync_provider_returns_correct_token(self): def sync_token_provider(): return "my-sync-token" - wrapped = _ensure_async_token_provider(sync_token_provider) + wrapped = ensure_async_token_provider(sync_token_provider) # Call the wrapped provider token = await wrapped() @@ -67,7 +67,7 @@ async def test_async_provider_returns_correct_token(self): async def async_token_provider(): return "my-async-token" - result = _ensure_async_token_provider(async_token_provider) + result = ensure_async_token_provider(async_token_provider) # Should be the same function assert result is async_token_provider @@ -86,7 +86,7 @@ def sync_token_provider(): call_count += 1 return f"token-{call_count}" - wrapped = _ensure_async_token_provider(sync_token_provider) + wrapped = ensure_async_token_provider(sync_token_provider) # Call multiple times token1 = await wrapped() @@ -97,15 +97,15 @@ def sync_token_provider(): assert call_count == 2 def test_sync_provider_wrapping_logs_info(self): - """Test that wrapping a sync provider logs an info message.""" + """Test that wrapping a sync provider logs a debug message.""" def sync_token_provider(): return "token" - with patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger: - _ensure_async_token_provider(sync_token_provider) - mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0][0] + with patch("pyrit.auth.azure_auth.logger") as mock_logger: + ensure_async_token_provider(sync_token_provider) + mock_logger.debug.assert_called_once() + call_args = mock_logger.debug.call_args[0][0] assert "synchronous token provider" in call_args.lower() assert "wrapping" in call_args.lower() @@ -124,7 +124,7 @@ def sync_token_provider(): with ( patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai, - patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger, + patch("pyrit.auth.azure_auth.logger") as mock_logger, ): mock_client = AsyncMock() mock_openai.return_value = mock_client @@ -135,14 +135,14 @@ def sync_token_provider(): api_key=sync_token_provider, ) - # Verify that info log was called about wrapping - mock_logger.info.assert_called() + # Verify that debug log was called about wrapping + mock_logger.debug.assert_called() info_call_found = False - for call in mock_logger.info.call_args_list: + for call in mock_logger.debug.call_args_list: if "synchronous token provider" in str(call).lower(): info_call_found = True break - assert info_call_found, "Expected info log about wrapping sync token provider" + assert info_call_found, "Expected debug log about wrapping sync token provider" # Verify AsyncOpenAI was initialized mock_openai.assert_called_once() @@ -223,7 +223,7 @@ def mock_sync_bearer_token_provider(): with ( patch("pyrit.prompt_target.openai.openai_target.AsyncOpenAI") as mock_openai, - patch("pyrit.prompt_target.openai.openai_target.logger") as mock_logger, + patch("pyrit.auth.azure_auth.logger") as mock_logger, ): mock_client = AsyncMock() mock_openai.return_value = mock_client @@ -235,7 +235,7 @@ def mock_sync_bearer_token_provider(): ) # Verify that sync provider was wrapped - mock_logger.info.assert_called() + mock_logger.debug.assert_called() # Get the wrapped api_key call_kwargs = mock_openai.call_args[1]