Skip to content

Commit 1f6403d

Browse files
【CPU】change Batch norm reserve_space dtype (#72314)
* batch_norm nouse output * Apply suggestions from code review
1 parent 8bd5227 commit 1f6403d

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

paddle/phi/kernels/cpu/batch_norm_kernel.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ void BatchNormKernel(const Context& ctx,
101101
ctx.template Alloc<T>(saved_variance), C);
102102
saved_mean_e.setZero();
103103
saved_variance_e.setZero();
104-
EigenVectorArrayMap<T> reserve_space_e(ctx.template Alloc<T>(reserve_space),
105-
0);
104+
EigenVectorArrayMap<uint8_t> reserve_space_e(
105+
ctx.template Alloc<uint8_t>(reserve_space), 0);
106106
reserve_space_e.setZero();
107107

108108
EigenVectorArrayMap<T> running_mean_arr(ctx.template Alloc<T>(mean_out), C);
@@ -222,4 +222,6 @@ void BatchNormKernel(const Context& ctx,
222222
} // namespace phi
223223

224224
PD_REGISTER_KERNEL(
225-
batch_norm, CPU, ALL_LAYOUT, phi::BatchNormKernel, float, double) {}
225+
batch_norm, CPU, ALL_LAYOUT, phi::BatchNormKernel, float, double) {
226+
kernel->OutputAt(5).SetDataType(phi::DataType::UINT8);
227+
}

paddle/phi/kernels/xpu/batch_norm_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ void BatchNormKernel(const Context& dev_ctx,
9898
dev_ctx.template Alloc<float>(variance_out);
9999
dev_ctx.template Alloc<float>(saved_mean);
100100
dev_ctx.template Alloc<float>(saved_variance);
101-
102101
PADDLE_ENFORCE_LE(
103102
x_dims.size(),
104103
5,
@@ -164,4 +163,5 @@ PD_REGISTER_KERNEL(batch_norm,
164163
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
165164
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
166165
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
166+
kernel->OutputAt(5).SetDataType(phi::DataType::UINT8);
167167
}

0 commit comments

Comments
 (0)