Skip to content

Commit 141c008

Browse files
authored
feat: Add training_utils folder and environment_variables for training
1 parent 0477f5a commit 141c008

File tree

3 files changed

+285
-0
lines changed

3 files changed

+285
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import json
19+
import os
20+
21+
from typing import Dict, Optional
22+
23+
24+
def _json_helper(env_var: str) -> Optional[Dict]:
25+
"""Helper to convert a dictionary represented as a string to a dictionary.
26+
27+
Args:
28+
env_var (str):
29+
Required. The name of the environment variable.
30+
31+
Returns:
32+
A dictionary if the variable was found, None otherwise.
33+
"""
34+
env = os.environ.get(env_var)
35+
if env is not None:
36+
return json.loads(env)
37+
else:
38+
return None
39+
40+
41+
# Cloud Storage URI of a directory intended for training data.
42+
training_data_uri = os.environ.get("AIP_TRAINING_DATA_URI")
43+
44+
# Cloud Storage URI of a directory intended for validation data.
45+
validation_data_uri = os.environ.get("AIP_VALIDATION_DATA_URI")
46+
47+
# Cloud Storage URI of a directory intended for test data.
48+
test_data_uri = os.environ.get("AIP_TEST_DATA_URI")
49+
50+
# Cloud Storage URI of a directory intended for saving model artefacts.
51+
model_dir = os.environ.get("AIP_MODEL_DIR")
52+
53+
# Cloud Storage URI of a directory intended for saving checkpoints.
54+
checkpoint_dir = os.environ.get("AIP_CHECKPOINT_DIR")
55+
56+
# Cloud Storage URI of a directory intended for saving TensorBoard logs.
57+
tensorboard_log_dir = os.environ.get("AIP_TENSORBOARD_LOG_DIR")
58+
59+
# json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#cluster-variables
60+
cluster_spec = _json_helper("CLUSTER_SPEC")
61+
62+
# json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#tf-config
63+
tf_config = _json_helper("TF_CONFIG")
64+
65+
# Profiler port used for capturing profiling samples.
66+
tf_profiler_port = os.environ.get("AIP_TF_PROFILER_PORT")
67+
68+
# API URI used for the tensorboard uploader.
69+
tensorboard_api_uri = os.environ.get("AIP_TENSORBOARD_API_URI")
70+
71+
# The name of the tensorboard resource, in the form:
72+
# `projects/{project_id}/locations/{location}/tensorboards/{tensorboard_name}`
73+
tensorboard_resource_name = os.environ.get("AIP_TENSORBOARD_RESOURCE_NAME")
74+
75+
# The name given to the training job.
76+
cloud_ml_job_id = os.environ.get("CLOUD_ML_JOB_ID")
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from importlib import reload
19+
import json
20+
import os
21+
import pytest
22+
23+
from google.cloud.aiplatform.training_utils import environment_variables
24+
from unittest import mock
25+
26+
_TEST_TRAINING_DATA_URI = "gs://training-data-uri"
27+
_TEST_VALIDATION_DATA_URI = "gs://test-validation-data-uri"
28+
_TEST_TEST_DATA_URI = "gs://test-data-uri"
29+
_TEST_MODEL_DIR = "gs://test-model-dir"
30+
_TEST_CHECKPOINT_DIR = "gs://test-checkpoint-dir"
31+
_TEST_TENSORBOARD_LOG_DIR = "gs://test-tensorboard-log-dir"
32+
_TEST_CLUSTER_SPEC = """{
33+
"cluster": {
34+
"worker_pools":[
35+
{
36+
"index":0,
37+
"replicas":[
38+
"training-workerpool0-ab-0:2222"
39+
]
40+
},
41+
{
42+
"index":1,
43+
"replicas":[
44+
"training-workerpool1-ab-0:2222",
45+
"training-workerpool1-ab-1:2222"
46+
]
47+
}
48+
]
49+
},
50+
"environment": "cloud",
51+
"task": {
52+
"worker_pool_index":0,
53+
"replica_index":0,
54+
"trial":"TRIAL_ID"
55+
}
56+
}"""
57+
_TEST_AIP_TF_PROFILER_PORT = "1234"
58+
_TEST_TENSORBOARD_API_URI = "http://testuri.com"
59+
_TEST_TENSORBOARD_RESOURCE_NAME = (
60+
"projects/myproj/locations/us-central1/tensorboards/1234"
61+
)
62+
_TEST_CLOUD_ML_JOB_ID = "myjob"
63+
64+
65+
class TestTrainingUtils:
66+
@pytest.fixture
67+
def mock_environment(self):
68+
env_vars = {
69+
"AIP_TRAINING_DATA_URI": _TEST_TRAINING_DATA_URI,
70+
"AIP_VALIDATION_DATA_URI": _TEST_VALIDATION_DATA_URI,
71+
"AIP_TEST_DATA_URI": _TEST_TEST_DATA_URI,
72+
"AIP_MODEL_DIR": _TEST_MODEL_DIR,
73+
"AIP_CHECKPOINT_DIR": _TEST_CHECKPOINT_DIR,
74+
"AIP_TENSORBOARD_LOG_DIR": _TEST_TENSORBOARD_LOG_DIR,
75+
"AIP_TF_PROFILER_PORT": _TEST_AIP_TF_PROFILER_PORT,
76+
"AIP_TENSORBOARD_API_URI": _TEST_TENSORBOARD_API_URI,
77+
"AIP_TENSORBOARD_RESOURCE_NAME": _TEST_TENSORBOARD_RESOURCE_NAME,
78+
"CLOUD_ML_JOB_ID": _TEST_CLOUD_ML_JOB_ID,
79+
"CLUSTER_SPEC": _TEST_CLUSTER_SPEC,
80+
"TF_CONFIG": _TEST_CLUSTER_SPEC,
81+
}
82+
with mock.patch.dict(os.environ, env_vars, clear=True):
83+
yield
84+
85+
@pytest.mark.usefixtures("mock_environment")
86+
def test_training_data_uri(self):
87+
reload(environment_variables)
88+
assert environment_variables.training_data_uri == _TEST_TRAINING_DATA_URI
89+
90+
def test_training_data_uri_none(self):
91+
reload(environment_variables)
92+
assert environment_variables.training_data_uri is None
93+
94+
@pytest.mark.usefixtures("mock_environment")
95+
def test_validation_data_uri(self):
96+
reload(environment_variables)
97+
assert environment_variables.validation_data_uri == _TEST_VALIDATION_DATA_URI
98+
99+
def test_validation_data_uri_none(self):
100+
reload(environment_variables)
101+
assert environment_variables.validation_data_uri is None
102+
103+
@pytest.mark.usefixtures("mock_environment")
104+
def test_test_data_uri(self):
105+
reload(environment_variables)
106+
assert environment_variables.test_data_uri == _TEST_TEST_DATA_URI
107+
108+
def test_test_data_uri_none(self):
109+
reload(environment_variables)
110+
assert environment_variables.test_data_uri is None
111+
112+
@pytest.mark.usefixtures("mock_environment")
113+
def test_model_dir(self):
114+
reload(environment_variables)
115+
assert environment_variables.model_dir == _TEST_MODEL_DIR
116+
117+
def test_model_dir_none(self):
118+
reload(environment_variables)
119+
assert environment_variables.model_dir is None
120+
121+
@pytest.mark.usefixtures("mock_environment")
122+
def test_checkpoint_dir(self):
123+
reload(environment_variables)
124+
assert environment_variables.checkpoint_dir == _TEST_CHECKPOINT_DIR
125+
126+
def test_checkpoint_dir_none(self):
127+
reload(environment_variables)
128+
assert environment_variables.checkpoint_dir is None
129+
130+
@pytest.mark.usefixtures("mock_environment")
131+
def test_tensorboard_log_dir(self):
132+
reload(environment_variables)
133+
assert environment_variables.tensorboard_log_dir == _TEST_TENSORBOARD_LOG_DIR
134+
135+
def test_tensorboard_log_dir_none(self):
136+
reload(environment_variables)
137+
assert environment_variables.tensorboard_log_dir is None
138+
139+
@pytest.mark.usefixtures("mock_environment")
140+
def test_cluster_spec(self):
141+
reload(environment_variables)
142+
assert environment_variables.cluster_spec == json.loads(_TEST_CLUSTER_SPEC)
143+
144+
def test_cluster_spec_none(self):
145+
reload(environment_variables)
146+
assert environment_variables.cluster_spec is None
147+
148+
@pytest.mark.usefixtures("mock_environment")
149+
def test_tf_config(self):
150+
reload(environment_variables)
151+
assert environment_variables.tf_config == json.loads(_TEST_CLUSTER_SPEC)
152+
153+
def test_tf_config_none(self):
154+
reload(environment_variables)
155+
assert environment_variables.tf_config is None
156+
157+
@pytest.mark.usefixtures("mock_environment")
158+
def test_tf_profiler_port(self):
159+
reload(environment_variables)
160+
assert environment_variables.tf_profiler_port == _TEST_AIP_TF_PROFILER_PORT
161+
162+
def test_tf_profiler_port_none(self):
163+
reload(environment_variables)
164+
assert environment_variables.tf_profiler_port is None
165+
166+
@pytest.mark.usefixtures("mock_environment")
167+
def test_tensorboard_api_uri(self):
168+
reload(environment_variables)
169+
assert environment_variables.tensorboard_api_uri == _TEST_TENSORBOARD_API_URI
170+
171+
def test_tensorboard_api_uri_none(self):
172+
reload(environment_variables)
173+
assert environment_variables.tensorboard_api_uri is None
174+
175+
@pytest.mark.usefixtures("mock_environment")
176+
def test_tensorboard_resource_name(self):
177+
reload(environment_variables)
178+
assert (
179+
environment_variables.tensorboard_resource_name
180+
== _TEST_TENSORBOARD_RESOURCE_NAME
181+
)
182+
183+
def test_tensorboard_resource_name_none(self):
184+
reload(environment_variables)
185+
assert environment_variables.tensorboard_resource_name is None
186+
187+
@pytest.mark.usefixtures("mock_environment")
188+
def test_cloud_ml_job_id(self):
189+
reload(environment_variables)
190+
assert environment_variables.cloud_ml_job_id == _TEST_CLOUD_ML_JOB_ID
191+
192+
def test_cloud_ml_job_id_none(self):
193+
reload(environment_variables)
194+
assert environment_variables.cloud_ml_job_id is None

0 commit comments

Comments
 (0)