@@ -67,13 +67,39 @@ class TestFunction:
6767
6868
6969class TestsCache :
70- def __init__ (self ) -> None :
70+ SCHEMA_VERSION = 1 # Increment this when schema changes
71+
72+ def __init__ (self , project_root_path : str | Path ) -> None :
73+ self .project_root_path = Path (project_root_path ).resolve ().as_posix ()
7174 self .connection = sqlite3 .connect (codeflash_cache_db )
7275 self .cur = self .connection .cursor ()
7376
77+ self .cur .execute (
78+ """
79+ CREATE TABLE IF NOT EXISTS schema_version(
80+ version INTEGER PRIMARY KEY
81+ )
82+ """
83+ )
84+
85+ self .cur .execute ("SELECT version FROM schema_version" )
86+ result = self .cur .fetchone ()
87+ current_version = result [0 ] if result else None
88+
89+ if current_version != self .SCHEMA_VERSION :
90+ logger .debug (
91+ f"Schema version mismatch (current: { current_version } , expected: { self .SCHEMA_VERSION } ). Recreating tables."
92+ )
93+ self .cur .execute ("DROP TABLE IF EXISTS discovered_tests" )
94+ self .cur .execute ("DROP INDEX IF EXISTS idx_discovered_tests_project_file_path_hash" )
95+ self .cur .execute ("DELETE FROM schema_version" )
96+ self .cur .execute ("INSERT INTO schema_version (version) VALUES (?)" , (self .SCHEMA_VERSION ,))
97+ self .connection .commit ()
98+
7499 self .cur .execute (
75100 """
76101 CREATE TABLE IF NOT EXISTS discovered_tests(
102+ project_root_path TEXT,
77103 file_path TEXT,
78104 file_hash TEXT,
79105 qualified_name_with_modules_from_root TEXT,
@@ -88,11 +114,12 @@ def __init__(self) -> None:
88114 )
89115 self .cur .execute (
90116 """
91- CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
92- ON discovered_tests (file_path, file_hash)
117+ CREATE INDEX IF NOT EXISTS idx_discovered_tests_project_file_path_hash
118+ ON discovered_tests (project_root_path, file_path, file_hash)
93119 """
94120 )
95- self ._memory_cache = {}
121+
122+ self .memory_cache = {}
96123
97124 def insert_test (
98125 self ,
@@ -108,8 +135,9 @@ def insert_test(
108135 ) -> None :
109136 test_type_value = test_type .value if hasattr (test_type , "value" ) else test_type
110137 self .cur .execute (
111- "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
138+ "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ? )" ,
112139 (
140+ self .project_root_path ,
113141 file_path ,
114142 file_hash ,
115143 qualified_name_with_modules_from_root ,
@@ -123,32 +151,48 @@ def insert_test(
123151 )
124152 self .connection .commit ()
125153
126- def get_tests_for_file (self , file_path : str , file_hash : str ) -> list [FunctionCalledInTest ]:
127- cache_key = (file_path , file_hash )
128- if cache_key in self ._memory_cache :
129- return self ._memory_cache [cache_key ]
130- self .cur .execute ("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?" , (file_path , file_hash ))
131- result = [
132- FunctionCalledInTest (
154+ def get_function_to_test_map_for_file (
155+ self , file_path : str , file_hash : str
156+ ) -> dict [str , set [FunctionCalledInTest ]] | None :
157+ cache_key = (self .project_root_path , file_path , file_hash )
158+ if cache_key in self .memory_cache :
159+ return self .memory_cache [cache_key ]
160+
161+ self .cur .execute (
162+ "SELECT * FROM discovered_tests WHERE project_root_path = ? AND file_path = ? AND file_hash = ?" ,
163+ (self .project_root_path , file_path , file_hash ),
164+ )
165+ rows = self .cur .fetchall ()
166+ if not rows :
167+ return None
168+
169+ function_to_test_map = defaultdict (set )
170+
171+ for row in rows :
172+ qualified_name_with_modules_from_root = row [3 ]
173+ function_called_in_test = FunctionCalledInTest (
133174 tests_in_file = TestsInFile (
134- test_file = Path (row [0 ]), test_class = row [4 ], test_function = row [5 ], test_type = TestType (int (row [6 ]))
175+ test_file = Path (row [1 ]), test_class = row [5 ], test_function = row [6 ], test_type = TestType (int (row [7 ]))
135176 ),
136- position = CodePosition (line_no = row [7 ], col_no = row [8 ]),
177+ position = CodePosition (line_no = row [8 ], col_no = row [9 ]),
137178 )
138- for row in self .cur .fetchall ()
139- ]
140- self ._memory_cache [cache_key ] = result
179+ function_to_test_map [qualified_name_with_modules_from_root ].add (function_called_in_test )
180+
181+ result = dict (function_to_test_map )
182+ self .memory_cache [cache_key ] = result
141183 return result
142184
143185 @staticmethod
144- def compute_file_hash (path : str ) -> str :
186+ def compute_file_hash (path : Path ) -> str :
145187 h = hashlib .sha256 (usedforsecurity = False )
146- with Path (path ).open ("rb" ) as f :
188+ with path .open ("rb" , buffering = 0 ) as f :
189+ buf = bytearray (8192 )
190+ mv = memoryview (buf )
147191 while True :
148- chunk = f .read ( 8192 )
149- if not chunk :
192+ n = f .readinto ( mv )
193+ if n == 0 :
150194 break
151- h .update (chunk )
195+ h .update (mv [: n ] )
152196 return h .hexdigest ()
153197
154198 def close (self ) -> None :
@@ -394,7 +438,7 @@ def discover_tests_pytest(
394438 cfg : TestConfig ,
395439 discover_only_these_tests : list [Path ] | None = None ,
396440 functions_to_optimize : list [FunctionToOptimize ] | None = None ,
397- ) -> tuple [dict [str , set [FunctionCalledInTest ]], int ]:
441+ ) -> tuple [dict [str , set [FunctionCalledInTest ]], int , int ]:
398442 tests_root = cfg .tests_root
399443 project_root = cfg .project_root_path
400444
@@ -432,9 +476,11 @@ def discover_tests_pytest(
432476 f"Failed to collect tests. Pytest Exit code: { exitcode } ={ PytestExitCode (exitcode ).name } \n { error_section } "
433477 )
434478 if "ModuleNotFoundError" in result .stdout :
435- match = ImportErrorPattern .search (result .stdout ).group ()
436- panel = Panel (Text .from_markup (f"⚠️ { match } " , style = "bold red" ), expand = False )
437- console .print (panel )
479+ match = ImportErrorPattern .search (result .stdout )
480+ if match :
481+ error_message = match .group ()
482+ panel = Panel (Text .from_markup (f"⚠️ { error_message } " , style = "bold red" ), expand = False )
483+ console .print (panel )
438484
439485 elif 0 <= exitcode <= 5 :
440486 logger .warning (f"Failed to collect tests. Pytest Exit code: { exitcode } ={ PytestExitCode (exitcode ).name } " )
@@ -469,13 +515,13 @@ def discover_tests_pytest(
469515
470516def discover_tests_unittest (
471517 cfg : TestConfig ,
472- discover_only_these_tests : list [str ] | None = None ,
518+ discover_only_these_tests : list [Path ] | None = None ,
473519 functions_to_optimize : list [FunctionToOptimize ] | None = None ,
474- ) -> tuple [dict [str , set [FunctionCalledInTest ]], int ]:
520+ ) -> tuple [dict [str , set [FunctionCalledInTest ]], int , int ]:
475521 tests_root : Path = cfg .tests_root
476522 loader : unittest .TestLoader = unittest .TestLoader ()
477523 tests : unittest .TestSuite = loader .discover (str (tests_root ))
478- file_to_test_map : defaultdict [str , list [TestsInFile ]] = defaultdict (list )
524+ file_to_test_map : defaultdict [Path , list [TestsInFile ]] = defaultdict (list )
479525
480526 def get_test_details (_test : unittest .TestCase ) -> TestsInFile | None :
481527 _test_function , _test_module , _test_suite_name = (
@@ -487,7 +533,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
487533 _test_module_path = Path (_test_module .replace ("." , os .sep )).with_suffix (".py" )
488534 _test_module_path = tests_root / _test_module_path
489535 if not _test_module_path .exists () or (
490- discover_only_these_tests and str ( _test_module_path ) not in discover_only_these_tests
536+ discover_only_these_tests and _test_module_path not in discover_only_these_tests
491537 ):
492538 return None
493539 if "__replay_test" in str (_test_module_path ):
@@ -497,10 +543,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
497543 else :
498544 test_type = TestType .EXISTING_UNIT_TEST
499545 return TestsInFile (
500- test_file = str (_test_module_path ),
501- test_function = _test_function ,
502- test_type = test_type ,
503- test_class = _test_suite_name ,
546+ test_file = _test_module_path , test_function = _test_function , test_type = test_type , test_class = _test_suite_name
504547 )
505548
506549 for _test_suite in tests ._tests :
@@ -518,18 +561,18 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
518561 continue
519562 details = get_test_details (test_2 )
520563 if details is not None :
521- file_to_test_map [str ( details .test_file ) ].append (details )
564+ file_to_test_map [details .test_file ].append (details )
522565 else :
523566 details = get_test_details (test )
524567 if details is not None :
525- file_to_test_map [str ( details .test_file ) ].append (details )
568+ file_to_test_map [details .test_file ].append (details )
526569 return process_test_files (file_to_test_map , cfg , functions_to_optimize )
527570
528571
529572def discover_parameters_unittest (function_name : str ) -> tuple [bool , str , str | None ]:
530- function_name = function_name .split ("_" )
531- if len (function_name ) > 1 and function_name [- 1 ].isdigit ():
532- return True , "_" .join (function_name [:- 1 ]), function_name [- 1 ]
573+ function_parts = function_name .split ("_" )
574+ if len (function_parts ) > 1 and function_parts [- 1 ].isdigit ():
575+ return True , "_" .join (function_parts [:- 1 ]), function_parts [- 1 ]
533576
534577 return False , function_name , None
535578
@@ -538,7 +581,7 @@ def process_test_files(
538581 file_to_test_map : dict [Path , list [TestsInFile ]],
539582 cfg : TestConfig ,
540583 functions_to_optimize : list [FunctionToOptimize ] | None = None ,
541- ) -> tuple [dict [str , set [FunctionCalledInTest ]], int ]:
584+ ) -> tuple [dict [str , set [FunctionCalledInTest ]], int , int ]:
542585 import jedi
543586
544587 project_root_path = cfg .project_root_path
@@ -553,29 +596,39 @@ def process_test_files(
553596 num_discovered_replay_tests = 0
554597 jedi_project = jedi .Project (path = project_root_path )
555598
599+ tests_cache = TestsCache (project_root_path )
600+ logger .info ("!lsp|Discovering tests and processing unit tests" )
556601 with test_files_progress_bar (total = len (file_to_test_map ), description = "Processing test files" ) as (
557602 progress ,
558603 task_id ,
559604 ):
560605 for test_file , functions in file_to_test_map .items ():
606+ file_hash = TestsCache .compute_file_hash (test_file )
607+
608+ cached_function_to_test_map = tests_cache .get_function_to_test_map_for_file (str (test_file ), file_hash )
609+
610+ if cfg .use_cache and cached_function_to_test_map :
611+ for qualified_name , test_set in cached_function_to_test_map .items ():
612+ function_to_test_map [qualified_name ].update (test_set )
613+
614+ for function_called_in_test in test_set :
615+ if function_called_in_test .tests_in_file .test_type == TestType .REPLAY_TEST :
616+ num_discovered_replay_tests += 1
617+ num_discovered_tests += 1
618+
619+ progress .advance (task_id )
620+ continue
561621 try :
562622 script = jedi .Script (path = test_file , project = jedi_project )
563623 test_functions = set ()
564624
565- # Single call to get all names with references and definitions
566- all_names = script .get_names (all_scopes = True , references = True , definitions = True )
625+ all_names = script .get_names (all_scopes = True , references = True )
626+ all_defs = script .get_names (all_scopes = True , definitions = True )
627+ all_names_top = script .get_names (all_scopes = True )
567628
568- # Filter once and create lookup dictionaries
569- top_level_functions = {}
570- top_level_classes = {}
571- all_defs = []
629+ top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
630+ top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
572631
573- for name in all_names :
574- if name .type == "function" :
575- top_level_functions [name .name ] = name
576- all_defs .append (name )
577- elif name .type == "class" :
578- top_level_classes [name .name ] = name
579632 except Exception as e :
580633 logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
581634 progress .advance (task_id )
@@ -697,6 +750,18 @@ def process_test_files(
697750 position = CodePosition (line_no = name .line , col_no = name .column ),
698751 )
699752 )
753+ tests_cache .insert_test (
754+ file_path = str (test_file ),
755+ file_hash = file_hash ,
756+ qualified_name_with_modules_from_root = qualified_name_with_modules_from_root ,
757+ function_name = scope ,
758+ test_class = test_func .test_class or "" ,
759+ test_function = scope_test_function ,
760+ test_type = test_func .test_type ,
761+ line_number = name .line ,
762+ col_number = name .column ,
763+ )
764+
700765 if test_func .test_type == TestType .REPLAY_TEST :
701766 num_discovered_replay_tests += 1
702767
@@ -707,4 +772,6 @@ def process_test_files(
707772
708773 progress .advance (task_id )
709774
775+ tests_cache .close ()
776+
710777 return dict (function_to_test_map ), num_discovered_tests , num_discovered_replay_tests
0 commit comments