Skip to content

Commit a6ef875

Browse files
committed
refine conv
1 parent 5ba231d commit a6ef875

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

paddle/operators/conv_op.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,11 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
260260

261261
if (input_grad) {
262262
input_grad->mutable_data<T>(context.GetPlace());
263-
set_zero(context.device_context(), input_grad, static_cast<T>(0));
264-
263+
// if is_expand is false, the operation of set_zero is unnecessary,
264+
// because math::matmul will reset input_grad.
265+
if (is_expand) {
266+
set_zero(context.device_context(), input_grad, static_cast<T>(0));
267+
}
265268
math::Col2VolFunctor<Place, T> col2vol;
266269
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
267270

paddle/operators/conv_transpose_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
225225

226226
if (input_grad) {
227227
input_grad->mutable_data<T>(context.GetPlace());
228-
set_zero(context.device_context(), input_grad, static_cast<T>(0));
228+
// set_zero is unnecessary, math::matmul will reset input_grad.
229229
}
230230
if (filter_grad) { // filter size (m, c, k_h, k_w)
231231
filter_grad->mutable_data<T>(context.GetPlace());

0 commit comments

Comments
 (0)