@@ -303,9 +303,9 @@ def get_config(self, gindex, pindex, group):
303303 config ["eps" ] = group ["eps" ]
304304 config ["weight_decay" ] = group ["weight_decay" ]
305305 config ["lr" ] = group ["lr" ]
306- config ["alpha" ] = group .get ("alpha" )
307- config ["t_alpha" ] = group .get ("t_alpha" )
308- config ["t_beta3" ] = group .get ("t_beta3" )
306+ config ["alpha" ] = group .get ("alpha" , 0.0 )
307+ config ["t_alpha" ] = group .get ("t_alpha" , 0 )
308+ config ["t_beta3" ] = group .get ("t_beta3" , 0 )
309309 config ["optim_bits" ] = self .args .optim_bits
310310 config ["min_8bit_size" ] = self .args .min_8bit_size
311311 config ["percentile_clipping" ] = self .args .percentile_clipping
@@ -530,7 +530,7 @@ def update_step(self, group, p, gindex, pindex):
530530 state ["state2" ],
531531 config ["betas" ][1 ],
532532 config ["betas" ][2 ] if len (config ["betas" ]) >= 3 else 0.0 ,
533- config [ "alpha" ] ,
533+ config . get ( "alpha" , 0.0 ) ,
534534 config ["weight_decay" ],
535535 gnorm_scale ,
536536 state ["unorm_vec" ] if config ["max_unorm" ] > 0.0 else None ,
@@ -575,7 +575,7 @@ def update_step(self, group, p, gindex, pindex):
575575 config ["betas" ][0 ],
576576 config ["betas" ][1 ],
577577 config ["betas" ][2 ] if len (config ["betas" ]) >= 3 else 0.0 ,
578- config [ "alpha" ] ,
578+ config . get ( "alpha" , 0.0 ) ,
579579 config ["eps" ],
580580 step ,
581581 config ["lr" ],
0 commit comments