|
1 | | -from elasticsearch import Elasticsearch |
2 | | -from lib.elasticsearch_chat_message_history import ElasticsearchChatMessageHistory |
3 | 1 | from flask import Flask, jsonify, request, Response |
4 | 2 | from flask_cors import CORS |
5 | | -from langchain.callbacks.base import BaseCallbackHandler |
6 | | -from langchain.chains import ConversationalRetrievalChain |
7 | | -from langchain.chat_models import ChatOpenAI |
8 | | -from langchain.prompts.chat import ( |
9 | | - HumanMessagePromptTemplate, |
10 | | - SystemMessagePromptTemplate, |
11 | | - ChatPromptTemplate, |
12 | | -) |
13 | | -from langchain.prompts.prompt import PromptTemplate |
14 | | -from langchain.vectorstores import ElasticsearchStore |
15 | 3 | from queue import Queue |
16 | 4 | from uuid import uuid4 |
17 | | -import json |
18 | | -import os |
| 5 | +from chat import chat, ask_question, parse_stream_message |
19 | 6 | import threading |
20 | 7 |
|
21 | | -INDEX = "workplace-app-docs" |
22 | | -INDEX_CHAT_HISTORY = "workplace-app-docs-chat-history" |
23 | | -ELASTIC_CLOUD_ID = os.getenv("ELASTIC_CLOUD_ID") |
24 | | -ELASTIC_USERNAME = os.getenv("ELASTIC_USERNAME") |
25 | | -ELASTIC_PASSWORD = os.getenv("ELASTIC_PASSWORD") |
26 | | -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
27 | | - |
28 | | -POISON_MESSAGE = "~~~END~~~" |
29 | | -SESSION_ID_TAG = "[SESSION_ID]" |
30 | | -SOURCE_TAG = "[SOURCE]" |
31 | | -DONE_TAG = "[DONE]" |
32 | | - |
33 | | - |
34 | | -class QueueCallbackHandler(BaseCallbackHandler): |
35 | | - def __init__( |
36 | | - self, |
37 | | - queue: Queue, |
38 | | - ): |
39 | | - self.queue = queue |
40 | | - self.in_human_prompt = True |
41 | | - |
42 | | - def on_retriever_end(self, documents, *, run_id, parent_run_id=None, **kwargs): |
43 | | - if len(documents) > 0: |
44 | | - for doc in documents: |
45 | | - source = { |
46 | | - "name": doc.metadata["name"], |
47 | | - "page_content": doc.page_content, |
48 | | - "url": doc.metadata["url"], |
49 | | - "icon": doc.metadata["category"], |
50 | | - "updated_at": doc.metadata.get("updated_at", None) |
51 | | - } |
52 | | - self.queue.put(f"{SOURCE_TAG} {json.dumps(source)}") |
53 | | - |
54 | | - def on_llm_new_token(self, token, **kwargs): |
55 | | - if not self.in_human_prompt: |
56 | | - self.queue.put(token) |
57 | | - |
58 | | - def on_llm_start( |
59 | | - self, |
60 | | - serialized, |
61 | | - prompts, |
62 | | - *, |
63 | | - run_id, |
64 | | - parent_run_id=None, |
65 | | - tags=None, |
66 | | - metadata=None, |
67 | | - **kwargs, |
68 | | - ): |
69 | | - self.in_human_prompt = prompts[0].startswith("Human:") |
70 | | - |
71 | | - def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs): |
72 | | - if not self.in_human_prompt: |
73 | | - self.queue.put(POISON_MESSAGE) |
74 | | - |
75 | | - |
76 | | -elasticsearch_client = Elasticsearch( |
77 | | - cloud_id=ELASTIC_CLOUD_ID, basic_auth=(ELASTIC_USERNAME, ELASTIC_PASSWORD) |
78 | | -) |
79 | | - |
80 | | -store = ElasticsearchStore( |
81 | | - es_connection=elasticsearch_client, |
82 | | - index_name=INDEX, |
83 | | - strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(), |
84 | | -) |
85 | | - |
86 | | -retriever = store.as_retriever() |
87 | | - |
88 | | -llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True, temperature=0.2) |
89 | | - |
90 | | -general_system_template = """ |
91 | | -Use the following passages to answer the user's question. |
92 | | -Each passage has a SOURCE which is the title of the document. When answering, give the source name of the passages you are answering from, put them as an array of strings in here <script>[sources]</script>. |
93 | | -If you don't know the answer, just say that you don't know, don't try to make up an answer. |
94 | | -
|
95 | | ----- |
96 | | -{context} |
97 | | ----- |
98 | | -
|
99 | | -""" |
100 | | -general_user_template = "Question: {question}" |
101 | | -qa_prompt = ChatPromptTemplate.from_messages( |
102 | | - [ |
103 | | - SystemMessagePromptTemplate.from_template(general_system_template), |
104 | | - HumanMessagePromptTemplate.from_template(general_user_template), |
105 | | - ] |
106 | | -) |
107 | | - |
108 | | -document_prompt = PromptTemplate( |
109 | | - input_variables=["page_content", "name"], |
110 | | - template=""" |
111 | | ---- |
112 | | -NAME: "{name}" |
113 | | -PASSAGE: |
114 | | -{page_content} |
115 | | ---- |
116 | | -""", |
117 | | -) |
118 | | - |
119 | | -chat = ConversationalRetrievalChain.from_llm( |
120 | | - llm=llm, |
121 | | - retriever=store.as_retriever(), |
122 | | - return_source_documents=True, |
123 | | - combine_docs_chain_kwargs={"prompt": qa_prompt, "document_prompt": document_prompt}, |
124 | | - verbose=True, |
125 | | -) |
126 | | - |
127 | 8 | app = Flask(__name__, static_folder="../frontend/public") |
128 | 9 | CORS(app) |
129 | 10 |
|
130 | | - |
131 | 11 | @app.route("/") |
132 | 12 | def api_index(): |
133 | 13 | return app.send_static_file("index.html") |
134 | 14 |
|
135 | | - |
136 | | -def ask_question(question, queue, chat_history): |
137 | | - result = chat( |
138 | | - {"question": question, "chat_history": chat_history.messages}, |
139 | | - callbacks=[QueueCallbackHandler(queue)], |
140 | | - ) |
141 | | - |
142 | | - chat_history.add_user_message(result["question"]) |
143 | | - chat_history.add_ai_message(result["answer"]) |
144 | | - |
145 | | - |
146 | 15 | @app.route("/api/chat", methods=["POST"]) |
147 | 16 | def api_chat(): |
148 | | - stream_queue = Queue() |
149 | 17 | request_json = request.get_json() |
150 | 18 | question = request_json.get("question") |
151 | 19 | if question is None: |
152 | 20 | return jsonify({"msg": "Missing question from request JSON"}), 400 |
153 | 21 |
|
| 22 | + stream_queue = Queue() |
154 | 23 | session_id = request.args.get("session_id", str(uuid4())) |
155 | 24 |
|
156 | 25 | print("Chat session ID: ", session_id) |
157 | | - chat_history = ElasticsearchChatMessageHistory( |
158 | | - client=elasticsearch_client, index=INDEX_CHAT_HISTORY, session_id=session_id |
159 | | - ) |
160 | | - |
161 | | - def generate(queue: Queue): |
162 | | - yield f"data: {SESSION_ID_TAG} {session_id}\n\n" |
163 | | - |
164 | | - message = None |
165 | | - while True: |
166 | | - message = queue.get() |
167 | | - |
168 | | - if message == POISON_MESSAGE: # Poison message |
169 | | - break |
170 | | - yield f"data: {message}\n\n" |
171 | | - |
172 | | - yield f"data: {DONE_TAG}\n\n" |
173 | 26 |
|
174 | 27 | threading.Thread( |
175 | | - target=ask_question, args=(question, stream_queue, chat_history) |
| 28 | + target=ask_question, args=(question, stream_queue, session_id) |
176 | 29 | ).start() |
177 | 30 |
|
178 | | - return Response(generate(stream_queue), mimetype="text/event-stream") |
| 31 | + return Response(parse_stream_message(session_id, stream_queue), mimetype="text/event-stream") |
179 | 32 |
|
180 | 33 |
|
181 | 34 | if __name__ == "__main__": |
|
0 commit comments