Skip to content

Commit 384b5a5

Browse files
authored
colbert: add lint to colbert package and tests (#340)
1 parent d375d61 commit 384b5a5

15 files changed

+317
-348
lines changed

.github/workflows/ci-unit-tests.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ jobs:
4646
poetry install --no-root -E colbert
4747
poetry build
4848
49+
- name: "Lint"
50+
run: |
51+
tox -e lint
52+
4953
- name: Run ragstack-ai unit and integration tests
50-
env:
51-
COLBERT_ASTRA_TOKEN: ${{ secrets.COLBERT_ASTRA_TOKEN }}
52-
COLBERT_ASTRA_SCB: ${{ secrets.COLBERT_ASTRA_SCB }}
5354
run: |
54-
tox
55+
tox -e tests

pyproject.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,23 @@ torch = { version = "2.2.1", optional = true }
3131
[tool.poetry.extras]
3232
langchain-google = ["langchain-google-genai", "langchain-google-vertexai"]
3333
langchain-nvidia = ["langchain-nvidia-ai-endpoints"]
34-
colbert = ["colbert-ai", "pyarrow", "torch", "cassio"]
34+
colbert = ["colbert-ai", "pyarrow", "torch"]
3535

3636
[tool.poetry.group.test.dependencies]
3737
pytest = "*"
38+
black = "*"
39+
ruff = "*"
3840
nbmake = "*"
3941
testcontainers = "^3.7.1"
4042
tox = "^4"
4143

44+
[tool.pytest.ini_options]
45+
log_cli = true
46+
log_cli_level = "INFO"
47+
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
48+
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
49+
50+
4251
[build-system]
4352
requires = ["poetry-core"]
4453
build-backend = "poetry.core.masonry.api"

ragstack/colbert/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from .colbert_embedding import ColbertTokenEmbeddings, calculate_query_maxlen
2-
from .cassandra_db import CassandraDB
1+
from .colbert_embedding import ColbertTokenEmbeddings
2+
from .cassandra_store import CassandraColBERTVectorStore
33
from .cassandra_retriever import ColbertCassandraRetriever, max_similarity_torch
44
from .token_embedding import PerTokenEmbeddings, PassageEmbeddings, TokenEmbeddings
55
from .vector_store import ColBERTVectorStore
66
from .constant import DEFAULT_COLBERT_MODEL, DEFAULT_COLBERT_DIM
77

88
__all__ = (
99
ColbertTokenEmbeddings,
10-
CassandraDB,
10+
CassandraColBERTVectorStore,
1111
ColbertCassandraRetriever,
1212
max_similarity_torch,
1313
PerTokenEmbeddings,

ragstack/colbert/cassandra_db.py

Lines changed: 0 additions & 179 deletions
This file was deleted.

ragstack/colbert/cassandra_retriever.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
from typing import List
2+
13
from .colbert_embedding import ColbertTokenEmbeddings
24

3-
from .cassandra_db import CassandraDB
5+
from .cassandra_store import CassandraColBERTVectorStore
46
import logging
57
from torch import tensor
68
import torch
79
import math
810

11+
from .vector_store import ColBERTVectorStoreRetriever, Document
12+
913
# max similarity between a query vector and a list of embeddings
1014
# The function returns the highest similarity score (i.e., the maximum dot product value)
1115
# between the query vector and any of the embedding vectors in the list.
@@ -69,31 +73,34 @@ def max_similarity_torch(query_vector, embedding_list, is_cuda: bool = False):
6973
return max_sim
7074

7175

72-
class ColbertCassandraRetriever:
73-
db: CassandraDB
74-
colbertEmbeddings: ColbertTokenEmbeddings
76+
class ColbertCassandraRetriever(ColBERTVectorStoreRetriever):
77+
vector_store: CassandraColBERTVectorStore
78+
colbert_embeddings: ColbertTokenEmbeddings
7579
is_cuda: bool = False
7680

7781
class Config:
7882
arbitrary_types_allowed = True
7983

8084
def __init__(
8185
self,
82-
db: CassandraDB,
83-
colbertEmbeddings: ColbertTokenEmbeddings,
84-
**kwargs,
86+
vector_store: CassandraColBERTVectorStore,
87+
colbert_embeddings: ColbertTokenEmbeddings,
8588
):
86-
# initialize pydantic base model
87-
self.db = db
88-
self.colbertEmbeddings = colbertEmbeddings
89+
self.vector_store = vector_store
90+
self.colbert_embeddings = colbert_embeddings
8991
self.is_cuda = torch.cuda.is_available()
9092

91-
def retrieve(self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs):
93+
def close(self):
94+
pass
95+
96+
def retrieve(
97+
self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs
98+
) -> List[Document]:
9299
#
93-
# if the query has fewer than a predefined number of of tokens Nq,
94-
# colbertEmbeddings will pad it with BERT special [mast] token up to length Nq.
100+
# if the query has fewer than a predefined number of tokens Nq,
101+
# colbert_embeddings will pad it with BERT special [mast] token up to length Nq.
95102
#
96-
query_encodings = self.colbertEmbeddings.encode_query(
103+
query_encodings = self.colbert_embeddings.encode_query(
97104
query, query_maxlen=query_maxlen
98105
)
99106

@@ -106,8 +113,8 @@ def retrieve(self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs):
106113
doc_futures = []
107114
for qv in query_encodings:
108115
# per token based retrieval
109-
doc_future = self.db.session.execute_async(
110-
self.db.query_colbert_ann_stmt, [list(qv), top_k]
116+
doc_future = self.vector_store.session.execute_async(
117+
self.vector_store.query_colbert_ann_stmt, [list(qv), top_k]
111118
)
112119
doc_futures.append(doc_future)
113120

@@ -119,8 +126,8 @@ def retrieve(self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs):
119126
scores = {}
120127
futures = []
121128
for title, part in docparts:
122-
future = self.db.session.execute_async(
123-
self.db.query_colbert_parts_stmt, [title, part]
129+
future = self.vector_store.session.execute_async(
130+
self.vector_store.query_colbert_parts_stmt, [title, part]
124131
)
125132
futures.append((future, title, part))
126133

@@ -141,23 +148,18 @@ def retrieve(self, query: str, k: int = 10, query_maxlen: int = 64, **kwargs):
141148
# query the doc body
142149
doc_futures = {}
143150
for title, part in docs_by_score:
144-
future = self.db.session.execute_async(
145-
self.db.query_part_by_pk_stmt, [title, part]
151+
future = self.vector_store.session.execute_async(
152+
self.vector_store.query_part_by_pk_stmt, [title, part]
146153
)
147154
doc_futures[(title, part)] = future
148155

149-
answers = []
156+
answers: List[Document] = []
150157
rank = 1
151158
for title, part in docs_by_score:
152159
rs = doc_futures[(title, part)].result()
153160
score = scores[(title, part)]
154161
answers.append(
155-
{
156-
"title": title,
157-
"score": score.item(),
158-
"rank": rank,
159-
"body": rs.one().body,
160-
}
162+
Document(title=title, score=score.item(), rank=rank, body=rs.one().body)
161163
)
162164
rank = rank + 1
163165
# clean up on tensor memory on GPU

0 commit comments

Comments
 (0)