Skip to content

Commit 0f139ab

Browse files
committed
[TEST ONLY] print statements for test_zero1.py to debug
1 parent 498bce4 commit 0f139ab

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

test/test_zero1.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
66
from torch_xla import runtime as xr
77
from torch.testing._internal.common_utils import TestCase
8+
from copy import deepcopy
89

910
import unittest
1011

@@ -34,19 +35,41 @@ def test_zero1(self):
3435

3536
opt1.step()
3637
opt2.step()
37-
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])
38-
3938
s1 = opt1.state_dict()
4039
s2 = opt2.state_dict()
40+
print("AFTER STEPPING ONCE")
41+
print("opt1.state", opt1.state)
42+
print("opt1.state_dict()", s1)
43+
print("opt2.state[base]", opt2.state['base'])
44+
print("opt2.state_dict()[base]", s2['base'])
45+
self.assertEqual(s1, s2['base'])
46+
47+
# s1_clone = deepcopy(s1)
48+
# s2_clone = deepcopy(s2)
4149
opt1.load_state_dict(s1)
4250
opt2.load_state_dict(s2)
51+
print("AFTER LOADING THE STATE_DICTs, should be same as before")
52+
print("opt1.state", opt1.state)
53+
print("opt1.state_dict()", opt1.state_dict())
54+
print("opt2.state", opt2.state['base'])
55+
print("opt2.state_dict()[base]", opt2.state_dict()['base'])
4356
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])
4457

4558
# step still runnable
4659
opt1.step()
4760
opt2.step()
61+
print("AFTER STEPPING AGAIN, WILL be different")
62+
print("opt1.state", opt1.state)
63+
print("opt1.state_dict()", opt1.state_dict())
64+
print("opt2.state", opt2.state['base'])
65+
print("opt2.state_dict()[base]", opt2.state_dict()['base'])
4866
opt1.load_state_dict(s1)
4967
opt2.load_state_dict(s2)
68+
print("AFTER LOADING THE STATE_DICTs, should be same as before")
69+
print("opt1.state", opt1.state)
70+
print("opt1.state_dict()", opt1.state_dict())
71+
print("opt2.state", opt2.state['base'])
72+
print("opt2.state_dict()[base]", opt2.state_dict()['base'])
5073
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])
5174

5275
# step still runnable

torch_patches/.torch_pin

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#106082

0 commit comments

Comments
 (0)