Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import re
import secrets
from time import time
from typing import Any
Expand Down Expand Up @@ -34,6 +35,7 @@
from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized
from sentry.users.models.identity import Identity
from sentry.utils.http import absolute_uri
from sentry.utils.strings import to_single_line_str, truncatechars

from .base import Provider

Expand All @@ -42,6 +44,23 @@
logger = logging.getLogger(__name__)
ERR_INVALID_STATE = "An error occurred while validating your request."
ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service."
_CONTROL_CHAR_RE = re.compile(r"[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f]")
_MAX_PROVIDER_ERROR_LENGTH = 256


def _sanitize_provider_error_message(raw_error: Any) -> str:
"""
Collapse control characters and multi-line payloads from untrusted provider errors so they can
be safely surfaced in logs and the pipeline error template.
"""

if raw_error is None:
return ""

error = str(raw_error)
error = _CONTROL_CHAR_RE.sub(" ", error)
error = to_single_line_str(error)
return truncatechars(error, _MAX_PROVIDER_ERROR_LENGTH) or ""


def _redirect_url(pipeline: IdentityPipeline) -> str:
Expand Down Expand Up @@ -359,11 +378,17 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp
code = request.GET.get("code")

if error:
sanitized_error = _sanitize_provider_error_message(error)
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra={"error": error},
extra={"error": sanitized_error},
)
return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}")
error_message = (
f"{ERR_INVALID_STATE}\nError: {sanitized_error}"
if sanitized_error
else ERR_INVALID_STATE
)
return pipeline.error(error_message)

if state != pipeline.fetch_state("state"):
extra = {
Expand All @@ -386,12 +411,18 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp

# these errors are based off of the results of exchange_token, lifecycle errors are captured inside
if "error_description" in data:
error = data.get("error")
return pipeline.error(data["error_description"])
sanitized_description = _sanitize_provider_error_message(data.get("error_description"))
return pipeline.error(sanitized_description or ERR_TOKEN_RETRIEVAL)

if "error" in data:
logger.info("identity.token-exchange-error", extra={"error": data["error"]})
return pipeline.error(f"{ERR_TOKEN_RETRIEVAL}\nError: {data['error']}")
sanitized_error = _sanitize_provider_error_message(data["error"])
logger.info("identity.token-exchange-error", extra={"error": sanitized_error})
error_message = (
f"{ERR_TOKEN_RETRIEVAL}\nError: {sanitized_error}"
if sanitized_error
else ERR_TOKEN_RETRIEVAL
)
return pipeline.error(error_message)

# we can either expect the API to be implicit and say "im looking for
# blah within state data" or we need to pass implementation + call a
Expand Down
60 changes: 59 additions & 1 deletion tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from requests.exceptions import SSLError

import sentry.identity
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.oauth2 import (
ERR_INVALID_STATE,
ERR_TOKEN_RETRIEVAL,
OAuth2CallbackView,
OAuth2LoginView,
)
from sentry.identity.pipeline import IdentityPipeline
from sentry.identity.providers.dummy import DummyProvider
from sentry.integrations.types import EventLifecycleOutcome
Expand Down Expand Up @@ -41,6 +46,13 @@ def view(self):
client_secret="secret-value",
)

def _build_pipeline(self, state: str = "expected-state") -> MagicMock:
pipeline = MagicMock()
pipeline.provider.key = "dummy"
pipeline.config = {}
pipeline.fetch_state.return_value = state
return pipeline

@responses.activate
def test_exchange_token_success(self, mock_record: MagicMock) -> None:
responses.add(
Expand Down Expand Up @@ -157,6 +169,52 @@ def test_api_error(self, mock_record: MagicMock) -> None:

assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}'))

def test_callback_error_parameter_is_sanitized(self, mock_record: MagicMock) -> None:
pipeline = self._build_pipeline()
sentinel_response = object()
pipeline.error.return_value = sentinel_response

request = RequestFactory().get("/", {"error": "bad \n<script>alert(1)</script>"})

response = self.view.dispatch(request, pipeline)

expected_message = f"{ERR_INVALID_STATE}\nError: bad <script>alert(1)</script>"
pipeline.error.assert_called_once_with(expected_message)
assert response is sentinel_response

def test_error_description_is_sanitized(self, mock_record: MagicMock) -> None:
pipeline = self._build_pipeline()
sentinel_response = object()
pipeline.error.return_value = sentinel_response

request = RequestFactory().get("/", {"state": "expected-state", "code": "auth-code"})

with patch.object(
self.view, "exchange_token", return_value={"error_description": "bad \r\nvalue"}
) as mock_exchange:
response = self.view.dispatch(request, pipeline)

mock_exchange.assert_called_once()
pipeline.error.assert_called_once_with("bad value")
assert response is sentinel_response

def test_exchange_token_error_payload_is_sanitized(self, mock_record: MagicMock) -> None:
pipeline = self._build_pipeline()
sentinel_response = object()
pipeline.error.return_value = sentinel_response
request = RequestFactory().get("/", {"state": "expected-state", "code": "auth-code"})

with patch.object(self.view, "exchange_token", return_value={"error": "foo\nbar"}):
with patch("sentry.identity.oauth2.logger") as mock_logger:
response = self.view.dispatch(request, pipeline)

expected_error = "foo bar"
pipeline.error.assert_called_once_with(f"{ERR_TOKEN_RETRIEVAL}\nError: {expected_error}")
mock_logger.info.assert_called_once_with(
"identity.token-exchange-error", extra={"error": expected_error}
)
assert response is sentinel_response


@control_silo_test
class OAuth2LoginViewTest(TestCase):
Expand Down
Loading