Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 75 additions & 81 deletions paddle/phi/kernels/funcs/index_put_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,67 +73,73 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
const Context& dev_ctx,
const std::vector<const phi::DenseTensor*>& indices_v,
std::vector<phi::DenseTensor>* tmp_indices_v) {
std::vector<const phi::DenseTensor*> res(indices_v.begin(), indices_v.end());
bool contains_bool_tensor = false;
std::vector<const phi::DenseTensor*> res;

bool contains_bool_tensor = false;
for (size_t i = 0; i < indices_v.size(); ++i) {
if (indices_v[i]->dtype() == phi::DataType::BOOL) {
contains_bool_tensor = true;
int rank = indices_v[i]->dims().size();
PADDLE_ENFORCE_GE(
rank,
1UL,
phi::errors::InvalidArgument("the only bool tensor in indices should "
"have number of dimension at least 1"));
phi::DenseTensor nonzero_indices(phi::DataType::INT64);
nonzero_indices.Resize(common::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[i], &nonzero_indices);

if (nonzero_indices.numel() == 0) {
std::vector<const phi::DenseTensor*> empty_indices;
return empty_indices;
}
break;
}
}

std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
const int tmp_ix = tmp_indices_v->size();
for (int i = 0; i < rank; ++i) {
tmp_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64)
.Resize(common::make_ddim({nonzero_indices.dims()[0]})));
}
for (int i = 0; i < rank; ++i) {
integer_indices[i] = &((*tmp_indices_v)[i + tmp_ix]);
}
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
if (contains_bool_tensor) {
for (size_t i = 0; i < indices_v.size(); ++i) {
if (indices_v[i]->dtype() == phi::DataType::BOOL) {
int rank = indices_v[i]->dims().size();
PADDLE_ENFORCE_GE(rank,
1UL,
phi::errors::InvalidArgument(
"the only bool tensor in indices should "
"have number of dimension at least 1"));
phi::DenseTensor nonzero_indices(phi::DataType::INT64);
nonzero_indices.Resize(common::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[i], &nonzero_indices);

if (nonzero_indices.numel() == 0) {
std::vector<const phi::DenseTensor*> empty_indices;
return empty_indices;
}

std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
const int tmp_ix = tmp_indices_v->size();
for (int i = 0; i < rank; ++i) {
tmp_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64)
.Resize(common::make_ddim({nonzero_indices.dims()[0]})));
}
for (int i = 0; i < rank; ++i) {
integer_indices[i] = &((*tmp_indices_v)[i + tmp_ix]);
}
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
#ifdef PADDLE_WITH_XPU
auto place = dev_ctx.GetPlace();
if (place.GetType() == phi::AllocationType::XPU) {
auto& pool = phi::DeviceContextPool::Instance();
auto* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
if (xpu_ctx->x_context()->xpu_stream) {
dev_ctx.Wait();
auto place = dev_ctx.GetPlace();
if (place.GetType() == phi::AllocationType::XPU) {
auto& pool = phi::DeviceContextPool::Instance();
auto* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
if (xpu_ctx->x_context()->xpu_stream) {
dev_ctx.Wait();
}
}
}
#endif

} else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
(indices_v[i]->dtype() == phi::DataType::INT32)) {
tmp_indices_v->emplace_back(*indices_v[i]);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"data type of tensor in indices must be int32, int64 or bool"));
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
(indices_v[i]->dtype() == phi::DataType::INT32)) {
tmp_indices_v->emplace_back(*indices_v[i]);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"data type of tensor in indices must be int32, int64 or bool"));
}
}
}
if (contains_bool_tensor) {
std::vector<const phi::DenseTensor*> res_tmp(tmp_indices_v->size(),
nullptr);
for (size_t i = 0; i < res_tmp.size(); ++i) {
res_tmp[i] = &((*tmp_indices_v)[i]);

res.reserve(tmp_indices_v->size());
for (size_t i = 0; i < tmp_indices_v->size(); ++i) {
res.emplace_back(&((*tmp_indices_v)[i]));
}
res.swap(res_tmp);
} else {
res = indices_v;
}

return res;
}

Expand Down Expand Up @@ -212,62 +218,50 @@ void DealWithIndices(const Context& dev_ctx,
res_dim_v->insert(res_dim_v->end(),
tmp_x_dims.begin() + int_indices_v.size(),
tmp_x_dims.end());

std::vector<DenseTensor> reshaped_indices_v;
phi::DDim res_dim = common::make_ddim(*res_dim_v);
for (size_t i = 0; i < int_indices_v.size(); ++i) {
phi::DenseTensor index_tensor;
if (int_indices_v[i]->dtype() == phi::DataType::INT32) {
reshaped_indices_v.emplace_back(phi::Cast<int, Context>(
dev_ctx, *int_indices_v[i], phi::DataType::INT64));
index_tensor = phi::Cast<int, Context>(
dev_ctx, *int_indices_v[i], phi::DataType::INT64);
} else {
reshaped_indices_v.emplace_back(*int_indices_v[i]);
index_tensor = *int_indices_v[i];
}
tmp_res_indices_v->emplace_back(
GetReshapeAndExpandTensor<int64_t, Context>(
dev_ctx, index_tensor, res_dim, bd_dim, 0));
}
reshaped_indices_v.insert(
reshaped_indices_v.end(), range_tensor_v.begin(), range_tensor_v.end());

phi::DDim res_dim = common::make_ddim(*res_dim_v);

for (size_t i = 0; i < reshaped_indices_v.size(); ++i) {
for (size_t i = 0; i < range_tensor_v.size(); ++i) {
tmp_res_indices_v->emplace_back(
GetReshapeAndExpandTensor<int64_t, Context>(
dev_ctx,
reshaped_indices_v[i],
res_dim,
bd_dim,
((i < int_indices_v.size())
? 0
: i - int_indices_v.size() + len_bd_dim)));
dev_ctx, range_tensor_v[i], res_dim, bd_dim, i + len_bd_dim));
}
for (size_t i = 0; i < res_indices_v->size(); ++i) {
(*res_indices_v)[i] = &(*tmp_res_indices_v)[i];
}

} else {
std::vector<DenseTensor> int_indices_v_tmp;

for (size_t i = 0; i < int_indices_v.size(); ++i) {
phi::DenseTensor index_tensor;
phi::DenseTensor expand_index;
if (int_indices_v[i]->dtype() == phi::DataType::INT32) {
int_indices_v_tmp.emplace_back(phi::Cast<int, Context>(
dev_ctx, *int_indices_v[i], phi::DataType::INT64));
index_tensor = phi::Cast<int, Context>(
dev_ctx, *int_indices_v[i], phi::DataType::INT64);
} else {
int_indices_v_tmp.emplace_back(*int_indices_v[i]);
index_tensor = *int_indices_v[i];
}
}

for (size_t i = 0; i < int_indices_v.size(); ++i) {
if (bd_dim != int_indices_v[i]->dims()) {
tmp_res_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64).Resize(bd_dim));
expand_index = DenseTensor(phi::DataType::INT64).Resize(bd_dim);
ExpandKernel<int64_t, Context>(
dev_ctx,
int_indices_v_tmp[i],
index_tensor,
IntArray(common::vectorize<int64_t>(bd_dim)),
&(*tmp_res_indices_v)[i]);
&expand_index);
} else {
tmp_res_indices_v->emplace_back(int_indices_v_tmp[i]);
expand_index = index_tensor;
}
tmp_res_indices_v->emplace_back(expand_index);
}

for (size_t i = 0; i < res_indices_v->size(); ++i) {
(*res_indices_v)[i] = &(*tmp_res_indices_v)[i];
}
Expand Down