Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit f6b9e32

Browse files
authored
Enable Qdrant vectorstore (#1076)
* enable qdrant vectorstore Signed-off-by: yuwenzho <yuwen.zhou@intel.com> Co-authored-by: XuhuiRen
1 parent 1d84fd8 commit f6b9e32

File tree

11 files changed

+356
-5
lines changed

11 files changed

+356
-5
lines changed

intel_extension_for_transformers/langchain/vectorstores/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
# limitations under the License.
1717

1818
from .chroma import Chroma
19+
from .qdrant import Qdrant
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import os
19+
import logging
20+
from typing import Any, Type, List, Optional, TYPE_CHECKING
21+
22+
from langchain_core.documents import Document
23+
from langchain_core.embeddings import Embeddings
24+
from langchain.vectorstores.qdrant import Qdrant as Qdrant_origin
25+
from intel_extension_for_transformers.transformers.utils.utility import LazyImport
26+
27+
logging.basicConfig(
28+
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
29+
datefmt="%d-%M-%Y %H:%M:%S",
30+
level=logging.INFO
31+
)
32+
33+
if TYPE_CHECKING:
34+
from qdrant_client.conversions import common_types
35+
36+
_DEFAULT_PERSIST_DIR = './output'
37+
38+
qdrant_client = LazyImport("qdrant_client")
39+
40+
class Qdrant(Qdrant_origin):
41+
42+
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
43+
44+
@classmethod
45+
def from_documents(
46+
cls,
47+
documents: List[Document],
48+
embedding: Embeddings,
49+
sign: Optional[str] = None,
50+
location: Optional[str] = None,
51+
url: Optional[str] = None,
52+
api_key: Optional[str] = None,
53+
host: Optional[str]= None,
54+
persist_directory: Optional[str] = None,
55+
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
56+
force_recreate: Optional[bool] = False,
57+
**kwargs: Any,
58+
):
59+
"""Create a Qdrant vectorstore from a list of documents.
60+
61+
Args:
62+
documents (List[Document]): List of documents to add to the vectorstore.
63+
embedding (Optional[Embeddings]): A subclass of `Embeddings`, responsible for text vectorization.
64+
sign (Optional[str], optional): sign for retrieval_type of 'child_parent'. Defaults to None.
65+
location (Optional[str], optional):
66+
If `:memory:` - use in-memory Qdrant instance.
67+
If `str` - use it as a `url` parameter.
68+
If `None` - fallback to relying on `host` and `port` parameters.
69+
Defaults to None.
70+
url (Optional[str], optional): either host or str of "Optional[scheme], host, Optional[port],
71+
Optional[prefix]". Defaults to None.
72+
api_key (Optional[str], optional): API key for authentication in Qdrant Cloud. Defaults to None.
73+
host (Optional[str], optional): Host name of Qdrant service. If url and host are None, set to
74+
'localhost'. Defaults to None.
75+
persist_directory (Optional[str], optional): Path in which the vectors will be stored while using
76+
local mode. Defaults to None.
77+
collection_name (Optional[str], optional): Name of the Qdrant collection to be used.
78+
Defaults to _LANGCHAIN_DEFAULT_COLLECTION_NAME.
79+
force_recreate (bool, optional): _description_. Defaults to False.
80+
"""
81+
if sum([param is not None for param in (location, url, host, persist_directory)]) == 0:
82+
# One of 'location', 'url', 'host' or 'persist_directory' should be specified.
83+
persist_directory = _DEFAULT_PERSIST_DIR
84+
if sign == "child":
85+
persist_directory = persist_directory + "_child"
86+
texts = [d.page_content for d in documents]
87+
metadatas = [d.metadata for d in documents]
88+
return cls.from_texts(
89+
texts,
90+
embedding,
91+
metadatas=metadatas,
92+
location=location,
93+
url=url,
94+
api_key=api_key,
95+
host=host,
96+
path=persist_directory,
97+
collection_name=collection_name,
98+
force_recreate=force_recreate,
99+
**kwargs)
100+
101+
@classmethod
102+
def build(
103+
cls,
104+
documents: List[Document],
105+
embedding: Optional[Embeddings],
106+
sign: Optional[str] = None,
107+
location: Optional[str] = None,
108+
url: Optional[str] = None,
109+
api_key: Optional[str] = None,
110+
host: Optional[str]= None,
111+
persist_directory: Optional[str] = None,
112+
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
113+
force_recreate: Optional[bool] = False,
114+
**kwargs: Any,
115+
):
116+
"""Build a Qdrant vectorstore.
117+
118+
Args:
119+
documents (List[Document]): List of documents to add to the vectorstore.
120+
embedding (Optional[Embeddings]): A subclass of `Embeddings`, responsible for text vectorization.
121+
sign (Optional[str], optional): sign for retrieval_type of 'child_parent'. Defaults to None.
122+
location (Optional[str], optional):
123+
If `:memory:` - use in-memory Qdrant instance.
124+
If `str` - use it as a `url` parameter.
125+
If `None` - fallback to relying on `host` and `port` parameters.
126+
Defaults to None.
127+
url (Optional[str], optional): either host or str of "Optional[scheme], host, Optional[port],
128+
Optional[prefix]". Defaults to None.
129+
api_key (Optional[str], optional): API key for authentication in Qdrant Cloud. Defaults to None.
130+
host (Optional[str], optional): Host name of Qdrant service. If url and host are None, set to
131+
'localhost'. Defaults to None.
132+
persist_directory (Optional[str], optional): Path in which the vectors will be stored while using
133+
local mode. Defaults to None.
134+
collection_name (Optional[str], optional): Name of the Qdrant collection to be used.
135+
Defaults to _LANGCHAIN_DEFAULT_COLLECTION_NAME.
136+
force_recreate (bool, optional): _description_. Defaults to False.
137+
kwargs:
138+
Current used:
139+
port (Optional[int], optional): Port of the REST API interface. Defaults to 6333.
140+
grpc_port (int, optional): Port of the gRPC interface. Defaults to 6334.
141+
prefer_grpc (bool, optional): If true - use gPRC interface whenever possible in custom methods.
142+
Defaults to False.
143+
https (Optional[bool], optional): If true - use HTTPS(SSL) protocol.
144+
prefix (Optional[str], optional):
145+
If not None - add prefix to the REST URL path.
146+
Example: service/v1 will result in
147+
http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
148+
timeout (Optional[float], optional):
149+
Timeout for REST and gRPC API requests.
150+
151+
distance_func (str, optional): Distance function. One of: "Cosine" / "Euclid" / "Dot".
152+
Defaults to "Cosine".
153+
content_payload_key (str, optional): A payload key used to store the content of the document.
154+
Defaults to CONTENT_KEY.
155+
metadata_payload_key (str, optional): A payload key used to store the metadata of the document.
156+
Defaults to METADATA_KEY.
157+
vector_name (Optional[str], optional): Name of the vector to be used internally in Qdrant.
158+
Defaults to VECTOR_NAME.
159+
shard_number (Optional[int], optional): Number of shards in collection.
160+
replication_factor (Optional[int], optional):
161+
Replication factor for collection.
162+
Defines how many copies of each shard will be created.
163+
Have effect only in distributed mode.
164+
write_consistency_factor (Optional[int], optional):
165+
Write consistency factor for collection.
166+
Defines how many replicas should apply the operation for us to consider
167+
it successful. Increasing this number will make the collection more
168+
resilient to inconsistencies, but will also make it fail if not enough
169+
replicas are available.
170+
Does not have any performance impact.
171+
Have effect only in distributed mode.
172+
on_disk_payload (Optional[bool], optional):
173+
If true - point`s payload will not be stored in memory.
174+
It will be read from the disk every time it is requested.
175+
This setting saves RAM by (slightly) increasing the response time.
176+
Note: those payload values that are involved in filtering and are
177+
indexed - remain in RAM.
178+
hnsw_config (Optional[common_types.HnswConfigDiff], optional): Params for HNSW index.
179+
optimizers_config (Optional[common_types.OptimizersConfigDiff], optional): Params for optimizer.
180+
wal_config (Optional[common_types.WalConfigDiff], optional): Params for Write-Ahead-Log.
181+
quantization_config (Optional[common_types.QuantizationConfig], optional):
182+
Params for quantization, if None - quantization will be disable.
183+
init_from (Optional[common_types.InitFrom], optional):
184+
Use data stored in another collection to initialize this collection.
185+
on_disk (Optional[bool], optional): if True, vectors will be stored on disk.
186+
If None, default value will be used.
187+
"""
188+
if sum([param is not None for param in (location, url, host, persist_directory)]) == 0:
189+
# One of 'location', 'url', 'host' or 'persist_directory' should be specified.
190+
persist_directory = _DEFAULT_PERSIST_DIR
191+
if sign == "child":
192+
persist_directory = persist_directory + "_child"
193+
if persist_directory and os.path.exists(persist_directory):
194+
if bool(os.listdir(persist_directory)):
195+
logging.info("Load the existing database!")
196+
texts = [d.page_content for d in documents]
197+
qdrant_collection = cls.construct_instance(
198+
texts=texts,
199+
embedding=embedding,
200+
location=location,
201+
url=url,
202+
api_key=api_key,
203+
host=host,
204+
path=persist_directory,
205+
collection_name=collection_name,
206+
force_recreate=force_recreate,
207+
**kwargs
208+
)
209+
return qdrant_collection
210+
else:
211+
logging.info("Create a new knowledge base...")
212+
qdrant_collection = cls.from_documents(
213+
documents=documents,
214+
embedding=embedding,
215+
location=location,
216+
url=url,
217+
api_key=api_key,
218+
host=host,
219+
persist_directory=persist_directory,
220+
collection_name=collection_name,
221+
force_recreate=force_recreate,
222+
**kwargs,
223+
)
224+
return qdrant_collection
225+
226+
227+
@classmethod
228+
def reload(
229+
cls,
230+
embedding: Optional[Embeddings],
231+
location: Optional[str] = None,
232+
url: Optional[str] = None,
233+
api_key: Optional[str] = None,
234+
host: Optional[str]= None,
235+
persist_directory: Optional[str] = None,
236+
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
237+
force_recreate: bool = False,
238+
**kwargs: Any,
239+
):
240+
"""Reload a Qdrant vectorstore.
241+
242+
Args:
243+
embedding (Optional[Embeddings]): A subclass of `Embeddings`, responsible for text vectorization.
244+
location (Optional[str], optional):
245+
If `:memory:` - use in-memory Qdrant instance.
246+
If `str` - use it as a `url` parameter.
247+
If `None` - fallback to relying on `host` and `port` parameters.
248+
Defaults to None.
249+
url (Optional[str], optional): either host or str of "Optional[scheme], host, Optional[port],
250+
Optional[prefix]". Defaults to None.
251+
api_key (Optional[str], optional): API key for authentication in Qdrant Cloud. Defaults to None.
252+
host (Optional[str], optional): Host name of Qdrant service. If url and host are None, set to
253+
'localhost'. Defaults to None.
254+
persist_directory (Optional[str], optional): Path in which the vectors will be stored while using
255+
local mode. Defaults to None.
256+
collection_name (Optional[str], optional): Name of the Qdrant collection to be used.
257+
Defaults to _LANGCHAIN_DEFAULT_COLLECTION_NAME.
258+
force_recreate (bool, optional): _description_. Defaults to False.
259+
"""
260+
if sum([param is not None for param in (location, url, host, persist_directory)]) == 0:
261+
# One of 'location', 'url', 'host' or 'persist_directory' should be specified.
262+
persist_directory = _DEFAULT_PERSIST_DIR
263+
264+
# for a single quick embedding to get vector size
265+
tmp_texts = ["foo"]
266+
267+
qdrant_collection = cls.construct_instance(
268+
texts=tmp_texts,
269+
embedding=embedding,
270+
location=location,
271+
url=url,
272+
api_key=api_key,
273+
host=host,
274+
path=persist_directory,
275+
collection_name=collection_name,
276+
force_recreate=force_recreate,
277+
**kwargs
278+
)
279+
return qdrant_collection
280+
281+
282+
def is_local(
283+
self,
284+
):
285+
"""Determine whether a client is local."""
286+
if hasattr(self.client, "_client") and \
287+
isinstance(self.client._client, qdrant_client.local.qdrant_local.QdrantLocal):
288+
return True
289+
else:
290+
return False

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
HuggingFaceInstructEmbeddings, HuggingFaceBgeEmbeddings
2727
from langchain.embeddings import GooglePalmEmbeddings
2828
from langchain.text_splitter import RecursiveCharacterTextSplitter
29-
from intel_extension_for_transformers.langchain.vectorstores import Chroma
29+
from intel_extension_for_transformers.langchain.vectorstores import Chroma, Qdrant
3030
import uuid
3131
from langchain_core.documents import Document
3232
import logging
@@ -133,10 +133,12 @@ def __init__(self,
133133
logging.info("The format of parsed documents is transferred.")
134134

135135
if self.vector_database == "Chroma":
136-
self.database = Chroma()
136+
self.database = Chroma
137+
elif self.vector_database == "Qdrant":
138+
self.database = Qdrant
137139
# elif self.vector_database == "PGVector":
138140
# self.database = PGVector()
139-
141+
140142
if self.retrieval_type == 'default': # Using vector store retriever
141143
if append:
142144
knowledge_base = self.database.from_documents(documents=langchain_documents, embedding=self.embeddings,
@@ -145,6 +147,9 @@ def __init__(self,
145147
knowledge_base = self.database.build(documents=langchain_documents, embedding=self.embeddings, **kwargs)
146148
self.retriever = RetrieverAdapter(retrieval_type=self.retrieval_type, document_store=knowledge_base, \
147149
**kwargs).retriever
150+
if self.vector_database == "Qdrant" and knowledge_base.is_local():
151+
# one local storage folder cannot be accessed by multiple instances of Qdrant client simultaneously.
152+
knowledge_base.client.close()
148153
elif self.retrieval_type == "child_parent": # Using child-parent store retriever
149154
child_documents = self.splitter.split_documents(langchain_documents)
150155
if append:
@@ -158,6 +163,12 @@ def __init__(self,
158163
sign='child', **kwargs)
159164
self.retriever = RetrieverAdapter(retrieval_type=self.retrieval_type, document_store=knowledge_base, \
160165
child_document_store=child_knowledge_base, **kwargs).retriever
166+
if self.vector_database == "Qdrant" :
167+
# one local storage folder cannot be accessed by multiple instances of Qdrant client simultaneously.
168+
if knowledge_base.is_local():
169+
knowledge_base.client.close()
170+
if child_knowledge_base.is_local():
171+
child_knowledge_base.client.close()
161172
logging.info("The retriever is successfully built.")
162173

163174
def reload_localdb(self, local_persist_dir, **kwargs):

intel_extension_for_transformers/neural_chat/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,4 @@ urllib3
7171
langid
7272
diffusers==0.12.1
7373
transformers_stream_generator
74+
qdrant-client

intel_extension_for_transformers/neural_chat/requirements_cpu.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,5 @@ einops
4848
cchardet
4949
zhconv
5050
urllib3
51-
langid
51+
langid
52+
qdrant-client

intel_extension_for_transformers/neural_chat/requirements_hpu.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ einops
4343
zhconv
4444
urllib3
4545
langid
46+
qdrant-client

intel_extension_for_transformers/neural_chat/requirements_pc.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,4 @@ langid
4646
pymysql
4747
deepface
4848
exifread
49+
qdrant-client

intel_extension_for_transformers/neural_chat/requirements_xpu.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ exifread
3939
zhconv
4040
urllib3
4141
langid
42+
qdrant-client

intel_extension_for_transformers/neural_chat/tests/ci/api/test_chatbot_build_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def test_build_chatbot_with_retrieval_plugin_bge_int8(self):
150150
def test_build_chatbot_with_retrieval_plugin_using_local_file(self):
151151

152152
def _run_retrieval(local_dir):
153+
plugins.tts.enable = False
153154
plugins.retrieval.enable = True
154155
plugins.retrieval.args["input_path"] = "../../../README.md"
155156
plugins.retrieval.args["embedding_model"] = local_dir

0 commit comments

Comments
 (0)