Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/sentry/auth/providers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sentry.http import safe_urlopen, safe_urlread
from sentry.models.authidentity import AuthIdentity
from sentry.utils.http import absolute_uri
from sentry.utils.oauth import sanitize_oauth_error

if TYPE_CHECKING:
from sentry.auth.helper import AuthHelper
Expand Down Expand Up @@ -129,7 +130,11 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
code = request.GET.get("code")

if error:
return pipeline.error(error)
safe_error = sanitize_oauth_error(error)
message = ERR_INVALID_STATE
if safe_error:
message = f"{message}\nError code: {safe_error}"
return pipeline.error(message)

if state != pipeline.fetch_state("state"):
return pipeline.error(ERR_INVALID_STATE)
Expand Down
9 changes: 7 additions & 2 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized
from sentry.users.models.identity import Identity
from sentry.utils.http import absolute_uri
from sentry.utils.oauth import sanitize_oauth_error

from .base import Provider

Expand Down Expand Up @@ -359,11 +360,15 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp
code = request.GET.get("code")

if error:
safe_error = sanitize_oauth_error(error)
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra={"error": error},
extra={"error": safe_error or "invalid_oauth_error_param"},
)
return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}")
message = ERR_INVALID_STATE
if safe_error:
message = f"{message}\nError code: {safe_error}"
return pipeline.error(message)

if state != pipeline.fetch_state("state"):
extra = {
Expand Down
37 changes: 37 additions & 0 deletions src/sentry/utils/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import string

__all__ = ["sanitize_oauth_error"]

_ALLOWED_OAUTH_ERROR_CHARS = frozenset(string.ascii_letters + string.digits + "-._~")
_MAX_OAUTH_ERROR_LENGTH = 128


def sanitize_oauth_error(
error: str | None, *, max_length: int = _MAX_OAUTH_ERROR_LENGTH
) -> str | None:
"""
Normalize an OAuth ``error`` query parameter value.

The OAuth 2.0 specification (RFC 6749) restricts the ``error`` token to a fixed set
of short identifiers comprised of unreserved URI characters. To prevent attackers
from echoing arbitrary strings back to the user or into logs, we only return a value
if it stays within the allowed character set and length budget. The result is
lower-cased for consistency.
"""

if error is None:
return None

token = error.strip()
if not token:
return None

if len(token) > max_length:
return None

if not set(token).issubset(_ALLOWED_OAUTH_ERROR_CHARS):
return None

return token.lower()
32 changes: 31 additions & 1 deletion tests/sentry/auth/providers/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections.abc import Mapping
from functools import cached_property
from typing import Any
from unittest.mock import MagicMock

import pytest
from django.test import RequestFactory

from sentry.auth.exceptions import IdentityNotValid
from sentry.auth.providers.oauth2 import OAuth2Provider
from sentry.auth.providers.oauth2 import ERR_INVALID_STATE, OAuth2Callback, OAuth2Provider
from sentry.models.authidentity import AuthIdentity
from sentry.models.authprovider import AuthProvider
from sentry.testutils.cases import TestCase
Expand Down Expand Up @@ -48,3 +50,31 @@ def test_refresh_identity_without_refresh_token(self) -> None:
provider = DummyOAuth2Provider()
with pytest.raises(IdentityNotValid):
provider.refresh_identity(auth_identity)


class OAuth2CallbackTest(TestCase):
def setUp(self) -> None:
self.request_factory = RequestFactory()
self.callback = OAuth2Callback(
access_token_url="https://example.org/token",
client_id="client-id",
client_secret="client-secret",
)

def test_error_param_uses_sanitized_code(self) -> None:
request = self.request_factory.get("/", {"error": "access_denied"})
request.subdomain = None
pipeline = MagicMock()

self.callback.dispatch(request, pipeline)

pipeline.error.assert_called_once_with(f"{ERR_INVALID_STATE}\nError code: access_denied")

def test_error_param_omitted_when_invalid(self) -> None:
request = self.request_factory.get("/", {"error": "(select(0)from(select(sleep(15)))v)"})
request.subdomain = None
pipeline = MagicMock()

self.callback.dispatch(request, pipeline)

pipeline.error.assert_called_once_with(ERR_INVALID_STATE)
31 changes: 30 additions & 1 deletion tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from requests.exceptions import SSLError

import sentry.identity
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.oauth2 import ERR_INVALID_STATE, OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityPipeline
from sentry.identity.providers.dummy import DummyProvider
from sentry.integrations.types import EventLifecycleOutcome
Expand Down Expand Up @@ -157,6 +157,35 @@ def test_api_error(self, mock_record: MagicMock) -> None:

assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}'))

def test_callback_error_param_uses_sanitized_code(self, mock_record: MagicMock) -> None:
request = RequestFactory().get("/", {"error": "access_denied"})
request.subdomain = None
pipeline = IdentityPipeline(request=request, provider_key="dummy")
pipeline.error = MagicMock()

with patch(
"sentry.integrations.utils.metrics.EventLifecycle.record_failure"
) as mock_failure:
self.view.dispatch(request, pipeline)

pipeline.error.assert_called_once_with(f"{ERR_INVALID_STATE}\nError code: access_denied")
assert mock_failure.call_args.kwargs["extra"] == {"error": "access_denied"}

def test_callback_error_param_rejected_when_invalid(self, mock_record: MagicMock) -> None:
malicious_payload = "(select(0)from(select(sleep(15)))v)"
request = RequestFactory().get("/", {"error": malicious_payload})
request.subdomain = None
pipeline = IdentityPipeline(request=request, provider_key="dummy")
pipeline.error = MagicMock()

with patch(
"sentry.integrations.utils.metrics.EventLifecycle.record_failure"
) as mock_failure:
self.view.dispatch(request, pipeline)

pipeline.error.assert_called_once_with(ERR_INVALID_STATE)
assert mock_failure.call_args.kwargs["extra"] == {"error": "invalid_oauth_error_param"}


@control_silo_test
class OAuth2LoginViewTest(TestCase):
Expand Down
Loading