1+ from typing import List
2+
13from .colbert_embedding import ColbertTokenEmbeddings
24
3- from .cassandra_db import CassandraDB
5+ from .cassandra_store import CassandraColBERTVectorStore
46import logging
57from torch import tensor
68import torch
79import math
810
11+ from .vector_store import ColBERTVectorStoreRetriever , Document
12+
913# max similarity between a query vector and a list of embeddings
1014# The function returns the highest similarity score (i.e., the maximum dot product value)
1115# between the query vector and any of the embedding vectors in the list.
@@ -69,31 +73,34 @@ def max_similarity_torch(query_vector, embedding_list, is_cuda: bool = False):
6973 return max_sim
7074
7175
72- class ColbertCassandraRetriever :
73- db : CassandraDB
74- colbertEmbeddings : ColbertTokenEmbeddings
76+ class ColbertCassandraRetriever ( ColBERTVectorStoreRetriever ) :
77+ vector_store : CassandraColBERTVectorStore
78+ colbert_embeddings : ColbertTokenEmbeddings
7579 is_cuda : bool = False
7680
7781 class Config :
7882 arbitrary_types_allowed = True
7983
8084 def __init__ (
8185 self ,
82- db : CassandraDB ,
83- colbertEmbeddings : ColbertTokenEmbeddings ,
84- ** kwargs ,
86+ vector_store : CassandraColBERTVectorStore ,
87+ colbert_embeddings : ColbertTokenEmbeddings ,
8588 ):
86- # initialize pydantic base model
87- self .db = db
88- self .colbertEmbeddings = colbertEmbeddings
89+ self .vector_store = vector_store
90+ self .colbert_embeddings = colbert_embeddings
8991 self .is_cuda = torch .cuda .is_available ()
9092
91- def retrieve (self , query : str , k : int = 10 , query_maxlen : int = 64 , ** kwargs ):
93+ def close (self ):
94+ pass
95+
96+ def retrieve (
97+ self , query : str , k : int = 10 , query_maxlen : int = 64 , ** kwargs
98+ ) -> List [Document ]:
9299 #
93- # if the query has fewer than a predefined number of of tokens Nq,
94- # colbertEmbeddings will pad it with BERT special [mast] token up to length Nq.
100+ # if the query has fewer than a predefined number of tokens Nq,
101+ # colbert_embeddings will pad it with BERT special [mast] token up to length Nq.
95102 #
96- query_encodings = self .colbertEmbeddings .encode_query (
103+ query_encodings = self .colbert_embeddings .encode_query (
97104 query , query_maxlen = query_maxlen
98105 )
99106
@@ -106,8 +113,8 @@ def retrieve(self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs):
106113 doc_futures = []
107114 for qv in query_encodings :
108115 # per token based retrieval
109- doc_future = self .db .session .execute_async (
110- self .db .query_colbert_ann_stmt , [list (qv ), top_k ]
116+ doc_future = self .vector_store .session .execute_async (
117+ self .vector_store .query_colbert_ann_stmt , [list (qv ), top_k ]
111118 )
112119 doc_futures .append (doc_future )
113120
@@ -119,8 +126,8 @@ def retrieve(self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs):
119126 scores = {}
120127 futures = []
121128 for title , part in docparts :
122- future = self .db .session .execute_async (
123- self .db .query_colbert_parts_stmt , [title , part ]
129+ future = self .vector_store .session .execute_async (
130+ self .vector_store .query_colbert_parts_stmt , [title , part ]
124131 )
125132 futures .append ((future , title , part ))
126133
@@ -141,23 +148,18 @@ def retrieve(self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs):
141148 # query the doc body
142149 doc_futures = {}
143150 for title , part in docs_by_score :
144- future = self .db .session .execute_async (
145- self .db .query_part_by_pk_stmt , [title , part ]
151+ future = self .vector_store .session .execute_async (
152+ self .vector_store .query_part_by_pk_stmt , [title , part ]
146153 )
147154 doc_futures [(title , part )] = future
148155
149- answers = []
156+ answers : List [ Document ] = []
150157 rank = 1
151158 for title , part in docs_by_score :
152159 rs = doc_futures [(title , part )].result ()
153160 score = scores [(title , part )]
154161 answers .append (
155- {
156- "title" : title ,
157- "score" : score .item (),
158- "rank" : rank ,
159- "body" : rs .one ().body ,
160- }
162+ Document (title = title , score = score .item (), rank = rank , body = rs .one ().body )
161163 )
162164 rank = rank + 1
163165 # clean up on tensor memory on GPU
0 commit comments