Skip to content

Commit ad17de4

Browse files
committed
tests pass
1 parent 896aa52 commit ad17de4

File tree

5 files changed

+52
-23
lines changed

5 files changed

+52
-23
lines changed

code_to_optimize/bubble_sort.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
def sorter(arr):
2+
print("codeflash stdout: Sorting list")
23
for i in range(len(arr)):
34
for j in range(len(arr) - 1):
45
if arr[j] > arr[j + 1]:
56
temp = arr[j]
67
arr[j] = arr[j + 1]
78
arr[j + 1] = temp
9+
print(f"result: {arr}")
810
return arr
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from code_to_optimize.bubble_sort import sorter
2+
from codeflash.benchmarking.codeflash_trace import codeflash_trace
3+
4+
def calculate_pairwise_products(arr):
5+
"""
6+
Calculate the average of all pairwise products in the array.
7+
"""
8+
sum_of_products = 0
9+
count = 0
10+
11+
for i in range(len(arr)):
12+
for j in range(len(arr)):
13+
if i != j:
14+
sum_of_products += arr[i] * arr[j]
15+
count += 1
16+
17+
# The average of all pairwise products
18+
return sum_of_products / count if count > 0 else 0
19+
20+
@codeflash_trace
21+
def compute_and_sort(arr):
22+
# Compute pairwise sums average
23+
pairwise_average = calculate_pairwise_products(arr)
24+
25+
# Call sorter function
26+
sorter(arr.copy())
27+
28+
return pairwise_average

codeflash/discovery/functions_to_optimize.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -363,23 +363,25 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
363363
for decorator in body_node.decorator_list
364364
):
365365
self.is_staticmethod = True
366+
print(f"static method found: {self.function_name}")
367+
return
368+
elif self.line_no:
369+
# If we have line number info, check if class has a static method with the same line number
370+
# This way, if we don't have the class name, we can still find the static method
371+
for body_node in node.body:
372+
if (
373+
isinstance(body_node, ast.FunctionDef)
374+
and body_node.name == self.function_name
375+
and body_node.lineno in {self.line_no, self.line_no + 1}
376+
and any(
377+
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
378+
for decorator in body_node.decorator_list
379+
)
380+
):
381+
self.is_staticmethod = True
382+
self.is_top_level = True
383+
self.class_name = node.name
366384
return
367-
# else:
368-
# # search if the class has a staticmethod with the same name and on the same line number
369-
# for body_node in node.body:
370-
# if (
371-
# isinstance(body_node, ast.FunctionDef)
372-
# and body_node.name == self.function_name
373-
# # and body_node.lineno in {self.line_no, self.line_no + 1}
374-
# and any(
375-
# isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
376-
# for decorator in body_node.decorator_list
377-
# )
378-
# ):
379-
# self.is_staticmethod = True
380-
# self.is_top_level = True
381-
# self.class_name = node.name
382-
# return
383385

384386
return
385387

codeflash/discovery/pytest_new_process_discovery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s
3434

3535
try:
3636
exitcode = pytest.main(
37-
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
37+
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip", "--benchmark-skip"], plugins=[PytestCollectionPlugin()]
3838
)
3939
except Exception as e: # noqa: BLE001
4040
print(f"Failed to collect tests: {e!s}") # noqa: T201

tests/test_unit_test_discovery.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,19 @@ def test_unit_test_discovery_pytest():
1818
)
1919
tests = discover_unit_tests(test_config)
2020
assert len(tests) > 0
21-
# print(tests)
21+
2222

2323
def test_benchmark_test_discovery_pytest():
2424
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
25-
tests_path = project_path / "tests" / "pytest" / "benchmarks" / "test_benchmark_bubble_sort.py"
25+
tests_path = project_path / "tests" / "pytest" / "benchmarks"
2626
test_config = TestConfig(
2727
tests_root=tests_path,
2828
project_root_path=project_path,
2929
test_framework="pytest",
3030
tests_project_rootdir=tests_path.parent,
3131
)
3232
tests = discover_unit_tests(test_config)
33-
assert len(tests) > 0
34-
assert 'bubble_sort.sorter' in tests
35-
benchmark_tests = sum(1 for test in tests['bubble_sort.sorter'] if test.tests_in_file.test_type == TestType.BENCHMARK_TEST)
36-
assert benchmark_tests == 1
33+
assert len(tests) == 1 # Should not discover benchmark tests
3734

3835

3936
def test_unit_test_discovery_unittest():

0 commit comments

Comments
 (0)