Skip to content

Commit 498bce4

Browse files
authored
[BE] use self.assertEquals instead of str equality in test_zero1.py (#5367)
* [BE] use self.assertEquals instead of str equality in test_zero1.py * Use our own assertEqual * Remove print statements
1 parent 4f4978c commit 498bce4

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

test/test_zero1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import torch_xla.core.xla_model as xm
55
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
66
from torch_xla import runtime as xr
7+
from torch.testing._internal.common_utils import TestCase
78

89
import unittest
910

1011

11-
class XlaZeRO1Test(unittest.TestCase):
12+
class XlaZeRO1Test(TestCase):
1213

1314
@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU")
1415
@unittest.skipIf(xr.device_type() == 'GPU',
@@ -33,20 +34,20 @@ def test_zero1(self):
3334

3435
opt1.step()
3536
opt2.step()
36-
assert str(opt1.state_dict()) == str(opt2.state_dict()['base'])
37+
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])
3738

3839
s1 = opt1.state_dict()
3940
s2 = opt2.state_dict()
4041
opt1.load_state_dict(s1)
4142
opt2.load_state_dict(s2)
42-
assert str(opt1.state_dict()) == str(opt2.state_dict()['base'])
43+
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])
4344

4445
# step still runnable
4546
opt1.step()
4647
opt2.step()
4748
opt1.load_state_dict(s1)
4849
opt2.load_state_dict(s2)
49-
assert str(opt1.state_dict()) == str(opt2.state_dict()['base'])
50+
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])
5051

5152
# step still runnable
5253
opt1.step()

0 commit comments

Comments
 (0)