|
12 | 12 | #include <ATen/WrapDimUtilsMulti.h> |
13 | 13 | #include <ATen/core/Tensor.h> |
14 | 14 | #include <ATen/native/layer_norm.h> |
| 15 | +#include <ATen/native/Resize.h> |
15 | 16 | #include <ATen/native/nested/NestedTensorUtils.h> |
16 | 17 |
|
17 | 18 | #include <tuple> |
@@ -879,19 +880,28 @@ Tensor _nested_view_from_buffer( |
879 | 880 | "Can only a create Nested Tensor from a normal tensor buffer"); |
880 | 881 | TORCH_INTERNAL_ASSERT(buffer.dim() == 1, "The input buffer must be flat"); |
881 | 882 | TORCH_INTERNAL_ASSERT(nested_sizes.dim() == 2, "Expected the nested size tensor to be two dimensional."); |
882 | | - uint64_t num_elements_nested_size = at::prod(nested_sizes, 1).sum().item<int64_t>(); |
883 | | - uint64_t buffer_storage_size = buffer.storage().nbytes()/buffer.dtype().itemsize(); |
884 | | - TORCH_INTERNAL_ASSERT( |
885 | | - buffer_storage_size == num_elements_nested_size, |
886 | | - "The number of elements in the buffer must equal the nested tensor size but buffer size: ", |
887 | | - buffer_storage_size, |
888 | | - " and nested tensor size: ", |
889 | | - num_elements_nested_size, |
890 | | - "."); |
891 | | - |
892 | 883 | TORCH_INTERNAL_ASSERT(nested_strides.dim() == 2, "Expected the nested stride tensor to be two dimensional."); |
893 | 884 | TORCH_INTERNAL_ASSERT(nested_sizes.size(0) == nested_strides.size(0), "Expected the first dimension of nested size and nested stride tensor to be equal."); |
894 | 885 | TORCH_INTERNAL_ASSERT(nested_strides.size(0) == storage_offsets.size(0), "Expected the first dimension of nested stride tensor to equal the length of offsets."); |
| 886 | + |
| 887 | + |
| 888 | + std::vector<at::Tensor> all_sizes = nested_sizes.unbind(); |
| 889 | + std::vector<at::Tensor> all_strides = nested_strides.unbind(); |
| 890 | + std::vector<at::Tensor> all_offsets = storage_offsets.unbind(); |
| 891 | + auto size_dim = nested_sizes.size(1); |
| 892 | + |
| 893 | + for (const auto i : c10::irange(nested_sizes.size(0))) { |
| 894 | + const int64_t* sizemat_ptr = all_sizes[i].const_data_ptr<int64_t>(); |
| 895 | + const int64_t* stridemat_ptr = all_strides[i].const_data_ptr<int64_t>(); |
| 896 | + const int64_t* offset_ptr = all_offsets[i].const_data_ptr<int64_t>(); |
| 897 | + checkInBoundsForStorage( |
| 898 | + IntArrayRef(sizemat_ptr, sizemat_ptr + size_dim), |
| 899 | + IntArrayRef(stridemat_ptr, stridemat_ptr + size_dim), |
| 900 | + *offset_ptr, |
| 901 | + buffer.dtype(), |
| 902 | + buffer.storage()); |
| 903 | + } |
| 904 | + |
895 | 905 | return at::detail::make_tensor<NestedTensorImpl>( |
896 | 906 | c10::TensorImpl::VIEW, |
897 | 907 | buffer, |
|
0 commit comments