Skip to content

Commit 0ebb015

Browse files
chris-lindstromisafulfDerek Legenzoff
authored
Make plugin work with AzureOAI (#221)
* Make this work better with Azure Open AI * Fix typo * Update README.md Co-authored-by: Derek Legenzoff <delegenz@microsoft.com> --------- Co-authored-by: isafulf <51974293+isafulf@users.noreply.github.com> Co-authored-by: Derek Legenzoff <delegenz@microsoft.com>
1 parent d619172 commit 0ebb015

File tree

6 files changed

+54
-13
lines changed

6 files changed

+54
-13
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin:
7272
export BEARER_TOKEN=<your_bearer_token>
7373
export OPENAI_API_KEY=<your_openai_api_key>
7474
75+
# Optional environment variables used when running Azure OpenAI
76+
export OPENAI_API_BASE=https://<AzureOpenAIName>.openai.azure.com/
77+
export OPENAI_API_TYPE=azure
78+
export OPENAI_EMBEDDINGMODEL_DEPLOYMENTID=<Name of text-embedding-ada-002 model deployment>
79+
export OPENAI_METADATA_EXTRACTIONMODEL_DEPLOYMENTID=<Name of deployment of model for metatdata>
80+
export OPENAI_COMPLETIONMODEL_DEPLOYMENTID=<Name of general model deployment used for completion>
81+
7582
# Add the environment variables for your chosen vector DB.
7683
# Some of these are optional; read the provider's setup docs in /docs/providers for more information.
7784
@@ -237,6 +244,17 @@ The API requires the following environment variables to work:
237244
| `BEARER_TOKEN` | Yes | This is a secret token that you need to authenticate your requests to the API. You can generate one using any tool or method you prefer, such as [jwt.io](https://jwt.io/). |
238245
| `OPENAI_API_KEY` | Yes | This is your OpenAI API key that you need to generate embeddings using the `text-embedding-ada-002` model. You can get an API key by creating an account on [OpenAI](https://openai.com/). |
239246

247+
248+
### Using the plugin with Azure OpenAI
249+
250+
The Azure Open AI uses URLs that are specific to your resource and references models not by model name but by the deployment id. As a result, you need to set additional environment variables for this case.
251+
252+
In addition to the OPENAI_API_BASE (your specific URL) and OPENAI_API_TYPE (azure), you should also set OPENAI_EMBEDDINGMODEL_DEPLOYMENTID which specifies the model to use for getting embeddings on upsert and query. For this, we recommend deploying text-embedding-ada-002 model and using the deployment name here.
253+
254+
If you wish to use the data preparation scripts, you will also need to set OPENAI_METADATA_EXTRACTIONMODEL_DEPLOYMENTID, used for metadata extraction and
255+
OPENAI_COMPLETIONMODEL_DEPLOYMENTID, used for PII handling.
256+
257+
240258
### Choosing a Vector Database
241259

242260
The plugin supports several vector database providers, each with different features, performance, and pricing. Depending on which one you choose, you will need to use a different Dockerfile and set different environment variables. The following sections provide brief introductions to each vector database provider.

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ packages = [{include = "server"}]
1010
python = "^3.10"
1111
fastapi = "^0.92.0"
1212
uvicorn = "^0.20.0"
13-
openai = "^0.27.2"
13+
openai = "^0.27.5"
1414
python-dotenv = "^0.21.1"
1515
pydantic = "^1.10.5"
1616
tenacity = "^8.2.1"

services/extract_metadata.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from services.openai import get_chat_completion
33
import json
44
from typing import Dict
5-
5+
import os
66

77
def extract_metadata_from_document(text: str) -> Dict[str, str]:
88
sources = Source.__members__.keys()
@@ -24,8 +24,12 @@ def extract_metadata_from_document(text: str) -> Dict[str, str]:
2424
{"role": "user", "content": text},
2525
]
2626

27+
# NOTE: Azure Open AI requires deployment id
28+
# Read environment variable - if not set - not used
2729
completion = get_chat_completion(
28-
messages, "gpt-4"
30+
messages,
31+
"gpt-4",
32+
os.environ.get("OPENAI_METADATA_EXTRACTIONMODEL_DEPLOYMENTID")
2933
) # TODO: change to your preferred model name
3034

3135
print(f"completion: {completion}")

services/openai.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import List
22
import openai
3-
3+
import os
44

55
from tenacity import retry, wait_random_exponential, stop_after_attempt
66

@@ -20,8 +20,15 @@ def get_embeddings(texts: List[str]) -> List[List[float]]:
2020
Exception: If the OpenAI API call fails.
2121
"""
2222
# Call the OpenAI API to get the embeddings
23-
response = openai.Embedding.create(input=texts, model="text-embedding-ada-002")
23+
# NOTE: Azure Open AI requires deployment id
24+
deployment = os.environ.get("OPENAI_EMBEDDINGMODEL_DEPLOYMENTID")
2425

26+
response = {}
27+
if deployment == None:
28+
response = openai.Embedding.create(input=texts, model="text-embedding-ada-002")
29+
else:
30+
response = openai.Embedding.create(input=texts, deployment_id=deployment)
31+
2532
# Extract the embedding data from the response
2633
data = response["data"] # type: ignore
2734

@@ -33,6 +40,7 @@ def get_embeddings(texts: List[str]) -> List[List[float]]:
3340
def get_chat_completion(
3441
messages,
3542
model="gpt-3.5-turbo", # use "gpt-4" for better results
43+
deployment_id = None
3644
):
3745
"""
3846
Generate a chat completion using OpenAI's chat completion API.
@@ -48,10 +56,19 @@ def get_chat_completion(
4856
Exception: If the OpenAI API call fails.
4957
"""
5058
# call the OpenAI chat completion API with the given messages
51-
response = openai.ChatCompletion.create(
52-
model=model,
53-
messages=messages,
54-
)
59+
# Note: Azure Open AI requires deployment id
60+
response = {}
61+
if deployment_id == None:
62+
response = openai.ChatCompletion.create(
63+
model=model,
64+
messages=messages,
65+
)
66+
else:
67+
response = openai.ChatCompletion.create(
68+
deployment_id = deployment_id,
69+
messages=messages,
70+
)
71+
5572

5673
choices = response["choices"] # type: ignore
5774
completion = choices[0].message.content.strip()

services/pii_detection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from services.openai import get_chat_completion
23

34

@@ -22,6 +23,7 @@ def screen_text_for_pii(text: str) -> bool:
2223

2324
completion = get_chat_completion(
2425
messages,
26+
deployment_id=os.environ.get("OPENAI_COMPLETIONMODEL_DEPLOYMENTID")
2527
)
2628

2729
if completion.startswith("True"):

0 commit comments

Comments
 (0)