File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
torch/autograd/_functions Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -179,16 +179,17 @@ class Dot(Function):
179179
180180 def forward (self , vector1 , vector2 ):
181181 self .save_for_backward (vector1 , vector2 )
182+ self .sizes = (vector1 .size (), vector2 .size ())
182183 return vector1 .new ((vector1 .dot (vector2 ),))
183184
184185 def backward (self , grad_output ):
185186 vector1 , vector2 = self .saved_tensors
186187 grad_vector1 = grad_vector2 = None
187188
188189 if self .needs_input_grad [0 ]:
189- grad_vector1 = vector2 .mul (grad_output [0 ])
190+ grad_vector1 = vector2 .mul (grad_output [0 ]). view ( self . sizes [ 0 ])
190191
191192 if self .needs_input_grad [1 ]:
192- grad_vector2 = vector1 .mul (grad_output [0 ])
193+ grad_vector2 = vector1 .mul (grad_output [0 ]). view ( self . sizes [ 1 ])
193194
194195 return grad_vector1 , grad_vector2
You can’t perform that action at this time.
0 commit comments