File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change 44import torch_xla .core .xla_model as xm
55from torch_xla .distributed .zero_redundancy_optimizer import ZeroRedundancyOptimizer
66from torch_xla import runtime as xr
7+ from torch .testing ._internal .common_utils import TestCase
78
89import unittest
910
1011
11- class XlaZeRO1Test (unittest . TestCase ):
12+ class XlaZeRO1Test (TestCase ):
1213
1314 @unittest .skipIf (xr .device_type () == 'TPU' , "Crash on TPU" )
1415 @unittest .skipIf (xr .device_type () == 'GPU' ,
@@ -33,20 +34,20 @@ def test_zero1(self):
3334
3435 opt1 .step ()
3536 opt2 .step ()
36- assert str (opt1 .state_dict ()) == str ( opt2 .state_dict ()['base' ])
37+ self . assertEqual (opt1 .state_dict (), opt2 .state_dict ()['base' ])
3738
3839 s1 = opt1 .state_dict ()
3940 s2 = opt2 .state_dict ()
4041 opt1 .load_state_dict (s1 )
4142 opt2 .load_state_dict (s2 )
42- assert str (opt1 .state_dict ()) == str ( opt2 .state_dict ()['base' ])
43+ self . assertEqual (opt1 .state_dict (), opt2 .state_dict ()['base' ])
4344
4445 # step still runnable
4546 opt1 .step ()
4647 opt2 .step ()
4748 opt1 .load_state_dict (s1 )
4849 opt2 .load_state_dict (s2 )
49- assert str (opt1 .state_dict ()) == str ( opt2 .state_dict ()['base' ])
50+ self . assertEqual (opt1 .state_dict (), opt2 .state_dict ()['base' ])
5051
5152 # step still runnable
5253 opt1 .step ()
You can’t perform that action at this time.
0 commit comments