Skip to content

Commit 3ecc137

Browse files
Nicoshevpytorchmergebot
authored andcommitted
[Caffe2] Improve AddMomentsVec and UpdateMomentsVec (pytorch#167664)
Summary: RowwiseMomentsImpl accounts for about 0.4% cpu time of AdRanker: https://fburl.com/strobelight/ywf79nw3. It primarily calls AddMomentsVec and UpdateMomentsVec. These two routines are written using Pytorch's VecLib, meaning the utilized operators translate into intrinsics. Unfortunately, the compiler makes less transformations and optimizations when intrinsics are used. Therefore, if we carefully decouple and re-order operations, the emitted instruction sequence improves. Here we can see the dissassembly for the old and new AddMomentsVec: https://godbolt.org/z/83fxYvKfv We can see a much better instruction sequence is achieved in the new implementation. Test Plan: AdRanker ServiceLab Differential Revision: D86805648 Pull Request resolved: pytorch#167664 Approved by: https://github.com/mcfi
1 parent 84a7a34 commit 3ecc137

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

aten/src/ATen/native/cpu/moments_utils.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ C10_ALWAYS_INLINE void AddMomentsVec(
4646
const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
4747
const Vec c_vec(c);
4848
const Vec delta = m1_add - m1;
49-
m1 += c_vec * delta;
50-
m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
49+
const Vec m2_tmp = m2 + m2_add;
50+
const Vec c_vec_delta = c_vec * delta;
51+
const Vec m0_delta = delta * Vec(static_cast<T>(m0));
52+
m1 = m1 + c_vec_delta;
53+
m2 = fmadd(m0_delta, c_vec_delta, m2_tmp);
5154
m0 = n;
5255
}
5356

@@ -65,9 +68,11 @@ UpdateMomentsVec(
6568
Vec m2_vec(0);
6669
for (const auto j : c10::irange(m0)) {
6770
const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
71+
const Vec tmpVec = c_vecs[j];
6872
const Vec delta_vec = x_vec - m1_vec;
69-
m1_vec += delta_vec * c_vecs[j];
70-
m2_vec += delta_vec * (x_vec - m1_vec);
73+
m1_vec = fmadd(tmpVec, delta_vec, m1_vec);
74+
const Vec tmpVec2 = x_vec - m1_vec;
75+
m2_vec = fmadd(delta_vec, tmpVec2, m2_vec);
7176
}
7277
AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
7378
}
@@ -89,13 +94,16 @@ UpdateMomentsVec(
8994
fVec m2_fvec0(0), m2_fvec1(0);
9095
for (const auto j : c10::irange(m0)) {
9196
const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size());
97+
const fVec tmpVec = c_vecs[j];
9298
auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
9399
const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
94100
const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
95-
m1_fvec0 += delta_fvec0 * c_vecs[j];
96-
m1_fvec1 += delta_fvec1 * c_vecs[j];
97-
m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0);
98-
m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1);
101+
m1_fvec0 = fmadd(delta_fvec0, tmpVec, m1_fvec0);
102+
m1_fvec1 = fmadd(delta_fvec1, tmpVec, m1_fvec1);
103+
const fVec delta_fvec2 = x_fvec0 - m1_fvec0;
104+
const fVec delta_fvec3 = x_fvec1 - m1_fvec1;
105+
m2_fvec0 = fmadd(delta_fvec0, delta_fvec2, m2_fvec0);
106+
m2_fvec1 = fmadd(delta_fvec1, delta_fvec3, m2_fvec1);
99107
}
100108
AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0);
101109
AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0);

0 commit comments

Comments
 (0)