Skip to content
This repository was archived by the owner on Mar 20, 2023. It is now read-only.

Commit 98e8d10

Browse files
ciscornfxdgear
authored andcommitted
Fix ClientSession.close() was never awaited (#39)
1 parent d368364 commit 98e8d10

File tree

10 files changed

+80
-12
lines changed

10 files changed

+80
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ dist
77
build
88
*.egg
99
coverage.xml
10+
.pytest_cache
1011
junit.xml
1112
test_elasticsearch_async/htmlcov
1213
docs/_build

README

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ Example::
2121

2222
loop = asyncio.get_event_loop()
2323
loop.run_until_complete(print_info())
24+
loop.run_until_complete(client.transport.close())
2425
loop.close()
25-
client.transport.close()
2626

2727

2828
Example with SSL Context::
@@ -46,8 +46,8 @@ Example with SSL Context::
4646

4747
loop = asyncio.get_event_loop()
4848
loop.run_until_complete(print_info())
49+
loop.run_until_complete(client.transport.close())
4950
loop.close()
50-
client.transport.close()
5151

5252

5353

elasticsearch_async/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def __init__(self, host='localhost', port=9200, http_auth=None,
8080
host, port, self.url_prefix
8181
)
8282

83+
@asyncio.coroutine
8384
def close(self):
84-
return self.session.close()
85+
yield from self.session.close()
8586

8687
@asyncio.coroutine
8788
def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None):
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import asyncio
2+
3+
from elasticsearch.connection_pool import ConnectionPool, DummyConnectionPool
4+
5+
6+
class AsyncConnectionPool(ConnectionPool):
7+
def __init__(self, connections, loop, **kwargs):
8+
self.loop = loop
9+
super().__init__(connections, **kwargs)
10+
11+
async def close(self):
12+
await asyncio.gather(*[conn.close() for conn in self.orig_connections],
13+
loop=self.loop)
14+
15+
16+
class AsyncDummyConnectionPool(DummyConnectionPool):
17+
async def close(self):
18+
await self.connection.close()

elasticsearch_async/transport.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@
44
from itertools import chain
55

66
from elasticsearch import Transport, TransportError, ConnectionTimeout, ConnectionError, SerializationError
7+
from elasticsearch.connection_pool import DummyConnectionPool
78

89
from .connection import AIOHttpConnection
10+
from .connection_pool import AsyncConnectionPool, AsyncDummyConnectionPool
911
from .helpers import ensure_future
1012

1113
logger = logging.getLogger('elasticsearch')
1214

1315
class AsyncTransport(Transport):
1416
def __init__(self, hosts, connection_class=AIOHttpConnection, loop=None,
17+
connection_pool_class=AsyncConnectionPool,
1518
sniff_on_start=False, raise_on_sniff_error=True, **kwargs):
1619
self.raise_on_sniff_error = raise_on_sniff_error
1720
self.loop = asyncio.get_event_loop() if loop is None else loop
1821
kwargs['loop'] = self.loop
19-
super().__init__(hosts, connection_class=connection_class, sniff_on_start=False, **kwargs)
22+
super().__init__(hosts, connection_class=connection_class, sniff_on_start=False,
23+
connection_pool_class=connection_pool_class, **kwargs)
2024

2125
self.sniffing_task = None
2226
if sniff_on_start:
@@ -42,10 +46,16 @@ def initiate_sniff(self, initial=False):
4246
if self.sniffing_task is None:
4347
self.sniffing_task = ensure_future(self.sniff_hosts(initial), loop=self.loop)
4448

49+
@asyncio.coroutine
4550
def close(self):
4651
if self.sniffing_task:
4752
self.sniffing_task.cancel()
48-
super().close()
53+
yield from self.connection_pool.close()
54+
55+
def set_connections(self, hosts):
56+
super().set_connections(hosts)
57+
if isinstance(self.connection_pool, DummyConnectionPool):
58+
self.connection_pool = AsyncDummyConnectionPool(self.connection_pool.connection_opts)
4959

5060
def get_connection(self):
5161
if self.sniffer_timeout:

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
'pytest',
2121
'pytest-asyncio',
2222
'pytest-cov',
23-
'pytest-catchlog',
2423
]
2524

2625
setup(

test_elasticsearch_async/conftest.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import logging
34

45
from pytest import yield_fixture, fixture
56

@@ -11,7 +12,7 @@
1112
def connection(event_loop, server, port):
1213
connection = AIOHttpConnection(port=port, loop=event_loop)
1314
yield connection
14-
connection.close()
15+
event_loop.run_until_complete(connection.close())
1516

1617

1718
class DummyElasticsearch(aiohttp.web.Server):
@@ -61,17 +62,19 @@ def port():
6162

6263
@yield_fixture
6364
def server(event_loop, port):
64-
server = DummyElasticsearch(debug=True, keep_alive=75)
65+
server = DummyElasticsearch(debug=True, loop=event_loop)
6566
f = event_loop.create_server(server, '127.0.0.1', port)
6667
event_loop.run_until_complete(f)
6768
yield server
6869
event_loop.run_until_complete(server.shutdown(timeout=.5))
6970

7071
@yield_fixture
7172
def client(event_loop, server, port):
73+
logger = logging.getLogger('elasticsearch')
74+
logger.setLevel(logging.DEBUG)
7275
c = AsyncElasticsearch([{'host': '127.0.0.1','port': port}], loop=event_loop)
7376
yield c
74-
c.transport.close()
77+
event_loop.run_until_complete(c.transport.close())
7578

7679
@fixture
7780
def sniff_data():

test_elasticsearch_async/test_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_auth_is_set_correctly(event_loop):
3030
def test_ssl_context_is_correctly(event_loop):
3131
context = create_ssl_context(cafile="test_elasticsearch_async/ca.crt")
3232
connection = AIOHttpConnection(ssl_context=context, http_auth=('user', 'secret'), loop=event_loop)
33-
assert connection.session.connector.ssl_context.get_ca_certs() == [{
33+
assert connection.session.connector._ssl.get_ca_certs() == [{
3434
'subject': ((('commonName', 'Elastic Certificate Tool Autogenerated CA'),),),
3535
'issuer': ((('commonName', 'Elastic Certificate Tool Autogenerated CA'),),),
3636
'version': 3,
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from pytest import mark
2+
3+
from elasticsearch_async import AsyncElasticsearch
4+
from elasticsearch_async.connection_pool import \
5+
AsyncConnectionPool, AsyncDummyConnectionPool
6+
7+
8+
@mark.asyncio
9+
def test_single_host_makes_async_dummy_pool(server, client, event_loop, port):
10+
client = AsyncElasticsearch(port=port, loop=event_loop)
11+
assert isinstance(client.transport.connection_pool, AsyncDummyConnectionPool)
12+
yield from client.transport.close()
13+
14+
@mark.asyncio
15+
def test_multiple_hosts_make_async_pool(server, event_loop, port):
16+
client = AsyncElasticsearch(
17+
hosts=['localhost', 'localhost'], port=port, loop=event_loop)
18+
assert isinstance(client.transport.connection_pool, AsyncConnectionPool)
19+
assert len(client.transport.connection_pool.orig_connections) == 2
20+
yield from client.transport.close()
21+
22+
@mark.asyncio
23+
def test_async_dummy_pool_is_closed_properly(server, event_loop, port):
24+
client = AsyncElasticsearch(port=port, loop=event_loop)
25+
assert isinstance(client.transport.connection_pool, AsyncDummyConnectionPool)
26+
yield from client.transport.close()
27+
assert client.transport.connection_pool.connection.session.closed
28+
29+
@mark.asyncio
30+
def test_async_pool_is_closed_properly(server, event_loop, port):
31+
client = AsyncElasticsearch(
32+
hosts=['localhost', 'localhost'], port=port, loop=event_loop)
33+
assert isinstance(client.transport.connection_pool, AsyncConnectionPool)
34+
yield from client.transport.close()
35+
for conn in client.transport.connection_pool.orig_connections:
36+
assert conn.session.closed

test_elasticsearch_async/test_transport.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ def test_sniff_on_start_sniffs(server, event_loop, port, sniff_data):
1919

2020
assert 1 == len(connections)
2121
assert 'http://node1:9200' == connections[0].host
22-
client.transport.close()
22+
yield from client.transport.close()
2323

2424
@mark.asyncio
2525
def test_retry_will_work(port, server, event_loop):
2626
client = AsyncElasticsearch(hosts=['not-an-es-host', 'localhost'], port=port, loop=event_loop, randomize_hosts=False)
2727

2828
data = yield from client.info()
2929
assert {'body': '', 'method': 'GET', 'params': {}, 'path': '/'} == data
30-
client.transport.close()
30+
yield from client.transport.close()

0 commit comments

Comments
 (0)