Skip to content

Commit 46b2488

Browse files
committed
reduce vector operations
1 parent ed77742 commit 46b2488

File tree

2 files changed

+15
-21
lines changed

2 files changed

+15
-21
lines changed

paddle/phi/kernels/funcs/index_put_utils.h

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -241,31 +241,27 @@ void DealWithIndices(const Context& dev_ctx,
241241
}
242242

243243
} else {
244-
std::vector<DenseTensor> int_indices_v_tmp;
245-
246244
for (size_t i = 0; i < int_indices_v.size(); ++i) {
245+
phi::DenseTensor index_tensor;
246+
phi::DenseTensor expand_index;
247247
if (int_indices_v[i]->dtype() == phi::DataType::INT32) {
248-
int_indices_v_tmp.emplace_back(phi::Cast<int, Context>(
249-
dev_ctx, *int_indices_v[i], phi::DataType::INT64));
248+
index_tensor = phi::Cast<int, Context>(
249+
dev_ctx, *int_indices_v[i], phi::DataType::INT64);
250250
} else {
251-
int_indices_v_tmp.emplace_back(*int_indices_v[i]);
251+
index_tensor = *int_indices_v[i];
252252
}
253-
}
254-
255-
for (size_t i = 0; i < int_indices_v.size(); ++i) {
256253
if (bd_dim != int_indices_v[i]->dims()) {
257-
tmp_res_indices_v->emplace_back(
258-
DenseTensor(phi::DataType::INT64).Resize(bd_dim));
254+
expand_index = DenseTensor(phi::DataType::INT64).Resize(bd_dim);
259255
ExpandKernel<int64_t, Context>(
260256
dev_ctx,
261-
int_indices_v_tmp[i],
257+
index_tensor,
262258
IntArray(common::vectorize<int64_t>(bd_dim)),
263-
&(*tmp_res_indices_v)[i]);
259+
&expand_index);
264260
} else {
265-
tmp_res_indices_v->emplace_back(int_indices_v_tmp[i]);
261+
expand_index = index_tensor;
266262
}
263+
tmp_res_indices_v->emplace_back(expand_index);
267264
}
268-
269265
for (size_t i = 0; i < res_indices_v->size(); ++i) {
270266
(*res_indices_v)[i] = &(*tmp_res_indices_v)[i];
271267
}

paddle/phi/kernels/gpu/index_put_kernel.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ void IndexPutKernel(const Context& dev_ctx,
137137
std::vector<int64_t> res_dim_v(common::vectorize(bd_dim));
138138
std::vector<const phi::DenseTensor*> res_indices_v(x.dims().size(), nullptr);
139139
std::vector<DenseTensor> tmp_res_indices_v;
140-
std::vector<DenseTensor> tmp_value_v;
140+
141141
std::vector<DenseTensor> range_tensor_v;
142142
const DenseTensor* ptr_value = nullptr;
143143

@@ -154,13 +154,11 @@ void IndexPutKernel(const Context& dev_ctx,
154154
range_tensor_v,
155155
bd_dim,
156156
&res_dim_v);
157-
157+
phi::DenseTensor tmp_value;
158158
if (value.numel() != 1) {
159-
tmp_value_v.emplace_back(
160-
DenseTensor(value.dtype()).Resize(common::make_ddim(res_dim_v)));
161-
ExpandKernel<T, Context>(
162-
dev_ctx, value, IntArray(res_dim_v), &tmp_value_v[0]);
163-
ptr_value = &tmp_value_v[0];
159+
tmp_value = DenseTensor(value.dtype()).Resize(common::make_ddim(res_dim_v));
160+
ExpandKernel<T, Context>(dev_ctx, value, IntArray(res_dim_v), &tmp_value);
161+
ptr_value = &tmp_value;
164162
} else {
165163
ptr_value = &value;
166164
}

0 commit comments

Comments
 (0)