Skip to content

Commit 0ef1fa5

Browse files
feat: Add chat store init methods (#39)
Co-authored-by: Averi Kitsch <akitsch@google.com>
1 parent 6ce6ba1 commit 0ef1fa5

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

src/llama_index_cloud_sql_pg/engine.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,91 @@ def init_index_store_table(
756756
)
757757
)
758758

759+
async def _ainit_chat_store_table(
760+
self,
761+
table_name: str,
762+
schema_name: str = "public",
763+
overwrite_existing: bool = False,
764+
) -> None:
765+
"""
766+
Create an table to save chat store.
767+
Args:
768+
table_name (str): The table name to store chat history.
769+
schema_name (str): The schema name to store the chat store table.
770+
Default: "public".
771+
overwrite_existing (bool): Whether to drop existing table.
772+
Default: False.
773+
Returns:
774+
None
775+
"""
776+
if overwrite_existing:
777+
async with self._pool.connect() as conn:
778+
await conn.execute(
779+
text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')
780+
)
781+
await conn.commit()
782+
783+
create_table_query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
784+
id SERIAL PRIMARY KEY,
785+
key VARCHAR NOT NULL,
786+
message JSON NOT NULL
787+
);"""
788+
create_index_query = f"""CREATE INDEX "{table_name}_idx_key" ON "{schema_name}"."{table_name}" (key);"""
789+
async with self._pool.connect() as conn:
790+
await conn.execute(text(create_table_query))
791+
await conn.execute(text(create_index_query))
792+
await conn.commit()
793+
794+
async def ainit_chat_store_table(
795+
self,
796+
table_name: str,
797+
schema_name: str = "public",
798+
overwrite_existing: bool = False,
799+
) -> None:
800+
"""
801+
Create an table to save chat store.
802+
Args:
803+
table_name (str): The table name to store chat store.
804+
schema_name (str): The schema name to store the chat store table.
805+
Default: "public".
806+
overwrite_existing (bool): Whether to drop existing table.
807+
Default: False.
808+
Returns:
809+
None
810+
"""
811+
await self._run_as_async(
812+
self._ainit_chat_store_table(
813+
table_name,
814+
schema_name,
815+
overwrite_existing,
816+
)
817+
)
818+
819+
def init_chat_store_table(
820+
self,
821+
table_name: str,
822+
schema_name: str = "public",
823+
overwrite_existing: bool = False,
824+
) -> None:
825+
"""
826+
Create an table to save chat store.
827+
Args:
828+
table_name (str): The table name to store chat store.
829+
schema_name (str): The schema name to store the chat store table.
830+
Default: "public".
831+
overwrite_existing (bool): Whether to drop existing table.
832+
Default: False.
833+
Returns:
834+
None
835+
"""
836+
self._run_as_sync(
837+
self._ainit_chat_store_table(
838+
table_name,
839+
schema_name,
840+
overwrite_existing,
841+
)
842+
)
843+
759844
async def _aload_table_schema(
760845
self, table_name: str, schema_name: str = "public"
761846
) -> Table:

tests/test_engine.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
DEFAULT_IS_TABLE_SYNC = "index_store_" + str(uuid.uuid4())
3535
DEFAULT_VS_TABLE = "vector_store_" + str(uuid.uuid4())
3636
DEFAULT_VS_TABLE_SYNC = "vector_store_" + str(uuid.uuid4())
37+
DEFAULT_CS_TABLE = "chat_store_" + str(uuid.uuid4())
38+
DEFAULT_CS_TABLE_SYNC = "chat_store_" + str(uuid.uuid4())
3739
VECTOR_SIZE = 768
3840

3941

@@ -113,6 +115,7 @@ async def engine(self, db_project, db_region, db_instance, db_name):
113115
await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"')
114116
await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"')
115117
await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"')
118+
await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"')
116119
await engine.close()
117120

118121
async def test_password(
@@ -296,6 +299,22 @@ async def test_init_index_store(self, engine):
296299
for row in results:
297300
assert row in expected
298301

302+
async def test_init_chat_store(self, engine):
303+
await engine.ainit_chat_store_table(
304+
table_name=DEFAULT_CS_TABLE,
305+
schema_name="public",
306+
overwrite_existing=True,
307+
)
308+
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE}';"
309+
results = await afetch(engine, stmt)
310+
expected = [
311+
{"column_name": "id", "data_type": "integer"},
312+
{"column_name": "key", "data_type": "character varying"},
313+
{"column_name": "message", "data_type": "json"},
314+
]
315+
for row in results:
316+
assert row in expected
317+
299318

300319
@pytest.mark.asyncio
301320
class TestEngineSync:
@@ -343,6 +362,7 @@ async def engine(self, db_project, db_region, db_instance, db_name):
343362
await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"')
344363
await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"')
345364
await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"')
365+
await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"')
346366
await engine.close()
347367

348368
async def test_password(
@@ -461,3 +481,19 @@ async def test_init_index_store(self, engine):
461481
]
462482
for row in results:
463483
assert row in expected
484+
485+
async def test_init_chat_store(self, engine):
486+
engine.init_chat_store_table(
487+
table_name=DEFAULT_CS_TABLE_SYNC,
488+
schema_name="public",
489+
overwrite_existing=True,
490+
)
491+
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE_SYNC}';"
492+
results = await afetch(engine, stmt)
493+
expected = [
494+
{"column_name": "id", "data_type": "integer"},
495+
{"column_name": "key", "data_type": "character varying"},
496+
{"column_name": "message", "data_type": "json"},
497+
]
498+
for row in results:
499+
assert row in expected

0 commit comments

Comments
 (0)