@@ -245,7 +245,7 @@ void Randomize(const CuMatrixBase<double> &src,
245245template <typename Real>
246246void NormalizePerRow (const CuMatrixBase<Real>& in, const Real target_rms,
247247 const bool add_log_stddev, CuMatrixBase<Real>* out) {
248- const Real kSquaredNormFloor = 1.35525271560688e -20 ; // 2^-66
248+ const Real kSquaredNormFloor = 1.3552527156068805425e -20 ; // 2^-66
249249 if (add_log_stddev) {
250250 KALDI_ASSERT (in.NumRows () == out->NumRows ());
251251 KALDI_ASSERT (in.NumCols () + 1 == out->NumCols ());
@@ -291,6 +291,100 @@ void NormalizePerRow(const CuMatrixBase<double>& in, const double target_rms,
291291 const bool add_log_stddev, CuMatrixBase<double >* out);
292292
293293
294+ // A note on the derivative of NormalizeComponent...
295+ // let both row_in and row_out be vectors of dimension D.
296+ // Let p = row_in^T row_in / (D * target_rms^2), and let
297+ // f = 1.0 / sqrt(max(kSquaredNormFloor, p)), and we compute row_out as:
298+ // row_out = f row_in.
299+ // Suppose we have a quantity deriv_out which is the derivative
300+ // of the objective function w.r.t. row_out. We want to compute
301+ // deriv_in which is the derivative of the objective function w.r.t.
302+ // row_in. Let the objective function be F. One term is obvious: we have
303+ // deriv_in = f deriv_out + ....
304+ // next we have to take into account the derivative that gets back-propagated
305+ // through f. Obviously, dF/df = deriv_out^T row_in.
306+ // And df/dp = (p <= kSquaredNormFloor ? 0.0 : -0.5 p^{-1.5}) = (f == 1.0 / sqrt(kSquaredNormFloor) ? 0.0 : -0.5 f^3),
307+ // and dp/d(row_in) = 2/(D * target_rms^2) row_in. [it's vector_valued].
308+ // So this term in dF/d(row_in) equals:
309+ // dF/df df/dp dp/d(row_in) = 2/(D * target_rms^2) (f == 1.0 / sqrt(kSquaredNormFloor) ? 0.0 : -0.5 f^3) (deriv_out^T row_in) row_in
310+ // So
311+ // deriv_in = f deriv_out + (f == 1.0 ? 0.0 : -f^3 / (D * target_rms^2) ) (deriv_out^T row_in) row_in
312+ // if add_log_stddev_ true, the deriv_in has another term as
313+ // dF/dx_i = dF/df . df/dx_i => df/dx_i = x_i/(x^T x)
314+ template <typename Real>
315+ void DiffNormalizePerRow (const CuMatrixBase<Real> &in_value,
316+ const CuMatrixBase<Real> &out_deriv,
317+ const Real target_rms, const bool add_log_stddev,
318+ CuMatrixBase<Real>* in_deriv) {
319+ const Real kSquaredNormFloor = 1.3552527156068805425e-20 ; // 2^-66
320+ #if HAVE_CUDA == 1
321+ if (CuDevice::Instantiate ().Enabled ()) {
322+ Timer tim;
323+ size_t dimBlock = CU1DBLOCK;
324+ size_t dimGrid = in_deriv->NumRows ();
325+ cuda_diff_normalize_per_row (dimGrid, dimBlock, in_deriv->Data (),
326+ in_deriv->Stride (), in_value.Data (),
327+ in_value.Dim (), out_deriv.Data (),
328+ out_deriv.Stride (), target_rms, add_log_stddev);
329+ CU_SAFE_CALL (cudaGetLastError ());
330+ CuDevice::Instantiate ().AccuProfile (__func__, tim.Elapsed ());
331+ } else
332+ #endif
333+ {
334+ const CuSubMatrix<Real> out_deriv_no_log (out_deriv, 0 , out_deriv.NumRows (),
335+ 0 , in_value.NumCols ());
336+ CuVector<Real> dot_products (out_deriv.NumRows ());
337+ dot_products.AddDiagMatMat (1.0 , out_deriv_no_log, kNoTrans , in_value,
338+ kTrans , 0.0 );
339+ CuVector<Real> in_norm (in_value.NumRows ());
340+ Real d_scaled = (in_value.NumCols () * target_rms * target_rms);
341+ in_norm.AddDiagMat2 (1.0 , in_value, kNoTrans , 0.0 );
342+
343+ if (add_log_stddev) {
344+ CuVector<Real> log_stddev_deriv (in_norm), // log_stddev deriv as dF/dy .* (x^T x)^-1
345+ out_deriv_for_stddev (out_deriv.NumRows (), kUndefined );
346+ // f = log(sqrt(max(epsi, x^T x / D)))
347+ // df/dx = epsi^2 * D < x^T x ? (1/(x^T x)) * x : 0.
348+ // we don't compute this exactly below for the case when x^2 x is very
349+ // small, but we do make sure that the deriv isn't infinity when the input
350+ // is zero.
351+ log_stddev_deriv.ApplyFloor (in_value.NumCols () * kSquaredNormFloor );
352+ log_stddev_deriv.ApplyPow (-1.0 );
353+ out_deriv_for_stddev.CopyColFromMat (out_deriv, (out_deriv.NumCols () - 1 ));
354+ log_stddev_deriv.MulElements (out_deriv_for_stddev);
355+ if (in_deriv)
356+ in_deriv->AddDiagVecMat (1.0 , log_stddev_deriv, in_value, kNoTrans , 1.0 );
357+ }
358+ in_norm.Scale (1.0 / d_scaled);
359+ in_norm.ApplyFloor (kSquaredNormFloor );
360+ in_norm.ApplyPow (-0.5 );
361+ if (in_deriv) {
362+ if (in_deriv->Data () != out_deriv_no_log.Data ())
363+ in_deriv->AddDiagVecMat (1.0 , in_norm, out_deriv_no_log, kNoTrans , 1.0 );
364+ else
365+ in_deriv->MulRowsVec (in_norm);
366+ in_norm.ReplaceValue (1.0 / sqrt (kSquaredNormFloor ), 0.0 );
367+ in_norm.ApplyPow (3.0 );
368+ dot_products.MulElements (in_norm);
369+
370+ in_deriv->AddDiagVecMat (-1.0 / d_scaled, dot_products, in_value, kNoTrans ,
371+ 1.0 );
372+ }
373+ }
374+ }
375+
376+ template
377+ void DiffNormalizePerRow (const CuMatrixBase<float > &in_value,
378+ const CuMatrixBase<float > &out_deriv,
379+ const float target_rms, const bool add_log_stddev,
380+ CuMatrixBase<float >* in_deriv);
381+ template
382+ void DiffNormalizePerRow (const CuMatrixBase<double > &in_value,
383+ const CuMatrixBase<double > &out_deriv,
384+ const double target_rms, const bool add_log_stddev,
385+ CuMatrixBase<double >* in_deriv);
386+
387+
294388// not calling this Sigmoid to reduce the chance of future collisions.
295389template <typename Real>
296390static inline Real ScalarSigmoid (Real a) {
0 commit comments