Skip to content

Commit bd9e666

Browse files
author
wenfeng.wf
committed
Add tool to support execute sql in db.
1 parent bb3c091 commit bd9e666

File tree

4 files changed

+263
-0
lines changed

4 files changed

+263
-0
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ dependencies = [
1111
"alibabacloud-vpc20160428>=6.11.4",
1212
"httpx>=0.28.1",
1313
"mcp[cli]>=1.6.0",
14+
"psycopg2>=2.9.10",
15+
"pymysql>=1.1.1",
16+
"pyodbc>=5.2.0",
1417
]
1518

1619
license = "Apache-2.0"

src/alibabacloud_rds_openapi_mcp_server/db_driver/__init__.py

Whitespace-only changes.
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import json
2+
import random
3+
import socket
4+
import string
5+
6+
import psycopg2
7+
import pymysql
8+
from alibabacloud_rds20140815 import models as rds_20140815_models
9+
10+
from utils import get_rds_client
11+
12+
13+
def random_str(length=8):
14+
chars = string.ascii_lowercase + string.digits
15+
return ''.join(random.choice(chars) for _ in range(length))
16+
17+
18+
def random_password(length=32):
19+
U = string.ascii_uppercase
20+
L = string.ascii_lowercase
21+
D = string.digits
22+
S = '_!@#$%^&*()-+='
23+
pool = U + L + D + S
24+
for _ in range(1000):
25+
# 确保至少三类
26+
chosen = [
27+
random.choice(U),
28+
random.choice(L),
29+
random.choice(D),
30+
random.choice(S)
31+
]
32+
rest = [random.choice(pool) for _ in range(length - len(chosen))]
33+
pw = ''.join(random.sample(chosen + rest, length))
34+
return pw
35+
36+
37+
def test_connect(host, port, timeout=1):
38+
try:
39+
with socket.create_connection((host, int(port)), timeout):
40+
return True
41+
except Exception:
42+
return False
43+
44+
45+
class DBService:
46+
"""
47+
Create a read-only account, execute the SQL statements, and automatically delete the account afterward.
48+
"""
49+
def __init__(self,
50+
region_id,
51+
instance_id,
52+
database=None, ):
53+
self.instance_id = instance_id
54+
self.database = database
55+
self.region_id = region_id
56+
57+
self.__db_type = None
58+
self.__account_name = None
59+
self.__account_password = None
60+
self.__host = None
61+
self.__port = None
62+
self.__client = get_rds_client(region_id)
63+
self.__db_conn = None
64+
65+
def __enter__(self):
66+
self._get_db_instance_info()
67+
self._create_temp_account()
68+
if self.database:
69+
self._grant_privilege()
70+
self.__db_conn = DBConn(self)
71+
self.__db_conn.connect()
72+
return self
73+
74+
def __exit__(self, exc_type, exc_val, exc_tb):
75+
if self.__db_conn is not None:
76+
self.__db_conn.close()
77+
self._delete_account()
78+
self.__client = None
79+
80+
def _get_db_instance_info(self):
81+
req = rds_20140815_models.DescribeDBInstanceAttributeRequest(
82+
dbinstance_id=self.instance_id,
83+
)
84+
self.__client.describe_dbinstance_attribute(req)
85+
resp = self.__client.describe_dbinstance_attribute(req)
86+
self.db_type = resp.body.items.dbinstance_attribute[0].engine.lower()
87+
88+
req = rds_20140815_models.DescribeDBInstanceNetInfoRequest(
89+
dbinstance_id=self.instance_id,
90+
)
91+
resp = self.__client.describe_dbinstance_net_info(req)
92+
93+
# 取支持的地址:
94+
vpc_host, vpc_port, public_host, public_port, dbtype = None, None, None, None, None
95+
net_infos = resp.body.dbinstance_net_infos.dbinstance_net_info
96+
for item in net_infos:
97+
if 'Private' == item.iptype:
98+
vpc_host = item.connection_string
99+
vpc_port = int(item.port)
100+
elif 'Public' in item.iptype:
101+
public_host = item.connection_string
102+
public_port = int(item.port)
103+
104+
if vpc_host and test_connect(vpc_host, vpc_port):
105+
self.host = vpc_host
106+
self.port = vpc_port
107+
elif public_host and test_connect(public_host, public_port):
108+
self.host = public_host
109+
self.port = public_port
110+
else:
111+
raise Exception('connection db failed.')
112+
113+
def _create_temp_account(self):
114+
self.account_name = 'mcp_' + random_str(10)
115+
self.account_password = random_password(32)
116+
request = rds_20140815_models.CreateAccountRequest(
117+
dbinstance_id=self.instance_id,
118+
account_name=self.account_name,
119+
account_password=self.account_password,
120+
account_description="Created by mcp for execute sql."
121+
)
122+
self.__client.create_account(request)
123+
124+
def _grant_privilege(self):
125+
req = rds_20140815_models.GrantAccountPrivilegeRequest(
126+
dbinstance_id=self.instance_id,
127+
account_name=self.account_name,
128+
dbname=self.database,
129+
account_privilege="ReadOnly" if self.db_type.lower() in ('mysql', 'postgresql') else "DBOwner"
130+
)
131+
self.__client.grant_account_privilege(req)
132+
133+
def _delete_account(self):
134+
if not self.account_name:
135+
return
136+
req = rds_20140815_models.DeleteAccountRequest(
137+
dbinstance_id=self.instance_id,
138+
account_name=self.account_name
139+
)
140+
self.__client.delete_account(req)
141+
142+
def execute_sql(self, sql):
143+
return self.__db_conn.execute_sql(sql)
144+
145+
@property
146+
def user(self):
147+
return self.account_name
148+
149+
@property
150+
def password(self):
151+
return self.account_password
152+
153+
154+
class DBConn:
155+
def __init__(self, db_service: DBService):
156+
self.dbtype = db_service.db_type
157+
self.host = db_service.host
158+
self.port = db_service.port
159+
self.user = db_service.user
160+
self.password = db_service.password
161+
self.database = db_service.database
162+
self.conn = None
163+
164+
def connect(self):
165+
if self.conn is not None:
166+
return
167+
168+
if self.dbtype == 'mysql':
169+
self.conn = pymysql.connect(
170+
host=self.host, port=self.port,
171+
user=self.user, password=self.password,
172+
db=self.database, charset='utf8mb4',
173+
cursorclass=pymysql.cursors.DictCursor
174+
)
175+
elif self.dbtype == 'postgresql' or self.dbtype == 'pg':
176+
self.conn = psycopg2.connect(
177+
host=self.host, port=self.port,
178+
user=self.user, password=self.password,
179+
dbname=self.database
180+
)
181+
elif self.dbtype == 'sqlserver':
182+
import pyodbc
183+
driver = 'ODBC Driver 17 for SQL Server'
184+
conn_str = (
185+
f'DRIVER={{{driver}}};SERVER={self.host},{self.port};'
186+
f'UID={self.user};PWD={self.password};DATABASE={self.database}'
187+
)
188+
self.conn = pyodbc.connect(conn_str)
189+
else:
190+
raise ValueError('Unsupported dbtype')
191+
192+
def close(self):
193+
if self.conn is not None:
194+
try:
195+
self.conn.close()
196+
except Exception as e:
197+
print(e)
198+
self.conn = None
199+
200+
def execute_sql(self, sql):
201+
cursor = self.conn.cursor()
202+
cursor.execute(sql)
203+
columns = [desc[0] for desc in cursor.description]
204+
rows = cursor.fetchall()
205+
if self.dbtype == 'mysql':
206+
result = [dict(row) for row in rows]
207+
elif self.dbtype == 'postgresql' or self.dbtype == 'pg':
208+
result = [dict(zip(columns, row)) for row in rows]
209+
elif self.dbtype == 'sqlserver':
210+
result = [dict(zip(columns, row)) for row in rows]
211+
else:
212+
result = []
213+
return json.dumps(result, ensure_ascii=False)

src/alibabacloud_rds_openapi_mcp_server/server.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from alibabacloud_vpc20160428 import models as vpc_20160428_models
1515
from mcp.server.fastmcp import FastMCP
1616

17+
from db_driver.db_service import DBService
18+
1719
current_dir = os.path.dirname(os.path.abspath(__file__))
1820
sys.path.append(current_dir)
1921
from utils import (transform_to_iso_8601,
@@ -1435,6 +1437,51 @@ def _descirbe(order_by: str):
14351437
}
14361438

14371439

1440+
@mcp.tool()
1441+
async def show_engine_innodb_status(
1442+
dbinstance_id: str,
1443+
region_id: str
1444+
) -> str:
1445+
"""
1446+
show engine innodb status in db.
1447+
Args:
1448+
dbinstance_id (str): The ID of the RDS instance.
1449+
region_id(str): the region id of instance.
1450+
Returns:
1451+
the sql result.
1452+
"""
1453+
try:
1454+
with DBService(region_id, dbinstance_id) as service:
1455+
return service.execute_sql("show engine innodb status")
1456+
except Exception as e:
1457+
logger.error(f"Error occurred: {str(e)}")
1458+
raise e
1459+
1460+
@mcp.tool()
1461+
async def show_create_table(
1462+
region_id: str,
1463+
dbinstance_id: str,
1464+
db_name: str,
1465+
table_name: str
1466+
) -> str:
1467+
"""
1468+
show create table db_name.table_name
1469+
Args:
1470+
dbinstance_id (str): The ID of the RDS instance.
1471+
region_id(str): the region id of instance.
1472+
db_name(str): the db name for table.
1473+
table_name(str): the table name.
1474+
Returns:
1475+
the sql result.
1476+
"""
1477+
try:
1478+
with DBService(region_id, dbinstance_id, db_name) as service:
1479+
return service.execute_sql(f"show create table {db_name}.{table_name}")
1480+
except Exception as e:
1481+
logger.error(f"Error occurred: {str(e)}")
1482+
raise e
1483+
1484+
14381485
def main():
14391486
mcp.run(transport=os.getenv('SERVER_TRANSPORT', 'stdio'))
14401487

0 commit comments

Comments
 (0)