Skip to content

Commit ba5b225

Browse files
authored
Add watcher test. (#146)
1 parent b58637d commit ba5b225

File tree

3 files changed

+103
-38
lines changed

3 files changed

+103
-38
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import time
15+
import unittest
16+
from edl.tests.unittests import etcd_test_base
17+
from edl.utils import cluster as edl_cluster
18+
from edl.utils import constants
19+
from edl.utils import exceptions
20+
from edl.utils import cluster_watcher
21+
from edl.utils import pod as edl_pod
22+
23+
24+
class TestWatcher(etcd_test_base.EtcdTestBase):
25+
def setUp(self):
26+
super(TestWatcher, self).setUp("test_watcher")
27+
28+
def test_watcher_stage_changed(self):
29+
cluster = edl_cluster.Cluster()
30+
cluster._stage = "0"
31+
print("cluster 0 ids:", cluster.to_json(), cluster.get_pods_ids_list())
32+
self._etcd.set_server_permanent(constants.ETCD_CLUSTER,
33+
constants.ETCD_CLUSTER,
34+
cluster.to_json())
35+
watcher = cluster_watcher.Watcher(self._job_env, cluster)
36+
37+
cluster._stage = "1"
38+
print("cluster 1 ids:", cluster.to_json(), cluster.get_pods_ids_list())
39+
self._etcd.set_server_permanent(constants.ETCD_CLUSTER,
40+
constants.ETCD_CLUSTER,
41+
cluster.to_json())
42+
time.sleep(constants.ETCD_TTL)
43+
self.assertTrue(watcher.changed)
44+
45+
def test_watch_valid(self):
46+
try:
47+
cluster = edl_cluster.Cluster()
48+
self._etcd.set_server_permanent(constants.ETCD_CLUSTER,
49+
constants.ETCD_CLUSTER,
50+
cluster.to_json())
51+
watcher = cluster_watcher.Watcher(self._job_env, cluster)
52+
self._etcd.remove_server(constants.ETCD_CLUSTER,
53+
constants.ETCD_CLUSTER)
54+
time.sleep(constants.ETCD_TTL)
55+
except exceptions.EdlTableError as e:
56+
pass
57+
58+
def test_watcher_ids_changed(self):
59+
cluster = edl_cluster.Cluster()
60+
print("cluster 0 ids:", cluster.to_json(), cluster.get_pods_ids_list())
61+
self._etcd.set_server_permanent(constants.ETCD_CLUSTER,
62+
constants.ETCD_CLUSTER,
63+
cluster.to_json())
64+
watcher = cluster_watcher.Watcher(self._job_env, cluster)
65+
66+
pod = edl_pod.Pod()
67+
cluster._pods.append(pod)
68+
print("cluster 1 ids:", cluster.to_json(), cluster.get_pods_ids_list())
69+
self._etcd.set_server_permanent(constants.ETCD_CLUSTER,
70+
constants.ETCD_CLUSTER,
71+
cluster.to_json())
72+
time.sleep(constants.ETCD_TTL)
73+
self.assertTrue(watcher.changed)
74+
75+
76+
if __name__ == '__main__':
77+
unittest.main()

python/edl/utils/cluster.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,13 @@
2020
import time
2121
import traceback
2222
import uuid
23-
from edl.discovery import etcd_client
23+
2424
from edl.utils import constants
2525
from edl.utils import error_utils
2626
from edl.utils import exceptions
2727
from edl.utils import json_serializable
28-
from edl.utils import leader_pod
2928
from edl.utils import pod as edl_pod
30-
from edl.utils import resource_pods as edl_resource_pods
3129
from edl.utils import status as edl_status
32-
from edl.utils import train_status as edl_train_status
33-
from edl.utils.log_utils import logger
3430

3531

3632
class Cluster(json_serializable.Serializable):
@@ -163,3 +159,15 @@ def load_from_etcd(etcd, timeout=60):
163159
cluster = Cluster()
164160
cluster.from_json(value)
165161
return cluster
162+
163+
164+
@error_utils.handle_errors_until_timeout
165+
def wait_to_load_from_etcd(etcd, timeout=60):
166+
cluster = load_from_etcd(etcd, timeout=60)
167+
if cluster is None:
168+
raise exceptions.EdlTableError(
169+
"can't load cluster from etcd path:{}".format(
170+
etcd.get_full_path(constants.ETCD_CLUSTER,
171+
constants.ETCD_CLUSTER)))
172+
173+
return cluster

python/edl/utils/watcher.py renamed to python/edl/utils/cluster_watcher.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,64 +13,44 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import threading
1617
import time
17-
from edl.discovery.etcd_client import EtcdClient
18+
from edl.discovery import etcd_client
1819
from edl.utils import cluster as edl_cluster
19-
from edl.utils import constants
2020
from edl.utils.log_utils import logger
21-
from threading import Lock, Thread, Event
2221

2322

2423
class Watcher(object):
25-
def __init__(self, job_env, cluster, pod):
26-
self._etcd = None
27-
24+
def __init__(self, job_env, cluster):
2825
self._job_id = job_env.job_id
2926

3027
# current context
3128
self._cluster = copy.copy(cluster)
32-
self._leader_id = cluster.get_pod_leader_id()
33-
self._current_pod = pod
3429

35-
self._new_cluster = None
36-
self._new_leader_id = None
30+
self._new_cluster = cluster
3731
self._changed = False
3832
logger.info("watcher gets the init cluster:{}".format(self._cluster))
3933

40-
self._lock = Lock()
41-
self._stop = Event()
34+
self._lock = threading.Lock()
35+
self._stop = threading.Event()
4236

37+
self._etcd = None
4338
self._t_watcher = None
39+
self._job_env = job_env
4440

4541
# assign value
46-
self._etcd = EtcdClient(self._job_env.etcd_endpoints, root=job_id)
42+
self._etcd = etcd_client.EtcdClient(
43+
self._job_env.etcd_endpoints, root=self._job_id)
4744
self._etcd.init()
4845

49-
self._t_watcher = Thread(target=self._watcher)
46+
self._t_watcher = threading.Thread(target=self._watcher)
5047
self._t_watcher.start()
5148

5249
def _watcher(self):
53-
begin = time.time()
5450
while not self._stop.is_set():
55-
# if leader_id changed?
56-
servers = self._etcd.get_service(constants.ETCD_POD_RANK)
57-
assert len(servers) <= 1
58-
if len(servers) == 0:
59-
time.sleep(1)
60-
continue
61-
62-
with self._lock:
63-
self._new_leader_id = s.info
64-
6551
# if cluster changed?
66-
value, _, _, _, _, = etcd._get_server(constants.ETCD_CLUSTER,
67-
self._new_leader_id)
68-
if value is None:
69-
time.sleep(1)
70-
continue
71-
new_cluster = edl_cluster.Cluster()
72-
new_cluster.from_json(value)
73-
52+
new_cluster = edl_cluster.wait_to_load_from_etcd(
53+
self._etcd, timeout=60)
7454
with self._lock:
7555
self._new_cluster = new_cluster
7656

0 commit comments

Comments
 (0)