Skip to content

Commit d11b8e6

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: Allow setting default service account
PiperOrigin-RevId: 559266585
1 parent 7eaa1d4 commit d11b8e6

File tree

15 files changed

+102
-7
lines changed

15 files changed

+102
-7
lines changed

google/cloud/aiplatform/initializer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(self):
9898
self._credentials = None
9999
self._encryption_spec_key_name = None
100100
self._network = None
101+
self._service_account = None
101102

102103
def init(
103104
self,
@@ -113,6 +114,7 @@ def init(
113114
credentials: Optional[auth_credentials.Credentials] = None,
114115
encryption_spec_key_name: Optional[str] = None,
115116
network: Optional[str] = None,
117+
service_account: Optional[str] = None,
116118
):
117119
"""Updates common initialization parameters with provided options.
118120
@@ -155,6 +157,12 @@ def init(
155157
Private services access must already be configured for the network.
156158
If specified, all eligible jobs and resources created will be peered
157159
with this VPC.
160+
service_account (str):
161+
Optional. The service account used to launch jobs and deploy models.
162+
Jobs that use service_account: BatchPredictionJob, CustomJob,
163+
PipelineJob, HyperparameterTuningJob, CustomTrainingJob,
164+
CustomPythonPackageTrainingJob, CustomContainerTrainingJob,
165+
ModelEvaluationJob.
158166
Raises:
159167
ValueError:
160168
If experiment_description is provided but experiment is not.
@@ -194,6 +202,8 @@ def init(
194202
self._encryption_spec_key_name = encryption_spec_key_name
195203
if network is not None:
196204
self._network = network
205+
if service_account is not None:
206+
self._service_account = service_account
197207

198208
if experiment:
199209
metadata._experiment_tracker.set_experiment(
@@ -297,6 +307,11 @@ def network(self) -> Optional[str]:
297307
"""Default Compute Engine network to peer to, if provided."""
298308
return self._network
299309

310+
@property
311+
def service_account(self) -> Optional[str]:
312+
"""Default service account, if provided."""
313+
return self._service_account
314+
300315
@property
301316
def experiment_name(self) -> Optional[str]:
302317
"""Default experiment name, if provided."""

google/cloud/aiplatform/jobs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def create(
761761
)
762762
gapic_batch_prediction_job.explanation_spec = explanation_spec
763763

764+
service_account = service_account or initializer.global_config.service_account
764765
if service_account:
765766
gapic_batch_prediction_job.service_account = service_account
766767

@@ -1693,6 +1694,7 @@ def run(
16931694
`restart_job_on_worker_restart` to False.
16941695
"""
16951696
network = network or initializer.global_config.network
1697+
service_account = service_account or initializer.global_config.service_account
16961698

16971699
self._run(
16981700
service_account=service_account,
@@ -1880,6 +1882,8 @@ def submit(
18801882
raise ValueError(
18811883
"'experiment' is required since you've enabled autolog in 'from_local_script'."
18821884
)
1885+
1886+
service_account = service_account or initializer.global_config.service_account
18831887
if service_account:
18841888
self._gca_resource.job_spec.service_account = service_account
18851889

@@ -2356,6 +2360,7 @@ def run(
23562360
`restart_job_on_worker_restart` to False.
23572361
"""
23582362
network = network or initializer.global_config.network
2363+
service_account = service_account or initializer.global_config.service_account
23592364

23602365
self._run(
23612366
service_account=service_account,

google/cloud/aiplatform/model_evaluation/model_evaluation_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def submit(
278278
Returns:
279279
(ModelEvaluationJob): Instantiated represnetation of the model evaluation job.
280280
"""
281+
service_account = service_account or initializer.global_config.service_account
281282

282283
if isinstance(model_name, aiplatform.Model):
283284
model_resource_name = model_name.versioned_resource_name

google/cloud/aiplatform/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,7 @@ def _deploy_call(
10961096
to the resource project.
10971097
Users deploying the Model must have the `iam.serviceAccounts.actAs`
10981098
permission on this service account.
1099+
If not specified, uses the service account set in aiplatform.init.
10991100
explanation_spec (aiplatform.explain.ExplanationSpec):
11001101
Optional. Specification of Model explanation.
11011102
metadata (Sequence[Tuple[str, str]]):
@@ -1120,6 +1121,8 @@ def _deploy_call(
11201121
is not 0 or 100.
11211122
"""
11221123

1124+
service_account = service_account or initializer.global_config.service_account
1125+
11231126
max_replica_count = max(min_replica_count, max_replica_count)
11241127

11251128
if bool(accelerator_type) != bool(accelerator_count):

google/cloud/aiplatform/pipeline_job_schedules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _create(
226226
if max_concurrent_run_count:
227227
self._gca_resource.max_concurrent_run_count = max_concurrent_run_count
228228

229+
service_account = service_account or initializer.global_config.service_account
229230
network = network or initializer.global_config.network
230231

231232
if service_account:

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def submit(
382382
current Experiment Run.
383383
"""
384384
network = network or initializer.global_config.network
385+
service_account = service_account or initializer.global_config.service_account
385386

386387
if service_account:
387388
self._gca_resource.service_account = service_account

google/cloud/aiplatform/training_jobs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3223,6 +3223,7 @@ def run(
32233223
produce a Vertex AI Model.
32243224
"""
32253225
network = network or initializer.global_config.network
3226+
service_account = service_account or initializer.global_config.service_account
32263227

32273228
worker_pool_specs, managed_model = self._prepare_and_validate_run(
32283229
model_display_name=model_display_name,
@@ -4579,6 +4580,7 @@ def run(
45794580
were not provided in constructor.
45804581
"""
45814582
network = network or initializer.global_config.network
4583+
service_account = service_account or initializer.global_config.service_account
45824584

45834585
worker_pool_specs, managed_model = self._prepare_and_validate_run(
45844586
model_display_name=model_display_name,
@@ -7348,6 +7350,7 @@ def run(
73487350
service_account (str):
73497351
Specifies the service account for workload run-as account.
73507352
Users submitting jobs must have act-as permission on this run-as account.
7353+
If not specified, uses the service account set in aiplatform.init.
73517354
network (str):
73527355
The full name of the Compute Engine network to which the job
73537356
should be peered. For example, projects/12345/global/networks/myVPC.
@@ -7501,6 +7504,7 @@ def run(
75017504
produce a Vertex AI Model.
75027505
"""
75037506
network = network or initializer.global_config.network
7507+
service_account = service_account or initializer.global_config.service_account
75047508

75057509
worker_pool_specs, managed_model = self._prepare_and_validate_run(
75067510
model_display_name=model_display_name,

google/cloud/aiplatform/utils/gcs_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
217217
"""
218218
project = project or initializer.global_config.project
219219
location = location or initializer.global_config.location
220+
service_account = service_account or initializer.global_config.service_account
220221
credentials = credentials or initializer.global_config.credentials
221222

222223
output_artifacts_gcs_dir = (

samples/model-builder/init_sample.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def init_sample(
2525
staging_bucket: Optional[str] = None,
2626
credentials: Optional[auth_credentials.Credentials] = None,
2727
encryption_spec_key_name: Optional[str] = None,
28+
service_account: Optional[str] = None,
2829
):
2930

3031
from google.cloud import aiplatform
@@ -36,6 +37,7 @@ def init_sample(
3637
staging_bucket=staging_bucket,
3738
credentials=credentials,
3839
encryption_spec_key_name=encryption_spec_key_name,
40+
service_account=service_account,
3941
)
4042

4143

samples/model-builder/init_sample_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_init_sample(mock_sdk_init):
2626
staging_bucket=constants.STAGING_BUCKET,
2727
credentials=constants.CREDENTIALS,
2828
encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME,
29+
service_account=constants.SERVICE_ACCOUNT,
2930
)
3031

3132
mock_sdk_init.assert_called_once_with(
@@ -35,4 +36,5 @@ def test_init_sample(mock_sdk_init):
3536
staging_bucket=constants.STAGING_BUCKET,
3637
credentials=constants.CREDENTIALS,
3738
encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME,
39+
service_account=constants.SERVICE_ACCOUNT,
3840
)

0 commit comments

Comments
 (0)