Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions automl/google/cloud/automl_v1beta1/tables/gcs_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

"""Wraps the Google Cloud Storage client library for use in tables helper."""

import logging
import time

from google.api_core import exceptions
Expand All @@ -30,19 +31,22 @@
except ImportError: # pragma: NO COVER
storage = None

_LOGGER = logging.getLogger(__name__)
_PANDAS_REQUIRED = "pandas is required to verify type DataFrame."
_STORAGE_REQUIRED = (
"google-cloud-storage is required to create Google Cloud Storage client."
"google-cloud-storage is required to create a Google Cloud Storage client."
)


class GcsClient(object):
"""Uploads Pandas DataFrame to a bucket in Google Cloud Storage."""

def __init__(self, client=None, credentials=None):
def __init__(self, bucket_name=None, client=None, credentials=None):
"""Constructor.

Args:
bucket_name (Optional[str]): The name of Google Cloud Storage
bucket for this client to send requests to.
client (Optional[storage.Client]): A Google Cloud Storage Client
instance.
credentials (Optional[google.auth.credentials.Credentials]): The
Expand All @@ -61,10 +65,18 @@ def __init__(self, client=None, credentials=None):
else:
self.client = storage.Client()

self.bucket_name = bucket_name

def ensure_bucket_exists(self, project, region):
"""Checks if a bucket named '{project}-automl-tables-staging' exists.

Creates this bucket if it doesn't exist.
If this bucket doesn't exist, creates one.
If this bucket already exists in `project`, do nothing.
If this bucket exists in a different project that we don't have
access to, creates a bucket named
'{project}-automl-tables-staging-{create_timestamp}' because bucket's
name must be globally unique.
Save the created bucket's name and reuse this for future requests.

Args:
project (str): The project that stores the bucket.
Expand All @@ -73,20 +85,30 @@ def ensure_bucket_exists(self, project, region):
Returns:
A string representing the created bucket name.
"""
bucket_name = "{}-automl-tables-staging".format(project)
if self.bucket_name is None:
self.bucket_name = "{}-automl-tables-staging".format(project)

try:
self.client.get_bucket(bucket_name)
except exceptions.NotFound:
bucket = self.client.bucket(bucket_name)
self.client.get_bucket(self.bucket_name)
except (exceptions.Forbidden, exceptions.NotFound) as e:
if isinstance(e, exceptions.Forbidden):
used_bucket_name = self.bucket_name
self.bucket_name = used_bucket_name + "-{}".format(int(time.time()))
_LOGGER.warning(
"Created a bucket named {} because a bucket named {} already exists in a different project.".format(
self.bucket_name, used_bucket_name
)
)

bucket = self.client.bucket(self.bucket_name)
bucket.create(project=project, location=region)
return bucket_name

def upload_pandas_dataframe(self, bucket_name, dataframe, uploaded_csv_name=None):
return self.bucket_name

def upload_pandas_dataframe(self, dataframe, uploaded_csv_name=None):
"""Uploads a Pandas DataFrame as CSV to the bucket.

Args:
bucket_name (str): The bucket name to upload the CSV to.
dataframe (pandas.DataFrame): The Pandas Dataframe to be uploaded.
uploaded_csv_name (Optional[str]): The name for the uploaded CSV.

Expand All @@ -99,14 +121,17 @@ def upload_pandas_dataframe(self, bucket_name, dataframe, uploaded_csv_name=None
if not isinstance(dataframe, pandas.DataFrame):
raise ValueError("'dataframe' must be a pandas.DataFrame instance.")

if self.bucket_name is None:
raise ValueError("Must ensure a bucket exists before uploading data.")

if uploaded_csv_name is None:
uploaded_csv_name = "automl-tables-dataframe-{}.csv".format(
int(time.time())
)
csv_string = dataframe.to_csv()

bucket = self.client.get_bucket(bucket_name)
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(uploaded_csv_name)
blob.upload_from_string(csv_string)

return "gs://{}/{}".format(bucket_name, uploaded_csv_name)
return "gs://{}/{}".format(self.bucket_name, uploaded_csv_name)
15 changes: 4 additions & 11 deletions automl/google/cloud/automl_v1beta1/tables/tables_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import pkg_resources
import logging
import google.auth

from google.api_core.gapic_v1 import client_info
from google.api_core import exceptions
Expand Down Expand Up @@ -418,8 +417,6 @@ def __ensure_gcs_client_is_initialized(self, credentials=None):
credentials from the environment.
"""
if self.gcs_client is None:
if credentials is None:
credentials, _ = google.auth.default()
self.gcs_client = gcs_client.GcsClient(credentials=credentials)

def list_datasets(self, project=None, region=None, **kwargs):
Expand Down Expand Up @@ -757,10 +754,8 @@ def import_data(

if pandas_dataframe is not None:
self.__ensure_gcs_client_is_initialized(credentials)
bucket_name = self.gcs_client.ensure_bucket_exists(project, region)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(
bucket_name, pandas_dataframe
)
self.gcs_client.ensure_bucket_exists(project, region)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(pandas_dataframe)
request = {"gcs_source": {"input_uris": [gcs_input_uri]}}
elif gcs_input_uris is not None:
if type(gcs_input_uris) != list:
Expand Down Expand Up @@ -2750,10 +2745,8 @@ def batch_predict(

if pandas_dataframe is not None:
self.__ensure_gcs_client_is_initialized(credentials)
bucket_name = self.gcs_client.ensure_bucket_exists(project, region)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(
bucket_name, pandas_dataframe
)
self.gcs_client.ensure_bucket_exists(project, region)
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(pandas_dataframe)
input_request = {"gcs_source": {"input_uris": [gcs_input_uri]}}
elif gcs_input_uris is not None:
if type(gcs_input_uris) != list:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def test_import_data(self):
self.cancel_and_wait(op)
client.delete_dataset(dataset=dataset)

@unittest.skipIf(RUNNING_IN_VPCSC, "Test is not VPCSC compatible.")
def test_import_pandas_dataframe(self):
client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION)
display_name = _id("t_import_pandas")
Expand Down
80 changes: 67 additions & 13 deletions automl/tests/unit/gapic/v1beta1/test_gcs_client_v1beta1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@


class TestGcsClient(object):
def gcs_client(self, client_attrs={}):
def gcs_client(self, bucket_name=None, client_attrs={}):
client_mock = mock.Mock(**client_attrs)
return automl_v1beta1.tables.gcs_client.GcsClient(client=client_mock)
return automl_v1beta1.tables.gcs_client.GcsClient(
bucket_name=bucket_name, client=client_mock
)

def test_ensure_bucket_exists(self):
mock_bucket = mock.Mock()
gcs_client = self.gcs_client(
{
client_attrs={
"get_bucket.side_effect": exceptions.NotFound("err"),
"bucket.return_value": mock_bucket,
}
Expand All @@ -48,10 +50,32 @@ def test_ensure_bucket_exists(self):
mock_bucket.create.assert_called_with(
project="my-project", location="us-central1"
)

assert returned_bucket_name == "my-project-automl-tables-staging"

def test_ensure_bucket_exists_bucket_already_exists(self):
def test_ensure_bucket_exists_bucket_already_exists_in_different_project(self):
mock_bucket = mock.Mock()
gcs_client = self.gcs_client(
client_attrs={
"get_bucket.side_effect": exceptions.Forbidden("err"),
"bucket.return_value": mock_bucket,
}
)
returned_bucket_name = gcs_client.ensure_bucket_exists(
"my-project", "us-central1"
)
gcs_client.client.get_bucket.assert_called_with(
"my-project-automl-tables-staging"
)
gcs_client.client.bucket.assert_called_with(returned_bucket_name)
mock_bucket.create.assert_called_with(
project="my-project", location="us-central1"
)

assert re.match(
"^my-project-automl-tables-staging-[0-9]*$", returned_bucket_name
)

def test_ensure_bucket_exists_bucket_already_exists_in_current_project(self):
gcs_client = self.gcs_client()
returned_bucket_name = gcs_client.ensure_bucket_exists(
"my-project", "us-central1"
Expand All @@ -62,15 +86,35 @@ def test_ensure_bucket_exists_bucket_already_exists(self):
gcs_client.client.bucket.assert_not_called()
assert returned_bucket_name == "my-project-automl-tables-staging"

def test_ensure_bucket_exists_custom_bucket_name(self):
mock_bucket = mock.Mock()
gcs_client = self.gcs_client(
bucket_name="my-bucket",
client_attrs={
"get_bucket.side_effect": exceptions.NotFound("err"),
"bucket.return_value": mock_bucket,
},
)
returned_bucket_name = gcs_client.ensure_bucket_exists(
"my-project", "us-central1"
)
gcs_client.client.get_bucket.assert_called_with("my-bucket")
gcs_client.client.bucket.assert_called_with("my-bucket")
mock_bucket.create.assert_called_with(
project="my-project", location="us-central1"
)
assert returned_bucket_name == "my-bucket"

def test_upload_pandas_dataframe(self):
mock_blob = mock.Mock()
mock_bucket = mock.Mock(**{"blob.return_value": mock_blob})
gcs_client = self.gcs_client({"get_bucket.return_value": mock_bucket})
gcs_client = self.gcs_client(
bucket_name="my-bucket",
client_attrs={"get_bucket.return_value": mock_bucket},
)
dataframe = pandas.DataFrame({"col1": [1, 2], "col2": [3, 4]})

gcs_uri = gcs_client.upload_pandas_dataframe(
"my-bucket", dataframe, "my-file.csv"
)
gcs_uri = gcs_client.upload_pandas_dataframe(dataframe, "my-file.csv")

gcs_client.client.get_bucket.assert_called_with("my-bucket")
mock_bucket.blob.assert_called_with("my-file.csv")
Expand All @@ -80,19 +124,29 @@ def test_upload_pandas_dataframe(self):
def test_upload_pandas_dataframe_no_csv_name(self):
mock_blob = mock.Mock()
mock_bucket = mock.Mock(**{"blob.return_value": mock_blob})
gcs_client = self.gcs_client({"get_bucket.return_value": mock_bucket})
gcs_client = self.gcs_client(
bucket_name="my-bucket",
client_attrs={"get_bucket.return_value": mock_bucket},
)
dataframe = pandas.DataFrame({"col1": [1, 2], "col2": [3, 4]})

gcs_uri = gcs_client.upload_pandas_dataframe("my-bucket", dataframe)
gcs_uri = gcs_client.upload_pandas_dataframe(dataframe)
generated_csv_name = gcs_uri.split("/")[-1]

gcs_client.client.get_bucket.assert_called_with("my-bucket")
mock_bucket.blob.assert_called_with(generated_csv_name)
mock_blob.upload_from_string.assert_called_with(",col1,col2\n0,1,3\n1,2,4\n")
assert re.match("gs://my-bucket/automl-tables-dataframe-([0-9]*).csv", gcs_uri)
assert re.match("^gs://my-bucket/automl-tables-dataframe-[0-9]*.csv$", gcs_uri)

def test_upload_pandas_dataframe_not_type_dataframe(self):
gcs_client = self.gcs_client()
with pytest.raises(ValueError):
gcs_client.upload_pandas_dataframe("my-bucket", "my-dataframe")
gcs_client.upload_pandas_dataframe("my-dataframe")
gcs_client.client.upload_pandas_dataframe.assert_not_called()

def test_upload_pandas_dataframe_bucket_not_exist(self):
gcs_client = self.gcs_client()
dataframe = pandas.DataFrame({})
with pytest.raises(ValueError):
gcs_client.upload_pandas_dataframe(dataframe)
gcs_client.client.upload_pandas_dataframe.assert_not_called()
12 changes: 4 additions & 8 deletions automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_import_not_found(self):
def test_import_pandas_dataframe(self):
client = self.tables_client(
gcs_client_attrs={
"ensure_bucket_exists.return_value": "my_bucket",
"bucket_name": "my_bucket",
"upload_pandas_dataframe.return_value": "uri",
}
)
Expand All @@ -209,9 +209,7 @@ def test_import_pandas_dataframe(self):
pandas_dataframe=dataframe,
)
client.gcs_client.ensure_bucket_exists.assert_called_with(PROJECT, REGION)
client.gcs_client.upload_pandas_dataframe.assert_called_with(
"my_bucket", dataframe
)
client.gcs_client.upload_pandas_dataframe.assert_called_with(dataframe)
client.auto_ml_client.import_data.assert_called_with(
"name", {"gcs_source": {"input_uris": ["uri"]}}
)
Expand Down Expand Up @@ -1200,7 +1198,7 @@ def test_predict_from_array_missing(self):
def test_batch_predict_pandas_dataframe(self):
client = self.tables_client(
gcs_client_attrs={
"ensure_bucket_exists.return_value": "my_bucket",
"bucket_name": "my_bucket",
"upload_pandas_dataframe.return_value": "gs://input",
}
)
Expand All @@ -1214,9 +1212,7 @@ def test_batch_predict_pandas_dataframe(self):
)

client.gcs_client.ensure_bucket_exists.assert_called_with(PROJECT, REGION)
client.gcs_client.upload_pandas_dataframe.assert_called_with(
"my_bucket", dataframe
)
client.gcs_client.upload_pandas_dataframe.assert_called_with(dataframe)

client.prediction_client.batch_predict.assert_called_with(
"my_model",
Expand Down