Skip to content

Conversation

@lcy-seso
Copy link
Contributor

@lcy-seso lcy-seso commented Sep 20, 2017

fixes #4236

  • optimize the forward kernel with soft labels
  • optimize the backward kernel with soft labels
  • Cpu forward kernel by directly calling Eigen.
  • Cpu backward kernel by directly calling Eigen.
@lcy-seso lcy-seso requested a review from Xreki September 20, 2017 15:25
@lcy-seso lcy-seso force-pushed the optimize_cross_entropy_kernel branch from c6a995f to a3a8a09 Compare September 20, 2017 15:26
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
self.inputs = {"X": X, "Label": label}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://www.python.org/dev/peps/pep-0008/#string-quotes

We following PEP8 as our style guide. In python, both '' and "" are fine.

Copy link
Contributor Author

@lcy-seso lcy-seso Sep 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I know, only personally, I prefer to keep consistent in one file.

@lcy-seso lcy-seso force-pushed the optimize_cross_entropy_kernel branch 2 times, most recently from 4b55fd2 to 1904577 Compare September 21, 2017 00:24
Copy link
Contributor

@qingqing01 qingqing01 Sep 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样分配二维block,相比一维的block空的线程更多,这个计算实际一维block就够了, 可以是使用非soft-label的方式:

int block = 512; int grid = (n * d + block - 1) / block;

kernel 改成:

for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N * D; i += blockDim.x * gridDim.x) { int id_dy = ids / D; dX[i] = -label[i] * dY[id_dy] / X[i]; }

另外如果方便的话,可以帮忙一起修复下:

// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.

谢谢了,当时写的时候ExecutionContext里还没有stream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉line 123, 写成:

SoftCrossEntropyKernel<T, 512><<<d, block>>>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@lcy-seso lcy-seso force-pushed the optimize_cross_entropy_kernel branch 4 times, most recently from a637d86 to 031d7eb Compare September 22, 2017 15:48
@lcy-seso
Copy link
Contributor Author

lcy-seso commented Sep 22, 2017

Not finished yet, please do not review. Thanks, everyone.

@lcy-seso lcy-seso force-pushed the optimize_cross_entropy_kernel branch 15 times, most recently from e8ed548 to fa3e373 Compare September 23, 2017 07:49
@lcy-seso
Copy link
Contributor Author

lcy-seso commented Sep 23, 2017

All finished, please review this, thanks, everyone. @qingqing01 @Xreki

@lcy-seso lcy-seso force-pushed the optimize_cross_entropy_kernel branch 4 times, most recently from 11516af to 058dfbb Compare September 23, 2017 08:34
@lcy-seso lcy-seso force-pushed the optimize_cross_entropy_kernel branch 3 times, most recently from 8f776c8 to f5be351 Compare September 23, 2017 09:10
@lcy-seso lcy-seso force-pushed the optimize_cross_entropy_kernel branch 5 times, most recently from 1a4eff4 to ae59260 Compare September 25, 2017 11:07
qingqing01
qingqing01 previously approved these changes Sep 26, 2017
Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set to 0 -> false.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set to 1 -> set to true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@lcy-seso lcy-seso force-pushed the optimize_cross_entropy_kernel branch from c80489a to 000d751 Compare September 26, 2017 05:43
@lcy-seso lcy-seso merged commit 7d65321 into PaddlePaddle:develop Sep 26, 2017
@lcy-seso lcy-seso deleted the optimize_cross_entropy_kernel branch September 27, 2017 02:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

3 participants