1111// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212// See the License for the specific language governing permissions and
1313// limitations under the License.
14+ #include " paddle/phi/kernels/cast_kernel.h"
1415#include " paddle/phi/kernels/impl/margin_cross_entropy.cu.h"
1516
1617namespace phi {
1718
18- template <typename T, typename IndexT>
19+ template <typename T, typename MPType, typename IndexT>
1920__global__ void AddMarginToPositiveLogitsKernel (T* logit,
2021 const IndexT* label,
2122 const float margin1,
@@ -26,7 +27,6 @@ __global__ void AddMarginToPositiveLogitsKernel(T* logit,
2627 const int64_t N,
2728 const int64_t D,
2829 const int * class_interval_ptr) {
29- using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
3030 int64_t start_index = class_interval_ptr[rank];
3131 int64_t end_index = class_interval_ptr[rank + 1 ];
3232 int num_classes = class_interval_ptr[nranks];
@@ -42,55 +42,48 @@ __global__ void AddMarginToPositiveLogitsKernel(T* logit,
4242
4343 if (real_label >= start_index && real_label < end_index) {
4444 int64_t offset = i * D + real_label - start_index;
45- if (fabs (margin1 - 1.0 ) > 1e-8 || fabs (margin2) > 1e-8 ) {
46- MPType x = static_cast <MPType>(logit[offset]);
47- MPType theta = acos (x);
48- if (fabs (margin1 - 1.0 ) > 1e-8 ) {
49- theta *= static_cast <MPType>(margin1);
50- }
51- if (fabs (margin2) > 1e-8 ) {
52- theta += static_cast <MPType>(margin2);
53- }
54- logit[offset] = static_cast <T>(cos (theta));
55- }
56- if (fabs (margin3) > 1e-8 ) {
57- MPType y = static_cast <MPType>(logit[offset]);
58- y -= static_cast <MPType>(margin3);
59- logit[offset] = static_cast <T>(y);
60- }
45+ MPType x = static_cast <MPType>(logit[offset]);
46+ MPType theta = acos (x);
47+ theta *= static_cast <MPType>(margin1);
48+ theta += static_cast <MPType>(margin2);
49+ MPType y = cos (theta) - static_cast <MPType>(margin3);
50+ logit[offset] = static_cast <T>(y);
6151 }
6252 }
6353}
6454
65- template <typename T>
55+ template <typename T, typename MPType >
6656__global__ void ScaleLogitKernel (T* logits,
6757 const float scale,
6858 const int64_t N,
6959 const int64_t D) {
7060 CUDA_KERNEL_LOOP_TYPE (i, N * D, int64_t ) {
71- logits[i] * = static_cast <T> (scale);
61+ logits[i] = static_cast <MPType>(logits[i]) * (scale);
7262 }
7363}
7464
75- template <typename T>
65+ template <typename T, typename MPType >
7666__global__ void LogitsMinusMaxKernel (T* logits,
7767 const T* logits_max_per_row,
7868 const int64_t N,
7969 const int64_t D) {
8070 CUDA_KERNEL_LOOP_TYPE (i, N * D, int64_t ) {
8171 auto row = i / D;
82- logits[i] -= logits_max_per_row[row];
72+ logits[i] = static_cast <MPType>(logits[i]) -
73+ static_cast <MPType>(logits_max_per_row[row]);
8374 }
8475}
8576
86- template <typename T>
77+ template <typename T, typename MPType >
8778__global__ void LogitsMinusLogSumKernel (T* logits,
8879 const T* logits_sum_per_row,
8980 const int64_t N,
9081 const int64_t D) {
9182 CUDA_KERNEL_LOOP_TYPE (i, N * D, int64_t ) {
9283 auto row = i / D;
93- logits[i] -= phi::kps::details::Log (logits_sum_per_row[row]);
84+ logits[i] =
85+ static_cast <MPType>(logits[i]) -
86+ static_cast <MPType>(phi::kps::details::Log (logits_sum_per_row[row]));
9487 }
9588}
9689
@@ -132,6 +125,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
132125 DenseTensor* softmax,
133126 DenseTensor* loss) {
134127 const auto & place = dev_ctx.GetPlace (); // old code
128+ using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
135129
136130#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
137131 phi::distributed::NCCLCommContext* comm_ctx = nullptr ;
@@ -192,7 +186,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
192186 // save match_logits, used for gradient computation.
193187 if (label_type == phi::DataType::INT32) {
194188 typedef int32_t LabelT;
195- AddMarginToPositiveLogitsKernel<T>
189+ AddMarginToPositiveLogitsKernel<T, MPType >
196190 <<<NumBlocks(N), threads, 0 , dev_ctx.stream()>>> (
197191 logits_ptr,
198192 labels.data <LabelT>(),
@@ -206,7 +200,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
206200 class_interval.data <int >());
207201 } else if (label_type == phi::DataType::INT64) {
208202 typedef int64_t LabelT;
209- AddMarginToPositiveLogitsKernel<T>
203+ AddMarginToPositiveLogitsKernel<T, MPType >
210204 <<<NumBlocks(N), threads, 0 , dev_ctx.stream()>>> (
211205 logits_ptr,
212206 labels.data <LabelT>(),
@@ -226,8 +220,9 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
226220 }
227221
228222 // scale by s
229- ScaleLogitKernel<T><<<NumBlocks(N * D), threads, 0 , dev_ctx.stream()>>> (
230- logits_ptr, scale, N, D);
223+ ScaleLogitKernel<T, MPType>
224+ <<<NumBlocks(N * D), threads, 0 , dev_ctx.stream()>>> (
225+ logits_ptr, scale, N, D);
231226
232227 // step 2, obtain logit_max
233228 DenseTensor logits_max;
@@ -250,8 +245,9 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
250245#endif
251246
252247 // step 3, logit - logit_max
253- LogitsMinusMaxKernel<T><<<NumBlocks(N * D), threads, 0 , dev_ctx.stream()>>> (
254- logits_ptr, logits_max_buff, N, D);
248+ LogitsMinusMaxKernel<T, MPType>
249+ <<<NumBlocks(N * D), threads, 0 , dev_ctx.stream()>>> (
250+ logits_ptr, logits_max_buff, N, D);
255251
256252 // step 4, sum(exp(logit - logit_max))
257253 DenseTensor sum_exp_logits;
@@ -272,7 +268,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
272268#endif
273269
274270 // step 5, (logit - logit_max) - log(sum(exp(logit - logit_max)))
275- LogitsMinusLogSumKernel<T>
271+ LogitsMinusLogSumKernel<T, MPType >
276272 <<<NumBlocks(N * D), threads, 0 , dev_ctx.stream()>>> (
277273 logits_ptr, sum_exp_logits_buff, N, D);
278274
0 commit comments