Skip to content

Commit 0cca2a9

Browse files
More completions api update (#148)
* update client * lint
1 parent 5de9f8c commit 0cca2a9

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

launch/client.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2898,6 +2898,8 @@ def completions_sync(
28982898
prompt: str,
28992899
max_new_tokens: int,
29002900
temperature: float,
2901+
stop_sequences: Optional[List[str]] = None,
2902+
return_token_log_probs: Optional[bool] = False,
29012903
) -> CompletionSyncV1Response:
29022904
"""
29032905
Run prompt completion on a sync LLM endpoint. Will fail if the endpoint is not sync.
@@ -2911,12 +2913,22 @@ def completions_sync(
29112913
29122914
temperature: The temperature to use for sampling
29132915
2916+
stop_sequences: List of sequences to stop the completion at
2917+
2918+
return_token_log_probs: Whether to return the log probabilities of the tokens
2919+
29142920
Returns:
29152921
Response for prompt completion
29162922
"""
29172923
with ApiClient(self.configuration) as api_client:
29182924
api_instance = DefaultApi(api_client)
2919-
request = CompletionSyncV1Request(max_new_tokens=max_new_tokens, prompt=prompt, temperature=temperature)
2925+
request = CompletionSyncV1Request(
2926+
max_new_tokens=max_new_tokens,
2927+
prompt=prompt,
2928+
temperature=temperature,
2929+
stop_sequences=stop_sequences if stop_sequences is not None else [],
2930+
return_token_log_probs=return_token_log_probs,
2931+
)
29202932
query_params = frozendict({"model_endpoint_name": endpoint_name})
29212933
response = api_instance.create_completion_sync_task_v1_llm_completions_sync_post( # type: ignore
29222934
body=request,
@@ -2932,6 +2944,8 @@ def completions_stream(
29322944
prompt: str,
29332945
max_new_tokens: int,
29342946
temperature: float,
2947+
stop_sequences: Optional[List[str]] = None,
2948+
return_token_log_probs: Optional[bool] = False,
29352949
) -> Iterable[CompletionStreamV1Response]:
29362950
"""
29372951
Run prompt completion on an LLM endpoint in streaming fashion. Will fail if endpoint does not support streaming.
@@ -2945,10 +2959,20 @@ def completions_stream(
29452959
29462960
temperature: The temperature to use for sampling
29472961
2962+
stop_sequences: List of sequences to stop the completion at
2963+
2964+
return_token_log_probs: Whether to return the log probabilities of the tokens
2965+
29482966
Returns:
29492967
Iterable responses for prompt completion
29502968
"""
2951-
request = {"max_new_tokens": max_new_tokens, "prompt": prompt, "temperature": temperature}
2969+
request = {
2970+
"max_new_tokens": max_new_tokens,
2971+
"prompt": prompt,
2972+
"temperature": temperature,
2973+
"stop_sequences": stop_sequences,
2974+
"return_token_log_probs": return_token_log_probs,
2975+
}
29522976
response = requests.post(
29532977
url=f"{self.configuration.host}/v1/llm/completions-stream?model_endpoint_name={endpoint_name}",
29542978
json=request,

0 commit comments

Comments
 (0)