Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
}

AdvancedIndex ad = AdvancedIndex(tensor, indices_int64);
const bool is_combined = false;
const bool accumulate = false;

return index_elementwise_get_ad_func(tensor,
Expand All @@ -791,7 +792,8 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
ad.indexed_sizes,
ad.indexed_strides,
slice_offset,
accumulate);
accumulate,
is_combined);
} else {
if (bool_index.shape().size() == 1)
return gather_ad_func(tensor, bool_2_idx);
Expand Down Expand Up @@ -1238,23 +1240,17 @@ static void ApplyGetitem(const int index_size,
&transed_index_int64);

AdvancedIndex ad = AdvancedIndex(*transed_tensor, transed_index_int64);
if (index_size == 1) {
paddle::Tensor flattened_tensor =
flatten_ad_func((*transed_index)[0], 0, -1);
*out = gather_ad_func(*transed_tensor, flattened_tensor);
*out = reshape_ad_func(*out, ad.src_sizes);
} else {
const bool accumulate = true;
*out = index_elementwise_get_ad_func(*self_tensor,
ad.indices,
ad.src_sizes,
ad.src_strides,
ad.indexed_sizes,
ad.indexed_strides,
slice_offset,
accumulate);
}

const bool is_combined = (index_size == 1) ? false : true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_combined表示什么含义?加些注释说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_combined用来区分是普通索引还是组合索引,如果仅有一个普通索引反向时会采用性能更好的IndexPutWithSortKernel。新增了注释。

const bool accumulate = true;
*out = index_elementwise_get_ad_func(*self_tensor,
ad.indices,
ad.src_sizes,
ad.src_strides,
ad.indexed_sizes,
ad.indexed_strides,
slice_offset,
accumulate,
is_combined);
return;
} else {
paddle::Tensor transed_advanced_index_tensor;
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,7 @@ void IndexElementwiseGetGradInferMeta(
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
MetaTensor* x_grad) {
if (x_grad) {
x_grad->share_meta(x);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -788,5 +788,6 @@ void IndexElementwiseGetGradInferMeta(
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
MetaTensor* x_grad);
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2599,6 +2599,7 @@ void IndexElementwiseGetInferMeta(const MetaTensor& x,
const std::vector<int64_t>& index_stride,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
MetaTensor* out) {
out->set_dims(common::make_ddim(input_dims));
out->set_dtype(x.dtype());
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ void IndexElementwiseGetInferMeta(const MetaTensor& x,
const std::vector<int64_t>& index_stride,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
MetaTensor* out);

void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
auto dxt = phi::EigenVector<T>::Flatten(*x_grad);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/index_elementwise_get_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ void IndexElementwiseGetKernel(const Context& dev_ctx,
const std::vector<int64_t>& index_stride,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
DenseTensor* out) {
const auto& index_type = index[0]->dtype();
PADDLE_ENFORCE_EQ(index_type == phi::DataType::INT64,
Expand Down
117 changes: 117 additions & 0 deletions paddle/phi/kernels/funcs/radix_sort.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/funcs/radix_sort.h"
#include "paddle/phi/common/memory_utils.h"

namespace phi {
namespace funcs {

namespace {

template <typename T>
struct CudaType {
using type = T;
};

template <>
struct CudaType<int64_t> {
using type = long long; // NOLINT
};

#define PADDLE_CUB_WRAPPER(func, ...) \
do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto temp_storage = \
phi::memory_utils::Alloc(dev_ctx.GetPlace(), temp_storage_bytes); \
func(temp_storage->ptr(), temp_storage_bytes, __VA_ARGS__); \
} while (0)

} // namespace

template <typename key_t, int value_size>
void RadixSortPairsImpl(const phi::GPUContext& dev_ctx,
const key_t* keys_in,
key_t* keys_out,
const OpaqueTypeRadix<value_size>* values_in,
OpaqueTypeRadix<value_size>* values_out,
int64_t n,
bool descending,
int64_t begin_bit,
int64_t end_bit) {
PADDLE_ENFORCE_LE(
n,
std::numeric_limits<int>::max(),
phi::errors::InvalidArgument(
"CUB sort does not support sorting more than INT_MAX elements"));

using key_t_ = typename CudaType<key_t>::type;

phi::Allocator::AllocationPtr keys_out_owner;
if (keys_out == nullptr) {
keys_out_owner =
phi::memory_utils::Alloc(dev_ctx.GetPlace(), n * sizeof(key_t));
keys_out = reinterpret_cast<key_t*>(keys_out_owner->ptr());
}

const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
PADDLE_CUB_WRAPPER(cub::DeviceRadixSort::SortPairsDescending,
keys_in_,
keys_out_,
values_in,
values_out,
static_cast<int>(n),
begin_bit,
end_bit,
dev_ctx.stream());
} else {
PADDLE_CUB_WRAPPER(cub::DeviceRadixSort::SortPairs,
keys_in_,
keys_out_,
values_in,
values_out,
static_cast<int>(n),
begin_bit,
end_bit,
dev_ctx.stream());
}
}

#define INSTANTIATE_SORT_PAIRS(key_t, value_size) \
template void RadixSortPairsImpl<key_t, value_size>( \
const phi::GPUContext&, \
const key_t*, \
key_t*, \
const OpaqueTypeRadix<value_size>*, \
OpaqueTypeRadix<value_size>*, \
int64_t, \
bool, \
int64_t, \
int64_t);

INSTANTIATE_SORT_PAIRS(int32_t, 1)
INSTANTIATE_SORT_PAIRS(int32_t, 2)
INSTANTIATE_SORT_PAIRS(int32_t, 4)
INSTANTIATE_SORT_PAIRS(int64_t, 1)
INSTANTIATE_SORT_PAIRS(int64_t, 2)
INSTANTIATE_SORT_PAIRS(int64_t, 4)
INSTANTIATE_SORT_PAIRS(int32_t, 8)
INSTANTIATE_SORT_PAIRS(int64_t, 8)

} // namespace funcs
} // namespace phi
80 changes: 80 additions & 0 deletions paddle/phi/kernels/funcs/radix_sort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <cub/cub.cuh>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
namespace funcs {

template <int kValueSize>
struct OpaqueTypeRadix {
uint8_t data[kValueSize];
__device__ __host__ OpaqueTypeRadix() = default;
};

template <typename key_t, int kValueSize>
void RadixSortPairsImpl(const phi::GPUContext& dev_ctx,
const key_t* keys_in,
key_t* keys_out,
const OpaqueTypeRadix<kValueSize>* values_in,
OpaqueTypeRadix<kValueSize>* values_out,
int64_t n,
bool descending = false,
int64_t begin_bit = 0,
int64_t end_bit = sizeof(key_t) * 8);

template <typename key_t, typename value_t>
void RadixSortPairs(const phi::GPUContext& dev_ctx,
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int64_t n,
bool descending = false,
int64_t begin_bit = 0,
int64_t end_bit = sizeof(key_t) * 8) {
PADDLE_ENFORCE_EQ(
std::is_trivially_copyable<value_t>::value,
true,
phi::errors::InvalidArgument(
"RadixSortPairs value type must be trivially copyable"));

using opaque_t = OpaqueTypeRadix<sizeof(value_t)>;
PADDLE_ENFORCE_EQ(
sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0,
true,
phi::errors::InvalidArgument(
"Unsupported value_t size (must be 1, 2, 4, or 8 bytes)"));
PADDLE_ENFORCE_EQ(
sizeof(value_t),
alignof(value_t),
phi::errors::InvalidArgument("Expected value_t to be size-aligned"));

RadixSortPairsImpl<key_t, sizeof(value_t)>(
dev_ctx,
keys_in,
keys_out,
reinterpret_cast<const opaque_t*>(values_in),
reinterpret_cast<opaque_t*>(values_out),
n,
descending,
begin_bit,
end_bit);
}

} // namespace funcs
} // namespace phi
Loading
Loading