Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions test/test_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
from torch_xla import runtime as xr
from torch.testing._internal.common_utils import TestCase
from copy import deepcopy

import unittest

Expand Down Expand Up @@ -34,18 +35,22 @@ def test_zero1(self):

opt1.step()
opt2.step()
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])

s1 = opt1.state_dict()
s2 = opt2.state_dict()
self.assertEqual(s1, s2['base'])

# deepcopy s1 to load later because pytorch optimizers do not guarantee the input
# state_dict will not be modified. on the other hand, s2 has this guarantee.
s1_clone = deepcopy(s1)

opt1.load_state_dict(s1)
opt2.load_state_dict(s2)
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])

# step still runnable
opt1.step()
opt2.step()
opt1.load_state_dict(s1)
opt1.load_state_dict(s1_clone)
opt2.load_state_dict(s2)
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])

Expand Down