|
10 | 10 | import yaml |
11 | 11 | from ignite.contrib.engines import common |
12 | 12 | from ignite.engine import Engine |
| 13 | + |
| 14 | +#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.limit_sec) { :::# |
13 | 15 | from ignite.engine.events import Events |
14 | | -from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine |
| 16 | + |
| 17 | +#::: } :::# |
| 18 | +#::: if (it.save_training || it.save_evaluation) { :::# |
| 19 | +from ignite.handlers import ( |
| 20 | + Checkpoint, |
| 21 | + DiskSaver, |
| 22 | + global_step_from_engine, |
| 23 | +) # usort: skip |
| 24 | + |
| 25 | +#::: } else { :::# |
| 26 | +from ignite.handlers import Checkpoint |
| 27 | + |
| 28 | +#::: } :::# |
| 29 | +#::: if (it.patience) { :::# |
15 | 30 | from ignite.handlers.early_stopping import EarlyStopping |
| 31 | + |
| 32 | +#::: } :::# |
| 33 | +#::: if (it.terminate_on_nan) { :::# |
16 | 34 | from ignite.handlers.terminate_on_nan import TerminateOnNan |
| 35 | + |
| 36 | +#::: } :::# |
| 37 | +#::: if (it.limit_sec) { :::# |
17 | 38 | from ignite.handlers.time_limit import TimeLimit |
| 39 | + |
| 40 | +#::: } :::# |
18 | 41 | from ignite.utils import setup_logger |
19 | 42 |
|
20 | 43 |
|
@@ -141,72 +164,6 @@ def setup_logging(config: Any) -> Logger: |
141 | 164 | return logger |
142 | 165 |
|
143 | 166 |
|
144 | | -#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.limit_sec) { :::# |
145 | | - |
146 | | - |
147 | | -def setup_handlers( |
148 | | - trainer: Engine, |
149 | | - evaluator: Engine, |
150 | | - config: Any, |
151 | | - to_save_train: Optional[dict] = None, |
152 | | - to_save_eval: Optional[dict] = None, |
153 | | -): |
154 | | - """Setup Ignite handlers.""" |
155 | | - |
156 | | - ckpt_handler_train = ckpt_handler_eval = None |
157 | | - #::: if (it.save_training || it.save_evaluation) { :::# |
158 | | - # checkpointing |
159 | | - saver = DiskSaver(config.output_dir / "checkpoints", require_empty=False) |
160 | | - #::: if (it.save_training) { :::# |
161 | | - ckpt_handler_train = Checkpoint( |
162 | | - to_save_train, |
163 | | - saver, |
164 | | - filename_prefix=config.filename_prefix, |
165 | | - n_saved=config.n_saved, |
166 | | - ) |
167 | | - trainer.add_event_handler( |
168 | | - Events.ITERATION_COMPLETED(every=config.save_every_iters), |
169 | | - ckpt_handler_train, |
170 | | - ) |
171 | | - #::: } :::# |
172 | | - #::: if (it.save_evaluation) { :::# |
173 | | - global_step_transform = None |
174 | | - if to_save_train.get("trainer", None) is not None: |
175 | | - global_step_transform = global_step_from_engine(to_save_train["trainer"]) |
176 | | - ckpt_handler_eval = Checkpoint( |
177 | | - to_save_eval, |
178 | | - saver, |
179 | | - filename_prefix="best", |
180 | | - n_saved=config.n_saved, |
181 | | - global_step_transform=global_step_transform, |
182 | | - ) |
183 | | - evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1), ckpt_handler_eval) |
184 | | - #::: } :::# |
185 | | - #::: } :::# |
186 | | - |
187 | | - #::: if (it.patience) { :::# |
188 | | - # early stopping |
189 | | - |
190 | | - es = EarlyStopping(config.patience, score_fn, trainer) |
191 | | - evaluator.add_event_handler(Events.EPOCH_COMPLETED, es) |
192 | | - #::: } :::# |
193 | | - |
194 | | - #::: if (it.terminate_on_nan) { :::# |
195 | | - # terminate on nan |
196 | | - trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) |
197 | | - #::: } :::# |
198 | | - |
199 | | - #::: if (it.limit_sec) { :::# |
200 | | - # time limit |
201 | | - trainer.add_event_handler(Events.ITERATION_COMPLETED, TimeLimit(config.limit_sec)) |
202 | | - #::: } :::# |
203 | | - #::: if (it.save_training || it.save_evaluation) { :::# |
204 | | - return ckpt_handler_train, ckpt_handler_eval |
205 | | - #::: } :::# |
206 | | - |
207 | | - |
208 | | -#::: } :::# |
209 | | - |
210 | 167 | #::: if (it.logger) { :::# |
211 | 168 |
|
212 | 169 |
|
|
0 commit comments