Skip to content

Commit 17a919b

Browse files
authored
Merge pull request andrewyng#147 from andrewyng/add-cohere-provider
Add cohere provider
2 parents 7639961 + c469423 commit 17a919b

File tree

8 files changed

+242
-7
lines changed

8 files changed

+242
-7
lines changed

.github/workflows/run_pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
run: |
1919
python -m pip install --upgrade pip
2020
pip install poetry
21-
poetry install --with test
21+
poetry install --all-extras --with test
2222
- name: Test with pytest
2323
run: poetry run pytest
2424

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,7 @@ env/
99
.coverage
1010

1111
# pyenv
12-
.python-version
12+
.python-version
13+
14+
.DS_Store
15+
**/.DS_Store
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
import cohere
3+
4+
from aisuite.framework import ChatCompletionResponse
5+
from aisuite.provider import Provider
6+
7+
8+
class CohereProvider(Provider):
9+
def __init__(self, **config):
10+
"""
11+
Initialize the Cohere provider with the given configuration.
12+
Pass the entire configuration dictionary to the Cohere client constructor.
13+
"""
14+
# Ensure API key is provided either in config or via environment variable
15+
config.setdefault("api_key", os.getenv("CO_API_KEY"))
16+
if not config["api_key"]:
17+
raise ValueError(
18+
" API key is missing. Please provide it in the config or set the CO_API_KEY environment variable."
19+
)
20+
self.client = cohere.ClientV2(**config)
21+
22+
def chat_completions_create(self, model, messages, **kwargs):
23+
response = self.client.chat(
24+
model=model,
25+
messages=messages,
26+
**kwargs # Pass any additional arguments to the Cohere API
27+
)
28+
29+
return self.normalize_response(response)
30+
31+
def normalize_response(self, response):
32+
"""Normalize the reponse from Cohere API to match OpenAI's response format."""
33+
normalized_response = ChatCompletionResponse()
34+
normalized_response.choices[0].message.content = response.message.content[
35+
0
36+
].text
37+
return normalized_response

guides/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
These guides give directions for obtaining API keys from different providers.
44

5-
Here're the instructions for:
5+
Here are the instructions for:
66
- [Anthropic](anthropic.md)
77
- [AWS](aws.md)
88
- [Azure](azure.md)
9+
- [Cohere](cohere.md)
910
- [Google](google.md)
1011
- [Hugging Face](huggingface.md)
1112
- [OpenAI](openai.md)

guides/cohere.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Cohere
2+
3+
To use Cohere with `aisuite`, you’ll need an [Cohere account](https://cohere.com/). After logging in, go to the [API Keys](https://dashboard.cohere.com/api-keys) section in your account settings, agree to the terms of service, connect your card, and generate a new key. Once you have your key, add it to your environment as follows:
4+
5+
```shell
6+
export CO_API_KEY="your-cohere-api-key"
7+
```
8+
9+
## Create a Chat Completion
10+
11+
Install the `cohere` Python client:
12+
13+
Example with pip:
14+
```shell
15+
pip install cohere
16+
```
17+
18+
Example with poetry:
19+
```shell
20+
poetry add cohere
21+
```
22+
23+
In your code:
24+
```python
25+
import aisuite as ai
26+
27+
client = ai.Client()
28+
29+
provider = "cohere"
30+
model_id = "command-r-plus-08-2024"
31+
32+
messages = [
33+
{"role": "user", "content": "Hi, how are you?"}
34+
]
35+
36+
response = client.chat.completions.create(
37+
model=f"{provider}:{model_id}",
38+
messages=messages,
39+
)
40+
41+
print(response.choices[0].message.content)
42+
```
43+
44+
Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md).

poetry.lock

Lines changed: 105 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ readme = "README.md"
99
python = "^3.10"
1010
anthropic = { version = "^0.30.1", optional = true }
1111
boto3 = { version = "^1.34.144", optional = true }
12+
cohere = { version = "^5.12.0", optional = true }
1213
vertexai = { version = "^1.63.0", optional = true }
1314
groq = { version = "^0.9.0", optional = true }
1415
mistralai = { version = "^1.0.3", optional = true }
@@ -21,14 +22,15 @@ httpx = "~0.27.0"
2122
anthropic = ["anthropic"]
2223
aws = ["boto3"]
2324
azure = []
25+
cohere = ["cohere"]
2426
google = ["vertexai"]
2527
groq = ["groq"]
2628
huggingface = []
2729
mistral = ["mistralai"]
2830
ollama = []
2931
openai = ["openai"]
3032
watsonx = ["ibm-watsonx-ai"]
31-
all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "watsonx"] # To install all providers
33+
all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers
3234

3335
[tool.poetry.group.dev.dependencies]
3436
pre-commit = "^3.7.1"
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pytest
4+
5+
from aisuite.providers.cohere_provider import CohereProvider
6+
7+
8+
@pytest.fixture(autouse=True)
9+
def set_api_key_env_var(monkeypatch):
10+
"""Fixture to set environment variables for tests."""
11+
monkeypatch.setenv("CO_API_KEY", "test-api-key")
12+
13+
14+
def test_cohere_provider():
15+
"""High-level test that the provider is initialized and chat completions are requested successfully."""
16+
17+
user_greeting = "Hello!"
18+
message_history = [{"role": "user", "content": user_greeting}]
19+
selected_model = "our-favorite-model"
20+
chosen_temperature = 0.75
21+
response_text_content = "mocked-text-response-from-model"
22+
23+
provider = CohereProvider()
24+
mock_response = MagicMock()
25+
mock_response.message = MagicMock()
26+
mock_response.message.content = [MagicMock()]
27+
mock_response.message.content[0].text = response_text_content
28+
29+
with patch.object(
30+
provider.client,
31+
"chat",
32+
return_value=mock_response,
33+
) as mock_create:
34+
response = provider.chat_completions_create(
35+
messages=message_history,
36+
model=selected_model,
37+
temperature=chosen_temperature,
38+
)
39+
40+
mock_create.assert_called_with(
41+
messages=message_history,
42+
model=selected_model,
43+
temperature=chosen_temperature,
44+
)
45+
46+
assert response.choices[0].message.content == response_text_content

0 commit comments

Comments
 (0)