Skip to content

Commit c8e46d1

Browse files
authored
[src] Fix nnet1 proj-lstm bug where gradient clipping not used; thx:@cbtpkzm (#2696)
1 parent 7741f7c commit c8e46d1

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

src/nnet/nnet-lstm-projected.h

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -655,21 +655,21 @@ class LstmProjected : public MultistreamComponent {
655655
const CuMatrixBase<BaseFloat> &diff) {
656656

657657
// apply the gradient clipping,
658-
if (clip_gradient_ > 0.0) {
659-
w_gifo_x_corr_.ApplyFloor(-clip_gradient_);
660-
w_gifo_x_corr_.ApplyCeiling(clip_gradient_);
661-
w_gifo_r_corr_.ApplyFloor(-clip_gradient_);
662-
w_gifo_r_corr_.ApplyCeiling(clip_gradient_);
663-
bias_corr_.ApplyFloor(-clip_gradient_);
664-
bias_corr_.ApplyCeiling(clip_gradient_);
665-
w_r_m_corr_.ApplyFloor(-clip_gradient_);
666-
w_r_m_corr_.ApplyCeiling(clip_gradient_);
667-
peephole_i_c_corr_.ApplyFloor(-clip_gradient_);
668-
peephole_i_c_corr_.ApplyCeiling(clip_gradient_);
669-
peephole_f_c_corr_.ApplyFloor(-clip_gradient_);
670-
peephole_f_c_corr_.ApplyCeiling(clip_gradient_);
671-
peephole_o_c_corr_.ApplyFloor(-clip_gradient_);
672-
peephole_o_c_corr_.ApplyCeiling(clip_gradient_);
658+
if (grad_clip_ > 0.0) {
659+
w_gifo_x_corr_.ApplyFloor(-grad_clip_);
660+
w_gifo_x_corr_.ApplyCeiling(grad_clip_);
661+
w_gifo_r_corr_.ApplyFloor(-grad_clip_);
662+
w_gifo_r_corr_.ApplyCeiling(grad_clip_);
663+
bias_corr_.ApplyFloor(-grad_clip_);
664+
bias_corr_.ApplyCeiling(grad_clip_);
665+
w_r_m_corr_.ApplyFloor(-grad_clip_);
666+
w_r_m_corr_.ApplyCeiling(grad_clip_);
667+
peephole_i_c_corr_.ApplyFloor(-grad_clip_);
668+
peephole_i_c_corr_.ApplyCeiling(grad_clip_);
669+
peephole_f_c_corr_.ApplyFloor(-grad_clip_);
670+
peephole_f_c_corr_.ApplyCeiling(grad_clip_);
671+
peephole_o_c_corr_.ApplyFloor(-grad_clip_);
672+
peephole_o_c_corr_.ApplyCeiling(grad_clip_);
673673
}
674674

675675
const BaseFloat lr = opts_.learn_rate;
@@ -698,9 +698,6 @@ class LstmProjected : public MultistreamComponent {
698698
// buffer for transfering state across batches,
699699
CuMatrix<BaseFloat> prev_nnet_state_;
700700

701-
// gradient-clipping value,
702-
BaseFloat clip_gradient_;
703-
704701
// feed-forward connections: from x to [g, i, f, o]
705702
CuMatrix<BaseFloat> w_gifo_x_;
706703
CuMatrix<BaseFloat> w_gifo_x_corr_;

0 commit comments

Comments
 (0)