Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import re
import secrets
from time import time
from typing import Any
Expand Down Expand Up @@ -34,6 +35,7 @@
from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized
from sentry.users.models.identity import Identity
from sentry.utils.http import absolute_uri
from sentry.utils.strings import to_single_line_str

from .base import Provider

Expand All @@ -43,6 +45,9 @@
ERR_INVALID_STATE = "An error occurred while validating your request."
ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service."

_SAFE_OAUTH_ERROR_RE = re.compile(r"^[A-Za-z0-9._:/ -]{1,128}$")
_SCRUBBED_OAUTH_ERROR = "<redacted>"


def _redirect_url(pipeline: IdentityPipeline) -> str:
associate_url = reverse(
Expand Down Expand Up @@ -264,6 +269,23 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp
return HttpResponseRedirect(redirect_uri)


def _sanitize_oauth_error_param(error: str | None) -> str | None:
"""
Return an allow-listed OAuth error string or None if the value is not display-safe.
"""
if not error:
return None

normalized = to_single_line_str(error)
if not normalized:
return None

if _SAFE_OAUTH_ERROR_RE.fullmatch(normalized):
return normalized

return None


class OAuth2CallbackView:
access_token_url: str | None = None
client_id: str | None = None
Expand Down Expand Up @@ -359,11 +381,23 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp
code = request.GET.get("code")

if error:
sanitized_error = _sanitize_oauth_error_param(error)
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra={"error": error},
extra={
"error": (
sanitized_error
if sanitized_error is not None
else _SCRUBBED_OAUTH_ERROR
)
},
)
message = (
f"{ERR_INVALID_STATE}\nError: {sanitized_error}"
if sanitized_error
else ERR_INVALID_STATE
)
return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}")
return pipeline.error(message)

if state != pipeline.fetch_state("state"):
extra = {
Expand Down
51 changes: 51 additions & 0 deletions tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from requests.exceptions import SSLError

import sentry.identity
from sentry.identity import oauth2
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityPipeline
from sentry.identity.providers.dummy import DummyProvider
Expand Down Expand Up @@ -41,6 +42,19 @@ def view(self):
client_secret="secret-value",
)

def _mock_pipeline(self):
pipeline = MagicMock()
pipeline.provider = MagicMock(key="dummy")
pipeline.error.return_value = MagicMock()
return pipeline

def _setup_record_event(self):
lifecycle = MagicMock()
context_manager = MagicMock()
context_manager.__enter__.return_value = lifecycle
context_manager.__exit__.return_value = None
return lifecycle, context_manager

@responses.activate
def test_exchange_token_success(self, mock_record: MagicMock) -> None:
responses.add(
Expand Down Expand Up @@ -157,6 +171,43 @@ def test_api_error(self, mock_record: MagicMock) -> None:

assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}'))

def test_malicious_error_param_is_scrubbed(self, mock_record: MagicMock) -> None:
request = RequestFactory().get(
"/", {"error": "1ivlcrqyy') OR 679=(SELECT 679 FROM PG_SLEEP(15))--"}
)
request.subdomain = None
pipeline = self._mock_pipeline()

with patch("sentry.identity.oauth2.record_event") as mock_event:
lifecycle, context_manager = self._setup_record_event()
mock_event.return_value.capture.return_value = context_manager

response = self.view.dispatch(request, pipeline)

assert response == pipeline.error.return_value
pipeline.error.assert_called_once_with(oauth2.ERR_INVALID_STATE)
lifecycle.record_failure.assert_called_once()
_, kwargs = lifecycle.record_failure.call_args
assert kwargs["extra"]["error"] == "<redacted>"

def test_safe_error_param_is_preserved(self, mock_record: MagicMock) -> None:
request = RequestFactory().get("/", {"error": "access_denied"})
request.subdomain = None
pipeline = self._mock_pipeline()

with patch("sentry.identity.oauth2.record_event") as mock_event:
lifecycle, context_manager = self._setup_record_event()
mock_event.return_value.capture.return_value = context_manager

response = self.view.dispatch(request, pipeline)

expected_message = f"{oauth2.ERR_INVALID_STATE}\nError: access_denied"
assert response == pipeline.error.return_value
pipeline.error.assert_called_once_with(expected_message)
lifecycle.record_failure.assert_called_once()
_, kwargs = lifecycle.record_failure.call_args
assert kwargs["extra"]["error"] == "access_denied"


@control_silo_test
class OAuth2LoginViewTest(TestCase):
Expand Down
Loading