Skip to content

Commit b0e33fb

Browse files
committed
cudnn + THNN match with parameters
1 parent d58b627 commit b0e33fb

File tree

4 files changed

+245
-126
lines changed

4 files changed

+245
-126
lines changed

torch/backends/cudnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def __del__(self):
199199
check_error(lib.cudnnDestroyDropoutDescriptor(self))
200200

201201
def set(self, handle, dropout, dropout_states, seed):
202+
self.dropout_states = dropout_states # make sure it's retained
202203
check_error(lib.cudnnSetDropoutDescriptor(
203204
self,
204205
handle,

torch/backends/cudnn/rnn.py

Lines changed: 104 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,27 @@
55

66

77
def initDropoutDescriptor(fn, handle):
8-
fn.dropout_desc = cudnn.DropoutDescriptor()
8+
dropout_desc = cudnn.DropoutDescriptor()
99

1010
dropout_states_size = ctypes.c_long()
1111
check_error(cudnn.lib.cudnnDropoutGetStatesSize(
1212
handle,
1313
ctypes.byref(dropout_states_size)))
1414

15-
fn.dropout_states = torch.cuda.ByteTensor(dropout_states_size.value)
16-
15+
dropout_states = torch.cuda.ByteTensor(dropout_states_size.value)
1716
fn.dropout_desc.set(
1817
handle,
1918
fn.dropout,
20-
fn.dropout_states,
19+
dropout_states,
2120
fn.seed
2221
)
22+
return dropout_desc
2323

2424

2525
def initRNNDescriptor(fn):
26-
fn.rnn_desc = cudnn.RNNDescriptor()
26+
rnn_desc = cudnn.RNNDescriptor()
2727

28-
fn.rnn_desc.set(
28+
rnn_desc.set(
2929
fn.hidden_size,
3030
fn.num_layers,
3131
fn.dropout_desc,
@@ -34,30 +34,15 @@ def initRNNDescriptor(fn):
3434
fn.mode,
3535
fn.datatype
3636
)
37+
return rnn_desc
3738

3839

3940
def initWeightDescriptor(fn, weight):
40-
fn.w_desc = cudnn.FilterDescriptor()
41+
w_desc = cudnn.FilterDescriptor()
4142
w_view = weight.view(-1, 1, 1) # seems that filters require >=3 dimensions
42-
fn.w_desc.set(w_view)
43-
44-
45-
def initIODescriptor(fn, input, output):
46-
fn.x_descs = cudnn.descriptor(input[0], fn.seq_length)
47-
fn.y_descs = cudnn.descriptor(output[0], fn.seq_length)
48-
43+
w_desc.set(w_view)
44+
return w_desc
4945

50-
def initHiddenDescriptors(fn, hx):
51-
fn.hx_desc = cudnn.descriptor(hx)
52-
fn.hy_desc = cudnn.descriptor(hx)
53-
54-
55-
def initCellDescriptors(fn, cx):
56-
if cx:
57-
fn.cx_desc = cudnn.descriptor(cx)
58-
fn.cy_desc = cudnn.descriptor(cx)
59-
else:
60-
fn.cx_desc = fn.cy_desc = None
6146

6247
def _inputSize(fn):
6348
return (fn.seq_length, fn.mini_batch, fn.input_size)
@@ -71,73 +56,88 @@ def _outputSize(fn):
7156
return (fn.seq_length, fn.mini_batch, fn.hidden_size * fn.num_directions)
7257

7358

74-
def getNumWeights(fn, handle):
59+
def getNumWeights(handle, rnn_desc, x_desc, datatype):
7560
weight_size = ctypes.c_long()
7661
check_error(cudnn.lib.cudnnGetRNNParamsSize(
7762
handle,
78-
fn.rnn_desc,
79-
fn.x_descs[0],
63+
rnn_desc,
64+
x_desc,
8065
ctypes.byref(weight_size),
81-
fn.datatype
66+
datatype
8267
))
8368
elem_size = cudnn._sizeofmap[fn.datatype]
8469
assert(weight_size.value % elem_size == 0)
8570
return weight_size.value // elem_size
8671

8772

88-
def _parametersHelper(fn, handle, weight, cudnn_method):
89-
linear_params = []
73+
def getParameters(fn, handle, weight_buf):
74+
75+
param_types = {
76+
'weight': cudnn.lib.cudnnGetRNNLinLayerMatrixParams,
77+
'bias' : cudnn.lib.cudnnGetRNNLinLayerBiasParams
78+
}
79+
80+
# if fn.mode == cudnn.CUDNN_RNN_RELU or fn.mode == cudnn.CUDNN_RNN_TANH:
81+
# linear_name = ["ih", "hh"]
82+
# elif fn.mode == cudnn.CUDNN_LSTM:
83+
# linear_name = ["ii", "if", "ic", "io", "hi", "hf", "hc", "ho"]
84+
# elif fn.mode == cudnn.CUDNN_GRU:
85+
# linear_name = ["ir", "iu", "ic", "hr", "hu", "hc"]
86+
# else:
87+
# raise Exception("Unknown mode: {}".format(fn.mode))
88+
89+
params = []
9090
num_linear_layers = _numLinearLayers(fn)
9191
num_layers = fn.num_directions * fn.num_layers
9292
for layer in range(num_layers):
93-
layer_info = []
94-
for layer_id in range(num_linear_layers):
95-
lin_layer_mat_desc = cudnn.FilterDescriptor()
96-
matrix_pointer = ctypes.c_void_p()
97-
check_error(cudnn_method(
98-
handle,
99-
fn.rnn_desc,
100-
layer,
101-
fn.x_descs[0],
102-
fn.w_desc,
103-
ctypes.c_void_p(weight.data_ptr()),
104-
layer_id,
105-
lin_layer_mat_desc,
106-
ctypes.byref(matrix_pointer)))
107-
108-
data_type = ctypes.c_int()
109-
format = ctypes.c_int()
110-
nb_dims = ctypes.c_int()
111-
min_dim = 3
112-
filter_dim_a = torch.IntStorage(min_dim)
113-
check_error(cudnn.lib.cudnnGetFilterNdDescriptor(
114-
lin_layer_mat_desc,
115-
min_dim,
116-
ctypes.byref(data_type),
117-
ctypes.byref(format),
118-
ctypes.byref(nb_dims),
119-
ctypes.c_void_p(filter_dim_a.data_ptr())))
120-
121-
filter_dim_a.resize_(nb_dims.value)
122-
elem_size = cudnn._sizeofmap[fn.datatype]
123-
offset_bytes = (matrix_pointer.value - weight.data_ptr())
124-
assert(offset_bytes % elem_size == 0)
125-
offset = offset_bytes // elem_size
126-
params = weight.new().set_(weight.storage(), offset, filter_dim_a.long())
127-
layer_info.append(params)
128-
129-
linear_params.append(layer_info)
130-
131-
return linear_params
132-
133-
def parameters(fn, handle, weight):
134-
parameters = {}
135-
parameters['weight'] = _parametersHelper(
136-
fn, handle, weight, cudnn.lib.cudnnGetRNNLinLayerMatrixParams)
137-
parameters['bias'] = _parametersHelper(
138-
fn, handle, weight, cudnn.lib.cudnnGetRNNLinLayerBiasParams)
139-
return parameters
140-
93+
layer_params = []
94+
for param_type, cudnn_method in enumerate(param_types):
95+
for layer_id in range(num_linear_layers):
96+
lin_layer_mat_desc = cudnn.FilterDescriptor()
97+
matrix_pointer = ctypes.c_void_p()
98+
check_error(cudnn_method(
99+
handle,
100+
fn.rnn_desc,
101+
layer,
102+
fn.x_desc,
103+
fn.w_desc,
104+
ctypes.c_void_p(weight_buf.data_ptr()),
105+
layer_id,
106+
lin_layer_mat_desc,
107+
ctypes.byref(matrix_pointer)))
108+
109+
data_type = ctypes.c_int()
110+
format = ctypes.c_int()
111+
nb_dims = ctypes.c_int()
112+
min_dim = 3
113+
filter_dim_a = torch.IntStorage(min_dim)
114+
check_error(cudnn.lib.cudnnGetFilterNdDescriptor(
115+
lin_layer_mat_desc,
116+
min_dim,
117+
ctypes.byref(data_type),
118+
ctypes.byref(format),
119+
ctypes.byref(nb_dims),
120+
ctypes.c_void_p(filter_dim_a.data_ptr())))
121+
122+
filter_dim_a.resize_(nb_dims.value)
123+
elem_size = cudnn._sizeofmap[fn.datatype]
124+
offset_bytes = (matrix_pointer.value - weight.data_ptr())
125+
assert(offset_bytes % elem_size == 0)
126+
offset = offset_bytes // elem_size
127+
param = fn.weight_buf.new().set_(weight.storage(), offset, filter_dim_a.long())
128+
# name = "l{}.{}.{}".format(layer, linear_name[layer_id], param_type)
129+
layer_params.append(param)
130+
131+
params.append(layer_params)
132+
133+
return params
134+
135+
136+
def _copyParams(params_from, params_to):
137+
for layer_params_from, layer_weights_from in zip(params_from, params_to):
138+
for param_from, param_to in zip(layer_params_from, layer_params_to):
139+
assert(param_from.type() == tuple(param_to.type()))
140+
param_to.copy(param_from)
141141

142142
def forward(fn, input, hx, cx, weight, output, hy, cy):
143143
with torch.cuda.device_of(input):
@@ -167,30 +167,35 @@ def forward(fn, input, hx, cx, weight, output, hy, cy):
167167
if cy:
168168
cy.resize_(*hidden_size).zero_()
169169
y = output
170-
w = weight
171170

172-
initDropoutDescriptor(fn, handle)
173-
initRNNDescriptor(fn)
174-
initIODescriptor(fn, x, y)
175-
initHiddenDescriptors(fn, hx)
176-
initCellDescriptors(fn, cx)
177-
initWeightDescriptor(fn, weight)
171+
# init descriptors
172+
fn.dropout_desc = initDropoutDescriptor(fn, handle)
173+
fn.rnn_desc = initRNNDescriptor(fn)
174+
fn.x_descs = cudnn.descriptor(x[0], fn.seq_length)
175+
fn.y_descs = cudnn.descriptor(y[0], fn.seq_length)
176+
fn.hx_desc = cudnn.descriptor(hx)
177+
fn.hy_desc = cudnn.descriptor(hx)
178+
fn.cx_desc = cudnn.descriptor(cx) if cx else None
179+
fn.cy_desc = cudnn.descriptor(cx) if cx else None
180+
181+
num_weights = getNumWeights(handle, fn.rnn_desc, fn.x_desc, fn.datatype)
182+
fn.weight_buf = input.new(num_weights)
183+
fn.w_desc = initWeightDescriptor(fn, fn.weight_buf)
184+
w = fn.weight_buf
178185

179-
params = parameters(fn, handle, weight)
180-
for ptype in params:
181-
for l, layer in enumerate(params[ptype]):
182-
for ll, param in enumerate(layer):
183-
print(ptype, l, ll, param.storage_offset(), tuple(param.size()))
186+
params = getParameters(fn, handle, weight)
187+
_copyParams(weight, params)
184188

185189
if tuple(hx.size()) != hidden_size:
186190
raise Exception('Expected hx size {}, got {}'.format(
187191
tuple(hidden_size, hx.size())))
188192
if cx and tuple(cx.size()) != hidden_size:
189193
raise Exception('Expected cx size {}, got {}'.format(
190194
tuple(hidden_size, cx.size())))
191-
if weight.nelement() != getNumWeights(fn, handle):
195+
expected_num_weights = getNumWeights(handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
196+
if weight.nelement() != expected_num_weights:
192197
raise Exception("Expected #weights {}, got {}".format(
193-
getNumWeights(fn, handle), weight.nelement()))
198+
expected_num_weights, weight.nelement()))
194199
workspace_size = ctypes.c_long()
195200
check_error(lib.cudnnGetRNNWorkspaceSize(
196201
handle,
@@ -259,7 +264,7 @@ def backward_grad(fn, input, hx, cx, weight, output, grad_output, grad_hy, grad_
259264
x = input.contiguous()
260265
dy = grad_output.contiguous()
261266
y = output
262-
w = weight
267+
w = fn.weight_buf
263268
dx = grad_input.resize_as_(input)
264269
dhy = grad_hy.resize_(*hidden_size)
265270
dcy = grad_cy.resize_(*hidden_size) if grad_cy else None
@@ -351,7 +356,7 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
351356

352357
x = input.contiguous()
353358
y = output
354-
dw = grad_weight.resize_as_(weight)
359+
dw = fn.weight_buf.new().resize_as_(weight_buf)
355360

356361
check_error(cudnn.lib.cudnnRNNBackwardWeights(
357362
handle,
@@ -365,4 +370,7 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
365370
ctypes.c_void_p(fn.reserve.data_ptr()), fn.reserve.size(0)
366371
))
367372

368-
return dw
373+
params = getParameters(fn, handle, dw)
374+
_copyParams(grad_params, grad_weight)
375+
376+
return grad_weight

0 commit comments

Comments
 (0)