@@ -91,6 +91,12 @@ void SetTensorValueKernelV2(const Context& dev_ctx,
9191 if (value.numel () == 1 ) {
9292 expand_tensor = value;
9393 expand_tensor.Resize (phi::make_ddim ({1 }));
94+ } else if (product (value.dims ()) == product (phi::make_ddim (new_out_shape))) {
95+ expand_tensor = value;
96+ if (value.dims () != phi::make_ddim (new_out_shape)) {
97+ expand_tensor.Resize (phi::make_ddim (new_out_shape));
98+ }
99+
94100 } else {
95101 auto value_dims = phi::vectorize<int64_t >(value.dims ());
96102 DenseTensor value_tensor = Empty<T>(dev_ctx, IntArray{value_dims});
@@ -109,12 +115,21 @@ void SetTensorValueKernelV2(const Context& dev_ctx,
109115
110116 out->ResetHolder (in.Holder ());
111117 out->ShareInplaceVersionCounterWith (in);
112- StridedCopyKernel<T, Context>(dev_ctx,
113- expand_tensor,
114- new_out_shape,
115- new_out_stride,
116- output_offset,
117- out);
118+ if (starts_local.empty () && ends_local.empty () && steps_local.empty ()) {
119+ if (expand_tensor.numel () == 1 ) {
120+ ExpandKernel<T, Context>(
121+ dev_ctx, expand_tensor, IntArray{new_out_shape}, out);
122+ } else {
123+ Copy<Context>(dev_ctx, expand_tensor, dev_ctx.GetPlace (), false , out);
124+ }
125+ } else {
126+ StridedCopyKernel<T, Context>(dev_ctx,
127+ expand_tensor,
128+ new_out_shape,
129+ new_out_stride,
130+ output_offset,
131+ out);
132+ }
118133 out->set_meta (meta);
119134}
120135
0 commit comments