Skip to content

Commit acf5b20

Browse files
Revert "Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (pytorch#167496)"
This reverts commit 8f4dc30. Reverted pytorch#167496 on behalf of https://github.com/atalman due to Failing validations - https://github.com/pytorch/test-infra/actions/runs/19513141127/job/55857898996 ([comment](pytorch#167496 (comment)))
1 parent 2e1821b commit acf5b20

File tree

8 files changed

+94
-388
lines changed

8 files changed

+94
-388
lines changed

.ci/pytorch/smoke_test/check_binary_symbols.py

Lines changed: 0 additions & 338 deletions
Original file line numberDiff line numberDiff line change
@@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
100100
)
101101

102102

103-
def _compile_and_extract_symbols(
104-
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
105-
) -> list[str]:
106-
"""
107-
Helper to compile a C++ file and extract all symbols.
108-
109-
Args:
110-
cpp_content: C++ source code to compile
111-
compile_flags: Compilation flags
112-
exclude_list: List of symbol names to exclude. Defaults to ["main"].
113-
114-
Returns:
115-
List of all symbols found in the object file (excluding those in exclude_list).
116-
"""
117-
import subprocess
118-
import tempfile
119-
120-
if exclude_list is None:
121-
exclude_list = ["main"]
122-
123-
with tempfile.TemporaryDirectory() as tmpdir:
124-
tmppath = Path(tmpdir)
125-
cpp_file = tmppath / "test.cpp"
126-
obj_file = tmppath / "test.o"
127-
128-
cpp_file.write_text(cpp_content)
129-
130-
result = subprocess.run(
131-
compile_flags + [str(cpp_file), "-o", str(obj_file)],
132-
capture_output=True,
133-
text=True,
134-
timeout=60,
135-
)
136-
137-
if result.returncode != 0:
138-
raise RuntimeError(f"Compilation failed: {result.stderr}")
139-
140-
symbols = get_symbols(str(obj_file))
141-
142-
# Return all symbol names, excluding those in the exclude list
143-
return [name for _addr, _stype, name in symbols if name not in exclude_list]
144-
145-
146-
def check_stable_only_symbols(install_root: Path) -> None:
147-
"""
148-
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
149-
150-
This approach tests:
151-
1. WITHOUT macros -> many torch symbols exposed
152-
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
153-
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
154-
4. WITH both macros -> zero torch symbols (all hidden)
155-
"""
156-
include_dir = install_root / "include"
157-
assert include_dir.exists(), f"Expected {include_dir} to be present"
158-
159-
test_cpp_content = """
160-
// Main torch C++ API headers
161-
#include <torch/torch.h>
162-
#include <torch/all.h>
163-
164-
// ATen tensor library
165-
#include <ATen/ATen.h>
166-
167-
// Core c10 headers (commonly used)
168-
#include <c10/core/Device.h>
169-
#include <c10/core/DeviceType.h>
170-
#include <c10/core/ScalarType.h>
171-
#include <c10/core/TensorOptions.h>
172-
#include <c10/util/Optional.h>
173-
174-
int main() { return 0; }
175-
"""
176-
177-
base_compile_flags = [
178-
"g++",
179-
"-std=c++17",
180-
f"-I{include_dir}",
181-
f"-I{include_dir}/torch/csrc/api/include",
182-
"-c", # Compile only, don't link
183-
]
184-
185-
# Compile WITHOUT any macros
186-
symbols_without = _compile_and_extract_symbols(
187-
cpp_content=test_cpp_content,
188-
compile_flags=base_compile_flags,
189-
)
190-
191-
# We expect constexpr symbols, inline functions used by other headers etc.
192-
# to produce symbols
193-
num_symbols_without = len(symbols_without)
194-
print(f"Found {num_symbols_without} symbols without any macros defined")
195-
assert num_symbols_without != 0, (
196-
"Expected a non-zero number of symbols without any macros"
197-
)
198-
199-
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
200-
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
201-
202-
symbols_with_stable_only = _compile_and_extract_symbols(
203-
cpp_content=test_cpp_content,
204-
compile_flags=compile_flags_with_stable_only,
205-
)
206-
207-
num_symbols_with_stable_only = len(symbols_with_stable_only)
208-
assert num_symbols_with_stable_only == 0, (
209-
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
210-
)
211-
212-
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
213-
compile_flags_with_target_version = base_compile_flags + [
214-
"-DTORCH_TARGET_VERSION=1"
215-
]
216-
217-
symbols_with_target_version = _compile_and_extract_symbols(
218-
cpp_content=test_cpp_content,
219-
compile_flags=compile_flags_with_target_version,
220-
)
221-
222-
num_symbols_with_target_version = len(symbols_with_target_version)
223-
assert num_symbols_with_target_version == 0, (
224-
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
225-
)
226-
227-
# Compile WITH both macros (expect 0 symbols)
228-
compile_flags_with_both = base_compile_flags + [
229-
"-DTORCH_STABLE_ONLY",
230-
"-DTORCH_TARGET_VERSION=1",
231-
]
232-
233-
symbols_with_both = _compile_and_extract_symbols(
234-
cpp_content=test_cpp_content,
235-
compile_flags=compile_flags_with_both,
236-
)
237-
238-
num_symbols_with_both = len(symbols_with_both)
239-
assert num_symbols_with_both == 0, (
240-
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
241-
)
242-
243-
244-
def check_stable_api_symbols(install_root: Path) -> None:
245-
"""
246-
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
247-
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
248-
"""
249-
include_dir = install_root / "include"
250-
assert include_dir.exists(), f"Expected {include_dir} to be present"
251-
252-
stable_dir = include_dir / "torch" / "csrc" / "stable"
253-
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
254-
255-
stable_headers = list(stable_dir.rglob("*.h"))
256-
if not stable_headers:
257-
raise RuntimeError("Could not find any stable headers")
258-
259-
includes = []
260-
for header in stable_headers:
261-
rel_path = header.relative_to(include_dir)
262-
includes.append(f"#include <{rel_path.as_posix()}>")
263-
264-
includes_str = "\n".join(includes)
265-
test_stable_content = f"""
266-
{includes_str}
267-
int main() {{ return 0; }}
268-
"""
269-
270-
compile_flags = [
271-
"g++",
272-
"-std=c++17",
273-
f"-I{include_dir}",
274-
f"-I{include_dir}/torch/csrc/api/include",
275-
"-c",
276-
"-DTORCH_STABLE_ONLY",
277-
]
278-
279-
symbols_stable = _compile_and_extract_symbols(
280-
cpp_content=test_stable_content,
281-
compile_flags=compile_flags,
282-
)
283-
num_symbols_stable = len(symbols_stable)
284-
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
285-
assert num_symbols_stable > 0, (
286-
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
287-
f"but found {num_symbols_stable} symbols"
288-
)
289-
290-
291-
def check_headeronly_symbols(install_root: Path) -> None:
292-
"""
293-
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
294-
"""
295-
include_dir = install_root / "include"
296-
assert include_dir.exists(), f"Expected {include_dir} to be present"
297-
298-
# Find all headers in torch/headeronly
299-
headeronly_dir = include_dir / "torch" / "headeronly"
300-
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
301-
headeronly_headers = list(headeronly_dir.rglob("*.h"))
302-
if not headeronly_headers:
303-
raise RuntimeError("Could not find any headeronly headers")
304-
305-
# Filter out platform-specific headers that may not compile everywhere
306-
platform_specific_keywords = [
307-
"cpu/vec",
308-
]
309-
310-
filtered_headers = []
311-
for header in headeronly_headers:
312-
rel_path = header.relative_to(include_dir).as_posix()
313-
if not any(
314-
keyword in rel_path.lower() for keyword in platform_specific_keywords
315-
):
316-
filtered_headers.append(header)
317-
318-
includes = []
319-
for header in filtered_headers:
320-
rel_path = header.relative_to(include_dir)
321-
includes.append(f"#include <{rel_path.as_posix()}>")
322-
323-
includes_str = "\n".join(includes)
324-
test_headeronly_content = f"""
325-
{includes_str}
326-
int main() {{ return 0; }}
327-
"""
328-
329-
compile_flags = [
330-
"g++",
331-
"-std=c++17",
332-
f"-I{include_dir}",
333-
f"-I{include_dir}/torch/csrc/api/include",
334-
"-c",
335-
"-DTORCH_STABLE_ONLY",
336-
]
337-
338-
symbols_headeronly = _compile_and_extract_symbols(
339-
cpp_content=test_headeronly_content,
340-
compile_flags=compile_flags,
341-
)
342-
num_symbols_headeronly = len(symbols_headeronly)
343-
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
344-
assert num_symbols_headeronly > 0, (
345-
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
346-
f"but found {num_symbols_headeronly} symbols"
347-
)
348-
349-
350-
def check_aoti_shim_symbols(install_root: Path) -> None:
351-
"""
352-
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
353-
"""
354-
include_dir = install_root / "include"
355-
assert include_dir.exists(), f"Expected {include_dir} to be present"
356-
357-
# There are no constexpr symbols etc., so we need to actually use functions
358-
# so that some symbols are found.
359-
test_shim_content = """
360-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
361-
int main() {
362-
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
363-
int32_t (*fp2)() = &aoti_torch_dtype_float32;
364-
(void)fp1; (void)fp2;
365-
return 0;
366-
}
367-
"""
368-
369-
compile_flags = [
370-
"g++",
371-
"-std=c++17",
372-
f"-I{include_dir}",
373-
f"-I{include_dir}/torch/csrc/api/include",
374-
"-c",
375-
"-DTORCH_STABLE_ONLY",
376-
]
377-
378-
symbols_shim = _compile_and_extract_symbols(
379-
cpp_content=test_shim_content,
380-
compile_flags=compile_flags,
381-
)
382-
num_symbols_shim = len(symbols_shim)
383-
assert num_symbols_shim > 0, (
384-
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
385-
f"but found {num_symbols_shim} symbols"
386-
)
387-
388-
389-
def check_stable_c_shim_symbols(install_root: Path) -> None:
390-
"""
391-
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
392-
"""
393-
include_dir = install_root / "include"
394-
assert include_dir.exists(), f"Expected {include_dir} to be present"
395-
396-
# Check if the stable C shim exists
397-
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
398-
if not stable_shim.exists():
399-
raise RuntimeError("Could not find stable c shim")
400-
401-
# There are no constexpr symbols etc., so we need to actually use functions
402-
# so that some symbols are found.
403-
test_stable_shim_content = """
404-
#include <torch/csrc/stable/c/shim.h>
405-
int main() {
406-
// Reference stable C API functions to create undefined symbols
407-
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
408-
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
409-
(void)fp1; (void)fp2;
410-
return 0;
411-
}
412-
"""
413-
414-
compile_flags = [
415-
"g++",
416-
"-std=c++17",
417-
f"-I{include_dir}",
418-
f"-I{include_dir}/torch/csrc/api/include",
419-
"-c",
420-
"-DTORCH_STABLE_ONLY",
421-
]
422-
423-
symbols_stable_shim = _compile_and_extract_symbols(
424-
cpp_content=test_stable_shim_content,
425-
compile_flags=compile_flags,
426-
)
427-
num_symbols_stable_shim = len(symbols_stable_shim)
428-
assert num_symbols_stable_shim > 0, (
429-
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
430-
f"but found {num_symbols_stable_shim} symbols"
431-
)
432-
433-
434103
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
435104
print(f"lib: {lib}")
436105
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@@ -460,13 +129,6 @@ def main() -> None:
460129
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
461130
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
462131

463-
# Check symbols when TORCH_STABLE_ONLY is defined
464-
check_stable_only_symbols(install_root)
465-
check_stable_api_symbols(install_root)
466-
check_headeronly_symbols(install_root)
467-
check_aoti_shim_symbols(install_root)
468-
check_stable_c_shim_symbols(install_root)
469-
470132

471133
if __name__ == "__main__":
472134
main()

0 commit comments

Comments
 (0)