@@ -791,8 +791,9 @@ def batch_async_request(
791
791
self ,
792
792
bundle_name : str ,
793
793
urls : List [str ],
794
- batch_url_file_location : str = None ,
794
+ batch_url_file_location : Optional [ str ] = None ,
795
795
serialization_format : str = "json" ,
796
+ batch_task_options : Optional [Dict [str , Any ]] = None ,
796
797
):
797
798
"""
798
799
Sends a batch inference request to the Model Endpoint at endpoint_id, returns a key that can be used to retrieve
@@ -806,10 +807,33 @@ def batch_async_request(
806
807
Must be accessible by Scale Launch, hence urls need to either be public or signedURLs.
807
808
batch_url_file_location: In self-hosted mode, the input to the batch job will be uploaded
808
809
to this location if provided. Otherwise, one will be determined from bundle_location_fn()
810
+ batch_task_options: A Dict of optional endpoint/batch task settings, i.e. certain endpoint settings
811
+ like cpus, memory, gpus, gpu_type, max_workers, as well as under-the-hood batch job settings, like
812
+ pyspark_partition_size, pyspark_max_executors.
809
813
810
814
Returns:
811
815
An id/key that can be used to fetch inference results at a later time
812
816
"""
817
+
818
+ if batch_task_options is None :
819
+ batch_task_options = {}
820
+ allowed_batch_task_options = {
821
+ "cpus" ,
822
+ "memory" ,
823
+ "gpus" ,
824
+ "gpu_type" ,
825
+ "max_workers" ,
826
+ "pyspark_partition_size" ,
827
+ "pyspark_max_executors" ,
828
+ }
829
+ if (
830
+ len (set (batch_task_options .keys ()) - allowed_batch_task_options )
831
+ > 0
832
+ ):
833
+ raise ValueError (
834
+ f"Disallowed options { set (batch_task_options .keys ()) - allowed_batch_task_options } for batch task"
835
+ )
836
+
813
837
f = StringIO ()
814
838
make_batch_input_file (urls , f )
815
839
f .seek (0 )
@@ -837,6 +861,7 @@ def batch_async_request(
837
861
input_path = file_location ,
838
862
serialization_format = serialization_format ,
839
863
)
864
+ payload .update (batch_task_options )
840
865
payload = self .endpoint_auth_decorator_fn (payload )
841
866
resp = self .connection .post (
842
867
route = f"{ BATCH_TASK_PATH } /{ bundle_name } " ,
0 commit comments