diff --git a/src/sentry/auth/providers/oauth2.py b/src/sentry/auth/providers/oauth2.py index 3f11932652e1a4..2612608a37d9fd 100644 --- a/src/sentry/auth/providers/oauth2.py +++ b/src/sentry/auth/providers/oauth2.py @@ -19,11 +19,20 @@ from sentry.http import safe_urlopen, safe_urlread from sentry.models.authidentity import AuthIdentity from sentry.utils.http import absolute_uri +from sentry.utils.oauth import KNOWN_OAUTH_ERROR_CODES, sanitize_oauth_error_code if TYPE_CHECKING: from sentry.auth.helper import AuthHelper +logger = logging.getLogger(__name__) ERR_INVALID_STATE = "An error occurred while validating your request." +ERR_PROVIDER_ERROR = "The identity provider returned an error while processing your request." + + +def _format_provider_error_message(error_code: str | None) -> str: + if error_code and error_code in KNOWN_OAUTH_ERROR_CODES: + return f"{ERR_PROVIDER_ERROR}\nError code: {error_code}" + return ERR_PROVIDER_ERROR def _get_redirect_url() -> str: @@ -129,7 +138,8 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa code = request.GET.get("code") if error: - return pipeline.error(error) + sanitized_error = sanitize_oauth_error_code(error) + return pipeline.error(_format_provider_error_message(sanitized_error)) if state != pipeline.fetch_state("state"): return pipeline.error(ERR_INVALID_STATE) @@ -143,8 +153,9 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa return pipeline.error(data["error_description"]) if "error" in data: - logging.info("Error exchanging token: %s", data["error"]) - return pipeline.error("Unable to retrieve your token") + sanitized_error = sanitize_oauth_error_code(data["error"]) + logger.info("Error exchanging token", extra={"error": sanitized_error}) + return pipeline.error(_format_provider_error_message(sanitized_error)) # we can either expect the API to be implicit and say "im looking for # blah within state data" or we need to pass implementation + call a diff --git a/src/sentry/identity/oauth2.py b/src/sentry/identity/oauth2.py index 6c1f6f0815b97a..ce1220409fce22 100644 --- a/src/sentry/identity/oauth2.py +++ b/src/sentry/identity/oauth2.py @@ -34,6 +34,7 @@ from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized from sentry.users.models.identity import Identity from sentry.utils.http import absolute_uri +from sentry.utils.oauth import KNOWN_OAUTH_ERROR_CODES, sanitize_oauth_error_code from .base import Provider @@ -41,9 +42,16 @@ logger = logging.getLogger(__name__) ERR_INVALID_STATE = "An error occurred while validating your request." +ERR_PROVIDER_ERROR = "The identity provider returned an error while processing your request." ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service." +def _format_provider_error_message(error_code: str | None) -> str: + if error_code and error_code in KNOWN_OAUTH_ERROR_CODES: + return f"{ERR_PROVIDER_ERROR}\nError code: {error_code}" + return ERR_PROVIDER_ERROR + + def _redirect_url(pipeline: IdentityPipeline) -> str: associate_url = reverse( "sentry-extension-setup", @@ -359,11 +367,12 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp code = request.GET.get("code") if error: + sanitized_error = sanitize_oauth_error_code(error) lifecycle.record_failure( IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, - extra={"error": error}, + extra={"error": sanitized_error or "unknown_error"}, ) - return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}") + return pipeline.error(_format_provider_error_message(sanitized_error)) if state != pipeline.fetch_state("state"): extra = { @@ -390,8 +399,9 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp return pipeline.error(data["error_description"]) if "error" in data: - logger.info("identity.token-exchange-error", extra={"error": data["error"]}) - return pipeline.error(f"{ERR_TOKEN_RETRIEVAL}\nError: {data['error']}") + sanitized_error = sanitize_oauth_error_code(data["error"]) + logger.info("identity.token-exchange-error", extra={"error": sanitized_error}) + return pipeline.error(_format_provider_error_message(sanitized_error)) # we can either expect the API to be implicit and say "im looking for # blah within state data" or we need to pass implementation + call a diff --git a/src/sentry/utils/oauth.py b/src/sentry/utils/oauth.py new file mode 100644 index 00000000000000..01682e4fb90ebc --- /dev/null +++ b/src/sentry/utils/oauth.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import re + +__all__ = ["KNOWN_OAUTH_ERROR_CODES", "sanitize_oauth_error_code"] + +_MAX_ERROR_LENGTH = 64 +_INVALID_ERROR_CHARS_RE = re.compile(r"[^a-z0-9._-]") + +# Common error codes returned by OAuth/OIDC providers. This list combines values from +# RFC 6749, RFC 8628, and popular provider-specific extensions so we can safely display +# known codes back to the user. +KNOWN_OAUTH_ERROR_CODES = frozenset( + { + "account_selection_required", + "access_denied", + "authorization_pending", + "consent_required", + "expired_token", + "interaction_required", + "invalid_client", + "invalid_grant", + "invalid_request", + "invalid_scope", + "login_required", + "mfa_enrollment_required", + "mfa_required", + "registration_not_supported", + "server_error", + "slow_down", + "temporarily_unavailable", + "unauthorized_client", + "unsupported_response_type", + "user_cancelled_authorize", + } +) + + +def sanitize_oauth_error_code(error: str | None) -> str | None: + """ + Normalize an OAuth error code so it is safe to log or compare. + + - Downcases the value + - Removes all characters outside of ``[a-z0-9._-]`` + - Truncates overly long values + """ + + if not error: + return None + + sanitized = _INVALID_ERROR_CHARS_RE.sub("", error.lower()) + sanitized = sanitized[:_MAX_ERROR_LENGTH] + return sanitized or None diff --git a/tests/sentry/auth/providers/test_oauth2.py b/tests/sentry/auth/providers/test_oauth2.py index 9a0dedc496d0f8..6a2f8c984890d7 100644 --- a/tests/sentry/auth/providers/test_oauth2.py +++ b/tests/sentry/auth/providers/test_oauth2.py @@ -1,11 +1,14 @@ from collections.abc import Mapping from functools import cached_property from typing import Any +from unittest.mock import MagicMock, patch import pytest +from django.http import HttpResponse +from django.test import RequestFactory from sentry.auth.exceptions import IdentityNotValid -from sentry.auth.providers.oauth2 import OAuth2Provider +from sentry.auth.providers.oauth2 import ERR_PROVIDER_ERROR, OAuth2Callback, OAuth2Provider from sentry.models.authidentity import AuthIdentity from sentry.models.authprovider import AuthProvider from sentry.testutils.cases import TestCase @@ -48,3 +51,45 @@ def test_refresh_identity_without_refresh_token(self) -> None: provider = DummyOAuth2Provider() with pytest.raises(IdentityNotValid): provider.refresh_identity(auth_identity) + + +@control_silo_test +class OAuth2CallbackErrorHandlingTest(TestCase): + def setUp(self) -> None: + super().setUp() + self.callback = OAuth2Callback(client_id="client-id", client_secret="secret") + self.factory = RequestFactory() + + def test_error_query_param_not_reflected(self) -> None: + request = self.factory.get("/", {"error": "1-1)) OR 114=(SELECT 114 FROM PG_SLEEP(15))--"}) + pipeline = MagicMock() + pipeline.error.return_value = HttpResponse() + + response = self.callback.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + (message,) = pipeline.error.call_args[0] + assert message == ERR_PROVIDER_ERROR + assert "PG_SLEEP" not in message + + def test_error_query_param_known_code(self) -> None: + request = self.factory.get("/", {"error": "access_denied"}) + pipeline = MagicMock() + pipeline.error.return_value = HttpResponse() + + self.callback.dispatch(request, pipeline) + + (message,) = pipeline.error.call_args[0] + assert message.endswith("access_denied") + + def test_exchange_token_error_known_code(self) -> None: + request = self.factory.get("/", {"state": "abc", "code": "123"}) + pipeline = MagicMock() + pipeline.error.return_value = HttpResponse() + pipeline.fetch_state.return_value = "abc" + + with patch.object(self.callback, "exchange_token", return_value={"error": "access_denied"}): + self.callback.dispatch(request, pipeline) + + (message,) = pipeline.error.call_args[0] + assert "access_denied" in message diff --git a/tests/sentry/identity/test_oauth2.py b/tests/sentry/identity/test_oauth2.py index 3df0c8881e6d90..1b6851f6f8187c 100644 --- a/tests/sentry/identity/test_oauth2.py +++ b/tests/sentry/identity/test_oauth2.py @@ -5,14 +5,16 @@ from urllib.parse import parse_qs, parse_qsl, urlparse import responses +from django.http import HttpResponse from django.test import Client, RequestFactory from requests.exceptions import SSLError import sentry.identity -from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView +from sentry.identity.oauth2 import ERR_PROVIDER_ERROR, OAuth2CallbackView, OAuth2LoginView from sentry.identity.pipeline import IdentityPipeline from sentry.identity.providers.dummy import DummyProvider from sentry.integrations.types import EventLifecycleOutcome +from sentry.integrations.utils.metrics import IntegrationPipelineErrorReason from sentry.shared_integrations.exceptions import ApiUnauthorized from sentry.testutils.asserts import assert_failure_metric, assert_slo_metric from sentry.testutils.silo import control_silo_test @@ -157,6 +159,32 @@ def test_api_error(self, mock_record: MagicMock) -> None: assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}')) + def test_error_query_param_not_reflected(self, mock_record: MagicMock) -> None: + pipeline = IdentityPipeline(request=self.request, provider_key="dummy") + malicious_error = "1-1)) OR 114=(SELECT 114 FROM PG_SLEEP(15))--" + request = RequestFactory().get("/", {"error": malicious_error}) + + with patch.object(pipeline, "error", return_value=HttpResponse()) as error_mock: + response = self.view.dispatch(request, pipeline) + + assert response == error_mock.return_value + (message,) = error_mock.call_args[0] + assert message == ERR_PROVIDER_ERROR + assert "PG_SLEEP" not in message + assert_failure_metric( + mock_record, IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE + ) + + def test_error_query_param_known_code(self, mock_record: MagicMock) -> None: + pipeline = IdentityPipeline(request=self.request, provider_key="dummy") + request = RequestFactory().get("/", {"error": "access_denied"}) + + with patch.object(pipeline, "error", return_value=HttpResponse()) as error_mock: + self.view.dispatch(request, pipeline) + + (message,) = error_mock.call_args[0] + assert "access_denied" in message + @control_silo_test class OAuth2LoginViewTest(TestCase):