Skip to content

Commit 6356e96

Browse files
sasha-gitgmorgandu
authored andcommitted
fix: add v1 conversion value rule
1 parent 3ce0163 commit 6356e96

File tree

7 files changed

+35
-37
lines changed

7 files changed

+35
-37
lines changed

google/cloud/aiplatform/helpers/_decorators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,5 @@ def _from_map(map_):
6868

6969
marshal = Marshal(name="google.cloud.aiplatform.v1beta1")
7070
marshal.register(Value, ConversionValueRule(marshal=marshal))
71+
marshal = Marshal(name="google.cloud.aiplatform.v1")
72+
marshal.register(Value, ConversionValueRule(marshal=marshal))

google/cloud/aiplatform/initializer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
encryption_spec_v1beta1 as gca_encryption_spec_v1beta1,
3939
)
4040

41+
4142
class _Config:
4243
"""Stores common parameters and options for API calls."""
4344

@@ -100,7 +101,12 @@ def get_encryption_spec(
100101
self,
101102
encryption_spec_key_name: Optional[str],
102103
select_version: Optional[str] = compat.DEFAULT_VERSION,
103-
) -> Optional[Union[gca_encryption_spec_v1.EncryptionSpec, gca_encryption_spec_v1beta1.EncryptionSpec]]:
104+
) -> Optional[
105+
Union[
106+
gca_encryption_spec_v1.EncryptionSpec,
107+
gca_encryption_spec_v1beta1.EncryptionSpec,
108+
]
109+
]:
104110
"""Creates a gca_encryption_spec.EncryptionSpec instance from the given key name.
105111
If the provided key name is None, it uses the default key name if provided.
106112

google/cloud/aiplatform/jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def create(
487487
# Optional Fields
488488
gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec(
489489
encryption_spec_key_name=encryption_spec_key_name,
490-
select_version=select_version
490+
select_version=select_version,
491491
)
492492

493493
if model_parameters:

tests/unit/aiplatform/test_end_to_end.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,6 @@
2525
from google.cloud.aiplatform import schema
2626
from google.cloud.aiplatform import training_jobs
2727

28-
from google.cloud.aiplatform_v1beta1.types import (
29-
dataset as gca_dataset_v1beta1,
30-
encryption_spec as gca_encryption_spec_v1beta1,
31-
io as gca_io_v1beta1,
32-
model as gca_model_v1beta1,
33-
pipeline_state as gca_pipeline_state_v1beta1,
34-
training_pipeline as gca_training_pipeline_v1beta1,
35-
)
36-
3728
from google.cloud.aiplatform_v1.types import (
3829
dataset as gca_dataset,
3930
encryption_spec as gca_encryption_spec,

tests/unit/aiplatform/test_endpoints.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
from google.cloud.aiplatform import models
3030
from google.cloud.aiplatform import utils
3131

32-
from google.cloud.aiplatform_v1beta1.services.model_service import (
33-
client as model_service_client_v1beta1,
34-
)
3532
from google.cloud.aiplatform_v1beta1.services.endpoint_service import (
3633
client as endpoint_service_client_v1beta1,
3734
)
@@ -40,11 +37,9 @@
4037
)
4138
from google.cloud.aiplatform_v1beta1.types import (
4239
endpoint as gca_endpoint_v1beta1,
43-
model as gca_model_v1beta1,
4440
machine_resources as gca_machine_resources_v1beta1,
4541
prediction_service as gca_prediction_service_v1beta1,
4642
endpoint_service as gca_endpoint_service_v1beta1,
47-
encryption_spec as gca_encryption_spec_v1beta1,
4843
)
4944

5045
from google.cloud.aiplatform_v1.services.model_service import (
@@ -99,7 +94,9 @@
9994
_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100"
10095
_TEST_ACCELERATOR_COUNT = 2
10196

102-
_TEST_EXPLANATIONS = [gca_prediction_service_v1beta1.explanation.Explanation(attributions=[])]
97+
_TEST_EXPLANATIONS = [
98+
gca_prediction_service_v1beta1.explanation.Explanation(attributions=[])
99+
]
103100

104101
_TEST_ATTRIBUTIONS = [
105102
gca_prediction_service_v1beta1.explanation.Attribution(

tests/unit/aiplatform/test_jobs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,15 @@
3939
batch_prediction_job as gca_batch_prediction_job_v1beta1,
4040
explanation as gca_explanation_v1beta1,
4141
io as gca_io_v1beta1,
42-
job_state as gca_job_state_v1beta1,
4342
machine_resources as gca_machine_resources_v1beta1,
4443
)
4544

46-
from google.cloud.aiplatform_v1.services.job_service import (
47-
client as job_service_client,
48-
)
45+
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
4946

5047
from google.cloud.aiplatform_v1.types import (
5148
batch_prediction_job as gca_batch_prediction_job,
5249
io as gca_io,
5350
job_state as gca_job_state,
54-
machine_resources as gca_machine_resources,
5551
)
5652

5753
_TEST_PROJECT = "test-project"
@@ -485,7 +481,9 @@ def test_batch_predict_gcs_source_bq_dest(
485481

486482
@pytest.mark.parametrize("sync", [True, False])
487483
@pytest.mark.usefixtures("get_batch_prediction_job_mock")
488-
def test_batch_predict_with_all_args(self, create_batch_prediction_job_with_explanations_mock, sync):
484+
def test_batch_predict_with_all_args(
485+
self, create_batch_prediction_job_with_explanations_mock, sync
486+
):
489487
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
490488
creds = auth_credentials.AnonymousCredentials()
491489

@@ -518,7 +516,9 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_with_expl
518516
model=_TEST_MODEL_NAME,
519517
input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig(
520518
instances_format="jsonl",
521-
gcs_source=gca_io_v1beta1.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
519+
gcs_source=gca_io_v1beta1.GcsSource(
520+
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
521+
),
522522
),
523523
output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig(
524524
gcs_destination=gca_io_v1beta1.GcsDestination(

tests/unit/aiplatform/test_models.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
env_var as gca_env_var_v1beta1,
4444
explanation as gca_explanation_v1beta1,
4545
io as gca_io_v1beta1,
46-
job_state as gca_job_state_v1beta1,
4746
model as gca_model_v1beta1,
4847
endpoint as gca_endpoint_v1beta1,
4948
machine_resources as gca_machine_resources_v1beta1,
@@ -55,15 +54,12 @@
5554
from google.cloud.aiplatform_v1.services.endpoint_service import (
5655
client as endpoint_service_client,
5756
)
58-
from google.cloud.aiplatform_v1.services.job_service import (
59-
client as job_service_client,
60-
)
57+
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
6158
from google.cloud.aiplatform_v1.services.model_service import (
6259
client as model_service_client,
6360
)
6461
from google.cloud.aiplatform_v1.types import (
6562
batch_prediction_job as gca_batch_prediction_job,
66-
env_var as gca_env_var,
6763
io as gca_io,
6864
job_state as gca_job_state,
6965
model as gca_model,
@@ -184,6 +180,7 @@ def get_model_mock():
184180
)
185181
yield get_model_mock
186182

183+
187184
@pytest.fixture
188185
def get_model_with_explanations_mock():
189186
with mock.patch.object(
@@ -194,6 +191,7 @@ def get_model_with_explanations_mock():
194191
)
195192
yield get_model_mock
196193

194+
197195
@pytest.fixture
198196
def get_model_with_custom_location_mock():
199197
with mock.patch.object(
@@ -244,7 +242,6 @@ def upload_model_with_explanations_mock():
244242
yield upload_model_mock
245243

246244

247-
248245
@pytest.fixture
249246
def upload_model_with_custom_project_mock():
250247
with mock.patch.object(
@@ -300,7 +297,6 @@ def deploy_model_mock():
300297
yield deploy_model_mock
301298

302299

303-
304300
@pytest.fixture
305301
def deploy_model_with_explanations_mock():
306302
with mock.patch.object(
@@ -343,6 +339,7 @@ def create_batch_prediction_job_mock():
343339
create_batch_prediction_job_mock.return_value = batch_prediction_job_mock
344340
yield create_batch_prediction_job_mock
345341

342+
346343
@pytest.fixture
347344
def create_batch_prediction_job_with_explanations_mock():
348345
with mock.patch.object(
@@ -355,6 +352,7 @@ def create_batch_prediction_job_with_explanations_mock():
355352
create_batch_prediction_job_mock.return_value = batch_prediction_job_mock
356353
yield create_batch_prediction_job_mock
357354

355+
358356
@pytest.fixture
359357
def create_client_mock():
360358
with mock.patch.object(
@@ -746,7 +744,9 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync):
746744
"get_endpoint_mock", "get_model_mock", "create_endpoint_mock"
747745
)
748746
@pytest.mark.parametrize("sync", [True, False])
749-
def test_deploy_no_endpoint_with_explanations(self, deploy_model_with_explanations_mock, sync):
747+
def test_deploy_no_endpoint_with_explanations(
748+
self, deploy_model_with_explanations_mock, sync
749+
):
750750
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
751751
test_model = models.Model(_TEST_ID)
752752
test_endpoint = test_model.deploy(
@@ -834,9 +834,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
834834
),
835835
input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
836836
instances_format="jsonl",
837-
gcs_source=gca_io.GcsSource(
838-
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
839-
),
837+
gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
840838
),
841839
output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
842840
gcs_destination=gca_io.GcsDestination(
@@ -940,7 +938,9 @@ def test_batch_predict_gcs_source_bq_dest(
940938

941939
@pytest.mark.parametrize("sync", [True, False])
942940
@pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
943-
def test_batch_predict_with_all_args(self, create_batch_prediction_job_with_explanations_mock, sync):
941+
def test_batch_predict_with_all_args(
942+
self, create_batch_prediction_job_with_explanations_mock, sync
943+
):
944944
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
945945
test_model = models.Model(_TEST_ID)
946946
creds = auth_credentials.AnonymousCredentials()
@@ -977,7 +977,9 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_with_expl
977977
),
978978
input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig(
979979
instances_format="jsonl",
980-
gcs_source=gca_io_v1beta1.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
980+
gcs_source=gca_io_v1beta1.GcsSource(
981+
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
982+
),
981983
),
982984
output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig(
983985
gcs_destination=gca_io_v1beta1.GcsDestination(

0 commit comments

Comments
 (0)