diff --git a/src/sentry/identity/oauth2.py b/src/sentry/identity/oauth2.py index 6c1f6f0815b97a..fe0a54c2d66216 100644 --- a/src/sentry/identity/oauth2.py +++ b/src/sentry/identity/oauth2.py @@ -246,9 +246,36 @@ def get_authorize_params(self, state, redirect_uri): @method_decorator(csrf_exempt) def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResponseBase: with record_event(IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key).capture(): - for param in ("code", "error", "state"): - if param in request.GET: - return pipeline.next_step() + pipeline_state = pipeline.fetch_state("state") + request_state = request.GET.get("state") + callback_params_present = any( + param in request.GET for param in ("code", "error", "state") + ) + + if callback_params_present: + if not pipeline_state or not request_state: + logger.info( + "identity.oauth-login.missing-state", + extra={ + "provider": pipeline.provider.key, + "has_code": "code" in request.GET, + "has_error": "error" in request.GET, + "has_state": bool(request_state), + }, + ) + return pipeline.error(ERR_INVALID_STATE) + + if not secrets.compare_digest(request_state, pipeline_state): + logger.info( + "identity.oauth-login.state-mismatch", + extra={ + "provider": pipeline.provider.key, + "request_state": request_state, + }, + ) + return pipeline.error(ERR_INVALID_STATE) + + return pipeline.next_step() state = secrets.token_hex() @@ -354,29 +381,42 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp with record_event( IntegrationPipelineViewType.OAUTH_CALLBACK, pipeline.provider.key ).capture() as lifecycle: - error = request.GET.get("error") state = request.GET.get("state") + expected_state = pipeline.fetch_state("state") code = request.GET.get("code") + error = request.GET.get("error") - if error: + if not expected_state or not state: lifecycle.record_failure( IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, - extra={"error": error}, + extra={ + "error": "missing_state", + "state": state, + "pipeline_state": expected_state, + "code": code, + }, ) - return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}") + return pipeline.error(ERR_INVALID_STATE) - if state != pipeline.fetch_state("state"): - extra = { - "error": "invalid_state", - "state": state, - "pipeline_state": pipeline.fetch_state("state"), - "code": code, - } + if not secrets.compare_digest(state, expected_state): lifecycle.record_failure( - IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, extra=extra + IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, + extra={ + "error": "invalid_state", + "state": state, + "pipeline_state": expected_state, + "code": code, + }, ) return pipeline.error(ERR_INVALID_STATE) + if error: + lifecycle.record_failure( + IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, + extra={"error": error}, + ) + return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}") + if code is None: lifecycle.record_halt(IntegrationPipelineHaltReason.NO_CODE_PROVIDED) return pipeline.error("no code was provided") diff --git a/tests/sentry/identity/test_oauth2.py b/tests/sentry/identity/test_oauth2.py index 3df0c8881e6d90..610aca2efc4812 100644 --- a/tests/sentry/identity/test_oauth2.py +++ b/tests/sentry/identity/test_oauth2.py @@ -1,14 +1,17 @@ from collections import namedtuple from functools import cached_property +from types import SimpleNamespace from unittest import TestCase from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, parse_qsl, urlparse import responses +from django.http import HttpResponse from django.test import Client, RequestFactory from requests.exceptions import SSLError import sentry.identity +from sentry.identity import oauth2 as oauth2_module from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView from sentry.identity.pipeline import IdentityPipeline from sentry.identity.providers.dummy import DummyProvider @@ -20,6 +23,23 @@ MockResponse = namedtuple("MockResponse", ["headers", "content"]) +class StubPipeline: + def __init__(self, state: str | None = None): + self.provider = SimpleNamespace(key="dummy") + self.config: dict[str, object] = {} + self._state: dict[str, str] = {} + if state is not None: + self._state["state"] = state + self.error = MagicMock(return_value=HttpResponse("error")) + self.next_step = MagicMock(return_value=HttpResponse("next")) + + def bind_state(self, key: str, value: str) -> None: + self._state[key] = value + + def fetch_state(self, key: str): + return self._state.get(key) + + @control_silo_test @patch("sentry.integrations.utils.metrics.EventLifecycle.record_event") class OAuth2CallbackViewTest(TestCase): @@ -157,6 +177,42 @@ def test_api_error(self, mock_record: MagicMock) -> None: assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}')) + def test_dispatch_requires_state_before_error(self, mock_record: MagicMock) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=boom") + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() + + def test_dispatch_rejects_mismatched_state(self, mock_record: MagicMock) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=boom&state=wrong") + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() + + def test_dispatch_passes_through_error_with_valid_state(self, mock_record: MagicMock) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=boom&state=expected-state") + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once() + error_message = pipeline.error.call_args[0][0] + assert "boom" in error_message + assert error_message.startswith(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() + @control_silo_test class OAuth2LoginViewTest(TestCase): @@ -209,3 +265,39 @@ def test_customer_domains(self) -> None: assert query["response_type"][0] == "code" assert query["scope"][0] == "all-the-things" assert "state" in query + + def test_error_callback_without_state_is_rejected(self) -> None: + pipeline = StubPipeline() + request = RequestFactory().get("/?error=bad") + request.session = Client().session + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called() + + def test_error_callback_with_valid_state_advances(self) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=bad&state=expected-state") + request.session = Client().session + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.next_step.return_value + pipeline.next_step.assert_called_once() + pipeline.error.assert_not_called() + + def test_error_callback_with_state_mismatch_is_rejected(self) -> None: + pipeline = StubPipeline(state="expected-state") + request = RequestFactory().get("/?error=bad&state=wrong") + request.session = Client().session + request.subdomain = None + + response = self.view.dispatch(request, pipeline) + + assert response == pipeline.error.return_value + pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE) + pipeline.next_step.assert_not_called()