4646)
4747_TEST_RESPONSE_RUNNING_2_POOLS_RESIZE .resource_pools [1 ].replica_count = 1
4848
49+ _TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER = copy .deepcopy (
50+ tc .ClusterConstants ._TEST_RESPONSE_RUNNING_1_POOL
51+ )
52+ _TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER .resource_pools [0 ].replica_count = 1
53+
4954
5055@pytest .fixture
5156def create_persistent_resource_1_pool_mock ():
@@ -163,6 +168,22 @@ def update_persistent_resource_1_pool_mock():
163168 yield update_persistent_resource_1_pool_mock
164169
165170
171+ @pytest .fixture
172+ def update_persistent_resource_1_pool_0_worker_mock ():
173+ with mock .patch .object (
174+ PersistentResourceServiceClient ,
175+ "update_persistent_resource" ,
176+ ) as update_persistent_resource_1_pool_0_worker_mock :
177+ update_persistent_resource_lro_mock = mock .Mock (ga_operation .Operation )
178+ update_persistent_resource_lro_mock .result .return_value = (
179+ _TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER
180+ )
181+ update_persistent_resource_1_pool_0_worker_mock .return_value = (
182+ update_persistent_resource_lro_mock
183+ )
184+ yield update_persistent_resource_1_pool_0_worker_mock
185+
186+
166187@pytest .fixture
167188def update_persistent_resource_2_pools_mock ():
168189 with mock .patch .object (
@@ -472,6 +493,30 @@ def test_update_ray_cluster_1_pool(self, update_persistent_resource_1_pool_mock)
472493
473494 assert returned_name == tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS
474495
496+ @pytest .mark .usefixtures ("get_persistent_resource_1_pool_mock" )
497+ def test_update_ray_cluster_1_pool_to_0_worker (
498+ self , update_persistent_resource_1_pool_mock
499+ ):
500+
501+ new_worker_node_types = []
502+ for worker_node_type in tc .ClusterConstants ._TEST_CLUSTER .worker_node_types :
503+ # resize worker node to node_count = 0
504+ worker_node_type .node_count = 0
505+ new_worker_node_types .append (worker_node_type )
506+
507+ returned_name = vertex_ray .update_ray_cluster (
508+ cluster_resource_name = tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS ,
509+ worker_node_types = new_worker_node_types ,
510+ )
511+
512+ request = persistent_resource_service .UpdatePersistentResourceRequest (
513+ persistent_resource = _TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER ,
514+ update_mask = _EXPECTED_MASK ,
515+ )
516+ update_persistent_resource_1_pool_mock .assert_called_once_with (request )
517+
518+ assert returned_name == tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS
519+
475520 @pytest .mark .usefixtures ("get_persistent_resource_2_pools_mock" )
476521 def test_update_ray_cluster_2_pools (self , update_persistent_resource_2_pools_mock ):
477522
@@ -493,3 +538,49 @@ def test_update_ray_cluster_2_pools(self, update_persistent_resource_2_pools_moc
493538 update_persistent_resource_2_pools_mock .assert_called_once_with (request )
494539
495540 assert returned_name == tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS
541+
542+ @pytest .mark .usefixtures ("get_persistent_resource_2_pools_mock" )
543+ def test_update_ray_cluster_2_pools_0_worker_fail (self ):
544+
545+ new_worker_node_types = []
546+ for worker_node_type in tc .ClusterConstants ._TEST_CLUSTER_2 .worker_node_types :
547+ # resize worker node to node_count = 0
548+ worker_node_type .node_count = 0
549+ new_worker_node_types .append (worker_node_type )
550+
551+ with pytest .raises (ValueError ) as e :
552+ vertex_ray .update_ray_cluster (
553+ cluster_resource_name = tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS ,
554+ worker_node_types = new_worker_node_types ,
555+ )
556+
557+ e .match (regexp = r"must update to >= 1 nodes." )
558+
559+ @pytest .mark .usefixtures ("get_persistent_resource_1_pool_mock" )
560+ def test_update_ray_cluster_duplicate_worker_node_types_error (self ):
561+ new_worker_node_types = (
562+ tc .ClusterConstants ._TEST_CLUSTER_2 .worker_node_types
563+ + tc .ClusterConstants ._TEST_CLUSTER_2 .worker_node_types
564+ )
565+ with pytest .raises (ValueError ) as e :
566+ vertex_ray .update_ray_cluster (
567+ cluster_resource_name = tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS ,
568+ worker_node_types = new_worker_node_types ,
569+ )
570+
571+ e .match (regexp = r"Worker_node_types have duplicate machine specs" )
572+
573+ @pytest .mark .usefixtures ("get_persistent_resource_1_pool_mock" )
574+ def test_update_ray_cluster_mismatch_worker_node_types_count_error (self ):
575+ with pytest .raises (ValueError ) as e :
576+ new_worker_node_types = (
577+ tc .ClusterConstants ._TEST_CLUSTER_2 .worker_node_types
578+ )
579+ vertex_ray .update_ray_cluster (
580+ cluster_resource_name = tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS ,
581+ worker_node_types = new_worker_node_types ,
582+ )
583+
584+ e .match (
585+ regexp = r"does not match the number of the existing worker_node_type"
586+ )
0 commit comments