@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include < thrust/device_vector.h>
16+ #include " paddle/fluid/framework/tensor_util.h"
1617#include " paddle/fluid/operators/math/math_function.h"
1718#include " paddle/fluid/operators/slice_op.h"
1819#include " paddle/fluid/platform/cuda_device_function.h"
1920#include " paddle/fluid/platform/cuda_primitives.h"
2021#include " paddle/fluid/platform/float16.h"
21-
2222namespace paddle {
2323namespace operators {
2424
@@ -94,17 +94,22 @@ class SliceGradKernel<paddle::platform::CUDADeviceContext,
9494 dim3 blocks ((numel - 1 ) / PADDLE_CUDA_NUM_THREADS + 1 );
9595 dim3 threads (PADDLE_CUDA_NUM_THREADS);
9696 auto stream = ctx.cuda_device_context ().stream ();
97-
98- auto out_shape = framework::vectorize<int64_t >(out_dims);
99- thrust::device_vector<int64_t > out_dims_vec (out_shape.begin (),
100- out_shape.end ());
101- auto in_shape = framework::vectorize<int64_t >(in_dims);
102- thrust::device_vector<int64_t > in_dims_vec (in_shape.begin (),
103- in_shape.end ());
104- thrust::device_vector<int64_t > offsets_vec (offsets.begin (), offsets.end ());
105- const int64_t * out_dims_ptr = thrust::raw_pointer_cast (out_dims_vec.data ());
106- const int64_t * in_dims_ptr = thrust::raw_pointer_cast (in_dims_vec.data ());
107- const int64_t * offsets_ptr = thrust::raw_pointer_cast (offsets_vec.data ());
97+ const std::vector<int64_t > out_shape =
98+ framework::vectorize<int64_t >(out_dims);
99+ const std::vector<int64_t > in_shape =
100+ framework::vectorize<int64_t >(in_dims);
101+
102+ framework::Tensor out_dims_tensor;
103+ framework::Tensor in_dims_tensor;
104+ framework::Tensor offsets_tensor;
105+ framework::TensorFromVector (out_shape, ctx.device_context (),
106+ &out_dims_tensor);
107+ framework::TensorFromVector (in_shape, ctx.device_context (),
108+ &in_dims_tensor);
109+ framework::TensorFromVector (offsets, ctx.device_context (), &offsets_tensor);
110+ const int64_t * out_dims_ptr = out_dims_tensor.data <int64_t >();
111+ const int64_t * in_dims_ptr = in_dims_tensor.data <int64_t >();
112+ const int64_t * offsets_ptr = offsets_tensor.data <int64_t >();
108113
109114 switch (rank) {
110115 case 1 :
0 commit comments