Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/phi/kernels/impl/einsum_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ void EinsumGradKernel(const Context& dev_ctx,
}
EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
labelshape,
operands_for_A,
equation_for_A,
&dA,
Expand All @@ -223,6 +224,7 @@ void EinsumGradKernel(const Context& dev_ctx,

EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
labelshape,
operands_for_B,
equation_for_B,
&dB,
Expand Down
44 changes: 38 additions & 6 deletions paddle/phi/kernels/impl/einsum_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ inline static void InferLabelShape(
const std::vector<std::string>& op_labels,
const std::vector<DDim>& inputs,
LabelMap* labelshape,
std::vector<std::vector<int64_t>>* broadcast_shapes) {
std::vector<std::vector<int64_t>>* broadcast_shapes,
LabelMap* labeltype) {
LabelMap labelshape_copy = *labelshape;
VLOG(5) << "Start InferLabelShape";
for (size_t i = 0; i < op_labels.size(); ++i) {
auto& op_str = op_labels[i];
Expand Down Expand Up @@ -233,6 +235,20 @@ inline static void InferLabelShape(
}
for (size_t i = 0; i < op_labels.size(); ++i) {
for (auto& c : op_labels[i]) {
// Note: When broadcasting is involved, ensure the gradient is calculated
// with respect to the broadcasted shape. For example, in
// einsum("ij,ij->j", x(2,2), y(1,2)), y is broadcast to (2,2). The
// gradient calculation for x must use this broadcasted shape of y.
if (labelshape_copy.exist(c) && labelshape_copy[c] > (*labelshape)[c]) {
// Strict check for the situation.
PADDLE_ENFORCE_EQ(
(*labelshape)[c] == 1 && ((*labeltype)[c] == LabelType::AO ||
(*labeltype)[c] == LabelType::BO),
true,
common::errors::InvalidArgument(
"Broadcast dims must be 1 for label: `%c`", c));
(*labelshape)[c] = labelshape_copy[c];
}
(*broadcast_shapes)[i].push_back((*labelshape)[c]);
}
}
Expand Down Expand Up @@ -282,7 +298,7 @@ inline static void ParseEinsumEquation(
// split_string("->") -> [], we push back a "".
if (op_labels.empty()) op_labels.emplace_back("");
GlobalInfo(op_labels, *right, labeltype, all_labels);
InferLabelShape(op_labels, inputs, labelshape, broadcast_shapes);
InferLabelShape(op_labels, inputs, labelshape, broadcast_shapes, labeltype);
VLOG(5) << "Einsum Infershape: right:" << *right;
VLOG(5) << "Einsum Infershape: left :"
<< paddle::string::join_strings(op_labels, '\n');
Expand Down Expand Up @@ -603,6 +619,7 @@ DenseTensor TransposeToOutput(const Context& dev_ctx,
template <typename T, typename Context>
void EinsumKernelImpl(const Context& dev_ctx,
const std::vector<char>& forward_all_labels,
const LabelMap& forward_label_shape,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
Expand All @@ -629,6 +646,7 @@ void EinsumKernelImpl(const Context& dev_ctx,
std::string right;
if (!is_forward) {
all_labels = forward_all_labels;
labelshape = forward_label_shape;
}
ParseEinsumEquation(equation,
input_dims,
Expand Down Expand Up @@ -680,15 +698,22 @@ void EinsumKernel(const Context& dev_ctx,
}
}
std::vector<char> tmp;
LabelMap labelshape_holder;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
// to BuildPhiKernelContext for details.
int diff = inputs.size() - cache.size();
for (int i = 0; i < diff; ++i) {
cache.push_back(nullptr);
}
EinsumKernelImpl<T, Context>(
dev_ctx, tmp, inputs, equation, out, cache, /*forward=*/true);
EinsumKernelImpl<T, Context>(dev_ctx,
tmp,
labelshape_holder,
inputs,
equation,
out,
cache,
/*forward=*/true);
}

template <typename T, typename Context>
Expand All @@ -697,13 +722,20 @@ void EinsumInferKernel(const Context& dev_ctx,
const std::string& equation,
DenseTensor* out) {
std::vector<char> place_holder;
LabelMap labelshape_holder;
std::vector<DenseTensor*> cache_tensor(
inputs.size()); // set empty; TA, TB, TdC
for (size_t i = 0; i < inputs.size(); ++i) {
cache_tensor[i] = nullptr;
}
EinsumKernelImpl<T, Context>(
dev_ctx, place_holder, inputs, equation, out, cache_tensor, true);
EinsumKernelImpl<T, Context>(dev_ctx,
place_holder,
labelshape_holder,
inputs,
equation,
out,
cache_tensor,
true);
}

} // namespace phi
45 changes: 45 additions & 0 deletions test/legacy_test/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,5 +532,50 @@ def test_static_graph(self):
self.check_output_equal(a, e)


class TestContractionBroadcastGrad(unittest.TestCase):
def setUp(self):
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def test_case1(self):
with paddle.base.dygraph.guard(self.place):
# paddle.einsum("i, i", Tensor([2],"float32"), Tensor([1],"float32"), )
x_np = np.array([0.1, 0.2]).astype(np.float32)
y_np = np.array([0.5]).astype(np.float32)
except_res = np.einsum("i, i", x_np, y_np)
except_grad_x = np.array([0.5, 0.5]).astype(np.float32)
except_grad_y = np.array([0.3]).astype(np.float32)
x = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.to_tensor(y_np, stop_gradient=False)
res = paddle.einsum("i, i", x, y)
np.testing.assert_allclose(res.numpy(), except_res)
res.sum().backward()
x.grad.get_tensor() # To check if accessing unallocated memory
np.testing.assert_allclose(x.grad.numpy(), except_grad_x)
np.testing.assert_allclose(y.grad.numpy(), except_grad_y)

def test_case2(self):
with paddle.base.dygraph.guard(self.place):
# paddle.einsum("ij,ij->j", Tensor([2, 2],"float32"), Tensor([1, 2],"float32"), )
x_np = np.array([[0.1, 0.2], [0.3, 0.4]]).astype(np.float32)
y_np = np.array([[0.5, 0.6]]).astype(np.float32)
except_res = np.einsum("ij,ij->j", x_np, y_np)
except_grad_x = np.array([[0.5, 0.6], [0.5, 0.6]]).astype(
np.float32
)
except_grad_y = np.array([[0.4, 0.6]]).astype(np.float32)
x = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.to_tensor(y_np, stop_gradient=False)
res = paddle.einsum("ij,ij->j", x, y)
np.testing.assert_allclose(res.numpy(), except_res)
res.sum().backward()
x.grad.get_tensor() # To check if accessing unallocated memory
np.testing.assert_allclose(x.grad.numpy(), except_grad_x)
np.testing.assert_allclose(y.grad.numpy(), except_grad_y)


if __name__ == "__main__":
unittest.main()