-
Couldn't load subscription status.
- Fork 560
Description
🐛 Bug
Two issues
-
xla::Einsumdoes not take more than 2 inputs. -
xla::Einsumdoes not support the same inputs as the PyTorch implementation of einsum, forcing us to fall back sometimes, e.g. in the case ofTestEinsumPyTorchLowerRepeatedAxisBackward.
Elaboration on the second problem:
In order to execute einsum_backward, we need to execute einsum using "backward" equations that we derive from the original input equation. For ijj,k->ik we need to execute einsum with ik,k->ijj and ijj,ik->k. However, ik,k->ijj will not work on xla::Einsum. We run into the following issue: INVALID_ARGUMENT: Transpose dimensions [0,0] are not a permutation of the operand dimensions (operand shape is f32[2]). So today, we fall back to the ATen Native implementation.
To Reproduce
Problem 1
Pass 3 or more xla tensors into torch.einsum.
Problem 2
When I try
xla::Einsum(x, y, "ik,k->ijj", xla::PrecisionConfig::Precision::PrecisionConfig_Precision_DEFAULT, xla::PrimitiveType::F32);
I get
INVALID_ARGUMENT: Transpose dimensions [0,0] are not a permutation of the operand dimensions (operand shape is f32[2]).
Steps to reproduce the behavior:
For problem 1:
Try
import torch import torch_xla import torch_xla.core.xla_model as xm x = torch.rand(4, requires_grad=True, device= xm.xla_device()) y = torch.rand(4, requires_grad=True, device= xm.xla_device()) z = torch.rand(4, requires_grad=True, device= xm.xla_device()) torch.einsum('i,j,k->ijk', x, y, z) Notice that xla::Einsum is not used (e.g. by printing IR or HLO)
For problem 2:
Remove the fallback logic from XLANativeFunctions::einsum and the
a. Run TestEinsumPyTorchLowerRepeatedAxisBackward
OR
b. Try
import torch import torch_xla import torch_xla.core.xla_model as xm x = torch.rand(2, 4, requires_grad=True, device= xm.xla_device()) y = torch.rand(4, requires_grad=True, device= xm.xla_device()) torch.einsum('ik,k->ijj', x, y) The program should crash inside of the xla::Einsum function
Expected behavior
- I would expect XLA to handle more than 2 input tensors, or for PyTorch to break down the inputs into smaller einsum ops for dispatch.
- I would expect XLA to support these kinds of equations, or for the implementation of einsum backwards to change such that these types of equations are no longer necessary.
Environment
This should be reproducible on CPU/GPU/TPU
Torch Version: 1.13