Skip to content

Commit e62c21f

Browse files
committed
wandb integrate sample.py
1 parent 27a02f2 commit e62c21f

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

config/sr_sr3_16_128.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"name": "FFHQ",
1818
"mode": "HR", // whether need LR img
1919
"dataroot": "dataset/ffhq_16_128",
20-
"datatype": "img", //lmdb or img, path of img files
20+
"datatype": "lmdb", //lmdb or img, path of img files
2121
"l_resolution": 16, // low resolution need to super_resolution
2222
"r_resolution": 128, // high resolution
2323
"batch_size": 4,
@@ -28,8 +28,8 @@
2828
"val": {
2929
"name": "CelebaHQ",
3030
"mode": "LRHR",
31-
"dataroot": "dataset/ffhq_16_128",
32-
"datatype": "img", //lmdb or img, path of img files
31+
"dataroot": "dataset/celebahq_16_128",
32+
"datatype": "lmdb", //lmdb or img, path of img files
3333
"l_resolution": 16,
3434
"r_resolution": 128,
3535
"data_len": 50 // data length in validation

sample.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import logging
66
import core.logger as Logger
77
import core.metrics as Metrics
8+
from core.wandb_logger import WandbLogger
89
from tensorboardX import SummaryWriter
910
import os
1011
import numpy as np
12+
import wandb
1113

1214
if __name__ == "__main__":
1315
parser = argparse.ArgumentParser()
@@ -17,6 +19,9 @@
1719
help='Run either train(training) or val(generation)', default='train')
1820
parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
1921
parser.add_argument('-debug', '-d', action='store_true')
22+
parser.add_argument('-enable_wandb', action='store_true')
23+
parser.add_argument('-log_wandb_ckpt', action='store_true')
24+
parser.add_argument('-log_eval', action='store_true')
2025

2126
# parse configs
2227
args = parser.parse_args()
@@ -35,6 +40,16 @@
3540
logger.info(Logger.dict2str(opt))
3641
tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])
3742

43+
# Initialize WandbLogger
44+
if opt['enable_wandb']:
45+
wandb_logger = WandbLogger(opt)
46+
# wandb.define_metric('validation/val_step')
47+
# wandb.define_metric('epoch')
48+
# wandb.define_metric("validation/*", step_metric="val_step")
49+
val_step = 0
50+
else:
51+
wandb_logger = None
52+
3853
# dataset
3954
for phase, dataset_opt in opt['datasets'].items():
4055
if phase == 'train' and args.phase != 'val':
@@ -78,6 +93,9 @@
7893
tb_logger.add_scalar(k, v, current_step)
7994
logger.info(message)
8095

96+
if wandb_logger:
97+
wandb_logger.log_metrics(logs)
98+
8199
# validation
82100
if current_step % opt['train']['val_freq'] == 0:
83101
result_path = '{}/{}'.format(opt['path']
@@ -100,12 +118,20 @@
100118
'Iter_{}'.format(current_step),
101119
np.transpose(sample_img, [2, 0, 1]),
102120
idx)
121+
122+
if wandb_logger:
123+
wandb_logger.log_image(f'validation_{idx}', sample_img)
124+
103125
diffusion.set_new_noise_schedule(
104126
opt['model']['beta_schedule']['train'], schedule_phase='train')
105127

106128
if current_step % opt['train']['save_checkpoint_freq'] == 0:
107129
logger.info('Saving models and training states.')
108130
diffusion.save_network(current_epoch, current_step)
131+
132+
if wandb_logger and opt['log_wandb_ckpt']:
133+
wandb_logger.log_checkpoint(current_epoch, current_step)
134+
109135
# save model
110136
logger.info('End of training.')
111137
else:

0 commit comments

Comments
 (0)