@@ -29,20 +29,21 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
2929 auto * dist_t = ctx.Input <framework::Tensor>(" PriorDist" );
3030 auto label_dim = in_t ->dims ()[in_t ->dims ().size () - 1 ];
3131 out_t ->mutable_data <T>(ctx.GetPlace ());
32-
33- auto epsilon = ctx.Attr <float >(" epsilon" );
34- auto out = framework::EigenVector<T>::Flatten (*out_t );
35- auto in = framework::EigenVector<T>::Flatten (*in_t );
36- auto & dev = *ctx.template device_context <DeviceContext>().eigen_device ();
37- if (dist_t ) {
38- auto dist = framework::EigenVector<T>::Flatten (*dist_t );
39- out.device (dev) =
40- static_cast <T>(1 - epsilon) * in +
41- static_cast <T>(epsilon) *
42- dist.broadcast (Eigen::DSizes<int , 1 >(in_t ->numel () / label_dim));
43- } else {
44- out.device (dev) = static_cast <T>(1 - epsilon) * in +
45- static_cast <T>(epsilon / label_dim);
32+ if (label_dim != 0 ) {
33+ auto epsilon = ctx.Attr <float >(" epsilon" );
34+ auto out = framework::EigenVector<T>::Flatten (*out_t );
35+ auto in = framework::EigenVector<T>::Flatten (*in_t );
36+ auto & dev = *ctx.template device_context <DeviceContext>().eigen_device ();
37+ if (dist_t ) {
38+ auto dist = framework::EigenVector<T>::Flatten (*dist_t );
39+ out.device (dev) = static_cast <T>(1 - epsilon) * in +
40+ static_cast <T>(epsilon) *
41+ dist.broadcast (Eigen::DSizes<int , 1 >(
42+ in_t ->numel () / label_dim));
43+ } else {
44+ out.device (dev) = static_cast <T>(1 - epsilon) * in +
45+ static_cast <T>(epsilon / label_dim);
46+ }
4647 }
4748 }
4849};
@@ -54,13 +55,15 @@ class LabelSmoothGradKernel : public framework::OpKernel<T> {
5455 auto * d_out_t = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
5556 auto * d_in_t = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
5657 d_in_t ->mutable_data <T>(ctx.GetPlace ());
58+ auto d_out_dim = d_out_t ->dims ()[d_out_t ->dims ().size () - 1 ];
59+ if (d_out_dim != 0 ) {
60+ auto d_out = framework::EigenVector<T>::Flatten (*d_out_t );
61+ auto d_in = framework::EigenVector<T>::Flatten (*d_in_t );
5762
58- auto d_out = framework::EigenVector<T>::Flatten (*d_out_t );
59- auto d_in = framework::EigenVector<T>::Flatten (*d_in_t );
60-
61- auto epsilon = ctx.Attr <float >(" epsilon" );
62- auto & dev = *ctx.template device_context <DeviceContext>().eigen_device ();
63- d_in.device (dev) = static_cast <T>(1 - epsilon) * d_out;
63+ auto epsilon = ctx.Attr <float >(" epsilon" );
64+ auto & dev = *ctx.template device_context <DeviceContext>().eigen_device ();
65+ d_in.device (dev) = static_cast <T>(1 - epsilon) * d_out;
66+ }
6467 }
6568};
6669} // namespace operators
0 commit comments