Skip to content

Commit 7352641

Browse files
authored
Merge pull request #73 from elastic/add-new-chat-models
[Search Experience] Add new chat models
2 parents f03b868 + 11a3df5 commit 7352641

File tree

6 files changed

+264
-156
lines changed

6 files changed

+264
-156
lines changed

example-apps/workplace-search/README.md

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ This app requires the following environment variables to be set:
2121
export ELASTIC_CLOUD_ID=...
2222
export ELASTIC_USERNAME=...
2323
export ELASTIC_PASSWORD=...
24-
export OPENAI_API_KEY=...
2524
```
2625

2726
Note:
@@ -31,7 +30,50 @@ Note:
3130
1. Go to the [Create deployment](https://cloud.elastic.co/deployments/create) page
3231
2. Select **Create deployment** and follow the instructions
3332

34-
- you can get your OpenAI key from the [OpenAI dashboard](https://platform.openai.com/account/api-keys).
33+
34+
To use llm other than openai you can set up the LLM_TYPE environment variable to one of the following values:
35+
```sh
36+
# azure|openai|vertex|bedrock
37+
export LLM_TYPE=azure
38+
```
39+
40+
### 2.1. OpenAI LLM
41+
42+
To use OpenAI LLM, you will need to set up only OPENAI_API_KEY environment variable:
43+
44+
```sh
45+
export OPENAI_API_KEY=...
46+
```
47+
You can get your OpenAI key from the [OpenAI dashboard](https://platform.openai.com/account/api-keys).
48+
### 2.2. Azure OPENAI LLM
49+
50+
If you are using Azure LLM, you will need to set the following environment variables:
51+
52+
```sh
53+
export OPENAI_VERSION=... # e.g. 2023-05-15
54+
export OPENAI_BASE_URL=...
55+
export OPENAI_API_KEY=...
56+
export OPENAI_ENGINE=... # deployment name in Azure
57+
```
58+
59+
### 2.3. Bedrock LLM
60+
61+
To use Bedrock LLM you need to set the following environment variables:
62+
63+
```sh
64+
export AWS_ACCESS_KEY=...
65+
export AWS_SECRET_KEY=...
66+
export AWS_REGION=... # e.g. us-east-1
67+
```
68+
or you can create config file `~/.aws/config` as it described here:
69+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials
70+
71+
```
72+
[default]
73+
aws_access_key_id=...
74+
aws_secret_access_key=...
75+
region=...
76+
```
3577

3678
## 3. Index Data
3779

example-apps/workplace-search/api/app.py

Lines changed: 4 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,181 +1,34 @@
1-
from elasticsearch import Elasticsearch
2-
from lib.elasticsearch_chat_message_history import ElasticsearchChatMessageHistory
31
from flask import Flask, jsonify, request, Response
42
from flask_cors import CORS
5-
from langchain.callbacks.base import BaseCallbackHandler
6-
from langchain.chains import ConversationalRetrievalChain
7-
from langchain.chat_models import ChatOpenAI
8-
from langchain.prompts.chat import (
9-
HumanMessagePromptTemplate,
10-
SystemMessagePromptTemplate,
11-
ChatPromptTemplate,
12-
)
13-
from langchain.prompts.prompt import PromptTemplate
14-
from langchain.vectorstores import ElasticsearchStore
153
from queue import Queue
164
from uuid import uuid4
17-
import json
18-
import os
5+
from chat import chat, ask_question, parse_stream_message
196
import threading
207

21-
INDEX = "workplace-app-docs"
22-
INDEX_CHAT_HISTORY = "workplace-app-docs-chat-history"
23-
ELASTIC_CLOUD_ID = os.getenv("ELASTIC_CLOUD_ID")
24-
ELASTIC_USERNAME = os.getenv("ELASTIC_USERNAME")
25-
ELASTIC_PASSWORD = os.getenv("ELASTIC_PASSWORD")
26-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
27-
28-
POISON_MESSAGE = "~~~END~~~"
29-
SESSION_ID_TAG = "[SESSION_ID]"
30-
SOURCE_TAG = "[SOURCE]"
31-
DONE_TAG = "[DONE]"
32-
33-
34-
class QueueCallbackHandler(BaseCallbackHandler):
35-
def __init__(
36-
self,
37-
queue: Queue,
38-
):
39-
self.queue = queue
40-
self.in_human_prompt = True
41-
42-
def on_retriever_end(self, documents, *, run_id, parent_run_id=None, **kwargs):
43-
if len(documents) > 0:
44-
for doc in documents:
45-
source = {
46-
"name": doc.metadata["name"],
47-
"page_content": doc.page_content,
48-
"url": doc.metadata["url"],
49-
"icon": doc.metadata["category"],
50-
"updated_at": doc.metadata.get("updated_at", None)
51-
}
52-
self.queue.put(f"{SOURCE_TAG} {json.dumps(source)}")
53-
54-
def on_llm_new_token(self, token, **kwargs):
55-
if not self.in_human_prompt:
56-
self.queue.put(token)
57-
58-
def on_llm_start(
59-
self,
60-
serialized,
61-
prompts,
62-
*,
63-
run_id,
64-
parent_run_id=None,
65-
tags=None,
66-
metadata=None,
67-
**kwargs,
68-
):
69-
self.in_human_prompt = prompts[0].startswith("Human:")
70-
71-
def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs):
72-
if not self.in_human_prompt:
73-
self.queue.put(POISON_MESSAGE)
74-
75-
76-
elasticsearch_client = Elasticsearch(
77-
cloud_id=ELASTIC_CLOUD_ID, basic_auth=(ELASTIC_USERNAME, ELASTIC_PASSWORD)
78-
)
79-
80-
store = ElasticsearchStore(
81-
es_connection=elasticsearch_client,
82-
index_name=INDEX,
83-
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(),
84-
)
85-
86-
retriever = store.as_retriever()
87-
88-
llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True, temperature=0.2)
89-
90-
general_system_template = """
91-
Use the following passages to answer the user's question.
92-
Each passage has a SOURCE which is the title of the document. When answering, give the source name of the passages you are answering from, put them as an array of strings in here <script>[sources]</script>.
93-
If you don't know the answer, just say that you don't know, don't try to make up an answer.
94-
95-
----
96-
{context}
97-
----
98-
99-
"""
100-
general_user_template = "Question: {question}"
101-
qa_prompt = ChatPromptTemplate.from_messages(
102-
[
103-
SystemMessagePromptTemplate.from_template(general_system_template),
104-
HumanMessagePromptTemplate.from_template(general_user_template),
105-
]
106-
)
107-
108-
document_prompt = PromptTemplate(
109-
input_variables=["page_content", "name"],
110-
template="""
111-
---
112-
NAME: "{name}"
113-
PASSAGE:
114-
{page_content}
115-
---
116-
""",
117-
)
118-
119-
chat = ConversationalRetrievalChain.from_llm(
120-
llm=llm,
121-
retriever=store.as_retriever(),
122-
return_source_documents=True,
123-
combine_docs_chain_kwargs={"prompt": qa_prompt, "document_prompt": document_prompt},
124-
verbose=True,
125-
)
126-
1278
app = Flask(__name__, static_folder="../frontend/public")
1289
CORS(app)
12910

130-
13111
@app.route("/")
13212
def api_index():
13313
return app.send_static_file("index.html")
13414

135-
136-
def ask_question(question, queue, chat_history):
137-
result = chat(
138-
{"question": question, "chat_history": chat_history.messages},
139-
callbacks=[QueueCallbackHandler(queue)],
140-
)
141-
142-
chat_history.add_user_message(result["question"])
143-
chat_history.add_ai_message(result["answer"])
144-
145-
14615
@app.route("/api/chat", methods=["POST"])
14716
def api_chat():
148-
stream_queue = Queue()
14917
request_json = request.get_json()
15018
question = request_json.get("question")
15119
if question is None:
15220
return jsonify({"msg": "Missing question from request JSON"}), 400
15321

22+
stream_queue = Queue()
15423
session_id = request.args.get("session_id", str(uuid4()))
15524

15625
print("Chat session ID: ", session_id)
157-
chat_history = ElasticsearchChatMessageHistory(
158-
client=elasticsearch_client, index=INDEX_CHAT_HISTORY, session_id=session_id
159-
)
160-
161-
def generate(queue: Queue):
162-
yield f"data: {SESSION_ID_TAG} {session_id}\n\n"
163-
164-
message = None
165-
while True:
166-
message = queue.get()
167-
168-
if message == POISON_MESSAGE: # Poison message
169-
break
170-
yield f"data: {message}\n\n"
171-
172-
yield f"data: {DONE_TAG}\n\n"
17326

17427
threading.Thread(
175-
target=ask_question, args=(question, stream_queue, chat_history)
28+
target=ask_question, args=(question, stream_queue, session_id)
17629
).start()
17730

178-
return Response(generate(stream_queue), mimetype="text/event-stream")
31+
return Response(parse_stream_message(session_id, stream_queue), mimetype="text/event-stream")
17932

18033

18134
if __name__ == "__main__":
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from langchain.callbacks.base import BaseCallbackHandler
2+
from langchain.chains import ConversationalRetrievalChain
3+
from langchain.prompts.chat import (
4+
HumanMessagePromptTemplate,
5+
SystemMessagePromptTemplate,
6+
ChatPromptTemplate,
7+
)
8+
from langchain.prompts.prompt import PromptTemplate
9+
from langchain.vectorstores import ElasticsearchStore
10+
from queue import Queue
11+
from llm_integrations import get_llm
12+
from elasticsearch_client import elasticsearch_client, get_elasticsearch_chat_message_history
13+
import json
14+
15+
INDEX = "workplace-app-docs"
16+
INDEX_CHAT_HISTORY = "workplace-app-docs-chat-history"
17+
POISON_MESSAGE = "~~~END~~~"
18+
SESSION_ID_TAG = "[SESSION_ID]"
19+
SOURCE_TAG = "[SOURCE]"
20+
DONE_TAG = "[DONE]"
21+
22+
class QueueCallbackHandler(BaseCallbackHandler):
23+
def __init__(
24+
self,
25+
queue: Queue,
26+
):
27+
self.queue = queue
28+
self.in_human_prompt = True
29+
30+
def on_retriever_end(self, documents, *, run_id, parent_run_id=None, **kwargs):
31+
if len(documents) > 0:
32+
for doc in documents:
33+
source = {
34+
"name": doc.metadata["name"],
35+
"page_content": doc.page_content,
36+
"url": doc.metadata["url"],
37+
"icon": doc.metadata["category"],
38+
"updated_at": doc.metadata.get("updated_at", None)
39+
}
40+
self.queue.put(f"{SOURCE_TAG} {json.dumps(source)}")
41+
42+
def on_llm_new_token(self, token, **kwargs):
43+
if not self.in_human_prompt:
44+
self.queue.put(token)
45+
46+
def on_llm_start(
47+
self,
48+
serialized,
49+
prompts,
50+
*,
51+
run_id,
52+
parent_run_id=None,
53+
tags=None,
54+
metadata=None,
55+
**kwargs,
56+
):
57+
self.in_human_prompt = prompts[0].startswith("Human:")
58+
59+
def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs):
60+
if not self.in_human_prompt:
61+
self.queue.put(POISON_MESSAGE)
62+
63+
store = ElasticsearchStore(
64+
es_connection=elasticsearch_client,
65+
index_name=INDEX,
66+
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(),
67+
)
68+
69+
general_system_template = """
70+
Use the following passages to answer the user's question.
71+
Each passage has a SOURCE which is the title of the document. When answering, give the source name of the passages you are answering from, put them as an array of strings in here <script>[sources]</script>.
72+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
73+
74+
----
75+
{context}
76+
----
77+
78+
"""
79+
general_user_template = "Question: {question}"
80+
qa_prompt = ChatPromptTemplate.from_messages(
81+
[
82+
SystemMessagePromptTemplate.from_template(general_system_template),
83+
HumanMessagePromptTemplate.from_template(general_user_template),
84+
]
85+
)
86+
87+
document_prompt = PromptTemplate(
88+
input_variables=["page_content", "name"],
89+
template="""
90+
---
91+
NAME: "{name}"
92+
PASSAGE:
93+
{page_content}
94+
---
95+
""",
96+
)
97+
98+
retriever = store.as_retriever()
99+
llm = get_llm()
100+
chat = ConversationalRetrievalChain.from_llm(
101+
llm=llm,
102+
retriever=store.as_retriever(),
103+
return_source_documents=True,
104+
combine_docs_chain_kwargs={"prompt": qa_prompt, "document_prompt": document_prompt},
105+
verbose=True,
106+
)
107+
108+
def parse_stream_message(session_id, queue: Queue):
109+
yield f"data: {SESSION_ID_TAG} {session_id}\n\n"
110+
111+
message = None
112+
while True:
113+
message = queue.get()
114+
115+
if message == POISON_MESSAGE:
116+
break
117+
yield f"data: {message}\n\n"
118+
119+
yield f"data: {DONE_TAG}\n\n"
120+
121+
def ask_question(question, queue, session_id):
122+
chat_history=get_elasticsearch_chat_message_history(INDEX_CHAT_HISTORY, session_id)
123+
result=chat(
124+
{"question": question, "chat_history": chat_history.messages},
125+
callbacks=[QueueCallbackHandler(queue)],
126+
)
127+
128+
chat_history.add_user_message(result["question"])
129+
chat_history.add_ai_message(result["answer"])
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from elasticsearch import Elasticsearch
2+
from lib.elasticsearch_chat_message_history import ElasticsearchChatMessageHistory
3+
import os
4+
5+
ELASTIC_CLOUD_ID = os.getenv("ELASTIC_CLOUD_ID")
6+
ELASTIC_USERNAME = os.getenv("ELASTIC_USERNAME")
7+
ELASTIC_PASSWORD = os.getenv("ELASTIC_PASSWORD")
8+
9+
elasticsearch_client = Elasticsearch(
10+
cloud_id=ELASTIC_CLOUD_ID, basic_auth=(ELASTIC_USERNAME, ELASTIC_PASSWORD)
11+
)
12+
13+
def get_elasticsearch_chat_message_history(index, session_id):
14+
return ElasticsearchChatMessageHistory(
15+
client=elasticsearch_client, index=index, session_id=session_id
16+
)

0 commit comments

Comments
 (0)