4545 {"name" : "ReJSON" , "ver" : 20404 }
4646]
4747REDIS_DEFAULT_ESCAPED_CHARS = re .compile (r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" )
48- REDIS_SEARCH_SCHEMA = {
49- "document_id" : TagField ("$.document_id" , as_name = "document_id" ),
50- "metadata" : {
51- # "source_id": TagField("$.metadata.source_id", as_name="source_id"),
52- "source" : TagField ("$.metadata.source" , as_name = "source" ),
53- # "author": TextField("$.metadata.author", as_name="author"),
54- # "created_at": NumericField("$.metadata.created_at", as_name="created_at"),
55- },
56- "embedding" : VectorField (
57- "$.embedding" ,
58- REDIS_INDEX_TYPE ,
59- {
60- "TYPE" : "FLOAT64" ,
61- "DIM" : VECTOR_DIMENSION ,
62- "DISTANCE_METRIC" : REDIS_DISTANCE_METRIC ,
63- },
64- as_name = "embedding" ,
65- ),
66- }
6748
6849# Helper functions
6950def unpack_schema (d : dict ):
@@ -82,22 +63,23 @@ async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]):
8263 error_message = "You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. " \
8364 "Please refer to Redis Stack docs: https://redis.io/docs/stack/"
8465 logging .error (error_message )
85- raise ValueError (error_message )
66+ raise AttributeError (error_message )
8667
8768
8869
8970class RedisDataStore (DataStore ):
90- def __init__ (self , client : redis .Redis ):
71+ def __init__ (self , client : redis .Redis , redisearch_schema ):
9172 self .client = client
73+ self ._schema = redisearch_schema
9274 # Init default metadata with sentinel values in case the document written has no metadata
9375 self ._default_metadata = {
94- field : "_null_" for field in REDIS_SEARCH_SCHEMA ["metadata" ]
76+ field : "_null_" for field in redisearch_schema ["metadata" ]
9577 }
9678
9779 ### Redis Helper Methods ###
9880
9981 @classmethod
100- async def init (cls ):
82+ async def init (cls , ** kwargs ):
10183 """
10284 Setup the index if it does not exist.
10385 """
@@ -112,7 +94,27 @@ async def init(cls):
11294 raise e
11395
11496 await _check_redis_module_exist (client , modules = REDIS_REQUIRED_MODULES )
115-
97+
98+ dim = kwargs .get ("dim" , VECTOR_DIMENSION )
99+ redisearch_schema = {
100+ "document_id" : TagField ("$.document_id" , as_name = "document_id" ),
101+ "metadata" : {
102+ "source_id" : TagField ("$.metadata.source_id" , as_name = "source_id" ),
103+ "source" : TagField ("$.metadata.source" , as_name = "source" ),
104+ "author" : TextField ("$.metadata.author" , as_name = "author" ),
105+ "created_at" : NumericField ("$.metadata.created_at" , as_name = "created_at" ),
106+ },
107+ "embedding" : VectorField (
108+ "$.embedding" ,
109+ REDIS_INDEX_TYPE ,
110+ {
111+ "TYPE" : "FLOAT64" ,
112+ "DIM" : dim ,
113+ "DISTANCE_METRIC" : REDIS_DISTANCE_METRIC ,
114+ },
115+ as_name = "embedding" ,
116+ ),
117+ }
116118 try :
117119 # Check for existence of RediSearch Index
118120 await client .ft (REDIS_INDEX_NAME ).info ()
@@ -123,11 +125,12 @@ async def init(cls):
123125 definition = IndexDefinition (
124126 prefix = [REDIS_DOC_PREFIX ], index_type = IndexType .JSON
125127 )
126- fields = list (unpack_schema (REDIS_SEARCH_SCHEMA ))
128+ fields = list (unpack_schema (redisearch_schema ))
129+ logging .info (f"Creating index with fields: { fields } " )
127130 await client .ft (REDIS_INDEX_NAME ).create_index (
128131 fields = fields , definition = definition
129132 )
130- return cls (client )
133+ return cls (client , redisearch_schema )
131134
132135 @staticmethod
133136 def _redis_key (document_id : str , chunk_id : str ) -> str :
@@ -217,20 +220,21 @@ def _typ_to_str(typ, field, value) -> str: # type: ignore
217220
218221 # Build filter
219222 if query .filter :
223+ redisearch_schema = self ._schema
220224 for field , value in query .filter .__dict__ .items ():
221225 if not value :
222226 continue
223- if field in REDIS_SEARCH_SCHEMA :
224- filter_str += _typ_to_str (REDIS_SEARCH_SCHEMA [field ], field , value )
225- elif field in REDIS_SEARCH_SCHEMA ["metadata" ]:
227+ if field in redisearch_schema :
228+ filter_str += _typ_to_str (redisearch_schema [field ], field , value )
229+ elif field in redisearch_schema ["metadata" ]:
226230 if field == "source" : # handle the enum
227231 value = value .value
228232 filter_str += _typ_to_str (
229- REDIS_SEARCH_SCHEMA ["metadata" ][field ], field , value
233+ redisearch_schema ["metadata" ][field ], field , value
230234 )
231235 elif field in ["start_date" , "end_date" ]:
232236 filter_str += _typ_to_str (
233- REDIS_SEARCH_SCHEMA ["metadata" ]["created_at" ], field , value
237+ redisearch_schema ["metadata" ]["created_at" ], field , value
234238 )
235239
236240 # Postprocess filter string
0 commit comments