Skip to content

Commit 5c71669

Browse files
author
xz
committed
clearer doc and progress bar for reproducing results in the paper (#9)
1 parent 829ca7a commit 5c71669

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

configs/gmpi.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ GMPI:
1616
# ["FFHQ256", "FFHQ512", "FFHQ1024", "AFHQCat", "MetFaces"]
1717
dataset: "FFHQ256"
1818

19+
total_iters: 5001
1920
n_epochs: 3000
2021
sample_interval: 200
2122
output_dir: "debug"

docs/TRAIN_EVAL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ The command to evaluate the trained model is in [eval.sh](../gmpi/eval/eval.sh).
181181
- Depth metrics,
182182
- Pose accuracy metrics.
183183

184-
Run the following command to evalute the model:
184+
In the paper, all results come from checkpoints at 5000 iterations. Run the following command to evalute the model:
185185
```bash
186186
bash ${GMPI_ROOT}/gmpi/eval/eval.sh \
187187
${GMPI_ROOT} \

gmpi/train.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,13 @@ def train(rank, world_size, config, master_port, run_dataset):
354354

355355
torch.manual_seed(rank)
356356
dataloader = None
357-
total_progress_bar = tqdm(total=opt.n_epochs, desc="Total progress", dynamic_ncols=True)
358-
total_progress_bar.update(discriminator.epoch)
359-
interior_step_bar = tqdm(dynamic_ncols=True)
357+
total_progress_bar = tqdm(total=opt.total_iters, desc="Total progress", dynamic_ncols=True)
358+
total_progress_bar.update(discriminator.step)
360359

361360
for epoch_i in range(opt.n_epochs):
362361

363-
total_progress_bar.update(1)
362+
if discriminator.step > opt.total_iters:
363+
break
364364

365365
metadata = curriculums.extract_metadata(curriculum, discriminator.step)
366366

@@ -402,20 +402,19 @@ def train(rank, world_size, config, master_port, run_dataset):
402402
step_next_upsample = curriculums.next_upsample_step(curriculum, discriminator.step)
403403
step_last_upsample = curriculums.last_upsample_step(curriculum, discriminator.step)
404404

405-
interior_step_bar.reset(total=(step_next_upsample - step_last_upsample))
406-
interior_step_bar.set_description(f"Progress to next stage")
407-
interior_step_bar.update((discriminator.step - step_last_upsample))
408-
409-
if rank == 0:
410-
logger.info(f"\nstep_next_upsample: {step_next_upsample}, {step_next_upsample}\n")
411-
412405
logger.info(f"New epoch {epoch_i}.\n\n")
413406

414407
# NOTE: this is requred to make distributed sampler shuffle dataset.
415408
data_sampler.set_epoch(epoch_i)
416409

417410
for batch_i, batch_data in enumerate(dataloader):
418411

412+
if discriminator.step > opt.total_iters:
413+
break
414+
415+
if rank == 0:
416+
total_progress_bar.update(1)
417+
419418
# pred_yaws_real/pred_pitches_real: [B, 1]
420419
imgs, flat_w2c_mats_real, _, pred_yaws_real, pred_pitches_real = batch_data
421420
# NOTE: we only condition on rotation
@@ -786,11 +785,10 @@ def train(rank, world_size, config, master_port, run_dataset):
786785
ema2.update(generator_ddp.parameters())
787786

788787
if rank == 0:
789-
interior_step_bar.update(1)
790788
if discriminator.step % 10 == 0:
791789
tqdm.write(
792790
f"[Experiment: {opt.output_dir}] "
793-
f"[Epoch: {discriminator.epoch}/{opt.n_epochs}] "
791+
# f"[Epoch: {discriminator.epoch}/{opt.n_epochs}] "
794792
f"[D loss: {d_loss.item()}] [G loss: {g_loss.item()}] "
795793
f"[Step: {discriminator.step}] "
796794
f"[Img Size: {metadata['img_size']}] [Batch Size: {metadata['batch_size']}] "
@@ -1008,7 +1006,7 @@ def train(rank, world_size, config, master_port, run_dataset):
10081006

10091007
# fmt: on
10101008

1011-
if opt.eval_freq > 0 and (discriminator.step + 1) % opt.eval_freq == 0:
1009+
if (opt.eval_freq > 0) and (discriminator.step > 0) and (discriminator.step % opt.eval_freq == 0):
10121010
generated_dir = os.path.join("./evaluation/generated")
10131011

10141012
if rank == 0:

0 commit comments

Comments
 (0)