Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions test/dynamo/test_dynamo_aliasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,41 +185,6 @@ def test_buffer_donation_on_non_data_tensor(self):
self.assertNotIn('XlaSetBufferDonation', met.counter_names())


class TestNonDynamoBufferDonationAliasing(unittest.TestCase):

def dummy_fn(self, input):
return torch.cos(torch.sin(input))

# Currently let's skip buffer donation api for the non-dynamo use case
def test_buffer_donation_skip_for_non_dynamo(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
xm.mark_step()
met.clear_all()

# We should be able to set buffer donation for input tensor, but when mark_step
# triggered, the buffer donation should be ignored.
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
res = self.dummy_fn(input)
xm.mark_step()
# Make sure that input buffer is not aliased and can be used for other compuations.
# Also make sure that buffer_donation will not trigger recompilation in non-dynamo.
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, False))
res2 = self.dummy_fn(input)
xm.mark_step()
torch.allclose(res.cpu(), res2.cpu())
self.assertEqual(met.metric_data('CompileTime')[0], 1)

def test_no_op_mark_step_keep_buffer_donation(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
xm.mark_step()
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
xm.mark_step()
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
71 changes: 53 additions & 18 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import contextlib
import copy
import os
import sys
import unittest
from absl.testing import parameterized

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import unittest
import contextlib
import copy


def create_xla_config_context(set_func, get_func):
Expand All @@ -34,7 +35,7 @@ def config_context(value):


# TODO(alanwaketan): add test for views.
class InputOutputAliasesTest(unittest.TestCase):
class InputOutputAliasesTest(parameterized.TestCase):

def test_non_view(self):
xla_device = xm.xla_device()
Expand Down Expand Up @@ -233,34 +234,59 @@ def test_device_data_cache_no_aliasing(self):
self.assertEqual(t1.item(), 43)

def test_user_config_donation_with_ltc_donation(self):
with alias_with_buffer_donor_config_context(True):
met.clear_all()
xla_device = xm.xla_device()
t0 = torch.randn(4, 2, 2).to(xla_device)
t1 = torch.randn(4, 2, 2).to(xla_device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
self.assertFalse(torch_xla._XLAC._get_buffer_donation(t1))
t2 = t0 + t1
t1 += 2
xm.mark_step(wait=True)

# We surface the C++ runtime error by checking that the backend data is
# no longer present for the IR node.
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)

@parameterized.parameters(True, False)
def test_user_config_donation_with_ltc_donation_graph_sync(
self, enable_buffer_donor_config):
with alias_with_buffer_donor_config_context(enable_buffer_donor_config):
met.clear_all()
xla_device = xm.xla_device()
t0 = torch.randn(4, 2, 2).to(xla_device)
t1 = torch.randn(4, 2, 2).to(xla_device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
self.assertFalse(torch_xla._XLAC._get_buffer_donation(t1))
t3 = t0 + t1
t2 = t0 + t1
t1 += 2
xm.mark_step(wait=True)
# We use _xla_sync_multi to explicitly disable sync_xla_data, which will
# in turn avoid using LTC aliasings. This ensures that the resulting
# aliasings are due to the buffer donation.
torch_xla._XLAC._xla_sync_multi([t0, t1, t2], [str(xla_device)], True,
False)

# We surface the C++ runtime error by checking that the backend data is
# no longer present for the IR node.
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)
self.assertEqual(
torch_xla._XLAC._is_placecholder(t0), enable_buffer_donor_config)
self.assertEqual(
met.metric_data("InputOutputAliasCount")[1],
enable_buffer_donor_config)

def test_user_config_donation_with_ltc_donation_overlap(self):
with alias_with_buffer_donor_config_context(True):
met.clear_all()
xla_device = xm.xla_device()
t0 = torch.randn(4, 2, 2).to(xla_device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
t0 += 2
xm.mark_step()
met.clear_all()
xla_device = xm.xla_device()
t0 = torch.randn(4, 2, 2).to(xla_device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
t0 += 2
xm.mark_step()

self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)

def test_user_config_donation(self):
with alias_with_buffer_donor_config_context(True):
Expand Down Expand Up @@ -304,6 +330,15 @@ def test_user_config_donation_no_op_mark_step(self):
xm.mark_step()
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))

def test_no_op_mark_step_keep_buffer_donation(self):
xla_device = xm.xla_device()
input = torch.randn(5, 5).to(xla_device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
xm.mark_step()
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
xm.mark_step()
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))


if __name__ == '__main__':
test = unittest.main()
Expand Down
11 changes: 8 additions & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1305,9 +1305,11 @@ std::vector<size_t> XLAGraphExecutor::GetBufferDonors(
return {};
}

bool donate_ltc_data =
coll.config.sync_ltc_data && coll.config.force_ltc_data;
std::vector<size_t> ltc_buffer_donor_indices;
if (coll.config.sync_ltc_data && coll.config.force_ltc_data) {
// We can only alias at the step barrier, when force_ltc_data is true.
if (donate_ltc_data) {
// We can only alias at the step barrier, when donate_ltc_data is true.
// Consider the case:
// 1. Tensor A(DEVICE_DATA)
// 2. Tensor B = A + 0.9
Expand Down Expand Up @@ -1336,7 +1338,10 @@ std::vector<size_t> XLAGraphExecutor::GetBufferDonors(
}

std::vector<size_t> user_config_buffer_donor_indices;
if (GetAliasWithBufferDonorConfig()) {
if (donate_ltc_data || GetAliasWithBufferDonorConfig()) {
// In case any tensor is explicitly marked for donation, we ensure that it
// is donated during step barrier, or if explicitly forced to donate via
// GetAliasWithBufferDonorConfig().
user_config_buffer_donor_indices =
GetBufferDonorIndexFromUserConfig(parameters_data);
}
Expand Down
Loading