Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,36 @@ def get_authorize_params(self, state, redirect_uri):
@method_decorator(csrf_exempt)
def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResponseBase:
with record_event(IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key).capture():
for param in ("code", "error", "state"):
if param in request.GET:
return pipeline.next_step()
pipeline_state = pipeline.fetch_state("state")
request_state = request.GET.get("state")
callback_params_present = any(
param in request.GET for param in ("code", "error", "state")
)

if callback_params_present:
if not pipeline_state or not request_state:
logger.info(
"identity.oauth-login.missing-state",
extra={
"provider": pipeline.provider.key,
"has_code": "code" in request.GET,
"has_error": "error" in request.GET,
"has_state": bool(request_state),
},
)
return pipeline.error(ERR_INVALID_STATE)

if not secrets.compare_digest(request_state, pipeline_state):
logger.info(
"identity.oauth-login.state-mismatch",
extra={
"provider": pipeline.provider.key,
"request_state": request_state,
},
)
return pipeline.error(ERR_INVALID_STATE)

return pipeline.next_step()

state = secrets.token_hex()

Expand Down Expand Up @@ -354,29 +381,42 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp
with record_event(
IntegrationPipelineViewType.OAUTH_CALLBACK, pipeline.provider.key
).capture() as lifecycle:
error = request.GET.get("error")
state = request.GET.get("state")
expected_state = pipeline.fetch_state("state")
code = request.GET.get("code")
error = request.GET.get("error")

if error:
if not expected_state or not state:
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra={"error": error},
extra={
"error": "missing_state",
"state": state,
"pipeline_state": expected_state,
"code": code,
},
)
return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}")
return pipeline.error(ERR_INVALID_STATE)

if state != pipeline.fetch_state("state"):
extra = {
"error": "invalid_state",
"state": state,
"pipeline_state": pipeline.fetch_state("state"),
"code": code,
}
if not secrets.compare_digest(state, expected_state):
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, extra=extra
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra={
"error": "invalid_state",
"state": state,
"pipeline_state": expected_state,
"code": code,
},
)
return pipeline.error(ERR_INVALID_STATE)

if error:
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra={"error": error},
)
return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}")

if code is None:
lifecycle.record_halt(IntegrationPipelineHaltReason.NO_CODE_PROVIDED)
return pipeline.error("no code was provided")
Expand Down
92 changes: 92 additions & 0 deletions tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from collections import namedtuple
from functools import cached_property
from types import SimpleNamespace
from unittest import TestCase
from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs, parse_qsl, urlparse

import responses
from django.http import HttpResponse
from django.test import Client, RequestFactory
from requests.exceptions import SSLError

import sentry.identity
from sentry.identity import oauth2 as oauth2_module
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityPipeline
from sentry.identity.providers.dummy import DummyProvider
Expand All @@ -20,6 +23,23 @@
MockResponse = namedtuple("MockResponse", ["headers", "content"])


class StubPipeline:
def __init__(self, state: str | None = None):
self.provider = SimpleNamespace(key="dummy")
self.config: dict[str, object] = {}
self._state: dict[str, str] = {}
if state is not None:
self._state["state"] = state
self.error = MagicMock(return_value=HttpResponse("error"))
self.next_step = MagicMock(return_value=HttpResponse("next"))

def bind_state(self, key: str, value: str) -> None:
self._state[key] = value

def fetch_state(self, key: str):
return self._state.get(key)


@control_silo_test
@patch("sentry.integrations.utils.metrics.EventLifecycle.record_event")
class OAuth2CallbackViewTest(TestCase):
Expand Down Expand Up @@ -157,6 +177,42 @@ def test_api_error(self, mock_record: MagicMock) -> None:

assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}'))

def test_dispatch_requires_state_before_error(self, mock_record: MagicMock) -> None:
pipeline = StubPipeline(state="expected-state")
request = RequestFactory().get("/?error=boom")
request.subdomain = None

response = self.view.dispatch(request, pipeline)

assert response == pipeline.error.return_value
pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE)
pipeline.next_step.assert_not_called()

def test_dispatch_rejects_mismatched_state(self, mock_record: MagicMock) -> None:
pipeline = StubPipeline(state="expected-state")
request = RequestFactory().get("/?error=boom&state=wrong")
request.subdomain = None

response = self.view.dispatch(request, pipeline)

assert response == pipeline.error.return_value
pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE)
pipeline.next_step.assert_not_called()

def test_dispatch_passes_through_error_with_valid_state(self, mock_record: MagicMock) -> None:
pipeline = StubPipeline(state="expected-state")
request = RequestFactory().get("/?error=boom&state=expected-state")
request.subdomain = None

response = self.view.dispatch(request, pipeline)

assert response == pipeline.error.return_value
pipeline.error.assert_called_once()
error_message = pipeline.error.call_args[0][0]
assert "boom" in error_message
assert error_message.startswith(oauth2_module.ERR_INVALID_STATE)
pipeline.next_step.assert_not_called()


@control_silo_test
class OAuth2LoginViewTest(TestCase):
Expand Down Expand Up @@ -209,3 +265,39 @@ def test_customer_domains(self) -> None:
assert query["response_type"][0] == "code"
assert query["scope"][0] == "all-the-things"
assert "state" in query

def test_error_callback_without_state_is_rejected(self) -> None:
pipeline = StubPipeline()
request = RequestFactory().get("/?error=bad")
request.session = Client().session
request.subdomain = None

response = self.view.dispatch(request, pipeline)

assert response == pipeline.error.return_value
pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE)
pipeline.next_step.assert_not_called()

def test_error_callback_with_valid_state_advances(self) -> None:
pipeline = StubPipeline(state="expected-state")
request = RequestFactory().get("/?error=bad&state=expected-state")
request.session = Client().session
request.subdomain = None

response = self.view.dispatch(request, pipeline)

assert response == pipeline.next_step.return_value
pipeline.next_step.assert_called_once()
pipeline.error.assert_not_called()

def test_error_callback_with_state_mismatch_is_rejected(self) -> None:
pipeline = StubPipeline(state="expected-state")
request = RequestFactory().get("/?error=bad&state=wrong")
request.session = Client().session
request.subdomain = None

response = self.view.dispatch(request, pipeline)

assert response == pipeline.error.return_value
pipeline.error.assert_called_once_with(oauth2_module.ERR_INVALID_STATE)
pipeline.next_step.assert_not_called()
Loading