Skip to content

Commit 7861f58

Browse files
apaszkesoumith
authored andcommitted
Reshape grad in dot
1 parent 274b5c9 commit 7861f58

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torch/autograd/_functions/blas.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,17 @@ class Dot(Function):
179179

180180
def forward(self, vector1, vector2):
181181
self.save_for_backward(vector1, vector2)
182+
self.sizes = (vector1.size(), vector2.size())
182183
return vector1.new((vector1.dot(vector2),))
183184

184185
def backward(self, grad_output):
185186
vector1, vector2 = self.saved_tensors
186187
grad_vector1 = grad_vector2 = None
187188

188189
if self.needs_input_grad[0]:
189-
grad_vector1 = vector2.mul(grad_output[0])
190+
grad_vector1 = vector2.mul(grad_output[0]).view(self.sizes[0])
190191

191192
if self.needs_input_grad[1]:
192-
grad_vector2 = vector1.mul(grad_output[0])
193+
grad_vector2 = vector1.mul(grad_output[0]).view(self.sizes[1])
193194

194195
return grad_vector1, grad_vector2

0 commit comments

Comments
 (0)