Skip to content

Commit 8909dfd

Browse files
Tweakable batch inference params (#21)
1 parent 15a5532 commit 8909dfd

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ max-line-length = 79
44
max-complexity = 18
55
select = B,C,E,F,W,T4,B9
66
exclude =
7+
# clientlib is documentation only, not runnable code
8+
launch/clientlib
79
# All of these excludes should mirror something in .gitignore
810
.git,
911
__pychache__,

.pylintrc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ reports=no
2424

2525
[tool.pylint.FORMAT]
2626
max-line-length=79
27+
28+
[MASTER]
29+
# Ignore anything inside launch/clientlib (since it's documentation)
30+
ignore=clientlib

launch/client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,8 +791,9 @@ def batch_async_request(
791791
self,
792792
bundle_name: str,
793793
urls: List[str],
794-
batch_url_file_location: str = None,
794+
batch_url_file_location: Optional[str] = None,
795795
serialization_format: str = "json",
796+
batch_task_options: Optional[Dict[str, Any]] = None,
796797
):
797798
"""
798799
Sends a batch inference request to the Model Endpoint at endpoint_id, returns a key that can be used to retrieve
@@ -806,10 +807,33 @@ def batch_async_request(
806807
Must be accessible by Scale Launch, hence urls need to either be public or signedURLs.
807808
batch_url_file_location: In self-hosted mode, the input to the batch job will be uploaded
808809
to this location if provided. Otherwise, one will be determined from bundle_location_fn()
810+
batch_task_options: A Dict of optional endpoint/batch task settings, i.e. certain endpoint settings
811+
like cpus, memory, gpus, gpu_type, max_workers, as well as under-the-hood batch job settings, like
812+
pyspark_partition_size, pyspark_max_executors.
809813
810814
Returns:
811815
An id/key that can be used to fetch inference results at a later time
812816
"""
817+
818+
if batch_task_options is None:
819+
batch_task_options = {}
820+
allowed_batch_task_options = {
821+
"cpus",
822+
"memory",
823+
"gpus",
824+
"gpu_type",
825+
"max_workers",
826+
"pyspark_partition_size",
827+
"pyspark_max_executors",
828+
}
829+
if (
830+
len(set(batch_task_options.keys()) - allowed_batch_task_options)
831+
> 0
832+
):
833+
raise ValueError(
834+
f"Disallowed options {set(batch_task_options.keys()) - allowed_batch_task_options} for batch task"
835+
)
836+
813837
f = StringIO()
814838
make_batch_input_file(urls, f)
815839
f.seek(0)
@@ -837,6 +861,7 @@ def batch_async_request(
837861
input_path=file_location,
838862
serialization_format=serialization_format,
839863
)
864+
payload.update(batch_task_options)
840865
payload = self.endpoint_auth_decorator_fn(payload)
841866
resp = self.connection.post(
842867
route=f"{BATCH_TASK_PATH}/{bundle_name}",

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ exclude = '''
1919
)
2020
'''
2121

22+
[tool.mypy]
23+
exclude = [
24+
'^launch/clientlib/'
25+
]
26+
2227
[tool.poetry]
2328
name = "scale-launch"
2429
version = "0.1.0"

0 commit comments

Comments
 (0)