Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions templates/gan/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,28 @@ def train_function(

Parameters
----------
- config: config object
- engine: Engine instance
- batch: batch in current iteration
- netD: discriminator model
- netG: generator model
- loss_fn: nn.Module loss
- optimizerD: discriminator optimizer
- optimizerG: generator optimizer
- device: device to use for training
- real_labels: real label tensor
- fake_labels: fake label tensor
config
config object
engine
Engine instance
batch
batch in current iteration
netD
discriminator model
netG
generator model
loss_fn
nn.Module loss
optimizerD
discriminator optimizer
optimizerG
generator optimizer
device
device to use for training
real_labels
real label tensor
fake_labels
fake label tensor

Returns
-------
Expand Down
4 changes: 3 additions & 1 deletion templates/gan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def initialize(

Parameters
----------
config:
config
config object
num_channels
number of channels for Generator

Returns
-------
Expand Down
38 changes: 25 additions & 13 deletions templates/image_classification/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@ def train_function(

Parameters
----------
- config: config object
- engine: Engine instance
- batch: batch in current iteration
- model: nn.Module model
- loss_fn: nn.Module loss
- optimizer: torch optimizer
- device: device to use for training
config
config object
engine
Engine instance
batch
batch in current iteration
model
nn.Module model
loss_fn
nn.Module loss
optimizer
torch optimizer
device
device to use for training

Returns
-------
Expand Down Expand Up @@ -86,12 +93,16 @@ def evaluate_function(

Parameters
----------
- config: config object
- engine: Engine instance
- batch: batch in current iteration
- model: nn.Module model
- loss_fn: nn.Module loss
- device: device to use for training
config
config object
engine
Engine instance
batch
batch in current iteration
model
nn.Module model
device
device to use for training

Returns
-------
Expand Down Expand Up @@ -129,6 +140,7 @@ def create_trainers(**kwargs) -> Tuple[Engine, Engine]:
**kwargs,
)
)
kwargs.pop('optimizer')
eval_engine = Engine(
lambda e, b: evaluate_function(
engine=e,
Expand Down
40 changes: 27 additions & 13 deletions templates/single/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@ def train_function(

Parameters
----------
- config: config object
- engine: Engine instance
- batch: batch in current iteration
- model: nn.Module model
- loss_fn: nn.Module loss
- optimizer: torch optimizer
- device: device to use for training
config
config object
engine
Engine instance
batch
batch in current iteration
model
nn.Module model
loss_fn
nn.Module loss
optimizer
torch optimizer
device
device to use for training

Returns
-------
Expand Down Expand Up @@ -87,12 +94,18 @@ def evaluate_function(

Parameters
----------
- config: config object
- engine: Engine instance
- batch: batch in current iteration
- model: nn.Module model
- loss_fn: nn.Module loss
- device: device to use for training
config
config object
engine
Engine instance
batch
batch in current iteration
model
nn.Module model
loss_fn
nn.Module loss
device
device to use for training

Returns
-------
Expand Down Expand Up @@ -133,6 +146,7 @@ def create_trainers(**kwargs) -> Tuple[Engine, Engine]:
**kwargs,
)
)
kwargs.pop('optimizer')
eval_engine = Engine(
lambda e, b: evaluate_function(
engine=e,
Expand Down