Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions src/sentry/net/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@

import functools
import ipaddress
import os
import socket
import threading
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from socket import timeout as SocketTimeout
from typing import TYPE_CHECKING
from urllib.parse import urlparse

from django.conf import settings
from django.utils.encoding import force_str
from urllib3.exceptions import LocationParseError
from urllib3.util.connection import _set_socket_options, allowed_gai_family
from urllib3.util.timeout import _DEFAULT_TIMEOUT, _TYPE_DEFAULT
from urllib3.util.timeout import _DEFAULT_TIMEOUT, _TYPE_DEFAULT, Timeout

from sentry.exceptions import RestrictedIPAddress

Expand All @@ -22,6 +27,11 @@
ipaddress.ip_network(str(i), strict=False) for i in settings.SENTRY_DISALLOWED_IPS
)

_DNS_THREADPOOL_LOCK = threading.Lock()
_DNS_THREADPOOL: ThreadPoolExecutor | None = None
_DNS_THREADPOOL_PID: int | None = None
_DNS_THREADPOOL_SIZE = max(1, getattr(settings, "SENTRY_DNS_RESOLUTION_MAX_WORKERS", 4))


@functools.lru_cache(maxsize=100)
def is_ipaddress_allowed(ip: str) -> bool:
Expand Down Expand Up @@ -104,6 +114,70 @@ def is_safe_hostname(hostname: str | None) -> bool:
return True


def _get_dns_executor() -> ThreadPoolExecutor:
global _DNS_THREADPOOL, _DNS_THREADPOOL_PID
pid = os.getpid()
if _DNS_THREADPOOL is None or _DNS_THREADPOOL_PID != pid:
with _DNS_THREADPOOL_LOCK:
if _DNS_THREADPOOL is not None and _DNS_THREADPOOL_PID != pid:
_DNS_THREADPOOL.shutdown(wait=False)
_DNS_THREADPOOL = None
if _DNS_THREADPOOL is None:
_DNS_THREADPOOL = ThreadPoolExecutor(
max_workers=_DNS_THREADPOOL_SIZE, thread_name_prefix="dns-resolver"
)
_DNS_THREADPOOL_PID = pid
return _DNS_THREADPOOL


def _as_float_timeout(value: object) -> float | None:
if value is None or value is _DEFAULT_TIMEOUT:
return None
try:
timeout_value = float(value) # type: ignore[arg-type]
except (TypeError, ValueError):
return None
return timeout_value


def _extract_connect_timeout(timeout: _TYPE_DEFAULT | float | Timeout | None) -> float | None:
if timeout is None or timeout is _DEFAULT_TIMEOUT:
return None

if isinstance(timeout, Timeout):
connect_timeout_bound = timeout.connect_timeout
if callable(connect_timeout_bound):
connect_timeout_bound = connect_timeout_bound()
return _as_float_timeout(connect_timeout_bound)

if hasattr(timeout, "connect_timeout"):
connect_timeout_bound = getattr(timeout, "connect_timeout")
if callable(connect_timeout_bound):
connect_timeout_bound = connect_timeout_bound()
return _as_float_timeout(connect_timeout_bound)

return _as_float_timeout(timeout)


def _resolve_addrinfo_with_timeout(
host: str, port: int, family: int, timeout: _TYPE_DEFAULT | float | Timeout | None
) -> list[tuple[int, int, int, str, tuple[str, int]]]:
connect_timeout = _extract_connect_timeout(timeout)
if connect_timeout is None:
return socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)

if connect_timeout <= 0:
raise SocketTimeout(f"timed out while resolving DNS for {host}")

executor = _get_dns_executor()
future = executor.submit(socket.getaddrinfo, host, port, family, socket.SOCK_STREAM)
try:
return future.result(connect_timeout)
except FuturesTimeoutError as exc:
future.cancel()
raise SocketTimeout(f"timed out while resolving DNS for {host}") from exc


# Modifed version of urllib3.util.connection.create_connection.
def safe_create_connection(
address: tuple[str, int],
Expand Down Expand Up @@ -134,7 +208,7 @@ def safe_create_connection(
except UnicodeError:
raise LocationParseError("'{host}', label empty or too long") from None

for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
for res in _resolve_addrinfo_with_timeout(host, port, family, timeout):
af, socktype, proto, canonname, sa = res

# Begin custom code.
Expand Down
46 changes: 45 additions & 1 deletion tests/sentry/net/test_socket.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import socket
import time
from unittest.mock import MagicMock, patch

import pytest
from django.test import override_settings

from sentry.net.socket import ensure_fqdn, is_ipaddress_allowed, is_safe_hostname
from sentry.net.socket import (
ensure_fqdn,
is_ipaddress_allowed,
is_safe_hostname,
safe_create_connection,
)
from sentry.testutils.cases import TestCase
from sentry.testutils.helpers import override_blocklist

Expand Down Expand Up @@ -41,3 +49,39 @@ def test_ensure_fqdn(self) -> None:
assert ensure_fqdn("example.com") == "example.com."
assert ensure_fqdn("127.0.0.1") == "127.0.0.1"
assert ensure_fqdn("example.com.") == "example.com."

@patch("sentry.net.socket.socket.socket")
@patch("sentry.net.socket.socket.getaddrinfo")
def test_safe_create_connection_times_out_on_slow_dns(
self, mock_getaddrinfo: MagicMock, mock_socket_ctor: MagicMock
) -> None:
def slow_lookup(*args, **kwargs):
time.sleep(0.2)
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("1.1.1.1", 443))]

mock_getaddrinfo.side_effect = slow_lookup
mock_socket_ctor.return_value = MagicMock()

with pytest.raises(socket.timeout):
safe_create_connection(("example.com", 443), timeout=0.05)

mock_socket_ctor.assert_not_called()

@patch("sentry.net.socket.socket.socket")
@patch("sentry.net.socket.socket.getaddrinfo")
def test_safe_create_connection_uses_timeout_objects(
self, mock_getaddrinfo: MagicMock, mock_socket_ctor: MagicMock
) -> None:
from urllib3.util.timeout import Timeout

def slow_lookup(*args, **kwargs):
time.sleep(0.2)
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("1.1.1.1", 443))]

mock_getaddrinfo.side_effect = slow_lookup
mock_socket_ctor.return_value = MagicMock()

with pytest.raises(socket.timeout):
safe_create_connection(("example.com", 443), timeout=Timeout(connect=0.05))

mock_socket_ctor.assert_not_called()
Loading