Skip to content

Commit 86e2570

Browse files
authored
Merge pull request #753 from codeflash-ai/test_cache_revival
Test cache revival
2 parents 3b522fa + 882a2e0 commit 86e2570

File tree

6 files changed

+168
-63
lines changed

6 files changed

+168
-63
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
304304
return True, function_names
305305

306306

307-
def get_run_tmp_file(file_path: Path) -> Path:
307+
def get_run_tmp_file(file_path: Path | str) -> Path:
308+
if isinstance(file_path, str):
309+
file_path = Path(file_path)
308310
if not hasattr(get_run_tmp_file, "tmpdir"):
309311
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
310312
return Path(get_run_tmp_file.tmpdir.name) / file_path

codeflash/discovery/discover_unit_tests.py

Lines changed: 119 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,39 @@ class TestFunction:
6767

6868

6969
class TestsCache:
70-
def __init__(self) -> None:
70+
SCHEMA_VERSION = 1 # Increment this when schema changes
71+
72+
def __init__(self, project_root_path: str | Path) -> None:
73+
self.project_root_path = Path(project_root_path).resolve().as_posix()
7174
self.connection = sqlite3.connect(codeflash_cache_db)
7275
self.cur = self.connection.cursor()
7376

77+
self.cur.execute(
78+
"""
79+
CREATE TABLE IF NOT EXISTS schema_version(
80+
version INTEGER PRIMARY KEY
81+
)
82+
"""
83+
)
84+
85+
self.cur.execute("SELECT version FROM schema_version")
86+
result = self.cur.fetchone()
87+
current_version = result[0] if result else None
88+
89+
if current_version != self.SCHEMA_VERSION:
90+
logger.debug(
91+
f"Schema version mismatch (current: {current_version}, expected: {self.SCHEMA_VERSION}). Recreating tables."
92+
)
93+
self.cur.execute("DROP TABLE IF EXISTS discovered_tests")
94+
self.cur.execute("DROP INDEX IF EXISTS idx_discovered_tests_project_file_path_hash")
95+
self.cur.execute("DELETE FROM schema_version")
96+
self.cur.execute("INSERT INTO schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,))
97+
self.connection.commit()
98+
7499
self.cur.execute(
75100
"""
76101
CREATE TABLE IF NOT EXISTS discovered_tests(
102+
project_root_path TEXT,
77103
file_path TEXT,
78104
file_hash TEXT,
79105
qualified_name_with_modules_from_root TEXT,
@@ -88,11 +114,12 @@ def __init__(self) -> None:
88114
)
89115
self.cur.execute(
90116
"""
91-
CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
92-
ON discovered_tests (file_path, file_hash)
117+
CREATE INDEX IF NOT EXISTS idx_discovered_tests_project_file_path_hash
118+
ON discovered_tests (project_root_path, file_path, file_hash)
93119
"""
94120
)
95-
self._memory_cache = {}
121+
122+
self.memory_cache = {}
96123

97124
def insert_test(
98125
self,
@@ -108,8 +135,9 @@ def insert_test(
108135
) -> None:
109136
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
110137
self.cur.execute(
111-
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
138+
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
112139
(
140+
self.project_root_path,
113141
file_path,
114142
file_hash,
115143
qualified_name_with_modules_from_root,
@@ -123,32 +151,48 @@ def insert_test(
123151
)
124152
self.connection.commit()
125153

126-
def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]:
127-
cache_key = (file_path, file_hash)
128-
if cache_key in self._memory_cache:
129-
return self._memory_cache[cache_key]
130-
self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
131-
result = [
132-
FunctionCalledInTest(
154+
def get_function_to_test_map_for_file(
155+
self, file_path: str, file_hash: str
156+
) -> dict[str, set[FunctionCalledInTest]] | None:
157+
cache_key = (self.project_root_path, file_path, file_hash)
158+
if cache_key in self.memory_cache:
159+
return self.memory_cache[cache_key]
160+
161+
self.cur.execute(
162+
"SELECT * FROM discovered_tests WHERE project_root_path = ? AND file_path = ? AND file_hash = ?",
163+
(self.project_root_path, file_path, file_hash),
164+
)
165+
rows = self.cur.fetchall()
166+
if not rows:
167+
return None
168+
169+
function_to_test_map = defaultdict(set)
170+
171+
for row in rows:
172+
qualified_name_with_modules_from_root = row[3]
173+
function_called_in_test = FunctionCalledInTest(
133174
tests_in_file=TestsInFile(
134-
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
175+
test_file=Path(row[1]), test_class=row[5], test_function=row[6], test_type=TestType(int(row[7]))
135176
),
136-
position=CodePosition(line_no=row[7], col_no=row[8]),
177+
position=CodePosition(line_no=row[8], col_no=row[9]),
137178
)
138-
for row in self.cur.fetchall()
139-
]
140-
self._memory_cache[cache_key] = result
179+
function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test)
180+
181+
result = dict(function_to_test_map)
182+
self.memory_cache[cache_key] = result
141183
return result
142184

143185
@staticmethod
144-
def compute_file_hash(path: str) -> str:
186+
def compute_file_hash(path: Path) -> str:
145187
h = hashlib.sha256(usedforsecurity=False)
146-
with Path(path).open("rb") as f:
188+
with path.open("rb", buffering=0) as f:
189+
buf = bytearray(8192)
190+
mv = memoryview(buf)
147191
while True:
148-
chunk = f.read(8192)
149-
if not chunk:
192+
n = f.readinto(mv)
193+
if n == 0:
150194
break
151-
h.update(chunk)
195+
h.update(mv[:n])
152196
return h.hexdigest()
153197

154198
def close(self) -> None:
@@ -394,7 +438,7 @@ def discover_tests_pytest(
394438
cfg: TestConfig,
395439
discover_only_these_tests: list[Path] | None = None,
396440
functions_to_optimize: list[FunctionToOptimize] | None = None,
397-
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
441+
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
398442
tests_root = cfg.tests_root
399443
project_root = cfg.project_root_path
400444

@@ -432,9 +476,11 @@ def discover_tests_pytest(
432476
f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}"
433477
)
434478
if "ModuleNotFoundError" in result.stdout:
435-
match = ImportErrorPattern.search(result.stdout).group()
436-
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
437-
console.print(panel)
479+
match = ImportErrorPattern.search(result.stdout)
480+
if match:
481+
error_message = match.group()
482+
panel = Panel(Text.from_markup(f"⚠️ {error_message} ", style="bold red"), expand=False)
483+
console.print(panel)
438484

439485
elif 0 <= exitcode <= 5:
440486
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}")
@@ -469,13 +515,13 @@ def discover_tests_pytest(
469515

470516
def discover_tests_unittest(
471517
cfg: TestConfig,
472-
discover_only_these_tests: list[str] | None = None,
518+
discover_only_these_tests: list[Path] | None = None,
473519
functions_to_optimize: list[FunctionToOptimize] | None = None,
474-
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
520+
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
475521
tests_root: Path = cfg.tests_root
476522
loader: unittest.TestLoader = unittest.TestLoader()
477523
tests: unittest.TestSuite = loader.discover(str(tests_root))
478-
file_to_test_map: defaultdict[str, list[TestsInFile]] = defaultdict(list)
524+
file_to_test_map: defaultdict[Path, list[TestsInFile]] = defaultdict(list)
479525

480526
def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
481527
_test_function, _test_module, _test_suite_name = (
@@ -487,7 +533,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
487533
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
488534
_test_module_path = tests_root / _test_module_path
489535
if not _test_module_path.exists() or (
490-
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
536+
discover_only_these_tests and _test_module_path not in discover_only_these_tests
491537
):
492538
return None
493539
if "__replay_test" in str(_test_module_path):
@@ -497,10 +543,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
497543
else:
498544
test_type = TestType.EXISTING_UNIT_TEST
499545
return TestsInFile(
500-
test_file=str(_test_module_path),
501-
test_function=_test_function,
502-
test_type=test_type,
503-
test_class=_test_suite_name,
546+
test_file=_test_module_path, test_function=_test_function, test_type=test_type, test_class=_test_suite_name
504547
)
505548

506549
for _test_suite in tests._tests:
@@ -518,18 +561,18 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
518561
continue
519562
details = get_test_details(test_2)
520563
if details is not None:
521-
file_to_test_map[str(details.test_file)].append(details)
564+
file_to_test_map[details.test_file].append(details)
522565
else:
523566
details = get_test_details(test)
524567
if details is not None:
525-
file_to_test_map[str(details.test_file)].append(details)
568+
file_to_test_map[details.test_file].append(details)
526569
return process_test_files(file_to_test_map, cfg, functions_to_optimize)
527570

528571

529572
def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]:
530-
function_name = function_name.split("_")
531-
if len(function_name) > 1 and function_name[-1].isdigit():
532-
return True, "_".join(function_name[:-1]), function_name[-1]
573+
function_parts = function_name.split("_")
574+
if len(function_parts) > 1 and function_parts[-1].isdigit():
575+
return True, "_".join(function_parts[:-1]), function_parts[-1]
533576

534577
return False, function_name, None
535578

@@ -538,7 +581,7 @@ def process_test_files(
538581
file_to_test_map: dict[Path, list[TestsInFile]],
539582
cfg: TestConfig,
540583
functions_to_optimize: list[FunctionToOptimize] | None = None,
541-
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
584+
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
542585
import jedi
543586

544587
project_root_path = cfg.project_root_path
@@ -553,29 +596,39 @@ def process_test_files(
553596
num_discovered_replay_tests = 0
554597
jedi_project = jedi.Project(path=project_root_path)
555598

599+
tests_cache = TestsCache(project_root_path)
600+
logger.info("!lsp|Discovering tests and processing unit tests")
556601
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
557602
progress,
558603
task_id,
559604
):
560605
for test_file, functions in file_to_test_map.items():
606+
file_hash = TestsCache.compute_file_hash(test_file)
607+
608+
cached_function_to_test_map = tests_cache.get_function_to_test_map_for_file(str(test_file), file_hash)
609+
610+
if cfg.use_cache and cached_function_to_test_map:
611+
for qualified_name, test_set in cached_function_to_test_map.items():
612+
function_to_test_map[qualified_name].update(test_set)
613+
614+
for function_called_in_test in test_set:
615+
if function_called_in_test.tests_in_file.test_type == TestType.REPLAY_TEST:
616+
num_discovered_replay_tests += 1
617+
num_discovered_tests += 1
618+
619+
progress.advance(task_id)
620+
continue
561621
try:
562622
script = jedi.Script(path=test_file, project=jedi_project)
563623
test_functions = set()
564624

565-
# Single call to get all names with references and definitions
566-
all_names = script.get_names(all_scopes=True, references=True, definitions=True)
625+
all_names = script.get_names(all_scopes=True, references=True)
626+
all_defs = script.get_names(all_scopes=True, definitions=True)
627+
all_names_top = script.get_names(all_scopes=True)
567628

568-
# Filter once and create lookup dictionaries
569-
top_level_functions = {}
570-
top_level_classes = {}
571-
all_defs = []
629+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
630+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
572631

573-
for name in all_names:
574-
if name.type == "function":
575-
top_level_functions[name.name] = name
576-
all_defs.append(name)
577-
elif name.type == "class":
578-
top_level_classes[name.name] = name
579632
except Exception as e:
580633
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
581634
progress.advance(task_id)
@@ -697,6 +750,18 @@ def process_test_files(
697750
position=CodePosition(line_no=name.line, col_no=name.column),
698751
)
699752
)
753+
tests_cache.insert_test(
754+
file_path=str(test_file),
755+
file_hash=file_hash,
756+
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
757+
function_name=scope,
758+
test_class=test_func.test_class or "",
759+
test_function=scope_test_function,
760+
test_type=test_func.test_type,
761+
line_number=name.line,
762+
col_number=name.column,
763+
)
764+
700765
if test_func.test_type == TestType.REPLAY_TEST:
701766
num_discovered_replay_tests += 1
702767

@@ -707,4 +772,6 @@ def process_test_files(
707772

708773
progress.advance(task_id)
709774

775+
tests_cache.close()
776+
710777
return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests

codeflash/optimization/optimizer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,15 @@ def discover_tests(
239239
from codeflash.discovery.discover_unit_tests import discover_unit_tests
240240

241241
console.rule()
242-
with progress_bar("Discovering existing function tests..."):
243-
start_time = time.time()
244-
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
245-
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
246-
)
247-
console.rule()
248-
logger.info(
249-
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
250-
)
242+
start_time = time.time()
243+
logger.info("lsp,loading|Discovering existing function tests...")
244+
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
245+
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
246+
)
247+
console.rule()
248+
logger.info(
249+
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
250+
)
251251
console.rule()
252252
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
253253
return function_to_tests, num_discovered_tests

codeflash/verification/verification_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,4 @@ class TestConfig:
7676
concolic_test_root_dir: Optional[Path] = None
7777
pytest_cmd: str = "pytest"
7878
benchmark_tests_root: Optional[Path] = None
79+
use_cache: bool = True

tests/scripts/end_to_end_test_init_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ def run_test(expected_improvement_pct: int) -> bool:
2121

2222

2323
if __name__ == "__main__":
24-
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 5))))
24+
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))

0 commit comments

Comments
 (0)