Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1255,8 +1255,8 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(inputs.size(), LabelMap(-1));
std::vector<char> all_labels;
std::vector<int> output_dims;
std::vector<std::vector<int>> broadcast_shapes(2);
std::vector<int64_t> output_dims;
std::vector<std::vector<int64_t>> broadcast_shapes(2);

std::vector<DDim> input_dims;
for (auto& i : inputs) {
Expand Down
33 changes: 17 additions & 16 deletions paddle/phi/kernels/impl/einsum_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@ template <typename T, typename Context>
DenseTensor PerformTileAndReduction(const Context& dev_ctx,
const LabelMap& label2type,
const LabelMap& label2shape,
const std::vector<int>& broadcast_shape,
const std::vector<int> x_shape,
const std::vector<int64_t>& broadcast_shape,
const std::vector<int64_t> x_shape,
std::string equ, // value pass
DenseTensor& t) { // NOLINT
auto tmp_label = equ;
auto tmp_union = unique_labels(tmp_label);
auto op_label = std::string(tmp_union.begin(), tmp_union.end());
VLOG(5) << "Start PerformTileAndReduction equation " << equ
<< " with operand shape: "
<< paddle::string::join_strings(common::vectorize<int>(t.dims()),
<< paddle::string::join_strings(common::vectorize<int64_t>(t.dims()),
",");
DenseTensor ret;
std::vector<int> repeat_times;
std::vector<int> resize_dims;
std::vector<int> recover_shape;
std::vector<int64_t> repeat_times;
std::vector<int64_t> resize_dims;
std::vector<int64_t> recover_shape;
for (int c : op_label) {
if (label2type[c] == LabelType::Reduction) {
repeat_times.push_back(label2shape[c]);
Expand All @@ -56,7 +56,7 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
}
t.Resize(common::make_ddim(resize_dims));
DenseTensor after_tile;
if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int x) {
if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int64_t x) {
return x == 1;
})) {
after_tile = t;
Expand All @@ -83,15 +83,16 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
// call TileGradKernel to reverse broadcast operation.
VLOG(5) << "After diagonalize, we have tensor with shape: "
<< paddle::string::join_strings(
common::vectorize<int>(undiagonal_out.dims()), ',');
common::vectorize<int64_t>(undiagonal_out.dims()), ',');
repeat_times.clear();
for (size_t i = 0; i < x_shape.size(); ++i) {
VLOG(4) << "broadcast shape is " << broadcast_shape[i] << ", x_shape is "
<< x_shape[i];
repeat_times.push_back(broadcast_shape[i] / x_shape[i]);
}
bool is_all_ones = std::all_of(
repeat_times.begin(), repeat_times.end(), [](int x) { return x == 1; });
bool is_all_ones = std::all_of(repeat_times.begin(),
repeat_times.end(),
[](int64_t x) { return x == 1; });
if (is_all_ones) {
VLOG(4) << "don't need broadcast recover, we just return undiagonal_out.";
return undiagonal_out;
Expand All @@ -104,7 +105,7 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
dev_ctx, tmp_x, undiagonal_out, repeat_times, &broadcast_out);
VLOG(5) << "After broadcast recover, we have tensor with shape: "
<< paddle::string::join_strings(
common::vectorize<int>(broadcast_out.dims()), ',');
common::vectorize<int64_t>(broadcast_out.dims()), ',');
return broadcast_out;
}

Expand All @@ -120,8 +121,8 @@ void EinsumGradKernel(const Context& dev_ctx,
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(x.size(), LabelMap(-1));
std::vector<char> all_labels; // order: ABO, AO, BO, AB, Reduce
std::vector<std::vector<int>> broadcast_shapes(2);
std::vector<int> output_dims;
std::vector<std::vector<int64_t>> broadcast_shapes(2);
std::vector<int64_t> output_dims;

std::vector<DDim> input_dims;
for (auto& i : x) {
Expand Down Expand Up @@ -165,7 +166,7 @@ void EinsumGradKernel(const Context& dev_ctx,
labeltype,
labelshape,
broadcast_shapes[0],
common::vectorize<int>(x[0]->dims()),
common::vectorize<int64_t>(x[0]->dims()),
left,
before_tile);
#ifndef PADDLE_WITH_XPU // xpu is not support conj now, we just disable it.
Expand Down Expand Up @@ -226,7 +227,7 @@ void EinsumGradKernel(const Context& dev_ctx,
labeltype,
labelshape,
broadcast_shapes[0],
common::vectorize<int>(x[0]->dims()),
common::vectorize<int64_t>(x[0]->dims()),
ops[0],
dA);
VLOG(4) << "After call dA";
Expand All @@ -240,7 +241,7 @@ void EinsumGradKernel(const Context& dev_ctx,
labeltype,
labelshape,
broadcast_shapes[1],
common::vectorize<int>(x[1]->dims()),
common::vectorize<int64_t>(x[1]->dims()),
ops[1],
dB);
#ifndef PADDLE_WITH_XPU // xpu is not support conj now, we just disable it.
Expand Down
80 changes: 41 additions & 39 deletions paddle/phi/kernels/impl/einsum_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,23 @@ enum LabelType {
Reduction, // A, B
};

// map a label('a' - 'z') -> int, O(1) speed.
// map a label('a' - 'z') -> int64_t, O(1) speed.
class LabelMap {
Comment on lines +75 to 76
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个LabelMap也要用int64吗?看起来只是map到了一个axis,不是map到shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key是axis,用label'a-z'来表示,在数组索引的时候将label修饰为int;value是shape,所以也要改成int64.

constexpr static int N =
26 + 1; // 'a' - 'z' + '.', '.' is for broadcast dims
int default_value;
int map[N];
int64_t default_value;
int64_t map[N];

public:
explicit LabelMap(int default_value = 0) {
explicit LabelMap(int64_t default_value = 0) {
this->default_value = default_value;
for (size_t i = 0; i < N; ++i) map[i] = default_value;
}
int& operator[](int label) {
int64_t& operator[](int label) {
int i = label - 'a';
return map[i];
}
int operator[](int label) const {
int64_t operator[](int label) const {
int i = label - 'a';
return map[i];
}
Expand Down Expand Up @@ -204,15 +204,15 @@ inline static void InferLabelShape(
const std::vector<std::string>& op_labels,
const std::vector<DDim>& inputs,
LabelMap* labelshape,
std::vector<std::vector<int>>* broadcast_shapes) {
std::vector<std::vector<int64_t>>* broadcast_shapes) {
VLOG(5) << "Start InferLabelShape";
for (size_t i = 0; i < op_labels.size(); ++i) {
auto& op_str = op_labels[i];
auto& op_dim = inputs[i];
int dim_ptr = 0;
for (auto& c : op_str) {
if (!labelshape->exist(c) || abs((*labelshape)[c]) == 1) {
(*labelshape)[c] = static_cast<int>(op_dim[dim_ptr]);
(*labelshape)[c] = op_dim[dim_ptr];
} else if (abs(op_dim[dim_ptr]) != 1) {
PADDLE_ENFORCE_EQ(
(*labelshape)[c],
Expand Down Expand Up @@ -248,7 +248,7 @@ inline static void InferLabelPerm(const CharIterable& op,

inline static void InferOutputDims(const std::string& right,
const LabelMap& labelshape,
std::vector<int>* output_dims) {
std::vector<int64_t>* output_dims) {
for (int c : right) {
output_dims->push_back(labelshape[c]);
}
Expand All @@ -261,8 +261,8 @@ inline static void ParseEinsumEquation(
LabelMap* labeltype,
std::vector<char>* all_labels,
std::vector<LabelMap>* label2perms,
std::vector<std::vector<int>>* broadcast_shapes,
std::vector<int>* output_dims,
std::vector<std::vector<int64_t>>* broadcast_shapes,
std::vector<int64_t>* output_dims,
std::string* right,
std::vector<std::string>* input_strs) {
VLOG(5) << "Start ParseEinsumEquation " << equation;
Expand Down Expand Up @@ -351,7 +351,7 @@ DenseTensor Undiagonal(const Context& dev_ctx,
// output is (3, 4, 5, 2, 1, 5)
VLOG(5) << "Start undiagonal with args: insert_pos = " << insert_pos
<< ", axis = " << axis;
std::vector<int> shape(tensor.dims().size() + 1);
std::vector<int64_t> shape(tensor.dims().size() + 1);
int point = 0; // point to the tensor.dims()
for (size_t i = 0; i < shape.size(); ++i) {
if (i == insert_pos)
Expand Down Expand Up @@ -391,26 +391,28 @@ DenseTensor PerformUndiagonal(const Context& dev_ctx,
}

template <typename T, typename Context>
DenseTensor PerformDiagonalAndReduction(const Context& dev_ctx,
const DenseTensor& tensor,
const std::string& equ,
const LabelMap& label2perm,
const std::vector<char>& all_labels,
const std::vector<int>& broadcast_shape,
const LabelMap& label2type) {
DenseTensor PerformDiagonalAndReduction(
const Context& dev_ctx,
const DenseTensor& tensor,
const std::string& equ,
const LabelMap& label2perm,
const std::vector<char>& all_labels,
const std::vector<int64_t>& broadcast_shape,
const LabelMap& label2type) {
auto res = tensor;
int tot = equ.size();
// tiling tensor for broadcast
std::vector<int> repeat_times;
std::vector<int64_t> repeat_times;
auto tensor_origin_shape = common::vectorize(tensor.dims());
for (size_t i = 0; i < tensor_origin_shape.size(); ++i) {
VLOG(4) << "broadcast shape is " << broadcast_shape[i]
<< ", tensor shape is " << tensor_origin_shape[i];
repeat_times.push_back(broadcast_shape[i] / tensor_origin_shape[i]);
}
DenseTensor after_tile;
bool is_all_ones = std::all_of(
repeat_times.begin(), repeat_times.end(), [](int x) { return x == 1; });
bool is_all_ones = std::all_of(repeat_times.begin(),
repeat_times.end(),
[](int64_t x) { return x == 1; });
if (!is_all_ones) {
TileKernel<T, Context>(dev_ctx, res, repeat_times, &after_tile);
res = after_tile;
Expand All @@ -423,7 +425,7 @@ DenseTensor PerformDiagonalAndReduction(const Context& dev_ctx,
// do diagonal, followed by movedim().
VLOG(5) << "Do diagonal with shape="
<< paddle::string::join_strings(
common::vectorize<int>(res.dims()), ',')
common::vectorize<int64_t>(res.dims()), ',')
<< ", axis1=" << cur << ", axis2=" << label2perm[c];
res = Diagonal<T, Context>(dev_ctx, res, 0, cur, label2perm[c]);
res = Transpose<T, Context>(
Expand Down Expand Up @@ -474,23 +476,23 @@ DenseTensor PerformContraction(
const std::vector<char>& all_labels,
const LabelMap& label2type,
const LabelMap& label2shape,
const std::vector<std::vector<int>>& broadcast_shapes,
const std::vector<std::vector<int64_t>>& broadcast_shapes,
std::vector<DenseTensor*> cache,
bool use_cache) {
auto all_valid = LabelMap(1);
auto recover_dim = GetShapeByType<int>(
auto recover_dim = GetShapeByType<int64_t>(
all_labels, label2type, all_valid, label2shape, {LabelType::Batch});
auto preprocess = [&](const DenseTensor& t,
const LabelMap& perm,
const std::vector<int>& broadcast,
const std::vector<int64_t>& broadcast,
int operand_idx) -> DenseTensor {
// reshape
auto frees = GetShapeByType<int>(all_labels,
label2type,
perm,
label2shape,
{LabelType::AO, LabelType::BO});
auto conts = GetShapeByType<int>(
auto frees = GetShapeByType<int64_t>(all_labels,
label2type,
perm,
label2shape,
{LabelType::AO, LabelType::BO});
auto conts = GetShapeByType<int64_t>(
all_labels, label2type, perm, label2shape, {LabelType::Contraction});
std::vector<char> reordered_all_labels = all_labels;
if (operand_idx == 1) {
Expand Down Expand Up @@ -526,19 +528,19 @@ DenseTensor PerformContraction(
<< "]: " << trans_t.dims();
}
}
auto mul_dims = GetShapeByType<int>(
auto mul_dims = GetShapeByType<int64_t>(
all_labels, label2type, perm, label2shape, {LabelType::Batch});
recover_dim.insert(recover_dim.end(), frees.begin(), frees.end());
if (operand_idx == 0) {
mul_dims.push_back(std::accumulate(
frees.begin(), frees.end(), 1, std::multiplies<int>()));
frees.begin(), frees.end(), 1, std::multiplies<int64_t>()));
mul_dims.push_back(std::accumulate(
conts.begin(), conts.end(), 1, std::multiplies<int>()));
conts.begin(), conts.end(), 1, std::multiplies<int64_t>()));
} else {
mul_dims.push_back(std::accumulate(
conts.begin(), conts.end(), 1, std::multiplies<int>()));
conts.begin(), conts.end(), 1, std::multiplies<int64_t>()));
mul_dims.push_back(std::accumulate(
frees.begin(), frees.end(), 1, std::multiplies<int>()));
frees.begin(), frees.end(), 1, std::multiplies<int64_t>()));
}
VLOG(5) << "PerformContraction: mul_dims: "
<< paddle::string::join_strings(mul_dims, ",");
Expand Down Expand Up @@ -608,8 +610,8 @@ void EinsumKernelImpl(const Context& dev_ctx,
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(inputs.size(), LabelMap(-1));
std::vector<char> all_labels; // order: ABO, AO, BO, AB, Reduce
std::vector<std::vector<int>> broadcast_shapes(2);
std::vector<int> output_dims;
std::vector<std::vector<int64_t>> broadcast_shapes(2);
std::vector<int64_t> output_dims;

std::vector<DDim> input_dims;
for (auto& i : inputs) {
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/impl/tile_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace phi {
template <typename Context, typename T, int Dims>
void TileBackward(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec,
const std::vector<int64_t>& reshape_dims_vec,
const std::vector<int64_t>& reduce_dims_vec,
DenseTensor* x_grad) {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
Expand Down Expand Up @@ -86,7 +86,7 @@ void TileGradKernel(const Context& dev_ctx,
const IntArray& repeat_times,
DenseTensor* x_grad) {
auto x_dims = x.dims();
auto vec_x_dims = common::vectorize<int>(x_dims);
auto vec_x_dims = common::vectorize<int64_t>(x_dims);
auto repeat_times_data = repeat_times.GetData();
if (repeat_times_data.size() < vec_x_dims.size()) {
int diff = vec_x_dims.size() - repeat_times_data.size();
Expand All @@ -99,8 +99,8 @@ void TileGradKernel(const Context& dev_ctx,
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
std::vector<int64_t> reshape_dims_vec;
std::vector<int64_t> reduce_dims_vec;
for (size_t i = 0; i < repeat_times_data.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times_data[i]);
Expand Down
Loading