-   Notifications  
You must be signed in to change notification settings  - Fork 408
 
feat: Add cloud profiler to training_utils #828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 66 commits
99b3519 535a0b3 9c84c23 1640a96 8c51461 b0594a4 fadcf9a 83a803e f71c41f 49b2727 17c3f8b 833bbbb 75048d3 b7ba730 400de49 4bc3912 097fa34 17be71a 9cd47b6 f3c7032 9907578 89fcce0 0107c54 5a4cd9a 7dfd269 6049df1 e469013 05aaec1 35ecd1d 78f6c5b e68cc5c 4d2ca8b 134713d 919c957 415d528 5c62724 6534781 df86310 5459c4d 8490990 df6c7ba 5930cb4 0d5b89d ad926c6 76a4acb 1e92260 6ea3718 281c9c4 38453cd cba4bdd ed42226 940e3e1 0bc0251 17a2698 50b2593 c957a1b 6ea028f 20f1268 348c0de b7469ce 983dbd3 d2e04ac 7ff3bad 58dd84f e6d7ded 2fc8eeb e3808a4 1c211a6 dba42ff File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| Cloud Profiler | ||
| ================================= | ||
|   |  ||
| Cloud Profiler allows you to profile your remote Vertex AI Training jobs on demand and visualize the results in Vertex Tensorboard. | ||
|   |  ||
| Quick Start | ||
| ------------ | ||
|   |  ||
| To start using the profiler with TensorFlow, update your training script to include the following: | ||
|   |  ||
| .. code-block:: Python | ||
|   |  ||
| from google.cloud.aiplatform.training_utils import cloud_profiler | ||
| ... | ||
| cloud_profiler.init() | ||
|   |  ||
|   |  ||
| Next, run the job with with a Vertex TensorBoard instance. For full details on how to do this, visit https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview | ||
|   |  ||
| Finally, visit your TensorBoard in your Google Cloud Console, navigate to the "Profile" tab, and click the `Capture Profile` button. This will allow users to capture profiling statistics for the running jobs. | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # -*- coding: utf-8 -*- | ||
|   |  ||
| # Copyright 2021 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|   |  ||
| try: | ||
| import google.cloud.aiplatform.training_utils.cloud_profiler.initializer as initializer | ||
| except ImportError as err: | ||
| raise ImportError( | ||
| "Could not load the cloud profiler. To use the profiler, " | ||
| 'install the SDK using "pip install google-cloud-aiplatform[cloud-profiler]"' | ||
| ) from err | ||
|   |  ||
| """ | ||
| Initialize the cloud profiler for tensorflow. | ||
|   |  ||
| Usage: | ||
| from google.cloud.aiplatform.training_utils import cloud_profiler | ||
|   |  ||
| cloud_profiler.init(profiler='tensorflow') | ||
| """ | ||
|   |  ||
| init = initializer.initialize | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| # -*- coding: utf-8 -*- | ||
|   |  ||
| # Copyright 2021 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|   |  ||
| import logging | ||
| import threading | ||
| from typing import Optional, Type | ||
| from werkzeug import serving | ||
    sasha-gitg marked this conversation as resolved.    Show resolved Hide resolved  |  ||
|   |  ||
| from google.cloud.aiplatform.training_utils import environment_variables | ||
| from google.cloud.aiplatform.training_utils.cloud_profiler import webserver | ||
| from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin | ||
| from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import ( | ||
| tf_profiler, | ||
| ) | ||
|   |  ||
| # Mapping of available plugins to use | ||
| _AVAILABLE_PLUGINS = {"tensorflow": tf_profiler.TFProfiler} | ||
|   |  ||
|   |  ||
| class MissingEnvironmentVariableException(Exception): | ||
| pass | ||
|   |  ||
|   |  ||
| def _build_plugin( | ||
| plugin: Type[base_plugin.BasePlugin], | ||
| ) -> Optional[base_plugin.BasePlugin]: | ||
| """Builds the plugin given the object. | ||
|   |  ||
| Args: | ||
| plugin (Type[base_plugin]): | ||
| Required. An uninitialized plugin class. | ||
|   |  ||
| Returns: | ||
| An initialized plugin, or None if plugin cannot be | ||
| initialized. | ||
| """ | ||
| if not plugin.can_initialize(): | ||
    mkovalski marked this conversation as resolved.    Show resolved Hide resolved  |  ||
| logging.warning("Cannot initialize the plugin") | ||
| return | ||
|   |  ||
| plugin.setup() | ||
|   |  ||
| if not plugin.post_setup_check(): | ||
| return | ||
|   |  ||
| return plugin() | ||
|   |  ||
|   |  ||
| def _run_app_thread(server: webserver.WebServer, port: int): | ||
| """Run the webserver in a separate thread. | ||
|   |  ||
| Args: | ||
| server (webserver.WebServer): | ||
| Required. A webserver to accept requests. | ||
| port (int): | ||
| Required. The port to run the webserver on. | ||
| """ | ||
| daemon = threading.Thread( | ||
| name="profile_server", | ||
| target=serving.run_simple, | ||
| args=("0.0.0.0", port, server,), | ||
| ) | ||
| daemon.setDaemon(True) | ||
| daemon.start() | ||
|   |  ||
|   |  ||
| def initialize(plugin: str = "tensorflow"): | ||
| """Initializes the profiling SDK. | ||
|   |  ||
| Args: | ||
| plugin (str): | ||
| Required. Name of the plugin to initialize. | ||
| Current options are ["tensorflow"] | ||
|   |  ||
| Raises: | ||
| ValueError: | ||
| The plugin does not exist. | ||
| MissingEnvironmentVariableException: | ||
| An environment variable that is needed is not set. | ||
| """ | ||
| plugin_obj = _AVAILABLE_PLUGINS.get(plugin) | ||
|   |  ||
| if not plugin_obj: | ||
| raise ValueError( | ||
| "Plugin {} not available, must choose from {}".format( | ||
| plugin, _AVAILABLE_PLUGINS.keys() | ||
| ) | ||
| ) | ||
|   |  ||
| prof_plugin = _build_plugin(plugin_obj) | ||
|   |  ||
| if prof_plugin is None: | ||
| return | ||
|   |  ||
| server = webserver.WebServer([prof_plugin]) | ||
|   |  ||
| if not environment_variables.http_handler_port: | ||
| raise MissingEnvironmentVariableException( | ||
| "'AIP_HTTP_HANDLER_PORT' must be set." | ||
|   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the user set this using  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is set by the service.  |  ||
| ) | ||
|   |  ||
| port = int(environment_variables.http_handler_port) | ||
|   |  ||
| _run_app_thread(server, port) | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # -*- coding: utf-8 -*- | ||
|   |  ||
| # Copyright 2021 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|   |  ||
| import abc | ||
| from typing import Callable, Dict | ||
| from werkzeug import Response | ||
|   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some request to wrap with informative exception. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #828 (comment)  |  ||
|   |  ||
|   |  ||
| class BasePlugin(abc.ABC): | ||
| """Base plugin for cloud training tools endpoints. | ||
|   |  ||
| The plugins support registering http handlers to be used for | ||
| AI Platform training jobs. | ||
| """ | ||
|   |  ||
| @staticmethod | ||
| @abc.abstractmethod | ||
| def setup() -> None: | ||
| """Run any setup code for the plugin before webserver is launched.""" | ||
| raise NotImplementedError | ||
|   |  ||
| @staticmethod | ||
| @abc.abstractmethod | ||
| def can_initialize() -> bool: | ||
| """Check whether a plugin is able to be initialized. | ||
|   |  ||
| Used for checking if correct dependencies are installed, system requirements, etc. | ||
|   |  ||
| Returns: | ||
| Bool indicating whether the plugin can be initialized. | ||
| """ | ||
| raise NotImplementedError | ||
|   |  ||
| @staticmethod | ||
| @abc.abstractmethod | ||
| def post_setup_check() -> bool: | ||
| """Check if after initialization, we need to use the plugin. | ||
|   |  ||
| Example: Web server only needs to run for main node for training, others | ||
| just need to have 'setup()' run to start the rpc server. | ||
|   |  ||
| Returns: | ||
| A boolean indicating whether post setup checks pass. | ||
| """ | ||
| raise NotImplementedError | ||
|   |  ||
| @abc.abstractmethod | ||
| def get_routes(self) -> Dict[str, Callable[..., Response]]: | ||
| """Get the mapping from path to handler. | ||
|   |  ||
| This is the method in which plugins can assign different routes to | ||
| different handlers. | ||
|   |  ||
| Returns: | ||
| A mapping from a route to a handler. | ||
| """ | ||
| raise NotImplementedError | ||
Uh oh!
There was an error while loading. Please reload this page.