Skip to content

Commit ad12fb2

Browse files
Merge branch 'master' into docs/ci-to-chore-pashma
2 parents 3a865de + 9f227c1 commit ad12fb2

File tree

13 files changed

+158
-118
lines changed

13 files changed

+158
-118
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
<!--next-version-placeholder-->
44

5+
## v1.2.1 (2021-03-03)
6+
### Fix
7+
* Fix crash in test function runtimes ([`ad4c1f3`](https://github.com/F-Secure/pytest-rts/commit/ad4c1f3820a72bf2b9cbc8583c94bba6d2b2dcc2))
8+
59
## v1.2.0 (2021-02-15)
610
### Feature
711
* Handle usage in non-git directory ([`8ad080b`](https://github.com/F-Secure/pytest-rts/commit/8ad080be7eb31b96e1047a4aadabe9fe1a944085))

pytest_rts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""pytest-rts: avoid already imported warning: PYTEST_DONT_REWRITE"""
2-
__version__ = "1.2.0"
2+
__version__ = "1.2.1"
Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,19 @@
11
"""This module contains code for initializing the mapping database"""
22
import os
3-
from timeit import default_timer as timer
43

54
import coverage
6-
import pytest
7-
from _pytest.python import Function
85

6+
from pytest_rts.pytest.mapper_plugin import MapperPlugin
97
from pytest_rts.utils.common import calculate_func_lines
108
from pytest_rts.utils.git import get_current_head_hash
11-
from pytest_rts.utils.mappinghelper import TestrunData
129

1310

14-
class InitPhasePlugin:
11+
class InitPhasePlugin(MapperPlugin):
1512
"""Class to handle mapping database initialization"""
1613

1714
def __init__(self, mappinghelper):
1815
""""Constructor calls database and Coverage.py initialization"""
19-
self.cov = coverage.Coverage(data_file=None)
20-
self.cov._warn_unimported_source = False
21-
self.testfiles = None
22-
self.test_func_lines = None
23-
24-
self.mappinghelper = mappinghelper
25-
self.mappinghelper.init_mapping()
16+
super().__init__(mappinghelper)
2617
self.mappinghelper.set_last_update_hash(get_current_head_hash())
2718

2819
def pytest_collection_modifyitems(self, session, config, items):
@@ -35,27 +26,3 @@ def pytest_collection_modifyitems(self, session, config, items):
3526
)
3627
for testfile_path in self.testfiles
3728
}
38-
39-
@pytest.hookimpl(hookwrapper=True)
40-
def pytest_runtest_protocol(self, item, nextitem):
41-
"""Start coverage collection for each test function run and save data"""
42-
del nextitem
43-
if isinstance(item, Function):
44-
start = timer()
45-
self.cov.erase()
46-
self.cov.start()
47-
yield
48-
self.cov.stop()
49-
end = timer()
50-
elapsed = round(end - start, 4)
51-
52-
testrun_data = TestrunData(
53-
pytest_item=item,
54-
elapsed_time=elapsed,
55-
coverage_data=self.cov.get_data(),
56-
found_testfiles=self.testfiles,
57-
test_function_lines=self.test_func_lines,
58-
)
59-
self.mappinghelper.save_testrun_data(testrun_data)
60-
else:
61-
yield

pytest_rts/pytest/mapper_plugin.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""This module contains code for initializing the mapping database"""
2+
# pylint: disable=too-few-public-methods
3+
from timeit import default_timer as timer
4+
5+
import coverage
6+
import pytest
7+
from _pytest.python import Function
8+
9+
from pytest_rts.utils.mappinghelper import TestrunData
10+
11+
12+
class MapperPlugin:
13+
"""Class to handle mapping database initialization"""
14+
15+
def __init__(self, mappinghelper):
16+
""""Constructor calls database and Coverage.py initialization"""
17+
self.cov = coverage.Coverage(data_file=None)
18+
self.cov._warn_unimported_source = False
19+
self.mappinghelper = mappinghelper
20+
self.mappinghelper.init_mapping()
21+
self.testfiles = {testfile[1] for testfile in self.mappinghelper.testfiles}
22+
self.test_func_lines = None
23+
24+
@pytest.hookimpl(hookwrapper=True)
25+
def pytest_runtest_protocol(self, item, nextitem):
26+
"""Start coverage collection for each test function run and save data"""
27+
del nextitem
28+
if isinstance(item, Function):
29+
start = timer()
30+
self.cov.erase()
31+
self.cov.start()
32+
yield
33+
self.cov.stop()
34+
end = timer()
35+
elapsed = round(end - start, 4)
36+
37+
testrun_data = TestrunData(
38+
pytest_item=item,
39+
elapsed_time=elapsed,
40+
coverage_data=self.cov.get_data(),
41+
found_testfiles=self.testfiles,
42+
test_function_lines=self.test_func_lines,
43+
)
44+
self.mappinghelper.save_testrun_data(testrun_data)
45+
else:
46+
yield
Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""This module contains code for running a specific test set without mapping database updating"""
22
# pylint: disable=too-few-public-methods
3-
import sys
4-
53
from pytest_rts.pytest.fake_item import FakeItem
4+
from pytest_rts.utils.common import filter_and_sort_pytest_items
65

76

87
class NormalPhasePlugin:
@@ -18,16 +17,9 @@ def pytest_collection_modifyitems(self, session, config, items):
1817
"""Select only specific tests for running and prioritize them based on queried times"""
1918
del config
2019
original_length = len(items)
21-
selected = list(filter(lambda item: item.nodeid in self.test_set, items))
22-
updated_runtimes = {
23-
item.nodeid: self.test_func_times[item.nodeid]
24-
if item.nodeid in self.test_func_times
25-
else sys.maxsize
26-
for item in selected
27-
}
28-
29-
items[:] = sorted(selected, key=lambda item: updated_runtimes[item.nodeid])
30-
20+
items[:] = filter_and_sort_pytest_items(
21+
self.test_set, items, self.test_func_times
22+
)
3123
session.config.hook.pytest_deselected(
32-
items=([FakeItem(session.config)] * (original_length - len(selected)))
24+
items=([FakeItem(session.config)] * (original_length - len(items)))
3325
)
Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
"""This module contains code for running a specific test set with mapping database updating"""
22
import os
3-
import sys
4-
from timeit import default_timer as timer
53

64
import coverage
7-
import pytest
8-
from _pytest.python import Function
95

106
from pytest_rts.pytest.fake_item import FakeItem
11-
from pytest_rts.utils.common import calculate_func_lines
12-
from pytest_rts.utils.mappinghelper import TestrunData
7+
from pytest_rts.utils.common import calculate_func_lines, filter_and_sort_pytest_items
8+
from pytest_rts.pytest.mapper_plugin import MapperPlugin
139

1410

1511
def _read_testfile_functions(testfile_path):
@@ -22,20 +18,14 @@ def _read_testfile_functions(testfile_path):
2218
return {}
2319

2420

25-
class UpdatePhasePlugin:
21+
class UpdatePhasePlugin(MapperPlugin):
2622
"""Class to handle running of selected tests and updating mapping with the results"""
2723

2824
def __init__(self, test_set, mappinghelper, testgetter):
2925
"""Constructor opens database connection and initializes Coverage.py"""
30-
self.cov = coverage.Coverage()
31-
self.cov._warn_unimported_source = False
26+
super().__init__(mappinghelper)
3227
self.test_set = test_set
33-
34-
self.mappinghelper = mappinghelper
3528
self.testgetter = testgetter
36-
37-
self.testfiles = {testfile[1] for testfile in self.mappinghelper.testfiles}
38-
self.test_func_lines = None
3929
self.test_func_times = self.testgetter.test_function_runtimes
4030

4131
def pytest_collection_modifyitems(self, session, config, items):
@@ -45,15 +35,9 @@ def pytest_collection_modifyitems(self, session, config, items):
4535
"""
4636
del config
4737
original_length = len(items)
48-
selected = list(filter(lambda item: item.nodeid in self.test_set, items))
49-
updated_runtimes = {
50-
item.nodeid: self.test_func_times[item.nodeid]
51-
if item.nodeid in self.test_func_times
52-
else sys.maxsize
53-
for item in selected
54-
}
55-
56-
items[:] = sorted(selected, key=lambda item: updated_runtimes[item.nodeid])
38+
items[:] = filter_and_sort_pytest_items(
39+
self.test_set, items, self.test_func_times
40+
)
5741

5842
self.testfiles.update({os.path.relpath(item.location[0]) for item in items})
5943
self.test_func_lines = {
@@ -62,29 +46,5 @@ def pytest_collection_modifyitems(self, session, config, items):
6246
}
6347

6448
session.config.hook.pytest_deselected(
65-
items=([FakeItem(session.config)] * (original_length - len(selected)))
49+
items=([FakeItem(session.config)] * (original_length - len(items)))
6650
)
67-
68-
@pytest.hookimpl(hookwrapper=True)
69-
def pytest_runtest_protocol(self, item, nextitem):
70-
"""Start coverage collection for each test function run and save data"""
71-
del nextitem
72-
if isinstance(item, Function):
73-
start = timer()
74-
self.cov.erase()
75-
self.cov.start()
76-
yield
77-
self.cov.stop()
78-
end = timer()
79-
elapsed = round(end - start, 4)
80-
81-
testrun_data = TestrunData(
82-
pytest_item=item,
83-
elapsed_time=elapsed,
84-
coverage_data=self.cov.get_data(),
85-
found_testfiles=self.testfiles,
86-
test_function_lines=self.test_func_lines,
87-
)
88-
self.mappinghelper.save_testrun_data(testrun_data)
89-
else:
90-
yield
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
import sys
3+
sys.path.append(".")
4+
from src.shop import Shop
5+
6+
def test_normal_shop_purchase():
7+
shop = Shop(5,10,0)
8+
shop.buy_item()
9+
shop.buy_item()
10+
11+
item_price = shop.get_item_price()
12+
assert shop.get_items() == 8
13+
assert shop.get_money() == item_price * 2
14+
15+
def test_normal_shop_purchase2():
16+
shop = Shop(5,10,0)
17+
18+
19+
shop.buy_item()
20+
shop.buy_item()
21+
shop.buy_item()
22+
shop.buy_item()
23+
24+
item_price = shop.get_item_price()
25+
26+
assert shop.get_items() == 6
27+
assert shop.get_money() == item_price * 4
28+
29+
def test_empty_shop_purchase():
30+
shop = Shop(5,0,0)
31+
shop.buy_item()
32+
assert shop.get_money() == 0

pytest_rts/tests/test_e2e.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,20 @@ def test_decorated_with_param_tracked(helper):
196196
assert helper.get_tests_from_tool_current() == {
197197
"tests/test_decorated.py::test_decorated_2"
198198
}
199+
200+
201+
def test_testfunction_runtimes_not_wiped(helper):
202+
"""Test that checks that test function runtimes are not removed
203+
from database when running the tool for committed changes
204+
and a testfile is changed
205+
"""
206+
orig_runtimes = helper.get_test_function_runtimes()
207+
helper.checkout_new_branch()
208+
209+
helper.change_file("changes/test_shop/shift_two_forward.txt", "tests/test_shop.py")
210+
helper.commit_change("tests/test_shop.py", "shift")
211+
212+
helper.run_tool()
213+
new_runtimes = helper.get_test_function_runtimes()
214+
215+
assert orig_runtimes == new_runtimes

pytest_rts/tests/testhelper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,10 @@ def delete_branch(self, branchname):
119119
stdout=subprocess.DEVNULL,
120120
stderr=subprocess.DEVNULL,
121121
)
122+
123+
def get_test_function_runtimes(self):
124+
conn = sqlite3.connect(DB_FILE_NAME)
125+
testgetter = TestGetter(conn)
126+
runtimes = testgetter.test_function_runtimes
127+
conn.close()
128+
return runtimes

pytest_rts/utils/common.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import ast
33
import os
44
import subprocess
5+
import sys
56
from typing import Dict, List, Set, Tuple
67

8+
from _pytest.nodes import Item
9+
710
from pytest_rts.utils.git import (
811
file_diff_data_between_commits,
912
file_diff_data_current,
@@ -71,10 +74,7 @@ def tests_from_changed_srcfiles(
7174
new_line_map[file_id] = line_mapping(updates_to_lines, filename)
7275

7376
if not all(
74-
[
75-
mappinghelper.line_exists(file_id, line_number)
76-
for line_number in new_lines
77-
]
77+
mappinghelper.line_exists(file_id, line_number) for line_number in new_lines
7878
):
7979
files_to_warn.append(filename)
8080

@@ -162,3 +162,15 @@ def calculate_func_lines(src_code) -> Dict[str, Tuple[int, int]]:
162162
parsed_src_code = ast.parse(src_code)
163163
func_lines = function_lines(parsed_src_code, len(src_code.splitlines()))
164164
return {x[0]: (x[1], x[2]) for x in func_lines}
165+
166+
167+
def filter_and_sort_pytest_items(test_set, pytest_items, runtimes) -> List[Item]:
168+
"""Selected pytest items based on found tests
169+
ordered by their runtimes
170+
"""
171+
selected = list(filter(lambda item: item.nodeid in test_set, pytest_items))
172+
updated_runtimes = {
173+
item.nodeid: runtimes[item.nodeid] if item.nodeid in runtimes else sys.maxsize
174+
for item in pytest_items
175+
}
176+
return sorted(selected, key=lambda item: updated_runtimes[item.nodeid])

0 commit comments

Comments
 (0)