Skip to content

Commit 3289d92

Browse files
haomengchaocopybara-github
authored andcommitted
feat: Implement check_license_agreement_status and accept_model_license_agreement for Model Garden OpenModel.
PiperOrigin-RevId: 753629633
1 parent 063f868 commit 3289d92

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

tests/unit/vertexai/model_garden/test_model_garden.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
_TEST_PROJECT = "test-project"
4343
_TEST_LOCATION = "us-central1"
44+
_TEST_PROJECT_NUMBER = "1234567890"
4445

4546
_TEST_MODEL_FULL_RESOURCE_NAME = (
4647
"publishers/google/models/paligemma@paligemma-224-float32"
@@ -398,6 +399,40 @@ def list_publisher_models_mock():
398399
yield list_publisher_models
399400

400401

402+
@pytest.fixture
403+
def check_license_agreement_status_mock():
404+
"""Mocks the check_license_agreement_status method."""
405+
with mock.patch.object(
406+
model_garden_service.ModelGardenServiceClient,
407+
"check_publisher_model_eula_acceptance",
408+
) as check_license_agreement_status:
409+
check_license_agreement_status.return_value = (
410+
types.PublisherModelEulaAcceptance(
411+
project_number=_TEST_PROJECT_NUMBER,
412+
publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
413+
publisher_model_eula_acked=True,
414+
)
415+
)
416+
yield check_license_agreement_status
417+
418+
419+
@pytest.fixture
420+
def accept_model_license_agreement_mock():
421+
"""Mocks the accept_model_license_agreement method."""
422+
with mock.patch.object(
423+
model_garden_service.ModelGardenServiceClient,
424+
"accept_publisher_model_eula",
425+
) as accept_model_license_agreement:
426+
accept_model_license_agreement.return_value = (
427+
types.PublisherModelEulaAcceptance(
428+
project_number=_TEST_PROJECT_NUMBER,
429+
publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
430+
publisher_model_eula_acked=True,
431+
)
432+
)
433+
yield accept_model_license_agreement
434+
435+
401436
@pytest.mark.usefixtures(
402437
"google_auth_mock",
403438
"deploy_mock",
@@ -406,6 +441,8 @@ def list_publisher_models_mock():
406441
"export_publisher_model_mock",
407442
"batch_prediction_mock",
408443
"complete_bq_uri_mock",
444+
"check_license_agreement_status_mock",
445+
"accept_model_license_agreement_mock",
409446
)
410447
class TestModelGarden:
411448
"""Test cases for ModelGarden class."""
@@ -999,3 +1036,43 @@ def test_batch_prediction_success(self, batch_prediction_mock):
9991036
batch_prediction_job=expected_gapic_batch_prediction_job,
10001037
timeout=None,
10011038
)
1039+
1040+
def test_check_license_agreement_status_success(
1041+
self, check_license_agreement_status_mock
1042+
):
1043+
"""Tests checking EULA acceptance for a model."""
1044+
aiplatform.init(
1045+
project=_TEST_PROJECT,
1046+
location=_TEST_LOCATION,
1047+
)
1048+
model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
1049+
eula_acceptance = model.check_license_agreement_status()
1050+
check_license_agreement_status_mock.assert_called_once_with(
1051+
types.CheckPublisherModelEulaAcceptanceRequest(
1052+
parent=f"projects/{_TEST_PROJECT}",
1053+
publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
1054+
)
1055+
)
1056+
assert eula_acceptance
1057+
1058+
def test_accept_model_license_agreement_success(
1059+
self, accept_model_license_agreement_mock
1060+
):
1061+
"""Tests accepting EULA for a model."""
1062+
aiplatform.init(
1063+
project=_TEST_PROJECT,
1064+
location=_TEST_LOCATION,
1065+
)
1066+
model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
1067+
eula_acceptance = model.accept_model_license_agreement()
1068+
accept_model_license_agreement_mock.assert_called_once_with(
1069+
types.AcceptPublisherModelEulaRequest(
1070+
parent=f"projects/{_TEST_PROJECT}",
1071+
publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
1072+
)
1073+
)
1074+
assert eula_acceptance == types.PublisherModelEulaAcceptance(
1075+
project_number=_TEST_PROJECT_NUMBER,
1076+
publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
1077+
publisher_model_eula_acked=True,
1078+
)

vertexai/model_garden/_model_garden.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,3 +715,44 @@ def batch_predict(
715715
starting_replica_count=starting_replica_count,
716716
max_replica_count=max_replica_count,
717717
)
718+
719+
def check_license_agreement_status(self) -> bool:
720+
"""Check whether the project has accepted the license agreement of the model.
721+
722+
EULA (End User License Agreement) is a legal document that the user must
723+
accept before using the model. For Models having license restrictions,
724+
the user must accept the EULA before using the model. You can check the
725+
details of the License in Model Garden.
726+
727+
Returns:
728+
bool : True if the project has accepted the End User License
729+
Agreement, False otherwise.
730+
"""
731+
request = types.CheckPublisherModelEulaAcceptanceRequest(
732+
parent=f"projects/{self._project}",
733+
publisher_model=self._publisher_model_name,
734+
)
735+
response = self._model_garden_client.check_publisher_model_eula_acceptance(
736+
request
737+
)
738+
return response.publisher_model_eula_acked
739+
740+
def accept_model_license_agreement(
741+
self,
742+
) -> types.model_garden_service.PublisherModelEulaAcceptance:
743+
"""Accepts the EULA(End User License Agreement) of the model for the project.
744+
745+
For Models having license restrictions, the user must accept the EULA
746+
before using the model. Calling this method will mark the EULA as accepted
747+
for the project.
748+
749+
Returns:
750+
types.model_garden_service.PublisherModelEulaAcceptance:
751+
The response of the accept_eula call, containing project number,
752+
model name and acceptance status.
753+
"""
754+
request = types.AcceptPublisherModelEulaRequest(
755+
parent=f"projects/{self._project}",
756+
publisher_model=self._publisher_model_name,
757+
)
758+
return self._model_garden_client.accept_publisher_model_eula(request)

0 commit comments

Comments
 (0)