Skip to content

Commit af7f20f

Browse files
authored
Harrison/elastic search (langchain-ai#2419)
1 parent 659c67e commit af7f20f

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "ab66dd43",
6+
"metadata": {},
7+
"source": [
8+
"# ElasticSearch BM25\n",
9+
"\n",
10+
"This notebook goes over how to use a retriever that under the hood uses ElasticSearcha and BM25.\n",
11+
"\n",
12+
"For more information on the details of BM25 see [this blog post](https://www.elastic.co/blog/practical-bm25-part-2-the-bm25-algorithm-and-its-variables)."
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": 2,
18+
"id": "393ac030",
19+
"metadata": {},
20+
"outputs": [],
21+
"source": [
22+
"from langchain.retrievers import ElasticSearchBM25Retriever"
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"id": "aaf80e7f",
28+
"metadata": {},
29+
"source": [
30+
"## Create New Retriever"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": 12,
36+
"id": "bcb3c8c2",
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"elasticsearch_url=\"http://localhost:9200\"\n",
41+
"retriever = ElasticSearchBM25Retriever.create(elasticsearch_url, \"langchain-index-3\")"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": 13,
47+
"id": "b605284d",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"# Alternatively, you can load an existing index\n",
52+
"# import elasticsearch\n",
53+
"# elasticsearch_url=\"http://localhost:9200\"\n",
54+
"# retriever = ElasticSearchBM25Retriever(elasticsearch.Elasticsearch(elasticsearch_url), \"langchain-index\")"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"id": "1c518c42",
60+
"metadata": {},
61+
"source": [
62+
"## Add texts (if necessary)\n",
63+
"\n",
64+
"We can optionally add texts to the retriever (if they aren't already in there)"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": 14,
70+
"id": "98b1c017",
71+
"metadata": {},
72+
"outputs": [
73+
{
74+
"data": {
75+
"text/plain": [
76+
"['386c76c9-4355-4c12-aaeb-7b80054caf93',\n",
77+
" 'fffd279c-a0c9-4158-a904-6e242c517c99',\n",
78+
" '7f5528a3-18d0-43b0-894d-f6770a002219',\n",
79+
" 'e2ef5e32-d5bd-44e2-b045-cfc5a8e0a0a1',\n",
80+
" 'cce8ba48-e473-4235-bca2-2c8d65e73ccf']"
81+
]
82+
},
83+
"execution_count": 14,
84+
"metadata": {},
85+
"output_type": "execute_result"
86+
}
87+
],
88+
"source": [
89+
"retriever.add_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])"
90+
]
91+
},
92+
{
93+
"cell_type": "markdown",
94+
"id": "08437fa2",
95+
"metadata": {},
96+
"source": [
97+
"## Use Retriever\n",
98+
"\n",
99+
"We can now use the retriever!"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": 15,
105+
"id": "c0455218",
106+
"metadata": {},
107+
"outputs": [],
108+
"source": [
109+
"result = retriever.get_relevant_documents(\"foo\")"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": 16,
115+
"id": "7dfa5c29",
116+
"metadata": {},
117+
"outputs": [
118+
{
119+
"data": {
120+
"text/plain": [
121+
"[Document(page_content='foo', metadata={}),\n",
122+
" Document(page_content='foo bar', metadata={})]"
123+
]
124+
},
125+
"execution_count": 16,
126+
"metadata": {},
127+
"output_type": "execute_result"
128+
}
129+
],
130+
"source": [
131+
"result"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": null,
137+
"id": "74bd9256",
138+
"metadata": {},
139+
"outputs": [],
140+
"source": []
141+
}
142+
],
143+
"metadata": {
144+
"kernelspec": {
145+
"display_name": "Python 3 (ipykernel)",
146+
"language": "python",
147+
"name": "python3"
148+
},
149+
"language_info": {
150+
"codemirror_mode": {
151+
"name": "ipython",
152+
"version": 3
153+
},
154+
"file_extension": ".py",
155+
"mimetype": "text/x-python",
156+
"name": "python",
157+
"nbconvert_exporter": "python",
158+
"pygments_lexer": "ipython3",
159+
"version": "3.9.1"
160+
}
161+
},
162+
"nbformat": 4,
163+
"nbformat_minor": 5
164+
}

langchain/retrievers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
2+
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
23
from langchain.retrievers.metal import MetalRetriever
34
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
45
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
@@ -8,4 +9,5 @@
89
"RemoteLangChainRetriever",
910
"PineconeHybridSearchRetriever",
1011
"MetalRetriever",
12+
"ElasticSearchBM25Retriever",
1113
]
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Wrapper around Elasticsearch vector database."""
2+
from __future__ import annotations
3+
4+
import uuid
5+
from typing import Any, Iterable, List
6+
7+
from langchain.docstore.document import Document
8+
from langchain.schema import BaseRetriever
9+
10+
11+
class ElasticSearchBM25Retriever(BaseRetriever):
12+
"""Wrapper around Elasticsearch using BM25 as a retrieval method.
13+
14+
15+
To connect to an Elasticsearch instance that requires login credentials,
16+
including Elastic Cloud, use the Elasticsearch URL format
17+
https://username:password@es_host:9243. For example, to connect to Elastic
18+
Cloud, create the Elasticsearch URL with the required authentication details and
19+
pass it to the ElasticVectorSearch constructor as the named parameter
20+
elasticsearch_url.
21+
22+
You can obtain your Elastic Cloud URL and login credentials by logging in to the
23+
Elastic Cloud console at https://cloud.elastic.co, selecting your deployment, and
24+
navigating to the "Deployments" page.
25+
26+
To obtain your Elastic Cloud password for the default "elastic" user:
27+
28+
1. Log in to the Elastic Cloud console at https://cloud.elastic.co
29+
2. Go to "Security" > "Users"
30+
3. Locate the "elastic" user and click "Edit"
31+
4. Click "Reset password"
32+
5. Follow the prompts to reset the password
33+
34+
The format for Elastic Cloud URLs is
35+
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
36+
"""
37+
38+
def __init__(self, client: Any, index_name: str):
39+
self.client = client
40+
self.index_name = index_name
41+
42+
@classmethod
43+
def create(
44+
cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75
45+
) -> ElasticSearchBM25Retriever:
46+
from elasticsearch import Elasticsearch
47+
48+
# Create an Elasticsearch client instance
49+
es = Elasticsearch(elasticsearch_url)
50+
51+
# Define the index settings and mappings
52+
index_settings = {
53+
"settings": {
54+
"analysis": {"analyzer": {"default": {"type": "standard"}}},
55+
"similarity": {
56+
"custom_bm25": {
57+
"type": "BM25",
58+
"k1": k1,
59+
"b": b,
60+
}
61+
},
62+
},
63+
"mappings": {
64+
"properties": {
65+
"content": {
66+
"type": "text",
67+
"similarity": "custom_bm25", # Use the custom BM25 similarity
68+
}
69+
}
70+
},
71+
}
72+
73+
# Create the index with the specified settings and mappings
74+
es.indices.create(index=index_name, body=index_settings)
75+
return cls(es, index_name)
76+
77+
def add_texts(
78+
self,
79+
texts: Iterable[str],
80+
refresh_indices: bool = True,
81+
) -> List[str]:
82+
"""Run more texts through the embeddings and add to the retriver.
83+
84+
Args:
85+
texts: Iterable of strings to add to the retriever.
86+
refresh_indices: bool to refresh ElasticSearch indices
87+
88+
Returns:
89+
List of ids from adding the texts into the retriever.
90+
"""
91+
try:
92+
from elasticsearch.helpers import bulk
93+
except ImportError:
94+
raise ValueError(
95+
"Could not import elasticsearch python package. "
96+
"Please install it with `pip install elasticsearch`."
97+
)
98+
requests = []
99+
ids = []
100+
for i, text in enumerate(texts):
101+
_id = str(uuid.uuid4())
102+
request = {
103+
"_op_type": "index",
104+
"_index": self.index_name,
105+
"content": text,
106+
"_id": _id,
107+
}
108+
ids.append(_id)
109+
requests.append(request)
110+
bulk(self.client, requests)
111+
112+
if refresh_indices:
113+
self.client.indices.refresh(index=self.index_name)
114+
return ids
115+
116+
def get_relevant_documents(self, query: str) -> List[Document]:
117+
query_dict = {"query": {"match": {"content": query}}}
118+
res = self.client.search(index=self.index_name, body=query_dict)
119+
120+
docs = []
121+
for r in res["hits"]["hits"]:
122+
docs.append(Document(page_content=r["_source"]["content"]))
123+
return docs
124+
125+
async def aget_relevant_documents(self, query: str) -> List[Document]:
126+
raise NotImplementedError

0 commit comments

Comments
 (0)