Skip to content

Commit d2984bd

Browse files
[slice]add single path for Set value when value is tensor (#73390)
* decrease slice_ci time * update slice yaml * set_value slice none case * delete unuse expand
1 parent 4ad3b8f commit d2984bd

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

paddle/phi/kernels/gpu/set_value_kernel.cu

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ void SetTensorValueKernelV2(const Context& dev_ctx,
9191
if (value.numel() == 1) {
9292
expand_tensor = value;
9393
expand_tensor.Resize(phi::make_ddim({1}));
94+
} else if (product(value.dims()) == product(phi::make_ddim(new_out_shape))) {
95+
expand_tensor = value;
96+
if (value.dims() != phi::make_ddim(new_out_shape)) {
97+
expand_tensor.Resize(phi::make_ddim(new_out_shape));
98+
}
99+
94100
} else {
95101
auto value_dims = phi::vectorize<int64_t>(value.dims());
96102
DenseTensor value_tensor = Empty<T>(dev_ctx, IntArray{value_dims});
@@ -109,12 +115,21 @@ void SetTensorValueKernelV2(const Context& dev_ctx,
109115

110116
out->ResetHolder(in.Holder());
111117
out->ShareInplaceVersionCounterWith(in);
112-
StridedCopyKernel<T, Context>(dev_ctx,
113-
expand_tensor,
114-
new_out_shape,
115-
new_out_stride,
116-
output_offset,
117-
out);
118+
if (starts_local.empty() && ends_local.empty() && steps_local.empty()) {
119+
if (expand_tensor.numel() == 1) {
120+
ExpandKernel<T, Context>(
121+
dev_ctx, expand_tensor, IntArray{new_out_shape}, out);
122+
} else {
123+
Copy<Context>(dev_ctx, expand_tensor, dev_ctx.GetPlace(), false, out);
124+
}
125+
} else {
126+
StridedCopyKernel<T, Context>(dev_ctx,
127+
expand_tensor,
128+
new_out_shape,
129+
new_out_stride,
130+
output_offset,
131+
out);
132+
}
118133
out->set_meta(meta);
119134
}
120135

0 commit comments

Comments
 (0)