Skip to content

Commit 5b80694

Browse files
authored
feat: adding search functions and tests (#56)
1 parent 2e30b48 commit 5b80694

File tree

10 files changed

+684
-25
lines changed

10 files changed

+684
-25
lines changed

integration.cloudbuild.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
steps:
1616
- id: Install dependencies
1717
name: python:${_VERSION}
18-
entrypoint: pip
19-
args: ["install", "--user", "-r", "requirements.txt"]
18+
entrypoint: /bin/bash
19+
args:
20+
- -c
21+
- |
22+
if [[ $_VERSION == "3.8" ]]; then version="-3.8"; fi
23+
pip install --user -r requirements${version}.txt
2024
2125
- id: Install module (and test requirements)
2226
name: python:${_VERSION}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ authors = [
1111
dependencies = [
1212
"langchain-core>=0.1.1, <1.0.0",
1313
"langchain-community>=0.0.18, <1.0.0",
14+
"numpy>=1.24.4, <2.0.0",
1415
"SQLAlchemy>=2.0.7, <3.0.0",
1516
"cloud-sql-python-connector[pymysql]>=1.7.0, <2.0.0"
1617
]

requirements-3.8.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
langchain==0.1.12
2+
langchain-community==0.0.28
3+
numpy==1.24.4
4+
SQLAlchemy==2.0.28
5+
cloud-sql-python-connector[pymysql]==1.8.0

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
langchain==0.1.12
22
langchain-community==0.0.28
3+
numpy==1.26.4
34
SQLAlchemy==2.0.28
45
cloud-sql-python-connector[pymysql]==1.8.0
56

src/langchain_google_cloud_sql_mysql/engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,13 @@ def _fetch(self, query: str, params: Optional[dict] = None):
241241
result_fetch = result_map.fetchall()
242242
return result_fetch
243243

244+
def _fetch_rows(self, query: str, params: Optional[dict] = None):
245+
"""Fetch results from a SQL query as rows."""
246+
with self.engine.connect() as conn:
247+
result = conn.execute(sqlalchemy.text(query), params)
248+
result_fetch = result.fetchall() # Directly fetch rows
249+
return result_fetch
250+
244251
def init_chat_history_table(self, table_name: str) -> None:
245252
"""Create table with schema required for MySQLChatMessageHistory class.
246253

src/langchain_google_cloud_sql_mysql/indexes.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,33 @@ class SearchType(Enum):
2929
ANN = "ANN"
3030

3131

32+
class DistanceMeasure(Enum):
33+
"""Enumerates the types of distance measures that can be used in searches.
34+
35+
Attributes:
36+
COSINE: Cosine similarity measure.
37+
L2_SQUARED: Squared L2 norm (Euclidean) distance.
38+
DOT_PRODUCT: Dot product similarity.
39+
"""
40+
41+
COSINE = "cosine"
42+
L2_SQUARED = "l2_squared"
43+
DOT_PRODUCT = "dot_product"
44+
45+
3246
@dataclass
3347
class QueryOptions:
3448
"""Holds configuration options for executing a search query.
3549
3650
Attributes:
3751
num_partitions (Optional[int]): The number of partitions to divide the search space into. None means default partitioning.
38-
num_neighbors (Optional[int]): The number of nearest neighbors to retrieve. None means use the default.
52+
num_neighbors (int): The number of nearest neighbors to retrieve. Default to 10.
3953
search_type (SearchType): The type of search algorithm to use. Defaults to KNN.
4054
"""
4155

4256
num_partitions: Optional[int] = None
43-
num_neighbors: Optional[int] = None
57+
num_neighbors: int = 10
58+
distance_measure: DistanceMeasure = DistanceMeasure.L2_SQUARED
4459
search_type: SearchType = SearchType.KNN
4560

4661

@@ -61,20 +76,6 @@ class IndexType(Enum):
6176
TREE_SQ = "TREE_SQ"
6277

6378

64-
class DistanceMeasure(Enum):
65-
"""Enumerates the types of distance measures that can be used in searches.
66-
67-
Attributes:
68-
COSINE: Cosine similarity measure.
69-
SQUARED_L2: Squared L2 norm (Euclidean) distance.
70-
DOT_PRODUCT: Dot product similarity.
71-
"""
72-
73-
COSINE = "cosine"
74-
SQUARED_L2 = "squared_l2"
75-
DOT_PRODUCT = "dot_product"
76-
77-
7879
class VectorIndex:
7980
"""Represents a vector index for storing and querying vectors.
8081

src/langchain_google_cloud_sql_mysql/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _parse_doc_from_row(
2929
content_columns: Iterable[str],
3030
metadata_columns: Iterable[str],
3131
row: Dict,
32-
metadata_json_column: str = DEFAULT_METADATA_COL,
32+
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
3333
) -> Document:
3434
page_content = " ".join(
3535
str(row[column]) for column in content_columns if column in row

0 commit comments

Comments
 (0)