@@ -354,13 +354,13 @@ def train(rank, world_size, config, master_port, run_dataset):
354354
355355 torch .manual_seed (rank )
356356 dataloader = None
357- total_progress_bar = tqdm (total = opt .n_epochs , desc = "Total progress" , dynamic_ncols = True )
358- total_progress_bar .update (discriminator .epoch )
359- interior_step_bar = tqdm (dynamic_ncols = True )
357+ total_progress_bar = tqdm (total = opt .total_iters , desc = "Total progress" , dynamic_ncols = True )
358+ total_progress_bar .update (discriminator .step )
360359
361360 for epoch_i in range (opt .n_epochs ):
362361
363- total_progress_bar .update (1 )
362+ if discriminator .step > opt .total_iters :
363+ break
364364
365365 metadata = curriculums .extract_metadata (curriculum , discriminator .step )
366366
@@ -402,20 +402,19 @@ def train(rank, world_size, config, master_port, run_dataset):
402402 step_next_upsample = curriculums .next_upsample_step (curriculum , discriminator .step )
403403 step_last_upsample = curriculums .last_upsample_step (curriculum , discriminator .step )
404404
405- interior_step_bar .reset (total = (step_next_upsample - step_last_upsample ))
406- interior_step_bar .set_description (f"Progress to next stage" )
407- interior_step_bar .update ((discriminator .step - step_last_upsample ))
408-
409- if rank == 0 :
410- logger .info (f"\n step_next_upsample: { step_next_upsample } , { step_next_upsample } \n " )
411-
412405 logger .info (f"New epoch { epoch_i } .\n \n " )
413406
414407 # NOTE: this is requred to make distributed sampler shuffle dataset.
415408 data_sampler .set_epoch (epoch_i )
416409
417410 for batch_i , batch_data in enumerate (dataloader ):
418411
412+ if discriminator .step > opt .total_iters :
413+ break
414+
415+ if rank == 0 :
416+ total_progress_bar .update (1 )
417+
419418 # pred_yaws_real/pred_pitches_real: [B, 1]
420419 imgs , flat_w2c_mats_real , _ , pred_yaws_real , pred_pitches_real = batch_data
421420 # NOTE: we only condition on rotation
@@ -786,11 +785,10 @@ def train(rank, world_size, config, master_port, run_dataset):
786785 ema2 .update (generator_ddp .parameters ())
787786
788787 if rank == 0 :
789- interior_step_bar .update (1 )
790788 if discriminator .step % 10 == 0 :
791789 tqdm .write (
792790 f"[Experiment: { opt .output_dir } ] "
793- f"[Epoch: { discriminator .epoch } /{ opt .n_epochs } ] "
791+ # f"[Epoch: {discriminator.epoch}/{opt.n_epochs}] "
794792 f"[D loss: { d_loss .item ()} ] [G loss: { g_loss .item ()} ] "
795793 f"[Step: { discriminator .step } ] "
796794 f"[Img Size: { metadata ['img_size' ]} ] [Batch Size: { metadata ['batch_size' ]} ] "
@@ -1008,7 +1006,7 @@ def train(rank, world_size, config, master_port, run_dataset):
10081006
10091007 # fmt: on
10101008
1011- if opt .eval_freq > 0 and (discriminator .step + 1 ) % opt .eval_freq == 0 :
1009+ if ( opt .eval_freq > 0 ) and (discriminator .step > 0 ) and ( discriminator . step % opt .eval_freq == 0 ) :
10121010 generated_dir = os .path .join ("./evaluation/generated" )
10131011
10141012 if rank == 0 :
0 commit comments