Skip to content

the init parameter "optimizer" of Trainer() should be a function #11157

@jacquesqiao

Description

@jacquesqiao

Background

@seiriosPlus find a bug in the new Trainer high-level API when writing a model with learning rate decay. The optimizer is created in a different program than the trainer use.

Reason and solution

class Trainer(object):
"""
Args:
train_func(callable): A function which will return loss. The loss must be a scalar.
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
place: The device place of this trainer.
"""
def __init__(self,
train_func,
optimizer,
param_path=None,
place=None,
parallel=False):
self.__stop = False

In the design of Trainer API, train_program is a function call, it will be called inside Trainer under a with scope

with framework.program_guard(self.train_program, self.startup_program):
program_func_outs = train_func()

But the init parameter optimizer is an object, it will be created outside the with scope, this is not right, so we should make optimizer also a function call that return an optimizer boject

the interface should be

class Trainer(object): def __init__(self, train_func, optimizer_func, param_path=None, place=None, parallel=False):

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions