@@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
100100 )
101101
102102
103- def _compile_and_extract_symbols (
104- cpp_content : str , compile_flags : list [str ], exclude_list : list [str ] | None = None
105- ) -> list [str ]:
106- """
107- Helper to compile a C++ file and extract all symbols.
108-
109- Args:
110- cpp_content: C++ source code to compile
111- compile_flags: Compilation flags
112- exclude_list: List of symbol names to exclude. Defaults to ["main"].
113-
114- Returns:
115- List of all symbols found in the object file (excluding those in exclude_list).
116- """
117- import subprocess
118- import tempfile
119-
120- if exclude_list is None :
121- exclude_list = ["main" ]
122-
123- with tempfile .TemporaryDirectory () as tmpdir :
124- tmppath = Path (tmpdir )
125- cpp_file = tmppath / "test.cpp"
126- obj_file = tmppath / "test.o"
127-
128- cpp_file .write_text (cpp_content )
129-
130- result = subprocess .run (
131- compile_flags + [str (cpp_file ), "-o" , str (obj_file )],
132- capture_output = True ,
133- text = True ,
134- timeout = 60 ,
135- )
136-
137- if result .returncode != 0 :
138- raise RuntimeError (f"Compilation failed: { result .stderr } " )
139-
140- symbols = get_symbols (str (obj_file ))
141-
142- # Return all symbol names, excluding those in the exclude list
143- return [name for _addr , _stype , name in symbols if name not in exclude_list ]
144-
145-
146- def check_stable_only_symbols (install_root : Path ) -> None :
147- """
148- Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
149-
150- This approach tests:
151- 1. WITHOUT macros -> many torch symbols exposed
152- 2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
153- 3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
154- 4. WITH both macros -> zero torch symbols (all hidden)
155- """
156- include_dir = install_root / "include"
157- assert include_dir .exists (), f"Expected { include_dir } to be present"
158-
159- test_cpp_content = """
160- // Main torch C++ API headers
161- #include <torch/torch.h>
162- #include <torch/all.h>
163-
164- // ATen tensor library
165- #include <ATen/ATen.h>
166-
167- // Core c10 headers (commonly used)
168- #include <c10/core/Device.h>
169- #include <c10/core/DeviceType.h>
170- #include <c10/core/ScalarType.h>
171- #include <c10/core/TensorOptions.h>
172- #include <c10/util/Optional.h>
173-
174- int main() { return 0; }
175- """
176-
177- base_compile_flags = [
178- "g++" ,
179- "-std=c++17" ,
180- f"-I{ include_dir } " ,
181- f"-I{ include_dir } /torch/csrc/api/include" ,
182- "-c" , # Compile only, don't link
183- ]
184-
185- # Compile WITHOUT any macros
186- symbols_without = _compile_and_extract_symbols (
187- cpp_content = test_cpp_content ,
188- compile_flags = base_compile_flags ,
189- )
190-
191- # We expect constexpr symbols, inline functions used by other headers etc.
192- # to produce symbols
193- num_symbols_without = len (symbols_without )
194- print (f"Found { num_symbols_without } symbols without any macros defined" )
195- assert num_symbols_without != 0 , (
196- "Expected a non-zero number of symbols without any macros"
197- )
198-
199- # Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
200- compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY" ]
201-
202- symbols_with_stable_only = _compile_and_extract_symbols (
203- cpp_content = test_cpp_content ,
204- compile_flags = compile_flags_with_stable_only ,
205- )
206-
207- num_symbols_with_stable_only = len (symbols_with_stable_only )
208- assert num_symbols_with_stable_only == 0 , (
209- f"Expected no symbols with TORCH_STABLE_ONLY macro, but found { num_symbols_with_stable_only } "
210- )
211-
212- # Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
213- compile_flags_with_target_version = base_compile_flags + [
214- "-DTORCH_TARGET_VERSION=1"
215- ]
216-
217- symbols_with_target_version = _compile_and_extract_symbols (
218- cpp_content = test_cpp_content ,
219- compile_flags = compile_flags_with_target_version ,
220- )
221-
222- num_symbols_with_target_version = len (symbols_with_target_version )
223- assert num_symbols_with_target_version == 0 , (
224- f"Expected no symbols with TORCH_TARGET_VERSION macro, but found { num_symbols_with_target_version } "
225- )
226-
227- # Compile WITH both macros (expect 0 symbols)
228- compile_flags_with_both = base_compile_flags + [
229- "-DTORCH_STABLE_ONLY" ,
230- "-DTORCH_TARGET_VERSION=1" ,
231- ]
232-
233- symbols_with_both = _compile_and_extract_symbols (
234- cpp_content = test_cpp_content ,
235- compile_flags = compile_flags_with_both ,
236- )
237-
238- num_symbols_with_both = len (symbols_with_both )
239- assert num_symbols_with_both == 0 , (
240- f"Expected no symbols with both macros, but found { num_symbols_with_both } "
241- )
242-
243-
244- def check_stable_api_symbols (install_root : Path ) -> None :
245- """
246- Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
247- The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
248- """
249- include_dir = install_root / "include"
250- assert include_dir .exists (), f"Expected { include_dir } to be present"
251-
252- stable_dir = include_dir / "torch" / "csrc" / "stable"
253- assert stable_dir .exists (), f"Expected { stable_dir } to be present"
254-
255- stable_headers = list (stable_dir .rglob ("*.h" ))
256- if not stable_headers :
257- raise RuntimeError ("Could not find any stable headers" )
258-
259- includes = []
260- for header in stable_headers :
261- rel_path = header .relative_to (include_dir )
262- includes .append (f"#include <{ rel_path .as_posix ()} >" )
263-
264- includes_str = "\n " .join (includes )
265- test_stable_content = f"""
266- { includes_str }
267- int main() {{ return 0; }}
268- """
269-
270- compile_flags = [
271- "g++" ,
272- "-std=c++17" ,
273- f"-I{ include_dir } " ,
274- f"-I{ include_dir } /torch/csrc/api/include" ,
275- "-c" ,
276- "-DTORCH_STABLE_ONLY" ,
277- ]
278-
279- symbols_stable = _compile_and_extract_symbols (
280- cpp_content = test_stable_content ,
281- compile_flags = compile_flags ,
282- )
283- num_symbols_stable = len (symbols_stable )
284- print (f"Found { num_symbols_stable } symbols in torch/csrc/stable" )
285- assert num_symbols_stable > 0 , (
286- f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
287- f"but found { num_symbols_stable } symbols"
288- )
289-
290-
291- def check_headeronly_symbols (install_root : Path ) -> None :
292- """
293- Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
294- """
295- include_dir = install_root / "include"
296- assert include_dir .exists (), f"Expected { include_dir } to be present"
297-
298- # Find all headers in torch/headeronly
299- headeronly_dir = include_dir / "torch" / "headeronly"
300- assert headeronly_dir .exists (), f"Expected { headeronly_dir } to be present"
301- headeronly_headers = list (headeronly_dir .rglob ("*.h" ))
302- if not headeronly_headers :
303- raise RuntimeError ("Could not find any headeronly headers" )
304-
305- # Filter out platform-specific headers that may not compile everywhere
306- platform_specific_keywords = [
307- "cpu/vec" ,
308- ]
309-
310- filtered_headers = []
311- for header in headeronly_headers :
312- rel_path = header .relative_to (include_dir ).as_posix ()
313- if not any (
314- keyword in rel_path .lower () for keyword in platform_specific_keywords
315- ):
316- filtered_headers .append (header )
317-
318- includes = []
319- for header in filtered_headers :
320- rel_path = header .relative_to (include_dir )
321- includes .append (f"#include <{ rel_path .as_posix ()} >" )
322-
323- includes_str = "\n " .join (includes )
324- test_headeronly_content = f"""
325- { includes_str }
326- int main() {{ return 0; }}
327- """
328-
329- compile_flags = [
330- "g++" ,
331- "-std=c++17" ,
332- f"-I{ include_dir } " ,
333- f"-I{ include_dir } /torch/csrc/api/include" ,
334- "-c" ,
335- "-DTORCH_STABLE_ONLY" ,
336- ]
337-
338- symbols_headeronly = _compile_and_extract_symbols (
339- cpp_content = test_headeronly_content ,
340- compile_flags = compile_flags ,
341- )
342- num_symbols_headeronly = len (symbols_headeronly )
343- print (f"Found { num_symbols_headeronly } symbols in torch/headeronly" )
344- assert num_symbols_headeronly > 0 , (
345- f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
346- f"but found { num_symbols_headeronly } symbols"
347- )
348-
349-
350- def check_aoti_shim_symbols (install_root : Path ) -> None :
351- """
352- Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
353- """
354- include_dir = install_root / "include"
355- assert include_dir .exists (), f"Expected { include_dir } to be present"
356-
357- # There are no constexpr symbols etc., so we need to actually use functions
358- # so that some symbols are found.
359- test_shim_content = """
360- #include <torch/csrc/inductor/aoti_torch/c/shim.h>
361- int main() {
362- int32_t (*fp1)() = &aoti_torch_device_type_cpu;
363- int32_t (*fp2)() = &aoti_torch_dtype_float32;
364- (void)fp1; (void)fp2;
365- return 0;
366- }
367- """
368-
369- compile_flags = [
370- "g++" ,
371- "-std=c++17" ,
372- f"-I{ include_dir } " ,
373- f"-I{ include_dir } /torch/csrc/api/include" ,
374- "-c" ,
375- "-DTORCH_STABLE_ONLY" ,
376- ]
377-
378- symbols_shim = _compile_and_extract_symbols (
379- cpp_content = test_shim_content ,
380- compile_flags = compile_flags ,
381- )
382- num_symbols_shim = len (symbols_shim )
383- assert num_symbols_shim > 0 , (
384- f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
385- f"but found { num_symbols_shim } symbols"
386- )
387-
388-
389- def check_stable_c_shim_symbols (install_root : Path ) -> None :
390- """
391- Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
392- """
393- include_dir = install_root / "include"
394- assert include_dir .exists (), f"Expected { include_dir } to be present"
395-
396- # Check if the stable C shim exists
397- stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
398- if not stable_shim .exists ():
399- raise RuntimeError ("Could not find stable c shim" )
400-
401- # There are no constexpr symbols etc., so we need to actually use functions
402- # so that some symbols are found.
403- test_stable_shim_content = """
404- #include <torch/csrc/stable/c/shim.h>
405- int main() {
406- // Reference stable C API functions to create undefined symbols
407- AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
408- AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
409- (void)fp1; (void)fp2;
410- return 0;
411- }
412- """
413-
414- compile_flags = [
415- "g++" ,
416- "-std=c++17" ,
417- f"-I{ include_dir } " ,
418- f"-I{ include_dir } /torch/csrc/api/include" ,
419- "-c" ,
420- "-DTORCH_STABLE_ONLY" ,
421- ]
422-
423- symbols_stable_shim = _compile_and_extract_symbols (
424- cpp_content = test_stable_shim_content ,
425- compile_flags = compile_flags ,
426- )
427- num_symbols_stable_shim = len (symbols_stable_shim )
428- assert num_symbols_stable_shim > 0 , (
429- f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
430- f"but found { num_symbols_stable_shim } symbols"
431- )
432-
433-
434103def check_lib_symbols_for_abi_correctness (lib : str ) -> None :
435104 print (f"lib: { lib } " )
436105 cxx11_symbols = grep_symbols (lib , LIBTORCH_CXX11_PATTERNS )
@@ -460,13 +129,6 @@ def main() -> None:
460129 check_lib_symbols_for_abi_correctness (libtorch_cpu_path )
461130 check_lib_statically_linked_libstdc_cxx_abi_symbols (libtorch_cpu_path )
462131
463- # Check symbols when TORCH_STABLE_ONLY is defined
464- check_stable_only_symbols (install_root )
465- check_stable_api_symbols (install_root )
466- check_headeronly_symbols (install_root )
467- check_aoti_shim_symbols (install_root )
468- check_stable_c_shim_symbols (install_root )
469-
470132
471133if __name__ == "__main__" :
472134 main ()
0 commit comments