diff --git a/src/sentry/identity/oauth2.py b/src/sentry/identity/oauth2.py index 6c1f6f0815b97a..9100ebf6e6ae8d 100644 --- a/src/sentry/identity/oauth2.py +++ b/src/sentry/identity/oauth2.py @@ -1,7 +1,10 @@ from __future__ import annotations +import hashlib import logging +import re import secrets +import string from time import time from typing import Any from urllib.parse import parse_qsl, urlencode @@ -42,6 +45,34 @@ logger = logging.getLogger(__name__) ERR_INVALID_STATE = "An error occurred while validating your request." ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service." +_STATE_VALUE_PATTERN = re.compile(r"^[a-f0-9]{8,128}$") +_SAFE_PROVIDER_ERROR_CHARS = frozenset(string.ascii_letters + string.digits + " ._-/:") + + +def _summarize_sensitive_value( + value: str | None, *, prefix: str, pattern: re.Pattern[str] | None = None +) -> dict[str, Any]: + summary: dict[str, Any] = {f"{prefix}_present": value is not None} + if value is None: + return summary + + summary[f"{prefix}_length"] = len(value) + summary[f"{prefix}_sha256"] = hashlib.sha256(value.encode("utf-8")).hexdigest() + if pattern is not None: + summary[f"{prefix}_matches_expected_format"] = bool(pattern.fullmatch(value)) + + return summary + + +def _sanitize_provider_error(error: str | None) -> str | None: + if not error: + return None + + trimmed = error.strip() + if trimmed and len(trimmed) <= 128 and all(ch in _SAFE_PROVIDER_ERROR_CHARS for ch in trimmed): + return trimmed + + return None def _redirect_url(pipeline: IdentityPipeline) -> str: @@ -359,19 +390,37 @@ def dispatch(self, request: HttpRequest, pipeline: IdentityPipeline) -> HttpResp code = request.GET.get("code") if error: + sanitized_error = _sanitize_provider_error(error) + extra = { + "error": sanitized_error or "provider_error_redacted", + } + extra.update(_summarize_sensitive_value(error, prefix="provider_error")) lifecycle.record_failure( IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, - extra={"error": error}, + extra=extra, ) - return pipeline.error(f"{ERR_INVALID_STATE}\nError: {error}") + message = ERR_INVALID_STATE + if sanitized_error: + message = f"{ERR_INVALID_STATE}\nError: {sanitized_error}" + return pipeline.error(message) - if state != pipeline.fetch_state("state"): + expected_state = pipeline.fetch_state("state") + if state != expected_state: extra = { "error": "invalid_state", - "state": state, - "pipeline_state": pipeline.fetch_state("state"), - "code": code, } + extra.update( + _summarize_sensitive_value( + state, prefix="provided_state", pattern=_STATE_VALUE_PATTERN + ) + ) + extra.update( + _summarize_sensitive_value( + expected_state, prefix="expected_state", pattern=_STATE_VALUE_PATTERN + ) + ) + if code: + extra.update(_summarize_sensitive_value(code, prefix="code")) lifecycle.record_failure( IntegrationPipelineErrorReason.TOKEN_EXCHANGE_MISMATCHED_STATE, extra=extra ) diff --git a/tests/sentry/identity/test_oauth2.py b/tests/sentry/identity/test_oauth2.py index 3df0c8881e6d90..71dd229fce1b04 100644 --- a/tests/sentry/identity/test_oauth2.py +++ b/tests/sentry/identity/test_oauth2.py @@ -1,3 +1,4 @@ +import json from collections import namedtuple from functools import cached_property from unittest import TestCase @@ -5,11 +6,12 @@ from urllib.parse import parse_qs, parse_qsl, urlparse import responses +from django.http import HttpResponse from django.test import Client, RequestFactory from requests.exceptions import SSLError import sentry.identity -from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView +from sentry.identity.oauth2 import ERR_INVALID_STATE, OAuth2CallbackView, OAuth2LoginView from sentry.identity.pipeline import IdentityPipeline from sentry.identity.providers.dummy import DummyProvider from sentry.integrations.types import EventLifecycleOutcome @@ -157,6 +159,93 @@ def test_api_error(self, mock_record: MagicMock) -> None: assert_failure_metric(mock_record, ApiUnauthorized('{"token": "a-fake-token"}')) + @patch("sentry.integrations.utils.metrics.IntegrationEventLifecycle.record_failure") + def test_state_mismatch_sanitizes_logged_value( + self, + mock_record_failure: MagicMock, + mock_record_event: MagicMock, + ) -> None: + pipeline = IdentityPipeline(request=self.request, provider_key="dummy") + pipeline.bind_state("state", "59bd69f591011a0cb6b64e0c0d271731") + + malicious_state = "1-1); waitfor delay '0:0:15' --" + request = RequestFactory().get("/", {"state": malicious_state, "code": "auth-code"}) + request.subdomain = None + + with patch.object( + IdentityPipeline, "error", return_value=HttpResponse("error") + ) as mock_error: + self.view.dispatch(request, pipeline) + + mock_error.assert_called_once() + _, message = mock_error.call_args.args + assert message == ERR_INVALID_STATE + + mock_record_failure.assert_called_once() + extra = mock_record_failure.call_args.kwargs["extra"] + assert extra["error"] == "invalid_state" + serialized_extra = json.dumps(extra) + assert "waitfor delay" not in serialized_extra + assert "auth-code" not in serialized_extra + assert extra["provided_state_present"] is True + assert extra["expected_state_present"] is True + assert extra["provided_state_matches_expected_format"] is False + assert extra["expected_state_matches_expected_format"] is True + assert "provided_state_sha256" in extra + assert "expected_state_sha256" in extra + + @patch("sentry.integrations.utils.metrics.IntegrationEventLifecycle.record_failure") + def test_provider_error_is_redacted_when_invalid( + self, + mock_record_failure: MagicMock, + mock_record_event: MagicMock, + ) -> None: + pipeline = IdentityPipeline(request=self.request, provider_key="dummy") + request = RequestFactory().get("/", {"error": "1-1); waitfor delay '0:0:15' --"}) + request.subdomain = None + + with patch.object( + IdentityPipeline, "error", return_value=HttpResponse("error") + ) as mock_error: + self.view.dispatch(request, pipeline) + + mock_error.assert_called_once() + _, message = mock_error.call_args.args + assert message == ERR_INVALID_STATE + + mock_record_failure.assert_called_once() + extra = mock_record_failure.call_args.kwargs["extra"] + assert extra["error"] == "provider_error_redacted" + serialized_extra = json.dumps(extra) + assert "waitfor delay" not in serialized_extra + assert extra["provider_error_present"] is True + assert "provider_error_sha256" in extra + + @patch("sentry.integrations.utils.metrics.IntegrationEventLifecycle.record_failure") + def test_provider_error_keeps_safe_message( + self, + mock_record_failure: MagicMock, + mock_record_event: MagicMock, + ) -> None: + pipeline = IdentityPipeline(request=self.request, provider_key="dummy") + request = RequestFactory().get("/", {"error": "access_denied"}) + request.subdomain = None + + with patch.object( + IdentityPipeline, "error", return_value=HttpResponse("error") + ) as mock_error: + self.view.dispatch(request, pipeline) + + mock_error.assert_called_once() + _, message = mock_error.call_args.args + assert message == f"{ERR_INVALID_STATE}\nError: access_denied" + + mock_record_failure.assert_called_once() + extra = mock_record_failure.call_args.kwargs["extra"] + assert extra["error"] == "access_denied" + assert extra["provider_error_present"] is True + assert "provider_error_sha256" in extra + @control_silo_test class OAuth2LoginViewTest(TestCase):