Skip to content

Commit 165f757

Browse files
Seanshi/20220701 bugbash (#41)
Fix cli bug when running scale-launch ... bundles Fix None isn't callable bug in determining csv name when creating batch job Allow passing in strings to functions that ask for ModelBundle/ModelEndpoint (where it makes sense) Log exceptions and return None on get_model_{bundle, endpoint}, get_batch_async_response on a nonexistent name/id
1 parent a1ee1b3 commit 165f757

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

launch/cli/bundles.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99
@click.group("bundles")
10+
@click.pass_context
1011
def bundles(ctx: click.Context):
1112
"""
1213
Bundles is a wrapper around model bundles in Scale Launch

launch/client.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ def _model_bundle_to_name(model_bundle: Union[ModelBundle, str]) -> str:
5858
raise TypeError("model_bundle should be type ModelBundle or str")
5959

6060

61+
def _model_endpoint_to_name(model_endpoint: Union[ModelEndpoint, str]) -> str:
62+
if isinstance(model_endpoint, ModelEndpoint):
63+
return model_endpoint.name
64+
elif isinstance(model_endpoint, str):
65+
return model_endpoint
66+
else:
67+
raise TypeError("model_endpoint should be type ModelEndpoint or str")
68+
69+
6170
def _add_app_config_to_bundle_create_payload(
6271
payload: Dict[str, Any], app_config: Optional[Union[Dict[str, Any], str]]
6372
):
@@ -591,7 +600,7 @@ def create_model_endpoint(
591600
"""
592601
if update_if_exists and self.model_endpoint_exists(endpoint_name):
593602
self.edit_model_endpoint(
594-
endpoint_name=endpoint_name,
603+
model_endpoint=endpoint_name,
595604
model_bundle=model_bundle,
596605
cpus=cpus,
597606
memory=memory,
@@ -650,7 +659,7 @@ def create_model_endpoint(
650659

651660
def edit_model_endpoint(
652661
self,
653-
endpoint_name: str,
662+
model_endpoint: Union[ModelEndpoint, str],
654663
model_bundle: Optional[Union[ModelBundle, str]] = None,
655664
cpus: Optional[float] = None,
656665
memory: Optional[str] = None,
@@ -668,7 +677,7 @@ def edit_model_endpoint(
668677
- The endpoint's type (i.e. you cannot go from a ``SyncEnpdoint`` to an ``AsyncEndpoint`` or vice versa.
669678
670679
Parameters:
671-
endpoint_name: The name of the model endpoint you want to create. The name must be unique across
680+
model_endpoint: The model endpoint (or its name) you want to edit. The name must be unique across
672681
all endpoints that you own.
673682
674683
model_bundle: The ``ModelBundle`` that the endpoint should serve.
@@ -709,6 +718,7 @@ def edit_model_endpoint(
709718
bundle_name = (
710719
_model_bundle_to_name(model_bundle) if model_bundle else None
711720
)
721+
endpoint_name = _model_endpoint_to_name(model_endpoint)
712722
payload = dict(
713723
bundle_name=bundle_name,
714724
cpus=cpus,
@@ -779,16 +789,19 @@ def list_model_bundles(self) -> List[ModelBundle]:
779789
]
780790
return model_bundles
781791

782-
def get_model_bundle(self, bundle_name: str) -> ModelBundle:
792+
def get_model_bundle(
793+
self, model_bundle: Union[ModelBundle, str]
794+
) -> ModelBundle:
783795
"""
784796
Returns a model bundle specified by ``bundle_name`` that the user owns.
785797
786798
Parameters:
787-
bundle_name: The name of the bundle.
799+
model_bundle: The bundle or its name.
788800
789801
Returns:
790802
A ``ModelBundle`` object
791803
"""
804+
bundle_name = _model_bundle_to_name(model_bundle)
792805
resp = self.connection.get(f"model_bundle/{bundle_name}")
793806
assert (
794807
len(resp["bundles"]) == 1
@@ -820,36 +833,41 @@ def list_model_endpoints(self) -> List[Endpoint]:
820833
]
821834
return async_endpoints + sync_endpoints
822835

823-
def delete_model_bundle(self, model_bundle: ModelBundle):
836+
def delete_model_bundle(self, model_bundle: Union[ModelBundle, str]):
824837
"""
825838
Deletes the model bundle.
826839
827840
Parameters:
828-
model_bundle: A ``ModelBundle`` object.
841+
model_bundle: A ``ModelBundle`` object or the name of a model bundle.
829842
830843
"""
831-
route = f"model_bundle/{model_bundle.name}"
844+
bundle_name = _model_bundle_to_name(model_bundle)
845+
route = f"model_bundle/{bundle_name}"
832846
resp = self.connection.delete(route)
833847
return resp["deleted"]
834848

835-
def delete_model_endpoint(self, model_endpoint: ModelEndpoint):
849+
def delete_model_endpoint(self, model_endpoint: Union[ModelEndpoint, str]):
836850
"""
837851
Deletes a model endpoint.
838852
839853
Parameters:
840854
model_endpoint: A ``ModelEndpoint`` object.
841855
"""
842-
route = f"{ENDPOINT_PATH}/{model_endpoint.name}"
856+
endpoint_name = _model_endpoint_to_name(model_endpoint)
857+
route = f"{ENDPOINT_PATH}/{endpoint_name}"
843858
resp = self.connection.delete(route)
844859
return resp["deleted"]
845860

846-
def read_endpoint_creation_logs(self, endpoint_name: str):
861+
def read_endpoint_creation_logs(
862+
self, model_endpoint: Union[ModelEndpoint, str]
863+
):
847864
"""
848865
Retrieves the logs for the creation of the endpoint.
849866
850867
Parameters:
851-
endpoint_name: The name of the endpoint.
868+
model_endpoint: The endpoint or its name.
852869
"""
870+
endpoint_name = _model_endpoint_to_name(model_endpoint)
853871
route = f"{ENDPOINT_PATH}/creation_logs/{endpoint_name}"
854872
resp = self.connection.get(route)
855873
return resp["content"]
@@ -1076,7 +1094,7 @@ def _get_async_endpoint_response(
10761094

10771095
def batch_async_request(
10781096
self,
1079-
bundle_name: str,
1097+
model_bundle: Union[ModelBundle, str],
10801098
urls: List[str] = None,
10811099
inputs: Optional[List[Dict[str, Any]]] = None,
10821100
batch_url_file_location: Optional[str] = None,
@@ -1090,7 +1108,7 @@ def batch_async_request(
10901108
Must have exactly one of urls or inputs passed in.
10911109
10921110
Parameters:
1093-
bundle_name: The name of the bundle to use for inference.
1111+
model_bundle: The bundle or the name of a the bundle to use for inference.
10941112
10951113
urls: A list of urls, each pointing to a file containing model input.
10961114
Must be accessible by Scale Launch, hence urls need to either be public or signedURLs.
@@ -1111,6 +1129,8 @@ def batch_async_request(
11111129
An id/key that can be used to fetch inference results at a later time
11121130
"""
11131131

1132+
bundle_name = _model_bundle_to_name(model_bundle)
1133+
11141134
if batch_task_options is None:
11151135
batch_task_options = {}
11161136
allowed_batch_task_options = {
@@ -1144,7 +1164,12 @@ def batch_async_request(
11441164

11451165
if self.self_hosted:
11461166
# TODO make this not use bundle_location_fn()
1147-
file_location = batch_url_file_location or self.batch_csv_location_fn() or self.bundle_location_fn() # type: ignore
1167+
location_fn = self.batch_csv_location_fn or self.bundle_location_fn
1168+
if location_fn is None and batch_url_file_location is None:
1169+
raise ValueError(
1170+
"Must register batch_csv_location_fn if csv file location not passed in"
1171+
)
1172+
file_location = batch_url_file_location or location_fn() # type: ignore
11481173
self.upload_batch_csv_fn( # type: ignore
11491174
f.getvalue(), file_location
11501175
)

0 commit comments

Comments
 (0)