@@ -58,6 +58,15 @@ def _model_bundle_to_name(model_bundle: Union[ModelBundle, str]) -> str:
58
58
raise TypeError ("model_bundle should be type ModelBundle or str" )
59
59
60
60
61
+ def _model_endpoint_to_name (model_endpoint : Union [ModelEndpoint , str ]) -> str :
62
+ if isinstance (model_endpoint , ModelEndpoint ):
63
+ return model_endpoint .name
64
+ elif isinstance (model_endpoint , str ):
65
+ return model_endpoint
66
+ else :
67
+ raise TypeError ("model_endpoint should be type ModelEndpoint or str" )
68
+
69
+
61
70
def _add_app_config_to_bundle_create_payload (
62
71
payload : Dict [str , Any ], app_config : Optional [Union [Dict [str , Any ], str ]]
63
72
):
@@ -591,7 +600,7 @@ def create_model_endpoint(
591
600
"""
592
601
if update_if_exists and self .model_endpoint_exists (endpoint_name ):
593
602
self .edit_model_endpoint (
594
- endpoint_name = endpoint_name ,
603
+ model_endpoint = endpoint_name ,
595
604
model_bundle = model_bundle ,
596
605
cpus = cpus ,
597
606
memory = memory ,
@@ -650,7 +659,7 @@ def create_model_endpoint(
650
659
651
660
def edit_model_endpoint (
652
661
self ,
653
- endpoint_name : str ,
662
+ model_endpoint : Union [ ModelEndpoint , str ] ,
654
663
model_bundle : Optional [Union [ModelBundle , str ]] = None ,
655
664
cpus : Optional [float ] = None ,
656
665
memory : Optional [str ] = None ,
@@ -668,7 +677,7 @@ def edit_model_endpoint(
668
677
- The endpoint's type (i.e. you cannot go from a ``SyncEnpdoint`` to an ``AsyncEndpoint`` or vice versa.
669
678
670
679
Parameters:
671
- endpoint_name : The name of the model endpoint you want to create . The name must be unique across
680
+ model_endpoint : The model endpoint (or its name) you want to edit . The name must be unique across
672
681
all endpoints that you own.
673
682
674
683
model_bundle: The ``ModelBundle`` that the endpoint should serve.
@@ -709,6 +718,7 @@ def edit_model_endpoint(
709
718
bundle_name = (
710
719
_model_bundle_to_name (model_bundle ) if model_bundle else None
711
720
)
721
+ endpoint_name = _model_endpoint_to_name (model_endpoint )
712
722
payload = dict (
713
723
bundle_name = bundle_name ,
714
724
cpus = cpus ,
@@ -779,16 +789,19 @@ def list_model_bundles(self) -> List[ModelBundle]:
779
789
]
780
790
return model_bundles
781
791
782
- def get_model_bundle (self , bundle_name : str ) -> ModelBundle :
792
+ def get_model_bundle (
793
+ self , model_bundle : Union [ModelBundle , str ]
794
+ ) -> ModelBundle :
783
795
"""
784
796
Returns a model bundle specified by ``bundle_name`` that the user owns.
785
797
786
798
Parameters:
787
- bundle_name : The name of the bundle .
799
+ model_bundle : The bundle or its name .
788
800
789
801
Returns:
790
802
A ``ModelBundle`` object
791
803
"""
804
+ bundle_name = _model_bundle_to_name (model_bundle )
792
805
resp = self .connection .get (f"model_bundle/{ bundle_name } " )
793
806
assert (
794
807
len (resp ["bundles" ]) == 1
@@ -820,36 +833,41 @@ def list_model_endpoints(self) -> List[Endpoint]:
820
833
]
821
834
return async_endpoints + sync_endpoints
822
835
823
- def delete_model_bundle (self , model_bundle : ModelBundle ):
836
+ def delete_model_bundle (self , model_bundle : Union [ ModelBundle , str ] ):
824
837
"""
825
838
Deletes the model bundle.
826
839
827
840
Parameters:
828
- model_bundle: A ``ModelBundle`` object.
841
+ model_bundle: A ``ModelBundle`` object or the name of a model bundle .
829
842
830
843
"""
831
- route = f"model_bundle/{ model_bundle .name } "
844
+ bundle_name = _model_bundle_to_name (model_bundle )
845
+ route = f"model_bundle/{ bundle_name } "
832
846
resp = self .connection .delete (route )
833
847
return resp ["deleted" ]
834
848
835
- def delete_model_endpoint (self , model_endpoint : ModelEndpoint ):
849
+ def delete_model_endpoint (self , model_endpoint : Union [ ModelEndpoint , str ] ):
836
850
"""
837
851
Deletes a model endpoint.
838
852
839
853
Parameters:
840
854
model_endpoint: A ``ModelEndpoint`` object.
841
855
"""
842
- route = f"{ ENDPOINT_PATH } /{ model_endpoint .name } "
856
+ endpoint_name = _model_endpoint_to_name (model_endpoint )
857
+ route = f"{ ENDPOINT_PATH } /{ endpoint_name } "
843
858
resp = self .connection .delete (route )
844
859
return resp ["deleted" ]
845
860
846
- def read_endpoint_creation_logs (self , endpoint_name : str ):
861
+ def read_endpoint_creation_logs (
862
+ self , model_endpoint : Union [ModelEndpoint , str ]
863
+ ):
847
864
"""
848
865
Retrieves the logs for the creation of the endpoint.
849
866
850
867
Parameters:
851
- endpoint_name : The name of the endpoint .
868
+ model_endpoint : The endpoint or its name .
852
869
"""
870
+ endpoint_name = _model_endpoint_to_name (model_endpoint )
853
871
route = f"{ ENDPOINT_PATH } /creation_logs/{ endpoint_name } "
854
872
resp = self .connection .get (route )
855
873
return resp ["content" ]
@@ -1076,7 +1094,7 @@ def _get_async_endpoint_response(
1076
1094
1077
1095
def batch_async_request (
1078
1096
self ,
1079
- bundle_name : str ,
1097
+ model_bundle : Union [ ModelBundle , str ] ,
1080
1098
urls : List [str ] = None ,
1081
1099
inputs : Optional [List [Dict [str , Any ]]] = None ,
1082
1100
batch_url_file_location : Optional [str ] = None ,
@@ -1090,7 +1108,7 @@ def batch_async_request(
1090
1108
Must have exactly one of urls or inputs passed in.
1091
1109
1092
1110
Parameters:
1093
- bundle_name : The name of the bundle to use for inference.
1111
+ model_bundle : The bundle or the name of a the bundle to use for inference.
1094
1112
1095
1113
urls: A list of urls, each pointing to a file containing model input.
1096
1114
Must be accessible by Scale Launch, hence urls need to either be public or signedURLs.
@@ -1111,6 +1129,8 @@ def batch_async_request(
1111
1129
An id/key that can be used to fetch inference results at a later time
1112
1130
"""
1113
1131
1132
+ bundle_name = _model_bundle_to_name (model_bundle )
1133
+
1114
1134
if batch_task_options is None :
1115
1135
batch_task_options = {}
1116
1136
allowed_batch_task_options = {
@@ -1144,7 +1164,12 @@ def batch_async_request(
1144
1164
1145
1165
if self .self_hosted :
1146
1166
# TODO make this not use bundle_location_fn()
1147
- file_location = batch_url_file_location or self .batch_csv_location_fn () or self .bundle_location_fn () # type: ignore
1167
+ location_fn = self .batch_csv_location_fn or self .bundle_location_fn
1168
+ if location_fn is None and batch_url_file_location is None :
1169
+ raise ValueError (
1170
+ "Must register batch_csv_location_fn if csv file location not passed in"
1171
+ )
1172
+ file_location = batch_url_file_location or location_fn () # type: ignore
1148
1173
self .upload_batch_csv_fn ( # type: ignore
1149
1174
f .getvalue (), file_location
1150
1175
)
0 commit comments