@@ -72,6 +72,13 @@ def is_acceptable(tensor):
7272CUDNN_TENSOR_NCHW = 0
7373CUDNN_TENSOR_NHWC = 1
7474
75+ CUDNN_RNN_RELU = 0
76+ CUDNN_RNN_TANH = 1
77+ CUDNN_LSTM = 2
78+ CUDNN_GRU = 3
79+
80+ CUDNN_LINEAR_INPUT = 0
81+ CUDNN_SKIP_INPUT = 1
7582
7683class CuDNNHandle :
7784 def __init__ (self ):
@@ -88,6 +95,7 @@ def __init__(self, status):
8895 msg = '{}: {}' .format (status , get_error_string (status ))
8996 super (CuDNNError , self ).__init__ (msg )
9097
98+
9199class TensorDescriptor (object ):
92100 def __init__ (self ):
93101 ptr = ctypes .c_void_p ()
@@ -109,6 +117,38 @@ def set(self, tensor):
109117 def as_tuple (self ):
110118 return (self ._type , tuple (self ._size ), tuple (self ._stride ))
111119
120+ < << << << 9 cd68129da50023929aff0ca4e4ba667ae75d785
121+ == == == =
122+
123+ class TensorDescriptorArray (object ):
124+ def __init__ (self , N ):
125+ self .ptrs = (ctypes .c_void_p * N )()
126+ for i in range (N ):
127+ ptr = ctypes .byref (self .ptrs , i * ctypes .sizeof (ctypes .c_void_p ))
128+ check_error (lib .cudnnCreateTensorDescriptor (ptr ))
129+ self ._as_parameter_ = self .ptrs
130+
131+ def __del__ (self ):
132+ for ptr in self .ptrs :
133+ check_error (lib .cudnnDestroyTensorDescriptor (ptr ))
134+
135+ def __getitem__ (self , key ):
136+ return self .ptrs [key ]
137+
138+ def set (self , tensor ):
139+ self ._type = tensor .type ()
140+ self ._size = tensor .size ()
141+ self ._stride = tensor .stride ()
142+ for ptr in self .ptrs :
143+ check_error (lib .cudnnSetTensorNdDescriptor (
144+ ptr , _typemap [tensor .type ()], tensor .dim (),
145+ int_array (tensor .size ()), int_array (tensor .stride ())))
146+
147+ def as_tuple (self ):
148+ return (self ._type , tuple (self ._size ), tuple (self ._stride ))
149+
150+
151+ > >> >> >> CUDNN RNN bindings
112152class ConvolutionDescriptor (object ):
113153 def __init__ (self ):
114154 ptr = ctypes .c_void_p ()
@@ -144,11 +184,52 @@ def set(self, weight):
144184 self ._size = weight .size ()
145185 datatype = _typemap [weight .type ()]
146186 check_error (lib .cudnnSetFilterNdDescriptor (
147- self , datatype , CUDNN_TENSOR_NCHW , 4 , int_array (weight .size ())))
187+ self , datatype , CUDNN_TENSOR_NCHW , weight . ndimension () , int_array (weight .size ())))
148188
149189 def as_tuple (self ):
150190 return tuple (self ._size )
151191
192+ class DropoutDescriptor (object ):
193+ def __init__ (self ):
194+ ptr = ctypes .c_void_p ()
195+ check_error (lib .cudnnCreateDropoutDescriptor (ctypes .byref (ptr )))
196+ self ._as_parameter_ = ptr
197+
198+ def __del__ (self ):
199+ check_error (lib .cudnnDestroyDropoutDescriptor (self ))
200+
201+ def set (self , handle , dropout , dropout_states , seed ):
202+ check_error (lib .cudnnSetDropoutDescriptor (
203+ self ,
204+ handle ,
205+ dropout ,
206+ ctypes .c_void_p (dropout_states .data_ptr ()),
207+ dropout_states .size (0 ),
208+ seed
209+ ))
210+
211+ class RNNDescriptor (object ):
212+ def __init__ (self ):
213+ ptr = ctypes .c_void_p ()
214+ check_error (lib .cudnnCreateRNNDescriptor (ctypes .byref (ptr )))
215+ self ._as_parameter_ = ptr
216+
217+ def __del__ (self ):
218+ check_error (lib .cudnnDestroyRNNDescriptor (self ))
219+
220+ def set (self , hidden_size , num_layers , dropout_desc , input_mode ,
221+ bidirectional , mode , datatype ):
222+ check_error (lib .cudnnSetRNNDescriptor (
223+ self ,
224+ hidden_size ,
225+ num_layers ,
226+ dropout_desc ,
227+ input_mode ,
228+ bidirectional ,
229+ mode ,
230+ datatype
231+ ))
232+
152233class ConvolutionAlgoPerf (ctypes .Structure ):
153234 _fields_ = [
154235 ("algo" , ctypes .c_int ),
@@ -180,6 +261,12 @@ def get_handle():
180261 'torch.cuda.DoubleTensor' : CUDNN_DATA_DOUBLE ,
181262}
182263
264+ _sizeofmap = {
265+ CUDNN_DATA_HALF : 2 ,
266+ CUDNN_DATA_FLOAT : 4 ,
267+ CUDNN_DATA_DOUBLE : 8 ,
268+ }
269+
183270def c_type (tensor ):
184271 if isinstance (tensor , torch .cuda .HalfTensor ):
185272 return ctypes .c_float
@@ -194,8 +281,11 @@ def int_array(itr):
194281 array_type = ctypes .c_int * len (itr )
195282 return array_type (* itr )
196283
197- def descriptor (tensor ):
198- descriptor = TensorDescriptor ()
284+ def descriptor (tensor , N = None ):
285+ if N :
286+ descriptor = TensorDescriptorArray (N )
287+ else :
288+ descriptor = TensorDescriptor ()
199289 if tensor .dim () == 2 :
200290 tensor = tensor .view (tensor .size (0 ), tensor .size (1 ), 1 , 1 )
201291 elif tensor .dim () == 3 :
0 commit comments