Skip to content

Commit e58e27c

Browse files
committed
Add 'torch/lib/ATen/' from commit '9d0c674cb7bcfae989d69f988363c1688c22fa89'
git-subtree-dir: torch/lib/ATen git-subtree-mainline: 3314d51 git-subtree-split: 9d0c674
2 parents 3314d51 + 9d0c674 commit e58e27c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+4002
-0
lines changed

torch/lib/ATen/ATen.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include "ATen/Scalar.h"
4+
#include "ATen/Type.h"
5+
#include "ATen/Generator.h"
6+
#include "ATen/Context.h"
7+
#include "ATen/Storage.h"
8+
#include "ATen/Tensor.h"
9+
#include "ATen/Functions.h"
10+
#include "ATen/Formatting.h"
11+
#include "ATen/TensorOperators.h"
12+
#include "ATen/TensorMethods.h"

torch/lib/ATen/ArrayRef.h

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
2+
//
3+
// The LLVM Compiler Infrastructure
4+
//
5+
// This file is distributed under the University of Illinois Open Source
6+
// License. See LICENSE.TXT for details.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
// ATen: modified from llvm::ArrayRef.
11+
// removed llvm-specific functionality
12+
// removed some implicit const -> non-const conversions that rely on
13+
// complicated std::enable_if meta-programming
14+
// removed a bunch of slice variants for simplicity...
15+
16+
#pragma once
17+
#include <assert.h>
18+
#include <array>
19+
#include <vector>
20+
21+
namespace at {
22+
/// ArrayRef - Represent a constant reference to an array (0 or more elements
23+
/// consecutively in memory), i.e. a start pointer and a length. It allows
24+
/// various APIs to take consecutive elements easily and conveniently.
25+
///
26+
/// This class does not own the underlying data, it is expected to be used in
27+
/// situations where the data resides in some other buffer, whose lifetime
28+
/// extends past that of the ArrayRef. For this reason, it is not in general
29+
/// safe to store an ArrayRef.
30+
///
31+
/// This is intended to be trivially copyable, so it should be passed by
32+
/// value.
33+
template<typename T>
34+
class ArrayRef {
35+
public:
36+
typedef const T *iterator;
37+
typedef const T *const_iterator;
38+
typedef size_t size_type;
39+
40+
typedef std::reverse_iterator<iterator> reverse_iterator;
41+
42+
private:
43+
/// The start of the array, in an external buffer.
44+
const T *Data;
45+
46+
/// The number of elements.
47+
size_type Length;
48+
49+
public:
50+
/// @name Constructors
51+
/// @{
52+
53+
/// Construct an empty ArrayRef.
54+
/*implicit*/ ArrayRef() : Data(nullptr), Length(0) {}
55+
56+
/// Construct an ArrayRef from a single element.
57+
/*implicit*/ ArrayRef(const T &OneElt)
58+
: Data(&OneElt), Length(1) {}
59+
60+
/// Construct an ArrayRef from a pointer and length.
61+
/*implicit*/ ArrayRef(const T *data, size_t length)
62+
: Data(data), Length(length) {}
63+
64+
/// Construct an ArrayRef from a range.
65+
ArrayRef(const T *begin, const T *end)
66+
: Data(begin), Length(end - begin) {}
67+
68+
/// Construct an ArrayRef from a std::vector.
69+
template<typename A>
70+
/*implicit*/ ArrayRef(const std::vector<T, A> &Vec)
71+
: Data(Vec.data()), Length(Vec.size()) {}
72+
73+
/// Construct an ArrayRef from a std::array
74+
template <size_t N>
75+
/*implicit*/ constexpr ArrayRef(const std::array<T, N> &Arr)
76+
: Data(Arr.data()), Length(N) {}
77+
78+
/// Construct an ArrayRef from a C array.
79+
template <size_t N>
80+
/*implicit*/ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
81+
82+
/// Construct an ArrayRef from a std::initializer_list.
83+
/*implicit*/ ArrayRef(const std::initializer_list<T> &Vec)
84+
: Data(Vec.begin() == Vec.end() ? (T*)nullptr : Vec.begin()),
85+
Length(Vec.size()) {}
86+
87+
/// @}
88+
/// @name Simple Operations
89+
/// @{
90+
91+
iterator begin() const { return Data; }
92+
iterator end() const { return Data + Length; }
93+
94+
reverse_iterator rbegin() const { return reverse_iterator(end()); }
95+
reverse_iterator rend() const { return reverse_iterator(begin()); }
96+
97+
/// empty - Check if the array is empty.
98+
bool empty() const { return Length == 0; }
99+
100+
const T *data() const { return Data; }
101+
102+
/// size - Get the array size.
103+
size_t size() const { return Length; }
104+
105+
/// front - Get the first element.
106+
const T &front() const {
107+
assert(!empty());
108+
return Data[0];
109+
}
110+
111+
/// back - Get the last element.
112+
const T &back() const {
113+
assert(!empty());
114+
return Data[Length-1];
115+
}
116+
117+
/// equals - Check for element-wise equality.
118+
bool equals(ArrayRef RHS) const {
119+
if (Length != RHS.Length)
120+
return false;
121+
return std::equal(begin(), end(), RHS.begin());
122+
}
123+
124+
/// slice(n, m) - Chop off the first N elements of the array, and keep M
125+
/// elements in the array.
126+
ArrayRef<T> slice(size_t N, size_t M) const {
127+
assert(N+M <= size() && "Invalid specifier");
128+
return ArrayRef<T>(data()+N, M);
129+
}
130+
131+
/// slice(n) - Chop off the first N elements of the array.
132+
ArrayRef<T> slice(size_t N) const { return slice(N, size() - N); }
133+
134+
/// @}
135+
/// @name Operator Overloads
136+
/// @{
137+
const T &operator[](size_t Index) const {
138+
assert(Index < Length && "Invalid index!");
139+
return Data[Index];
140+
}
141+
142+
/// Disallow accidental assignment from a temporary.
143+
///
144+
/// The declaration here is extra complicated so that "arrayRef = {}"
145+
/// continues to select the move assignment operator.
146+
template <typename U>
147+
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type &
148+
operator=(U &&Temporary) = delete;
149+
150+
/// Disallow accidental assignment from a temporary.
151+
///
152+
/// The declaration here is extra complicated so that "arrayRef = {}"
153+
/// continues to select the move assignment operator.
154+
template <typename U>
155+
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type &
156+
operator=(std::initializer_list<U>) = delete;
157+
158+
/// @}
159+
/// @name Expensive Operations
160+
/// @{
161+
std::vector<T> vec() const {
162+
return std::vector<T>(Data, Data+Length);
163+
}
164+
165+
/// @}
166+
/// @name Conversion operators
167+
/// @{
168+
operator std::vector<T>() const {
169+
return std::vector<T>(Data, Data+Length);
170+
}
171+
172+
/// @}
173+
};
174+
175+
} // end namespace at

torch/lib/ATen/CMakeLists.txt

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
2+
CMAKE_MINIMUM_REQUIRED(VERSION 2.8)
3+
SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
4+
5+
if(${CMAKE_VERSION} VERSION_LESS "2.8.12")
6+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
7+
else(${CMAKE_VERSION} VERSION_LESS "2.8.12")
8+
if(${CMAKE_VERSION} VERSION_LESS "3.1")
9+
add_compile_options(-std=c++11) # CMake 2.8.12 to 3.1
10+
endif(${CMAKE_VERSION} VERSION_LESS "3.1")
11+
endif(${CMAKE_VERSION} VERSION_LESS "2.8.12")
12+
13+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -pedantic -Wno-vla")
14+
15+
################################################################################
16+
# Helper functions
17+
################################################################################
18+
19+
FUNCTION(EXCLUDE_DIR list_name dir_name)
20+
# A helper that excludes all files that contain dir_name in their file path
21+
SET(local_list ${${list_name}})
22+
FOREACH(source ${local_list})
23+
IF(${source} MATCHES ${dir_name})
24+
MESSAGE(STATUS "Excluding " ${source} " from the build")
25+
LIST(REMOVE_ITEM local_list ${source})
26+
ENDIF()
27+
ENDFOREACH()
28+
SET(${list_name} ${local_list} PARENT_SCOPE)
29+
ENDFUNCTION()
30+
31+
function(filter_list output input)
32+
unset(result)
33+
foreach(filename ${${input}})
34+
foreach(pattern ${ARGN})
35+
if("${filename}" MATCHES "${pattern}")
36+
list(APPEND result "${filename}")
37+
endif()
38+
endforeach()
39+
endforeach()
40+
set(${output} ${result} PARENT_SCOPE)
41+
endfunction()
42+
43+
IF(NOT Torch_FOUND)
44+
FIND_PACKAGE(Torch REQUIRED)
45+
ENDIF()
46+
47+
IF(NOT TH_LIBRARIES)
48+
SET(TH_LIBRARIES "TH")
49+
ENDIF(NOT TH_LIBRARIES)
50+
MESSAGE(STATUS "TH_LIBRARIES: ${TH_LIBRARIES}")
51+
52+
IF(NOT THS_LIBRARIES)
53+
SET(THS_LIBRARIES "THS")
54+
ENDIF()
55+
56+
IF(NOT THNN_LIBRARIES)
57+
SET(THNN_LIBRARIES "THNN")
58+
ENDIF(NOT THNN_LIBRARIES)
59+
MESSAGE(STATUS "THNN_LIBRARIES: ${THNN_LIBRARIES}")
60+
61+
IF(NO_CUDA)
62+
MESSAGE(STATUS "ignoring CUDA")
63+
SET(CUDA_FLAG -n)
64+
ELSE()
65+
FIND_PACKAGE(CUDA 5.5)
66+
IF(CUDA_FOUND)
67+
ADD_DEFINITIONS(-DAT_CUDA_ENABLED)
68+
INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
69+
70+
IF(NOT THC_LIBRARIES)
71+
SET(THC_LIBRARIES "THC")
72+
ENDIF(NOT THC_LIBRARIES)
73+
MESSAGE(STATUS "THC_LIBRARIES: ${THC_LIBRARIES}")
74+
75+
IF(NOT THCS_LIBRARIES)
76+
SET(THCS_LIBRARIES "THCS")
77+
ENDIF(NOT THCS_LIBRARIES)
78+
MESSAGE(STATUS "THCS_LIBRARIES: ${THCS_LIBRARIES}")
79+
80+
IF(NOT THCUNN_LIBRARIES)
81+
SET(THCUNN_LIBRARIES "THCUNN")
82+
ENDIF(NOT THCUNN_LIBRARIES)
83+
MESSAGE(STATUS "THCUNN_LIBRARIES: ${THCUNN_LIBRARIES}")
84+
ENDIF()
85+
ENDIF()
86+
87+
# Can be compiled standalone
88+
IF(NOT TENSOR_LIB_INSTALL_BIN_DIR OR NOT TENSOR_LIB_INSTALL_LIB_DIR OR NOT TENSOR_LIB_INSTALL_INCLUDE_DIR)
89+
SET(TENSOR_LIB_INSTALL_BIN_DIR "bin" CACHE PATH "TENSOR_LIB install binary subdirectory")
90+
SET(TENSOR_LIB_INSTALL_LIB_DIR "lib" CACHE PATH "TENSOR_LIB install library subdirectory")
91+
SET(TENSOR_LIB_INSTALL_INCLUDE_DIR "include" CACHE PATH "TENSOR_LIB install include subdirectory")
92+
ENDIF()
93+
94+
FILE(GLOB base_h RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.h")
95+
FILE(GLOB base_cpp RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
96+
97+
FILE(GLOB all_python RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py")
98+
99+
IF(NOT DEFINED cwrap_files)
100+
SET(CWRAP_FILES_BASE ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc )
101+
SET(cwrap_files
102+
# ${CWRAP_FILES_BASE}/cudnn/cuDNN.cwrap
103+
${CWRAP_FILES_BASE}/generic/TensorMethods.cwrap
104+
# ${CWRAP_FILES_BASE}/generic/methods/SparseTensor.cwrap
105+
${CWRAP_FILES_BASE}/generic/methods/Tensor.cwrap
106+
${CWRAP_FILES_BASE}/generic/methods/TensorApply.cwrap
107+
${CWRAP_FILES_BASE}/generic/methods/TensorCompare.cwrap
108+
${CWRAP_FILES_BASE}/generic/methods/TensorCuda.cwrap
109+
${CWRAP_FILES_BASE}/generic/methods/TensorMath.cwrap
110+
${CWRAP_FILES_BASE}/generic/methods/TensorRandom.cwrap
111+
# ${CWRAP_FILES_BASE}/generic/methods/TensorSerialization.cwrap
112+
${CMAKE_CURRENT_SOURCE_DIR}/Local.cwrap
113+
${CMAKE_CURRENT_SOURCE_DIR}/../THNN/generic/THNN.h
114+
${CMAKE_CURRENT_SOURCE_DIR}/../THCUNN/generic/THCUNN.h
115+
)
116+
ENDIF()
117+
118+
EXECUTE_PROCESS(
119+
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen.py ${CUDA_FLAG} -s ${CMAKE_CURRENT_SOURCE_DIR} --print-dependencies ${cwrap_files}
120+
# user stderr rather than stdout so we can still debug the script with print
121+
ERROR_VARIABLE generated_cpp
122+
RESULT_VARIABLE RETURN_VALUE
123+
)
124+
if (NOT RETURN_VALUE EQUAL 0)
125+
message(STATUS ${generated_cpp})
126+
message(FATAL_ERROR "Failed to get generated_cpp list")
127+
endif()
128+
129+
FILE(GLOB_RECURSE all_templates "templates/*")
130+
131+
FILE(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ATen)
132+
133+
ADD_CUSTOM_COMMAND(OUTPUT ${generated_cpp}
134+
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen.py ${CUDA_FLAG} -s ${CMAKE_CURRENT_SOURCE_DIR} ${cwrap_files}
135+
DEPENDS ${all_python} ${all_templates} ${cwrap_files})
136+
137+
SET(all_cpp ${base_cpp} ${generated_cpp})
138+
filter_list(generated_h generated_cpp "\\.h$")
139+
140+
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/..)
141+
# so the build can find the generated header files
142+
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_BINARY_DIR})
143+
ADD_LIBRARY(ATen SHARED ${all_cpp})
144+
SET_TARGET_PROPERTIES(ATen PROPERTIES VERSION 1 SOVERSION 1)
145+
146+
if(NOT ${CMAKE_VERSION} VERSION_LESS "3.1")
147+
SET_PROPERTY(TARGET ATen PROPERTY CXX_STANDARD 11)
148+
endif(NOT ${CMAKE_VERSION} VERSION_LESS "3.1")
149+
150+
TARGET_LINK_LIBRARIES(ATen ${TH_LIBRARIES} ${THNN_LIBRARIES} ${THS_LIBRARIES})
151+
IF(CUDA_FOUND)
152+
TARGET_LINK_LIBRARIES(ATen ${THC_LIBRARIES} ${THCUNN_LIBRARIES} ${THCS_LIBRARIES})
153+
TARGET_LINK_LIBRARIES(ATen ${CUDA_LIBRARIES})
154+
ENDIF()
155+
156+
INSTALL(TARGETS ATen
157+
RUNTIME DESTINATION "${TENSOR_LIB_INSTALL_BIN_DIR}"
158+
LIBRARY DESTINATION "${TENSOR_LIB_INSTALL_LIB_DIR}"
159+
ARCHIVE DESTINATION "${TENSOR_LIB_INSTALL_LIB_DIR}")
160+
161+
FOREACH(HEADER ${base_h})
162+
INSTALL(FILES ${HEADER} DESTINATION ${TENSOR_LIB_INSTALL_INCLUDE_DIR}/ATen)
163+
ENDFOREACH()
164+
FOREACH(HEADER ${generated_h})
165+
INSTALL(FILES ${CMAKE_CURRENT_BINARY_DIR}/${HEADER}
166+
DESTINATION ${TENSOR_LIB_INSTALL_INCLUDE_DIR}/ATen)
167+
ENDFOREACH()

torch/lib/ATen/CPUFixedAllocator.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include "TH/TH.h"
4+
5+
// This file creates a fake allocator that just throws exceptions if
6+
// it is actually used.
7+
8+
// state passed to the allocator is the std::function<void(void*)> called
9+
// when the blob is release by ATen
10+
11+
namespace at {
12+
13+
static cpu_fixed_malloc(void *, ptrdiff_t) {
14+
runtime_error("attempting to resize a tensor view of an external blob");
15+
}
16+
17+
static cpu_fixed_realloc(void *, void*, ptrdiff_t) {
18+
runtime_error("attempting to resize a tensor view of an external blob");
19+
}
20+
21+
static cpu_fixed_free(void * state, void * allocation) {
22+
auto on_release = static_cast<std::function<void(void*)>*>(state);
23+
(*on_release)(allocation);
24+
delete on_release;
25+
}
26+
27+
static THAllocator CPU_fixed_allocator =
28+
{ cpu_fixed_malloc, cpu_fixed_realloc, cpu_fixed_free };
29+
30+
}

0 commit comments

Comments
 (0)