Skip to content

Commit 628ff34

Browse files
Fix FPE of label smooth op (#35861)
1 parent 7ff226f commit 628ff34

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

paddle/fluid/operators/label_smooth_op.h

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,21 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
2929
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
3030
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
3131
out_t->mutable_data<T>(ctx.GetPlace());
32-
33-
auto epsilon = ctx.Attr<float>("epsilon");
34-
auto out = framework::EigenVector<T>::Flatten(*out_t);
35-
auto in = framework::EigenVector<T>::Flatten(*in_t);
36-
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
37-
if (dist_t) {
38-
auto dist = framework::EigenVector<T>::Flatten(*dist_t);
39-
out.device(dev) =
40-
static_cast<T>(1 - epsilon) * in +
41-
static_cast<T>(epsilon) *
42-
dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel() / label_dim));
43-
} else {
44-
out.device(dev) = static_cast<T>(1 - epsilon) * in +
45-
static_cast<T>(epsilon / label_dim);
32+
if (label_dim != 0) {
33+
auto epsilon = ctx.Attr<float>("epsilon");
34+
auto out = framework::EigenVector<T>::Flatten(*out_t);
35+
auto in = framework::EigenVector<T>::Flatten(*in_t);
36+
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
37+
if (dist_t) {
38+
auto dist = framework::EigenVector<T>::Flatten(*dist_t);
39+
out.device(dev) = static_cast<T>(1 - epsilon) * in +
40+
static_cast<T>(epsilon) *
41+
dist.broadcast(Eigen::DSizes<int, 1>(
42+
in_t->numel() / label_dim));
43+
} else {
44+
out.device(dev) = static_cast<T>(1 - epsilon) * in +
45+
static_cast<T>(epsilon / label_dim);
46+
}
4647
}
4748
}
4849
};
@@ -54,13 +55,15 @@ class LabelSmoothGradKernel : public framework::OpKernel<T> {
5455
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
5556
auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
5657
d_in_t->mutable_data<T>(ctx.GetPlace());
58+
auto d_out_dim = d_out_t->dims()[d_out_t->dims().size() - 1];
59+
if (d_out_dim != 0) {
60+
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
61+
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);
5762

58-
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
59-
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);
60-
61-
auto epsilon = ctx.Attr<float>("epsilon");
62-
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
63-
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
63+
auto epsilon = ctx.Attr<float>("epsilon");
64+
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
65+
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
66+
}
6467
}
6568
};
6669
} // namespace operators

0 commit comments

Comments
 (0)