Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,28 @@ def get_authorize_params(self, state, redirect_uri):

@method_decorator(csrf_exempt)
def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResponseBase:
with record_event(IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key).capture():
for param in ("code", "error", "state"):
if param in request.GET:
return pipeline.next_step()
with record_event(
IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key
).capture() as lifecycle:
callback_attempt = any(param in request.GET for param in ("code", "error", "state"))
if callback_attempt:
request_state = request.GET.get("state")
pipeline_state = pipeline.fetch_state("state")

if not request_state or request_state != pipeline_state:
extra = {
"request_state": request_state,
"pipeline_state": pipeline_state,
"has_code": "code" in request.GET,
"has_error": "error" in request.GET,
}
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra=extra,
)
return pipeline.error(ERR_INVALID_STATE)

return pipeline.next_step()

state = secrets.token_hex()

Expand Down
42 changes: 41 additions & 1 deletion tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from urllib.parse import parse_qs, parse_qsl, urlparse

import responses
from django.http import HttpResponse
from django.test import Client, RequestFactory
from requests.exceptions import SSLError

import sentry.identity
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.oauth2 import ERR_INVALID_STATE, OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityPipeline
from sentry.identity.providers.dummy import DummyProvider
from sentry.integrations.types import EventLifecycleOutcome
Expand Down Expand Up @@ -194,6 +195,45 @@ def test_simple(self) -> None:
assert query["scope"][0] == "all-the-things"
assert "state" in query

def _build_request(self, params: dict[str, str] | None = None):
request = RequestFactory().get("/", params or {})
request.session = Client().session
request.subdomain = None
return request

def test_callback_missing_state_is_rejected(self) -> None:
request = self._build_request({"code": "auth-code"})
pipeline = IdentityPipeline(request=request, provider_key="dummy")
pipeline.bind_state("state", "expected")

with patch.object(pipeline, "error", return_value=HttpResponse("error")) as mock_error:
response = self.view.dispatch(request, pipeline)

assert response.content == b"error"
mock_error.assert_called_once_with(ERR_INVALID_STATE)

def test_callback_mismatched_state_is_rejected(self) -> None:
request = self._build_request({"code": "auth-code", "state": "unexpected"})
pipeline = IdentityPipeline(request=request, provider_key="dummy")
pipeline.bind_state("state", "expected")

with patch.object(pipeline, "error", return_value=HttpResponse("error")) as mock_error:
response = self.view.dispatch(request, pipeline)

assert response.content == b"error"
mock_error.assert_called_once_with(ERR_INVALID_STATE)

def test_callback_with_matching_state_advances_pipeline(self) -> None:
request = self._build_request({"code": "auth-code", "state": "expected"})
pipeline = IdentityPipeline(request=request, provider_key="dummy")
pipeline.bind_state("state", "expected")

with patch.object(pipeline, "next_step", return_value=HttpResponse("ok")) as mock_next:
response = self.view.dispatch(request, pipeline)

assert response.content == b"ok"
mock_next.assert_called_once_with()

def test_customer_domains(self) -> None:
self.request.subdomain = "albertos-apples"
pipeline = IdentityPipeline(request=self.request, provider_key="dummy")
Expand Down
Loading