Skip to content

Commit 719938c

Browse files
guangyeypytorchmergebot
authored andcommitted
Generalize pin memory logic for accelerator when non blocking copy happened (pytorch#143783)
# Motivation fix pytorch#143641 Generalize pin memory logic for accelerator when non-blocking copy happened. Each accelerator has its implementation on `empty_strided`. The accelerator which doesn't have pin memory mechanism could ignore or mimic when pin_out is True. Pull Request resolved: pytorch#143783 Approved by: https://github.com/EikanWang, https://github.com/albanD ghstack dependencies: pytorch#144959
1 parent 28b6430 commit 719938c

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

aten/src/ATen/native/TensorConversions.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ Tensor _to_copy(
343343
}
344344

345345
bool pin_out =
346-
(non_blocking && (self.is_cuda() || self.is_privateuseone()) &&
346+
(non_blocking &&
347+
at::accelerator::isAcceleratorExcluded(self.device().type(), at::kMPS) &&
347348
options.device().is_cpu() && (options.layout() == c10::kStrided));
348349

349350
if (memory_format == MemoryFormat::Preserve) {

test/test_accelerator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import unittest
55

66
import torch
7-
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
7+
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_MPS, TestCase
88

99

1010
if not torch.accelerator.is_available():
@@ -102,6 +102,14 @@ def test_multi_device_stream_context_manager(self):
102102
self.assertEqual(torch.accelerator.current_stream(), src_prev_stream)
103103
self.assertEqual(torch.accelerator.current_stream(dst_device), dst_prev_stream)
104104

105+
@unittest.skipIf(TEST_MPS, "MPS doesn't support pin memory!")
106+
def test_pin_memory_on_non_blocking_copy(self):
107+
t_acc = torch.randn(100).to(torch.accelerator.current_accelerator())
108+
t_host = t_acc.to("cpu", non_blocking=True)
109+
torch.accelerator.synchronize()
110+
self.assertTrue(t_host.is_pinned())
111+
self.assertEqual(t_acc.cpu(), t_host)
112+
105113

106114
if __name__ == "__main__":
107115
run_tests()

0 commit comments

Comments
 (0)