diff --git a/src/sentry/identity/oauth2.py b/src/sentry/identity/oauth2.py index 6c1f6f0815b97a..5c6cbae6e63abd 100644 --- a/src/sentry/identity/oauth2.py +++ b/src/sentry/identity/oauth2.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import re import secrets from time import time from typing import Any @@ -34,6 +35,7 @@ from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized from sentry.users.models.identity import Identity from sentry.utils.http import absolute_uri +from sentry.utils.strings import to_single_line_str from .base import Provider @@ -43,6 +45,9 @@ ERR_INVALID_STATE = "An error occurred while validating your request." ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service." +_SAFE_OAUTH_ERROR_RE = re.compile(r"^[A-Za-z0-9._:/ -]{1,128}$") +_SCRUBBED_OAUTH_ERROR = "" + def _redirect_url(pipeline: IdentityPipeline) -> str: associate_url = reverse( @@ -264,6 +269,23 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp return HttpResponseRedirect(redirect_uri) +def _sanitize_oauth_error_param(error: str | None) -> str | None: + """ + Return an allow-listed OAuth error string or None if the value is not display-safe. + """ + if not error: + return None + + normalized = to_single_line_str(error) + if not normalized: + return None + + if _SAFE_OAUTH_ERROR_RE.fullmatch(normalized): + return normalized + + return None + + class OAuth2CallbackView: access_token_url: str | None = None client_id: str | None = None @@ -359,11 +381,23 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp code = request.GET.get("code") if error: + sanitized_error = _sanitize_oauth_error_param(error) lifecycle.record_failure( IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, - extra={"error": error}, + extra={ + "error": ( + sanitized_error + if sanitized_error is not None + else _SCRUBBED_OAUTH_ERROR + ) + }, + ) + message = ( + f"{ERR_INVALID_STATE}\nError: {sanitized_error}" + if sanitized_error + else ERR_INVALID_STATE ) - return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}") + return pipeline.error(message) if state != pipeline.fetch_state("state"): extra = { diff --git a/tests/sentry/identity/test_oauth2.py b/tests/sentry/identity/test_oauth2.py index 3df0c8881e6d90..e5606070fba5cf 100644 --- a/tests/sentry/identity/test_oauth2.py +++ b/tests/sentry/identity/test_oauth2.py @@ -9,6 +9,7 @@ from requests.exceptions import SSLError import sentry.identity +from sentry.identity import oauth2 from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView from sentry.identity.pipeline import IdentityPipeline from sentry.identity.providers.dummy import DummyProvider @@ -41,6 +42,19 @@ def view(self): client_secret="secret-value", ) + def _mock_pipeline(self): + pipeline = MagicMock() + pipeline.provider = MagicMock(key="dummy") + pipeline.error.return_value = MagicMock() + return pipeline + + def _setup_record_event(self): + lifecycle = MagicMock() + context_manager = MagicMock() + context_manager.__enter__.return_value = lifecycle + context_manager.__exit__.return_value = None + return lifecycle, context_manager + @responses.activate def test_exchange_token_success(self, mock_record: MagicMock) -> None: responses.add( @@ -157,6 +171,43 @@ def test_api_error(self, mock_record: MagicMock) -> None: assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}')) + def test_malicious_error_param_is_scrubbed(self, mock_record: MagicMock) -> None: + request = RequestFactory().get( + "/", {"error": "1ivlcrqyy') OR 679=(SELECT 679 FROM PG_SLEEP(15))--"} + ) + request.subdomain = None + pipeline = self._mock_pipeline() + + with patch("sentry.identity.oauth2.record_event") as mock_event: + lifecycle, context_manager = self._setup_record_event() + mock_event.return_value.capture.return_value = context_manager + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2.ERR_INVALID_STATE) + lifecycle.record_failure.assert_called_once() + _, kwargs = lifecycle.record_failure.call_args + assert kwargs["extra"]["error"] == "" + + def test_safe_error_param_is_preserved(self, mock_record: MagicMock) -> None: + request = RequestFactory().get("/", {"error": "access_denied"}) + request.subdomain = None + pipeline = self._mock_pipeline() + + with patch("sentry.identity.oauth2.record_event") as mock_event: + lifecycle, context_manager = self._setup_record_event() + mock_event.return_value.capture.return_value = context_manager + + response = self.view.dispatch(request, pipeline) + + expected_message = f"{oauth2.ERR_INVALID_STATE}\nError: access_denied" + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(expected_message) + lifecycle.record_failure.assert_called_once() + _, kwargs = lifecycle.record_failure.call_args + assert kwargs["extra"]["error"] == "access_denied" + @control_silo_test class OAuth2LoginViewTest(TestCase):