4141
4242_TEST_PROJECT = "test-project"
4343_TEST_LOCATION = "us-central1"
44+ _TEST_PROJECT_NUMBER = "1234567890"
4445
4546_TEST_MODEL_FULL_RESOURCE_NAME = (
4647 "publishers/google/models/paligemma@paligemma-224-float32"
@@ -398,6 +399,40 @@ def list_publisher_models_mock():
398399 yield list_publisher_models
399400
400401
402+ @pytest .fixture
403+ def check_license_agreement_status_mock ():
404+ """Mocks the check_license_agreement_status method."""
405+ with mock .patch .object (
406+ model_garden_service .ModelGardenServiceClient ,
407+ "check_publisher_model_eula_acceptance" ,
408+ ) as check_license_agreement_status :
409+ check_license_agreement_status .return_value = (
410+ types .PublisherModelEulaAcceptance (
411+ project_number = _TEST_PROJECT_NUMBER ,
412+ publisher_model = _TEST_MODEL_FULL_RESOURCE_NAME ,
413+ publisher_model_eula_acked = True ,
414+ )
415+ )
416+ yield check_license_agreement_status
417+
418+
419+ @pytest .fixture
420+ def accept_model_license_agreement_mock ():
421+ """Mocks the accept_model_license_agreement method."""
422+ with mock .patch .object (
423+ model_garden_service .ModelGardenServiceClient ,
424+ "accept_publisher_model_eula" ,
425+ ) as accept_model_license_agreement :
426+ accept_model_license_agreement .return_value = (
427+ types .PublisherModelEulaAcceptance (
428+ project_number = _TEST_PROJECT_NUMBER ,
429+ publisher_model = _TEST_MODEL_FULL_RESOURCE_NAME ,
430+ publisher_model_eula_acked = True ,
431+ )
432+ )
433+ yield accept_model_license_agreement
434+
435+
401436@pytest .mark .usefixtures (
402437 "google_auth_mock" ,
403438 "deploy_mock" ,
@@ -406,6 +441,8 @@ def list_publisher_models_mock():
406441 "export_publisher_model_mock" ,
407442 "batch_prediction_mock" ,
408443 "complete_bq_uri_mock" ,
444+ "check_license_agreement_status_mock" ,
445+ "accept_model_license_agreement_mock" ,
409446)
410447class TestModelGarden :
411448 """Test cases for ModelGarden class."""
@@ -999,3 +1036,43 @@ def test_batch_prediction_success(self, batch_prediction_mock):
9991036 batch_prediction_job = expected_gapic_batch_prediction_job ,
10001037 timeout = None ,
10011038 )
1039+
1040+ def test_check_license_agreement_status_success (
1041+ self , check_license_agreement_status_mock
1042+ ):
1043+ """Tests checking EULA acceptance for a model."""
1044+ aiplatform .init (
1045+ project = _TEST_PROJECT ,
1046+ location = _TEST_LOCATION ,
1047+ )
1048+ model = model_garden .OpenModel (model_name = _TEST_MODEL_FULL_RESOURCE_NAME )
1049+ eula_acceptance = model .check_license_agreement_status ()
1050+ check_license_agreement_status_mock .assert_called_once_with (
1051+ types .CheckPublisherModelEulaAcceptanceRequest (
1052+ parent = f"projects/{ _TEST_PROJECT } " ,
1053+ publisher_model = _TEST_MODEL_FULL_RESOURCE_NAME ,
1054+ )
1055+ )
1056+ assert eula_acceptance
1057+
1058+ def test_accept_model_license_agreement_success (
1059+ self , accept_model_license_agreement_mock
1060+ ):
1061+ """Tests accepting EULA for a model."""
1062+ aiplatform .init (
1063+ project = _TEST_PROJECT ,
1064+ location = _TEST_LOCATION ,
1065+ )
1066+ model = model_garden .OpenModel (model_name = _TEST_MODEL_FULL_RESOURCE_NAME )
1067+ eula_acceptance = model .accept_model_license_agreement ()
1068+ accept_model_license_agreement_mock .assert_called_once_with (
1069+ types .AcceptPublisherModelEulaRequest (
1070+ parent = f"projects/{ _TEST_PROJECT } " ,
1071+ publisher_model = _TEST_MODEL_FULL_RESOURCE_NAME ,
1072+ )
1073+ )
1074+ assert eula_acceptance == types .PublisherModelEulaAcceptance (
1075+ project_number = _TEST_PROJECT_NUMBER ,
1076+ publisher_model = _TEST_MODEL_FULL_RESOURCE_NAME ,
1077+ publisher_model_eula_acked = True ,
1078+ )
0 commit comments