Skip to content

Commit b7dc91b

Browse files
committed
Add mypy for type checking
1 parent c90dd27 commit b7dc91b

File tree

9 files changed

+74
-64
lines changed

9 files changed

+74
-64
lines changed

pyproject.toml

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,26 @@ log_cli_level = "INFO"
4747
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
4848
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
4949

50+
[tool.mypy]
51+
disallow_any_generics = true
52+
disallow_incomplete_defs = true
53+
disallow_untyped_calls = true
54+
disallow_untyped_decorators = true
55+
disallow_untyped_defs = true
56+
follow_imports = "normal"
57+
ignore_missing_imports = true
58+
no_implicit_reexport = true
59+
show_error_codes = true
60+
show_error_context = true
61+
strict_equality = true
62+
strict_optional = true
63+
warn_redundant_casts = true
64+
warn_return_any = true
65+
warn_unused_ignores = true
5066

5167
[build-system]
5268
requires = ["poetry-core"]
5369
build-backend = "poetry.core.masonry.api"
5470

55-
56-
57-
5871
[tool.poetry.group.dev.dependencies]
5972
yamllint = "^1.34.0"

ragstack/colbert/__init__.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from .vector_store import ColBERTVectorStore
66
from .constant import DEFAULT_COLBERT_MODEL, DEFAULT_COLBERT_DIM
77

8-
__all__ = (
9-
ColbertTokenEmbeddings,
10-
CassandraColBERTVectorStore,
11-
ColbertCassandraRetriever,
12-
max_similarity_torch,
13-
PerTokenEmbeddings,
14-
PassageEmbeddings,
15-
TokenEmbeddings,
16-
ColBERTVectorStore,
17-
DEFAULT_COLBERT_MODEL,
18-
DEFAULT_COLBERT_DIM,
19-
)
8+
__all__ = [
9+
"ColbertTokenEmbeddings",
10+
"CassandraColBERTVectorStore",
11+
"ColbertCassandraRetriever",
12+
"max_similarity_torch",
13+
"PerTokenEmbeddings",
14+
"PassageEmbeddings",
15+
"TokenEmbeddings",
16+
"ColBERTVectorStore",
17+
"DEFAULT_COLBERT_MODEL",
18+
"DEFAULT_COLBERT_DIM",
19+
]

ragstack/colbert/cassandra_retriever.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import List
1+
from typing import List, Set, Tuple, Any, Dict
2+
3+
from cassandra.cluster import ResponseFuture
24

35
from .colbert_embedding import ColbertTokenEmbeddings
46

57
from .cassandra_store import CassandraColBERTVectorStore
68
import logging
7-
from torch import tensor
9+
from torch import tensor, Tensor
810
import torch
911
import math
1012

@@ -43,7 +45,9 @@ def max_similarity_numpy_based(query_vector, embedding_list):
4345

4446
# this torch based max similary has the best performance.
4547
# it is at least 20 times faster than dot product operator and numpy based implementation CuDA and CPU
46-
def max_similarity_torch(query_vector, embedding_list, is_cuda: bool = False):
48+
def max_similarity_torch(
49+
query_vector: Tensor, embedding_list: List[Tensor], is_cuda: bool = False
50+
) -> Tensor:
4751
"""
4852
Calculate the maximum similarity (dot product) between a query vector and a list of embedding vectors,
4953
optimized for performance using PyTorch for GPU acceleration.
@@ -59,12 +63,12 @@ def max_similarity_torch(query_vector, embedding_list, is_cuda: bool = False):
5963
# stacks the list of embedding tensors into a single tensor
6064
if is_cuda:
6165
query_vector = query_vector.to("cuda")
62-
embedding_list = torch.stack(embedding_list).to("cuda")
66+
_embedding_list = torch.stack(embedding_list).to("cuda")
6367
else:
64-
embedding_list = torch.stack(embedding_list)
68+
_embedding_list = torch.stack(embedding_list)
6569

6670
# Calculate the dot products in a vectorized manner on the GPU
67-
sims = torch.matmul(embedding_list, query_vector)
71+
sims = torch.matmul(_embedding_list, query_vector)
6872

6973
# Find the maximum similarity (dot product) value
7074
max_sim = torch.max(sims)
@@ -90,11 +94,11 @@ def __init__(
9094
self.colbert_embeddings = colbert_embeddings
9195
self.is_cuda = torch.cuda.is_available()
9296

93-
def close(self):
97+
def close(self) -> None:
9498
pass
9599

96100
def retrieve(
97-
self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs
101+
self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs: Any
98102
) -> List[Document]:
99103
#
100104
# if the query has fewer than a predefined number of tokens Nq,
@@ -109,7 +113,7 @@ def retrieve(
109113
logging.debug(f"query length {len(query)} embeddings top_k: {top_k}")
110114

111115
# find the most relevant documents
112-
docparts = set()
116+
docparts: Set[Tuple[Any, Any]] = set()
113117
doc_futures = []
114118
for qv in query_encodings:
115119
# per token based retrieval
@@ -146,17 +150,17 @@ def retrieve(
146150
docs_by_score = sorted(scores, key=scores.get, reverse=True)[:k]
147151

148152
# query the doc body
149-
doc_futures = {}
153+
doc_futures2: Dict[Tuple[Any, Any], ResponseFuture] = {}
150154
for title, part in docs_by_score:
151155
future = self.vector_store.session.execute_async(
152156
self.vector_store.query_part_by_pk_stmt, [title, part]
153157
)
154-
doc_futures[(title, part)] = future
158+
doc_futures2[(title, part)] = future
155159

156160
answers: List[Document] = []
157161
rank = 1
158162
for title, part in docs_by_score:
159-
rs = doc_futures[(title, part)].result()
163+
rs = doc_futures2[(title, part)].result()
160164
score = scores[(title, part)]
161165
answers.append(
162166
Document(title=title, score=score.item(), rank=rank, body=rs.one().body)

ragstack/colbert/cassandra_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, session: Session, keyspace: str, table_name: str):
6565
"""
6666
)
6767

68-
def __create_tables(self):
68+
def __create_tables(self) -> None:
6969
self.session.execute(
7070
f"""
7171
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
@@ -112,10 +112,10 @@ def put_document(
112112
) -> None:
113113
return self.insert_colbert_embeddings_chunks(embeddings, delete_existed_passage)
114114

115-
def delete_documents(self, titles: List[str]):
115+
def delete_documents(self, titles: List[str]) -> None:
116116
execute_concurrent_with_args(
117117
self.session, self.delete_part_by_title_stmt, [(t,) for t in titles]
118118
)
119119

120-
def close(self):
120+
def close(self) -> None:
121121
pass

ragstack/colbert/colbert_embedding.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,6 @@ def __init__(
106106
query_maxlen=query_maxlen,
107107
gpus=total_visible_gpus,
108108
)
109-
self.__doc_maxlen = doc_maxlen
110-
self.__nbits = nbits
111-
self.__kmeans_niters = kmeans_niters
112-
self.__nranks = nranks
113109
logging.info("creating checkpoint")
114110
self.checkpoint = Checkpoint(
115111
self.colbert_config.checkpoint, colbert_config=self.colbert_config
@@ -148,7 +144,7 @@ def encode_queries(
148144
# the length does not grow or shrink despite the number of tokens in the query
149145
# we continue to use the same term to align with ColBERT documentation/library
150146
query_maxlen: int = -1,
151-
):
147+
) -> Tensor:
152148
queries = query if isinstance(query, list) else [query]
153149
bsize = 128 if len(queries) > 128 else None
154150

@@ -179,7 +175,7 @@ def encode_query(
179175
query: str,
180176
full_length_search: bool = False,
181177
query_maxlen: int = 32,
182-
):
178+
) -> Tensor:
183179
queries = self.encode_queries(
184180
query, full_length_search, query_maxlen=query_maxlen
185181
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .retriever import ColBERTVectorStoreLangChainRetriever
22

3-
__all__ = (ColBERTVectorStoreLangChainRetriever,)
3+
__all__ = ["ColBERTVectorStoreLangChainRetriever"]

ragstack/colbert/langchain/retriever.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class ColBERTVectorStoreLangChainRetriever(BaseRetriever):
2121
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
2222
qa.run("what happened on June 4th?")
2323
"""
24+
2425
retriever: ColBERTVectorStoreRetriever = Field(default=None)
2526
kwargs: dict = {}
2627
k: int = 10

ragstack/colbert/token_embedding.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# this is a base class for ColBERT per token based embedding
33

44
from abc import ABC, abstractmethod
5-
from typing import List
5+
from typing import List, Optional
66
from .constant import DEFAULT_COLBERT_DIM, DEFAULT_COLBERT_MODEL
77
import uuid
88

@@ -14,7 +14,7 @@ def __init__(
1414
self,
1515
id: int,
1616
part: int,
17-
parent_id: uuid.UUID = None,
17+
parent_id: Optional[uuid.UUID] = None,
1818
title: str = "",
1919
):
2020
self.id = id
@@ -23,19 +23,19 @@ def __init__(
2323
self.title = title
2424
self.part = part
2525

26-
def add_embeddings(self, embeddings: List[float]):
26+
def add_embeddings(self, embeddings: List[float]) -> None:
2727
self.__embeddings = embeddings
2828

2929
def get_embeddings(self) -> List[float]:
3030
return self.__embeddings
3131

32-
def id(self):
32+
def id(self) -> int:
3333
return self.id
3434

35-
def parent_id(self):
35+
def parent_id(self) -> Optional[uuid.UUID]:
3636
return self.parent_id
3737

38-
def part(self):
38+
def part(self) -> int:
3939
return self.part
4040

4141

@@ -50,7 +50,7 @@ def __init__(
5050
text: str,
5151
title: str = "",
5252
part: int = 0,
53-
id: uuid.UUID = None,
53+
id: Optional[uuid.UUID] = None,
5454
model: str = DEFAULT_COLBERT_MODEL,
5555
dim: int = DEFAULT_COLBERT_DIM,
5656
):
@@ -65,31 +65,31 @@ def __init__(
6565
self.__title = title
6666
self.__part = part
6767

68-
def model(self):
68+
def model(self) -> str:
6969
return self.__model
7070

71-
def dim(self):
71+
def dim(self) -> int:
7272
return self.__dim
7373

74-
def token_size(self):
74+
def token_size(self) -> int:
7575
return len(self.token_ids)
7676

77-
def title(self):
77+
def title(self) -> str:
7878
return self.__title
7979

80-
def __len__(self):
80+
def __len__(self) -> int:
8181
return len(self.embeddings)
8282

83-
def id(self):
83+
def id(self) -> uuid.UUID:
8484
return self.__id
8585

86-
def part(self):
86+
def part(self) -> int:
8787
return self.__part
8888

89-
def add_token_embeddings(self, token_embeddings: PerTokenEmbeddings):
89+
def add_token_embeddings(self, token_embeddings: PerTokenEmbeddings) -> None:
9090
self.__token_embeddings.append(token_embeddings)
9191

92-
def get_token_embeddings(self, token_id: int) -> PerTokenEmbeddings:
92+
def get_token_embeddings(self, token_id: int) -> Optional[PerTokenEmbeddings]:
9393
for token in self.__token_embeddings:
9494
if token.token_id == token_id:
9595
return token
@@ -98,7 +98,7 @@ def get_token_embeddings(self, token_id: int) -> PerTokenEmbeddings:
9898
def get_all_token_embeddings(self) -> List[PerTokenEmbeddings]:
9999
return self.__token_embeddings
100100

101-
def get_text(self):
101+
def get_text(self) -> str:
102102
return self.__text
103103

104104

ragstack/colbert/vector_store.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,22 @@
55
import dataclasses
66
from abc import ABC, abstractmethod
77
from numbers import Number
8-
from typing import List, Optional
8+
from typing import List, Optional, Any
99

1010

1111
class ColBERTVectorStore(ABC):
1212
"""Interface for a vector store."""
1313

1414
@abstractmethod
15-
def close(self):
15+
def close(self) -> None:
1616
"""Close the store."""
17-
pass
1817

1918
@abstractmethod
20-
def put_document(self, document: str, metadata: dict):
19+
def put_document(self, document: str, metadata: dict) -> None:
2120
"""Put a document into the store."""
22-
pass
2321

2422
@abstractmethod
25-
def delete_documents(self, titles: List[str]):
23+
def delete_documents(self, titles: List[str]) -> None:
2624
"""Delete a document from the store."""
2725
pass
2826

@@ -37,13 +35,11 @@ class Document:
3735

3836
class ColBERTVectorStoreRetriever(ABC):
3937
@abstractmethod
40-
def close(self):
38+
def close(self) -> None:
4139
"""Close the store."""
42-
pass
4340

4441
@abstractmethod
4542
def retrieve(
46-
self, query: str, k: Optional[int], query_maxlen: Optional[int], **kwargs
43+
self, query: str, k: Optional[int], query_maxlen: Optional[int], **kwargs: Any
4744
) -> List[Document]:
4845
"""Retrieve documents from the store"""
49-
pass

0 commit comments

Comments
 (0)