1414import subprocess
1515import sys
1616import time
17+ from typing import Awaitable , DefaultDict , Dict , List , Match , Optional , Set , cast
1718
19+ from typing_extensions import TypedDict
1820
1921help_msg = """fast_nvcc [OPTION]... -- [NVCC_ARG]...
2022
3739)
3840parser .add_argument (
3941 "--graph" ,
40- metavar = "FILE.dot " ,
42+ metavar = "FILE.gv " ,
4143 help = "write Graphviz DOT file with execution graph" ,
4244)
4345parser .add_argument (
7880re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
7981
8082
81- def fast_nvcc_warn (warning ) :
83+ def fast_nvcc_warn (warning : str ) -> None :
8284 """
8385 Warn the user about something regarding fast_nvcc.
8486 """
8587 print (f"warning (fast_nvcc): { warning } " , file = sys .stderr )
8688
8789
88- def warn_if_windows ():
90+ def warn_if_windows () -> None :
8991 """
9092 Warn the user that using fast_nvcc on Windows might not work.
9193 """
@@ -97,7 +99,7 @@ def warn_if_windows():
9799 fast_nvcc_warn (url_vars )
98100
99101
100- def warn_if_tmpdir_flag (args ) :
102+ def warn_if_tmpdir_flag (args : List [ str ]) -> None :
101103 """
102104 Warn the user that using fast_nvcc with some flags might not work.
103105 """
@@ -121,11 +123,17 @@ def warn_if_tmpdir_flag(args):
121123 fast_nvcc_warn (f"{ url_base } #{ frag } " )
122124
123125
124- def nvcc_dryrun_data (binary , args ):
126+ class DryunData (TypedDict ):
127+ env : Dict [str , str ]
128+ commands : List [str ]
129+ exit_code : int
130+
131+
132+ def nvcc_dryrun_data (binary : str , args : List [str ]) -> DryunData :
125133 """
126134 Return parsed environment variables and commands from nvcc --dryrun.
127135 """
128- result = subprocess .run (
136+ result = subprocess .run ( # type: ignore[call-overload]
129137 [binary , "--dryrun" ] + args ,
130138 capture_output = True ,
131139 encoding = "ascii" , # this is just a guess
@@ -148,7 +156,7 @@ def nvcc_dryrun_data(binary, args):
148156 return {"env" : env , "commands" : commands , "exit_code" : result .returncode }
149157
150158
151- def warn_if_tmpdir_set (env ) :
159+ def warn_if_tmpdir_set (env : Dict [ str , str ]) -> None :
152160 """
153161 Warn the user that setting TMPDIR with fast_nvcc might not work.
154162 """
@@ -157,7 +165,7 @@ def warn_if_tmpdir_set(env):
157165 fast_nvcc_warn (url_vars )
158166
159167
160- def contains_non_executable (commands ) :
168+ def contains_non_executable (commands : List [ str ]) -> bool :
161169 for command in commands :
162170 # This is to deal with special command dry-run result from NVCC such as:
163171 # ```
@@ -170,7 +178,7 @@ def contains_non_executable(commands):
170178 return False
171179
172180
173- def module_id_contents (command ) :
181+ def module_id_contents (command : List [ str ]) -> str :
174182 """
175183 Guess the contents of the .module_id file contained within command.
176184 """
@@ -187,7 +195,7 @@ def module_id_contents(command):
187195 return f"_{ len (middle )} _{ middle } _{ suffix } "
188196
189197
190- def unique_module_id_files (commands ) :
198+ def unique_module_id_files (commands : List [ str ]) -> List [ str ] :
191199 """
192200 Give each command its own .module_id filename instead of sharing.
193201 """
@@ -196,7 +204,7 @@ def unique_module_id_files(commands):
196204 for i , line in enumerate (commands ):
197205 arr = []
198206
199- def uniqueify (s ) :
207+ def uniqueify (s : Match [ str ]) -> str :
200208 filename = re .sub (r"\-(\d+)" , r"-\1-" + str (i ), s .group (0 ))
201209 arr .append (filename )
202210 return filename
@@ -212,14 +220,19 @@ def uniqueify(s):
212220 return uniqueified
213221
214222
215- def make_rm_force (commands ) :
223+ def make_rm_force (commands : List [ str ]) -> List [ str ] :
216224 """
217225 Add --force to all rm commands.
218226 """
219227 return [f"{ c } --force" if c .startswith ("rm " ) else c for c in commands ]
220228
221229
222- def print_verbose_output (* , env , commands , filename ):
230+ def print_verbose_output (
231+ * ,
232+ env : Dict [str , str ],
233+ commands : List [List [str ]],
234+ filename : str ,
235+ ) -> None :
223236 """
224237 Human-readably write nvcc --dryrun data to stderr.
225238 """
@@ -234,21 +247,24 @@ def print_verbose_output(*, env, commands, filename):
234247 print (f'#{ " " * len (prefix )} { part } ' , file = f )
235248
236249
237- def straight_line_dependencies (commands ):
250+ Graph = List [Set [int ]]
251+
252+
253+ def straight_line_dependencies (commands : List [str ]) -> Graph :
238254 """
239255 Return a straight-line dependency graph.
240256 """
241257 return [({i - 1 } if i > 0 else set ()) for i in range (len (commands ))]
242258
243259
244- def files_mentioned (command ) :
260+ def files_mentioned (command : str ) -> List [ str ] :
245261 """
246262 Return fully-qualified names of all tmp files referenced by command.
247263 """
248264 return [f"/tmp/{ match .group (1 )} " for match in re .finditer (re_tmp , command )]
249265
250266
251- def nvcc_data_dependencies (commands ) :
267+ def nvcc_data_dependencies (commands : List [ str ]) -> Graph :
252268 """
253269 Return a list of the set of dependencies for each command.
254270 """
@@ -261,8 +277,8 @@ def nvcc_data_dependencies(commands):
261277 # data dependency is sort of flipped, because the steps that use the
262278 # files generated by cicc need to wait for the fatbinary step to
263279 # finish first
264- tmp_files = {}
265- fatbins = collections .defaultdict (set )
280+ tmp_files : Dict [ str , int ] = {}
281+ fatbins : DefaultDict [ int , Set [ str ]] = collections .defaultdict (set )
266282 graph = []
267283 for i , line in enumerate (commands ):
268284 deps = set ()
@@ -284,13 +300,13 @@ def nvcc_data_dependencies(commands):
284300 return graph
285301
286302
287- def is_weakly_connected (graph ) :
303+ def is_weakly_connected (graph : Graph ) -> bool :
288304 """
289305 Return true iff graph is weakly connected.
290306 """
291307 if not graph :
292308 return True
293- neighbors = [set () for _ in graph ]
309+ neighbors : List [ Set [ int ]] = [set () for _ in graph ]
294310 for node , predecessors in enumerate (graph ):
295311 for pred in predecessors :
296312 neighbors [pred ].add (node )
@@ -307,20 +323,25 @@ def is_weakly_connected(graph):
307323 return len (found ) == len (graph )
308324
309325
310- def warn_if_not_weakly_connected (graph ) :
326+ def warn_if_not_weakly_connected (graph : Graph ) -> None :
311327 """
312328 Warn the user if the execution graph is not weakly connected.
313329 """
314330 if not is_weakly_connected (graph ):
315331 fast_nvcc_warn ("execution graph is not (weakly) connected" )
316332
317333
318- def print_dot_graph (* , commands , graph , filename ):
334+ def print_dot_graph (
335+ * ,
336+ commands : List [List [str ]],
337+ graph : Graph ,
338+ filename : str ,
339+ ) -> None :
319340 """
320341 Print a DOT file displaying short versions of the commands in graph.
321342 """
322343
323- def name (k ) :
344+ def name (k : int ) -> str :
324345 return f'"{ k } { os .path .basename (commands [k ][0 ])} "'
325346
326347 with open (filename , "w" ) as f :
@@ -334,7 +355,23 @@ def name(k):
334355 print ("}" , file = f )
335356
336357
337- async def run_command (command , * , env , deps , gather_data , i , save ):
358+ class Result (TypedDict , total = False ):
359+ exit_code : int
360+ stdout : bytes
361+ stderr : bytes
362+ time : float
363+ files : Dict [str , int ]
364+
365+
366+ async def run_command (
367+ command : str ,
368+ * ,
369+ env : Dict [str , str ],
370+ deps : Set [Awaitable [Result ]],
371+ gather_data : bool ,
372+ i : int ,
373+ save : Optional [str ],
374+ ) -> Result :
338375 """
339376 Run the command with the given env after waiting for deps.
340377 """
@@ -352,8 +389,8 @@ async def run_command(command, *, env, deps, gather_data, i, save):
352389 stderr = asyncio .subprocess .PIPE ,
353390 )
354391 stdout , stderr = await proc .communicate ()
355- code = proc .returncode
356- results = {"exit_code" : code , "stdout" : stdout , "stderr" : stderr }
392+ code = cast ( int , proc .returncode )
393+ results : Result = {"exit_code" : code , "stdout" : stdout , "stderr" : stderr }
357394 if gather_data :
358395 t2 = time .monotonic ()
359396 results ["time" ] = t2 - t1
@@ -373,16 +410,23 @@ async def run_command(command, *, env, deps, gather_data, i, save):
373410 return results
374411
375412
376- async def run_graph (* , env , commands , graph , gather_data = False , save = None ):
413+ async def run_graph (
414+ * ,
415+ env : Dict [str , str ],
416+ commands : List [str ],
417+ graph : Graph ,
418+ gather_data : bool = False ,
419+ save : Optional [str ] = None ,
420+ ) -> List [Result ]:
377421 """
378422 Return outputs/errors (and optionally time/file info) from commands.
379423 """
380- tasks = []
424+ tasks : List [ Awaitable [ Result ]] = []
381425 for i , (command , indices ) in enumerate (zip (commands , graph )):
382426 deps = {tasks [j ] for j in indices }
383427 tasks .append (
384428 asyncio .create_task (
385- run_command (
429+ run_command ( # type: ignore[attr-defined]
386430 command ,
387431 env = env ,
388432 deps = deps ,
@@ -395,7 +439,7 @@ async def run_graph(*, env, commands, graph, gather_data=False, save=None):
395439 return [await task for task in tasks ]
396440
397441
398- def print_command_outputs (command_results ) :
442+ def print_command_outputs (command_results : List [ Result ]) -> None :
399443 """
400444 Print captured stdout and stderr from commands.
401445 """
@@ -404,11 +448,16 @@ def print_command_outputs(command_results):
404448 sys .stderr .write (result .get ("stderr" , b"" ).decode ("ascii" ))
405449
406450
407- def write_log_csv (command_parts , command_results , * , filename ):
451+ def write_log_csv (
452+ command_parts : List [List [str ]],
453+ command_results : List [Result ],
454+ * ,
455+ filename : str ,
456+ ) -> None :
408457 """
409458 Write a CSV file of the times and /tmp file sizes from each command.
410459 """
411- tmp_files = []
460+ tmp_files : List [ str ] = []
412461 for result in command_results :
413462 tmp_files .extend (result .get ("files" , {}).keys ())
414463 with open (filename , "w" , newline = "" ) as csvfile :
@@ -421,7 +470,7 @@ def write_log_csv(command_parts, command_results, *, filename):
421470 writer .writerow ({** row , ** result .get ("files" , {})})
422471
423472
424- def exit_code (results ) :
473+ def exit_code (results : List [ Result ]) -> int :
425474 """
426475 Aggregate individual exit codes into a single code.
427476 """
@@ -432,11 +481,18 @@ def exit_code(results):
432481 return 0
433482
434483
435- def wrap_nvcc (args , config = default_config ):
484+ def wrap_nvcc (
485+ args : List [str ],
486+ config : argparse .Namespace = default_config ,
487+ ) -> int :
436488 return subprocess .call ([config .nvcc ] + args )
437489
438490
439- def fast_nvcc (args , * , config = default_config ):
491+ def fast_nvcc (
492+ args : List [str ],
493+ * ,
494+ config : argparse .Namespace = default_config ,
495+ ) -> int :
440496 """
441497 Emulate the result of calling the given nvcc binary with args.
442498
@@ -472,7 +528,7 @@ def fast_nvcc(args, *, config=default_config):
472528 if config .sequential :
473529 graph = straight_line_dependencies (commands )
474530 results = asyncio .run (
475- run_graph (
531+ run_graph ( # type: ignore[attr-defined]
476532 env = env ,
477533 commands = commands ,
478534 graph = graph ,
@@ -483,10 +539,10 @@ def fast_nvcc(args, *, config=default_config):
483539 print_command_outputs (results )
484540 if config .table :
485541 write_log_csv (command_parts , results , filename = config .table )
486- return exit_code ([dryrun_data ] + results )
542+ return exit_code ([dryrun_data ] + results ) # type: ignore[arg-type, operator]
487543
488544
489- def our_arg (arg ) :
545+ def our_arg (arg : str ) -> bool :
490546 return arg != "--"
491547
492548
0 commit comments