Skip to content

Commit b173bcd

Browse files
LLM API reference (#123)
* update llm docs * add missing file * fix tests
1 parent 13f8a7c commit b173bcd

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

docs/api/llms.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# LLM APIs
2+
3+
We provide some APIs to conveniently create, list and inference with LLMs. Under the hood they are Launch model endpoints.
4+
5+
## Example
6+
7+
```py title="LLM APIs Usage"
8+
import os
9+
10+
from rich import print
11+
12+
from launch import LaunchClient
13+
from launch.api_client.model.llm_inference_framework import (
14+
LLMInferenceFramework,
15+
)
16+
from launch.api_client.model.llm_source import LLMSource
17+
18+
client = LaunchClient(api_key=os.getenv("LAUNCH_API_KEY"), endpoint=os.getenv("LAUNCH_ENDPOINT"))
19+
20+
endpoints = client.list_llm_model_endpoints()
21+
22+
print(endpoints)
23+
24+
endpoint_name = "test-flan-t5-xxl"
25+
client.create_llm_model_endpoint(
26+
endpoint_name=endpoint_name,
27+
model_name="flan-t5-xxl",
28+
source=LLMSource.HUGGING_FACE,
29+
inference_framework=LLMInferenceFramework.DEEPSPEED,
30+
inference_framework_image_tag=os.getenv("INFERENCE_FRAMEWORK_IMAGE_TAG"),
31+
num_shards=4,
32+
min_workers=1,
33+
max_workers=1,
34+
gpus=4,
35+
endpoint_type="sync",
36+
)
37+
38+
# Wait for the endpoint to be ready
39+
40+
output = client.completion_sync(endpoint_name, prompts=["What is Deep Learning?"], max_new_tokens=10, temperature=0)
41+
print(output)
42+
```

launch/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2699,7 +2699,7 @@ def create_llm_model_endpoint(
26992699
labels: An optional dictionary of key/value pairs to associate with this endpoint.
27002700
27012701
Returns:
2702-
A Endpoint object that can be used to make requests to the endpoint.
2702+
A Endpoint object that can be used to make requests to the endpoint.
27032703
27042704
"""
27052705
existing_endpoint = self.get_model_endpoint(endpoint_name)

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ nav:
4444
- concepts/callbacks.md
4545
- 'API Documentation':
4646
- api/client.md
47+
- api/llms.md
4748
- api/model_bundles.md
4849
- api/model_endpoints.md
4950
- api/endpoint_predictions.md

tests/test_docs.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
import pytest
88
from _pytest.assertion.rewrite import AssertionRewritingHook
99

10+
from launch.api_client.model.completion_sync_v1_response import (
11+
CompletionSyncV1Response,
12+
)
1013
from launch.model_bundle import ModelBundle
11-
from launch.model_endpoint import AsyncEndpoint, ModelEndpoint
14+
from launch.model_endpoint import AsyncEndpoint, ModelEndpoint, SyncEndpoint
1215

1316
ROOT_DIR = Path(__file__).parent.parent
1417

@@ -98,6 +101,20 @@ def mock_batch_job():
98101
return {"job_id": "test-batch-job", "status": "SUCCESS"}
99102

100103

104+
@pytest.fixture
105+
def mock_list_llm_model_endpoints():
106+
mock = Mock(spec=SyncEndpoint)
107+
mock.model_endpoint = Mock(spec=ModelEndpoint)
108+
mock.model_endpoint.id = "test-endpoint"
109+
mock.status = Mock(return_value="READY")
110+
return [mock]
111+
112+
113+
@pytest.fixture
114+
def mock_completion_sync_response():
115+
return CompletionSyncV1Response(status="SUCCESS", outputs=["Deep learning is a subnet of machine learning."])
116+
117+
101118
@pytest.mark.parametrize("module_name,source_code", generate_code_chunks("launch", "docs"))
102119
def test_docs_examples(
103120
module_name,
@@ -108,6 +125,7 @@ def test_docs_examples(
108125
mock_model_bundle,
109126
mock_async_endpoint,
110127
mock_batch_job,
128+
mock_list_llm_model_endpoints,
111129
):
112130
mocker.patch("launch.connection.Connection", MagicMock())
113131
mocker.patch("launch.client.DefaultApi", MagicMock())
@@ -120,6 +138,11 @@ def test_docs_examples(
120138
mocker.patch("launch.client.LaunchClient.create_model_bundle", MagicMock(return_value=mock_model_bundle))
121139
mocker.patch("launch.client.LaunchClient.create_model_endpoint", MagicMock(return_value=mock_async_endpoint))
122140
mocker.patch("launch.client.LaunchClient.get_batch_async_response", MagicMock(return_value=mock_batch_job))
141+
mocker.patch(
142+
"launch.client.LaunchClient.list_llm_model_endpoints", MagicMock(return_value=mock_list_llm_model_endpoints)
143+
)
144+
mocker.patch("launch.client.LaunchClient.create_llm_model_endpoint", MagicMock(return_value=mock_async_endpoint))
145+
mocker.patch("launch.client.LaunchClient.completion_sync", MagicMock(return_value=mock_batch_job))
123146
mocker.patch("launch.client.Connection.make_request", MagicMock(return_value=mock_dictionary))
124147
mocker.patch("launch.client.requests", MagicMock())
125148
mocker.patch("pydantic.BaseModel.parse_raw", MagicMock())

0 commit comments

Comments
 (0)