Skip to content

Conversation

@yitongh
Copy link
Contributor

@yitongh yitongh commented Aug 9, 2024

Fix the issue where the input output alias does not work when using custom inplace operations such as xm.optimization_barrier_. This issue mainly arises from custom inplace operations not updating the alias id with _propagate_xla_data, so it still uses the outdated alias id. This issue is very common in FSDP and can result in param not being alias.

Consider the case without this pr

t1 = torch.randn(4, 5).to(xla_device) # Tensor ID 2, alias ID 2 t1 *= t1 # Tensor ID 3, alias ID 2 xm.mark_step() xm.optimization_barrier_([t1]) # Tensor ID 3, alias ID 2 t1 *= 100 # Tensor ID 4, alias ID 2 

In the second graph, the t1 cannot be aliased because the tensor id of t1 is 3, not 2.

with this pr

t1 = torch.randn(4, 5).to(xla_device) # Tensor ID 2, alias ID 2 t1 *= t1 # Tensor ID 3, alias ID 2 xm.mark_step() xm.optimization_barrier_([t1]) # Tensor ID 3, alias ID 3 t1 *= 100 # Tensor ID 4, alias ID 3 

In the second graph, the t1 can be aliased because the output alias id and input tensor id of t1 are both 3.

std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
const std::vector<XLATensorPtr>& tensors, absl::Span<const size_t> indices,
LoweringContext* lowering_ctx) {
std::unordered_map<int64_t, size_t> output_tensor_id_map;
std::vector<size_t> buffer_donor_indexs;
// tensors[indices] represent all tensors that needs to be updated after
// the execution. We can only alias the current buffer of these tensors
// since those buffers are no longer needed after execution.
for (size_t i = 0; i < indices.size(); ++i) {
size_t tensor_index = indices[i];
int64_t tensor_id = tensors[tensor_index]->data()->alias_id;
output_tensor_id_map[tensor_id] = i;
}
const auto& parameters_data = lowering_ctx->GetParametersData();
std::vector<ssize_t> alias_map(indices.size(), -1);
for (size_t i = 0; i < parameters_data.size(); ++i) {
auto* data_info =
static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
parameters_data[i]->info());
if (data_info != nullptr && !data_info->read_only) {
auto it = output_tensor_id_map.find(data_info->tensor_id);
// Parameter buffer's TensorId in output_tensor_id_map means
// this buffer is not needed after execution since XLATensor will get a
// new buffer.
if (it != output_tensor_id_map.end()) {
.

@JackCaoG JackCaoG self-requested a review August 12, 2024 06:02
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@JackCaoG JackCaoG merged commit 8f79488 into pytorch:master Aug 12, 2024
@yitongh yitongh deleted the fix_custom_alias branch August 13, 2024 01:55
yitongh added a commit to AlibabaPAI/xla that referenced this pull request Aug 13, 2024
yitongh added a commit to AlibabaPAI/xla that referenced this pull request Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants