Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
151 commits
Select commit Hold shift + click to select a range
3480718
initial implementation for pytest benchmark discovery
alvin-r Feb 25, 2025
7f917a0
Merge branch 'refs/heads/main' into pytest-benchmark
alvin-r Feb 27, 2025
133a9e3
initial implementation for tracing benchmarks using a plugin, and pro…
alvin-r Feb 28, 2025
2f26695
initial implementation of tracing benchmarks via the plugin
alvin-r Mar 4, 2025
32b0d3b
Merge branch 'main' into pytest-benchmark
alvin-r Mar 11, 2025
6b4b68a
basic version working on bubble sort
alvin-r Mar 11, 2025
887e3cb
initial attempt for codeflash_trace_decorator
alvin-r Mar 11, 2025
84bd0f0
improvements
alvin-r Mar 12, 2025
1c3919d
Merge branch 'refs/heads/main' into codeflash-trace-decorator
alvin-r Mar 12, 2025
c4694b7
work on new replay_test logic
alvin-r Mar 12, 2025
c150c05
Merge branch 'main' into codeflash-trace-decorator
alvin-r Mar 14, 2025
1801d41
initial replay test version working
alvin-r Mar 14, 2025
88a11d3
Merge branch 'main' into codeflash-trace-decorator
alvin-r Mar 14, 2025
f7466a5
replay test functionality working for functions, methods, static meth…
alvin-r Mar 14, 2025
4c19e6f
restored overwritten logic
alvin-r Mar 14, 2025
7eba031
functioning end to end, gets the funciton impact on benchmarks
alvin-r Mar 18, 2025
896aa52
modified printing of results, handle errors when collecting benchmarks
alvin-r Mar 19, 2025
ad17de4
tests pass
alvin-r Mar 19, 2025
92e6bf5
revert pyproject.toml
alvin-r Mar 19, 2025
4784723
mypy fixes
alvin-r Mar 19, 2025
5f05711
Merge branch 'main' into codeflash-trace-decorator
alvin-r Mar 19, 2025
b77a979
import changes
alvin-r Mar 20, 2025
8878baf
Merge branch 'pytest-plugin-blocker' into codeflash-trace-decorator
alvin-r Mar 20, 2025
0c2a3b6
removed benchmark skip command
alvin-r Mar 20, 2025
9a41bdd
shifted benchmark class in plugin, improved display of benchmark info
alvin-r Mar 20, 2025
5577cd5
cleanup tests better
alvin-r Mar 20, 2025
83f1c1c
Merge branch 'main' into codeflash-trace-decorator
alvin-r Mar 20, 2025
80730f9
modified paths in test
alvin-r Mar 20, 2025
d610f8c
typing fix
alvin-r Mar 20, 2025
93f583c
typing fix for 3.9
alvin-r Mar 21, 2025
d422e35
typing fix for 3.9
alvin-r Mar 21, 2025
d664040
Merge branch 'main' into codeflash-trace-decorator
alvin-r Mar 21, 2025
1637810
works with multithreading, added test
alvin-r Mar 24, 2025
684acf8
Merge branch 'refs/heads/main' into codeflash-trace-decorator
alvin-r Mar 24, 2025
6180c9d
refactored get_function_benchmark_timings and get_benchmark_timings i…
alvin-r Mar 25, 2025
fa93df6
Merge branch 'refs/heads/main' into codeflash-trace-decorator
alvin-r Mar 25, 2025
67d3f19
fixed isort
alvin-r Mar 25, 2025
f4be9be
modified PR info
alvin-r Mar 26, 2025
77f43a5
mypy fix
alvin-r Mar 26, 2025
da6385f
use dill instead of pickle
alvin-r Mar 26, 2025
f34f22f
modified the benchmarking approach. codeflash_trace and codeflash_ben…
alvin-r Mar 28, 2025
57b80ec
started implementing group by benchmark
alvin-r Mar 28, 2025
87ad743
Merge branch 'refs/heads/merge_test_results_into_models' into codefla…
alvin-r Mar 28, 2025
d03ed96
Merge branch 'merge_test_results_into_models' into codeflash-trace-de…
alvin-r Mar 28, 2025
8d95b18
Merge branch 'main' into codeflash-trace-decorator
alvin-r Mar 28, 2025
56e3447
reworked matching benchmark key to test results.
alvin-r Mar 31, 2025
5a34697
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 1, 2025
5f86bdd
PRAGMA journal to memory to make it faster
alvin-r Apr 1, 2025
20890fa
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 1, 2025
9764c25
benchmarks root must be subdir of tests root
alvin-r Apr 1, 2025
d703b13
replay tests are now grouped by benchmark file. each benchmark test f…
alvin-r Apr 1, 2025
14c33f9
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 1, 2025
c6a201b
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 1, 2025
30ec0c4
Use module path instead of file path for benchmarks, improved display…
alvin-r Apr 2, 2025
cf00212
Merge branch 'refs/heads/main' into codeflash-trace-decorator
alvin-r Apr 2, 2025
bb9c5db
benchmark flow is working. changed paths to use module_path instead o…
alvin-r Apr 2, 2025
7d9c4e1
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 2, 2025
1928dc4
fixed string error
alvin-r Apr 2, 2025
217e239
fixed mypy error
alvin-r Apr 2, 2025
96dd780
new end to end test for benchmarking bubble sort
alvin-r Apr 2, 2025
5785875
renamed test
alvin-r Apr 2, 2025
d656d3b
fixed e2e test
alvin-r Apr 2, 2025
4d0eb3d
printing issues on github actions
alvin-r Apr 2, 2025
6100620
attempt to use horizontals for rows
alvin-r Apr 2, 2025
21a79eb
added row lines
alvin-r Apr 2, 2025
b374b6e
made benchmarks-root use resolve()
alvin-r Apr 3, 2025
27a6488
handled edge case for instrumenting codeflash trace
alvin-r Apr 3, 2025
4a24f2c
fixed slight bug with formatting table
alvin-r Apr 3, 2025
9de664b
improved file removal after errors
alvin-r Apr 3, 2025
a8d4fda
fixed a return bug
alvin-r Apr 4, 2025
4d53330
Merge branch 'jedi_ctx_fix' into codeflash-trace-decorator
alvin-r Apr 4, 2025
c82a3a3
Merge branch 'refs/heads/main' into codeflash-trace-decorator
alvin-r Apr 7, 2025
1f3fcff
Support recursive functions, and @benchmark / @pytest.mark.benchmark …
alvin-r Apr 7, 2025
fe63652
basic pickle patch version working
alvin-r Apr 8, 2025
d653d0d
draft of end to end test
alvin-r Apr 8, 2025
a73b541
initial implementation for pytest benchmark discovery
alvin-r Feb 25, 2025
965e2c8
initial implementation for tracing benchmarks using a plugin, and pro…
alvin-r Feb 28, 2025
7590c29
initial implementation of tracing benchmarks via the plugin
alvin-r Mar 4, 2025
034bed3
basic version working on bubble sort
alvin-r Mar 11, 2025
1f3fd4d
initial attempt for codeflash_trace_decorator
alvin-r Mar 11, 2025
5faccd8
improvements
alvin-r Mar 12, 2025
d6217e8
work on new replay_test logic
alvin-r Mar 12, 2025
26b2c4f
initial replay test version working
alvin-r Mar 14, 2025
adffb9d
replay test functionality working for functions, methods, static meth…
alvin-r Mar 14, 2025
f9144ec
restored overwritten logic
alvin-r Mar 14, 2025
c29c8bf
functioning end to end, gets the funciton impact on benchmarks
alvin-r Mar 18, 2025
54fe71f
modified printing of results, handle errors when collecting benchmarks
alvin-r Mar 19, 2025
5fd112a
tests pass
alvin-r Mar 19, 2025
8194554
revert pyproject.toml
alvin-r Mar 19, 2025
4c1d2af
mypy fixes
alvin-r Mar 19, 2025
6e676e9
import changes
alvin-r Mar 20, 2025
62f3b36
removed benchmark skip command
alvin-r Mar 20, 2025
a614972
shifted benchmark class in plugin, improved display of benchmark info
alvin-r Mar 20, 2025
82cb7a9
cleanup tests better
alvin-r Mar 20, 2025
7601895
modified paths in test
alvin-r Mar 20, 2025
4d69427
typing fix
alvin-r Mar 20, 2025
ebe3e12
typing fix for 3.9
alvin-r Mar 21, 2025
0449d0d
typing fix for 3.9
alvin-r Mar 21, 2025
baac964
works with multithreading, added test
alvin-r Mar 24, 2025
357f586
refactored get_function_benchmark_timings and get_benchmark_timings i…
alvin-r Mar 25, 2025
9efa47f
fixed isort
alvin-r Mar 25, 2025
64b4c64
modified PR info
alvin-r Mar 26, 2025
4c61de9
mypy fix
alvin-r Mar 26, 2025
eda0d46
use dill instead of pickle
alvin-r Mar 26, 2025
a82e9f0
modified the benchmarking approach. codeflash_trace and codeflash_ben…
alvin-r Mar 28, 2025
582bea0
started implementing group by benchmark
alvin-r Mar 28, 2025
e5a8260
reworked matching benchmark key to test results.
alvin-r Mar 31, 2025
0937329
PRAGMA journal to memory to make it faster
alvin-r Apr 1, 2025
ed8f5ef
benchmarks root must be subdir of tests root
alvin-r Apr 1, 2025
75c1be7
replay tests are now grouped by benchmark file. each benchmark test f…
alvin-r Apr 1, 2025
b3c8320
Use module path instead of file path for benchmarks, improved display…
alvin-r Apr 2, 2025
972ef46
benchmark flow is working. changed paths to use module_path instead o…
alvin-r Apr 2, 2025
06b3818
fixed string error
alvin-r Apr 2, 2025
37577e7
fixed mypy error
alvin-r Apr 2, 2025
5c30d3e
new end to end test for benchmarking bubble sort
alvin-r Apr 2, 2025
906e434
renamed test
alvin-r Apr 2, 2025
821fa47
fixed e2e test
alvin-r Apr 2, 2025
41f7e0a
printing issues on github actions
alvin-r Apr 2, 2025
c20f29a
attempt to use horizontals for rows
alvin-r Apr 2, 2025
d1a8d25
added row lines
alvin-r Apr 2, 2025
705105c
made benchmarks-root use resolve()
alvin-r Apr 3, 2025
26546de
handled edge case for instrumenting codeflash trace
alvin-r Apr 3, 2025
0c04adf
fixed slight bug with formatting table
alvin-r Apr 3, 2025
30d32bb
improved file removal after errors
alvin-r Apr 3, 2025
c997b90
fixed a return bug
alvin-r Apr 4, 2025
d6ed1c3
Support recursive functions, and @benchmark / @pytest.mark.benchmark …
alvin-r Apr 7, 2025
a4c4c2d
Merge remote-tracking branch 'origin/codeflash-trace-decorator' into …
alvin-r Apr 10, 2025
40e416e
Merge branch 'refs/heads/main' into codeflash-trace-decorator
alvin-r Apr 11, 2025
3158f9c
end to end test that proves picklepatcher works. example shown is a s…
alvin-r Apr 11, 2025
9578854
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 11, 2025
4bb0aad
minor fix for removing files
alvin-r Apr 11, 2025
790d77c
fixes to sync with main
alvin-r Apr 11, 2025
fce641e
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 11, 2025
b70c4c9
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 15, 2025
28fd746
cmd init changes
alvin-r Apr 15, 2025
4e8483b
created benchmarks for codeflash, modified codeflash-optimize to use …
alvin-r Apr 16, 2025
efc91d6
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 16, 2025
0680f79
added benchmarks root
alvin-r Apr 16, 2025
583b464
removed comment
alvin-r Apr 16, 2025
1eaaad7
debugging
alvin-r Apr 16, 2025
ab9079b
debugging
alvin-r Apr 16, 2025
d7274ec
removed benchmark-skip
alvin-r Apr 16, 2025
a624221
added pytest-benchmark as dependency
alvin-r Apr 16, 2025
605d078
updated pyproject
alvin-r Apr 16, 2025
78871fe
gha failing on multithreaded t est
alvin-r Apr 16, 2025
0146d82
line number test is off by 1 for python versions 39 and 310, removed …
alvin-r Apr 17, 2025
6c1a369
Merge branch 'main' into codeflash-trace-decorator
alvin-r Apr 17, 2025
3017ccf
100 max function calls before flushing to disk instead of 1000
alvin-r Apr 17, 2025
f14cf01
skip multithreaded benchmark test if machine is single threaded (fixe…
alvin-r Apr 17, 2025
e5ca10f
marked multithreaded trace benchmarks test to be skipped during CI as…
alvin-r Apr 17, 2025
683c9f6
shift check for pickle placerholder access error in comparator
alvin-r Apr 17, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support recursive functions, and @benchmark / @pytest.mark.benchmark …
…ways of using benchmark. created tests for all of them
  • Loading branch information
alvin-r committed Apr 10, 2025
commit d6ed1c33c4a307bbf7ae3be57d22dc6ed25951cb
18 changes: 18 additions & 0 deletions code_to_optimize/bubble_sort_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@ def sorter(arr):
arr[j + 1] = temp
return arr

@codeflash_trace
def recursive_bubble_sort(arr, n=None):
# Initialize n if not provided
if n is None:
n = len(arr)

# Base case: if n is 1, the array is already sorted
if n == 1:
return arr

# One pass of bubble sort - move the largest element to the end
for i in range(n - 1):
if arr[i] > arr[i + 1]:
arr[i], arr[i + 1] = arr[i + 1], arr[i]

# Recursively sort the remaining n-1 elements
return recursive_bubble_sort(arr, n - 1)

class Sorter:
@codeflash_trace
def __init__(self, arr):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from code_to_optimize.bubble_sort_codeflash_trace import recursive_bubble_sort


def test_recursive_sort(benchmark):
result = benchmark(recursive_bubble_sort, list(reversed(range(500))))
assert result == list(range(500))
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from code_to_optimize.bubble_sort_codeflash_trace import sorter

def test_benchmark_sort(benchmark):
@benchmark
def do_sort():
sorter(list(reversed(range(500))))

@pytest.mark.benchmark(group="benchmark_decorator")
def test_pytest_mark(benchmark):
benchmark(sorter, list(reversed(range(500))))
79 changes: 55 additions & 24 deletions codeflash/benchmarking/codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
import sqlite3
import sys
import threading
import time
from typing import Callable

Expand All @@ -18,6 +19,8 @@ def __init__(self) -> None:
self.pickle_count_limit = 1000
self._connection = None
self._trace_path = None
self._thread_local = threading.local()
self._thread_local.active_functions = set()

def setup(self, trace_path: str) -> None:
"""Set up the database connection for direct writing.
Expand Down Expand Up @@ -98,23 +101,29 @@ def __call__(self, func: Callable) -> Callable:
The wrapped function

"""
func_id = (func.__module__,func.__name__)
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Initialize thread-local active functions set if it doesn't exist
if not hasattr(self._thread_local, "active_functions"):
self._thread_local.active_functions = set()
# If it's in a recursive function, just return the result
if func_id in self._thread_local.active_functions:
return func(*args, **kwargs)
# Track active functions so we can detect recursive functions
self._thread_local.active_functions.add(func_id)
# Measure execution time
start_time = time.thread_time_ns()
result = func(*args, **kwargs)
end_time = time.thread_time_ns()
# Calculate execution time
execution_time = end_time - start_time

self.function_call_count += 1

# Measure overhead
original_recursion_limit = sys.getrecursionlimit()
# Check if currently in pytest benchmark fixture
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
self._thread_local.active_functions.remove(func_id)
return result

# Get benchmark info from environment
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
Expand All @@ -125,32 +134,54 @@ def wrapper(*args, **kwargs):
if "." in qualname:
class_name = qualname.split(".")[0]

if self.function_call_count <= self.pickle_count_limit:
# Limit pickle count so memory does not explode
if self.function_call_count > self.pickle_count_limit:
print("Pickle limit reached")
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
)
return result

try:
original_recursion_limit = sys.getrecursionlimit()
sys.setrecursionlimit(10000)
# args = dict(args.items())
# if class_name and func.__name__ == "__init__" and "self" in args:
# del args["self"]
# Pickle the arguments
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# Retry with dill if pickle fails. It's slower but more comprehensive
try:
sys.setrecursionlimit(1000000)
args = dict(args.items())
if class_name and func.__name__ == "__init__" and "self" in args:
del args["self"]
# Pickle the arguments
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# we retry with dill if pickle fails. It's slower but more comprehensive
try:
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)

except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
print(f"Error pickling arguments for function {func.__name__}: {e}")
return result

except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
print(f"Error pickling arguments for function {func.__name__}: {e}")
# Add to the list of function calls without pickled args. Used for timing info only
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
)
return result

# Flush to database every 1000 calls
if len(self.function_calls_data) > 1000:
self.write_function_timings()
# Calculate overhead time
overhead_time = time.thread_time_ns() - end_time

# Add to the list of function calls with pickled args, to be used for replay tests
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
Expand Down
71 changes: 41 additions & 30 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
# Subtract overhead from total time
overhead = overhead_by_benchmark.get(benchmark_key, 0)
print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead)
result[benchmark_key] = time_ns - overhead

finally:
Expand Down Expand Up @@ -210,61 +211,71 @@ def pytest_plugin_registered(plugin, manager):
manager.unregister(plugin)

@staticmethod
def pytest_configure(config):
"""Register the benchmark marker."""
config.addinivalue_line(
"markers",
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
)
@staticmethod
def pytest_collection_modifyitems(config, items):
# Skip tests that don't have the benchmark fixture
if not config.getoption("--codeflash-trace"):
return

skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
continue
item.add_marker(skip_no_benchmark)
# Check for direct benchmark fixture usage
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames

# Check for @pytest.mark.benchmark marker
has_marker = False
if hasattr(item, "get_closest_marker"):
marker = item.get_closest_marker("benchmark")
if marker is not None:
has_marker = True

# Skip if neither fixture nor marker is present
if not (has_fixture or has_marker):
item.add_marker(skip_no_benchmark)

# Benchmark fixture
class Benchmark:
def __init__(self, request):
self.request = request

def __call__(self, func, *args, **kwargs):
"""Handle behaviour for the benchmark fixture in pytest.

For example,

def test_something(benchmark):
benchmark(sorter, [3,2,1])

Args:
func: The function to benchmark (e.g. sorter)
args: The arguments to pass to the function (e.g. [3,2,1])
kwargs: The keyword arguments to pass to the function

Returns:
The return value of the function
a

"""
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
"""Handle both direct function calls and decorator usage."""
if args or kwargs:
# Used as benchmark(func, *args, **kwargs)
return self._run_benchmark(func, *args, **kwargs)
# Used as @benchmark decorator
def wrapped_func(*args, **kwargs):
return func(*args, **kwargs)
result = self._run_benchmark(func)
return wrapped_func

def _run_benchmark(self, func, *args, **kwargs):
"""Actual benchmark implementation."""
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
Path(codeflash_benchmark_plugin.project_root))
benchmark_function_name = self.request.node.name
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack

# Set env vars so codeflash decorator can identify what benchmark its being run in
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
# Set env vars
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
os.environ["CODEFLASH_BENCHMARKING"] = "True"

# Run the function
start = time.perf_counter_ns()
# Run the function
start = time.thread_time_ns()
result = func(*args, **kwargs)
end = time.perf_counter_ns()

end = time.thread_time_ns()
# Reset the environment variable
os.environ["CODEFLASH_BENCHMARKING"] = "False"

# Write function calls
codeflash_trace.write_function_timings()
# Reset function call count after a benchmark is run
# Reset function call count
codeflash_trace.function_call_count = 0
# Add to the benchmark timings buffer
codeflash_benchmark_plugin.benchmark_timings.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
codeflash_benchmark_plugin.setup(trace_file, project_root)
codeflash_trace.setup(trace_file)
exitcode = pytest.main(
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
) # Errors will be printed to stdout, not stderr

except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_next_arg_and_return(
)

while (val := cursor.fetchone()) is not None:
yield val[9], val[10] # args and kwargs are at indices 7 and 8
yield val[9], val[10] # pickled_args, pickled_kwargs


def get_function_alias(module: str, function_name: str) -> str:
Expand Down
62 changes: 61 additions & 1 deletion tests/test_trace_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_trace_benchmarks():
function_calls = cursor.fetchall()

# Assert the length of function calls
assert len(function_calls) == 7, f"Expected 6 function calls, but got {len(function_calls)}"
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"

bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
Expand Down Expand Up @@ -64,6 +64,10 @@ def test_trace_benchmarks():
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8),

("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5),
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
Expand Down Expand Up @@ -222,6 +226,62 @@ def test_trace_multithreaded_benchmark() -> None:
# Close connection
conn.close()

finally:
# cleanup
output_file.unlink(missing_ok=True)

def test_trace_benchmark_decorator() -> None:
project_root = Path(__file__).parent.parent / "code_to_optimize"
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator"
tests_root = project_root / "tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
# check contents of trace file
# connect to database
conn = sqlite3.connect(output_file.as_posix())
cursor = conn.cursor()

# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
function_calls = cursor.fetchall()

# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results

test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0

bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
# Expected function calls
expected_calls = [
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_benchmark_sort", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 5),
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11),
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
# Close connection
conn.close()

finally:
# cleanup
output_file.unlink(missing_ok=True)