Skip to content

Commit 368211d

Browse files
authored
Merge pull request #805 from Separius/patch-1
Remove duplicate code in create_scheduler
2 parents 3cdaf5e + abf3e04 commit 368211d

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

timm/scheduler/scheduler_factory.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def create_scheduler(args, optimizer):
2121
noise_range = lr_noise * num_epochs
2222
else:
2323
noise_range = None
24+
noise_args = dict(
25+
noise_range_t=noise_range,
26+
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
27+
noise_std=getattr(args, 'lr_noise_std', 1.),
28+
noise_seed=getattr(args, 'seed', 42),
29+
)
2430

2531
lr_scheduler = None
2632
if args.sched == 'cosine':
@@ -34,10 +40,7 @@ def create_scheduler(args, optimizer):
3440
warmup_t=args.warmup_epochs,
3541
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
3642
t_in_epochs=True,
37-
noise_range_t=noise_range,
38-
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
39-
noise_std=getattr(args, 'lr_noise_std', 1.),
40-
noise_seed=getattr(args, 'seed', 42),
43+
**noise_args,
4144
)
4245
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
4346
elif args.sched == 'tanh':
@@ -50,10 +53,7 @@ def create_scheduler(args, optimizer):
5053
warmup_t=args.warmup_epochs,
5154
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
5255
t_in_epochs=True,
53-
noise_range_t=noise_range,
54-
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
55-
noise_std=getattr(args, 'lr_noise_std', 1.),
56-
noise_seed=getattr(args, 'seed', 42),
56+
**noise_args,
5757
)
5858
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
5959
elif args.sched == 'step':
@@ -63,10 +63,7 @@ def create_scheduler(args, optimizer):
6363
decay_rate=args.decay_rate,
6464
warmup_lr_init=args.warmup_lr,
6565
warmup_t=args.warmup_epochs,
66-
noise_range_t=noise_range,
67-
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
68-
noise_std=getattr(args, 'lr_noise_std', 1.),
69-
noise_seed=getattr(args, 'seed', 42),
66+
**noise_args,
7067
)
7168
elif args.sched == 'multistep':
7269
lr_scheduler = MultiStepLRScheduler(
@@ -75,10 +72,7 @@ def create_scheduler(args, optimizer):
7572
decay_rate=args.decay_rate,
7673
warmup_lr_init=args.warmup_lr,
7774
warmup_t=args.warmup_epochs,
78-
noise_range_t=noise_range,
79-
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
80-
noise_std=getattr(args, 'lr_noise_std', 1.),
81-
noise_seed=getattr(args, 'seed', 42),
75+
**noise_args,
8276
)
8377
elif args.sched == 'plateau':
8478
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
@@ -91,10 +85,7 @@ def create_scheduler(args, optimizer):
9185
warmup_lr_init=args.warmup_lr,
9286
warmup_t=args.warmup_epochs,
9387
cooldown_t=0,
94-
noise_range_t=noise_range,
95-
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
96-
noise_std=getattr(args, 'lr_noise_std', 1.),
97-
noise_seed=getattr(args, 'seed', 42),
88+
**noise_args,
9889
)
9990

10091
return lr_scheduler, num_epochs

0 commit comments

Comments
 (0)