Skip to content

Conversation

@ysiraichi
Copy link
Collaborator

Fix: #6008

This PR, in addition to pytorch/pytorch#129691, add support for embedding bag calls when its sparse parameter is true.

Problem: when sparse=true, _embedding_bag_backward called at::_sparse_coo_tensor_unsafe_symint function which returns a sparse tensor. Since PyTorch/XLA does not support sparse tensors, this resulted in a dispatch error (see the original issue).

Solution: although, ideally we should support sparse tensors, in the short-term we decided (in an offline discussion) to fallback to the dense backwards function.

cc @miladm @JackCaoG

call(grad, indices_, offsets_, offset2bag, bag_size_, max_indices_,
num_weights, scale_grad_by_freq, mode, /*sparse=*/false,
per_sample_weights_opt, padding_idx);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test to test/test_operations.py, run the embedding bag fwd and bwd and make sure it doesn't crash? this way we can prevent it from regressing.

@ysiraichi ysiraichi force-pushed the ysiraichi/fallback-embedding-bag-backward branch from cf3a4a3 to 6e5131a Compare June 27, 2024 20:45
@ysiraichi ysiraichi force-pushed the ysiraichi/fallback-embedding-bag-backward branch from 6e5131a to 5ffaadd Compare June 28, 2024 15:03
fresh = tensor.clone()
# Make this tensor a leaf tensor by detaching and reseting its
# requires_grad property.
fresh = fresh.detach()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool. I wonder why we need to make the tensor a leaf tensor. What would happen if we don't do it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's so we can grab its tensor.grad and compare.

@ysiraichi ysiraichi merged commit 3bcb1fb into master Jul 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

4 participants