Skip to content

Commit d5a9d5c

Browse files
kangshiyindanpovey
authored andcommitted
[src] add CUDA kernel for backprop of NormalizeLayer (#1458)
1 parent f6b011f commit d5a9d5c

File tree

8 files changed

+381
-85
lines changed

8 files changed

+381
-85
lines changed

src/cudamatrix/cu-kernels-ansi.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,17 @@ void cudaD_copy_cols_from_vec(dim3 Gr, dim3 Bl, double *mat_out,
705705
void cudaF_copy_cols_from_vec(dim3 Gr, dim3 Bl, float *mat_out, MatrixDim d_out,
706706
const float *v_in);
707707

708+
void cudaF_diff_normalize_per_row(size_t Gr, size_t Bl, float *id,
709+
int id_stride, const float *iv,
710+
MatrixDim iv_dim, const float* od,
711+
int od_stride, float target_rms,
712+
bool add_log_stddev);
713+
void cudaD_diff_normalize_per_row(size_t Gr, size_t Bl, double *id,
714+
int id_stride, const double *iv,
715+
MatrixDim iv_dim, const double* od,
716+
int od_stride, double target_rms,
717+
bool add_log_stddev);
718+
708719
} // extern "C"
709720

710721
#endif // HAVE_CUDA

src/cudamatrix/cu-kernels.cu

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2292,7 +2292,7 @@ static void _normalize_per_row(Real *y, int y_stride, const Real *x,
22922292
}
22932293
}
22942294

2295-
const Real kSquaredNormFloor = 1.35525271560688e-20; // 2^-66
2295+
const Real kSquaredNormFloor = 1.3552527156068805425e-20; // 2^-66
22962296
if (tid == 0) {
22972297
ssum[0] = sqrt(
22982298
fmax(ssum[0] / (target_rms * target_rms * x_d.cols), kSquaredNormFloor));
@@ -2315,6 +2315,87 @@ static void _normalize_per_row(Real *y, int y_stride, const Real *x,
23152315
}
23162316

23172317

2318+
template<typename Real>
2319+
__global__
2320+
static void _diff_normalize_per_row(Real *id, int id_stride, const Real *iv,
2321+
MatrixDim iv_dim, const Real* od,
2322+
int od_stride, Real target_rms,
2323+
bool add_log_stddev) {
2324+
2325+
const Real kSquaredNormFloor = 1.3552527156068805425e-20; // 2^-66
2326+
const Real kInvNormFloor = 8589934592.0;
2327+
2328+
const int tid = threadIdx.x;
2329+
const int i = blockIdx.x;
2330+
const Real* iv_row = iv + i * iv_dim.stride;
2331+
const Real* od_row = od + i * od_stride;
2332+
2333+
// reduce to CU1DBLOCK elements per row
2334+
Real dot_products = Real(0);
2335+
Real in_norm = Real(0);
2336+
for (int j = tid; j < iv_dim.cols; j += CU1DBLOCK) {
2337+
const Real iv_ij = iv_row[j];
2338+
dot_products += iv_ij * od_row[j];
2339+
in_norm += iv_ij * iv_ij;
2340+
}
2341+
__shared__ Real sprod[CU1DBLOCK];
2342+
__shared__ Real snorm[CU1DBLOCK];
2343+
sprod[tid] = dot_products;
2344+
snorm[tid] = in_norm;
2345+
__syncthreads();
2346+
2347+
// reduce to 2x warpSize elements per row
2348+
# pragma unroll
2349+
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
2350+
if (tid < shift) {
2351+
sprod[tid] += sprod[tid + shift];
2352+
snorm[tid] += snorm[tid + shift];
2353+
}
2354+
__syncthreads();
2355+
}
2356+
2357+
// reduce to 1 element per row
2358+
if (tid < warpSize) {
2359+
# pragma unroll
2360+
for (int shift = warpSize; shift > 0; shift >>= 1) {
2361+
sprod[tid] += sprod[tid + shift];
2362+
snorm[tid] += snorm[tid + shift];
2363+
}
2364+
}
2365+
2366+
// broadcast the sum results
2367+
__syncthreads();
2368+
dot_products = sprod[0];
2369+
in_norm = snorm[0];
2370+
2371+
Real log_stddev_deriv;
2372+
if (add_log_stddev) {
2373+
log_stddev_deriv = Real(1) / max(in_norm, iv_dim.cols * kSquaredNormFloor)
2374+
* od_row[iv_dim.cols];
2375+
}
2376+
2377+
const Real inv_d_scaled = Real(1) / (iv_dim.cols * target_rms * target_rms);
2378+
in_norm = Real(1) / sqrt(max(in_norm * inv_d_scaled, kSquaredNormFloor));
2379+
2380+
const Real f = in_norm == kInvNormFloor ? Real(0) : in_norm;
2381+
dot_products *= f * f * f * inv_d_scaled;
2382+
2383+
for (int j = tid; j < iv_dim.cols; j += CU1DBLOCK) {
2384+
const Real iv_ij = iv_row[j];
2385+
Real id_ij = id[i * id_stride + j];
2386+
if (add_log_stddev) {
2387+
id_ij += log_stddev_deriv * iv_ij;
2388+
}
2389+
if (id != od) {
2390+
id_ij += in_norm * od_row[j];
2391+
} else {
2392+
id_ij *= in_norm;
2393+
}
2394+
id_ij -= dot_products * iv_ij;
2395+
id[i * id_stride + j] = id_ij;
2396+
}
2397+
}
2398+
23182399
// Per-row log-softmax operation on 'x', with writing to 'y'.
23192400
// note, x and y may point to the same memory. This is equivalent to setting
23202401
// matrix y to matrix x and then, for each row of y, subtracting the offset that
@@ -4690,3 +4771,20 @@ void cudaF_copy_cols_from_vec(dim3 Gr, dim3 Bl, float *mat_out, MatrixDim d_out,
46904771
const float *v_in) {
46914772
_copy_cols_from_vec<<<Gr, Bl>>>(mat_out, d_out, v_in);
46924773
}
4774+
4775+
void cudaF_diff_normalize_per_row(size_t Gr, size_t Bl, float *id,
4776+
int id_stride, const float *iv,
4777+
MatrixDim iv_dim, const float* od,
4778+
int od_stride, float target_rms,
4779+
bool add_log_stddev) {
4780+
_diff_normalize_per_row<<<Gr, Bl>>>(id, id_stride, iv, iv_dim, od, od_stride,
4781+
target_rms, add_log_stddev);
4782+
}
4783+
void cudaD_diff_normalize_per_row(size_t Gr, size_t Bl, double *id,
4784+
int id_stride, const double *iv,
4785+
MatrixDim iv_dim, const double* od,
4786+
int od_stride, double target_rms,
4787+
bool add_log_stddev) {
4788+
_diff_normalize_per_row<<<Gr, Bl>>>(id, id_stride, iv, iv_dim, od, od_stride,
4789+
target_rms, add_log_stddev);
4790+
}

src/cudamatrix/cu-kernels.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,6 +1348,23 @@ inline void cuda_copy_cols_from_vec(dim3 Gr, dim3 Bl, float *mat_out,
13481348
cudaF_copy_cols_from_vec(Gr, Bl, mat_out, d_out, v_in);
13491349
}
13501350

1351+
inline void cuda_diff_normalize_per_row(size_t Gr, size_t Bl, double *id,
1352+
int id_stride, const double *iv,
1353+
MatrixDim iv_dim, const double* od,
1354+
int od_stride, double target_rms,
1355+
bool add_log_stddev) {
1356+
cudaD_diff_normalize_per_row(Gr, Bl, id, id_stride, iv, iv_dim, od, od_stride,
1357+
target_rms, add_log_stddev);
1358+
}
1359+
inline void cuda_diff_normalize_per_row(size_t Gr, size_t Bl, float *id,
1360+
int id_stride, const float *iv,
1361+
MatrixDim iv_dim, const float* od,
1362+
int od_stride, float target_rms,
1363+
bool add_log_stddev) {
1364+
cudaF_diff_normalize_per_row(Gr, Bl, id, id_stride, iv, iv_dim, od, od_stride,
1365+
target_rms, add_log_stddev);
1366+
}
1367+
13511368
} // namespace kaldi
13521369

13531370
#endif // HAVE_CUDA

src/cudamatrix/cu-math-test.cc

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,14 +510,115 @@ static void UnitTestCuMathNormalizePerRow() {
510510

511511
BaseFloat gflops = ((BaseFloat) dim * dim * iter)
512512
/ (tim.Elapsed() * 1.0e+09);
513-
KALDI_LOG << "For CuMatrix::NormalizePerRow"
513+
KALDI_LOG << "For CuMath::NormalizePerRow"
514514
<< (sizeof(Real)==8?"<double>":"<float>") << ", for dim = "
515515
<< dim << ", speed was " << gflops << " gigaflops.";
516516
if (tim.Elapsed() > 0.05)
517517
break;
518518
}
519519
}
520520

521+
template<typename Real>
522+
static void UnitTestCuDiffNormalizePerRow() {
523+
for (int32 i = 0; i < 2; i++) {
524+
int row = 10 + Rand() % 40;
525+
int col = 10 + Rand() % 50;
526+
527+
Matrix<Real> Hi(row, col);
528+
Matrix<Real> Ho(row, col + 1);
529+
Matrix<Real> Hid(row, col);
530+
Matrix<Real> Hod(row, col + 1);
531+
Hi.SetRandn();
532+
Hod.SetRandn();
533+
Hi.Scale(5.0);
534+
535+
CuMatrix<Real> Di(row, col);
536+
CuMatrix<Real> Do(row, col + 1);
537+
CuMatrix<Real> Did(row, col);
538+
CuMatrix<Real> Dod(row, col + 1);
539+
Di.CopyFromMat(Hi);
540+
Dod.CopyFromMat(Hod);
541+
542+
Real target_rms = 0.3456;
543+
bool add_log_stddev = true;
544+
const Real kSquaredNormFloor = 1.3552527156068805425e-20; // 2^-66
545+
546+
//gpu
547+
cu::DiffNormalizePerRow(Di, Dod, target_rms, add_log_stddev, &Did);
548+
549+
//cpu
550+
{
551+
MatrixBase<Real>* in_deriv = &Hid;
552+
MatrixBase<Real>& out_deriv(Hod);
553+
MatrixBase<Real>& in_value(Hi);
554+
555+
const SubMatrix<Real> out_deriv_no_log(out_deriv, 0, out_deriv.NumRows(),
556+
0, in_value.NumCols());
557+
Vector<Real> dot_products(out_deriv.NumRows());
558+
dot_products.AddDiagMatMat(1.0, out_deriv_no_log, kNoTrans, in_value,
559+
kTrans, 0.0);
560+
Vector<Real> in_norm(in_value.NumRows());
561+
Real d_scaled = (in_value.NumCols() * target_rms * target_rms);
562+
in_norm.AddDiagMat2(1.0, in_value, kNoTrans, 0.0);
563+
if (add_log_stddev) {
564+
Vector<Real> log_stddev_deriv(in_norm), // log_stddev deriv as dF/dy .* (x^T x)^-1
565+
out_deriv_for_stddev(out_deriv.NumRows(), kUndefined);
566+
// f = log(sqrt(max(epsi, x^T x / D)))
567+
// df/dx = epsi^2 * D < x^T x ? (1/(x^T x)) * x : 0.
568+
// we don't compute this exactly below for the case when x^2 x is very
569+
// small, but we do make sure that the deriv isn't infinity when the input
570+
// is zero.
571+
log_stddev_deriv.ApplyFloor(in_value.NumCols() * kSquaredNormFloor);
572+
log_stddev_deriv.ApplyPow(-1.0);
573+
out_deriv_for_stddev.CopyColFromMat(out_deriv,
574+
(out_deriv.NumCols() - 1));
575+
log_stddev_deriv.MulElements(out_deriv_for_stddev);
576+
if (in_deriv)
577+
in_deriv->AddDiagVecMat(1.0, log_stddev_deriv, in_value, kNoTrans,
578+
1.0);
579+
}
580+
in_norm.Scale(1.0 / d_scaled);
581+
in_norm.ApplyFloor(kSquaredNormFloor);
582+
in_norm.ApplyPow(-0.5);
583+
if (in_deriv) {
584+
if (in_deriv->Data() != out_deriv_no_log.Data())
585+
in_deriv->AddDiagVecMat(1.0, in_norm, out_deriv_no_log, kNoTrans,
586+
1.0);
587+
else
588+
in_deriv->MulRowsVec(in_norm);
589+
in_norm.ReplaceValue(1.0 / sqrt(kSquaredNormFloor), 0.0);
590+
in_norm.ApplyPow(3.0);
591+
dot_products.MulElements(in_norm);
592+
593+
in_deriv->AddDiagVecMat(-1.0 / d_scaled, dot_products, in_value,
594+
kNoTrans, 1.0);
595+
}
596+
597+
Matrix<Real> Hid2(Did);
598+
AssertEqual(Hid, Hid2, 0.00001);
599+
}
600+
}
601+
602+
for (int dim = 16; dim <= 1024; dim *= 2) {
603+
BaseFloat time_in_secs = 0.025;
604+
CuMatrix<Real> id(dim, dim), iv(dim, dim), od(dim, dim + 1);
605+
iv.SetRandn();
606+
od.SetRandn();
607+
Timer tim;
608+
int32 iter = 0;
609+
for (; tim.Elapsed() < time_in_secs; iter++) {
610+
cu::DiffNormalizePerRow(iv, od, Real(0.456), true, &id);
611+
}
612+
BaseFloat fdim = dim;
613+
BaseFloat gflops = (fdim * fdim * iter) / (tim.Elapsed() * 1.0e+09);
614+
KALDI_LOG << "For CuMath::DiffNormalizePerRow"
615+
<< (sizeof(Real)==8?"<double>":"<float>")
616+
<< ", for dim = " << dim << ", speed was " << gflops
617+
<< " gigaflops.";
618+
}
619+
}
620+
621+
521622

522623
template<typename Real> void CudaMathUnitTest() {
523624
#if HAVE_CUDA == 1
@@ -531,6 +632,7 @@ template<typename Real> void CudaMathUnitTest() {
531632
UnitTestLstmNonlinearity();
532633
UnitTestBackpropLstmNonlinearity<Real>();
533634
UnitTestCuMathNormalizePerRow<Real>();
635+
UnitTestCuDiffNormalizePerRow<Real>();
534636
}
535637

536638
} // namespace kaldi

src/cudamatrix/cu-math.cc

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ void Randomize(const CuMatrixBase<double> &src,
245245
template<typename Real>
246246
void NormalizePerRow(const CuMatrixBase<Real>& in, const Real target_rms,
247247
const bool add_log_stddev, CuMatrixBase<Real>* out) {
248-
const Real kSquaredNormFloor = 1.35525271560688e-20; // 2^-66
248+
const Real kSquaredNormFloor = 1.3552527156068805425e-20; // 2^-66
249249
if (add_log_stddev) {
250250
KALDI_ASSERT(in.NumRows() == out->NumRows());
251251
KALDI_ASSERT(in.NumCols() + 1 == out->NumCols());
@@ -291,6 +291,100 @@ void NormalizePerRow(const CuMatrixBase<double>& in, const double target_rms,
291291
const bool add_log_stddev, CuMatrixBase<double>* out);
292292

293293

294+
// A note on the derivative of NormalizeComponent...
295+
// let both row_in and row_out be vectors of dimension D.
296+
// Let p = row_in^T row_in / (D * target_rms^2), and let
297+
// f = 1.0 / sqrt(max(kSquaredNormFloor, p)), and we compute row_out as:
298+
// row_out = f row_in.
299+
// Suppose we have a quantity deriv_out which is the derivative
300+
// of the objective function w.r.t. row_out. We want to compute
301+
// deriv_in which is the derivative of the objective function w.r.t.
302+
// row_in. Let the objective function be F. One term is obvious: we have
303+
// deriv_in = f deriv_out + ....
304+
// next we have to take into account the derivative that gets back-propagated
305+
// through f. Obviously, dF/df = deriv_out^T row_in.
306+
// And df/dp = (p <= kSquaredNormFloor ? 0.0 : -0.5 p^{-1.5}) = (f == 1.0 / sqrt(kSquaredNormFloor) ? 0.0 : -0.5 f^3),
307+
// and dp/d(row_in) = 2/(D * target_rms^2) row_in. [it's vector_valued].
308+
// So this term in dF/d(row_in) equals:
309+
// dF/df df/dp dp/d(row_in) = 2/(D * target_rms^2) (f == 1.0 / sqrt(kSquaredNormFloor) ? 0.0 : -0.5 f^3) (deriv_out^T row_in) row_in
310+
// So
311+
// deriv_in = f deriv_out + (f == 1.0 ? 0.0 : -f^3 / (D * target_rms^2) ) (deriv_out^T row_in) row_in
312+
// if add_log_stddev_ true, the deriv_in has another term as
313+
// dF/dx_i = dF/df . df/dx_i => df/dx_i = x_i/(x^T x)
314+
template<typename Real>
315+
void DiffNormalizePerRow(const CuMatrixBase<Real> &in_value,
316+
const CuMatrixBase<Real> &out_deriv,
317+
const Real target_rms, const bool add_log_stddev,
318+
CuMatrixBase<Real>* in_deriv) {
319+
const Real kSquaredNormFloor = 1.3552527156068805425e-20; // 2^-66
320+
#if HAVE_CUDA == 1
321+
if (CuDevice::Instantiate().Enabled()) {
322+
Timer tim;
323+
size_t dimBlock = CU1DBLOCK;
324+
size_t dimGrid = in_deriv->NumRows();
325+
cuda_diff_normalize_per_row(dimGrid, dimBlock, in_deriv->Data(),
326+
in_deriv->Stride(), in_value.Data(),
327+
in_value.Dim(), out_deriv.Data(),
328+
out_deriv.Stride(), target_rms, add_log_stddev);
329+
CU_SAFE_CALL(cudaGetLastError());
330+
CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
331+
} else
332+
#endif
333+
{
334+
const CuSubMatrix<Real> out_deriv_no_log(out_deriv, 0, out_deriv.NumRows(),
335+
0, in_value.NumCols());
336+
CuVector<Real> dot_products(out_deriv.NumRows());
337+
dot_products.AddDiagMatMat(1.0, out_deriv_no_log, kNoTrans, in_value,
338+
kTrans, 0.0);
339+
CuVector<Real> in_norm(in_value.NumRows());
340+
Real d_scaled = (in_value.NumCols() * target_rms * target_rms);
341+
in_norm.AddDiagMat2(1.0, in_value, kNoTrans, 0.0);
342+
343+
if (add_log_stddev) {
344+
CuVector<Real> log_stddev_deriv(in_norm), // log_stddev deriv as dF/dy .* (x^T x)^-1
345+
out_deriv_for_stddev(out_deriv.NumRows(), kUndefined);
346+
// f = log(sqrt(max(epsi, x^T x / D)))
347+
// df/dx = epsi^2 * D < x^T x ? (1/(x^T x)) * x : 0.
348+
// we don't compute this exactly below for the case when x^2 x is very
349+
// small, but we do make sure that the deriv isn't infinity when the input
350+
// is zero.
351+
log_stddev_deriv.ApplyFloor(in_value.NumCols() * kSquaredNormFloor);
352+
log_stddev_deriv.ApplyPow(-1.0);
353+
out_deriv_for_stddev.CopyColFromMat(out_deriv, (out_deriv.NumCols() - 1));
354+
log_stddev_deriv.MulElements(out_deriv_for_stddev);
355+
if (in_deriv)
356+
in_deriv->AddDiagVecMat(1.0, log_stddev_deriv, in_value, kNoTrans, 1.0);
357+
}
358+
in_norm.Scale(1.0 / d_scaled);
359+
in_norm.ApplyFloor(kSquaredNormFloor);
360+
in_norm.ApplyPow(-0.5);
361+
if (in_deriv) {
362+
if (in_deriv->Data() != out_deriv_no_log.Data())
363+
in_deriv->AddDiagVecMat(1.0, in_norm, out_deriv_no_log, kNoTrans, 1.0);
364+
else
365+
in_deriv->MulRowsVec(in_norm);
366+
in_norm.ReplaceValue(1.0 / sqrt(kSquaredNormFloor), 0.0);
367+
in_norm.ApplyPow(3.0);
368+
dot_products.MulElements(in_norm);
369+
370+
in_deriv->AddDiagVecMat(-1.0 / d_scaled, dot_products, in_value, kNoTrans,
371+
1.0);
372+
}
373+
}
374+
}
375+
376+
template
377+
void DiffNormalizePerRow(const CuMatrixBase<float> &in_value,
378+
const CuMatrixBase<float> &out_deriv,
379+
const float target_rms, const bool add_log_stddev,
380+
CuMatrixBase<float>* in_deriv);
381+
template
382+
void DiffNormalizePerRow(const CuMatrixBase<double> &in_value,
383+
const CuMatrixBase<double> &out_deriv,
384+
const double target_rms, const bool add_log_stddev,
385+
CuMatrixBase<double>* in_deriv);
386+
387+
294388
// not calling this Sigmoid to reduce the chance of future collisions.
295389
template<typename Real>
296390
static inline Real ScalarSigmoid(Real a) {

0 commit comments

Comments
 (0)