Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions paperspace/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from paperspace import constants, client, config
from paperspace.cli.common import api_key_option, del_if_value_is_none
from paperspace.cli.jobs import jobs_group
from paperspace.cli.models import models_group
from paperspace.cli.projects import projects_group
from paperspace.cli.types import ChoiceType, json_string
from paperspace.cli.validators import validate_mutually_exclusive, validate_email
Expand Down Expand Up @@ -400,12 +401,22 @@ def create_deployment(api_key=None, **kwargs):
type=ChoiceType(DEPLOYMENT_STATES_MAP, case_sensitive=False),
help="Filter by deployment state",
)
@click.option(
"--projectId",
"projectId",
help="Use to filter by project ID",
)
@click.option(
"--modelId",
"modelId",
help="Use to filter by project ID",
)
@api_key_option
def get_deployments_list(api_key=None, **kwargs):
del_if_value_is_none(kwargs)
def get_deployments_list(api_key=None, **filters):
del_if_value_is_none(filters)
deployments_api = client.API(config.CONFIG_HOST, api_key=api_key)
command = deployments_commands.ListDeploymentsCommand(api=deployments_api)
command.execute(kwargs)
command.execute(filters)


@deployments.command("update", help="Update deployment properties")
Expand Down Expand Up @@ -445,7 +456,7 @@ def get_deployments_list(api_key=None, **kwargs):
def update_deployment_model(id_, api_key, **kwargs):
del_if_value_is_none(kwargs)
deployments_api = client.API(config.CONFIG_HOST, api_key=api_key)
command = deployments_commands.UpdateModelCommand(api=deployments_api)
command = deployments_commands.UpdateDeploymentCommand(api=deployments_api)
command.execute(id_, kwargs)


Expand Down Expand Up @@ -1054,3 +1065,4 @@ def version():

cli.add_command(jobs_group)
cli.add_command(projects_group)
cli.add_command(models_group)
29 changes: 29 additions & 0 deletions paperspace/cli/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import click

from paperspace import client, config
from paperspace.cli import common
from paperspace.commands import models as models_commands


@click.group("models", help="Manage models")
def models_group():
pass


@models_group.command("list", help="List models with optional filtering")
@click.option(
"--experimentId",
"experimentId",
help="Use to filter by experiment ID",
)
@click.option(
"--projectId",
"projectId",
help="Use to filter by project ID",
)
@common.api_key_option
def list_models(api_key, **filters):
common.del_if_value_is_none(filters)
models_api = client.API(config.CONFIG_HOST, api_key=api_key)
command = models_commands.ListModelsCommand(api=models_api)
command.execute(filters)
31 changes: 17 additions & 14 deletions paperspace/commands/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,28 @@ def execute(self, kwargs):


class ListDeploymentsCommand(_DeploymentCommandBase):
def execute(self, kwargs):
json_ = self._get_request_json(kwargs)
def execute(self, filters=None):
json_ = self._get_request_json(filters)
response = self.api.get("/deployments/getDeploymentList/", json=json_)

try:
data = response.json()
if not response.ok:
self.logger.log_error_response(data)
return
deployments = self._get_deployments_list(response)
except (ValueError, KeyError) as e:
self.logger.log("Error while parsing response data: {}".format(e))
else:
self._log_deployments_list(deployments)

@staticmethod
def _get_request_json(kwargs):
state = kwargs.get("state")
if not state:
def _get_request_json(filters):
if not filters:
return None

params = {"filter": {"where": {"and": [{"state": state}]}}}
return params
json_ = {"filter": {"where": {"and": [filters]}}}
return json_

@staticmethod
def _get_deployments_list(response):
Expand Down Expand Up @@ -95,13 +98,13 @@ def _make_deployments_list_table(deployments):
return table_string


class UpdateModelCommand(_DeploymentCommandBase):
def execute(self, model_id, kwargs):
class UpdateDeploymentCommand(_DeploymentCommandBase):
def execute(self, deployment_id, kwargs):
if not kwargs:
self.logger.log("No parameters to update were given. Use --help for more information.")
return

json_ = {"id": model_id,
json_ = {"id": deployment_id,
"upd": kwargs}
response = self.api.post("/deployments/updateDeployment/", json=json_)
self._log_message(response,
Expand All @@ -110,8 +113,8 @@ def execute(self, model_id, kwargs):


class StartDeploymentCommand(_DeploymentCommandBase):
def execute(self, model_id):
json_ = {"id": model_id,
def execute(self, deployment_id):
json_ = {"id": deployment_id,
"isRunning": True}
response = self.api.post("/deployments/updateDeployment/", json=json_)
self._log_message(response,
Expand All @@ -120,8 +123,8 @@ def execute(self, model_id):


class DeleteDeploymentCommand(_DeploymentCommandBase):
def execute(self, model_id):
json_ = {"id": model_id,
def execute(self, deployment_id):
json_ = {"id": deployment_id,
"upd": {"isDeleted": True}}
response = self.api.post("/deployments/updateDeployment/", json=json_)
self._log_message(response,
Expand Down
5 changes: 2 additions & 3 deletions paperspace/commands/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, api=experiments_api, logger_=logger):
self.logger = logger_

def execute(self, project_handles=None):
project_handles = project_handles or []
params = self._get_query_params(project_handles)
response = self.api.get("/experiments/", params=params)

Expand All @@ -68,9 +69,7 @@ def execute(self, project_handles=None):

@staticmethod
def _get_query_params(project_handles):
# TODO: change to limit: -1 when PS-9535 is deployed to production
# to list all experiments
params = {"limit": 1000000}
params = {"limit": -1} # so the API sends back full list without pagination
for i, handle in enumerate(project_handles):
key = "projectHandle[{}]".format(i)
params[key] = handle
Expand Down
64 changes: 64 additions & 0 deletions paperspace/commands/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pydoc

import terminaltables

from paperspace.utils import get_terminal_lines

from paperspace.commands import CommandBase


class ListModelsCommand(CommandBase):
def execute(self, filters):
json_ = self._get_request_json(filters)
params = {"limit": -1} # so the api returns full list without pagination
response = self.api.get("/mlModels/getModelList/", json=json_, params=params)

try:
data = response.json()
if not response.ok:
self.logger.log_error_response(data)
return
models = self._get_objects_list(response)
except (ValueError, KeyError) as e:
self.logger.log("Error while parsing response data: {}".format(e))
else:
self._log_models_list(models)

@staticmethod
def _get_request_json(filters):
if not filters:
return None

json_ = {"filter": {"where": {"and": [filters]}}}
return json_

@staticmethod
def _get_objects_list(response):
data = response.json()["modelList"]
return data

def _log_models_list(self, models):
if not models:
self.logger.log("No models found")
else:
table_str = self._make_models_list_table(models)
if len(table_str.splitlines()) > get_terminal_lines():
pydoc.pager(table_str)
else:
self.logger.log(table_str)

@staticmethod
def _make_models_list_table(models):
data = [("Name", "ID", "Model Type", "Project ID", "Experiment ID")]
for model in models:
name = model.get("name")
id_ = model.get("id")
model_type = model.get("modelType")
project_id = model.get("projectId")
experiment_id = model.get("experimentId")
data.append((name, id_, model_type, project_id, experiment_id))

ascii_table = terminaltables.AsciiTable(data)
table_string = ascii_table.table
return table_string

2 changes: 1 addition & 1 deletion paperspace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_api_key(config_dir_path, config_file_name):

_DEFAULT_CONFIG_HOST = "https://api.paperspace.io"
_DEFAULT_CONFIG_LOG_HOST = "https://logs.paperspace.io"
_DEFAULT_CONFIG_EXPERIMENTS_HOST = "https://services.paperspace.io/experiments/v1/" # TODO: validate this
_DEFAULT_CONFIG_EXPERIMENTS_HOST = "https://services.paperspace.io/experiments/v1/"
_DEFAULT_CONFIG_DIR_PATH = "~/.paperspace"
_DEFAULT_CONFIG_FILE_NAME = os.path.expanduser("config.json")

Expand Down
2 changes: 1 addition & 1 deletion paperspace/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def main():
if len(sys.argv) >= 2 and sys.argv[1] in ('experiments', 'deployments', 'machines', 'login', 'logout', 'version',
'projects', 'jobs'):
'projects', 'jobs', 'models'):
cli(sys.argv[1:])

args = sys.argv[:]
Expand Down
Loading