Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/sentry/auth/providers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@
from sentry.http import safe_urlopen, safe_urlread
from sentry.models.authidentity import AuthIdentity
from sentry.utils.http import absolute_uri
from sentry.utils.oauth import KNOWN_OAUTH_ERROR_CODES, sanitize_oauth_error_code

if TYPE_CHECKING:
from sentry.auth.helper import AuthHelper

logger = logging.getLogger(__name__)
ERR_INVALID_STATE = "An error occurred while validating your request."
ERR_PROVIDER_ERROR = "The identity provider returned an error while processing your request."


def _format_provider_error_message(error_code: str | None) -> str:
if error_code and error_code in KNOWN_OAUTH_ERROR_CODES:
return f"{ERR_PROVIDER_ERROR}\nError code: {error_code}"
return ERR_PROVIDER_ERROR


def _get_redirect_url() -> str:
Expand Down Expand Up @@ -129,7 +138,8 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
code = request.GET.get("code")

if error:
return pipeline.error(error)
sanitized_error = sanitize_oauth_error_code(error)
return pipeline.error(_format_provider_error_message(sanitized_error))

if state != pipeline.fetch_state("state"):
return pipeline.error(ERR_INVALID_STATE)
Expand All @@ -143,8 +153,9 @@ def dispatch(self, request: HttpRequest, pipeline: AuthHelper) -> HttpResponseBa
return pipeline.error(data["error_description"])

if "error" in data:
logging.info("Error exchanging token: %s", data["error"])
return pipeline.error("Unable to retrieve your token")
sanitized_error = sanitize_oauth_error_code(data["error"])
logger.info("Error exchanging token", extra={"error": sanitized_error})
return pipeline.error(_format_provider_error_message(sanitized_error))

# we can either expect the API to be implicit and say "im looking for
# blah within state data" or we need to pass implementation + call a
Expand Down
18 changes: 14 additions & 4 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,24 @@
from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized
from sentry.users.models.identity import Identity
from sentry.utils.http import absolute_uri
from sentry.utils.oauth import KNOWN_OAUTH_ERROR_CODES, sanitize_oauth_error_code

from .base import Provider

__all__ = ["OAuth2Provider", "OAuth2CallbackView", "OAuth2LoginView"]

logger = logging.getLogger(__name__)
ERR_INVALID_STATE = "An error occurred while validating your request."
ERR_PROVIDER_ERROR = "The identity provider returned an error while processing your request."
ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service."


def _format_provider_error_message(error_code: str | None) -> str:
if error_code and error_code in KNOWN_OAUTH_ERROR_CODES:
return f"{ERR_PROVIDER_ERROR}\nError code: {error_code}"
return ERR_PROVIDER_ERROR


def _redirect_url(pipeline: IdentityPipeline) -> str:
associate_url = reverse(
"sentry-extension-setup",
Expand Down Expand Up @@ -359,11 +367,12 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp
code = request.GET.get("code")

if error:
sanitized_error = sanitize_oauth_error_code(error)
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra={"error": error},
extra={"error": sanitized_error or "unknown_error"},
)
return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}")
return pipeline.error(_format_provider_error_message(sanitized_error))

if state != pipeline.fetch_state("state"):
extra = {
Expand All @@ -390,8 +399,9 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp
return pipeline.error(data["error_description"])

if "error" in data:
logger.info("identity.token-exchange-error", extra={"error": data["error"]})
return pipeline.error(f"{ERR_TOKEN_RETRIEVAL}\nError: {data['error']}")
sanitized_error = sanitize_oauth_error_code(data["error"])
logger.info("identity.token-exchange-error", extra={"error": sanitized_error})
return pipeline.error(_format_provider_error_message(sanitized_error))

# we can either expect the API to be implicit and say "im looking for
# blah within state data" or we need to pass implementation + call a
Expand Down
53 changes: 53 additions & 0 deletions src/sentry/utils/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import re

__all__ = ["KNOWN_OAUTH_ERROR_CODES", "sanitize_oauth_error_code"]

_MAX_ERROR_LENGTH = 64
_INVALID_ERROR_CHARS_RE = re.compile(r"[^a-z0-9._-]")

# Common error codes returned by OAuth/OIDC providers. This list combines values from
# RFC 6749, RFC 8628, and popular provider-specific extensions so we can safely display
# known codes back to the user.
KNOWN_OAUTH_ERROR_CODES = frozenset(
{
"account_selection_required",
"access_denied",
"authorization_pending",
"consent_required",
"expired_token",
"interaction_required",
"invalid_client",
"invalid_grant",
"invalid_request",
"invalid_scope",
"login_required",
"mfa_enrollment_required",
"mfa_required",
"registration_not_supported",
"server_error",
"slow_down",
"temporarily_unavailable",
"unauthorized_client",
"unsupported_response_type",
"user_cancelled_authorize",
}
)


def sanitize_oauth_error_code(error: str | None) -> str | None:
"""
Normalize an OAuth error code so it is safe to log or compare.

- Downcases the value
- Removes all characters outside of ``[a-z0-9._-]``
- Truncates overly long values
"""

if not error:
return None

sanitized = _INVALID_ERROR_CHARS_RE.sub("", error.lower())
sanitized = sanitized[:_MAX_ERROR_LENGTH]
return sanitized or None
47 changes: 46 additions & 1 deletion tests/sentry/auth/providers/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections.abc import Mapping
from functools import cached_property
from typing import Any
from unittest.mock import MagicMock, patch

import pytest
from django.http import HttpResponse
from django.test import RequestFactory

from sentry.auth.exceptions import IdentityNotValid
from sentry.auth.providers.oauth2 import OAuth2Provider
from sentry.auth.providers.oauth2 import ERR_PROVIDER_ERROR, OAuth2Callback, OAuth2Provider
from sentry.models.authidentity import AuthIdentity
from sentry.models.authprovider import AuthProvider
from sentry.testutils.cases import TestCase
Expand Down Expand Up @@ -48,3 +51,45 @@ def test_refresh_identity_without_refresh_token(self) -> None:
provider = DummyOAuth2Provider()
with pytest.raises(IdentityNotValid):
provider.refresh_identity(auth_identity)


@control_silo_test
class OAuth2CallbackErrorHandlingTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.callback = OAuth2Callback(client_id="client-id", client_secret="secret")
self.factory = RequestFactory()

def test_error_query_param_not_reflected(self) -> None:
request = self.factory.get("/", {"error": "1-1)) OR 114=(SELECT 114 FROM PG_SLEEP(15))--"})
pipeline = MagicMock()
pipeline.error.return_value = HttpResponse()

response = self.callback.dispatch(request, pipeline)

assert response == pipeline.error.return_value
(message,) = pipeline.error.call_args[0]
assert message == ERR_PROVIDER_ERROR
assert "PG_SLEEP" not in message

def test_error_query_param_known_code(self) -> None:
request = self.factory.get("/", {"error": "access_denied"})
pipeline = MagicMock()
pipeline.error.return_value = HttpResponse()

self.callback.dispatch(request, pipeline)

(message,) = pipeline.error.call_args[0]
assert message.endswith("access_denied")

def test_exchange_token_error_known_code(self) -> None:
request = self.factory.get("/", {"state": "abc", "code": "123"})
pipeline = MagicMock()
pipeline.error.return_value = HttpResponse()
pipeline.fetch_state.return_value = "abc"

with patch.object(self.callback, "exchange_token", return_value={"error": "access_denied"}):
self.callback.dispatch(request, pipeline)

(message,) = pipeline.error.call_args[0]
assert "access_denied" in message
30 changes: 29 additions & 1 deletion tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from urllib.parse import parse_qs, parse_qsl, urlparse

import responses
from django.http import HttpResponse
from django.test import Client, RequestFactory
from requests.exceptions import SSLError

import sentry.identity
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.oauth2 import ERR_PROVIDER_ERROR, OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityPipeline
from sentry.identity.providers.dummy import DummyProvider
from sentry.integrations.types import EventLifecycleOutcome
from sentry.integrations.utils.metrics import IntegrationPipelineErrorReason
from sentry.shared_integrations.exceptions import ApiUnauthorized
from sentry.testutils.asserts import assert_failure_metric, assert_slo_metric
from sentry.testutils.silo import control_silo_test
Expand Down Expand Up @@ -157,6 +159,32 @@ def test_api_error(self, mock_record: MagicMock) -> None:

assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}'))

def test_error_query_param_not_reflected(self, mock_record: MagicMock) -> None:
pipeline = IdentityPipeline(request=self.request, provider_key="dummy")
malicious_error = "1-1)) OR 114=(SELECT 114 FROM PG_SLEEP(15))--"
request = RequestFactory().get("/", {"error": malicious_error})

with patch.object(pipeline, "error", return_value=HttpResponse()) as error_mock:
response = self.view.dispatch(request, pipeline)

assert response == error_mock.return_value
(message,) = error_mock.call_args[0]
assert message == ERR_PROVIDER_ERROR
assert "PG_SLEEP" not in message
assert_failure_metric(
mock_record, IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE
)

def test_error_query_param_known_code(self, mock_record: MagicMock) -> None:
pipeline = IdentityPipeline(request=self.request, provider_key="dummy")
request = RequestFactory().get("/", {"error": "access_denied"})

with patch.object(pipeline, "error", return_value=HttpResponse()) as error_mock:
self.view.dispatch(request, pipeline)

(message,) = error_mock.call_args[0]
assert "access_denied" in message


@control_silo_test
class OAuth2LoginViewTest(TestCase):
Expand Down
Loading