Skip to content

Commit f870836

Browse files
committed
handle cuda variable gradient issue
1 parent de62e23 commit f870836

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

solver.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def train(self):
9191
for idx, (images,labels) in enumerate(self.data_loader['train']):
9292
self.global_iter += 1
9393

94-
x = cuda(Variable(images), self.cuda)
95-
y = cuda(Variable(labels), self.cuda)
94+
x = Variable(cuda(images, self.cuda))
95+
y = Variable(cuda(labels, self.cuda))
9696
(mu, std), logit = self.toynet(x)
9797

9898
class_loss = F.cross_entropy(logit,y).div(math.log(2))
@@ -114,7 +114,7 @@ def train(self):
114114
_, avg_soft_logit = self.toynet(x,self.num_avg,softmax=True)
115115
avg_prediction = F.softmax(avg_soft_logit,dim=1).max(1)[1]
116116
avg_accuracy = torch.eq(avg_prediction,y).float().mean()
117-
else : avg_accuracy = Variable(torch.zeros(accuracy.size()))
117+
else : avg_accuracy = Variable(cuda(torch.zeros(accuracy.size()), self.cuda))
118118

119119
IZY.append(izy_bound.data)
120120
IZX.append(izx_bound.data)
@@ -165,9 +165,8 @@ def test(self):
165165
izx_bound = 0
166166
for idx, (images,labels) in enumerate(self.data_loader['test']):
167167

168-
x = cuda(Variable(images), self.cuda)
169-
y = cuda(Variable(labels), self.cuda)
170-
#(mu, std), logit = self.toynet(x)
168+
x = Variable(cuda(images, self.cuda))
169+
y = Variable(cuda(labels, self.cuda))
171170
(mu, std), logit = self.toynet_ema.model(x)
172171

173172
class_loss += F.cross_entropy(logit,y,size_average=False).div(math.log(2))
@@ -186,7 +185,7 @@ def test(self):
186185
avg_prediction = F.softmax(avg_soft_logit,dim=1).max(1)[1]
187186
avg_correct += torch.eq(avg_prediction,y).float().sum()
188187
else :
189-
avg_correct = Variable(torch.zeros(correct.size()))
188+
avg_correct = Variable(cuda(torch.zeros(correct.size()), self.cuda))
190189

191190
accuracy = correct/total_num
192191
avg_accuracy = avg_correct/total_num

0 commit comments

Comments
 (0)