|
5 | 5 | import shutil
|
6 | 6 | import tempfile
|
7 | 7 | from io import StringIO
|
8 |
| -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union |
| 8 | +from typing import ( |
| 9 | + Any, |
| 10 | + Callable, |
| 11 | + Dict, |
| 12 | + Iterable, |
| 13 | + List, |
| 14 | + Optional, |
| 15 | + Type, |
| 16 | + TypeVar, |
| 17 | + Union, |
| 18 | +) |
9 | 19 | from zipfile import ZipFile
|
10 | 20 |
|
11 | 21 | import cloudpickle
|
12 | 22 | import requests
|
| 23 | +import sseclient |
13 | 24 | import yaml
|
14 | 25 | from deprecation import deprecated
|
15 | 26 | from frozendict import frozendict
|
|
28 | 39 | from launch.api_client.model.cloudpickle_artifact_flavor import (
|
29 | 40 | CloudpickleArtifactFlavor,
|
30 | 41 | )
|
| 42 | +from launch.api_client.model.completion_stream_v1_response import ( |
| 43 | + CompletionStreamV1Response, |
| 44 | +) |
31 | 45 | from launch.api_client.model.completion_sync_v1_request import (
|
32 | 46 | CompletionSyncV1Request,
|
33 | 47 | )
|
@@ -2893,6 +2907,39 @@ def completions_sync(
|
2893 | 2907 | resp = json.loads(response.response.data)
|
2894 | 2908 | return resp
|
2895 | 2909 |
|
| 2910 | + def completions_stream( |
| 2911 | + self, |
| 2912 | + endpoint_name: str, |
| 2913 | + prompt: str, |
| 2914 | + max_new_tokens: int, |
| 2915 | + temperature: float, |
| 2916 | + ) -> Iterable[CompletionStreamV1Response]: |
| 2917 | + """ |
| 2918 | + Run prompt completion on an LLM endpoint in streaming fashion. Will fail if endpoint does not support streaming. |
| 2919 | +
|
| 2920 | + Parameters: |
| 2921 | + endpoint_name: The name of the LLM endpoint to make the request to |
| 2922 | +
|
| 2923 | + prompt: The prompt to send to the endpoint |
| 2924 | +
|
| 2925 | + max_new_tokens: The maximum number of tokens to generate for each prompt |
| 2926 | +
|
| 2927 | + temperature: The temperature to use for sampling |
| 2928 | +
|
| 2929 | + Returns: |
| 2930 | + Iterable responses for prompt completion |
| 2931 | + """ |
| 2932 | + request = {"max_new_tokens": max_new_tokens, "prompt": prompt, "temperature": temperature} |
| 2933 | + response = requests.post( |
| 2934 | + url=f"{self.configuration.host}/v1/llm/completions-stream?model_endpoint_name={endpoint_name}", |
| 2935 | + json=request, |
| 2936 | + auth=(self.configuration.username, self.configuration.password), |
| 2937 | + ) |
| 2938 | + sse_client = sseclient.SSEClient(response) |
| 2939 | + events = sse_client.events() |
| 2940 | + for event in events: |
| 2941 | + yield json.loads(event.data) |
| 2942 | + |
2896 | 2943 |
|
2897 | 2944 | def _zip_directory(zipf: ZipFile, path: str) -> None:
|
2898 | 2945 | for root, _, files in os.walk(path):
|
|
0 commit comments