Skip to content

Commit f86e66c

Browse files
Batch inference uses bundles, not endpoints (and also functions) (#18)
* does batch task work? let's find out! * typo * make batch file * kinda wip pass in a list of urls instead of a file to batch inference * batch file location * comment the deficiencies * mypy * rename route to match TDD * use the batch task input signedURL endpoint * make the change in the right code place oops
1 parent dbb0e77 commit f86e66c

File tree

4 files changed

+105
-5
lines changed

4 files changed

+105
-5
lines changed

launch/client.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import shutil
55
import tempfile
6+
from io import StringIO
67
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
78

89
import cloudpickle
@@ -13,13 +14,17 @@
1314
from launch.constants import (
1415
ASYNC_TASK_PATH,
1516
ASYNC_TASK_RESULT_PATH,
17+
BATCH_TASK_INPUT_SIGNED_URL_PATH,
18+
BATCH_TASK_PATH,
19+
BATCH_TASK_RESULTS_PATH,
1620
ENDPOINT_PATH,
1721
MODEL_BUNDLE_SIGNED_URL_PATH,
1822
SCALE_LAUNCH_ENDPOINT,
1923
SYNC_TASK_PATH,
2024
)
2125
from launch.errors import APIError
2226
from launch.find_packages import find_packages_from_imports, get_imports
27+
from launch.make_batch_file import make_batch_input_file
2328
from launch.model_bundle import ModelBundle
2429
from launch.model_endpoint import (
2530
AsyncEndpoint,
@@ -84,6 +89,7 @@ def __init__(
8489
self.connection = Connection(api_key, endpoint)
8590
self.self_hosted = self_hosted
8691
self.upload_bundle_fn: Optional[Callable[[str, str], None]] = None
92+
self.upload_batch_csv_fn: Optional[Callable[[str, str], None]] = None
8793
self.endpoint_auth_decorator_fn: Callable[
8894
[Dict[str, Any]], Dict[str, Any]
8995
] = lambda x: x
@@ -108,11 +114,27 @@ def register_upload_bundle_fn(
108114
See register_bundle_location_fn for more notes on the signature of upload_bundle_fn
109115
110116
Parameters:
111-
upload_bundle_fn: Function that takes in a serialized bundle, and uploads that bundle to an appropriate
117+
upload_bundle_fn: Function that takes in a serialized bundle (bytes type), and uploads that bundle to an appropriate
112118
location. Only needed for self-hosted mode.
113119
"""
114120
self.upload_bundle_fn = upload_bundle_fn
115121

122+
def register_upload_batch_csv_fn(
123+
self, upload_batch_csv_fn: Callable[[str, str], None]
124+
):
125+
"""
126+
For self-hosted mode only. Registers a function that handles batch text upload. This function is called as
127+
128+
upload_batch_csv_fn(csv_text, csv_url)
129+
130+
This function should directly write the contents of csv_text as a text string into csv_url.
131+
132+
Parameters:
133+
upload_batch_csv_fn: Function that takes in a csv text (string type), and uploads that bundle to an appropriate
134+
location. Only needed for self-hosted mode.
135+
"""
136+
self.upload_batch_csv_fn = upload_batch_csv_fn
137+
116138
def register_bundle_location_fn(
117139
self, bundle_location_fn: Callable[[], str]
118140
):
@@ -744,20 +766,62 @@ def get_async_response(self, async_task_id: str) -> Dict[str, Any]:
744766
)
745767
return resp
746768

747-
def batch_async_request(self, endpoint_id: str, urls: List[str]):
769+
def batch_async_request(
770+
self,
771+
bundle_name: str,
772+
urls: List[str],
773+
batch_url_file_location: str = None,
774+
serialization_format: str = "json",
775+
):
748776
"""
749777
Sends a batch inference request to the Model Endpoint at endpoint_id, returns a key that can be used to retrieve
750778
the results of inference at a later time.
751779
752780
Parameters:
753-
endpoint_id: The id of the endpoint to make the request to
781+
bundle_name: The id of the bundle to make the request to
782+
serialization_format: Serialization format of output, either 'pickle' or 'json'.
783+
'pickle' corresponds to pickling results + returning
754784
urls: A list of urls, each pointing to a file containing model input.
755785
Must be accessible by Scale Launch, hence urls need to either be public or signedURLs.
786+
batch_url_file_location: In self-hosted mode, the input to the batch job will be uploaded
787+
to this location if provided. Otherwise, one will be determined from bundle_location_fn()
756788
757789
Returns:
758790
An id/key that can be used to fetch inference results at a later time
759791
"""
760-
raise NotImplementedError
792+
f = StringIO()
793+
make_batch_input_file(urls, f)
794+
f.seek(0)
795+
796+
if self.self_hosted:
797+
# TODO make this not use bundle_location_fn()
798+
if batch_url_file_location is None:
799+
file_location = self.bundle_location_fn() # type: ignore
800+
else:
801+
file_location = batch_url_file_location
802+
self.upload_batch_csv_fn( # type: ignore
803+
f.getvalue(), file_location
804+
)
805+
else:
806+
model_bundle_s3_url = self.connection.post(
807+
{}, BATCH_TASK_INPUT_SIGNED_URL_PATH
808+
)
809+
s3_path = model_bundle_s3_url["signedUrl"]
810+
requests.put(s3_path, data=f.getvalue())
811+
file_location = f"s3://{model_bundle_s3_url['bucket']}/{model_bundle_s3_url['key']}"
812+
813+
logger.info("Writing batch task csv to %s", file_location)
814+
815+
payload = dict(
816+
input_path=file_location,
817+
serialization_format=serialization_format,
818+
)
819+
payload = self.endpoint_auth_decorator_fn(payload)
820+
resp = self.connection.post(
821+
route=f"{BATCH_TASK_PATH}/{bundle_name}",
822+
payload=payload,
823+
)
824+
return resp["job_id"]
761825

762826
def get_batch_async_response(self, batch_async_task_id: str):
763827
"""
@@ -770,4 +834,7 @@ def get_batch_async_response(self, batch_async_task_id: str):
770834
Returns:
771835
TODO Something similar to a list of signed s3URLs
772836
"""
773-
raise NotImplementedError
837+
resp = self.connection.get(
838+
route=f"{BATCH_TASK_RESULTS_PATH}/{batch_async_task_id}"
839+
)
840+
return resp

launch/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
ENDPOINT_PATH = "endpoints"
22
MODEL_BUNDLE_SIGNED_URL_PATH = "model_bundle_upload"
3+
BATCH_TASK_INPUT_SIGNED_URL_PATH = "batch_task_input_upload"
34
ASYNC_TASK_PATH = "task_async"
45
ASYNC_TASK_RESULT_PATH = "task/result"
56
SYNC_TASK_PATH = "task_sync"
7+
BATCH_TASK_PATH = "batch_job"
8+
BATCH_TASK_RESULTS_PATH = "batch_job"
69
SCALE_LAUNCH_ENDPOINT = "https://api.scale.com/v1/hosted_inference"
710

811
DEFAULT_NETWORK_TIMEOUT_SEC = 120

launch/make_batch_file.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import csv
2+
from typing import IO, List
3+
4+
5+
def make_batch_input_file(urls: List[str], file: IO[str]):
6+
writer = csv.DictWriter(file, fieldnames=["id", "url"])
7+
writer.writeheader()
8+
for i, url in enumerate(urls):
9+
writer.writerow({"id": i, "url": url})

tests/test_make_batch_file.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import csv
2+
from io import StringIO
3+
4+
from launch.make_batch_file import make_batch_input_file
5+
6+
7+
def test_make_batch_file():
8+
f = StringIO()
9+
urls = ["one_url.count", "two_urls.count", "three_urls.count"]
10+
make_batch_input_file(urls, f)
11+
f.seek(0)
12+
13+
reader = csv.DictReader(f)
14+
rows = [row for row in reader]
15+
print(f.getvalue())
16+
print(rows)
17+
for tup in zip(enumerate(urls), rows):
18+
print(tup)
19+
(i, expected_row), actual_row = tup
20+
assert str(i) == actual_row["id"]
21+
assert expected_row == actual_row["url"]

0 commit comments

Comments
 (0)