@@ -91,8 +91,8 @@ def train(self):
9191 for idx , (images ,labels ) in enumerate (self .data_loader ['train' ]):
9292 self .global_iter += 1
9393
94- x = cuda ( Variable (images ) , self .cuda )
95- y = cuda ( Variable (labels ) , self .cuda )
94+ x = Variable (cuda ( images , self .cuda ) )
95+ y = Variable (cuda ( labels , self .cuda ) )
9696 (mu , std ), logit = self .toynet (x )
9797
9898 class_loss = F .cross_entropy (logit ,y ).div (math .log (2 ))
@@ -114,7 +114,7 @@ def train(self):
114114 _ , avg_soft_logit = self .toynet (x ,self .num_avg ,softmax = True )
115115 avg_prediction = F .softmax (avg_soft_logit ,dim = 1 ).max (1 )[1 ]
116116 avg_accuracy = torch .eq (avg_prediction ,y ).float ().mean ()
117- else : avg_accuracy = Variable (torch .zeros (accuracy .size ()))
117+ else : avg_accuracy = Variable (cuda ( torch .zeros (accuracy .size ()), self . cuda ))
118118
119119 IZY .append (izy_bound .data )
120120 IZX .append (izx_bound .data )
@@ -165,9 +165,8 @@ def test(self):
165165 izx_bound = 0
166166 for idx , (images ,labels ) in enumerate (self .data_loader ['test' ]):
167167
168- x = cuda (Variable (images ), self .cuda )
169- y = cuda (Variable (labels ), self .cuda )
170- #(mu, std), logit = self.toynet(x)
168+ x = Variable (cuda (images , self .cuda ))
169+ y = Variable (cuda (labels , self .cuda ))
171170 (mu , std ), logit = self .toynet_ema .model (x )
172171
173172 class_loss += F .cross_entropy (logit ,y ,size_average = False ).div (math .log (2 ))
@@ -186,7 +185,7 @@ def test(self):
186185 avg_prediction = F .softmax (avg_soft_logit ,dim = 1 ).max (1 )[1 ]
187186 avg_correct += torch .eq (avg_prediction ,y ).float ().sum ()
188187 else :
189- avg_correct = Variable (torch .zeros (correct .size ()))
188+ avg_correct = Variable (cuda ( torch .zeros (correct .size ()), self . cuda ))
190189
191190 accuracy = correct / total_num
192191 avg_accuracy = avg_correct / total_num
0 commit comments