Skip to content

Commit 8156688

Browse files
committed
PYTHON-1189 Fix race condition in poll.poll
Each Pool guards its poll.poll with a mutex to prevent concurrent calls. poll and select are retried if they are interrupted with EINTR. Also fixes PYTHON-1179.
1 parent f7dc88c commit 8156688

File tree

3 files changed

+85
-26
lines changed

3 files changed

+85
-26
lines changed

pymongo/network.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,22 @@
1818
import errno
1919
import select
2020
import struct
21+
import threading
2122

2223
_HAS_POLL = True
23-
_poller = None
2424
_EVENT_MASK = 0
2525
try:
2626
from select import poll
27-
_poller = poll()
2827
_EVENT_MASK = (select.POLLIN | select.POLLPRI | select.POLLERR |
2928
select.POLLHUP | select.POLLNVAL)
3029
except ImportError:
3130
_HAS_POLL = False
3231

32+
try:
33+
from select import error as _SELECT_ERROR
34+
except ImportError:
35+
_SELECT_ERROR = OSError
36+
3337
from pymongo import helpers, message
3438
from pymongo.common import MAX_MESSAGE_SIZE
3539
from pymongo.errors import (AutoReconnect,
@@ -159,12 +163,7 @@ def _receive_data_on_socket(sock, length):
159163
try:
160164
chunk = sock.recv(length)
161165
except (IOError, OSError) as exc:
162-
err = None
163-
if hasattr(exc, 'errno'):
164-
err = exc.errno
165-
elif exc.args:
166-
err = exc.args[0]
167-
if err == errno.EINTR:
166+
if _errno_from_exception(exc) == errno.EINTR:
168167
continue
169168
raise
170169
if chunk == b"":
@@ -176,17 +175,56 @@ def _receive_data_on_socket(sock, length):
176175
return msg
177176

178177

179-
def socket_closed(sock):
180-
"""Return True if we know socket has been closed, False otherwise.
181-
"""
182-
try:
178+
def _errno_from_exception(exc):
179+
if hasattr(exc, 'errno'):
180+
return exc.errno
181+
elif exc.args:
182+
return exc.args[0]
183+
else:
184+
return None
185+
186+
187+
class SocketChecker(object):
188+
189+
def __init__(self):
183190
if _HAS_POLL:
184-
_poller.register(sock, _EVENT_MASK)
185-
rd = _poller.poll(0)
186-
_poller.unregister(sock)
191+
self._lock = threading.Lock()
192+
self._poller = poll()
187193
else:
188-
rd, _, _ = select.select([sock], [], [], 0)
189-
# Any exception here is equally bad (select.error, ValueError, etc.).
190-
except:
191-
return True
192-
return len(rd) > 0
194+
self._lock = None
195+
self._poller = None
196+
197+
def socket_closed(self, sock):
198+
"""Return True if we know socket has been closed, False otherwise.
199+
"""
200+
while True:
201+
try:
202+
if self._poller:
203+
with self._lock:
204+
self._poller.register(sock, _EVENT_MASK)
205+
try:
206+
rd = self._poller.poll(0)
207+
finally:
208+
self._poller.unregister(sock)
209+
else:
210+
rd, _, _ = select.select([sock], [], [], 0)
211+
except (RuntimeError, KeyError):
212+
# RuntimeError is raised during a concurrent poll. KeyError
213+
# is raised by unregister if the socket is not in the poller.
214+
# These errors should not be possible since we protect the
215+
# poller with a mutex.
216+
raise
217+
except ValueError:
218+
# ValueError is raised by register/unregister/select if the
219+
# socket file descriptor is negative or outside the range for
220+
# select (> 1023).
221+
return True
222+
except (_SELECT_ERROR, IOError) as exc:
223+
if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN):
224+
continue
225+
return True
226+
except:
227+
# Any other exceptions should be attributed to a closed
228+
# or invalid socket.
229+
return True
230+
return len(rd) > 0

pymongo/pool.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class SSLError(socket.error):
4545
from pymongo.monotonic import time as _time
4646
from pymongo.network import (command,
4747
receive_message,
48-
socket_closed)
48+
SocketChecker)
4949
from pymongo.read_concern import DEFAULT_READ_CONCERN
5050
from pymongo.read_preferences import ReadPreference
5151
from pymongo.server_type import SERVER_TYPE
@@ -703,6 +703,7 @@ def __init__(self, address, options, handshake=True):
703703

704704
self._socket_semaphore = thread_util.create_semaphore(
705705
self.opts.max_pool_size, max_waiters)
706+
self.socket_checker = SocketChecker()
706707

707708
def reset(self):
708709
with self.lock:
@@ -879,7 +880,7 @@ def _check(self, sock_info):
879880
and (
880881
0 == self._check_interval_seconds
881882
or age > self._check_interval_seconds)):
882-
if socket_closed(sock_info.sock):
883+
if self.socket_checker.socket_closed(sock_info.sock):
883884
sock_info.close()
884885
error = True
885886

test/test_pooling.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
sys.path[0:0] = [""]
3131

32-
from pymongo.network import socket_closed
32+
from pymongo.network import SocketChecker
3333
from pymongo.pool import Pool, PoolOptions
3434
from test import client_context, unittest
3535
from test.utils import (get_pool,
@@ -236,7 +236,8 @@ def test_pool_removes_dead_socket(self):
236236
# Simulate a closed socket without telling the SocketInfo it's
237237
# closed.
238238
sock_info.sock.close()
239-
self.assertTrue(socket_closed(sock_info.sock))
239+
self.assertTrue(
240+
cx_pool.socket_checker.socket_closed(sock_info.sock))
240241

241242
with cx_pool.get_socket({}) as new_sock_info:
242243
self.assertEqual(0, len(cx_pool.sockets))
@@ -251,9 +252,28 @@ def test_pool_removes_dead_socket(self):
251252
def test_socket_closed(self):
252253
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
253254
s.connect((client_context.host, client_context.port))
254-
self.assertFalse(socket_closed(s))
255+
socket_checker = SocketChecker()
256+
self.assertFalse(socket_checker.socket_closed(s))
255257
s.close()
256-
self.assertTrue(socket_closed(s))
258+
self.assertTrue(socket_checker.socket_closed(s))
259+
260+
def test_socket_closed_thread_safe(self):
261+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
262+
s.connect((client_context.host, client_context.port))
263+
socket_checker = SocketChecker()
264+
265+
def check_socket():
266+
for _ in range(1000):
267+
self.assertFalse(socket_checker.socket_closed(s))
268+
269+
threads = []
270+
for i in range(3):
271+
thread = threading.Thread(target=check_socket)
272+
thread.start()
273+
threads.append(thread)
274+
275+
for thread in threads:
276+
thread.join()
257277

258278
def test_return_socket_after_reset(self):
259279
pool = self.create_pool()

0 commit comments

Comments
 (0)