|
5 | 5 | from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer |
6 | 6 | from torch_xla import runtime as xr |
7 | 7 | from torch.testing._internal.common_utils import TestCase |
| 8 | +from copy import deepcopy |
8 | 9 |
|
9 | 10 | import unittest |
10 | 11 |
|
@@ -34,19 +35,41 @@ def test_zero1(self): |
34 | 35 |
|
35 | 36 | opt1.step() |
36 | 37 | opt2.step() |
37 | | - self.assertEqual(opt1.state_dict(), opt2.state_dict()['base']) |
38 | | - |
39 | 38 | s1 = opt1.state_dict() |
40 | 39 | s2 = opt2.state_dict() |
| 40 | + print("AFTER STEPPING ONCE") |
| 41 | + print("opt1.state", opt1.state) |
| 42 | + print("opt1.state_dict()", s1) |
| 43 | + print("opt2.state[base]", opt2.state['base']) |
| 44 | + print("opt2.state_dict()[base]", s2['base']) |
| 45 | + self.assertEqual(s1, s2['base']) |
| 46 | + |
| 47 | + # s1_clone = deepcopy(s1) |
| 48 | + # s2_clone = deepcopy(s2) |
41 | 49 | opt1.load_state_dict(s1) |
42 | 50 | opt2.load_state_dict(s2) |
| 51 | + print("AFTER LOADING THE STATE_DICTs, should be same as before") |
| 52 | + print("opt1.state", opt1.state) |
| 53 | + print("opt1.state_dict()", opt1.state_dict()) |
| 54 | + print("opt2.state", opt2.state['base']) |
| 55 | + print("opt2.state_dict()[base]", opt2.state_dict()['base']) |
43 | 56 | self.assertEqual(opt1.state_dict(), opt2.state_dict()['base']) |
44 | 57 |
|
45 | 58 | # step still runnable |
46 | 59 | opt1.step() |
47 | 60 | opt2.step() |
| 61 | + print("AFTER STEPPING AGAIN, WILL be different") |
| 62 | + print("opt1.state", opt1.state) |
| 63 | + print("opt1.state_dict()", opt1.state_dict()) |
| 64 | + print("opt2.state", opt2.state['base']) |
| 65 | + print("opt2.state_dict()[base]", opt2.state_dict()['base']) |
48 | 66 | opt1.load_state_dict(s1) |
49 | 67 | opt2.load_state_dict(s2) |
| 68 | + print("AFTER LOADING THE STATE_DICTs, should be same as before") |
| 69 | + print("opt1.state", opt1.state) |
| 70 | + print("opt1.state_dict()", opt1.state_dict()) |
| 71 | + print("opt2.state", opt2.state['base']) |
| 72 | + print("opt2.state_dict()[base]", opt2.state_dict()['base']) |
50 | 73 | self.assertEqual(opt1.state_dict(), opt2.state_dict()['base']) |
51 | 74 |
|
52 | 75 | # step still runnable |
|
0 commit comments