State
A state is introduced in Engine
to store the output of the process_function, current epoch, iteration and other helpful information. Each Engine
contains a State
, which includes the following:
- engine.state.seed: Seed to set at each data “epoch”.
- engine.state.epoch: Number of epochs the engine has completed. Initializated as 0 and the first epoch is 1.
- engine.state.iteration: Number of iterations the engine has completed. Initialized as 0 and the first iteration is 1.
- engine.state.max_epochs: Number of epochs to run for. Initializated as 1.
- engine.state.output: The output of the process_function defined for the
Engine
. See below. - etc
Other attributes can be found in the docs of State
.
In the code below, engine.state.output
will store the batch loss. This output is used to print the loss at every iteration.
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() def on_iteration_completed(engine): iteration = engine.state.iteration epoch = engine.state.epoch loss = engine.state.output print(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {loss}") trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)
Since there is no restrictions on the output of process_function, Ignite provides output_transform
argument for its ignite.metrics
and ignite.handlers
. Argument output_transform
is a function used to transform engine.state.output
for intended use. Below we’ll see different types of engine.state.output
and how to transform them.
In the code below, engine.state.output
will be a list of loss, y_pred, y for the processed batch. If we want to attach Accuracy
to the engine, output_transform
will be needed to get y_pred
and y
from engine.state.output
. Let’s see how that is done:
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item(), y_pred, y trainer = Engine(update) @trainer.on(Events.EPOCH_COMPLETED) def print_loss(engine): epoch = engine.state.epoch loss = engine.state.output[0] print (f'Epoch {epoch}: train_loss = {loss}') accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]]) accuracy.attach(trainer, 'acc') trainer.run(data, max_epochs=10)
Similar to above, but this time the output of the process_function is a dictionary of loss, y_pred, y for the processed batch, this is how the user can use output_transform
to get y_pred
and y
from engine.state.output
. See below:
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return {'loss': loss.item(), 'y_pred': y_pred, 'y': y} trainer = Engine(update) @trainer.on(Events.EPOCH_COMPLETED) def print_loss(engine): epoch = engine.state.epoch loss = engine.state.output['loss'] print (f'Epoch {epoch}: train_loss = {loss}') accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']]) accuracy.attach(trainer, 'acc') trainer.run(data, max_epochs=10)
Note:
A good practice is to use State
also as a storage of user data created in update or handler functions. For example, we would like to save new_attribute in the state:
def user_handler_function(engine): engine.state.new_attribute = 12345