Skip to content
6 changes: 6 additions & 0 deletions doc/fluid/api/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,12 @@ argmax
.. autofunction:: paddle.fluid.layers.argmax
:noindex:

argsort
------

.. autofunction:: paddle.fluid.layers.argsort
:noindex:

ones
----

Expand Down
87 changes: 87 additions & 0 deletions paddle/fluid/operators/argsort_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/argsort_op.h"

namespace paddle {
namespace operators {

class ArgsortOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ArgsortOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ArgsortOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Indices"),
"Output(Indices) of ArgsortOp should not be null.");

auto in_dims = ctx->GetInputDim("X");
int axis = ctx->Attrs().Get<int>("axis");

auto num_dims = in_dims.size();
PADDLE_ENFORCE(axis < num_dims,
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X)'s "
"rank %d.",
axis, num_dims);
PADDLE_ENFORCE(axis >= -num_dims,
"Attr(axis) %d of ArgsortOp must be not less than "
"-rank(Input(X)) (%d).",
axis, num_dims);

ctx->SetOutputDim("Out", in_dims);
ctx->SetOutputDim("Indices", in_dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
};

class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input of Argsort op.");
AddOutput("Out",
"(Tensor) The sorted tensor of Argsort op, with the same "
"shape as Input(X).");
AddOutput("Indices",
"(Tensor) The indices of a tensor giving the sorted order, with "
"the same shape as Input(X).");
AddComment(R"DOC(
Argsort operator

Performs sorting on the input tensor along the given axis and outputs two
tensors, Output(Out) and Output(Indices). They reserve the same shape
with Input(X), and Output(Out) represents the sorted tensor while
Output(Indices) gives the sorted order along the given axis Attr(axis).

)DOC");
AddAttr<int>("axis",
"(int, default -1) The axis along which to sort the tensor. "
"When axis < 0, the actual axis will be the |axis|'th "
"counting backwards. Default -1, the last dimension.")
.SetDefault(-1);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(argsort,
ops::ArgsortKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortKernel<paddle::platform::CPUPlace, double>);
152 changes: 152 additions & 0 deletions paddle/fluid/operators/argsort_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;

__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
int axis, int64_t n, int64_t* trg_idx,
int64_t* med_ids) {
int64_t index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
const int max_rank = 9; // Max rank of a tensor allow in Fluid
Copy link
Contributor

@qingqing01 qingqing01 Jun 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this constant variable before line 19.

const int kMaxRank = 6; 
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean outside the kernel function? Then done.

int64_t shape_out_axis[max_rank - 1] = {0};
int64_t dims_out_axis[max_rank - 1] = {0};
int64_t tmp = index;
int64_t pos_in_axis = 0;
int64_t i = dims_size - 2;
int64_t dim_axis = 0;
for (int64_t j = dims_size - 1; j >= 0; --j) {
int64_t dim = in_dims[j];
if (j != axis) {
shape_out_axis[i] = tmp % dim;
dims_out_axis[i] = dim;
i--;
} else {
dim_axis = dim;
pos_in_axis = tmp % dim_axis;
}
tmp /= dim;
}
int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0;
for (int64_t j = 0; j < dims_size - 2; ++j) {
group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1];
}

int64_t traget_idx = group * dim_axis + pos_in_axis;
trg_idx[index] = traget_idx;
med_ids[traget_idx] = pos_in_axis;
}
}

template <typename T>
__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n,
T* med_out) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
med_out[trg_idx[index]] = in[index];
}
}

template <typename T>
__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out,
int64_t* med_ids) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < groups) {
thrust::sort_by_key(thrust::device, med_out + index * axis_dim,
med_out + axis_dim * (1 + index),
med_ids + index * axis_dim);
}
}

template <typename T>
__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids,
const int64_t* trg_idx, int64_t n, T* out,
int64_t* indices) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
out[index] = med_out[trg_idx[index]];
indices[index] = med_ids[trg_idx[index]];
}
}

template <typename T>
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
int axis = ctx.Attr<int>("axis");

auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());

int64_t numel = input->numel();
int64_t groups = numel / in_dims[axis];

std::vector<int64_t> in_dims_vec = vectorize(in_dims);
thrust::device_vector<int64_t> in_dims_dev(in_dims_vec.begin(),
in_dims_vec.end());
int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data());
// Mediate tensor for sorting data and indices
Tensor mediate_output, mediate_indices;
T* med_out_data =
mediate_output.mutable_data<T>(input->dims(), ctx.GetPlace());
int64_t* med_ids_data =
mediate_indices.mutable_data<int64_t>(in_dims, ctx.GetPlace());
// Target index of each element along the given axis in the mediate tensors
Tensor trg_idx_t;
int64_t* trg_idx = trg_idx_t.mutable_data<int64_t>(in_dims, ctx.GetPlace());

auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto stream = ctx.cuda_device_context().stream();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int num_threads = PADDLE_CUDA_NUM_THREADS;

ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>(
in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data);

PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>(
in_data, trg_idx, numel, med_out_data);

Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>(
in_dims[axis], groups, med_out_data, med_ids_data);

PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0,
stream>>>(med_out_data, med_ids_data, trg_idx, numel,
out_data, ids_data);
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
paddle::operators::ArgsortOpCUDAKernel<double>);
81 changes: 81 additions & 0 deletions paddle/fluid/operators/argsort_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class ArgsortKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out");
auto* indices = ctx.Output<framework::Tensor>("Indices");
int axis = ctx.Attr<int>("axis");

auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());

int64_t groups = input->numel() / in_dims[axis];
int64_t stride = (axis == in_dims.size() - 1)
? 1
: framework::product(framework::slice_ddim(
in_dims, axis + 1, in_dims.size()));

for (int64_t i = 0; i < groups; ++i) {
int64_t idx = i;
std::vector<int64_t> shape_vec(in_dims.size(), 0);
for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) {
if (dim != axis) {
shape_vec[dim] = idx % in_dims[dim];
idx /= in_dims[dim];
}
}

int64_t start_index = shape_vec[0];
for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) {
start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1];
}

std::vector<int64_t> org_index_vec(in_dims[axis], start_index);
for (int64_t j = 1; j < in_dims[axis]; ++j) {
org_index_vec[j] += j * stride;
}

std::sort(org_index_vec.begin(), org_index_vec.end(),
[in_data](const int64_t v1, const int64_t v2) {
return in_data[v1] < in_data[v2];
});

for (size_t j = 0; j < org_index_vec.size(); ++j) {
int64_t index = start_index + j * stride;
out_data[index] = in_data[org_index_vec[j]];
ids_data[index] = (org_index_vec[j] - start_index) / stride;
}
}
Copy link
Collaborator

@sneaxiy sneaxiy Jun 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 40-73 can be changed to be more efficient and save memory used.

int64_t part_dims_prod = input->numel() / in_dims[axis]; int64_t step = 1; for (int64_t i = in_dims.size()-1; i > axis; --i) step *= in_dims[i]; std::vector<int64_t> org_index_vec(in_dims.size()); std::vector<int64_t> idx_vec(in_dims.size()); idx_vec[axis] = 0; for (int64_t i = 0; i < part_dims_prod; ++i) { for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { if (dim != axis) { idx_vec[dim] = idx % in_dims[dim]; idx /= in_dims[dim]; } } int64_t start_index = idx_vec[0]; for (int64_t dim = 1; dim < in_dims.size(); ++dim) { start_index = start_index * in_dims[dim] + idx_vec[dim]; } for (int64_t j = 0; j < in_dims.size(); ++j) { org_index_vec[j] = start_index + j*step; } std::sort( org_index_vec.begin(), org_index_vec.end(), [in_data](int64_t idx1, int64_t idx2) { return in_data[idx1] < in_data[idx2]; }); for (size_t j = 0; j < org_index_vec.size(); ++j) { int64_t org_index = org_index_vec[j]; int64_t ret_index = start_index + j*step; out_data[ret_index] = in_data[org_index]; idx_data[ret_index] = org_index; } }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It is a good idea to only sort the index, and I made the change. Please take a look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent!

}
};

} // namespace operators
} // namespace paddle
Loading