diff --git a/src/sentry/identity/oauth2.py b/src/sentry/identity/oauth2.py index 6c1f6f0815b97a..e3f47cafa39505 100644 --- a/src/sentry/identity/oauth2.py +++ b/src/sentry/identity/oauth2.py @@ -245,10 +245,28 @@ def get_authorize_params(self, state, redirect_uri): @method_decorator(csrf_exempt) def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResponseBase: - with record_event(IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key).capture(): - for param in ("code", "error", "state"): - if param in request.GET: - return pipeline.next_step() + with record_event( + IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key + ).capture() as lifecycle: + callback_attempt = any(param in request.GET for param in ("code", "error", "state")) + if callback_attempt: + request_state = request.GET.get("state") + pipeline_state = pipeline.fetch_state("state") + + if not request_state or request_state != pipeline_state: + extra = { + "request_state": request_state, + "pipeline_state": pipeline_state, + "has_code": "code" in request.GET, + "has_error": "error" in request.GET, + } + lifecycle.record_failure( + IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, + extra=extra, + ) + return pipeline.error(ERR_INVALID_STATE) + + return pipeline.next_step() state = secrets.token_hex() diff --git a/tests/sentry/identity/test_oauth2.py b/tests/sentry/identity/test_oauth2.py index 3df0c8881e6d90..589eca408a5f45 100644 --- a/tests/sentry/identity/test_oauth2.py +++ b/tests/sentry/identity/test_oauth2.py @@ -5,11 +5,12 @@ from urllib.parse import parse_qs, parse_qsl, urlparse import responses +from django.http import HttpResponse from django.test import Client, RequestFactory from requests.exceptions import SSLError import sentry.identity -from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView +from sentry.identity.oauth2 import ERR_INVALID_STATE, OAuth2CallbackView, OAuth2LoginView from sentry.identity.pipeline import IdentityPipeline from sentry.identity.providers.dummy import DummyProvider from sentry.integrations.types import EventLifecycleOutcome @@ -194,6 +195,45 @@ def test_simple(self) -> None: assert query["scope"][0] == "all-the-things" assert "state" in query + def _build_request(self, params: dict[str, str] | None = None): + request = RequestFactory().get("/", params or {}) + request.session = Client().session + request.subdomain = None + return request + + def test_callback_missing_state_is_rejected(self) -> None: + request = self._build_request({"code": "auth-code"}) + pipeline = IdentityPipeline(request=request, provider_key="dummy") + pipeline.bind_state("state", "expected") + + with patch.object(pipeline, "error", return_value=HttpResponse("error")) as mock_error: + response = self.view.dispatch(request, pipeline) + + assert response.content == b"error" + mock_error.assert_called_once_with(ERR_INVALID_STATE) + + def test_callback_mismatched_state_is_rejected(self) -> None: + request = self._build_request({"code": "auth-code", "state": "unexpected"}) + pipeline = IdentityPipeline(request=request, provider_key="dummy") + pipeline.bind_state("state", "expected") + + with patch.object(pipeline, "error", return_value=HttpResponse("error")) as mock_error: + response = self.view.dispatch(request, pipeline) + + assert response.content == b"error" + mock_error.assert_called_once_with(ERR_INVALID_STATE) + + def test_callback_with_matching_state_advances_pipeline(self) -> None: + request = self._build_request({"code": "auth-code", "state": "expected"}) + pipeline = IdentityPipeline(request=request, provider_key="dummy") + pipeline.bind_state("state", "expected") + + with patch.object(pipeline, "next_step", return_value=HttpResponse("ok")) as mock_next: + response = self.view.dispatch(request, pipeline) + + assert response.content == b"ok" + mock_next.assert_called_once_with() + def test_customer_domains(self) -> None: self.request.subdomain = "albertos-apples" pipeline = IdentityPipeline(request=self.request, provider_key="dummy")