Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion paddle/phi/kernels/gpu/argsort_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ void ArgFullSort(const phi::GPUContext& dev_ctx,
const int64_t num_rows,
const int64_t num_cols,
const bool descending) {
PADDLE_ENFORCE_LE(num_cols,
std::numeric_limits<int>::max(),
::common::errors::PreconditionNotMet(
"The dimension being sorted should be less than "
"2^31, but got %lld. Please check the input tensor. ",
num_cols));

auto cu_stream = dev_ctx.stream();
auto ComputeBlockSize = [](IndType col) {
if (col > 512)
Expand Down Expand Up @@ -228,8 +235,14 @@ void ArgFullSort(const phi::GPUContext& dev_ctx,
const int64_t total_elements = num_cols * num_rows;
const int64_t segment_size = num_cols;
const int64_t element_per_call = std::min(max_elements, total_elements);

// make sure element_per_call >= segment_size
const int64_t adjusted_elements_per_call =
std::max(max_elements, segment_size);

// make sure batch size is the multiple of segment_size
const int64_t batch_size = (element_per_call / segment_size) * segment_size;
const int64_t batch_size =
(adjusted_elements_per_call / segment_size) * segment_size;
int64_t offset = 0;
DenseTensor input_indices;

Expand Down
Loading