Skip to content

Commit 7deba74

Browse files
gchanansoumith
authored andcommitted
Implement MaxPool{1d,2d,3d}Backwards (non-differentiable) functions.
1 parent 48bb07a commit 7deba74

File tree

1 file changed

+76
-14
lines changed

1 file changed

+76
-14
lines changed

torch/nn/_functions/thnn/pooling.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,28 @@ def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1,
4343
return output
4444

4545
@staticmethod
46-
@once_differentiable
4746
def backward(ctx, grad_output, _indices_grad=None):
4847
if ctx.return_indices:
49-
input, indices = ctx.saved_tensors
48+
input, indices = ctx.saved_variables
5049
else:
51-
input, = ctx.saved_tensors
50+
input, = ctx.saved_variables
5251
indices = ctx.indices
5352

53+
grad_input = MaxPool1dBackward.apply(input, indices, grad_output, ctx.kernel_size, ctx.stride, ctx.pad,
54+
ctx.dilation, ctx.return_indices, ctx.ceil_mode)
55+
return grad_input, None, None, None, None, None, None
56+
57+
58+
class MaxPool1dBackward(Function):
59+
60+
@staticmethod
61+
def forward(ctx, input, indices, grad_output, kernel_size, stride, padding, dilation, return_indices, ceil_mode):
62+
ctx.kernel_size = kernel_size
63+
ctx.stride = stride
64+
ctx.pad = padding
65+
ctx.dilation = dilation
66+
ctx.return_indices = return_indices
67+
ctx.ceil_mode = ceil_mode
5468
input2d = input.unsqueeze(2)
5569
indices2d = indices.unsqueeze(2)
5670
grad_output2d = grad_output.unsqueeze(2)
@@ -64,7 +78,11 @@ def backward(ctx, grad_output, _indices_grad=None):
6478
ctx.dilation, 1,
6579
ctx.ceil_mode)
6680
grad_input = grad_input.squeeze(2)
67-
return grad_input, None, None, None, None, None, None
81+
return grad_input
82+
83+
@staticmethod
84+
def backward(ctx, ggI, ggIndices=None):
85+
raise ValueError("MaxPool1d cannot be differentiated twice")
6886

6987

7088
class MaxPool2d(Function):
@@ -97,13 +115,29 @@ def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1,
97115
return output
98116

99117
@staticmethod
100-
@once_differentiable
101118
def backward(ctx, grad_output, _indices_grad=None):
102119
if ctx.return_indices:
103-
input, indices = ctx.saved_tensors
120+
input, indices = ctx.saved_variables
104121
else:
105-
input, = ctx.saved_tensors
122+
input, = ctx.saved_variables
106123
indices = ctx.indices
124+
grad_input = MaxPool2dBackward.apply(input, indices, grad_output, ctx.kernel_size, ctx.stride, ctx.padding,
125+
ctx.dilation, ctx.return_indices, ctx.ceil_mode)
126+
return grad_input, None, None, None, None, None, None
127+
128+
129+
class MaxPool2dBackward(Function):
130+
131+
@staticmethod
132+
def forward(ctx, input, indices, grad_output, kernel_size, stride, padding, dilation,
133+
return_indices, ceil_mode):
134+
ctx.kernel_size = kernel_size
135+
ctx.stride = stride
136+
ctx.padding = padding
137+
ctx.dilation = dilation
138+
ctx.return_indices = return_indices
139+
ctx.ceil_mode = ceil_mode
140+
107141
grad_input = grad_output.new()
108142
backend = type2backend[type(input)]
109143
backend.SpatialDilatedMaxPooling_updateGradInput(backend.library_state,
@@ -113,7 +147,11 @@ def backward(ctx, grad_output, _indices_grad=None):
113147
ctx.padding[1], ctx.padding[0],
114148
ctx.dilation[1], ctx.dilation[0],
115149
ctx.ceil_mode)
116-
return grad_input, None, None, None, None, None, None
150+
return grad_input
151+
152+
@staticmethod
153+
def backward(ctx, ggI, _ggIndices=None):
154+
raise ValueError("MaxPool2d cannot be differentiated twice")
117155

118156

119157
class MaxPool3d(Function):
@@ -146,13 +184,28 @@ def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1,
146184
return output
147185

148186
@staticmethod
149-
@once_differentiable
150187
def backward(ctx, grad_output, _indices_grad=None):
151188
if ctx.return_indices:
152-
input, indices = ctx.saved_tensors
189+
input, indices = ctx.saved_variables
153190
else:
154-
input, = ctx.saved_tensors
191+
input, = ctx.saved_variables
155192
indices = ctx.indices
193+
grad_input = MaxPool3dBackward.apply(input, indices, grad_output, ctx.kernel_size, ctx.stride,
194+
ctx.padding, ctx.dilation, ctx.return_indices, ctx.ceil_mode)
195+
return grad_input, None, None, None, None, None, None
196+
197+
198+
class MaxPool3dBackward(Function):
199+
200+
@staticmethod
201+
def forward(ctx, input, indices, grad_output, kernel_size, stride, padding, dilation,
202+
return_indices, ceil_mode):
203+
ctx.kernel_size = kernel_size
204+
ctx.stride = stride
205+
ctx.padding = padding
206+
ctx.dilation = dilation
207+
ctx.return_indices = return_indices
208+
ctx.ceil_mode = ceil_mode
156209
grad_input = grad_output.new()
157210
backend = type2backend[type(input)]
158211
backend.VolumetricDilatedMaxPooling_updateGradInput(backend.library_state,
@@ -163,7 +216,11 @@ def backward(ctx, grad_output, _indices_grad=None):
163216
ctx.padding[0], ctx.padding[2], ctx.padding[1],
164217
ctx.dilation[0], ctx.dilation[2], ctx.dilation[1],
165218
ctx.ceil_mode)
166-
return grad_input, None, None, None, None, None, None
219+
return grad_input
220+
221+
@staticmethod
222+
def backward(ctx, ggI, _ggIndices=None):
223+
raise ValueError("MaxPool3d cannot be differentiated twice")
167224

168225

169226
class MaxUnpool2d(Function):
@@ -382,11 +439,11 @@ def forward(ctx, input, kernel_size, stride=None):
382439
@staticmethod
383440
def backward(ctx, grad_output):
384441
input, = ctx.saved_variables
385-
grad_input = AvgPool3dBackwards.apply(input, grad_output, ctx.kernel_size, ctx.stride)
442+
grad_input = AvgPool3dBackward.apply(input, grad_output, ctx.kernel_size, ctx.stride)
386443
return grad_input, None, None
387444

388445

389-
class AvgPool3dBackwards(Function):
446+
class AvgPool3dBackward(Function):
390447

391448
@staticmethod
392449
def forward(ctx, input, grad_output, kernel_size, stride):
@@ -408,6 +465,7 @@ def backward(ctx, ggI):
408465
ggO = AvgPool3d.apply(ggI, ctx.kernel_size, ctx.stride)
409466
return gI, ggO, None, None
410467

468+
411469
class AdaptiveMaxPool1d(Function):
412470

413471
@staticmethod
@@ -553,9 +611,13 @@ def backward(ctx, grad_output):
553611
_all_functions.append(AvgPool2d)
554612
_all_functions.append(AvgPool2dBackward)
555613
_all_functions.append(AvgPool3d)
614+
_all_functions.append(AvgPool3dBackward)
556615
_all_functions.append(MaxPool1d)
616+
_all_functions.append(MaxPool1dBackward)
557617
_all_functions.append(MaxPool2d)
618+
_all_functions.append(MaxPool2dBackward)
558619
_all_functions.append(MaxPool3d)
620+
_all_functions.append(MaxPool3dBackward)
559621
_all_functions.append(MaxUnpool2d)
560622
_all_functions.append(MaxUnpool3d)
561623
_all_functions.append(FractionalMaxPool2d)

0 commit comments

Comments
 (0)