Skip to content

Commit 9a1f720

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Validate inputs to _nested_view_from_buffer to prevent overflows (pytorch#147356)
Pull Request resolved: pytorch#147356 Approved by: https://github.com/albanD, https://github.com/jbschlosser ghstack dependencies: pytorch#147352, pytorch#147354
1 parent 536bce5 commit 9a1f720

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

aten/src/ATen/native/nested/NestedTensorMath.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ATen/WrapDimUtilsMulti.h>
1313
#include <ATen/core/Tensor.h>
1414
#include <ATen/native/layer_norm.h>
15+
#include <ATen/native/Resize.h>
1516
#include <ATen/native/nested/NestedTensorUtils.h>
1617

1718
#include <tuple>
@@ -879,19 +880,28 @@ Tensor _nested_view_from_buffer(
879880
"Can only a create Nested Tensor from a normal tensor buffer");
880881
TORCH_INTERNAL_ASSERT(buffer.dim() == 1, "The input buffer must be flat");
881882
TORCH_INTERNAL_ASSERT(nested_sizes.dim() == 2, "Expected the nested size tensor to be two dimensional.");
882-
uint64_t num_elements_nested_size = at::prod(nested_sizes, 1).sum().item<int64_t>();
883-
uint64_t buffer_storage_size = buffer.storage().nbytes()/buffer.dtype().itemsize();
884-
TORCH_INTERNAL_ASSERT(
885-
buffer_storage_size == num_elements_nested_size,
886-
"The number of elements in the buffer must equal the nested tensor size but buffer size: ",
887-
buffer_storage_size,
888-
" and nested tensor size: ",
889-
num_elements_nested_size,
890-
".");
891-
892883
TORCH_INTERNAL_ASSERT(nested_strides.dim() == 2, "Expected the nested stride tensor to be two dimensional.");
893884
TORCH_INTERNAL_ASSERT(nested_sizes.size(0) == nested_strides.size(0), "Expected the first dimension of nested size and nested stride tensor to be equal.");
894885
TORCH_INTERNAL_ASSERT(nested_strides.size(0) == storage_offsets.size(0), "Expected the first dimension of nested stride tensor to equal the length of offsets.");
886+
887+
888+
std::vector<at::Tensor> all_sizes = nested_sizes.unbind();
889+
std::vector<at::Tensor> all_strides = nested_strides.unbind();
890+
std::vector<at::Tensor> all_offsets = storage_offsets.unbind();
891+
auto size_dim = nested_sizes.size(1);
892+
893+
for (const auto i : c10::irange(nested_sizes.size(0))) {
894+
const int64_t* sizemat_ptr = all_sizes[i].const_data_ptr<int64_t>();
895+
const int64_t* stridemat_ptr = all_strides[i].const_data_ptr<int64_t>();
896+
const int64_t* offset_ptr = all_offsets[i].const_data_ptr<int64_t>();
897+
checkInBoundsForStorage(
898+
IntArrayRef(sizemat_ptr, sizemat_ptr + size_dim),
899+
IntArrayRef(stridemat_ptr, stridemat_ptr + size_dim),
900+
*offset_ptr,
901+
buffer.dtype(),
902+
buffer.storage());
903+
}
904+
895905
return at::detail::make_tensor<NestedTensorImpl>(
896906
c10::TensorImpl::VIEW,
897907
buffer,

test/test_nestedtensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,21 @@ def test_cat(self):
857857
):
858858
torch.cat([x, y], dim=-1)
859859

860+
def test_nested_view_from_buffer_overflow_errors(self):
861+
buffer = torch.tensor([1])
862+
sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64)
863+
strides = torch.tensor(
864+
[[0x41414141], [0x41414141], [0x41414141]], dtype=torch.int64
865+
)
866+
offsets = torch.tensor(
867+
[[0x41414141], [0x41414141], [0x41414141]], dtype=torch.int64
868+
)
869+
with self.assertRaisesRegex(
870+
RuntimeError,
871+
r"Storage size calculation overflowed with sizes=\[9223372036854775807\] and strides=\[1094795585\]",
872+
):
873+
nt = torch._nested_view_from_buffer(buffer, sizes, strides, offsets)
874+
860875

861876
@markDynamoStrictTest
862877
class TestNestedTensorDeviceType(NestedTensorTestCase):

0 commit comments

Comments
 (0)