Skip to content

Commit db515e6

Browse files
authored
[BIG tensor] fix acc error in margin_cross_entropy (#74252)
1 parent 70bf749 commit db515e6

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
#include "paddle/phi/kernels/cast_kernel.h"
1415
#include "paddle/phi/kernels/impl/margin_cross_entropy.cu.h"
1516

1617
namespace phi {
1718

18-
template <typename T, typename IndexT>
19+
template <typename T, typename MPType, typename IndexT>
1920
__global__ void AddMarginToPositiveLogitsKernel(T* logit,
2021
const IndexT* label,
2122
const float margin1,
@@ -26,7 +27,6 @@ __global__ void AddMarginToPositiveLogitsKernel(T* logit,
2627
const int64_t N,
2728
const int64_t D,
2829
const int* class_interval_ptr) {
29-
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
3030
int64_t start_index = class_interval_ptr[rank];
3131
int64_t end_index = class_interval_ptr[rank + 1];
3232
int num_classes = class_interval_ptr[nranks];
@@ -42,55 +42,48 @@ __global__ void AddMarginToPositiveLogitsKernel(T* logit,
4242

4343
if (real_label >= start_index && real_label < end_index) {
4444
int64_t offset = i * D + real_label - start_index;
45-
if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) {
46-
MPType x = static_cast<MPType>(logit[offset]);
47-
MPType theta = acos(x);
48-
if (fabs(margin1 - 1.0) > 1e-8) {
49-
theta *= static_cast<MPType>(margin1);
50-
}
51-
if (fabs(margin2) > 1e-8) {
52-
theta += static_cast<MPType>(margin2);
53-
}
54-
logit[offset] = static_cast<T>(cos(theta));
55-
}
56-
if (fabs(margin3) > 1e-8) {
57-
MPType y = static_cast<MPType>(logit[offset]);
58-
y -= static_cast<MPType>(margin3);
59-
logit[offset] = static_cast<T>(y);
60-
}
45+
MPType x = static_cast<MPType>(logit[offset]);
46+
MPType theta = acos(x);
47+
theta *= static_cast<MPType>(margin1);
48+
theta += static_cast<MPType>(margin2);
49+
MPType y = cos(theta) - static_cast<MPType>(margin3);
50+
logit[offset] = static_cast<T>(y);
6151
}
6252
}
6353
}
6454

65-
template <typename T>
55+
template <typename T, typename MPType>
6656
__global__ void ScaleLogitKernel(T* logits,
6757
const float scale,
6858
const int64_t N,
6959
const int64_t D) {
7060
CUDA_KERNEL_LOOP_TYPE(i, N * D, int64_t) {
71-
logits[i] *= static_cast<T>(scale);
61+
logits[i] = static_cast<MPType>(logits[i]) * (scale);
7262
}
7363
}
7464

75-
template <typename T>
65+
template <typename T, typename MPType>
7666
__global__ void LogitsMinusMaxKernel(T* logits,
7767
const T* logits_max_per_row,
7868
const int64_t N,
7969
const int64_t D) {
8070
CUDA_KERNEL_LOOP_TYPE(i, N * D, int64_t) {
8171
auto row = i / D;
82-
logits[i] -= logits_max_per_row[row];
72+
logits[i] = static_cast<MPType>(logits[i]) -
73+
static_cast<MPType>(logits_max_per_row[row]);
8374
}
8475
}
8576

86-
template <typename T>
77+
template <typename T, typename MPType>
8778
__global__ void LogitsMinusLogSumKernel(T* logits,
8879
const T* logits_sum_per_row,
8980
const int64_t N,
9081
const int64_t D) {
9182
CUDA_KERNEL_LOOP_TYPE(i, N * D, int64_t) {
9283
auto row = i / D;
93-
logits[i] -= phi::kps::details::Log(logits_sum_per_row[row]);
84+
logits[i] =
85+
static_cast<MPType>(logits[i]) -
86+
static_cast<MPType>(phi::kps::details::Log(logits_sum_per_row[row]));
9487
}
9588
}
9689

@@ -132,6 +125,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
132125
DenseTensor* softmax,
133126
DenseTensor* loss) {
134127
const auto& place = dev_ctx.GetPlace(); // old code
128+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
135129

136130
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
137131
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
@@ -192,7 +186,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
192186
// save match_logits, used for gradient computation.
193187
if (label_type == phi::DataType::INT32) {
194188
typedef int32_t LabelT;
195-
AddMarginToPositiveLogitsKernel<T>
189+
AddMarginToPositiveLogitsKernel<T, MPType>
196190
<<<NumBlocks(N), threads, 0, dev_ctx.stream()>>>(
197191
logits_ptr,
198192
labels.data<LabelT>(),
@@ -206,7 +200,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
206200
class_interval.data<int>());
207201
} else if (label_type == phi::DataType::INT64) {
208202
typedef int64_t LabelT;
209-
AddMarginToPositiveLogitsKernel<T>
203+
AddMarginToPositiveLogitsKernel<T, MPType>
210204
<<<NumBlocks(N), threads, 0, dev_ctx.stream()>>>(
211205
logits_ptr,
212206
labels.data<LabelT>(),
@@ -226,8 +220,9 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
226220
}
227221

228222
// scale by s
229-
ScaleLogitKernel<T><<<NumBlocks(N * D), threads, 0, dev_ctx.stream()>>>(
230-
logits_ptr, scale, N, D);
223+
ScaleLogitKernel<T, MPType>
224+
<<<NumBlocks(N * D), threads, 0, dev_ctx.stream()>>>(
225+
logits_ptr, scale, N, D);
231226

232227
// step 2, obtain logit_max
233228
DenseTensor logits_max;
@@ -250,8 +245,9 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
250245
#endif
251246

252247
// step 3, logit - logit_max
253-
LogitsMinusMaxKernel<T><<<NumBlocks(N * D), threads, 0, dev_ctx.stream()>>>(
254-
logits_ptr, logits_max_buff, N, D);
248+
LogitsMinusMaxKernel<T, MPType>
249+
<<<NumBlocks(N * D), threads, 0, dev_ctx.stream()>>>(
250+
logits_ptr, logits_max_buff, N, D);
255251

256252
// step 4, sum(exp(logit - logit_max))
257253
DenseTensor sum_exp_logits;
@@ -272,7 +268,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
272268
#endif
273269

274270
// step 5, (logit - logit_max) - log(sum(exp(logit - logit_max)))
275-
LogitsMinusLogSumKernel<T>
271+
LogitsMinusLogSumKernel<T, MPType>
276272
<<<NumBlocks(N * D), threads, 0, dev_ctx.stream()>>>(
277273
logits_ptr, sum_exp_logits_buff, N, D);
278274

0 commit comments

Comments
 (0)