Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions paddle/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return;
}

AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS>>>(
num_samples, infer_width, inference_data, label_data, accuracy_data);
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
1, PADDLE_CUDA_NUM_THREADS, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(num_samples, infer_width, inference_data, label_data,
accuracy_data);
}
};

Expand Down
83 changes: 51 additions & 32 deletions paddle/operators/cross_entropy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null.");
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) should be not null.");

auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
"Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
if (ctx.Attr<bool>("soft_label")) {
if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == true, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
"If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == false, The 2nd dimension of "
"Input(Label) must be 1.");
"If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}

ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1});
Expand All @@ -57,35 +58,38 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null.");
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) must not be null.");
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");

auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2,
"Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
"Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
"The 1st dimension of Input(X) and Input(Y@Grad) must "
"The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) must be 1.");
if (ctx.Attr<bool>("soft_label")) {
"The 2nd dimension of Input(Y@Grad) should be 1.");
if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == true, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
"When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == false, The 2nd dimension of "
"Input(Label) must be 1.");
"When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}

auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Expand All @@ -98,24 +102,39 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
CrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of CrossEntropyOp");
AddInput("Label", "The second input of CrossEntropyOp");
AddOutput("Y", "The output of CrossEntropyOp");
AddAttr<bool>("soft_label", "Is soft label. Default zero.")
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
"This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator.");
AddInput(
"Label",
"(Tensor, default Tensor<int>), the ground truth which is "
"a 2-D tensor. "
"When softLabel is set to false, `Label` is a Tensor<int> with shape "
"[N x 1]. "
"When softLabel is set to true, `Label` is a Tensor<float/double> "
"with shape [N x K].");
AddOutput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor "
"with shape [N x 1]. The cross entropy loss.");
AddAttr<bool>(
"softLabel",
"(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false);

AddComment(R"DOC(
CrossEntropy Operator.

It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
soft_label = False, Label[i, 0] indicates the class index for sample i:
softLabel = false, Label[i, 0] indicates the class index for sample i:

Y[i] = -log(X[i, Label[i]])

2) Soft-label cross-entropy:
soft_label = True, Label[i, j] indicates the soft label of class j
softLabel = true, Label[i, j] indicates the soft label of class j
for sample i:

Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Expand Down
147 changes: 92 additions & 55 deletions paddle/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -tolerable_value(log(X[i * D + label[i]]));
Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
}
}

template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
return val;
}

template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int N, const int D) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
T sum = static_cast<T>(0);
for (int j = 0; j < D; j++) {
sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
}
Y[i] = -sum;
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
d_sum[tid] = 0;

int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
}
__syncthreads();

for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
}

T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}

// TODO(qingqing): make zero setting an common function.
// TODO(qingqing): make zero setting a common function.
template <typename T>
__global__ void zero(T* X, const int N) {
__global__ void Zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
X[i] = 0.0;
Expand All @@ -71,13 +94,10 @@ template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const T* label, const int N,
const int D) {
// TOOD(qingqing): optimize for this kernel
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
for (int j = 0; j < D; ++j) {
int idx = i * D + j;
dX[idx] = -label[idx] * dY[i] / X[idx];
}
int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < N * D) {
int row_ids = ids / D;
dX[ids] = -label[ids] * dY[row_ids] / X[ids];
}
}

Expand All @@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"This kernel only runs on GPU device.");

auto x = ctx.Input<Tensor>("X");
auto y = ctx.Output<Tensor>("Y");
auto label = ctx.Input<Tensor>("Label");
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* label = ctx.Input<Tensor>("Label");
Tensor* y = ctx.Output<Tensor>("Y");

auto* x_data = x->data<T>();
y->mutable_data<T>(ctx.GetPlace());
auto* y_data = y->data<T>();
const T* x_data = x->data<T>();
T* y_data = y->mutable_data<T>(ctx.GetPlace());

int n = x->dims()[0];
int d = x->dims()[1];
int block = 512;
int grid = (n + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
if (ctx.Attr<bool>("soft_label")) {
int batch_size = x->dims()[0];
int class_num = x->dims()[1];

if (ctx.Attr<bool>("softLabel")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
d);
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));

SoftCrossEntropyKernel<
T><<<batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data, class_num);
} else {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
int block = 512;
int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data,
batch_size, class_num);
}
}
};
Expand All @@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"This kernel only runs on GPU device.");

const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* label = ctx.Input<Tensor>("Label");
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));

auto x = ctx.Input<Tensor>("X");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto label = ctx.Input<Tensor>("Label");
const T* dy_data =
ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>();

auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->data<T>();
auto* x_data = x->data<T>();
int batch_size = x->dims()[0];
int class_num = x->dims()[1];

int n = x->dims()[0];
int d = x->dims()[1];
int block = 512;
int grid = (n * d + block - 1) / block;
zero<T><<<grid, block>>>(dx_data, n * d);
grid = (n + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
if (ctx.Attr<bool>("soft_label")) {
int grid = (batch_size * class_num + block - 1) / block;

if (ctx.Attr<bool>("softLabel")) {
auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
dx_data, dy_data, x_data, label_data, n, d);
SoftCrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num);
} else {
Zero<T><<<grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, batch_size * class_num);

auto* label_data = label->data<int>();
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
label_data, n, d);
grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num);
}
}
};
Expand Down
Loading