Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions ddtestpy/internal/ddtrace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import typing as t

from ddtestpy.internal.utils import DDTESTOPT_ROOT_SPAN_RESOURCE
from ddtestpy.internal.utils import DDTraceTestContext
from ddtestpy.internal.utils import PlainTestContext
from ddtestpy.internal.utils import TestContext
from ddtestpy.internal.utils import _gen_item_id
from ddtestpy.internal.writer import TestOptWriter


Expand Down Expand Up @@ -82,7 +83,7 @@ def trace_context(ddtrace_enabled: bool) -> t.ContextManager[TestContext]:


@contextlib.contextmanager
def _ddtrace_context() -> t.Generator[TestContext, None, None]:
def _ddtrace_context() -> t.Generator[DDTraceTestContext, None, None]:
import ddtrace

# TODO: check if this breaks async tests.
Expand All @@ -91,9 +92,9 @@ def _ddtrace_context() -> t.Generator[TestContext, None, None]:
ddtrace.tracer.context_provider.activate(None) # type: ignore[attr-defined]

with ddtrace.tracer.trace(DDTESTOPT_ROOT_SPAN_RESOURCE) as root_span: # type: ignore[attr-defined]
yield TestContext(trace_id=root_span.trace_id % (1 << 64), span_id=root_span.span_id % (1 << 64))
yield DDTraceTestContext(root_span)


@contextlib.contextmanager
def _plain_context() -> t.Generator[TestContext, None, None]:
yield TestContext(trace_id=_gen_item_id(), span_id=_gen_item_id())
def _plain_context() -> t.Generator[PlainTestContext, None, None]:
yield PlainTestContext()
5 changes: 5 additions & 0 deletions ddtestpy/internal/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def get_or_create_child(self, name: str) -> t.Tuple[TChildClass, bool]:
def set_tags(self, tags: t.Dict[str, str]) -> None:
self.tags.update(tags)

def set_metrics(self, metrics: t.Dict[str, float]) -> None:
self.metrics.update(metrics)


class TestRun(TestItem["Test", t.NoReturn]):
__test__ = False
Expand All @@ -151,6 +154,8 @@ def __init__(self, name: str, parent: Test) -> None:
def set_context(self, context: TestContext) -> None:
self.span_id = context.span_id
self.trace_id = context.trace_id
self.set_tags(context.get_tags())
self.set_metrics(context.get_metrics())


class Test(TestItem["TestSuite", "TestRun"]):
Expand Down
50 changes: 46 additions & 4 deletions ddtestpy/internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from dataclasses import dataclass
from __future__ import annotations

import random
import re
import typing as t


if t.TYPE_CHECKING:
from ddtrace.trace import Span


DDTESTOPT_ROOT_SPAN_RESOURCE = "ddtestpy_root_span"


Expand All @@ -21,15 +26,52 @@ def asbool(value: t.Union[str, bool, None]) -> bool:
return value.lower() in ("true", "1")


def ensure_text(s: t.Any) -> str:
if isinstance(s, str):
return s
if isinstance(s, bytes):
return s.decode("utf-8", errors="ignore")
return str(s)


_RE_URL = re.compile(r"(https?://|ssh://)[^/]*@")


def _filter_sensitive_info(url: t.Optional[str]) -> t.Optional[str]:
return _RE_URL.sub("\\1", url) if url is not None else None


@dataclass
class TestContext:
class TestContext(t.Protocol):
span_id: int
trace_id: int
__test__ = False

def get_tags(self) -> t.Dict[str, str]: ...

def get_metrics(self) -> t.Dict[str, float]: ...


class PlainTestContext(TestContext):
def __init__(self, span_id: t.Optional[int] = None, trace_id: t.Optional[int] = None):
self.span_id = span_id or _gen_item_id()
self.trace_id = trace_id or _gen_item_id()

def get_tags(self) -> t.Dict[str, str]:
return {}

def get_metrics(self) -> t.Dict[str, float]:
return {}


class DDTraceTestContext(TestContext):
def __init__(self, span: Span):
self.trace_id = span.trace_id % (1 << 64)
self.span_id = span.span_id % (1 << 64)
self._span = span

def get_tags(self) -> t.Dict[str, str]:
# DEV: in ddtrace < 4.x, key names can be bytes.
return {ensure_text(k): v for k, v in self._span.get_tags().items()}

def get_metrics(self) -> t.Dict[str, float]:
# DEV: in ddtrace < 4.x, key names can be bytes.
return {ensure_text(k): v for k, v in self._span.get_metrics().items()}
38 changes: 38 additions & 0 deletions tests/internal/pytest/test_pytest_ddtrace_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from unittest.mock import patch

from _pytest.pytester import Pytester
import pytest

from tests.mocks import EventCapture
from tests.mocks import mock_api_client_settings
from tests.mocks import setup_standard_mocks


class TestDDTraceTags:
@pytest.mark.slow
def test_ddtrace_tags_are_reflected_in_ddtestpy_events(self, pytester: Pytester) -> None:
pytester.makepyfile(
test_foo="""
def test_set_ddtrace_tags():
from ddtrace import tracer
tracer.current_span().set_tag("my_custom_tag", "foo")
tracer.current_span().set_tag("my_other_tag", "bar")
tracer.current_span().set_metric("my_custom_metric", 42)
"""
)

with patch(
"ddtestpy.internal.session_manager.APIClient",
return_value=mock_api_client_settings(),
), setup_standard_mocks():
with EventCapture.capture() as event_capture:
result = pytester.inline_run("--ddtestpy", "--ddtestpy-with-ddtrace", "-p", "no:ddtrace", "-v", "-s")

assert result.ret == 0

test_event = event_capture.event_by_test_name("test_set_ddtrace_tags")
assert test_event["content"]["meta"].get("my_custom_tag") == "foo"
assert test_event["content"]["meta"].get("my_other_tag") == "bar"
assert test_event["content"]["metrics"].get("my_custom_metric") == 42
21 changes: 7 additions & 14 deletions tests/internal/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for ddtestpy.internal.utils module."""

from ddtestpy.internal.utils import TestContext
from ddtestpy.internal.utils import PlainTestContext
from ddtestpy.internal.utils import _gen_item_id
from ddtestpy.internal.utils import asbool

Expand Down Expand Up @@ -68,23 +68,16 @@ def test_asbool_with_arbitrary_string(self) -> None:
assert asbool("hello") is False


class TestTestContext:
"""Tests for TestContext dataclass."""
class TestPlainTestContext:
"""Tests for PlainTestContext dataclass."""

def test_test_context_creation(self) -> None:
"""Test that TestContext can be created with span_id and trace_id."""
"""Test that PlainTestContext can be created with span_id and trace_id."""
span_id = 12345
trace_id = 67890
context = TestContext(span_id=span_id, trace_id=trace_id)
context = PlainTestContext(span_id=span_id, trace_id=trace_id)

assert context.span_id == span_id
assert context.trace_id == trace_id

def test_test_context_equality(self) -> None:
"""Test that TestContext instances with same values are equal."""
context1 = TestContext(span_id=123, trace_id=456)
context2 = TestContext(span_id=123, trace_id=456)
context3 = TestContext(span_id=123, trace_id=789)

assert context1 == context2
assert context1 != context3
assert context.get_tags() == {}
assert context.get_metrics() == {}