55
66
77def initDropoutDescriptor (fn , handle ):
8- fn . dropout_desc = cudnn .DropoutDescriptor ()
8+ dropout_desc = cudnn .DropoutDescriptor ()
99
1010 dropout_states_size = ctypes .c_long ()
1111 check_error (cudnn .lib .cudnnDropoutGetStatesSize (
1212 handle ,
1313 ctypes .byref (dropout_states_size )))
1414
15- fn .dropout_states = torch .cuda .ByteTensor (dropout_states_size .value )
16-
15+ dropout_states = torch .cuda .ByteTensor (dropout_states_size .value )
1716 fn .dropout_desc .set (
1817 handle ,
1918 fn .dropout ,
20- fn . dropout_states ,
19+ dropout_states ,
2120 fn .seed
2221 )
22+ return dropout_desc
2323
2424
2525def initRNNDescriptor (fn ):
26- fn . rnn_desc = cudnn .RNNDescriptor ()
26+ rnn_desc = cudnn .RNNDescriptor ()
2727
28- fn . rnn_desc .set (
28+ rnn_desc .set (
2929 fn .hidden_size ,
3030 fn .num_layers ,
3131 fn .dropout_desc ,
@@ -34,30 +34,15 @@ def initRNNDescriptor(fn):
3434 fn .mode ,
3535 fn .datatype
3636 )
37+ return rnn_desc
3738
3839
3940def initWeightDescriptor (fn , weight ):
40- fn . w_desc = cudnn .FilterDescriptor ()
41+ w_desc = cudnn .FilterDescriptor ()
4142 w_view = weight .view (- 1 , 1 , 1 ) # seems that filters require >=3 dimensions
42- fn .w_desc .set (w_view )
43-
44-
45- def initIODescriptor (fn , input , output ):
46- fn .x_descs = cudnn .descriptor (input [0 ], fn .seq_length )
47- fn .y_descs = cudnn .descriptor (output [0 ], fn .seq_length )
48-
43+ w_desc .set (w_view )
44+ return w_desc
4945
50- def initHiddenDescriptors (fn , hx ):
51- fn .hx_desc = cudnn .descriptor (hx )
52- fn .hy_desc = cudnn .descriptor (hx )
53-
54-
55- def initCellDescriptors (fn , cx ):
56- if cx :
57- fn .cx_desc = cudnn .descriptor (cx )
58- fn .cy_desc = cudnn .descriptor (cx )
59- else :
60- fn .cx_desc = fn .cy_desc = None
6146
6247def _inputSize (fn ):
6348 return (fn .seq_length , fn .mini_batch , fn .input_size )
@@ -71,73 +56,88 @@ def _outputSize(fn):
7156 return (fn .seq_length , fn .mini_batch , fn .hidden_size * fn .num_directions )
7257
7358
74- def getNumWeights (fn , handle ):
59+ def getNumWeights (handle , rnn_desc , x_desc , datatype ):
7560 weight_size = ctypes .c_long ()
7661 check_error (cudnn .lib .cudnnGetRNNParamsSize (
7762 handle ,
78- fn . rnn_desc ,
79- fn . x_descs [ 0 ] ,
63+ rnn_desc ,
64+ x_desc ,
8065 ctypes .byref (weight_size ),
81- fn . datatype
66+ datatype
8267 ))
8368 elem_size = cudnn ._sizeofmap [fn .datatype ]
8469 assert (weight_size .value % elem_size == 0 )
8570 return weight_size .value // elem_size
8671
8772
88- def _parametersHelper (fn , handle , weight , cudnn_method ):
89- linear_params = []
73+ def getParameters (fn , handle , weight_buf ):
74+
75+ param_types = {
76+ 'weight' : cudnn .lib .cudnnGetRNNLinLayerMatrixParams ,
77+ 'bias' : cudnn .lib .cudnnGetRNNLinLayerBiasParams
78+ }
79+
80+ # if fn.mode == cudnn.CUDNN_RNN_RELU or fn.mode == cudnn.CUDNN_RNN_TANH:
81+ # linear_name = ["ih", "hh"]
82+ # elif fn.mode == cudnn.CUDNN_LSTM:
83+ # linear_name = ["ii", "if", "ic", "io", "hi", "hf", "hc", "ho"]
84+ # elif fn.mode == cudnn.CUDNN_GRU:
85+ # linear_name = ["ir", "iu", "ic", "hr", "hu", "hc"]
86+ # else:
87+ # raise Exception("Unknown mode: {}".format(fn.mode))
88+
89+ params = []
9090 num_linear_layers = _numLinearLayers (fn )
9191 num_layers = fn .num_directions * fn .num_layers
9292 for layer in range (num_layers ):
93- layer_info = []
94- for layer_id in range ( num_linear_layers ):
95- lin_layer_mat_desc = cudnn . FilterDescriptor ()
96- matrix_pointer = ctypes . c_void_p ()
97- check_error ( cudnn_method (
98- handle ,
99- fn . rnn_desc ,
100- layer ,
101- fn . x_descs [ 0 ] ,
102- fn .w_desc ,
103- ctypes . c_void_p ( weight . data_ptr ()) ,
104- layer_id ,
105- lin_layer_mat_desc ,
106- ctypes . byref ( matrix_pointer )))
107-
108- data_type = ctypes . c_int ()
109- format = ctypes .c_int ()
110- nb_dims = ctypes .c_int ()
111- min_dim = 3
112- filter_dim_a = torch . IntStorage ( min_dim )
113- check_error ( cudnn . lib . cudnnGetFilterNdDescriptor (
114- lin_layer_mat_desc ,
115- min_dim ,
116- ctypes . byref ( data_type ) ,
117- ctypes .byref (format ),
118- ctypes .byref (nb_dims ),
119- ctypes .c_void_p ( filter_dim_a . data_ptr ())))
120-
121- filter_dim_a . resize_ ( nb_dims . value )
122- elem_size = cudnn . _sizeofmap [ fn . datatype ]
123- offset_bytes = ( matrix_pointer . value - weight . data_ptr ())
124- assert ( offset_bytes % elem_size == 0 )
125- offset = offset_bytes // elem_size
126- params = weight . new (). set_ ( weight . storage (), offset , filter_dim_a . long ())
127- layer_info . append ( params )
128-
129- linear_params .append (layer_info )
130-
131- return linear_params
132-
133- def parameters ( fn , handle , weight ):
134- parameters = {}
135- parameters [ 'weight' ] = _parametersHelper (
136- fn , handle , weight , cudnn . lib . cudnnGetRNNLinLayerMatrixParams )
137- parameters [ 'bias' ] = _parametersHelper (
138- fn , handle , weight , cudnn . lib . cudnnGetRNNLinLayerBiasParams )
139- return parameters
140-
93+ layer_params = []
94+ for param_type , cudnn_method in enumerate ( param_types ):
95+ for layer_id in range ( num_linear_layers ):
96+ lin_layer_mat_desc = cudnn . FilterDescriptor ()
97+ matrix_pointer = ctypes . c_void_p ()
98+ check_error ( cudnn_method (
99+ handle ,
100+ fn . rnn_desc ,
101+ layer ,
102+ fn .x_desc ,
103+ fn . w_desc ,
104+ ctypes . c_void_p ( weight_buf . data_ptr ()) ,
105+ layer_id ,
106+ lin_layer_mat_desc ,
107+ ctypes . byref ( matrix_pointer )))
108+
109+ data_type = ctypes .c_int ()
110+ format = ctypes .c_int ()
111+ nb_dims = ctypes . c_int ()
112+ min_dim = 3
113+ filter_dim_a = torch . IntStorage ( min_dim )
114+ check_error ( cudnn . lib . cudnnGetFilterNdDescriptor (
115+ lin_layer_mat_desc ,
116+ min_dim ,
117+ ctypes .byref (data_type ),
118+ ctypes .byref (format ),
119+ ctypes .byref ( nb_dims ),
120+ ctypes . c_void_p ( filter_dim_a . data_ptr ())))
121+
122+ filter_dim_a . resize_ ( nb_dims . value )
123+ elem_size = cudnn . _sizeofmap [ fn . datatype ]
124+ offset_bytes = ( matrix_pointer . value - weight . data_ptr () )
125+ assert ( offset_bytes % elem_size == 0 )
126+ offset = offset_bytes // elem_size
127+ param = fn . weight_buf . new (). set_ ( weight . storage (), offset , filter_dim_a . long () )
128+ # name = "l{}.{}.{}".format(layer, linear_name[layer_id], param_type)
129+ layer_params .append (param )
130+
131+ params . append ( layer_params )
132+
133+ return params
134+
135+
136+ def _copyParams ( params_from , params_to ):
137+ for layer_params_from , layer_weights_from in zip ( params_from , params_to ):
138+ for param_from , param_to in zip ( layer_params_from , layer_params_to ):
139+ assert ( param_from . type () == tuple ( param_to . type ()))
140+ param_to . copy ( param_from )
141141
142142def forward (fn , input , hx , cx , weight , output , hy , cy ):
143143 with torch .cuda .device_of (input ):
@@ -167,30 +167,35 @@ def forward(fn, input, hx, cx, weight, output, hy, cy):
167167 if cy :
168168 cy .resize_ (* hidden_size ).zero_ ()
169169 y = output
170- w = weight
171170
172- initDropoutDescriptor (fn , handle )
173- initRNNDescriptor (fn )
174- initIODescriptor (fn , x , y )
175- initHiddenDescriptors (fn , hx )
176- initCellDescriptors (fn , cx )
177- initWeightDescriptor (fn , weight )
171+ # init descriptors
172+ fn .dropout_desc = initDropoutDescriptor (fn , handle )
173+ fn .rnn_desc = initRNNDescriptor (fn )
174+ fn .x_descs = cudnn .descriptor (x [0 ], fn .seq_length )
175+ fn .y_descs = cudnn .descriptor (y [0 ], fn .seq_length )
176+ fn .hx_desc = cudnn .descriptor (hx )
177+ fn .hy_desc = cudnn .descriptor (hx )
178+ fn .cx_desc = cudnn .descriptor (cx ) if cx else None
179+ fn .cy_desc = cudnn .descriptor (cx ) if cx else None
180+
181+ num_weights = getNumWeights (handle , fn .rnn_desc , fn .x_desc , fn .datatype )
182+ fn .weight_buf = input .new (num_weights )
183+ fn .w_desc = initWeightDescriptor (fn , fn .weight_buf )
184+ w = fn .weight_buf
178185
179- params = parameters (fn , handle , weight )
180- for ptype in params :
181- for l , layer in enumerate (params [ptype ]):
182- for ll , param in enumerate (layer ):
183- print (ptype , l , ll , param .storage_offset (), tuple (param .size ()))
186+ params = getParameters (fn , handle , weight )
187+ _copyParams (weight , params )
184188
185189 if tuple (hx .size ()) != hidden_size :
186190 raise Exception ('Expected hx size {}, got {}' .format (
187191 tuple (hidden_size , hx .size ())))
188192 if cx and tuple (cx .size ()) != hidden_size :
189193 raise Exception ('Expected cx size {}, got {}' .format (
190194 tuple (hidden_size , cx .size ())))
191- if weight .nelement () != getNumWeights (fn , handle ):
195+ expected_num_weights = getNumWeights (handle , fn .rnn_desc , fn .x_descs [0 ], fn .datatype )
196+ if weight .nelement () != expected_num_weights :
192197 raise Exception ("Expected #weights {}, got {}" .format (
193- getNumWeights ( fn , handle ) , weight .nelement ()))
198+ expected_num_weights , weight .nelement ()))
194199 workspace_size = ctypes .c_long ()
195200 check_error (lib .cudnnGetRNNWorkspaceSize (
196201 handle ,
@@ -259,7 +264,7 @@ def backward_grad(fn, input, hx, cx, weight, output, grad_output, grad_hy, grad_
259264 x = input .contiguous ()
260265 dy = grad_output .contiguous ()
261266 y = output
262- w = weight
267+ w = fn . weight_buf
263268 dx = grad_input .resize_as_ (input )
264269 dhy = grad_hy .resize_ (* hidden_size )
265270 dcy = grad_cy .resize_ (* hidden_size ) if grad_cy else None
@@ -351,7 +356,7 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
351356
352357 x = input .contiguous ()
353358 y = output
354- dw = grad_weight . resize_as_ (weight )
359+ dw = fn . weight_buf . new (). resize_as_ (weight_buf )
355360
356361 check_error (cudnn .lib .cudnnRNNBackwardWeights (
357362 handle ,
@@ -365,4 +370,7 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
365370 ctypes .c_void_p (fn .reserve .data_ptr ()), fn .reserve .size (0 )
366371 ))
367372
368- return dw
373+ params = getParameters (fn , handle , dw )
374+ _copyParams (grad_params , grad_weight )
375+
376+ return grad_weight
0 commit comments