- Notifications
You must be signed in to change notification settings - Fork 343
Firebase ML Kit Publish and Unpublish Implementation #345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -277,7 +277,11 @@ | |
| PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ | ||
| '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) | ||
| INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] | ||
| | ||
| PUBLISH_AND_UNPUBLISH_VALUE_ARGS = [ | ||
| ||
| (mlkit.publish_model, True), | ||
| (mlkit.unpublish_model, False) | ||
| ] | ||
| PUBLISH_AND_UNPUBLISH_ARGS = [mlkit.publish_model, mlkit.unpublish_model] | ||
| ||
| | ||
| # For validation type errors | ||
| def check_error(excinfo, err_type, msg=None): | ||
| | @@ -657,7 +661,7 @@ def test_operation_error(self): | |
| instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) | ||
| with pytest.raises(Exception) as excinfo: | ||
| mlkit.update_model(MODEL_1) | ||
| # The http request succeeded, the operation returned contains a create failure | ||
| # The http request succeeded, the operation returned contains an update failure | ||
| check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) | ||
| | ||
| def test_malformed_operation(self): | ||
| | @@ -673,7 +677,7 @@ def test_malformed_operation(self): | |
| assert recorder[1].method == 'GET' | ||
| assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) | ||
| | ||
| def test_rpc_error_create(self): | ||
| def test_rpc_error(self): | ||
| create_recorder = instrument_mlkit_service( | ||
| status=400, payload=ERROR_RESPONSE_BAD_REQUEST) | ||
| with pytest.raises(Exception) as excinfo: | ||
| | @@ -712,6 +716,90 @@ def test_invalid_op_name(self, op_name): | |
| check_error(excinfo, ValueError, 'Operation name format is invalid.') | ||
| | ||
| | ||
| class TestPublishUnpublish(object): | ||
| """Tests mlkit.publish_model and mlkit.unpublish_model.""" | ||
| @classmethod | ||
| def setup_class(cls): | ||
| cred = testutils.MockCredential() | ||
| firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) | ||
| mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test | ||
| | ||
| @classmethod | ||
| def teardown_class(cls): | ||
| testutils.cleanup_apps() | ||
| | ||
| @staticmethod | ||
| def _url(project_id, model_id): | ||
| return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) | ||
| | ||
| @staticmethod | ||
| def _op_url(project_id, model_id): | ||
| return BASE_URL + \ | ||
| 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) | ||
| | ||
| @pytest.mark.parametrize('publish_function, published', PUBLISH_AND_UNPUBLISH_VALUE_ARGS) | ||
| def test_immediate_done(self, publish_function, published): | ||
| recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) | ||
| model = publish_function(MODEL_ID_1) | ||
| assert model == CREATED_UPDATED_MODEL_1 | ||
| assert len(recorder) == 1 | ||
| assert recorder[0].method == 'PATCH' | ||
| assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) | ||
| body = json.loads(recorder[0].body.decode()) | ||
| assert body.get('model', {}).get('state', {}).get('published', None) is published | ||
| assert body.get('updateMask', {}) == 'state.published' | ||
| | ||
| @pytest.mark.parametrize('publish_function', PUBLISH_AND_UNPUBLISH_ARGS) | ||
| def test_returns_locked(self, publish_function): | ||
| recorder = instrument_mlkit_service( | ||
| status=[200, 200], | ||
| payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) | ||
| expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) | ||
| model = publish_function(MODEL_ID_1) | ||
| | ||
| assert model == expected_model | ||
| assert len(recorder) == 2 | ||
| assert recorder[0].method == 'PATCH' | ||
| assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) | ||
| assert recorder[1].method == 'GET' | ||
| assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) | ||
| | ||
| @pytest.mark.parametrize('publish_function', PUBLISH_AND_UNPUBLISH_ARGS) | ||
| def test_operation_error(self, publish_function): | ||
| instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) | ||
| with pytest.raises(Exception) as excinfo: | ||
| publish_function(MODEL_ID_1) | ||
| # The http request succeeded, the operation returned contains an update failure | ||
| check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) | ||
| | ||
| @pytest.mark.parametrize('publish_function', PUBLISH_AND_UNPUBLISH_ARGS) | ||
| def test_malformed_operation(self, publish_function): | ||
| recorder = instrument_mlkit_service( | ||
| status=[200, 200], | ||
| payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) | ||
| expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) | ||
| model = publish_function(MODEL_ID_1) | ||
| assert model == expected_model | ||
| assert len(recorder) == 2 | ||
| assert recorder[0].method == 'PATCH' | ||
| assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) | ||
| assert recorder[1].method == 'GET' | ||
| assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) | ||
| | ||
| @pytest.mark.parametrize('publish_function', PUBLISH_AND_UNPUBLISH_ARGS) | ||
| def test_rpc_error(self, publish_function): | ||
| create_recorder = instrument_mlkit_service( | ||
| status=400, payload=ERROR_RESPONSE_BAD_REQUEST) | ||
| with pytest.raises(Exception) as excinfo: | ||
| publish_function(MODEL_ID_1) | ||
| check_firebase_error( | ||
| excinfo, | ||
| ERROR_STATUS_BAD_REQUEST, | ||
| ERROR_CODE_BAD_REQUEST, | ||
| ERROR_MSG_BAD_REQUEST | ||
| ) | ||
| assert len(create_recorder) == 1 | ||
| | ||
| class TestGetModel(object): | ||
| """Tests mlkit.get_model.""" | ||
| @classmethod | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be call this
set_published()There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done