|
5 | 5 | import logging
|
6 | 6 | import core.logger as Logger
|
7 | 7 | import core.metrics as Metrics
|
| 8 | +from core.wandb_logger import WandbLogger |
8 | 9 | from tensorboardX import SummaryWriter
|
9 | 10 | import os
|
10 | 11 | import numpy as np
|
| 12 | +import wandb |
11 | 13 |
|
12 | 14 | if __name__ == "__main__":
|
13 | 15 | parser = argparse.ArgumentParser()
|
|
17 | 19 | help='Run either train(training) or val(generation)', default='train')
|
18 | 20 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
|
19 | 21 | parser.add_argument('-debug', '-d', action='store_true')
|
| 22 | + parser.add_argument('-enable_wandb', action='store_true') |
| 23 | + parser.add_argument('-log_wandb_ckpt', action='store_true') |
| 24 | + parser.add_argument('-log_eval', action='store_true') |
20 | 25 |
|
21 | 26 | # parse configs
|
22 | 27 | args = parser.parse_args()
|
|
35 | 40 | logger.info(Logger.dict2str(opt))
|
36 | 41 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])
|
37 | 42 |
|
| 43 | + # Initialize WandbLogger |
| 44 | + if opt['enable_wandb']: |
| 45 | + wandb_logger = WandbLogger(opt) |
| 46 | + # wandb.define_metric('validation/val_step') |
| 47 | + # wandb.define_metric('epoch') |
| 48 | + # wandb.define_metric("validation/*", step_metric="val_step") |
| 49 | + val_step = 0 |
| 50 | + else: |
| 51 | + wandb_logger = None |
| 52 | + |
38 | 53 | # dataset
|
39 | 54 | for phase, dataset_opt in opt['datasets'].items():
|
40 | 55 | if phase == 'train' and args.phase != 'val':
|
|
78 | 93 | tb_logger.add_scalar(k, v, current_step)
|
79 | 94 | logger.info(message)
|
80 | 95 |
|
| 96 | + if wandb_logger: |
| 97 | + wandb_logger.log_metrics(logs) |
| 98 | + |
81 | 99 | # validation
|
82 | 100 | if current_step % opt['train']['val_freq'] == 0:
|
83 | 101 | result_path = '{}/{}'.format(opt['path']
|
|
100 | 118 | 'Iter_{}'.format(current_step),
|
101 | 119 | np.transpose(sample_img, [2, 0, 1]),
|
102 | 120 | idx)
|
| 121 | + |
| 122 | + if wandb_logger: |
| 123 | + wandb_logger.log_image(f'validation_{idx}', sample_img) |
| 124 | + |
103 | 125 | diffusion.set_new_noise_schedule(
|
104 | 126 | opt['model']['beta_schedule']['train'], schedule_phase='train')
|
105 | 127 |
|
106 | 128 | if current_step % opt['train']['save_checkpoint_freq'] == 0:
|
107 | 129 | logger.info('Saving models and training states.')
|
108 | 130 | diffusion.save_network(current_epoch, current_step)
|
| 131 | + |
| 132 | + if wandb_logger and opt['log_wandb_ckpt']: |
| 133 | + wandb_logger.log_checkpoint(current_epoch, current_step) |
| 134 | + |
109 | 135 | # save model
|
110 | 136 | logger.info('End of training.')
|
111 | 137 | else:
|
|
0 commit comments