Skip to content

Commit 698330b

Browse files
committed
slice-check
1 parent 919f9cc commit 698330b

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

paddle/phi/kernels/funcs/radix_sort.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
namespace phi {
1919
namespace funcs {
2020

21+
#ifdef PADDLE_WITH_CUDA
2122
namespace {
22-
2323
template <typename T>
2424
struct CudaType {
2525
using type = T;
@@ -113,5 +113,6 @@ INSTANTIATE_SORT_PAIRS(int64_t, 4)
113113
INSTANTIATE_SORT_PAIRS(int32_t, 8)
114114
INSTANTIATE_SORT_PAIRS(int64_t, 8)
115115

116+
#endif
116117
} // namespace funcs
117118
} // namespace phi

paddle/phi/kernels/funcs/radix_sort.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
// limitations under the License.
1414

1515
#pragma once
16+
#ifdef PADDLE_WITH_CUDA
1617
#include <cub/cub.cuh>
18+
#endif
1719
#include "paddle/phi/backends/gpu/gpu_context.h"
1820
#include "paddle/phi/core/dense_tensor.h"
1921

2022
namespace phi {
2123
namespace funcs {
2224

25+
#ifdef PADDLE_WITH_CUDA
2326
template <int kValueSize>
2427
struct OpaqueTypeRadix {
2528
uint8_t data[kValueSize];
@@ -76,5 +79,6 @@ void RadixSortPairs(const phi::GPUContext& dev_ctx,
7679
end_bit);
7780
}
7881

82+
#endif
7983
} // namespace funcs
8084
} // namespace phi

paddle/phi/kernels/funcs/stride_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ static inline void ScatterAddStride(
544544
*numel = num;
545545
}
546546

547-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
547+
#if defined(PADDLE_WITH_CUDA)
548548

549549
static inline std::vector<phi::DenseTensor> expandTensors(
550550
const phi::GPUContext& dev_ctx,

paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx,
401401
phi::DataType::INT64));
402402

403403
if (accumulate && index.size() == 1 && !is_combined) {
404+
#ifdef PADDLE_WITH_CUDA
404405
IndexPutWithSortKernel<T, int64_t>(dev_ctx,
405406
x,
406407
out_grad,
@@ -412,19 +413,21 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx,
412413
slice_offset,
413414
accumulate,
414415
x_grad);
415-
} else {
416-
GPUIndexElementwiseGetGrad<T, int64_t>(dev_ctx,
417-
x,
418-
out_grad,
419-
index,
420-
input_dims,
421-
input_strides,
422-
index_dims,
423-
index_strides,
424-
slice_offset,
425-
accumulate,
426-
x_grad);
416+
return;
417+
#endif
427418
}
419+
420+
GPUIndexElementwiseGetGrad<T, int64_t>(dev_ctx,
421+
x,
422+
out_grad,
423+
index,
424+
input_dims,
425+
input_strides,
426+
index_dims,
427+
index_strides,
428+
slice_offset,
429+
accumulate,
430+
x_grad);
428431
}
429432

430433
} // namespace phi

0 commit comments

Comments
 (0)