Skip to content

Commit 8ded27a

Browse files
committed
Update code comments
1 parent 9f3de0b commit 8ded27a

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ else:
7777
```
7878
#### Step 3: Apply both diffusion training loss and reward loss:
7979
```python
80+
# reward model inference
81+
if args.task_name == 'canny':
82+
outputs = reward_model(image.to(accelerator.device), low_threshold, high_threshold)
83+
else:
84+
outputs = reward_model(image.to(accelerator.device))
85+
8086
# Determine which samples in the current batch need to calculate reward loss
8187
timestep_mask = (args.min_timestep_rewarding <= timesteps.reshape(-1, 1)) & (timesteps.reshape(-1, 1) <= args.max_timestep_rewarding)
8288

train/reward_control.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,14 +1428,14 @@ def load_model_hook(models, input_dir):
14281428
"""
14291429
Rewarding ControlNet
14301430
"""
1431-
# compute the original image
1431+
# Predict the single-step denoised latents
14321432
pred_original_sample = [
14331433
noise_scheduler.step(noise, t, noisy_latent).pred_original_sample.to(weight_dtype) \
14341434
for (noise, t, noisy_latent) in zip(model_pred, timesteps, noisy_latents)
14351435
]
14361436
pred_original_sample = torch.stack(pred_original_sample)
14371437

1438-
# compute the original image
1438+
# Map the denoised latents into RGB images
14391439
pred_original_sample = 1 / vae.config.scaling_factor * pred_original_sample
14401440
image = vae.decode(pred_original_sample.to(weight_dtype)).sample
14411441
image = (image / 2 + 0.5).clamp(0, 1)
@@ -1500,7 +1500,7 @@ def load_model_hook(models, input_dir):
15001500

15011501
labels = [x.to(accelerator.device) for x in labels] if isinstance(labels, list) else labels.to(accelerator.device)
15021502

1503-
# timestep-based filtering
1503+
# Determine which samples in the current batch need to calculate reward loss
15041504
timestep_mask = (args.min_timestep_rewarding <= timesteps.reshape(-1, 1)) & (timesteps.reshape(-1, 1) <= args.max_timestep_rewarding)
15051505

15061506
# calculate the reward loss

0 commit comments

Comments
 (0)