Skip to content

Commit d8bba93

Browse files
launch-python-client completions stream (#127)
1 parent 5c8cb36 commit d8bba93

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

launch/client.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,22 @@
55
import shutil
66
import tempfile
77
from io import StringIO
8-
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
Dict,
12+
Iterable,
13+
List,
14+
Optional,
15+
Type,
16+
TypeVar,
17+
Union,
18+
)
919
from zipfile import ZipFile
1020

1121
import cloudpickle
1222
import requests
23+
import sseclient
1324
import yaml
1425
from deprecation import deprecated
1526
from frozendict import frozendict
@@ -28,6 +39,9 @@
2839
from launch.api_client.model.cloudpickle_artifact_flavor import (
2940
CloudpickleArtifactFlavor,
3041
)
42+
from launch.api_client.model.completion_stream_v1_response import (
43+
CompletionStreamV1Response,
44+
)
3145
from launch.api_client.model.completion_sync_v1_request import (
3246
CompletionSyncV1Request,
3347
)
@@ -2893,6 +2907,39 @@ def completions_sync(
28932907
resp = json.loads(response.response.data)
28942908
return resp
28952909

2910+
def completions_stream(
2911+
self,
2912+
endpoint_name: str,
2913+
prompt: str,
2914+
max_new_tokens: int,
2915+
temperature: float,
2916+
) -> Iterable[CompletionStreamV1Response]:
2917+
"""
2918+
Run prompt completion on an LLM endpoint in streaming fashion. Will fail if endpoint does not support streaming.
2919+
2920+
Parameters:
2921+
endpoint_name: The name of the LLM endpoint to make the request to
2922+
2923+
prompt: The prompt to send to the endpoint
2924+
2925+
max_new_tokens: The maximum number of tokens to generate for each prompt
2926+
2927+
temperature: The temperature to use for sampling
2928+
2929+
Returns:
2930+
Iterable responses for prompt completion
2931+
"""
2932+
request = {"max_new_tokens": max_new_tokens, "prompt": prompt, "temperature": temperature}
2933+
response = requests.post(
2934+
url=f"{self.configuration.host}/v1/llm/completions-stream?model_endpoint_name={endpoint_name}",
2935+
json=request,
2936+
auth=(self.configuration.username, self.configuration.password),
2937+
)
2938+
sse_client = sseclient.SSEClient(response)
2939+
events = sse_client.events()
2940+
for event in events:
2941+
yield json.loads(event.data)
2942+
28962943

28972944
def _zip_directory(zipf: ZipFile, path: str) -> None:
28982945
for root, _, files in os.walk(path):

0 commit comments

Comments
 (0)