1- from typing import List
1+ from typing import List , Set , Tuple , Any , Dict
2+
3+ from cassandra .cluster import ResponseFuture
24
35from .colbert_embedding import ColbertTokenEmbeddings
46
57from .cassandra_store import CassandraColBERTVectorStore
68import logging
7- from torch import tensor
9+ from torch import tensor , Tensor
810import torch
911import math
1012
@@ -43,7 +45,9 @@ def max_similarity_numpy_based(query_vector, embedding_list):
4345
4446# this torch based max similary has the best performance.
4547# it is at least 20 times faster than dot product operator and numpy based implementation CuDA and CPU
46- def max_similarity_torch (query_vector , embedding_list , is_cuda : bool = False ):
48+ def max_similarity_torch (
49+ query_vector : Tensor , embedding_list : List [Tensor ], is_cuda : bool = False
50+ ) -> Tensor :
4751 """
4852 Calculate the maximum similarity (dot product) between a query vector and a list of embedding vectors,
4953 optimized for performance using PyTorch for GPU acceleration.
@@ -59,12 +63,12 @@ def max_similarity_torch(query_vector, embedding_list, is_cuda: bool = False):
5963 # stacks the list of embedding tensors into a single tensor
6064 if is_cuda :
6165 query_vector = query_vector .to ("cuda" )
62- embedding_list = torch .stack (embedding_list ).to ("cuda" )
66+ _embedding_list = torch .stack (embedding_list ).to ("cuda" )
6367 else :
64- embedding_list = torch .stack (embedding_list )
68+ _embedding_list = torch .stack (embedding_list )
6569
6670 # Calculate the dot products in a vectorized manner on the GPU
67- sims = torch .matmul (embedding_list , query_vector )
71+ sims = torch .matmul (_embedding_list , query_vector )
6872
6973 # Find the maximum similarity (dot product) value
7074 max_sim = torch .max (sims )
@@ -90,11 +94,11 @@ def __init__(
9094 self .colbert_embeddings = colbert_embeddings
9195 self .is_cuda = torch .cuda .is_available ()
9296
93- def close (self ):
97+ def close (self ) -> None :
9498 pass
9599
96100 def retrieve (
97- self , query : str , k : int = 10 , query_maxlen : int = 64 , ** kwargs
101+ self , query : str , k : int = 10 , query_maxlen : int = 64 , ** kwargs : Any
98102 ) -> List [Document ]:
99103 #
100104 # if the query has fewer than a predefined number of tokens Nq,
@@ -109,7 +113,7 @@ def retrieve(
109113 logging .debug (f"query length { len (query )} embeddings top_k: { top_k } " )
110114
111115 # find the most relevant documents
112- docparts = set ()
116+ docparts : Set [ Tuple [ Any , Any ]] = set ()
113117 doc_futures = []
114118 for qv in query_encodings :
115119 # per token based retrieval
@@ -146,17 +150,17 @@ def retrieve(
146150 docs_by_score = sorted (scores , key = scores .get , reverse = True )[:k ]
147151
148152 # query the doc body
149- doc_futures = {}
153+ doc_futures2 : Dict [ Tuple [ Any , Any ], ResponseFuture ] = {}
150154 for title , part in docs_by_score :
151155 future = self .vector_store .session .execute_async (
152156 self .vector_store .query_part_by_pk_stmt , [title , part ]
153157 )
154- doc_futures [(title , part )] = future
158+ doc_futures2 [(title , part )] = future
155159
156160 answers : List [Document ] = []
157161 rank = 1
158162 for title , part in docs_by_score :
159- rs = doc_futures [(title , part )].result ()
163+ rs = doc_futures2 [(title , part )].result ()
160164 score = scores [(title , part )]
161165 answers .append (
162166 Document (title = title , score = score .item (), rank = rank , body = rs .one ().body )
0 commit comments