@@ -194,14 +194,34 @@ def test_volatile(self):
194194
195195 def test_indexing (self ):
196196 x = torch .range (1 , 16 ).resize_ (4 , 4 )
197- y = Variable (x )
198- self .assertEqual (x [1 ], y [1 ].data )
199- self .assertEqual (x [1 , 1 ], y [1 , 1 ].data [0 ])
200- self .assertEqual (x [1 :], y [1 :].data )
201- self .assertEqual (x [:2 ], y [:2 ].data )
202- self .assertEqual (x [:2 , 2 ], y [:2 , 2 ].data )
203- self .assertEqual (x [1 :2 , 2 ], y [1 :2 , 2 ].data )
204- self .assertEqual (x [1 , 2 :], y [1 , 2 :].data )
197+ y = Variable (x , requires_grad = True )
198+
199+ def check_index (idx ):
200+ y .grad .data .zero_ ()
201+ indexed_tensor = x [idx ]
202+ indexed_var = y [idx ]
203+
204+ indexed_var_t = indexed_var .data
205+ if not torch .is_tensor (indexed_tensor ):
206+ indexed_var_t = indexed_var_t [0 ]
207+ self .assertEqual (indexed_tensor , indexed_var )
208+
209+ indexed_var .sum ().backward ()
210+ expected_grad = torch .zeros (4 , 4 )
211+ expected_grad [idx ] = 1
212+ self .assertEqual (y .grad .data , expected_grad )
213+
214+ check_index (1 )
215+ check_index ((1 , 1 ))
216+ check_index (slice (1 , None ))
217+ check_index (slice (None , 2 ))
218+ check_index ((slice (None , 2 ), 2 ))
219+ check_index ((slice (1 , 2 ), 2 ))
220+ check_index ((1 , slice (2 , None )))
221+ check_index ((slice (None , None ), slice (2 , None )))
222+ check_index (torch .LongTensor ([0 , 2 ]))
223+ check_index (torch .rand (4 , 4 ).bernoulli ().byte ())
224+ check_index ((Ellipsis , slice (2 , None )))
205225
206226 def test_requires_grad (self ):
207227 x = Variable (torch .randn (5 , 5 ))
0 commit comments