Skip to content

Commit b93bde9

Browse files
feat: add MSSQLChatMessageHistory class (#9)
1 parent 3aef1c5 commit b93bde9

File tree

4 files changed

+247
-1
lines changed

4 files changed

+247
-1
lines changed

src/langchain_google_cloud_sql_mssql/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from langchain_google_cloud_sql_mssql.mssql_chat_message_history import (
16+
MSSQLChatMessageHistory,
17+
)
1518
from langchain_google_cloud_sql_mssql.mssql_engine import MSSQLEngine
1619
from langchain_google_cloud_sql_mssql.mssql_loader import MSSQLLoader
1720

18-
__all__ = ["MSSQLEngine", "MSSQLLoader"]
21+
__all__ = ["MSSQLChatMessageHistory", "MSSQLEngine", "MSSQLLoader"]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import json
15+
from typing import List
16+
17+
import sqlalchemy
18+
from langchain_core.chat_history import BaseChatMessageHistory
19+
from langchain_core.messages import BaseMessage, messages_from_dict
20+
21+
from langchain_google_cloud_sql_mssql.mssql_engine import MSSQLEngine
22+
23+
24+
class MSSQLChatMessageHistory(BaseChatMessageHistory):
25+
"""Chat message history stored in a Cloud SQL MSSQL database.
26+
27+
Args:
28+
engine (MSSQLEngine): SQLAlchemy connection pool engine for managing
29+
connections to Cloud SQL for SQL Server.
30+
session_id (str): Arbitrary key that is used to store the messages
31+
of a single chat session.
32+
table_name (str): The name of the table to use for storing/retrieving
33+
the chat message history.
34+
"""
35+
36+
def __init__(
37+
self,
38+
engine: MSSQLEngine,
39+
session_id: str,
40+
table_name: str,
41+
) -> None:
42+
self.engine = engine
43+
self.session_id = session_id
44+
self.table_name = table_name
45+
self._verify_schema()
46+
47+
def _verify_schema(self) -> None:
48+
"""Verify table exists with required schema for MSSQLChatMessageHistory class.
49+
50+
Use helper method MSSQLEngine.create_chat_history_table(...) to create
51+
table with valid schema.
52+
"""
53+
insp = sqlalchemy.inspect(self.engine.engine)
54+
# check table exists
55+
if insp.has_table(self.table_name):
56+
# check that all required columns are present
57+
required_columns = ["id", "session_id", "data", "type"]
58+
column_names = [
59+
c["name"] for c in insp.get_columns(table_name=self.table_name)
60+
]
61+
if not (all(x in column_names for x in required_columns)):
62+
raise IndexError(
63+
f"Table '{self.table_name}' has incorrect schema. Got "
64+
f"column names '{column_names}' but required column names "
65+
f"'{required_columns}'.\nPlease create table with following schema:"
66+
f"\nCREATE TABLE {self.table_name} ("
67+
"\n id INT IDENTITY(1,1) PRIMARY KEY,"
68+
"\n session_id NVARCHAR(MAX) NOT NULL,"
69+
"\n data NVARCHAR(MAX) NOT NULL,"
70+
"\n type NVARCHAR(MAX) NOT NULL"
71+
"\n);"
72+
)
73+
else:
74+
raise AttributeError(
75+
f"Table '{self.table_name}' does not exist. Please create "
76+
"it before initializing MSSQLChatMessageHistory. See "
77+
"MSSQLEngine.create_chat_history_table() for a helper method."
78+
)
79+
80+
@property
81+
def messages(self) -> List[BaseMessage]: # type: ignore
82+
"""Retrieve the messages from Cloud SQL"""
83+
query = f'SELECT data, type FROM "{self.table_name}" WHERE session_id = :session_id ORDER BY id;'
84+
with self.engine.connect() as conn:
85+
results = conn.execute(
86+
sqlalchemy.text(query), {"session_id": self.session_id}
87+
).fetchall()
88+
# load SQLAlchemy row objects into dicts
89+
items = [{"data": json.loads(r[0]), "type": r[1]} for r in results]
90+
messages = messages_from_dict(items)
91+
return messages
92+
93+
def add_message(self, message: BaseMessage) -> None:
94+
"""Append the message to the record in Cloud SQL"""
95+
query = f'INSERT INTO "{self.table_name}" (session_id, data, type) VALUES (:session_id, :data, :type);'
96+
with self.engine.connect() as conn:
97+
conn.execute(
98+
sqlalchemy.text(query),
99+
{
100+
"session_id": self.session_id,
101+
"data": json.dumps(message.dict()),
102+
"type": message.type,
103+
},
104+
)
105+
conn.commit()
106+
107+
def clear(self) -> None:
108+
"""Clear session memory from Cloud SQL"""
109+
query = f'DELETE FROM "{self.table_name}" WHERE session_id = :session_id;'
110+
with self.engine.connect() as conn:
111+
conn.execute(sqlalchemy.text(query), {"session_id": self.session_id})
112+
conn.commit()

src/langchain_google_cloud_sql_mssql/mssql_engine.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,33 @@ def connect(self) -> sqlalchemy.engine.Connection:
118118
out from the connection pool.
119119
"""
120120
return self.engine.connect()
121+
122+
def create_chat_history_table(self, table_name: str) -> None:
123+
"""Create table with schema required for MSSQLChatMessageHistory class.
124+
125+
Required schema is as follows:
126+
127+
CREATE TABLE {table_name} (
128+
id INT IDENTITY(1,1) PRIMARY KEY,
129+
session_id NVARCHAR(MAX) NOT NULL,
130+
data NVARCHAR(MAX) NOT NULL,
131+
type NVARCHAR(MAX) NOT NULL
132+
)
133+
134+
Args:
135+
table_name (str): Name of database table to create for storing chat
136+
message history.
137+
"""
138+
create_table_query = f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES
139+
WHERE TABLE_NAME = '{table_name}')
140+
BEGIN
141+
CREATE TABLE {table_name} (
142+
id INT IDENTITY(1,1) PRIMARY KEY,
143+
session_id NVARCHAR(MAX) NOT NULL,
144+
data NVARCHAR(MAX) NOT NULL,
145+
type NVARCHAR(MAX) NOT NULL
146+
)
147+
END;"""
148+
with self.engine.connect() as conn:
149+
conn.execute(sqlalchemy.text(create_table_query))
150+
conn.commit()
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from typing import Generator
16+
17+
import pytest
18+
import sqlalchemy
19+
from langchain_core.messages.ai import AIMessage
20+
from langchain_core.messages.human import HumanMessage
21+
22+
from langchain_google_cloud_sql_mssql import MSSQLChatMessageHistory, MSSQLEngine
23+
24+
project_id = os.environ["PROJECT_ID"]
25+
region = os.environ["REGION"]
26+
instance_id = os.environ["INSTANCE_ID"]
27+
db_name = os.environ["DB_NAME"]
28+
db_user = os.environ["DB_USER"]
29+
db_password = os.environ["DB_PASSWORD"]
30+
table_name = "message_store"
31+
32+
33+
@pytest.fixture(name="memory_engine")
34+
def setup() -> Generator:
35+
engine = MSSQLEngine.from_instance(
36+
project_id=project_id,
37+
region=region,
38+
instance=instance_id,
39+
database=db_name,
40+
user=db_user,
41+
password=db_password,
42+
)
43+
44+
# create table with malformed schema (missing 'type')
45+
query = """CREATE TABLE malformed_table (
46+
id INT IDENTITY(1,1) PRIMARY KEY,
47+
session_id NVARCHAR(MAX) NOT NULL,
48+
data NVARCHAR(MAX) NOT NULL,
49+
);"""
50+
with engine.connect() as conn:
51+
conn.execute(sqlalchemy.text(query))
52+
conn.commit()
53+
yield engine
54+
# cleanup tables
55+
with engine.connect() as conn:
56+
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS {table_name}"))
57+
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS malformed_table"))
58+
conn.commit()
59+
60+
61+
def test_chat_message_history(memory_engine: MSSQLEngine) -> None:
62+
memory_engine.create_chat_history_table(table_name)
63+
history = MSSQLChatMessageHistory(
64+
engine=memory_engine, session_id="test", table_name=table_name
65+
)
66+
history.add_user_message("hi!")
67+
history.add_ai_message("whats up?")
68+
messages = history.messages
69+
70+
# verify messages are correct
71+
assert messages[0].content == "hi!"
72+
assert type(messages[0]) is HumanMessage
73+
assert messages[1].content == "whats up?"
74+
assert type(messages[1]) is AIMessage
75+
76+
# verify clear() clears message history
77+
history.clear()
78+
assert len(history.messages) == 0
79+
80+
81+
def test_chat_message_history_table_does_not_exist(memory_engine: MSSQLEngine) -> None:
82+
"""Test that MSSQLChatMessageHistory fails if table does not exist."""
83+
with pytest.raises(AttributeError) as exc_info:
84+
MSSQLChatMessageHistory(
85+
engine=memory_engine, session_id="test", table_name="missing_table"
86+
)
87+
# assert custom error message for missing table
88+
assert (
89+
exc_info.value.args[0]
90+
== f"Table 'missing_table' does not exist. Please create it before initializing MSSQLChatMessageHistory. See MSSQLEngine.create_chat_history_table() for a helper method."
91+
)
92+
93+
94+
def test_chat_message_history_table_malformed_schema(
95+
memory_engine: MSSQLEngine,
96+
) -> None:
97+
"""Test that MSSQLChatMessageHistory fails if schema is malformed."""
98+
with pytest.raises(IndexError):
99+
MSSQLChatMessageHistory(
100+
engine=memory_engine, session_id="test", table_name="malformed_table"
101+
)

0 commit comments

Comments
 (0)