Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import os
import sqlite3
import sys
import time
from pathlib import Path

import pytest

from codeflash.benchmarking.codeflash_trace import codeflash_trace
from codeflash.models.models import BenchmarkKey

Expand All @@ -15,7 +18,7 @@ def __init__(self) -> None:
self._connection = None
self.benchmark_timings = []

def setup(self, trace_path:str) -> None:
def setup(self, trace_path: str) -> None:
try:
# Open connection
self._trace_path = trace_path
Expand All @@ -28,7 +31,7 @@ def setup(self, trace_path:str) -> None:
"benchmark_time_ns INTEGER)"
)
self._connection.commit()
self.close() # Reopen only at the end of pytest session
self.close() # Reopen only at the end of pytest session
except Exception as e:
print(f"Database setup error: {e}")
if self._connection:
Expand All @@ -42,20 +45,23 @@ def write_benchmark_timings(self) -> None:

if self._connection is None:
self._connection = sqlite3.connect(self._trace_path)
self._connection.execute("PRAGMA synchronous = OFF")
self._connection.execute("PRAGMA journal_mode = MEMORY")

try:
cur = self._connection.cursor()
# Insert data into the benchmark_timings table
cur.executemany(
"INSERT INTO benchmark_timings (benchmark_file_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
self.benchmark_timings
self.benchmark_timings,
)
self._connection.commit()
self.benchmark_timings = [] # Clear the benchmark timings list
self.benchmark_timings = [] # Clear the benchmark timings list
except Exception as e:
print(f"Error writing to benchmark timings database: {e}")
self._connection.rollback()
raise

def close(self) -> None:
if self._connection:
self._connection.close()
Expand Down Expand Up @@ -189,12 +195,7 @@ def pytest_sessionfinish(self, session, exitstatus):

@staticmethod
def pytest_addoption(parser):
parser.addoption(
"--codeflash-trace",
action="store_true",
default=False,
help="Enable CodeFlash tracing"
)
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")

@staticmethod
def pytest_plugin_registered(plugin, manager):
Expand Down Expand Up @@ -246,7 +247,7 @@ def test_something(benchmark):
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
os.environ["CODEFLASH_BENCHMARKING"] = "True"

# Run the function
# Run the function
start = time.perf_counter_ns()
result = func(*args, **kwargs)
end = time.perf_counter_ns()
Expand All @@ -260,7 +261,8 @@ def test_something(benchmark):
codeflash_trace.function_call_count = 0
# Add to the benchmark timings buffer
codeflash_benchmark_plugin.benchmark_timings.append(
(benchmark_file_path, benchmark_function_name, line_number, end - start))
(benchmark_file_path, benchmark_function_name, line_number, end - start)
)

return result

Expand All @@ -272,4 +274,5 @@ def benchmark(request):

return CodeFlashBenchmarkPlugin.Benchmark(request)

codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
Loading