Skip to content
169 changes: 156 additions & 13 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
deleting, publishing and unpublishing Firebase ML Kit models.
"""


import datetime
import numbers
import re
Expand All @@ -30,13 +31,27 @@
from firebase_admin import _utils
from firebase_admin import exceptions

# pylint: disable=import-error,no-name-in-module
try:
from firebase_admin import storage
_GCS_ENABLED = True
except ImportError:
_GCS_ENABLED = False

# pylint: disable=import-error,no-name-in-module
try:
import tensorflow as tf
_TF_ENABLED = True
except ImportError:
_TF_ENABLED = False

_MLKIT_ATTRIBUTE = '_mlkit'
_MAX_PAGE_SIZE = 100
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_GCS_TFLITE_URI_PATTERN = re.compile(
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
Expand Down Expand Up @@ -301,16 +316,16 @@ def model_format(self, model_format):
self._model_format = model_format #Can be None
return self

def as_dict(self):
def as_dict(self, for_upload=False):
copy = dict(self._data)
if self._model_format:
copy.update(self._model_format.as_dict())
copy.update(self._model_format.as_dict(for_upload=for_upload))
return copy


class ModelFormat(object):
"""Abstract base class representing a Model Format such as TFLite."""
def as_dict(self):
def as_dict(self, for_upload=False):
raise NotImplementedError


Expand Down Expand Up @@ -364,22 +379,70 @@ def model_source(self, model_source):
def size_bytes(self):
return self._data.get('sizeBytes')

def as_dict(self):
def as_dict(self, for_upload=False):
copy = dict(self._data)
if self._model_source:
copy.update(self._model_source.as_dict())
copy.update(self._model_source.as_dict(for_upload=for_upload))
return {'tfliteModel': copy}


class TFLiteModelSource(object):
"""Abstract base class representing a model source for TFLite format models."""
def as_dict(self):
def as_dict(self, for_upload=False):
raise NotImplementedError


class _CloudStorageClient(object):
"""Cloud Storage helper class"""

GCS_URI = 'gs://{0}/{1}'
BLOB_NAME = 'Firebase/MLKit/Models/{0}'

@staticmethod
def _assert_gcs_enabled():
if not _GCS_ENABLED:
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')

@staticmethod
def _parse_gcs_tflite_uri(uri):
# GCS Bucket naming rules are complex. The regex is not comprehensive.
# See https://cloud.google.com/storage/docs/naming for full details.
matcher = _GCS_TFLITE_URI_PATTERN.match(uri)
if not matcher:
raise ValueError('GCS TFLite URI format is invalid.')
return matcher.group('bucket_name'), matcher.group('blob_name')

@staticmethod
def upload(bucket_name, model_file_name, app):
_CloudStorageClient._assert_gcs_enabled()
bucket = storage.bucket(bucket_name, app=app)
blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name)
blob = bucket.blob(blob_name)
blob.upload_from_filename(model_file_name)
return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name)

@staticmethod
def sign_uri(gcs_tflite_uri, app):
"""Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri."""
_CloudStorageClient._assert_gcs_enabled()
bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri)
bucket = storage.bucket(bucket_name, app=app)
blob = bucket.blob(blob_name)
return blob.generate_signed_url(
version='v4',
expiration=datetime.timedelta(minutes=10),
method='GET'
)


class TFLiteGCSModelSource(TFLiteModelSource):
"""TFLite model source representing a tflite model file stored in GCS."""
def __init__(self, gcs_tflite_uri):

_STORAGE_CLIENT = _CloudStorageClient()

def __init__(self, gcs_tflite_uri, app=None):
self._app = app
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)

def __eq__(self, other):
Expand All @@ -391,6 +454,81 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
"""Uploads the model file to an existing Google Cloud Storage bucket.

Args:
model_file_name: The name of the model file.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: A Firebase app instance (or None to use the default app).

Returns:
TFLiteGCSModelSource: The source created from the model_file

Raises:
ImportError: If the Cloud Storage Library has not been installed.
"""
gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app)
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)

@staticmethod
def _assert_tf_version_1_enabled():
if not _TF_ENABLED:
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
'to install the tensorflow module.')
if not tf.VERSION.startswith('1.'):
raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION))

@classmethod
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.

Args:
saved_model_dir: The saved model directory.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: Optional. A Firebase app instance (or None to use the default app)

Returns:
TFLiteGCSModelSource: The source created from the saved_model_dir

Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_version_1_enabled()
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_mlkit_model.tflite', bucket_name, app)

@classmethod
def from_keras_model(cls, keras_model, bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.

Args:
keras_model: A tf.keras model.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: Optional. A Firebase app instance (or None to use the default app)

Returns:
TFLiteGCSModelSource: The source created from the keras_model

Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_version_1_enabled()
keras_file = 'keras_model.h5'
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_mlkit_model.tflite', bucket_name, app)

@property
def gcs_tflite_uri(self):
return self._gcs_tflite_uri
Expand All @@ -399,10 +537,15 @@ def gcs_tflite_uri(self):
def gcs_tflite_uri(self, gcs_tflite_uri):
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)

def as_dict(self):
return {"gcsTfliteUri": self._gcs_tflite_uri}
def _get_signed_gcs_tflite_uri(self):
"""Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified."""
return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app)

def as_dict(self, for_upload=False):
if for_upload:
return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()}

#TODO(ifielker): implement from_saved_model etc.
return {'gcsTfliteUri': self._gcs_tflite_uri}


class ListModelsPage(object):
Expand Down Expand Up @@ -671,13 +814,13 @@ def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict()))
self._client.body('post', url='models', json=model.as_dict(for_upload=True)))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def update_model(self, model, update_mask=None):
_validate_model(model, update_mask)
data = {'model': model.as_dict()}
data = {'model': model.as_dict(for_upload=True)}
if update_mask is not None:
data['updateMask'] = update_mask
try:
Expand Down
50 changes: 49 additions & 1 deletion tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@
}
}

GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite'
GCS_BUCKET_NAME = 'my_bucket'
GCS_BLOB_NAME = 'mymodel.tflite'
GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME)
GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI}
GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI)
TFLITE_FORMAT_JSON = {
Expand All @@ -112,6 +114,10 @@
}
TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON)

GCS_TFLITE_SIGNED_URI_PATTERN = (
'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo')
GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME)

GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite'
GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2}
GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2)
Expand Down Expand Up @@ -325,6 +331,18 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non
session_url, adapter(payload, status, recorder))
return recorder

class _TestStorageClient(object):
@staticmethod
def upload(bucket_name, model_file_name, app):
del app # unused variable
blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name)
return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name)

@staticmethod
def sign_uri(gcs_tflite_uri, app):
del app # unused variable
bucket_name, blob_name = mlkit._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri)
return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name)

class TestModel(object):
"""Tests mlkit.Model class."""
Expand All @@ -333,6 +351,7 @@ def setup_class(cls):
cred = testutils.MockCredential()
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test
mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient()

@classmethod
def teardown_class(cls):
Expand Down Expand Up @@ -404,6 +423,13 @@ def test_model_format_source_creation(self):
}
}

def test_source_creation_from_tflite_file(self):
model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file(
"my_model.tflite", "my_bucket")
assert model_source.as_dict() == {
'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite'
}

def test_model_source_setters(self):
model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI)
model_source.gcs_tflite_uri = GCS_TFLITE_URI_2
Expand All @@ -420,6 +446,27 @@ def test_model_format_setters(self):
}
}

def test_model_as_dict_for_upload(self):
model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
model_format = mlkit.TFLiteFormat(model_source=model_source)
model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
assert model.as_dict(for_upload=True) == {
'displayName': DISPLAY_NAME_1,
'tfliteModel': {
'gcsTfliteUri': GCS_TFLITE_SIGNED_URI
}
}

@pytest.mark.parametrize('helper_func', [
mlkit.TFLiteGCSModelSource.from_keras_model,
mlkit.TFLiteGCSModelSource.from_saved_model
])
def test_tf_not_enabled(self, helper_func):
mlkit._TF_ENABLED = False # for reliability
with pytest.raises(ImportError) as excinfo:
helper_func(None)
check_error(excinfo, ImportError)

@pytest.mark.parametrize('display_name, exc_type', [
('', ValueError),
('&_*#@:/?', ValueError),
Expand Down Expand Up @@ -803,6 +850,7 @@ def test_rpc_error(self, publish_function):
)
assert len(create_recorder) == 1


class TestGetModel(object):
"""Tests mlkit.get_model."""
@classmethod
Expand Down