Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 5 additions & 145 deletions src/_pytest/warnings.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,8 @@
import sys
import warnings
from contextlib import contextmanager
from typing import Generator
from typing import Optional
from typing import TYPE_CHECKING

import pytest
from _pytest.config import apply_warning_filters
from _pytest.config import Config
from _pytest.config import parse_warning_filter
from _pytest.main import Session
from _pytest.nodes import Item
from _pytest.terminal import TerminalReporter
from _pytest.warning_types import PytestWarning

if TYPE_CHECKING:
from typing_extensions import Literal


def pytest_configure(config: Config) -> None:
config.addinivalue_line(
"markers",
"filterwarnings(warning): add a warning filter to the given test. "
"see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#pytest-mark-filterwarnings ",
)


@contextmanager
def catch_warnings_for_item(
config: Config,
ihook,
when: "Literal['config', 'collect', 'runtest']",
item: Optional[Item],
) -> Generator[None, None, None]:
"""Context manager that catches warnings generated in the contained execution block.

``item`` can be None if we are not in the context of an item execution.

Each warning captured triggers the ``pytest_warning_recorded`` hook.
class ReturnTestWarning(PytestWarning):
"""
A warning raised when a test function returns a value other than None,
and strict return value checking is enabled.
"""
config_filters = config.getini("filterwarnings")
cmdline_filters = config.known_args_namespace.pythonwarnings or []
with warnings.catch_warnings(record=True) as log:
# mypy can't infer that record=True means log is not None; help it.
assert log is not None

if not sys.warnoptions:
# If user is not explicitly configuring warning filters, show deprecation warnings by default (#2908).
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.filterwarnings("always", category=PendingDeprecationWarning)

apply_warning_filters(config_filters, cmdline_filters)

# apply filters from "filterwarnings" marks
nodeid = "" if item is None else item.nodeid
if item is not None:
for mark in item.iter_markers(name="filterwarnings"):
for arg in mark.args:
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))

yield

for warning_message in log:
ihook.pytest_warning_recorded.call_historic(
kwargs=dict(
warning_message=warning_message,
nodeid=nodeid,
when=when,
location=None,
)
)


def warning_record_to_str(warning_message: warnings.WarningMessage) -> str:
"""Convert a warnings.WarningMessage to a string."""
warn_msg = warning_message.message
msg = warnings.formatwarning(
str(warn_msg),
warning_message.category,
warning_message.filename,
warning_message.lineno,
warning_message.line,
)
if warning_message.source is not None:
try:
import tracemalloc
except ImportError:
pass
else:
tb = tracemalloc.get_object_traceback(warning_message.source)
if tb is not None:
formatted_tb = "\n".join(tb.format())
# Use a leading new line to better separate the (large) output
# from the traceback to the previous warning text.
msg += f"\nObject allocated at:\n{formatted_tb}"
else:
# No need for a leading new line.
url = "https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings"
msg += "Enable tracemalloc to get traceback where the object was allocated.\n"
msg += f"See {url} for more info."
return msg


@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
with catch_warnings_for_item(
config=item.config, ihook=item.ihook, when="runtest", item=item
):
yield


@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_collection(session: Session) -> Generator[None, None, None]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="collect", item=None
):
yield


@pytest.hookimpl(hookwrapper=True)
def pytest_terminal_summary(
terminalreporter: TerminalReporter,
) -> Generator[None, None, None]:
config = terminalreporter.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
):
yield


@pytest.hookimpl(hookwrapper=True)
def pytest_sessionfinish(session: Session) -> Generator[None, None, None]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
):
yield


@pytest.hookimpl(hookwrapper=True)
def pytest_load_initial_conftests(
early_config: "Config",
) -> Generator[None, None, None]:
with catch_warnings_for_item(
config=early_config, ihook=early_config.hook, when="config", item=None
):
yield
38 changes: 38 additions & 0 deletions testing/test_return_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

import pytest
from _pytest.config import Config
from _pytest.nodes import Item

def pytest_addoption(parser):
parser.addoption("--strict-return-values", action="store_true", help="Enforce no return value from test functions")

def pytest_collection_modifyitems(config, items):
if config.getoption("--strict-return-values"):
for item in items:
def check_item(item=item):
if hasattr(item.obj, "__wrapped__"):
return
if item._itestrunresult is not None:
if item._itestrunresult.ret is not None:
pytest.fail("Test function returned non-None value", pytrace=False)

item.addfinalizer(check_item)

@pytest.fixture
def return_value_testdir(testdir):
testdir.makepyfile(test_code="""
import pytest

@pytest.mark.parametrize('value', [123, None, True])
def test_return_values(value):
return value
""")
return testdir

def test_return_values_cause_failure(return_value_testdir):
result = return_value_testdir.runpytest("--strict-return-values")
result.assert_outcomes(failed=1, passed=2)

def test_return_values_pass_without_warning(return_value_testdir):
result = return_value_testdir.runpytest()
result.assert_outcomes(passed=3)