Skip to content

Commit f88c3e9

Browse files
committed
fix some missing features in pytorch needed for RNNs
1 parent 942ca47 commit f88c3e9

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

torch/autograd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

33
from .variable import Variable
4-
from .function import Function, NestedInputFunction
4+
from .function import Function, NestedIOFunction
55

66
assert torch._C._autograd_init()

torch/autograd/functions/tensor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,19 @@ def backward(self, grad_output):
2828
return grad_input
2929

3030

31-
class SetValue(InplaceFunction):
31+
class SetItem(InplaceFunction):
3232

33-
def __init__(self, index, value):
34-
super(SetValue, self).__init__(True)
33+
def __init__(self, index, value=None):
34+
super(SetItem, self).__init__(True)
3535
self.index = index
3636
self.value = value
3737

38-
def forward(self, i):
38+
def forward(self, i, value=None):
3939
self.mark_dirty(i)
40-
i[self.index].fill_(self.value)
40+
if self.value is None:
41+
i[self.index].copy_(value)
42+
else:
43+
i[self.index].fill_(self.value)
4144
return i
4245

4346
def backward(self, grad_output):

torch/autograd/variable.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ def __setitem__(self, key, value):
5656
if (isinstance(key, Variable) and
5757
type(key.data).__name__ == 'ByteTensor'):
5858
return MaskedFill(value, inplace=True)(self, key)
59-
return SetValue(key, value)(self)
59+
if isinstance(value, Variable):
60+
return SetItem(key)(self, value)
61+
return SetItem(key, value)(self)
62+
63+
def __iter__(self):
64+
return iter(map(lambda i: self[i], range(self.size(0))))
6065

6166
def __deepcopy__(self, memo):
6267
if self.creator is None:

torch/nn/modules/module.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,25 @@ def _apply(self, fn):
3737
self._buffers[key] = fn(buf)
3838
return self
3939

40+
def _deleteGradParameters(self):
41+
for param in self._parameters.values():
42+
if hasattr(param, '_grad'):
43+
param._grad = None
44+
4045
def cuda(self, device_id=None):
46+
self._deleteGradParameters()
4147
return self._apply(lambda t: t.cuda(device_id))
4248

49+
4350
def cpu(self, device_id=None):
51+
self._deleteGradParameters()
4452
return self._apply(lambda t: t.cpu())
4553

4654
def float(self):
4755
return self._apply(lambda t: t.float())
4856

4957
def double(self):
58+
self._deleteGradParameters()
5059
return self._apply(lambda t: t.double())
5160

5261
def register_backward_hook(self, name, hook):

0 commit comments

Comments
 (0)