@@ -43,14 +43,28 @@ def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1,
4343 return output
4444
4545 @staticmethod
46- @once_differentiable
4746 def backward (ctx , grad_output , _indices_grad = None ):
4847 if ctx .return_indices :
49- input , indices = ctx .saved_tensors
48+ input , indices = ctx .saved_variables
5049 else :
51- input , = ctx .saved_tensors
50+ input , = ctx .saved_variables
5251 indices = ctx .indices
5352
53+ grad_input = MaxPool1dBackward .apply (input , indices , grad_output , ctx .kernel_size , ctx .stride , ctx .pad ,
54+ ctx .dilation , ctx .return_indices , ctx .ceil_mode )
55+ return grad_input , None , None , None , None , None , None
56+
57+
58+ class MaxPool1dBackward (Function ):
59+
60+ @staticmethod
61+ def forward (ctx , input , indices , grad_output , kernel_size , stride , padding , dilation , return_indices , ceil_mode ):
62+ ctx .kernel_size = kernel_size
63+ ctx .stride = stride
64+ ctx .pad = padding
65+ ctx .dilation = dilation
66+ ctx .return_indices = return_indices
67+ ctx .ceil_mode = ceil_mode
5468 input2d = input .unsqueeze (2 )
5569 indices2d = indices .unsqueeze (2 )
5670 grad_output2d = grad_output .unsqueeze (2 )
@@ -64,7 +78,11 @@ def backward(ctx, grad_output, _indices_grad=None):
6478 ctx .dilation , 1 ,
6579 ctx .ceil_mode )
6680 grad_input = grad_input .squeeze (2 )
67- return grad_input , None , None , None , None , None , None
81+ return grad_input
82+
83+ @staticmethod
84+ def backward (ctx , ggI , ggIndices = None ):
85+ raise ValueError ("MaxPool1d cannot be differentiated twice" )
6886
6987
7088class MaxPool2d (Function ):
@@ -97,13 +115,29 @@ def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1,
97115 return output
98116
99117 @staticmethod
100- @once_differentiable
101118 def backward (ctx , grad_output , _indices_grad = None ):
102119 if ctx .return_indices :
103- input , indices = ctx .saved_tensors
120+ input , indices = ctx .saved_variables
104121 else :
105- input , = ctx .saved_tensors
122+ input , = ctx .saved_variables
106123 indices = ctx .indices
124+ grad_input = MaxPool2dBackward .apply (input , indices , grad_output , ctx .kernel_size , ctx .stride , ctx .padding ,
125+ ctx .dilation , ctx .return_indices , ctx .ceil_mode )
126+ return grad_input , None , None , None , None , None , None
127+
128+
129+ class MaxPool2dBackward (Function ):
130+
131+ @staticmethod
132+ def forward (ctx , input , indices , grad_output , kernel_size , stride , padding , dilation ,
133+ return_indices , ceil_mode ):
134+ ctx .kernel_size = kernel_size
135+ ctx .stride = stride
136+ ctx .padding = padding
137+ ctx .dilation = dilation
138+ ctx .return_indices = return_indices
139+ ctx .ceil_mode = ceil_mode
140+
107141 grad_input = grad_output .new ()
108142 backend = type2backend [type (input )]
109143 backend .SpatialDilatedMaxPooling_updateGradInput (backend .library_state ,
@@ -113,7 +147,11 @@ def backward(ctx, grad_output, _indices_grad=None):
113147 ctx .padding [1 ], ctx .padding [0 ],
114148 ctx .dilation [1 ], ctx .dilation [0 ],
115149 ctx .ceil_mode )
116- return grad_input , None , None , None , None , None , None
150+ return grad_input
151+
152+ @staticmethod
153+ def backward (ctx , ggI , _ggIndices = None ):
154+ raise ValueError ("MaxPool2d cannot be differentiated twice" )
117155
118156
119157class MaxPool3d (Function ):
@@ -146,13 +184,28 @@ def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1,
146184 return output
147185
148186 @staticmethod
149- @once_differentiable
150187 def backward (ctx , grad_output , _indices_grad = None ):
151188 if ctx .return_indices :
152- input , indices = ctx .saved_tensors
189+ input , indices = ctx .saved_variables
153190 else :
154- input , = ctx .saved_tensors
191+ input , = ctx .saved_variables
155192 indices = ctx .indices
193+ grad_input = MaxPool3dBackward .apply (input , indices , grad_output , ctx .kernel_size , ctx .stride ,
194+ ctx .padding , ctx .dilation , ctx .return_indices , ctx .ceil_mode )
195+ return grad_input , None , None , None , None , None , None
196+
197+
198+ class MaxPool3dBackward (Function ):
199+
200+ @staticmethod
201+ def forward (ctx , input , indices , grad_output , kernel_size , stride , padding , dilation ,
202+ return_indices , ceil_mode ):
203+ ctx .kernel_size = kernel_size
204+ ctx .stride = stride
205+ ctx .padding = padding
206+ ctx .dilation = dilation
207+ ctx .return_indices = return_indices
208+ ctx .ceil_mode = ceil_mode
156209 grad_input = grad_output .new ()
157210 backend = type2backend [type (input )]
158211 backend .VolumetricDilatedMaxPooling_updateGradInput (backend .library_state ,
@@ -163,7 +216,11 @@ def backward(ctx, grad_output, _indices_grad=None):
163216 ctx .padding [0 ], ctx .padding [2 ], ctx .padding [1 ],
164217 ctx .dilation [0 ], ctx .dilation [2 ], ctx .dilation [1 ],
165218 ctx .ceil_mode )
166- return grad_input , None , None , None , None , None , None
219+ return grad_input
220+
221+ @staticmethod
222+ def backward (ctx , ggI , _ggIndices = None ):
223+ raise ValueError ("MaxPool3d cannot be differentiated twice" )
167224
168225
169226class MaxUnpool2d (Function ):
@@ -382,11 +439,11 @@ def forward(ctx, input, kernel_size, stride=None):
382439 @staticmethod
383440 def backward (ctx , grad_output ):
384441 input , = ctx .saved_variables
385- grad_input = AvgPool3dBackwards .apply (input , grad_output , ctx .kernel_size , ctx .stride )
442+ grad_input = AvgPool3dBackward .apply (input , grad_output , ctx .kernel_size , ctx .stride )
386443 return grad_input , None , None
387444
388445
389- class AvgPool3dBackwards (Function ):
446+ class AvgPool3dBackward (Function ):
390447
391448 @staticmethod
392449 def forward (ctx , input , grad_output , kernel_size , stride ):
@@ -408,6 +465,7 @@ def backward(ctx, ggI):
408465 ggO = AvgPool3d .apply (ggI , ctx .kernel_size , ctx .stride )
409466 return gI , ggO , None , None
410467
468+
411469class AdaptiveMaxPool1d (Function ):
412470
413471 @staticmethod
@@ -553,9 +611,13 @@ def backward(ctx, grad_output):
553611_all_functions .append (AvgPool2d )
554612_all_functions .append (AvgPool2dBackward )
555613_all_functions .append (AvgPool3d )
614+ _all_functions .append (AvgPool3dBackward )
556615_all_functions .append (MaxPool1d )
616+ _all_functions .append (MaxPool1dBackward )
557617_all_functions .append (MaxPool2d )
618+ _all_functions .append (MaxPool2dBackward )
558619_all_functions .append (MaxPool3d )
620+ _all_functions .append (MaxPool3dBackward )
559621_all_functions .append (MaxUnpool2d )
560622_all_functions .append (MaxUnpool3d )
561623_all_functions .append (FractionalMaxPool2d )
0 commit comments