Skip to content

Commit 4675912

Browse files
committed
Add tests for early rewrite bailout code and handle patterns with subdirectories
1 parent d53e449 commit 4675912

File tree

3 files changed

+118
-22
lines changed

3 files changed

+118
-22
lines changed

src/_pytest/assertion/rewrite.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,18 @@ def __init__(self, config):
6767
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
6868
# which might result in infinite recursion (#3506)
6969
self._writing_pyc = False
70-
self._basenames_to_check_rewrite = set('conftest',)
70+
self._basenames_to_check_rewrite = {"conftest"}
7171
self._marked_for_rewrite_cache = {}
7272
self._session_paths_checked = False
7373

7474
def set_session(self, session):
7575
self.session = session
7676
self._session_paths_checked = False
7777

78+
def _imp_find_module(self, name, path=None):
79+
"""Indirection so we can mock calls to find_module originated from the hook during testing"""
80+
return imp.find_module(name, path)
81+
7882
def find_module(self, name, path=None):
7983
if self._writing_pyc:
8084
return None
@@ -93,7 +97,7 @@ def find_module(self, name, path=None):
9397
pth = path[0]
9498
if pth is None:
9599
try:
96-
fd, fn, desc = imp.find_module(lastname, path)
100+
fd, fn, desc = self._imp_find_module(lastname, path)
97101
except ImportError:
98102
return None
99103
if fd is not None:
@@ -179,8 +183,7 @@ def _early_rewrite_bailout(self, name, state):
179183
from this class) is a major slowdown, so, this method tries to
180184
filter what we're sure won't be rewritten before getting to it.
181185
"""
182-
if not self._session_paths_checked and self.session is not None \
183-
and hasattr(self.session, '_initialpaths'):
186+
if self.session is not None and not self._session_paths_checked:
184187
self._session_paths_checked = True
185188
for path in self.session._initialpaths:
186189
# Make something as c:/projects/my_project/path.py ->
@@ -190,14 +193,18 @@ def _early_rewrite_bailout(self, name, state):
190193
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
191194

192195
# Note: conftest already by default in _basenames_to_check_rewrite.
193-
parts = name.split('.')
196+
parts = name.split(".")
194197
if parts[-1] in self._basenames_to_check_rewrite:
195198
return False
196199

197200
# For matching the name it must be as if it was a filename.
198-
parts[-1] = parts[-1] + '.py'
201+
parts[-1] = parts[-1] + ".py"
199202
fn_pypath = py.path.local(os.path.sep.join(parts))
200203
for pat in self.fnpats:
204+
# if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
205+
# on the name alone because we need to match against the full path
206+
if os.path.dirname(pat):
207+
return False
201208
if fn_pypath.fnmatch(pat):
202209
return False
203210

@@ -237,7 +244,7 @@ def _is_marked_for_rewrite(self, name, state):
237244
state.trace("matched marked file %r (from %r)" % (name, marked))
238245
self._marked_for_rewrite_cache[name] = True
239246
return True
240-
247+
241248
self._marked_for_rewrite_cache[name] = False
242249
return False
243250

@@ -289,6 +296,16 @@ def load_module(self, name):
289296
raise
290297
return sys.modules[name]
291298

299+
def is_package(self, name):
300+
try:
301+
fd, fn, desc = imp.find_module(name)
302+
except ImportError:
303+
return False
304+
if fd is not None:
305+
fd.close()
306+
tp = desc[2]
307+
return tp == imp.PKG_DIRECTORY
308+
292309
@classmethod
293310
def _register_with_pkg_resources(cls):
294311
"""

src/_pytest/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def __init__(self, config):
383383
self.trace = config.trace.root.get("collection")
384384
self._norecursepatterns = config.getini("norecursedirs")
385385
self.startdir = py.path.local()
386+
self._initialpaths = frozenset()
386387
# Keep track of any collected nodes in here, so we don't duplicate fixtures
387388
self._node_cache = {}
388389

@@ -565,7 +566,6 @@ def _tryconvertpyarg(self, x):
565566
"""Convert a dotted module name to path.
566567
567568
"""
568-
569569
try:
570570
with _patched_find_module():
571571
loader = pkgutil.find_loader(x)

testing/test_assertrewrite.py

+93-14
Original file line numberDiff line numberDiff line change
@@ -1106,22 +1106,21 @@ def test_ternary_display():
11061106

11071107

11081108
class TestIssue2121:
1109-
def test_simple(self, testdir):
1110-
testdir.tmpdir.join("tests/file.py").ensure().write(
1111-
"""
1112-
def test_simple_failure():
1113-
assert 1 + 1 == 3
1114-
"""
1115-
)
1116-
testdir.tmpdir.join("pytest.ini").write(
1117-
textwrap.dedent(
1109+
def test_rewrite_python_files_contain_subdirs(self, testdir):
1110+
testdir.makepyfile(
1111+
**{
1112+
"tests/file.py": """
1113+
def test_simple_failure():
1114+
assert 1 + 1 == 3
11181115
"""
1119-
[pytest]
1120-
python_files = tests/**.py
1121-
"""
1122-
)
1116+
}
1117+
)
1118+
testdir.makeini(
1119+
"""
1120+
[pytest]
1121+
python_files = tests/**.py
1122+
"""
11231123
)
1124-
11251124
result = testdir.runpytest()
11261125
result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3")
11271126

@@ -1153,3 +1152,83 @@ def spy_write_pyc(*args, **kwargs):
11531152
hook = AssertionRewritingHook(pytestconfig)
11541153
assert hook.find_module("test_foo") is not None
11551154
assert len(write_pyc_called) == 1
1155+
1156+
1157+
class TestEarlyRewriteBailout(object):
1158+
@pytest.fixture
1159+
def hook(self, pytestconfig, monkeypatch, testdir):
1160+
"""Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
1161+
if imp.find_module has been called.
1162+
"""
1163+
import imp
1164+
1165+
self.find_module_calls = []
1166+
self.initial_paths = set()
1167+
1168+
class StubSession(object):
1169+
_initialpaths = self.initial_paths
1170+
1171+
def isinitpath(self, p):
1172+
return p in self._initialpaths
1173+
1174+
def spy_imp_find_module(name, path):
1175+
self.find_module_calls.append(name)
1176+
return imp.find_module(name, path)
1177+
1178+
hook = AssertionRewritingHook(pytestconfig)
1179+
# use default patterns, otherwise we inherit pytest's testing config
1180+
hook.fnpats[:] = ["test_*.py", "*_test.py"]
1181+
monkeypatch.setattr(hook, "_imp_find_module", spy_imp_find_module)
1182+
hook.set_session(StubSession())
1183+
testdir.syspathinsert()
1184+
return hook
1185+
1186+
def test_basic(self, testdir, hook):
1187+
"""
1188+
Ensure we avoid calling imp.find_module when we know for sure a certain module will not be rewritten
1189+
to optimize assertion rewriting (#3918).
1190+
"""
1191+
testdir.makeconftest(
1192+
"""
1193+
import pytest
1194+
@pytest.fixture
1195+
def fix(): return 1
1196+
"""
1197+
)
1198+
testdir.makepyfile(test_foo="def test_foo(): pass")
1199+
testdir.makepyfile(bar="def bar(): pass")
1200+
foobar_path = testdir.makepyfile(foobar="def foobar(): pass")
1201+
self.initial_paths.add(foobar_path)
1202+
1203+
# conftest files should always be rewritten
1204+
assert hook.find_module("conftest") is not None
1205+
assert self.find_module_calls == ["conftest"]
1206+
1207+
# files matching "python_files" mask should always be rewritten
1208+
assert hook.find_module("test_foo") is not None
1209+
assert self.find_module_calls == ["conftest", "test_foo"]
1210+
1211+
# file does not match "python_files": early bailout
1212+
assert hook.find_module("bar") is None
1213+
assert self.find_module_calls == ["conftest", "test_foo"]
1214+
1215+
# file is an initial path (passed on the command-line): should be rewritten
1216+
assert hook.find_module("foobar") is not None
1217+
assert self.find_module_calls == ["conftest", "test_foo", "foobar"]
1218+
1219+
def test_pattern_contains_subdirectories(self, testdir, hook):
1220+
"""If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
1221+
because we need to match with the full path, which can only be found by calling imp.find_module.
1222+
"""
1223+
p = testdir.makepyfile(
1224+
**{
1225+
"tests/file.py": """
1226+
def test_simple_failure():
1227+
assert 1 + 1 == 3
1228+
"""
1229+
}
1230+
)
1231+
testdir.syspathinsert(p.dirpath())
1232+
hook.fnpats[:] = ["tests/**.py"]
1233+
assert hook.find_module("file") is not None
1234+
assert self.find_module_calls == ["file"]

0 commit comments

Comments
 (0)