Skip to content

Commit 815c5c5

Browse files
authored
Merge pull request #219 from fantix/t14
Fixed #14, implemented prepared statement
2 parents f7bd271 + 3868ea5 commit 815c5c5

File tree

6 files changed

+222
-12
lines changed

6 files changed

+222
-12
lines changed

gino/crud.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ def _init_table(cls, sub_cls):
413413
return rv
414414

415415
@classmethod
416-
async def _create_without_instance(cls, bind=None, timeout=DEFAULT, **values):
416+
async def _create_without_instance(cls, bind=None, timeout=DEFAULT,
417+
**values):
417418
return await cls(**values)._create(bind=bind, timeout=timeout)
418419

419420
async def _create(self, bind=None, timeout=DEFAULT):

gino/dialects/asyncpg.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ async def forward(self, n, *, timeout=base.DEFAULT):
9999

100100

101101
class PreparedStatement(base.PreparedStatement):
102-
def __init__(self, prepared):
103-
super().__init__()
102+
def __init__(self, prepared, clause=None):
103+
super().__init__(clause)
104104
self._prepared = prepared
105105

106106
def _get_iterator(self, *params, **kwargs):
@@ -111,27 +111,41 @@ async def _get_cursor(self, *params, **kwargs):
111111
iterator = await self._prepared.cursor(*params, **kwargs)
112112
return AsyncpgCursor(self.context, iterator)
113113

114+
async def _execute(self, params, one):
115+
if one:
116+
rv = await self._prepared.fetchrow(*params)
117+
if rv is None:
118+
rv = []
119+
else:
120+
rv = [rv]
121+
else:
122+
rv = await self._prepared.fetch(*params)
123+
return self._prepared.get_statusmsg(), rv
124+
114125

115126
class DBAPICursor(base.DBAPICursor):
116127
def __init__(self, dbapi_conn):
117128
self._conn = dbapi_conn
118129
self._attributes = None
119130
self._status = None
120131

121-
async def prepare(self, query, timeout):
132+
async def prepare(self, context, clause=None):
133+
timeout = context.timeout
122134
if timeout is None:
123135
conn = await self._conn.acquire(timeout=timeout)
124136
else:
125137
before = time.monotonic()
126138
conn = await self._conn.acquire(timeout=timeout)
127139
after = time.monotonic()
128140
timeout -= after - before
129-
prepared = await conn.prepare(query, timeout=timeout)
141+
prepared = await conn.prepare(context.statement, timeout=timeout)
130142
try:
131143
self._attributes = prepared.get_attributes()
132144
except TypeError: # asyncpg <= 0.12.0
133145
self._attributes = []
134-
return PreparedStatement(prepared)
146+
rv = PreparedStatement(prepared, clause)
147+
rv.context = context
148+
return rv
135149

136150
async def async_execute(self, query, timeout, args, limit=0, many=False):
137151
if timeout is None:

gino/dialects/base.py

Lines changed: 132 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def executemany(self, statement, parameters):
2626
def description(self):
2727
raise NotImplementedError
2828

29-
async def prepare(self, query, timeout):
29+
async def prepare(self, context, clause=None):
3030
raise NotImplementedError
3131

3232
async def async_execute(self, query, timeout, args, limit=0, many=False):
@@ -67,18 +67,62 @@ async def rollback(self):
6767

6868

6969
class PreparedStatement:
70-
def __init__(self):
70+
def __init__(self, clause=None):
7171
self.context = None
72+
self.clause = clause
7273

7374
def iterate(self, *params, **kwargs):
7475
return _PreparedIterableCursor(self, params, kwargs)
7576

77+
async def _do_execute(self, multiparams, params, one=False,
78+
return_model=True, status=False):
79+
ctx = self.context.connection.execute(
80+
self.clause, *multiparams, **params).context
81+
if ctx.executemany:
82+
raise ValueError('PreparedStatement does not support multiple '
83+
'parameters.')
84+
assert ctx.statement == self.context.statement, (
85+
'Prepared statement generated different SQL with parameters')
86+
params = []
87+
for val in ctx.parameters[0]:
88+
params.append(val)
89+
msg, rows = await self._execute(params, one)
90+
if status:
91+
return msg
92+
item = self.context.process_rows(rows, return_model=return_model)
93+
if one:
94+
if item:
95+
item = item[0]
96+
else:
97+
item = None
98+
return item
99+
100+
async def all(self, *multiparams, **params):
101+
return await self._do_execute(multiparams, params)
102+
103+
async def first(self, *multiparams, **params):
104+
return await self._do_execute(multiparams, params, one=True)
105+
106+
async def scalar(self, *multiparams, **params):
107+
rv = await self._do_execute(multiparams, params, one=True,
108+
return_model=False)
109+
if rv:
110+
return rv[0]
111+
else:
112+
return None
113+
114+
async def status(self, *multiparams, **params):
115+
return await self._do_execute(multiparams, params, status=True)
116+
76117
def _get_iterator(self, *params, **kwargs):
77118
raise NotImplementedError
78119

79120
async def _get_cursor(self, *params, **kwargs):
80121
raise NotImplementedError
81122

123+
async def _execute(self, params, one):
124+
raise NotImplementedError
125+
82126

83127
class _PreparedIterableCursor:
84128
def __init__(self, prepared, params, kwargs):
@@ -100,9 +144,7 @@ def __init__(self, context):
100144
self._context = context
101145

102146
async def _iterate(self):
103-
prepared = await self._context.cursor.prepare(self._context.statement,
104-
self._context.timeout)
105-
prepared.context = self._context
147+
prepared = await self._context.cursor.prepare(self._context)
106148
return prepared.iterate(*self._context.parameters[0],
107149
timeout=self._context.timeout)
108150

@@ -173,6 +215,9 @@ def iterate(self):
173215
raise ValueError('too many multiparams')
174216
return _IterableCursor(self._context)
175217

218+
async def prepare(self, clause):
219+
return await self._context.cursor.prepare(self._context, clause)
220+
176221
def _soft_close(self):
177222
pass
178223

@@ -237,6 +282,88 @@ def process_rows(self, rows, return_model=True):
237282
def get_result_proxy(self):
238283
return _ResultProxy(self)
239284

285+
@classmethod
286+
def _init_compiled_prepared(cls, dialect, connection, dbapi_connection,
287+
compiled, parameters):
288+
self = cls.__new__(cls)
289+
self.root_connection = connection
290+
self._dbapi_connection = dbapi_connection
291+
self.dialect = connection.dialect
292+
293+
self.compiled = compiled
294+
295+
# this should be caught in the engine before
296+
# we get here
297+
assert compiled.can_execute
298+
299+
self.execution_options = compiled.execution_options.union(
300+
connection._execution_options)
301+
302+
self.result_column_struct = (
303+
compiled._result_columns, compiled._ordered_columns,
304+
compiled._textual_ordered_columns)
305+
306+
self.unicode_statement = util.text_type(compiled)
307+
if not dialect.supports_unicode_statements:
308+
self.statement = self.unicode_statement.encode(
309+
self.dialect.encoding)
310+
else:
311+
self.statement = self.unicode_statement
312+
313+
self.isinsert = compiled.isinsert
314+
self.isupdate = compiled.isupdate
315+
self.isdelete = compiled.isdelete
316+
self.is_text = compiled.isplaintext
317+
318+
self.executemany = False
319+
320+
self.cursor = self.create_cursor()
321+
322+
if self.isinsert or self.isupdate or self.isdelete:
323+
self.is_crud = True
324+
self._is_explicit_returning = bool(compiled.statement._returning)
325+
self._is_implicit_returning = bool(
326+
compiled.returning and not compiled.statement._returning)
327+
328+
if self.dialect.positional:
329+
self.parameters = [dialect.execute_sequence_format()]
330+
else:
331+
self.parameters = [{}]
332+
self.compiled_parameters = [{}]
333+
334+
return self
335+
336+
@classmethod
337+
def _init_statement_prepared(cls, dialect, connection, dbapi_connection,
338+
statement, parameters):
339+
"""Initialize execution context for a string SQL statement."""
340+
341+
self = cls.__new__(cls)
342+
self.root_connection = connection
343+
self._dbapi_connection = dbapi_connection
344+
self.dialect = connection.dialect
345+
self.is_text = True
346+
347+
# plain text statement
348+
self.execution_options = connection._execution_options
349+
350+
if self.dialect.positional:
351+
self.parameters = [dialect.execute_sequence_format()]
352+
else:
353+
self.parameters = [{}]
354+
355+
self.executemany = False
356+
357+
if not dialect.supports_unicode_statements and \
358+
isinstance(statement, util.text_type):
359+
self.unicode_statement = statement
360+
self.statement = dialect._encoder(statement)[0]
361+
else:
362+
self.statement = self.unicode_statement = statement
363+
364+
self.cursor = self.create_cursor()
365+
return self
366+
240367

241368
class AsyncDialectMixin:
242369
cursor_cls = DBAPICursor

gino/engine.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,26 @@ async def _release(self):
129129
pass
130130

131131

132+
# noinspection PyPep8Naming,PyMethodMayBeStatic
133+
class _bypass_no_param:
134+
def keys(self):
135+
return []
136+
137+
138+
_bypass_no_param = _bypass_no_param()
139+
140+
132141
# noinspection PyAbstractClass
133142
class _SAConnection(Connection):
134-
pass
143+
def _execute_context(self, dialect, constructor,
144+
statement, parameters,
145+
*args):
146+
if parameters == [_bypass_no_param]:
147+
constructor = getattr(self.dialect.execution_ctx_cls,
148+
constructor.__name__ + '_prepared',
149+
constructor)
150+
return super()._execute_context(dialect, constructor, statement,
151+
parameters, *args)
135152

136153

137154
# noinspection PyAbstractClass
@@ -503,6 +520,10 @@ async def _run_visitor(self, visitorcallable, element, **kwargs):
503520
await visitorcallable(self.dialect, self,
504521
**kwargs).traverse_single(element)
505522

523+
async def prepare(self, clause):
524+
return await self._execute(
525+
clause, (_bypass_no_param,), {}).prepare(clause)
526+
506527

507528
class GinoEngine:
508529
"""

tests/test_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ async def test_scalar_return_none(bind):
170170

171171

172172
async def test_asyncpg_0120(bind, mocker):
173+
# for asyncpg 0.12.0
173174
assert await bind.first('rollback') is None
174175

175176
orig = getattr(asyncpg.Connection, '_do_execute')

tests/test_prepared_stmt.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from datetime import datetime
2+
3+
import pytest
4+
from .models import db, User
5+
6+
pytestmark = pytest.mark.asyncio
7+
8+
9+
async def test_compiled_and_bindparam(bind):
10+
async with db.acquire() as conn:
11+
# noinspection PyArgumentList
12+
ins = await conn.prepare(User.insert().returning(
13+
*User).execution_options(loader=User))
14+
users = {}
15+
for name in '12345':
16+
u = await ins.first(nickname=name)
17+
assert u.nickname == name
18+
users[u.id] = u
19+
get = await conn.prepare(
20+
User.query.where(User.id == db.bindparam('uid')))
21+
for key in users:
22+
u = await get.first(uid=key)
23+
assert u.nickname == users[key].nickname
24+
assert (await get.all(uid=key))[0].nickname == u.nickname
25+
26+
assert await get.scalar(uid=-1) is None
27+
28+
with pytest.raises(ValueError, match='does not support multiple'):
29+
await get.all([dict(uid=1), dict(uid=2)])
30+
31+
delete = await conn.prepare(
32+
User.delete.where(User.nickname == db.bindparam('name')))
33+
for name in '12345':
34+
msg = await delete.status(name=name)
35+
assert msg == 'DELETE 1'
36+
37+
38+
async def test_statement(engine):
39+
async with engine.acquire() as conn:
40+
stmt = await conn.prepare('SELECT now()')
41+
last = None
42+
for i in range(5):
43+
now = await stmt.scalar()
44+
assert isinstance(now, datetime)
45+
assert last != now
46+
last = now

0 commit comments

Comments
 (0)