Skip to content

Commit 5cd518e

Browse files
committed
Add Watsonx provider with tests using the python SDK
1 parent 1b5da0e commit 5cd518e

File tree

8 files changed

+366
-64
lines changed

8 files changed

+366
-64
lines changed

.env.sample

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ FIREWORKS_API_KEY=
2525

2626
# Together AI
2727
TOGETHER_API_KEY=
28+
29+
# WatsonX
30+
WATSONX_SERVICE_URL=
31+
WATSONX_API_KEY=
32+
WATSONX_PROJECT_ID=
33+

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Simple, unified interface to multiple Generative AI providers.
77
`aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future.
88

99
Currently supported providers are -
10-
OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace and Ollama.
10+
OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace, Ollama and Watsonx.
1111
To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider.
1212

1313
## Installation
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from aisuite.provider import Provider
2+
import os
3+
from ibm_watsonx_ai import Credentials
4+
from ibm_watsonx_ai.foundation_models import ModelInference
5+
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
6+
7+
DEFAULT_TEMPERATURE = 0.7
8+
9+
10+
class WatsonxProvider(Provider):
11+
def __init__(self, **config):
12+
self.service_url = config.get("service_url") or os.getenv("WATSONX_SERVICE_URL")
13+
self.api_key = config.get("api_key") or os.getenv("WATSONX_API_KEY")
14+
self.project_id = config.get("project_id") or os.getenv("WATSONX_PROJECT_ID")
15+
16+
if not self.service_url or not self.api_key or not self.project_id:
17+
raise EnvironmentError(
18+
"Missing one or more required WatsonX environment variables: "
19+
"WATSONX_SERVICE_URL, WATSONX_API_KEY, WATSONX_PROJECT_ID. "
20+
"Please refer to the setup guide: /guides/watsonx.md."
21+
)
22+
23+
def chat_completions_create(self, model, messages, **kwargs):
24+
model = ModelInference(
25+
model_id=model,
26+
params={
27+
GenParams.TEMPERATURE: kwargs.get("temperature", DEFAULT_TEMPERATURE),
28+
},
29+
credentials=Credentials(api_key=self.api_key, url=self.service_url),
30+
project_id=self.project_id,
31+
)
32+
33+
return model.chat(prompt=messages, **kwargs)

guides/watsonx.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Watsonx with `aisuite`
2+
3+
A a step-by-step guide to set up Watsonx with the `aisuite` library, enabling you to use IBM Watsonx's powerful AI models for various tasks.
4+
5+
## Setup Instructions
6+
7+
### Step 1: Create a Watsonx Account
8+
9+
1. Visit [IBM Watsonx](https://www.ibm.com/watsonx).
10+
2. Sign up for a new account or log in with your existing IBM credentials.
11+
3. Once logged in, navigate to the **Watsonx Dashboard**.
12+
13+
---
14+
15+
### Step 2: Obtain API Credentials
16+
17+
1. **Generate an API Key**:
18+
- Go to the **API Keys** section in your Watsonx account settings.
19+
- Click on **Create API Key**.
20+
- Provide a name for your API key (e.g., `MyWatsonxKey`).
21+
- Click **Generate**, then download or copy the API key. **Keep this key secure!**
22+
23+
2. **Locate the Service URL**:
24+
- Go to the **Endpoints** section in the Watsonx dashboard.
25+
- Find the URL corresponding to your service and note it. This is your `WATSONX_SERVICE_URL`.
26+
27+
3. **Get the Project ID**:
28+
- Navigate to the **Projects** tab in the dashboard.
29+
- Select the project you want to use.
30+
- Copy the **Project ID**. This will serve as your `WATSONX_PROJECT_ID`.
31+
32+
---
33+
34+
### Step 3: Set Environment Variables
35+
36+
To simplify authentication, set the following environment variables:
37+
38+
Run the following commands in your terminal:
39+
40+
```bash
41+
export WATSONX_API_KEY="your-watsonx-api-key"
42+
export WATSONX_SERVICE_URL="your-watsonx-service-url"
43+
export WATSONX_PROJECT_ID="your-watsonx-project-id"
44+
```
45+
46+
47+
## Create a Chat Completion
48+
49+
Install the `ibm-watsonx-ai` Python client:
50+
51+
Example with pip:
52+
53+
```shell
54+
pip install ibm-watsonx-ai
55+
```
56+
57+
Example with poetry:
58+
59+
```shell
60+
poetry add ibm-watsonx-ai
61+
```
62+
63+
In your code:
64+
65+
```python
66+
import aisuite as ai
67+
client = ai.Client()
68+
69+
provider = "watsonx"
70+
model_id = "meta-llama/llama-3-70b-instruct"
71+
72+
messages = [
73+
{"role": "system", "content": "You are a helpful assistant."},
74+
{"role": "user", "content": "Tell me a joke."},
75+
]
76+
77+
response = client.chat.completions.create(
78+
model=f"{provider}:{model_id}",
79+
messages=messages,
80+
)
81+
82+
print(response.choices[0].message.content)
83+
```

poetry.lock

Lines changed: 159 additions & 62 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ vertexai = { version = "^1.63.0", optional = true }
1313
groq = { version = "^0.9.0", optional = true }
1414
mistralai = { version = "^1.0.3", optional = true }
1515
openai = { version = "^1.35.8", optional = true }
16+
ibm-watsonx-ai = { version = "^1.1.16", optional = true }
1617

1718
# Optional dependencies for different providers
1819
[tool.poetry.extras]
@@ -25,7 +26,8 @@ huggingface = []
2526
mistral = ["mistralai"]
2627
ollama = []
2728
openai = ["openai"]
28-
all = ["anthropic", "aws", "google", "groq", "mistral", "openai"] # To install all providers
29+
watsonx = ["ibm-watsonx-ai"]
30+
all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "watsonx"] # To install all providers
2931

3032
[tool.poetry.group.dev.dependencies]
3133
pytest = "^8.2.2"
@@ -44,6 +46,7 @@ chromadb = "^0.5.4"
4446
sentence-transformers = "^3.0.1"
4547
datasets = "^2.20.0"
4648
vertexai = "^1.63.0"
49+
ibm-watsonx-ai = "^1.1.16"
4750

4851
[build-system]
4952
requires = ["poetry-core"]

tests/client/test_client.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
class TestClient(unittest.TestCase):
7+
78
@patch("aisuite.providers.mistral_provider.MistralProvider.chat_completions_create")
89
@patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create")
910
@patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create")
@@ -16,6 +17,7 @@ class TestClient(unittest.TestCase):
1617
@patch(
1718
"aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create"
1819
)
20+
@patch("aisuite.providers.watsonx_provider.WatsonxProvider.chat_completions_create")
1921
def test_client_chat_completions(
2022
self,
2123
mock_fireworks,
@@ -26,6 +28,7 @@ def test_client_chat_completions(
2628
mock_openai,
2729
mock_groq,
2830
mock_mistral,
31+
mock_watsonx,
2932
):
3033
# Mock responses from providers
3134
mock_openai.return_value = "OpenAI Response"
@@ -36,6 +39,7 @@ def test_client_chat_completions(
3639
mock_mistral.return_value = "Mistral Response"
3740
mock_google.return_value = "Google Response"
3841
mock_fireworks.return_value = "Fireworks Response"
42+
mock_watsonx.return_value = "Watsonx Response"
3943

4044
# Provider configurations
4145
provider_configs = {
@@ -64,6 +68,11 @@ def test_client_chat_completions(
6468
"fireworks": {
6569
"api_key": "fireworks-api-key",
6670
},
71+
"watsonx": {
72+
"service_url": "https://watsonx-service-url.com",
73+
"api_key": "watsonx-api-key",
74+
"project_id": "watsonx-project-id",
75+
},
6776
}
6877

6978
# Initialize the client
@@ -134,6 +143,14 @@ def test_client_chat_completions(
134143
self.assertEqual(fireworks_response, "Fireworks Response")
135144
mock_fireworks.assert_called_once()
136145

146+
# Test Watsonx model
147+
watsonx_model = "watsonx" + ":" + "watsonx-model"
148+
watsonx_response = client.chat.completions.create(
149+
watsonx_model, messages=messages
150+
)
151+
self.assertEqual(watsonx_response, "Watsonx Response")
152+
mock_watsonx.assert_called_once()
153+
137154
# Test that new instances of Completion are not created each time we make an inference call.
138155
compl_instance = client.chat.completions
139156
next_compl_instance = client.chat.completions
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pytest
4+
from ibm_watsonx_ai import Credentials
5+
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
6+
7+
from aisuite.providers.watsonx_provider import WatsonxProvider
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def set_api_key_env_var(monkeypatch):
12+
"""Fixture to set environment variables for tests."""
13+
monkeypatch.setenv("WATSONX_SERVICE_URL", "https://watsonx-service-url.com")
14+
monkeypatch.setenv("WATSONX_API_KEY", "test-api-key")
15+
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
16+
17+
18+
def test_watsonx_provider():
19+
"""High-level test that the provider is initialized and chat completions are requested successfully."""
20+
21+
user_greeting = "Hello!"
22+
message_history = [{"role": "user", "content": user_greeting}]
23+
selected_model = "our-favorite-model"
24+
chosen_temperature = 0.7
25+
response_text_content = "mocked-text-response-from-model"
26+
27+
provider = WatsonxProvider()
28+
mock_response = MagicMock()
29+
mock_response.choices = [MagicMock()]
30+
mock_response.choices[0].message = MagicMock()
31+
mock_response.choices[0].message.content = response_text_content
32+
33+
with patch(
34+
"aisuite.providers.watsonx_provider.ModelInference"
35+
) as mock_model_inference:
36+
mock_model = MagicMock()
37+
mock_model_inference.return_value = mock_model
38+
mock_model.chat.return_value = mock_response
39+
40+
response = provider.chat_completions_create(
41+
messages=message_history,
42+
model=selected_model,
43+
temperature=chosen_temperature,
44+
)
45+
46+
# Assert that ModelInference was called with correct arguments.
47+
mock_model_inference.assert_called_once()
48+
args, kwargs = mock_model_inference.call_args
49+
assert kwargs["model_id"] == selected_model
50+
assert kwargs["params"] == {GenParams.TEMPERATURE: chosen_temperature}
51+
52+
# Assert that the credentials have the correct API key and service URL.
53+
credentials = kwargs["credentials"]
54+
assert credentials.api_key == provider.api_key
55+
assert credentials.url == provider.service_url
56+
57+
# Assert that chat was called with correct history and temperature.
58+
mock_model.chat.assert_called_once_with(
59+
prompt=message_history,
60+
temperature=chosen_temperature,
61+
)
62+
63+
assert response.choices[0].message.content == response_text_content

0 commit comments

Comments
 (0)