Skip to content

Commit d58b627

Browse files
committed
CUDNN RNN bindings
1 parent b85fc35 commit d58b627

File tree

6 files changed

+635
-4
lines changed

6 files changed

+635
-4
lines changed

torch/backends/cudnn/__init__.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def is_acceptable(tensor):
7272
CUDNN_TENSOR_NCHW = 0
7373
CUDNN_TENSOR_NHWC = 1
7474

75+
CUDNN_RNN_RELU = 0
76+
CUDNN_RNN_TANH = 1
77+
CUDNN_LSTM = 2
78+
CUDNN_GRU = 3
79+
80+
CUDNN_LINEAR_INPUT = 0
81+
CUDNN_SKIP_INPUT = 1
7582

7683
class CuDNNHandle:
7784
def __init__(self):
@@ -88,6 +95,7 @@ def __init__(self, status):
8895
msg = '{}: {}'.format(status, get_error_string(status))
8996
super(CuDNNError, self).__init__(msg)
9097

98+
9199
class TensorDescriptor(object):
92100
def __init__(self):
93101
ptr = ctypes.c_void_p()
@@ -109,6 +117,38 @@ def set(self, tensor):
109117
def as_tuple(self):
110118
return (self._type, tuple(self._size), tuple(self._stride))
111119

120+
<<<<<<< 9cd68129da50023929aff0ca4e4ba667ae75d785
121+
=======
122+
123+
class TensorDescriptorArray(object):
124+
def __init__(self, N):
125+
self.ptrs = (ctypes.c_void_p * N)()
126+
for i in range(N):
127+
ptr = ctypes.byref(self.ptrs, i * ctypes.sizeof(ctypes.c_void_p))
128+
check_error(lib.cudnnCreateTensorDescriptor(ptr))
129+
self._as_parameter_ = self.ptrs
130+
131+
def __del__(self):
132+
for ptr in self.ptrs:
133+
check_error(lib.cudnnDestroyTensorDescriptor(ptr))
134+
135+
def __getitem__(self, key):
136+
return self.ptrs[key]
137+
138+
def set(self, tensor):
139+
self._type = tensor.type()
140+
self._size = tensor.size()
141+
self._stride = tensor.stride()
142+
for ptr in self.ptrs:
143+
check_error(lib.cudnnSetTensorNdDescriptor(
144+
ptr, _typemap[tensor.type()], tensor.dim(),
145+
int_array(tensor.size()), int_array(tensor.stride())))
146+
147+
def as_tuple(self):
148+
return (self._type, tuple(self._size), tuple(self._stride))
149+
150+
151+
>>>>>>> CUDNN RNN bindings
112152
class ConvolutionDescriptor(object):
113153
def __init__(self):
114154
ptr = ctypes.c_void_p()
@@ -144,11 +184,52 @@ def set(self, weight):
144184
self._size = weight.size()
145185
datatype = _typemap[weight.type()]
146186
check_error(lib.cudnnSetFilterNdDescriptor(
147-
self, datatype, CUDNN_TENSOR_NCHW, 4, int_array(weight.size())))
187+
self, datatype, CUDNN_TENSOR_NCHW, weight.ndimension(), int_array(weight.size())))
148188

149189
def as_tuple(self):
150190
return tuple(self._size)
151191

192+
class DropoutDescriptor(object):
193+
def __init__(self):
194+
ptr = ctypes.c_void_p()
195+
check_error(lib.cudnnCreateDropoutDescriptor(ctypes.byref(ptr)))
196+
self._as_parameter_ = ptr
197+
198+
def __del__(self):
199+
check_error(lib.cudnnDestroyDropoutDescriptor(self))
200+
201+
def set(self, handle, dropout, dropout_states, seed):
202+
check_error(lib.cudnnSetDropoutDescriptor(
203+
self,
204+
handle,
205+
dropout,
206+
ctypes.c_void_p(dropout_states.data_ptr()),
207+
dropout_states.size(0),
208+
seed
209+
))
210+
211+
class RNNDescriptor(object):
212+
def __init__(self):
213+
ptr = ctypes.c_void_p()
214+
check_error(lib.cudnnCreateRNNDescriptor(ctypes.byref(ptr)))
215+
self._as_parameter_ = ptr
216+
217+
def __del__(self):
218+
check_error(lib.cudnnDestroyRNNDescriptor(self))
219+
220+
def set(self, hidden_size, num_layers, dropout_desc, input_mode,
221+
bidirectional, mode, datatype):
222+
check_error(lib.cudnnSetRNNDescriptor(
223+
self,
224+
hidden_size,
225+
num_layers,
226+
dropout_desc,
227+
input_mode,
228+
bidirectional,
229+
mode,
230+
datatype
231+
))
232+
152233
class ConvolutionAlgoPerf(ctypes.Structure):
153234
_fields_ = [
154235
("algo", ctypes.c_int),
@@ -180,6 +261,12 @@ def get_handle():
180261
'torch.cuda.DoubleTensor': CUDNN_DATA_DOUBLE,
181262
}
182263

264+
_sizeofmap = {
265+
CUDNN_DATA_HALF : 2,
266+
CUDNN_DATA_FLOAT : 4,
267+
CUDNN_DATA_DOUBLE : 8,
268+
}
269+
183270
def c_type(tensor):
184271
if isinstance(tensor, torch.cuda.HalfTensor):
185272
return ctypes.c_float
@@ -194,8 +281,11 @@ def int_array(itr):
194281
array_type = ctypes.c_int * len(itr)
195282
return array_type(*itr)
196283

197-
def descriptor(tensor):
198-
descriptor = TensorDescriptor()
284+
def descriptor(tensor, N=None):
285+
if N:
286+
descriptor = TensorDescriptorArray(N)
287+
else:
288+
descriptor = TensorDescriptor()
199289
if tensor.dim() == 2:
200290
tensor = tensor.view(tensor.size(0), tensor.size(1), 1, 1)
201291
elif tensor.dim() == 3:

0 commit comments

Comments
 (0)