Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 44 additions & 31 deletions python/tach/pytest_plugin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from copy import copy
from pathlib import Path
from typing import Any, Protocol
from typing import Any, Generator, Protocol

import pytest
from pytest import Collector

from tach import filesystem as fs
from tach.errors import TachSetupError
Expand All @@ -19,6 +19,10 @@ class TachConfig(Protocol):
def getoption(self, name: str) -> Any: ...


class HasTachConfig(Protocol):
config: TachConfig


def pytest_addoption(parser: pytest.Parser):
group = parser.getgroup("tach")
group.addoption(
Expand Down Expand Up @@ -59,35 +63,44 @@ def pytest_configure(config: TachConfig):
)


def pytest_collection_modifyitems(
session: pytest.Session,
config: TachConfig,
items: list[pytest.Item],
):
handler = config.tach_handler
seen: set[Path] = set()
for item in copy(items):
if not item.path:
continue
if str(item.path) in handler.removed_test_paths:
handler.num_removed_items += 1
items.remove(item)
continue
if item.path in seen:
continue

if str(item.path) in handler.all_affected_modules:
# If this test file was changed,
# then we know we need to rerun it
seen.add(item.path)
continue

if handler.should_remove_items(file_path=item.path.resolve()):
handler.num_removed_items += 1
items.remove(item)
handler.remove_test_path(item.path)

seen.add(item.path)
def _count_items(collector: Collector) -> int:
"""Recursively count test items from a collector."""
count = 0
for item in collector.collect():
if isinstance(item, Collector):
# It's a collector (e.g., Class), recurse
count += _count_items(item)
else:
# It's a test item
count += 1
return count


@pytest.hookimpl(wrapper=True)
def pytest_collect_file(
file_path: Path, parent: HasTachConfig
) -> Generator[None, list[Collector], list[Collector]]:
handler = parent.config.tach_handler
# Skip any paths that already get filtered out by other hook impls
result = yield
if not result:
return result

resolved_path = file_path.resolve()

# If this test file was changed, keep it
if str(resolved_path) in handler.all_affected_modules:
return result

# Check if file should be removed based on its imports
if handler.should_remove_items(file_path=resolved_path):
# Recursively count all test items before discarding
for collector in result:
handler.num_removed_items += _count_items(collector)
handler.remove_test_path(file_path)
return []

return result


def pytest_report_collectionfinish(
Expand Down
169 changes: 169 additions & 0 deletions python/tests/test_pytest_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from __future__ import annotations

import pytest

pytest_plugins = ["pytester"]


def makepyfile(pytester: pytest.Pytester, *args: str | bytes, **kwargs: str | bytes):
"""workaround for https://github.com/pytest-dev/pytest/pull/14080"""
_ = pytester.makepyfile(*args, **kwargs) # pyright: ignore[reportUnknownMemberType]


@pytest.fixture
def tach_project(pytester: pytest.Pytester):
"""Create a basic tach project structure."""
_ = pytester.makefile(".toml", tach='source_roots = ["."]')
makepyfile(
pytester,
src_module="""
def add(a, b):
return a + b

def subtract(a, b):
return a - b
""",
test_with_import="""
from src_module import add

def test_add_basic():
assert add(1, 2) == 3

def test_add_zero():
assert add(0, 0) == 0

def test_add_negative():
assert add(-1, 1) == 0
""",
test_no_import="""
def test_standalone_1():
assert True

def test_standalone_2():
assert 1 + 1 == 2
""",
)
# Initialize git repo
_ = pytester.run("git", "init")
_ = pytester.run("git", "config", "user.email", "[email protected]")
_ = pytester.run("git", "config", "user.name", "Test")
_ = pytester.run("git", "add", "-A")
_ = pytester.run("git", "commit", "-m", "initial")
return pytester


def run_pytest(pytester: pytest.Pytester, *args: str) -> pytest.RunResult:
"""Run pytest in subprocess to avoid PyO3 reinitialization issues."""
return pytester.runpytest_subprocess("-p", "tach.pytest_plugin", *args)


class TestPytestPluginSkipping:
def test_no_changes_skips_all_tests(self, tach_project: pytest.Pytester):
"""When there are no changes, all tests should be skipped."""
result = run_pytest(tach_project, "--tach-base", "HEAD")
result.assert_outcomes(passed=0)
result.stdout.fnmatch_lines(["*Skipped 2 test file*"])

def test_source_change_runs_dependent_tests(self, tach_project: pytest.Pytester):
"""When a source file changes, only tests that import it should run."""
# Modify the source file
makepyfile(
tach_project,
src_module="""
def add(a, b):
return a + b

def subtract(a, b):
return a - b

# Modified
""",
)
_ = tach_project.run("git", "add", "src_module.py")
_ = tach_project.run("git", "commit", "-m", "modify source")

result = run_pytest(tach_project, "--tach-base", "HEAD~1")
result.assert_outcomes(passed=3)
result.stdout.fnmatch_lines(
[
"*Skipped 1 test file*",
"*test_no_import.py*",
]
)

def test_test_file_change_runs_that_file(self, tach_project: pytest.Pytester):
"""When a test file is directly modified, it should run."""
# Modify a test file
makepyfile(
tach_project,
test_no_import="""
def test_standalone_1():
assert True

def test_standalone_2():
assert 1 + 1 == 2

def test_standalone_3():
assert "new test"
""",
)
_ = tach_project.run("git", "add", "test_no_import.py")
_ = tach_project.run("git", "commit", "-m", "add test")

result = run_pytest(tach_project, "--tach-base", "HEAD~1")
result.assert_outcomes(passed=3)
result.stdout.fnmatch_lines(["*Skipped 1 test file*"])


class TestPytestPluginCounting:
def test_counts_all_tests_in_file(self, tach_project: pytest.Pytester):
"""Should correctly count all tests including parametrized ones."""
makepyfile(
tach_project,
test_parametrized="""
import pytest

@pytest.mark.parametrize("x,y,expected", [
(1, 2, 3),
(2, 3, 5),
(10, 20, 30),
])
def test_param_add(x, y, expected):
assert x + y == expected

def test_regular():
assert True
""",
)
_ = tach_project.run("git", "add", "test_parametrized.py")
_ = tach_project.run("git", "commit", "--amend", "--no-edit")

result = run_pytest(tach_project, "--tach-base", "HEAD")
result.assert_outcomes(passed=0)
# 3 (test_with_import) + 2 (test_no_import) + 4 (test_parametrized) = 9
result.stdout.fnmatch_lines(["*Skipped 3 test file* (9 tests)*"])

def test_counts_tests_in_classes(self, tach_project: pytest.Pytester):
"""Should correctly count tests inside test classes."""
makepyfile(
tach_project,
test_class="""
class TestGroup:
def test_one(self):
assert True

def test_two(self):
assert True

class TestAnotherGroup:
def test_three(self):
assert True
""",
)
_ = tach_project.run("git", "add", "test_class.py")
_ = tach_project.run("git", "commit", "--amend", "--no-edit")

result = run_pytest(tach_project, "--tach-base", "HEAD")
result.assert_outcomes(passed=0)
# 3 (test_with_import) + 2 (test_no_import) + 3 (test_class) = 8
result.stdout.fnmatch_lines(["*Skipped 3 test file* (8 tests)*"])
Loading