Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,34 @@ def update_model(model, app=None):
return Model.from_dict(mlkit_service.update_model(model), app=app)


def publish_model(model_id, app=None):
"""Publishes a model in Firebase ML Kit.

Args:
model_id: The id of the model to publish.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The published model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.publish_model(model_id, publish=True), app=app)


def unpublish_model(model_id, app=None):
"""Unpublishes a model in Firebase ML Kit.

Args:
model_id: The id of the model to unpublish.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The unpublished model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.publish_model(model_id, publish=False), app=app)


def get_model(model_id, app=None):
"""Gets a model from Firebase ML Kit.

Expand Down Expand Up @@ -562,12 +590,12 @@ class _MLKitService(object):
POLL_BASE_WAIT_TIME_SECONDS = 3

def __init__(self, app):
project_id = app.project_id
if not project_id:
self._project_id = app.project_id
if not self._project_id:
raise ValueError(
'Project ID is required to access MLKit service. Either set the '
'projectId option, or use service account credentials.')
self._project_url = _MLKitService.PROJECT_URL.format(project_id)
self._project_url = _MLKitService.PROJECT_URL.format(self._project_id)
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=self._project_url)
Expand Down Expand Up @@ -595,7 +623,6 @@ def _exponential_backoff(self, current_attempt, stop_time):
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)


def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None):
"""Handles long running operations.

Expand Down Expand Up @@ -659,6 +686,17 @@ def update_model(self, model, update_mask=None):
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def publish_model(self, model_id, publish):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be call this set_published()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

_validate_model_id(model_id)
model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id)
model = Model.from_dict({
'name': model_name,
'state': {
'published': publish
}
})
return self.update_model(model, update_mask='state.published')

def get_model(self, model_id):
_validate_model_id(model_id)
try:
Expand Down
94 changes: 91 additions & 3 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@
PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \
'1 and {0}'.format(mlkit._MAX_PAGE_SIZE)
INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()]

PUBLISH_AND_UNPUBLISH_VALUE_ARGS = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PUBLISH_AND_UNPUBLISH_FUNCS_WITH_ARGS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pylint doesn't like more than 30 characters. Changed to PUBLISH_UNPUBLISH_WITH_ARGS

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of putting these in the module namespace, you might want to put them in the TestPublishUnpublish class.

Same goes for most other constants in this file. But you can do that in a future PR if you want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for these 2. I can move the rest in a future PR

(mlkit.publish_model, True),
(mlkit.unpublish_model, False)
]
PUBLISH_AND_UNPUBLISH_ARGS = [mlkit.publish_model, mlkit.unpublish_model]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PUBLISH_AND_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_AND_UNPUBLISH_FUNCS_WITH_ARGS] 
Copy link
Contributor Author

@ifielker ifielker Sep 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


# For validation type errors
def check_error(excinfo, err_type, msg=None):
Expand Down Expand Up @@ -657,7 +661,7 @@ def test_operation_error(self):
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
# The http request succeeded, the operation returned contains a create failure
# The http request succeeded, the operation returned contains an update failure
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

def test_malformed_operation(self):
Expand All @@ -673,7 +677,7 @@ def test_malformed_operation(self):
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)

def test_rpc_error_create(self):
def test_rpc_error(self):
create_recorder = instrument_mlkit_service(
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
with pytest.raises(Exception) as excinfo:
Expand Down Expand Up @@ -712,6 +716,90 @@ def test_invalid_op_name(self, op_name):
check_error(excinfo, ValueError, 'Operation name format is invalid.')


class TestPublishUnpublish(object):
"""Tests mlkit.publish_model and mlkit.unpublish_model."""
@classmethod
def setup_class(cls):
cred = testutils.MockCredential()
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test

@classmethod
def teardown_class(cls):
testutils.cleanup_apps()

@staticmethod
def _url(project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

@staticmethod
def _op_url(project_id, model_id):
return BASE_URL + \
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)

@pytest.mark.parametrize('publish_function, published', PUBLISH_AND_UNPUBLISH_VALUE_ARGS)
def test_immediate_done(self, publish_function, published):
recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE)
model = publish_function(MODEL_ID_1)
assert model == CREATED_UPDATED_MODEL_1
assert len(recorder) == 1
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
body = json.loads(recorder[0].body.decode())
assert body.get('model', {}).get('state', {}).get('published', None) is published
assert body.get('updateMask', {}) == 'state.published'

@pytest.mark.parametrize('publish_function', PUBLISH_AND_UNPUBLISH_ARGS)
def test_returns_locked(self, publish_function):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = publish_function(MODEL_ID_1)

assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)

@pytest.mark.parametrize('publish_function', PUBLISH_AND_UNPUBLISH_ARGS)
def test_operation_error(self, publish_function):
instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE)
with pytest.raises(Exception) as excinfo:
publish_function(MODEL_ID_1)
# The http request succeeded, the operation returned contains an update failure
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

@pytest.mark.parametrize('publish_function', PUBLISH_AND_UNPUBLISH_ARGS)
def test_malformed_operation(self, publish_function):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = publish_function(MODEL_ID_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)

@pytest.mark.parametrize('publish_function', PUBLISH_AND_UNPUBLISH_ARGS)
def test_rpc_error(self, publish_function):
create_recorder = instrument_mlkit_service(
status=400, payload=ERROR_RESPONSE_BAD_REQUEST)
with pytest.raises(Exception) as excinfo:
publish_function(MODEL_ID_1)
check_firebase_error(
excinfo,
ERROR_STATUS_BAD_REQUEST,
ERROR_CODE_BAD_REQUEST,
ERROR_MSG_BAD_REQUEST
)
assert len(create_recorder) == 1

class TestGetModel(object):
"""Tests mlkit.get_model."""
@classmethod
Expand Down