3
3
import os
4
4
import shutil
5
5
import tempfile
6
+ from io import StringIO
6
7
from typing import Any , Callable , Dict , List , Optional , TypeVar , Union
7
8
8
9
import cloudpickle
13
14
from launch .constants import (
14
15
ASYNC_TASK_PATH ,
15
16
ASYNC_TASK_RESULT_PATH ,
17
+ BATCH_TASK_INPUT_SIGNED_URL_PATH ,
18
+ BATCH_TASK_PATH ,
19
+ BATCH_TASK_RESULTS_PATH ,
16
20
ENDPOINT_PATH ,
17
21
MODEL_BUNDLE_SIGNED_URL_PATH ,
18
22
SCALE_LAUNCH_ENDPOINT ,
19
23
SYNC_TASK_PATH ,
20
24
)
21
25
from launch .errors import APIError
22
26
from launch .find_packages import find_packages_from_imports , get_imports
27
+ from launch .make_batch_file import make_batch_input_file
23
28
from launch .model_bundle import ModelBundle
24
29
from launch .model_endpoint import (
25
30
AsyncEndpoint ,
@@ -84,6 +89,7 @@ def __init__(
84
89
self .connection = Connection (api_key , endpoint )
85
90
self .self_hosted = self_hosted
86
91
self .upload_bundle_fn : Optional [Callable [[str , str ], None ]] = None
92
+ self .upload_batch_csv_fn : Optional [Callable [[str , str ], None ]] = None
87
93
self .endpoint_auth_decorator_fn : Callable [
88
94
[Dict [str , Any ]], Dict [str , Any ]
89
95
] = lambda x : x
@@ -108,11 +114,27 @@ def register_upload_bundle_fn(
108
114
See register_bundle_location_fn for more notes on the signature of upload_bundle_fn
109
115
110
116
Parameters:
111
- upload_bundle_fn: Function that takes in a serialized bundle, and uploads that bundle to an appropriate
117
+ upload_bundle_fn: Function that takes in a serialized bundle (bytes type) , and uploads that bundle to an appropriate
112
118
location. Only needed for self-hosted mode.
113
119
"""
114
120
self .upload_bundle_fn = upload_bundle_fn
115
121
122
+ def register_upload_batch_csv_fn (
123
+ self , upload_batch_csv_fn : Callable [[str , str ], None ]
124
+ ):
125
+ """
126
+ For self-hosted mode only. Registers a function that handles batch text upload. This function is called as
127
+
128
+ upload_batch_csv_fn(csv_text, csv_url)
129
+
130
+ This function should directly write the contents of csv_text as a text string into csv_url.
131
+
132
+ Parameters:
133
+ upload_batch_csv_fn: Function that takes in a csv text (string type), and uploads that bundle to an appropriate
134
+ location. Only needed for self-hosted mode.
135
+ """
136
+ self .upload_batch_csv_fn = upload_batch_csv_fn
137
+
116
138
def register_bundle_location_fn (
117
139
self , bundle_location_fn : Callable [[], str ]
118
140
):
@@ -744,20 +766,62 @@ def get_async_response(self, async_task_id: str) -> Dict[str, Any]:
744
766
)
745
767
return resp
746
768
747
- def batch_async_request (self , endpoint_id : str , urls : List [str ]):
769
+ def batch_async_request (
770
+ self ,
771
+ bundle_name : str ,
772
+ urls : List [str ],
773
+ batch_url_file_location : str = None ,
774
+ serialization_format : str = "json" ,
775
+ ):
748
776
"""
749
777
Sends a batch inference request to the Model Endpoint at endpoint_id, returns a key that can be used to retrieve
750
778
the results of inference at a later time.
751
779
752
780
Parameters:
753
- endpoint_id: The id of the endpoint to make the request to
781
+ bundle_name: The id of the bundle to make the request to
782
+ serialization_format: Serialization format of output, either 'pickle' or 'json'.
783
+ 'pickle' corresponds to pickling results + returning
754
784
urls: A list of urls, each pointing to a file containing model input.
755
785
Must be accessible by Scale Launch, hence urls need to either be public or signedURLs.
786
+ batch_url_file_location: In self-hosted mode, the input to the batch job will be uploaded
787
+ to this location if provided. Otherwise, one will be determined from bundle_location_fn()
756
788
757
789
Returns:
758
790
An id/key that can be used to fetch inference results at a later time
759
791
"""
760
- raise NotImplementedError
792
+ f = StringIO ()
793
+ make_batch_input_file (urls , f )
794
+ f .seek (0 )
795
+
796
+ if self .self_hosted :
797
+ # TODO make this not use bundle_location_fn()
798
+ if batch_url_file_location is None :
799
+ file_location = self .bundle_location_fn () # type: ignore
800
+ else :
801
+ file_location = batch_url_file_location
802
+ self .upload_batch_csv_fn ( # type: ignore
803
+ f .getvalue (), file_location
804
+ )
805
+ else :
806
+ model_bundle_s3_url = self .connection .post (
807
+ {}, BATCH_TASK_INPUT_SIGNED_URL_PATH
808
+ )
809
+ s3_path = model_bundle_s3_url ["signedUrl" ]
810
+ requests .put (s3_path , data = f .getvalue ())
811
+ file_location = f"s3://{ model_bundle_s3_url ['bucket' ]} /{ model_bundle_s3_url ['key' ]} "
812
+
813
+ logger .info ("Writing batch task csv to %s" , file_location )
814
+
815
+ payload = dict (
816
+ input_path = file_location ,
817
+ serialization_format = serialization_format ,
818
+ )
819
+ payload = self .endpoint_auth_decorator_fn (payload )
820
+ resp = self .connection .post (
821
+ route = f"{ BATCH_TASK_PATH } /{ bundle_name } " ,
822
+ payload = payload ,
823
+ )
824
+ return resp ["job_id" ]
761
825
762
826
def get_batch_async_response (self , batch_async_task_id : str ):
763
827
"""
@@ -770,4 +834,7 @@ def get_batch_async_response(self, batch_async_task_id: str):
770
834
Returns:
771
835
TODO Something similar to a list of signed s3URLs
772
836
"""
773
- raise NotImplementedError
837
+ resp = self .connection .get (
838
+ route = f"{ BATCH_TASK_RESULTS_PATH } /{ batch_async_task_id } "
839
+ )
840
+ return resp
0 commit comments