Skip to content

Einsum still falls back in some cases #4032

@steventk-g

Description

@steventk-g

🐛 Bug

Two issues

  1. xla::Einsum does not take more than 2 inputs.

  2. xla::Einsum does not support the same inputs as the PyTorch implementation of einsum, forcing us to fall back sometimes, e.g. in the case of TestEinsumPyTorchLowerRepeatedAxisBackward.

Elaboration on the second problem:

In order to execute einsum_backward, we need to execute einsum using "backward" equations that we derive from the original input equation. For ijj,k->ik we need to execute einsum with ik,k->ijj and ijj,ik->k. However, ik,k->ijj will not work on xla::Einsum. We run into the following issue: INVALID_ARGUMENT: Transpose dimensions [0,0] are not a permutation of the operand dimensions (operand shape is f32[2]). So today, we fall back to the ATen Native implementation.

To Reproduce

Problem 1

Pass 3 or more xla tensors into torch.einsum.

Problem 2

When I try
xla::Einsum(x, y, "ik,k->ijj", xla::PrecisionConfig::Precision::PrecisionConfig_Precision_DEFAULT, xla::PrimitiveType::F32);

I get
INVALID_ARGUMENT: Transpose dimensions [0,0] are not a permutation of the operand dimensions (operand shape is f32[2]).

Steps to reproduce the behavior:

For problem 1:

Try

import torch import torch_xla import torch_xla.core.xla_model as xm x = torch.rand(4, requires_grad=True, device= xm.xla_device()) y = torch.rand(4, requires_grad=True, device= xm.xla_device()) z = torch.rand(4, requires_grad=True, device= xm.xla_device()) torch.einsum('i,j,k->ijk', x, y, z) 

Notice that xla::Einsum is not used (e.g. by printing IR or HLO)

For problem 2:

Remove the fallback logic from XLANativeFunctions::einsum and the

a. Run TestEinsumPyTorchLowerRepeatedAxisBackward

OR

b. Try

import torch import torch_xla import torch_xla.core.xla_model as xm x = torch.rand(2, 4, requires_grad=True, device= xm.xla_device()) y = torch.rand(4, requires_grad=True, device= xm.xla_device()) torch.einsum('ik,k->ijj', x, y) 

The program should crash inside of the xla::Einsum function

Expected behavior

  1. I would expect XLA to handle more than 2 input tensors, or for PyTorch to break down the inputs into smaller einsum ops for dispatch.
  2. I would expect XLA to support these kinds of equations, or for the implementation of einsum backwards to change such that these types of equations are no longer necessary.

Environment

This should be reproducible on CPU/GPU/TPU

Torch Version: 1.13

Additional context

#3843
#4027

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requestnostaleDo not consider for staleness

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions