- Notifications
You must be signed in to change notification settings - Fork 5.9k
optimize cross entropy kernel by using reduce. #4237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize cross entropy kernel by using reduce. #4237
Conversation
c6a995f to a3a8a09 Compare | self.inputs = {'X': X, 'Label': label} | ||
| self.outputs = {'Y': cross_entropy} | ||
| self.attrs = {'soft_label': 1} | ||
| self.inputs = {"X": X, "Label": label} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://www.python.org/dev/peps/pep-0008/#string-quotes
We following PEP8 as our style guide. In python, both '' and "" are fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I know, only personally, I prefer to keep consistent in one file.
4b55fd2 to 1904577 Compare paddle/operators/cross_entropy_op.cu Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样分配二维block,相比一维的block空的线程更多,这个计算实际一维block就够了, 可以是使用非soft-label的方式:
int block = 512; int grid = (n * d + block - 1) / block;kernel 改成:
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N * D; i += blockDim.x * gridDim.x) { int id_dy = ids / D; dX[i] = -label[i] * dY[id_dy] / X[i]; }另外如果方便的话,可以帮忙一起修复下:
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
谢谢了,当时写的时候ExecutionContext里还没有stream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
paddle/operators/cross_entropy_op.cu Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉line 123, 写成:
SoftCrossEntropyKernel<T, 512><<<d, block>>>There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
a637d86 to 031d7eb Compare |
|
e8ed548 to fa3e373 Compare | All finished, please review this, thanks, everyone. @qingqing01 @Xreki |
11516af to 058dfbb Compare 8f776c8 to f5be351 Compare 1a4eff4 to ae59260 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
paddle/operators/cross_entropy_op.cc Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set to 0 -> false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/operators/cross_entropy_op.cc Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set to 1 -> set to true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ae59260 to c80489a Compare c80489a to 000d751 Compare
fixes #4236