@@ -241,31 +241,27 @@ void DealWithIndices(const Context& dev_ctx,
241241 }
242242
243243 } else {
244- std::vector<DenseTensor> int_indices_v_tmp;
245-
246244 for (size_t i = 0 ; i < int_indices_v.size (); ++i) {
245+ phi::DenseTensor index_tensor;
246+ phi::DenseTensor expand_index;
247247 if (int_indices_v[i]->dtype () == phi::DataType::INT32) {
248- int_indices_v_tmp. emplace_back ( phi::Cast<int , Context>(
249- dev_ctx, *int_indices_v[i], phi::DataType::INT64)) ;
248+ index_tensor = phi::Cast<int , Context>(
249+ dev_ctx, *int_indices_v[i], phi::DataType::INT64);
250250 } else {
251- int_indices_v_tmp. emplace_back ( *int_indices_v[i]) ;
251+ index_tensor = *int_indices_v[i];
252252 }
253- }
254-
255- for (size_t i = 0 ; i < int_indices_v.size (); ++i) {
256253 if (bd_dim != int_indices_v[i]->dims ()) {
257- tmp_res_indices_v->emplace_back (
258- DenseTensor (phi::DataType::INT64).Resize (bd_dim));
254+ expand_index = DenseTensor (phi::DataType::INT64).Resize (bd_dim);
259255 ExpandKernel<int64_t , Context>(
260256 dev_ctx,
261- int_indices_v_tmp[i] ,
257+ index_tensor ,
262258 IntArray (common::vectorize<int64_t >(bd_dim)),
263- &(*tmp_res_indices_v)[i] );
259+ &expand_index );
264260 } else {
265- tmp_res_indices_v-> emplace_back (int_indices_v_tmp[i]) ;
261+ expand_index = index_tensor ;
266262 }
263+ tmp_res_indices_v->emplace_back (expand_index);
267264 }
268-
269265 for (size_t i = 0 ; i < res_indices_v->size (); ++i) {
270266 (*res_indices_v)[i] = &(*tmp_res_indices_v)[i];
271267 }
0 commit comments