From 4adf91d58bbb0d088dc9c974bebba2d87d550ef9 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 19 Nov 2025 04:22:21 +0000 Subject: [PATCH 1/2] Refactor OAuth2 state validation and error handling Co-authored-by: jenn.muengtaweepongsa --- src/sentry/identity/oauth2.py | 68 ++++++++++++++++----- tests/sentry/identity/test_oauth2.py | 91 ++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 15 deletions(-) diff --git a/src/sentry/identity/oauth2.py b/src/sentry/identity/oauth2.py index 6c1f6f0815b97a..9d0160c314d795 100644 --- a/src/sentry/identity/oauth2.py +++ b/src/sentry/identity/oauth2.py @@ -246,9 +246,34 @@ def get_authorize_params(self, state, redirect_uri): @method_decorator(csrf_exempt) def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResponseBase: with record_event(IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key).capture(): - for param in ("code", "error", "state"): - if param in request.GET: - return pipeline.next_step() + pipeline_state = pipeline.fetch_state("state") + request_state = request.GET.get("state") + callback_params_present = any(param in request.GET for param in ("code", "error", "state")) + + if callback_params_present: + if not pipeline_state or not request_state: + logger.info( + "identity.oauth-login.missing-state", + extra={ + "provider": pipeline.provider.key, + "has_code": "code" in request.GET, + "has_error": "error" in request.GET, + "has_state": bool(request_state), + }, + ) + return pipeline.error(ERR_INVALID_STATE) + + if not secrets.compare_digest(request_state, pipeline_state): + logger.info( + "identity.oauth-login.state-mismatch", + extra={ + "provider": pipeline.provider.key, + "request_state": request_state, + }, + ) + return pipeline.error(ERR_INVALID_STATE) + + return pipeline.next_step() state = secrets.token_hex() @@ -354,29 +379,42 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp with record_event( IntegrationPipelineViewType.OAUTH_CALLBACK, pipeline.provider.key ).capture() as lifecycle: - error = request.GET.get("error") state = request.GET.get("state") + expected_state = pipeline.fetch_state("state") code = request.GET.get("code") + error = request.GET.get("error") - if error: + if not expected_state or not state: lifecycle.record_failure( IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, - extra={"error": error}, + extra={ + "error": "missing_state", + "state": state, + "pipeline_state": expected_state, + "code": code, + }, ) - return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}") + return pipeline.error(ERR_INVALID_STATE) - if state != pipeline.fetch_state("state"): - extra = { - "error": "invalid_state", - "state": state, - "pipeline_state": pipeline.fetch_state("state"), - "code": code, - } + if not secrets.compare_digest(state, expected_state): lifecycle.record_failure( - IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, extra=extra + IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, + extra={ + "error": "invalid_state", + "state": state, + "pipeline_state": expected_state, + "code": code, + }, ) return pipeline.error(ERR_INVALID_STATE) + if error: + lifecycle.record_failure( + IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, + extra={"error": error}, + ) + return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}") + if code is None: lifecycle.record_halt(IntegrationPipelineHaltReason.NO_CODE_PROVIDED) return pipeline.error("no code was provided") diff --git a/tests/sentry/identity/test_oauth2.py b/tests/sentry/identity/test_oauth2.py index 3df0c8881e6d90..dd7e05b539b647 100644 --- a/tests/sentry/identity/test_oauth2.py +++ b/tests/sentry/identity/test_oauth2.py @@ -1,14 +1,17 @@ from collections import namedtuple from functools import cached_property +from types import SimpleNamespace from unittest import TestCase from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, parse_qsl, urlparse import responses +from django.http import HttpResponse from django.test import Client, RequestFactory from requests.exceptions import SSLError import sentry.identity +from sentry.identity import oauth2 as oauth2_module from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView from sentry.identity.pipeline import IdentityPipeline from sentry.identity.providers.dummy import DummyProvider @@ -20,6 +23,23 @@ MockResponse = namedtuple("MockResponse", ["headers", "content"]) +class StubPipeline: + def __init__(self, state: str | None = None): + self.provider = SimpleNamespace(key="dummy") + self.config: dict[str, object] = {} + self._state: dict[str, str] = {} + if state is not None: + self._state["state"] = state + self.error = MagicMock(return_value=HttpResponse("error")) + self.next_step = MagicMock(return_value=HttpResponse("next")) + + def bind_state(self, key: str, value: str) -> None: + self._state[key] = value + + def fetch_state(self, key: str): + return self._state.get(key) + + @control_silo_test @patch("sentry.integrations.utils.metrics.EventLifecycle.record_event") class OAuth2CallbackViewTest(TestCase): @@ -157,6 +177,41 @@ def test_api_error(self, mock_record: MagicMock) -> None: assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}')) + def test_dispatch_requires_state_before_error(self, mock_record: MagicMock) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=boom") + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() + + def test_dispatch_rejects_mismatched_state(self, mock_record: MagicMock) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=boom&state=wrong") + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() + + def test_dispatch_passes_through_error_with_valid_state(self, mock_record: MagicMock) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=boom&state=expected-state") + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once() + error_message = pipeline.error.call_args[0][0] + assert "boom" in error_message + assert error_message.startswith(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() @control_silo_test class OAuth2LoginViewTest(TestCase): @@ -209,3 +264,39 @@ def test_customer_domains(self) -> None: assert query["response_type"][0] == "code" assert query["scope"][0] == "all-the-things" assert "state" in query + + def test_error_callback_without_state_is_rejected(self) -> None: + pipeline = StubPipeline() + request = RequestFactory().get("/?error=bad") + request.session = Client().session + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() + + def test_error_callback_with_valid_state_advances(self) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=bad&state=expected-state") + request.session = Client().session + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.next_step.return_value + pipeline.next_step.assert_called_once() + pipeline.error.assert_not_called() + + def test_error_callback_with_state_mismatch_is_rejected(self) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=bad&state=wrong") + request.session = Client().session + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() From 5e363235c3ae3668690bbd7a8b255604dbadc364 Mon Sep 17 00:00:00 2001 From: "getsantry[bot]" <66042841+getsantry[bot]@users.noreply.github.com> Date: Wed, 19 Nov 2025 04:23:36 +0000 Subject: [PATCH 2/2] :hammer_and_wrench: apply pre-commit fixes --- src/sentry/identity/oauth2.py | 4 +++- tests/sentry/identity/test_oauth2.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sentry/identity/oauth2.py b/src/sentry/identity/oauth2.py index 9d0160c314d795..fe0a54c2d66216 100644 --- a/src/sentry/identity/oauth2.py +++ b/src/sentry/identity/oauth2.py @@ -248,7 +248,9 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp with record_event(IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key).capture(): pipeline_state = pipeline.fetch_state("state") request_state = request.GET.get("state") - callback_params_present = any(param in request.GET for param in ("code", "error", "state")) + callback_params_present = any( + param in request.GET for param in ("code", "error", "state") + ) if callback_params_present: if not pipeline_state or not request_state: diff --git a/tests/sentry/identity/test_oauth2.py b/tests/sentry/identity/test_oauth2.py index dd7e05b539b647..610aca2efc4812 100644 --- a/tests/sentry/identity/test_oauth2.py +++ b/tests/sentry/identity/test_oauth2.py @@ -213,6 +213,7 @@ def test_dispatch_passes_through_error_with_valid_state(self, mock_record: Magic assert error_message.startswith(oauth2_module.ERR_INVALID_STATE) pipeline.next_step.assert_not_called() + @control_silo_test class OAuth2LoginViewTest(TestCase): def setUp(self) -> None: