Skip to content

Commit b825e1d

Browse files
suopytorchmergebot
authored andcommitted
Revert autoformat of tools/fast_nvcc/fast_nvcc.py
This was an Meta-internal change that seems to have deleted a bunch of types and is thus causing us to fail mypy type checking. Reverting that portion of the change. Pull Request resolved: pytorch#77327 Approved by: https://github.com/qihqi
1 parent 257c55f commit b825e1d

File tree

1 file changed

+94
-38
lines changed

1 file changed

+94
-38
lines changed

tools/fast_nvcc/fast_nvcc.py

Lines changed: 94 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import subprocess
1515
import sys
1616
import time
17+
from typing import Awaitable, DefaultDict, Dict, List, Match, Optional, Set, cast
1718

19+
from typing_extensions import TypedDict
1820

1921
help_msg = """fast_nvcc [OPTION]... -- [NVCC_ARG]...
2022
@@ -37,7 +39,7 @@
3739
)
3840
parser.add_argument(
3941
"--graph",
40-
metavar="FILE.dot",
42+
metavar="FILE.gv",
4143
help="write Graphviz DOT file with execution graph",
4244
)
4345
parser.add_argument(
@@ -78,14 +80,14 @@
7880
re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
7981

8082

81-
def fast_nvcc_warn(warning):
83+
def fast_nvcc_warn(warning: str) -> None:
8284
"""
8385
Warn the user about something regarding fast_nvcc.
8486
"""
8587
print(f"warning (fast_nvcc): {warning}", file=sys.stderr)
8688

8789

88-
def warn_if_windows():
90+
def warn_if_windows() -> None:
8991
"""
9092
Warn the user that using fast_nvcc on Windows might not work.
9193
"""
@@ -97,7 +99,7 @@ def warn_if_windows():
9799
fast_nvcc_warn(url_vars)
98100

99101

100-
def warn_if_tmpdir_flag(args):
102+
def warn_if_tmpdir_flag(args: List[str]) -> None:
101103
"""
102104
Warn the user that using fast_nvcc with some flags might not work.
103105
"""
@@ -121,11 +123,17 @@ def warn_if_tmpdir_flag(args):
121123
fast_nvcc_warn(f"{url_base}#{frag}")
122124

123125

124-
def nvcc_dryrun_data(binary, args):
126+
class DryunData(TypedDict):
127+
env: Dict[str, str]
128+
commands: List[str]
129+
exit_code: int
130+
131+
132+
def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
125133
"""
126134
Return parsed environment variables and commands from nvcc --dryrun.
127135
"""
128-
result = subprocess.run(
136+
result = subprocess.run( # type: ignore[call-overload]
129137
[binary, "--dryrun"] + args,
130138
capture_output=True,
131139
encoding="ascii", # this is just a guess
@@ -148,7 +156,7 @@ def nvcc_dryrun_data(binary, args):
148156
return {"env": env, "commands": commands, "exit_code": result.returncode}
149157

150158

151-
def warn_if_tmpdir_set(env):
159+
def warn_if_tmpdir_set(env: Dict[str, str]) -> None:
152160
"""
153161
Warn the user that setting TMPDIR with fast_nvcc might not work.
154162
"""
@@ -157,7 +165,7 @@ def warn_if_tmpdir_set(env):
157165
fast_nvcc_warn(url_vars)
158166

159167

160-
def contains_non_executable(commands):
168+
def contains_non_executable(commands: List[str]) -> bool:
161169
for command in commands:
162170
# This is to deal with special command dry-run result from NVCC such as:
163171
# ```
@@ -170,7 +178,7 @@ def contains_non_executable(commands):
170178
return False
171179

172180

173-
def module_id_contents(command):
181+
def module_id_contents(command: List[str]) -> str:
174182
"""
175183
Guess the contents of the .module_id file contained within command.
176184
"""
@@ -187,7 +195,7 @@ def module_id_contents(command):
187195
return f"_{len(middle)}_{middle}_{suffix}"
188196

189197

190-
def unique_module_id_files(commands):
198+
def unique_module_id_files(commands: List[str]) -> List[str]:
191199
"""
192200
Give each command its own .module_id filename instead of sharing.
193201
"""
@@ -196,7 +204,7 @@ def unique_module_id_files(commands):
196204
for i, line in enumerate(commands):
197205
arr = []
198206

199-
def uniqueify(s):
207+
def uniqueify(s: Match[str]) -> str:
200208
filename = re.sub(r"\-(\d+)", r"-\1-" + str(i), s.group(0))
201209
arr.append(filename)
202210
return filename
@@ -212,14 +220,19 @@ def uniqueify(s):
212220
return uniqueified
213221

214222

215-
def make_rm_force(commands):
223+
def make_rm_force(commands: List[str]) -> List[str]:
216224
"""
217225
Add --force to all rm commands.
218226
"""
219227
return [f"{c} --force" if c.startswith("rm ") else c for c in commands]
220228

221229

222-
def print_verbose_output(*, env, commands, filename):
230+
def print_verbose_output(
231+
*,
232+
env: Dict[str, str],
233+
commands: List[List[str]],
234+
filename: str,
235+
) -> None:
223236
"""
224237
Human-readably write nvcc --dryrun data to stderr.
225238
"""
@@ -234,21 +247,24 @@ def print_verbose_output(*, env, commands, filename):
234247
print(f'#{" "*len(prefix)}{part}', file=f)
235248

236249

237-
def straight_line_dependencies(commands):
250+
Graph = List[Set[int]]
251+
252+
253+
def straight_line_dependencies(commands: List[str]) -> Graph:
238254
"""
239255
Return a straight-line dependency graph.
240256
"""
241257
return [({i - 1} if i > 0 else set()) for i in range(len(commands))]
242258

243259

244-
def files_mentioned(command):
260+
def files_mentioned(command: str) -> List[str]:
245261
"""
246262
Return fully-qualified names of all tmp files referenced by command.
247263
"""
248264
return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]
249265

250266

251-
def nvcc_data_dependencies(commands):
267+
def nvcc_data_dependencies(commands: List[str]) -> Graph:
252268
"""
253269
Return a list of the set of dependencies for each command.
254270
"""
@@ -261,8 +277,8 @@ def nvcc_data_dependencies(commands):
261277
# data dependency is sort of flipped, because the steps that use the
262278
# files generated by cicc need to wait for the fatbinary step to
263279
# finish first
264-
tmp_files = {}
265-
fatbins = collections.defaultdict(set)
280+
tmp_files: Dict[str, int] = {}
281+
fatbins: DefaultDict[int, Set[str]] = collections.defaultdict(set)
266282
graph = []
267283
for i, line in enumerate(commands):
268284
deps = set()
@@ -284,13 +300,13 @@ def nvcc_data_dependencies(commands):
284300
return graph
285301

286302

287-
def is_weakly_connected(graph):
303+
def is_weakly_connected(graph: Graph) -> bool:
288304
"""
289305
Return true iff graph is weakly connected.
290306
"""
291307
if not graph:
292308
return True
293-
neighbors = [set() for _ in graph]
309+
neighbors: List[Set[int]] = [set() for _ in graph]
294310
for node, predecessors in enumerate(graph):
295311
for pred in predecessors:
296312
neighbors[pred].add(node)
@@ -307,20 +323,25 @@ def is_weakly_connected(graph):
307323
return len(found) == len(graph)
308324

309325

310-
def warn_if_not_weakly_connected(graph):
326+
def warn_if_not_weakly_connected(graph: Graph) -> None:
311327
"""
312328
Warn the user if the execution graph is not weakly connected.
313329
"""
314330
if not is_weakly_connected(graph):
315331
fast_nvcc_warn("execution graph is not (weakly) connected")
316332

317333

318-
def print_dot_graph(*, commands, graph, filename):
334+
def print_dot_graph(
335+
*,
336+
commands: List[List[str]],
337+
graph: Graph,
338+
filename: str,
339+
) -> None:
319340
"""
320341
Print a DOT file displaying short versions of the commands in graph.
321342
"""
322343

323-
def name(k):
344+
def name(k: int) -> str:
324345
return f'"{k} {os.path.basename(commands[k][0])}"'
325346

326347
with open(filename, "w") as f:
@@ -334,7 +355,23 @@ def name(k):
334355
print("}", file=f)
335356

336357

337-
async def run_command(command, *, env, deps, gather_data, i, save):
358+
class Result(TypedDict, total=False):
359+
exit_code: int
360+
stdout: bytes
361+
stderr: bytes
362+
time: float
363+
files: Dict[str, int]
364+
365+
366+
async def run_command(
367+
command: str,
368+
*,
369+
env: Dict[str, str],
370+
deps: Set[Awaitable[Result]],
371+
gather_data: bool,
372+
i: int,
373+
save: Optional[str],
374+
) -> Result:
338375
"""
339376
Run the command with the given env after waiting for deps.
340377
"""
@@ -352,8 +389,8 @@ async def run_command(command, *, env, deps, gather_data, i, save):
352389
stderr=asyncio.subprocess.PIPE,
353390
)
354391
stdout, stderr = await proc.communicate()
355-
code = proc.returncode
356-
results = {"exit_code": code, "stdout": stdout, "stderr": stderr}
392+
code = cast(int, proc.returncode)
393+
results: Result = {"exit_code": code, "stdout": stdout, "stderr": stderr}
357394
if gather_data:
358395
t2 = time.monotonic()
359396
results["time"] = t2 - t1
@@ -373,16 +410,23 @@ async def run_command(command, *, env, deps, gather_data, i, save):
373410
return results
374411

375412

376-
async def run_graph(*, env, commands, graph, gather_data=False, save=None):
413+
async def run_graph(
414+
*,
415+
env: Dict[str, str],
416+
commands: List[str],
417+
graph: Graph,
418+
gather_data: bool = False,
419+
save: Optional[str] = None,
420+
) -> List[Result]:
377421
"""
378422
Return outputs/errors (and optionally time/file info) from commands.
379423
"""
380-
tasks = []
424+
tasks: List[Awaitable[Result]] = []
381425
for i, (command, indices) in enumerate(zip(commands, graph)):
382426
deps = {tasks[j] for j in indices}
383427
tasks.append(
384428
asyncio.create_task(
385-
run_command(
429+
run_command( # type: ignore[attr-defined]
386430
command,
387431
env=env,
388432
deps=deps,
@@ -395,7 +439,7 @@ async def run_graph(*, env, commands, graph, gather_data=False, save=None):
395439
return [await task for task in tasks]
396440

397441

398-
def print_command_outputs(command_results):
442+
def print_command_outputs(command_results: List[Result]) -> None:
399443
"""
400444
Print captured stdout and stderr from commands.
401445
"""
@@ -404,11 +448,16 @@ def print_command_outputs(command_results):
404448
sys.stderr.write(result.get("stderr", b"").decode("ascii"))
405449

406450

407-
def write_log_csv(command_parts, command_results, *, filename):
451+
def write_log_csv(
452+
command_parts: List[List[str]],
453+
command_results: List[Result],
454+
*,
455+
filename: str,
456+
) -> None:
408457
"""
409458
Write a CSV file of the times and /tmp file sizes from each command.
410459
"""
411-
tmp_files = []
460+
tmp_files: List[str] = []
412461
for result in command_results:
413462
tmp_files.extend(result.get("files", {}).keys())
414463
with open(filename, "w", newline="") as csvfile:
@@ -421,7 +470,7 @@ def write_log_csv(command_parts, command_results, *, filename):
421470
writer.writerow({**row, **result.get("files", {})})
422471

423472

424-
def exit_code(results):
473+
def exit_code(results: List[Result]) -> int:
425474
"""
426475
Aggregate individual exit codes into a single code.
427476
"""
@@ -432,11 +481,18 @@ def exit_code(results):
432481
return 0
433482

434483

435-
def wrap_nvcc(args, config=default_config):
484+
def wrap_nvcc(
485+
args: List[str],
486+
config: argparse.Namespace = default_config,
487+
) -> int:
436488
return subprocess.call([config.nvcc] + args)
437489

438490

439-
def fast_nvcc(args, *, config=default_config):
491+
def fast_nvcc(
492+
args: List[str],
493+
*,
494+
config: argparse.Namespace = default_config,
495+
) -> int:
440496
"""
441497
Emulate the result of calling the given nvcc binary with args.
442498
@@ -472,7 +528,7 @@ def fast_nvcc(args, *, config=default_config):
472528
if config.sequential:
473529
graph = straight_line_dependencies(commands)
474530
results = asyncio.run(
475-
run_graph(
531+
run_graph( # type: ignore[attr-defined]
476532
env=env,
477533
commands=commands,
478534
graph=graph,
@@ -483,10 +539,10 @@ def fast_nvcc(args, *, config=default_config):
483539
print_command_outputs(results)
484540
if config.table:
485541
write_log_csv(command_parts, results, filename=config.table)
486-
return exit_code([dryrun_data] + results)
542+
return exit_code([dryrun_data] + results) # type: ignore[arg-type, operator]
487543

488544

489-
def our_arg(arg):
545+
def our_arg(arg: str) -> bool:
490546
return arg != "--"
491547

492548

0 commit comments

Comments
 (0)