Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions launch/cli/bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


@click.group("bundles")
@click.pass_context
def bundles(ctx: click.Context):
"""
Bundles is a wrapper around model bundles in Scale Launch
Expand Down
55 changes: 40 additions & 15 deletions launch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def _model_bundle_to_name(model_bundle: Union[ModelBundle, str]) -> str:
raise TypeError("model_bundle should be type ModelBundle or str")


def _model_endpoint_to_name(model_endpoint: Union[ModelEndpoint, str]) -> str:
if isinstance(model_endpoint, ModelEndpoint):
return model_endpoint.name
elif isinstance(model_endpoint, str):
return model_endpoint
else:
raise TypeError("model_endpoint should be type ModelEndpoint or str")


def _add_app_config_to_bundle_create_payload(
payload: Dict[str, Any], app_config: Optional[Union[Dict[str, Any], str]]
):
Expand Down Expand Up @@ -591,7 +600,7 @@ def create_model_endpoint(
"""
if update_if_exists and self.model_endpoint_exists(endpoint_name):
self.edit_model_endpoint(
endpoint_name=endpoint_name,
model_endpoint=endpoint_name,
model_bundle=model_bundle,
cpus=cpus,
memory=memory,
Expand Down Expand Up @@ -650,7 +659,7 @@ def create_model_endpoint(

def edit_model_endpoint(
self,
endpoint_name: str,
model_endpoint: Union[ModelEndpoint, str],
model_bundle: Optional[Union[ModelBundle, str]] = None,
cpus: Optional[float] = None,
memory: Optional[str] = None,
Expand All @@ -668,7 +677,7 @@ def edit_model_endpoint(
- The endpoint's type (i.e. you cannot go from a ``SyncEnpdoint`` to an ``AsyncEndpoint`` or vice versa.

Parameters:
endpoint_name: The name of the model endpoint you want to create. The name must be unique across
model_endpoint: The model endpoint (or its name) you want to edit. The name must be unique across
all endpoints that you own.

model_bundle: The ``ModelBundle`` that the endpoint should serve.
Expand Down Expand Up @@ -709,6 +718,7 @@ def edit_model_endpoint(
bundle_name = (
_model_bundle_to_name(model_bundle) if model_bundle else None
)
endpoint_name = _model_endpoint_to_name(model_endpoint)
payload = dict(
bundle_name=bundle_name,
cpus=cpus,
Expand Down Expand Up @@ -779,16 +789,19 @@ def list_model_bundles(self) -> List[ModelBundle]:
]
return model_bundles

def get_model_bundle(self, bundle_name: str) -> ModelBundle:
def get_model_bundle(
self, model_bundle: Union[ModelBundle, str]
) -> ModelBundle:
"""
Returns a model bundle specified by ``bundle_name`` that the user owns.

Parameters:
bundle_name: The name of the bundle.
model_bundle: The bundle or its name.

Returns:
A ``ModelBundle`` object
"""
bundle_name = _model_bundle_to_name(model_bundle)
resp = self.connection.get(f"model_bundle/{bundle_name}")
assert (
len(resp["bundles"]) == 1
Expand Down Expand Up @@ -820,36 +833,41 @@ def list_model_endpoints(self) -> List[Endpoint]:
]
return async_endpoints + sync_endpoints

def delete_model_bundle(self, model_bundle: ModelBundle):
def delete_model_bundle(self, model_bundle: Union[ModelBundle, str]):
"""
Deletes the model bundle.

Parameters:
model_bundle: A ``ModelBundle`` object.
model_bundle: A ``ModelBundle`` object or the name of a model bundle.

"""
route = f"model_bundle/{model_bundle.name}"
bundle_name = _model_bundle_to_name(model_bundle)
route = f"model_bundle/{bundle_name}"
resp = self.connection.delete(route)
return resp["deleted"]

def delete_model_endpoint(self, model_endpoint: ModelEndpoint):
def delete_model_endpoint(self, model_endpoint: Union[ModelEndpoint, str]):
"""
Deletes a model endpoint.

Parameters:
model_endpoint: A ``ModelEndpoint`` object.
"""
route = f"{ENDPOINT_PATH}/{model_endpoint.name}"
endpoint_name = _model_endpoint_to_name(model_endpoint)
route = f"{ENDPOINT_PATH}/{endpoint_name}"
resp = self.connection.delete(route)
return resp["deleted"]

def read_endpoint_creation_logs(self, endpoint_name: str):
def read_endpoint_creation_logs(
self, model_endpoint: Union[ModelEndpoint, str]
):
"""
Retrieves the logs for the creation of the endpoint.

Parameters:
endpoint_name: The name of the endpoint.
model_endpoint: The endpoint or its name.
"""
endpoint_name = _model_endpoint_to_name(model_endpoint)
route = f"{ENDPOINT_PATH}/creation_logs/{endpoint_name}"
resp = self.connection.get(route)
return resp["content"]
Expand Down Expand Up @@ -1076,7 +1094,7 @@ def _get_async_endpoint_response(

def batch_async_request(
self,
bundle_name: str,
model_bundle: Union[ModelBundle, str],
urls: List[str] = None,
inputs: Optional[List[Dict[str, Any]]] = None,
batch_url_file_location: Optional[str] = None,
Expand All @@ -1090,7 +1108,7 @@ def batch_async_request(
Must have exactly one of urls or inputs passed in.

Parameters:
bundle_name: The name of the bundle to use for inference.
model_bundle: The bundle or the name of a the bundle to use for inference.

urls: A list of urls, each pointing to a file containing model input.
Must be accessible by Scale Launch, hence urls need to either be public or signedURLs.
Expand All @@ -1111,6 +1129,8 @@ def batch_async_request(
An id/key that can be used to fetch inference results at a later time
"""

bundle_name = _model_bundle_to_name(model_bundle)

if batch_task_options is None:
batch_task_options = {}
allowed_batch_task_options = {
Expand Down Expand Up @@ -1144,7 +1164,12 @@ def batch_async_request(

if self.self_hosted:
# TODO make this not use bundle_location_fn()
file_location = batch_url_file_location or self.batch_csv_location_fn() or self.bundle_location_fn() # type: ignore
location_fn = self.batch_csv_location_fn or self.bundle_location_fn
if location_fn is None and batch_url_file_location is None:
raise ValueError(
"Must register batch_csv_location_fn if csv file location not passed in"
)
file_location = batch_url_file_location or location_fn() # type: ignore
self.upload_batch_csv_fn( # type: ignore
f.getvalue(), file_location
)
Expand Down