Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import hashlib
import logging
import re
import secrets
import string
from time import time
from typing import Any
from urllib.parse import parse_qsl, urlencode
Expand Down Expand Up @@ -42,6 +45,34 @@
logger = logging.getLogger(__name__)
ERR_INVALID_STATE = "An error occurred while validating your request."
ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service."
_STATE_VALUE_PATTERN = re.compile(r"^[a-f0-9]{8,128}$")
_SAFE_PROVIDER_ERROR_CHARS = frozenset(string.ascii_letters + string.digits + " ._-/:")


def _summarize_sensitive_value(
value: str | None, *, prefix: str, pattern: re.Pattern[str] | None = None
) -> dict[str, Any]:
summary: dict[str, Any] = {f"{prefix}_present": value is not None}
if value is None:
return summary

summary[f"{prefix}_length"] = len(value)
summary[f"{prefix}_sha256"] = hashlib.sha256(value.encode("utf-8")).hexdigest()
if pattern is not None:
summary[f"{prefix}_matches_expected_format"] = bool(pattern.fullmatch(value))

return summary


def _sanitize_provider_error(error: str | None) -> str | None:
if not error:
return None

trimmed = error.strip()
if trimmed and len(trimmed) <= 128 and all(ch in _SAFE_PROVIDER_ERROR_CHARS for ch in trimmed):
return trimmed

return None


def _redirect_url(pipeline: IdentityPipeline) -> str:
Expand Down Expand Up @@ -359,19 +390,37 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp
code = request.GET.get("code")

if error:
sanitized_error = _sanitize_provider_error(error)
extra = {
"error": sanitized_error or "provider_error_redacted",
}
extra.update(_summarize_sensitive_value(error, prefix="provider_error"))
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE,
extra={"error": error},
extra=extra,
)
return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}")
message = ERR_INVALID_STATE
if sanitized_error:
message = f"{ERR_INVALID_STATE}\nError: {sanitized_error}"
return pipeline.error(message)

if state != pipeline.fetch_state("state"):
expected_state = pipeline.fetch_state("state")
if state != expected_state:
extra = {
"error": "invalid_state",
"state": state,
"pipeline_state": pipeline.fetch_state("state"),
"code": code,
}
extra.update(
_summarize_sensitive_value(
state, prefix="provided_state", pattern=_STATE_VALUE_PATTERN
)
)
extra.update(
_summarize_sensitive_value(
expected_state, prefix="expected_state", pattern=_STATE_VALUE_PATTERN
)
)
if code:
extra.update(_summarize_sensitive_value(code, prefix="code"))
lifecycle.record_failure(
IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, extra=extra
)
Expand Down
91 changes: 90 additions & 1 deletion tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
from collections import namedtuple
from functools import cached_property
from unittest import TestCase
from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs, parse_qsl, urlparse

import responses
from django.http import HttpResponse
from django.test import Client, RequestFactory
from requests.exceptions import SSLError

import sentry.identity
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.oauth2 import ERR_INVALID_STATE, OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityPipeline
from sentry.identity.providers.dummy import DummyProvider
from sentry.integrations.types import EventLifecycleOutcome
Expand Down Expand Up @@ -157,6 +159,93 @@ def test_api_error(self, mock_record: MagicMock) -> None:

assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}'))

@patch("sentry.integrations.utils.metrics.IntegrationEventLifecycle.record_failure")
def test_state_mismatch_sanitizes_logged_value(
self,
mock_record_failure: MagicMock,
mock_record_event: MagicMock,
) -> None:
pipeline = IdentityPipeline(request=self.request, provider_key="dummy")
pipeline.bind_state("state", "59bd69f591011a0cb6b64e0c0d271731")

malicious_state = "1-1); waitfor delay '0:0:15' --"
request = RequestFactory().get("/", {"state": malicious_state, "code": "auth-code"})
request.subdomain = None

with patch.object(
IdentityPipeline, "error", return_value=HttpResponse("error")
) as mock_error:
self.view.dispatch(request, pipeline)

mock_error.assert_called_once()
_, message = mock_error.call_args.args
assert message == ERR_INVALID_STATE

mock_record_failure.assert_called_once()
extra = mock_record_failure.call_args.kwargs["extra"]
assert extra["error"] == "invalid_state"
serialized_extra = json.dumps(extra)
assert "waitfor delay" not in serialized_extra
assert "auth-code" not in serialized_extra
assert extra["provided_state_present"] is True
assert extra["expected_state_present"] is True
assert extra["provided_state_matches_expected_format"] is False
assert extra["expected_state_matches_expected_format"] is True
assert "provided_state_sha256" in extra
assert "expected_state_sha256" in extra

@patch("sentry.integrations.utils.metrics.IntegrationEventLifecycle.record_failure")
def test_provider_error_is_redacted_when_invalid(
self,
mock_record_failure: MagicMock,
mock_record_event: MagicMock,
) -> None:
pipeline = IdentityPipeline(request=self.request, provider_key="dummy")
request = RequestFactory().get("/", {"error": "1-1); waitfor delay '0:0:15' --"})
request.subdomain = None

with patch.object(
IdentityPipeline, "error", return_value=HttpResponse("error")
) as mock_error:
self.view.dispatch(request, pipeline)

mock_error.assert_called_once()
_, message = mock_error.call_args.args
assert message == ERR_INVALID_STATE

mock_record_failure.assert_called_once()
extra = mock_record_failure.call_args.kwargs["extra"]
assert extra["error"] == "provider_error_redacted"
serialized_extra = json.dumps(extra)
assert "waitfor delay" not in serialized_extra
assert extra["provider_error_present"] is True
assert "provider_error_sha256" in extra

@patch("sentry.integrations.utils.metrics.IntegrationEventLifecycle.record_failure")
def test_provider_error_keeps_safe_message(
self,
mock_record_failure: MagicMock,
mock_record_event: MagicMock,
) -> None:
pipeline = IdentityPipeline(request=self.request, provider_key="dummy")
request = RequestFactory().get("/", {"error": "access_denied"})
request.subdomain = None

with patch.object(
IdentityPipeline, "error", return_value=HttpResponse("error")
) as mock_error:
self.view.dispatch(request, pipeline)

mock_error.assert_called_once()
_, message = mock_error.call_args.args
assert message == f"{ERR_INVALID_STATE}\nError: access_denied"

mock_record_failure.assert_called_once()
extra = mock_record_failure.call_args.kwargs["extra"]
assert extra["error"] == "access_denied"
assert extra["provider_error_present"] is True
assert "provider_error_sha256" in extra


@control_silo_test
class OAuth2LoginViewTest(TestCase):
Expand Down
Loading