-
- Notifications
You must be signed in to change notification settings - Fork 656
Description
🚀 Feature
Problem
I am using multiple engines in a nested way. That means, that if e.g. the main engine fires Events.EPOCH_COMPLETED
, another child engine is attached to this event and shall run only one epoch. A solution would be to run the child engine with engine.run(max_epochs=1)
but then, the engine fires setup and teardown events like Events.STARTED
and Events.COMPLETED
each time I call engine.run(max_epochs=1)
even though those events are for the purpose to only be fired one time, as far as I understand.
Since my child engine must setup and teardown things, I could attach event handlers to the main engine, but the handlers I want to attach do not know that a main engine exists. The handlers shouldn't have any access to the main engine.
Solution
I need some functionality that the engine can do the following (This is just an example with a bad but possible way of implementing this):
engine.run_epoch(max_epochs=3) # runs setup and first epoch, fires events from `STARTED` to `EPOCH_COMPLETED` engine.run_epoch(max_epochs=3) # runs second epoch, fires events from `EPOCH_STARTED`to `EPOCH_COMPLETED` engine.run_epoche(max_epochs=3) # runs last epoch and teardown, fires events from `EPOCH_STARTED` to `COMPLETED`
Instead of calling a function, one could create an iterable object from engine.run
and get the same behavior in a nicer way:
epoch_iterator = iterable_engine.run(max_epochs=3) next(epoch_iterator) # runs setup and first episode, fires events from `STARTED` to `EPOCH_COMPLETED` next(epoch_iterator) # runs second episode, fires events from `EPOCH_STARTED`to `EPOCH_COMPLETED` next(epoch_iterator) # runs last episode and teardown, fires events from `EPOCH_STARTED` to `COMPLETED`
Or one can use loops:
iterable_engine = IterableEngine(lambda x, y: 0.) iterable_engine.add_event_handler(Events.STARTED, lambda x: print("started")) iterable_engine.add_event_handler(Events.EPOCH_STARTED, lambda x: print("epoch started")) iterable_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("epoch completed")) iterable_engine.add_event_handler(Events.COMPLETED, lambda x: print("completed")) epoch_iterator = iterable_engine.run([1], max_epochs=3) for state in epoch_iterator: print("This is outside engine.run")
The output is:
started epoch started epoch completed This is outside engine.run epoch started epoch completed This is outside engine.run epoch started epoch completed This is outside engine.run completed
I added the code at the bottom where I subclass from Engine
and overload the _internal_run
method with a copy of the original method and added one line, where I add the yield
statement. You can execute it and it outputs the example.
To switch between the actual and this behavior, one could put yield
into an if statement and pass an additional argument to engine.run
, e.g. engine.run(max_epochs=3, return_generator=True
) or set a flag of the engine to enable this functionality.
What do you think?
Code:
import time from ignite._utils import _to_hours_mins_secs from ignite.engine import Engine from ignite.engine import Events from ignite.engine import State class IterableEngine(Engine): def _internal_run(self) -> State: self.should_terminate = self.should_terminate_single_epoch = False self._init_timers(self.state) try: start_time = time.time() self._fire_event(Events.STARTED) while self.state.epoch < self.state.max_epochs and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) if self._dataloader_iter is None: self._setup_engine() time_taken = self._run_once_on_dataset() # time is available for handlers but must be update after fire self.state.times[Events.EPOCH_COMPLETED.name] = time_taken handlers_start_time = time.time() if self.should_terminate: self._fire_event(Events.TERMINATE) else: self._fire_event(Events.EPOCH_COMPLETED) time_taken += time.time() - handlers_start_time # update time wrt handlers self.state.times[Events.EPOCH_COMPLETED.name] = time_taken hours, mins, secs = _to_hours_mins_secs(time_taken) self.logger.info( "Epoch[%s] Complete. Time taken: %02d:%02d:%02d" % (self.state.epoch, hours, mins, secs) ) if self.should_terminate: break yield self.state time_taken = time.time() - start_time # time is available for handlers but must be update after fire self.state.times[Events.COMPLETED.name] = time_taken handlers_start_time = time.time() self._fire_event(Events.COMPLETED) time_taken += time.time() - handlers_start_time # update time wrt handlers self.state.times[Events.COMPLETED.name] = time_taken hours, mins, secs = _to_hours_mins_secs(time_taken) self.logger.info("Engine run complete. Time taken: %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._dataloader_iter = None self.logger.error("Engine run is terminating due to exception: %s.", str(e)) self._handle_exception(e) self._dataloader_iter = None return self.state if __name__ == '__main__': iterable_engine = IterableEngine(lambda x, y: 0.) iterable_engine.add_event_handler(Events.STARTED, lambda x: print("started")) iterable_engine.add_event_handler(Events.EPOCH_STARTED, lambda x: print("epoch started")) iterable_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("epoch completed")) iterable_engine.add_event_handler(Events.COMPLETED, lambda x: print("completed")) epoch_iterator = iterable_engine.run([1], max_epochs=3) for state in epoch_iterator: print("This is outside engine.run")