1414import logging
1515import os
1616import re
17+ import sys
1718from argparse import Namespace
1819from typing import Any , Optional
1920
3132
3233_log = logging .getLogger (__name__ )
3334
34- _CLICK_AVAILABLE = RequirementCache ("click " )
35+ _JSONARGPARSE_AVAILABLE = RequirementCache ("jsonargparse " )
3536_LIGHTNING_SDK_AVAILABLE = RequirementCache ("lightning_sdk" )
3637
38+ if _JSONARGPARSE_AVAILABLE :
39+ from jsonargparse import ArgumentParser
40+
3741_SUPPORTED_ACCELERATORS = ("cpu" , "gpu" , "cuda" , "mps" , "tpu" , "auto" )
3842
3943
@@ -45,127 +49,112 @@ def _get_supported_strategies() -> list[str]:
4549 return [strategy for strategy in available_strategies if not re .match (excluded , strategy )]
4650
4751
48- if _CLICK_AVAILABLE :
49- import click
52+ def _build_parser () -> "ArgumentParser" :
53+ """Build the jsonargparse-based CLI parser with subcommands."""
54+ if not _JSONARGPARSE_AVAILABLE : # pragma: no cover
55+ raise RuntimeError (
56+ "To use the Lightning Fabric CLI, you must have `jsonargparse` installed. "
57+ "Install it by running `pip install -U jsonargparse`."
58+ )
5059
51- @click .group ()
52- def _main () -> None :
53- pass
60+ parser = ArgumentParser (description = "Lightning Fabric command line tool" )
61+ subcommands = parser .add_subcommands ()
5462
55- @_main .command (
56- "run" ,
57- context_settings = {
58- "ignore_unknown_options" : True ,
59- },
60- )
61- @click .argument (
62- "script" ,
63- type = click .Path (exists = True ),
64- )
65- @click .option (
63+ # run subcommand
64+ run_parser = ArgumentParser (description = "Run a Lightning Fabric script." )
65+ run_parser .add_argument (
6666 "--accelerator" ,
67- type = click .Choice (_SUPPORTED_ACCELERATORS ),
67+ type = str ,
68+ choices = _SUPPORTED_ACCELERATORS ,
6869 default = None ,
6970 help = "The hardware accelerator to run on." ,
7071 )
71- @ click . option (
72+ run_parser . add_argument (
7273 "--strategy" ,
73- type = click .Choice (_get_supported_strategies ()),
74+ type = str ,
75+ choices = _get_supported_strategies (),
7476 default = None ,
7577 help = "Strategy for how to run across multiple devices." ,
7678 )
77- @ click . option (
79+ run_parser . add_argument (
7880 "--devices" ,
7981 type = str ,
8082 default = "1" ,
8183 help = (
82- "Number of devices to run on (`` int`` ), which devices to run on (`` list`` or `` str`` ), or `` 'auto'``. "
83- " The value applies per node."
84+ "Number of devices to run on (int), which devices to run on (list or str), or 'auto'. "
85+ "The value applies per node."
8486 ),
8587 )
86- @click .option (
87- "--num-nodes" ,
88+ run_parser .add_argument (
8889 "--num_nodes" ,
90+ "--num-nodes" ,
8991 type = int ,
9092 default = 1 ,
9193 help = "Number of machines (nodes) for distributed execution." ,
9294 )
93- @click .option (
94- "--node-rank" ,
95+ run_parser .add_argument (
9596 "--node_rank" ,
97+ "--node-rank" ,
9698 type = int ,
9799 default = 0 ,
98100 help = (
99- "The index of the machine (node) this command gets started on. Must be a number in the range"
100- " 0, ..., num_nodes - 1."
101+ "The index of the machine (node) this command gets started on. Must be a number in the range "
102+ "0, ..., num_nodes - 1."
101103 ),
102104 )
103- @click .option (
104- "--main-address" ,
105+ run_parser .add_argument (
105106 "--main_address" ,
107+ "--main-address" ,
106108 type = str ,
107109 default = "127.0.0.1" ,
108110 help = "The hostname or IP address of the main machine (usually the one with node_rank = 0)." ,
109111 )
110- @click .option (
111- "--main-port" ,
112+ run_parser .add_argument (
112113 "--main_port" ,
114+ "--main-port" ,
113115 type = int ,
114116 default = 29400 ,
115117 help = "The main port to connect to the main machine." ,
116118 )
117- @ click . option (
119+ run_parser . add_argument (
118120 "--precision" ,
119- type = click .Choice (get_args (_PRECISION_INPUT_STR ) + get_args (_PRECISION_INPUT_STR_ALIAS )),
121+ type = str ,
122+ choices = list (get_args (_PRECISION_INPUT_STR )) + list (get_args (_PRECISION_INPUT_STR_ALIAS )),
120123 default = None ,
121124 help = (
122- "Double precision (`` 64-true`` or ``64`` ), full precision (`` 32-true`` or ``32`` ), "
123- "half precision (`` 16-mixed`` or ``16`` ) or bfloat16 precision (`` bf16-mixed`` or `` bf16``) "
125+ "Double precision (' 64-true' or '64' ), full precision (' 32-true' or '32' ), "
126+ "half precision (' 16-mixed' or '16' ) or bfloat16 precision (' bf16-mixed' or ' bf16'). "
124127 ),
125128 )
126- @click .argument ("script_args" , nargs = - 1 , type = click .UNPROCESSED )
127- def _run (** kwargs : Any ) -> None :
128- """Run a Lightning Fabric script.
129-
130- SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object.
131-
132- SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
133- there.
134-
135- """
136- script_args = list (kwargs .pop ("script_args" , []))
137- main (args = Namespace (** kwargs ), script_args = script_args )
129+ run_parser .add_argument (
130+ "script" ,
131+ type = str ,
132+ help = "Path to the Python script with the code to run. The script must contain a Fabric object." ,
133+ )
134+ subcommands .add_subcommand ("run" , run_parser , help = "Run a Lightning Fabric script" )
138135
139- @_main .command (
140- "consolidate" ,
141- context_settings = {
142- "ignore_unknown_options" : True ,
143- },
136+ # consolidate subcommand
137+ con_parser = ArgumentParser (
138+ description = "Convert a distributed/sharded checkpoint into a single file that can be loaded with torch.load()."
144139 )
145- @ click . argument (
140+ con_parser . add_argument (
146141 "checkpoint_folder" ,
147- type = click .Path (exists = True ),
142+ type = str ,
143+ help = "Path to the checkpoint folder to consolidate." ,
148144 )
149- @ click . option (
145+ con_parser . add_argument (
150146 "--output_file" ,
151- type = click . Path ( exists = True ) ,
147+ type = str ,
152148 default = None ,
153149 help = (
154- "Path to the file where the converted checkpoint should be saved. The file should not already exist."
155- " If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
156- " and a '.consolidated' suffix."
150+ "Path to the file where the converted checkpoint should be saved. The file should not already exist. "
151+ "If not provided, the file will be saved next to the input checkpoint folder with the same name and a "
152+ "'.consolidated' suffix."
157153 ),
158154 )
159- def _consolidate (checkpoint_folder : str , output_file : Optional [str ]) -> None :
160- """Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
161-
162- Only supports FSDP sharded checkpoints at the moment.
155+ subcommands .add_subcommand ("consolidate" , con_parser , help = "Consolidate a distributed checkpoint" )
163156
164- """
165- args = Namespace (checkpoint_folder = checkpoint_folder , output_file = output_file )
166- config = _process_cli_args (args )
167- checkpoint = _load_distributed_checkpoint (config .checkpoint_folder )
168- torch .save (checkpoint , config .output_file )
157+ return parser
169158
170159
171160def _set_env_variables (args : Namespace ) -> None :
@@ -234,12 +223,44 @@ def main(args: Namespace, script_args: Optional[list[str]] = None) -> None:
234223 _torchrun_launch (args , script_args or [])
235224
236225
237- if __name__ == "__main__" :
238- if not _CLICK_AVAILABLE : # pragma: no cover
226+ def _run_command (cfg : Namespace , script_args : list [str ]) -> None :
227+ """Execute the 'run' subcommand with the provided config and extra script args."""
228+ main (args = Namespace (** cfg ), script_args = script_args )
229+
230+
231+ def _consolidate_command (cfg : Namespace ) -> None :
232+ """Execute the 'consolidate' subcommand with the provided config."""
233+ args = Namespace (checkpoint_folder = cfg .checkpoint_folder , output_file = cfg .output_file )
234+ config = _process_cli_args (args )
235+ checkpoint = _load_distributed_checkpoint (config .checkpoint_folder )
236+ torch .save (checkpoint , config .output_file )
237+
238+
239+ def cli_main (argv : Optional [list [str ]] = None ) -> None :
240+ """Entry point for the Fabric CLI using jsonargparse."""
241+ if not _JSONARGPARSE_AVAILABLE : # pragma: no cover
239242 _log .error (
240- "To use the Lightning Fabric CLI, you must have `click ` installed."
241- " Install it by running `pip install -U click `."
243+ "To use the Lightning Fabric CLI, you must have `jsonargparse ` installed."
244+ " Install it by running `pip install -U jsonargparse `."
242245 )
243246 raise SystemExit (1 )
244247
245- _run ()
248+ parser = _build_parser ()
249+ # parse_known_args so that for 'run' we can forward unknown args to the user script
250+ cfg , unknown = parser .parse_known_args (argv )
251+
252+ if not getattr (cfg , "subcommand" , None ):
253+ parser .print_help ()
254+ return
255+
256+ if cfg .subcommand == "run" :
257+ # unknown contains the script's own args
258+ _run_command (cfg .run , unknown )
259+ elif cfg .subcommand == "consolidate" :
260+ _consolidate_command (cfg .consolidate )
261+ else : # pragma: no cover
262+ parser .print_help ()
263+
264+
265+ if __name__ == "__main__" :
266+ cli_main (sys .argv [1 :])
0 commit comments