Skip to content

Commit 109346a

Browse files
committed
Add invalidation-mode option
1 parent 6e814c8 commit 109346a

File tree

4 files changed

+134
-43
lines changed

4 files changed

+134
-43
lines changed

src/_pytest/assertion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class AssertionState:
8383
def __init__(self, config: Config, mode) -> None:
8484
self.mode = mode
8585
self.trace = config.trace.root.get("assertion")
86+
self.invalidation_mode = config.option.invalidationmode
8687
self.hook: Optional[rewrite.AssertionRewritingHook] = None
8788

8889

src/_pytest/assertion/rewrite.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Rewrite assertion AST to produce nice error messages."""
2+
import _imp
23
import ast
34
import errno
45
import functools
@@ -21,6 +22,7 @@
2122
from typing import Iterable
2223
from typing import Iterator
2324
from typing import List
25+
from typing import Literal
2426
from typing import Optional
2527
from typing import Sequence
2628
from typing import Set
@@ -290,23 +292,31 @@ def get_resource_reader(self, name: str) -> TraversableResources: # type: ignor
290292

291293

292294
def _write_pyc_fp(
293-
fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType
295+
fp: IO[bytes],
296+
source_stat: os.stat_result,
297+
source_hash: bytes,
298+
co: types.CodeType,
299+
invalidation_mode: Literal["timestamp", "checked-hash"],
294300
) -> None:
295301
# Technically, we don't have to have the same pyc format as
296302
# (C)Python, since these "pycs" should never be seen by builtin
297303
# import. However, there's little reason to deviate.
298304
fp.write(importlib.util.MAGIC_NUMBER)
299305
# https://www.python.org/dev/peps/pep-0552/
300-
flags = b"\x00\x00\x00\x00"
301-
fp.write(flags)
302-
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
303-
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
304-
size = source_stat.st_size & 0xFFFFFFFF
305-
# 64-bit source file hash
306-
source_hash = source_hash[:8]
307-
# "<LL" stands for 2 unsigned longs, little-endian.
308-
fp.write(struct.pack("<LL", mtime, size))
309-
fp.write(source_hash)
306+
if invalidation_mode == "timestamp":
307+
flags = b"\x00\x00\x00\x00"
308+
fp.write(flags)
309+
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
310+
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
311+
size = source_stat.st_size & 0xFFFFFFFF
312+
# "<LL" stands for 2 unsigned longs, little-endian.
313+
fp.write(struct.pack("<LL", mtime, size))
314+
elif invalidation_mode == "checked-hash":
315+
flags = b"\x03\x00\x00\x00"
316+
fp.write(flags)
317+
# 64-bit source file hash
318+
source_hash = source_hash[:8]
319+
fp.write(source_hash)
310320
fp.write(marshal.dumps(co))
311321

312322

@@ -320,7 +330,7 @@ def _write_pyc(
320330
proc_pyc = f"{pyc}.{os.getpid()}"
321331
try:
322332
with open(proc_pyc, "wb") as fp:
323-
_write_pyc_fp(fp, source_stat, source_hash, co)
333+
_write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
324334
except OSError as e:
325335
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
326336
return False
@@ -367,35 +377,41 @@ def _read_pyc(
367377
stat_result = os.stat(source)
368378
mtime = int(stat_result.st_mtime)
369379
size = stat_result.st_size
370-
data = fp.read(24)
380+
data = fp.read(16)
371381
except OSError as e:
372382
trace(f"_read_pyc({source}): OSError {e}")
373383
return None
374384
# Check for invalid or out of date pyc file.
375-
if len(data) != (24):
385+
if len(data) != (16):
376386
trace("_read_pyc(%s): invalid pyc (too short)" % source)
377387
return None
378388
if data[:4] != importlib.util.MAGIC_NUMBER:
379389
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
380390
return None
381-
if data[4:8] != b"\x00\x00\x00\x00":
382-
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
383-
return None
384-
size_data = data[12:16]
385-
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
386-
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
387-
return None
388-
mtime_data = data[8:12]
389-
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
390-
trace("_read_pyc(%s): out of date" % source)
391-
hash = data[16:24]
391+
392+
hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
393+
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
394+
trace("_read_pyc(%s): timestamp based" % source)
395+
mtime_data = data[8:12]
396+
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
397+
trace("_read_pyc(%s): out of date" % source)
398+
return None
399+
size_data = data[12:16]
400+
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
401+
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
402+
return None
403+
elif data[4:8] == b"\x03\x00\x00\x00":
404+
trace("_read_pyc(%s): hash based" % source)
405+
hash = data[8:16]
392406
# source_hash returns bytes not int: https://github.com/python/typeshed/pull/10686
393407
source_hash: bytes = importlib.util.source_hash(source.read_bytes()) # type: ignore[assignment]
394-
if source_hash[:8] == hash:
395-
trace("_read_pyc(%s): source hash match (no change detected)" % source)
396-
else:
408+
if source_hash[:8] != hash:
397409
trace("_read_pyc(%s): hash doesn't match" % source)
398410
return None
411+
else:
412+
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
413+
return None
414+
399415
try:
400416
co = marshal.load(fp)
401417
except Exception as e:

src/_pytest/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,13 @@ def pytest_addoption(parser: Parser) -> None:
215215
help="Prepend/append to sys.path when importing test modules and conftest "
216216
"files. Default: prepend.",
217217
)
218+
group.addoption(
219+
"--invalidation-mode",
220+
default="timestamp",
221+
choices=["timestamp", "checked-hash"],
222+
dest="invalidationmode",
223+
help="Pytest pyc cache invalidation mode. Default: timestamp.",
224+
)
218225

219226
group = parser.getgroup("debugconfig", "test session debugging and configuration")
220227
group.addoption(

testing/test_assertrewrite.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import _imp
12
import ast
23
import errno
34
import glob
@@ -1124,13 +1125,37 @@ def test_read_pyc_success(self, tmp_path: Path, pytester: Pytester) -> None:
11241125
_write_pyc(state, co, source_stat, hash, pyc)
11251126
assert _read_pyc(fn, pyc, state.trace) is not None
11261127

1127-
# pyc read should still work if only the mtime changed
1128-
# Fallback to hash comparison
1129-
new_mtime = source_stat.st_mtime + 1.2
1130-
os.utime(fn, (new_mtime, new_mtime))
1131-
assert source_stat.st_mtime != os.stat(fn).st_mtime
1128+
pyc_bytes = pyc.read_bytes()
1129+
assert pyc_bytes[4] == 0 # timestamp flag set
1130+
1131+
def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
1132+
from _pytest.assertion import AssertionState
1133+
from _pytest.assertion.rewrite import _read_pyc
1134+
from _pytest.assertion.rewrite import _rewrite_test
1135+
from _pytest.assertion.rewrite import _write_pyc
1136+
1137+
config = pytester.parseconfig("--invalidation-mode=checked-hash")
1138+
state = AssertionState(config, "rewrite")
1139+
1140+
fn = tmp_path / "source.py"
1141+
pyc = Path(str(fn) + "c")
1142+
1143+
# Test private attribute didn't change
1144+
assert getattr(_imp, "check_hash_based_pycs", None) in {
1145+
"default",
1146+
"always",
1147+
"never",
1148+
}
1149+
1150+
fn.write_text("def test(): assert True", encoding="utf-8")
1151+
source_stat, hash, co = _rewrite_test(fn, config)
1152+
_write_pyc(state, co, source_stat, hash, pyc)
11321153
assert _read_pyc(fn, pyc, state.trace) is not None
11331154

1155+
pyc_bytes = pyc.read_bytes()
1156+
assert pyc_bytes[4] == 3 # checked-hash flag set
1157+
assert pyc_bytes[8:16] == hash
1158+
11341159
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
11351160
from _pytest.assertion.rewrite import _read_pyc
11361161

@@ -1149,36 +1174,78 @@ def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
11491174
os.utime(source, (mtime_int, mtime_int))
11501175

11511176
size = len(source_bytes).to_bytes(4, "little")
1152-
# source_hash returns bytes not int: https://github.com/python/typeshed/pull/10686
1153-
hash: bytes = source_hash(source_bytes) # type: ignore[assignment]
1154-
hash = hash[:8]
11551177

11561178
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
11571179

11581180
# Good header.
1159-
pyc.write_bytes(magic + flags + mtime + size + hash + code)
1181+
pyc.write_bytes(magic + flags + mtime + size + code)
11601182
assert _read_pyc(source, pyc, print) is not None
11611183

11621184
# Too short.
11631185
pyc.write_bytes(magic + flags + mtime)
11641186
assert _read_pyc(source, pyc, print) is None
11651187

11661188
# Bad magic.
1167-
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code)
1189+
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code)
11681190
assert _read_pyc(source, pyc, print) is None
11691191

11701192
# Unsupported flags.
1171-
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code)
1193+
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code)
11721194
assert _read_pyc(source, pyc, print) is None
11731195

1174-
# Bad size.
1175-
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code)
1196+
# Bad mtime.
1197+
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code)
11761198
assert _read_pyc(source, pyc, print) is None
11771199

1178-
# Bad mtime + bad hash.
1179-
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code)
1200+
# Bad size.
1201+
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
11801202
assert _read_pyc(source, pyc, print) is None
11811203

1204+
def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
1205+
from _pytest.assertion.rewrite import _read_pyc
1206+
1207+
source = tmp_path / "source.py"
1208+
pyc = tmp_path / "source.pyc"
1209+
1210+
source_bytes = b"def test(): pass\n"
1211+
source.write_bytes(source_bytes)
1212+
1213+
magic = importlib.util.MAGIC_NUMBER
1214+
1215+
flags = b"\x00\x00\x00\x00"
1216+
flags_hash = b"\x03\x00\x00\x00"
1217+
1218+
mtime = b"\x58\x3c\xb0\x5f"
1219+
mtime_int = int.from_bytes(mtime, "little")
1220+
os.utime(source, (mtime_int, mtime_int))
1221+
1222+
size = len(source_bytes).to_bytes(4, "little")
1223+
1224+
# source_hash returns bytes not int: https://github.com/python/typeshed/pull/10686
1225+
hash: bytes = source_hash(source_bytes) # type: ignore[assignment]
1226+
hash = hash[:8]
1227+
1228+
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
1229+
1230+
# check_hash_based_pycs == "default" with hash based pyc file.
1231+
pyc.write_bytes(magic + flags_hash + hash + code)
1232+
assert _read_pyc(source, pyc, print) is not None
1233+
1234+
# check_hash_based_pycs == "always" with hash based pyc file.
1235+
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
1236+
pyc.write_bytes(magic + flags_hash + hash + code)
1237+
assert _read_pyc(source, pyc, print) is not None
1238+
1239+
# Bad hash.
1240+
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
1241+
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
1242+
assert _read_pyc(source, pyc, print) is None
1243+
1244+
# check_hash_based_pycs == "always" with timestamp based pyc file.
1245+
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
1246+
pyc.write_bytes(magic + flags + mtime + size + code)
1247+
assert _read_pyc(source, pyc, print) is None
1248+
11821249
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
11831250
"""Reloading a (collected) module after change picks up the change."""
11841251
pytester.makeini(

0 commit comments

Comments
 (0)