Skip to content

Commit c4d57f3

Browse files
committed
use phi::Copy
1 parent 8aef993 commit c4d57f3

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
1616

1717
#include "paddle/fluid/framework/convert_utils.h"
18-
#include "paddle/fluid/framework/tensor_util.h"
1918
#include "paddle/fluid/operators/gather_scatter_kernel.h"
2019
#include "paddle/fluid/platform/place.h"
2120
#include "paddle/phi/backends/cpu/cpu_context.h"
2221
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/copy_kernel.h"
2323

2424
namespace phi {
2525

@@ -40,7 +40,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
4040
const auto& index_type =
4141
paddle::framework::TransToProtoVarType(index.dtype());
4242
if (x_grad) {
43-
paddle::framework::TensorCopy(out_grad, dev_ctx.GetPlace(), x_grad);
43+
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
4444
if (index_type == paddle::framework::proto::VarType::INT32) {
4545
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>(
4646
// Here passing an unused argument out_grad, because it's

paddle/phi/kernels/cpu/put_along_axis_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
#include "paddle/phi/kernels/put_along_axis_kernel.h"
1616

1717
#include "paddle/fluid/framework/convert_utils.h"
18-
#include "paddle/fluid/framework/tensor_util.h"
1918
#include "paddle/fluid/operators/gather_scatter_kernel.h"
2019
#include "paddle/fluid/platform/place.h"
2120
#include "paddle/phi/backends/cpu/cpu_context.h"
2221
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/copy_kernel.h"
2323

2424
namespace phi {
2525

@@ -36,7 +36,7 @@ void PutAlongAxisKernel(const Context& dev_ctx,
3636
true,
3737
errors::PreconditionNotMet("PutAlongAxisOpKernel only runs on CPU."));
3838

39-
paddle::framework::TensorCopy(x, dev_ctx.GetPlace(), out);
39+
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
4040
const auto& index_type =
4141
paddle::framework::TransToProtoVarType(index.dtype());
4242
if (reduce == "add") {

paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
1616

1717
#include "paddle/fluid/framework/convert_utils.h"
18-
#include "paddle/fluid/framework/tensor_util.h"
1918
#include "paddle/fluid/operators/gather_scatter_kernel.h"
2019
#include "paddle/fluid/platform/place.h"
2120
#include "paddle/phi/backends/gpu/gpu_context.h"
2221
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/copy_kernel.h"
2323

2424
namespace phi {
2525

@@ -40,7 +40,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
4040
const auto& index_type =
4141
paddle::framework::TransToProtoVarType(index.dtype());
4242
if (x_grad) {
43-
paddle::framework::TensorCopy(out_grad, dev_ctx.GetPlace(), x_grad);
43+
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
4444
if (index_type == paddle::framework::proto::VarType::INT32) {
4545
paddle::operators::gpu_scatter_input_grad_kernel<T, int32_t>(
4646
out_grad, axis, index, *x_grad, dev_ctx);

paddle/phi/kernels/gpu/put_along_axis_kernel.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "paddle/fluid/platform/place.h"
2020
#include "paddle/phi/backends/gpu/gpu_context.h"
2121
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/copy_kernel.h"
2223

2324
namespace phi {
2425

@@ -38,7 +39,7 @@ void PutAlongAxisKernel(const Context& dev_ctx,
3839
const auto& index_type =
3940
paddle::framework::TransToProtoVarType(index.dtype());
4041

41-
paddle::framework::TensorCopy(x, dev_ctx.GetPlace(), out);
42+
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
4243
if (reduce == "add") {
4344
if (index_type == paddle::framework::proto::VarType::INT32) {
4445
paddle::operators::gpu_scatter_add_kernel<T, int32_t>(

0 commit comments

Comments
 (0)