Skip to content

Commit 6a69f70

Browse files
apaszkesoumith
authored andcommitted
Revert "add keyword out for autograd function Concat to match torch.cat (pytorch#1336)" (pytorch#1340)
This reverts commit 71b9dea.
1 parent 71b9dea commit 6a69f70

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

test/test_autograd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,6 @@ def prod_single_zero(dim_size):
10741074
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S))),
10751075
(Concat, (-1,), ((S, S, 1), (S, S, 2), (S, S, 3)), 'negdim-1'),
10761076
(Concat, (-2,), ((S, 1, S), (S, 2, S), (S, 3, S)), 'negdim-2'),
1077-
(Concat, (0, None), ((1, S, S), (2, S, S), (3, S, S)), 'out'),
10781077
(Resize, (S * S, S), ((S, S, S),)),
10791078
(Diag, (), ((S, S),), '2d'),
10801079
(Diag, (), ((S,),), '1d'),

torch/autograd/_functions/tensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,13 @@ def backward(self, grad_output):
301301

302302
class Concat(Function):
303303

304-
def __init__(self, dim, out=None):
304+
def __init__(self, dim):
305305
super(Concat, self).__init__()
306306
self.dim = dim
307-
self.out = out
308307

309308
def forward(self, *inputs):
310309
self.input_sizes = [i.size(self.dim) for i in inputs]
311-
return torch.cat(inputs, self.dim, out=self.out)
310+
return torch.cat(inputs, self.dim)
312311

313312
def backward(self, grad_output):
314313
return tuple(grad_output.narrow(self.dim, end - size, size) for size, end

torch/autograd/variable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,8 @@ def __hash__(self):
837837
class _torch(object):
838838

839839
@staticmethod
840-
def cat(iterable, dim=0, out=None):
841-
return Concat(dim, out)(*iterable)
840+
def cat(iterable, dim=0):
841+
return Concat(dim)(*iterable)
842842

843843
@staticmethod
844844
def normal(means, std=1):

0 commit comments

Comments
 (0)