Skip to content

Commit 4753a9d

Browse files
committed
Check for send to return 0
When it does, close the socket and raise an Exception. This prevents leaking sockets when send fails.
1 parent 9aaf781 commit 4753a9d

File tree

8 files changed

+251
-93
lines changed

8 files changed

+251
-93
lines changed

adafruit_requests.py

Lines changed: 68 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -377,26 +377,28 @@ def __init__(self, socket_pool, ssl_context=None):
377377
self._last_response = None
378378

379379
def _free_socket(self, socket):
380-
381380
if socket not in self._open_sockets.values():
382381
raise RuntimeError("Socket not from session")
383382
self._socket_free[socket] = True
384383

384+
def _close_socket(self, sock):
385+
sock.close()
386+
del self._socket_free[sock]
387+
key = None
388+
for k in self._open_sockets:
389+
if self._open_sockets[k] == sock:
390+
key = k
391+
break
392+
if key:
393+
del self._open_sockets[key]
394+
385395
def _free_sockets(self):
386396
free_sockets = []
387397
for sock in self._socket_free:
388398
if self._socket_free[sock]:
389-
sock.close()
390399
free_sockets.append(sock)
391400
for sock in free_sockets:
392-
del self._socket_free[sock]
393-
key = None
394-
for k in self._open_sockets:
395-
if self._open_sockets[k] == sock:
396-
key = k
397-
break
398-
if key:
399-
del self._open_sockets[key]
401+
self._close_socket(sock)
400402

401403
def _get_socket(self, host, port, proto, *, timeout=1):
402404
key = (host, port, proto)
@@ -440,6 +442,56 @@ def _get_socket(self, host, port, proto, *, timeout=1):
440442
self._socket_free[sock] = False
441443
return sock
442444

445+
def _send(self, socket, data):
446+
total_sent = 0
447+
while total_sent < len(data):
448+
sent = socket.send(data[total_sent:])
449+
if sent == 0:
450+
raise RuntimeError("Connection closed")
451+
total_sent += sent
452+
453+
def _send_request(self, socket, host, method, path, headers, data, json):
454+
self._send(socket, bytes(method, "utf-8"))
455+
self._send(socket, b" /")
456+
self._send(socket, bytes(path, "utf-8"))
457+
self._send(socket, b" HTTP/1.1\r\n")
458+
if "Host" not in headers:
459+
self._send(socket, b"Host: ")
460+
self._send(socket, bytes(host, "utf-8"))
461+
self._send(socket, b"\r\n")
462+
if "User-Agent" not in headers:
463+
self._send(socket, b"User-Agent: Adafruit CircuitPython\r\n")
464+
# Iterate over keys to avoid tuple alloc
465+
for k in headers:
466+
self._send(socket, k.encode())
467+
self._send(socket, b": ")
468+
self._send(socket, headers[k].encode())
469+
self._send(socket, b"\r\n")
470+
if json is not None:
471+
assert data is None
472+
# pylint: disable=import-outside-toplevel
473+
try:
474+
import json as json_module
475+
except ImportError:
476+
import ujson as json_module
477+
data = json_module.dumps(json)
478+
self._send(socket, b"Content-Type: application/json\r\n")
479+
if data:
480+
if isinstance(data, dict):
481+
self._send(socket, b"Content-Type: application/x-www-form-urlencoded\r\n")
482+
_post_data = ""
483+
for k in data:
484+
_post_data = "{}&{}={}".format(_post_data, k, data[k])
485+
data = _post_data[1:]
486+
self._send(socket, b"Content-Length: %d\r\n" % len(data))
487+
self._send(socket, b"\r\n")
488+
if data:
489+
if isinstance(data, bytearray):
490+
self._send(socket, bytes(data))
491+
else:
492+
self._send(socket, bytes(data, "utf-8"))
493+
494+
443495
# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
444496
def request(
445497
self, method, url, data=None, json=None, headers=None, stream=False, timeout=60
@@ -476,42 +528,11 @@ def request(
476528
self._last_response = None
477529

478530
socket = self._get_socket(host, port, proto, timeout=timeout)
479-
socket.send(
480-
b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))
481-
)
482-
if "Host" not in headers:
483-
socket.send(b"Host: %s\r\n" % bytes(host, "utf-8"))
484-
if "User-Agent" not in headers:
485-
socket.send(b"User-Agent: Adafruit CircuitPython\r\n")
486-
# Iterate over keys to avoid tuple alloc
487-
for k in headers:
488-
socket.send(k.encode())
489-
socket.send(b": ")
490-
socket.send(headers[k].encode())
491-
socket.send(b"\r\n")
492-
if json is not None:
493-
assert data is None
494-
# pylint: disable=import-outside-toplevel
495-
try:
496-
import json as json_module
497-
except ImportError:
498-
import ujson as json_module
499-
data = json_module.dumps(json)
500-
socket.send(b"Content-Type: application/json\r\n")
501-
if data:
502-
if isinstance(data, dict):
503-
socket.send(b"Content-Type: application/x-www-form-urlencoded\r\n")
504-
_post_data = ""
505-
for k in data:
506-
_post_data = "{}&{}={}".format(_post_data, k, data[k])
507-
data = _post_data[1:]
508-
socket.send(b"Content-Length: %d\r\n" % len(data))
509-
socket.send(b"\r\n")
510-
if data:
511-
if isinstance(data, bytearray):
512-
socket.send(bytes(data))
513-
else:
514-
socket.send(bytes(data, "utf-8"))
531+
try:
532+
self._send_request(socket, host, method, path, headers, data, json)
533+
except:
534+
self._close_socket(socket)
535+
raise
515536

516537
resp = Response(socket, self) # our response
517538
if "location" in resp.headers and 300 <= resp.status_code <= 399:
@@ -557,6 +578,7 @@ def __init__(self, socket, tls_mode):
557578
self.settimeout = socket.settimeout
558579
self.send = socket.send
559580
self.recv = socket.recv
581+
self.close = socket.close
560582

561583
def connect(self, address):
562584
"""connect wrapper to add non-standard mode parameter"""

tests/chunk_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,19 @@ def test_get_text():
3939
r = s.get("http://" + host + path)
4040

4141
sock.connect.assert_called_once_with((ip, 80))
42+
43+
sock.send.assert_has_calls(
44+
[
45+
mock.call(b"GET"),
46+
mock.call(b" /"),
47+
mock.call(b"testwifi/index.html"),
48+
mock.call(b" HTTP/1.1\r\n"),
49+
]
50+
)
4251
sock.send.assert_has_calls(
4352
[
44-
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
45-
mock.call(b"Host: wifitest.adafruit.com\r\n"),
53+
mock.call(b"Host: "),
54+
mock.call(b"wifitest.adafruit.com"),
4655
]
4756
)
4857
assert r.text == str(text, "utf-8")

tests/header_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ def test_json():
1414
sock = mocket.Mocket(response_headers)
1515
pool.socket.return_value = sock
1616
sent = []
17-
sock.send.side_effect = sent.append
17+
def _send(data):
18+
sent.append(data)
19+
return len(data)
20+
sock.send.side_effect = _send
1821

1922
s = adafruit_requests.Session(pool)
2023
headers = {"user-agent": "blinka/1.0.0"}

tests/legacy_mocket.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@ def __init__(self, response):
1313
self.settimeout = mock.Mock()
1414
self.close = mock.Mock()
1515
self.connect = mock.Mock()
16-
self.send = mock.Mock()
16+
self.send = mock.Mock(side_effect=self._send)
1717
self.readline = mock.Mock(side_effect=self._readline)
1818
self.recv = mock.Mock(side_effect=self._recv)
1919
self._response = response
2020
self._position = 0
2121

22+
def _send(self, data):
23+
return len(data)
24+
2225
def _readline(self):
2326
i = self._response.find(b"\r\n", self._position)
2427
r = self._response[self._position : i + 2]

tests/mocket.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ def __init__(self, response):
1414
self.settimeout = mock.Mock()
1515
self.close = mock.Mock()
1616
self.connect = mock.Mock()
17-
self.send = mock.Mock()
17+
self.send = mock.Mock(side_effect=self._send)
1818
self.readline = mock.Mock(side_effect=self._readline)
1919
self.recv = mock.Mock(side_effect=self._recv)
2020
self.recv_into = mock.Mock(side_effect=self._recv_into)
2121
self._response = response
2222
self._position = 0
2323

24+
def _send(self, data):
25+
return len(data)
26+
2427
def _readline(self):
2528
i = self._response.find(b"\r\n", self._position)
2629
r = self._response[self._position : i + 2]

tests/post_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,20 @@ def test_method():
2121
s = adafruit_requests.Session(pool)
2222
r = s.post("http://" + host + "/post")
2323
sock.connect.assert_called_once_with((ip, 80))
24+
25+
sock.send.assert_has_calls(
26+
[
27+
mock.call(b"POST"),
28+
mock.call(b" /"),
29+
mock.call(b"post"),
30+
mock.call(b" HTTP/1.1\r\n"),
31+
]
32+
)
2433
sock.send.assert_has_calls(
25-
[mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")]
34+
[
35+
mock.call(b"Host: "),
36+
mock.call(b"httpbin.org"),
37+
]
2638
)
2739

2840

tests/protocol_test.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,19 @@ def test_get_https_text():
3232
r = s.get("https://" + host + path)
3333

3434
sock.connect.assert_called_once_with((host, 443))
35+
3536
sock.send.assert_has_calls(
3637
[
37-
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
38-
mock.call(b"Host: wifitest.adafruit.com\r\n"),
38+
mock.call(b"GET"),
39+
mock.call(b" /"),
40+
mock.call(b"testwifi/index.html"),
41+
mock.call(b" HTTP/1.1\r\n"),
42+
]
43+
)
44+
sock.send.assert_has_calls(
45+
[
46+
mock.call(b"Host: "),
47+
mock.call(b"wifitest.adafruit.com"),
3948
]
4049
)
4150
assert r.text == str(text, "utf-8")
@@ -54,10 +63,19 @@ def test_get_http_text():
5463
r = s.get("http://" + host + path)
5564

5665
sock.connect.assert_called_once_with((ip, 80))
66+
67+
sock.send.assert_has_calls(
68+
[
69+
mock.call(b"GET"),
70+
mock.call(b" /"),
71+
mock.call(b"testwifi/index.html"),
72+
mock.call(b" HTTP/1.1\r\n"),
73+
]
74+
)
5775
sock.send.assert_has_calls(
5876
[
59-
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
60-
mock.call(b"Host: wifitest.adafruit.com\r\n"),
77+
mock.call(b"Host: "),
78+
mock.call(b"wifitest.adafruit.com"),
6179
]
6280
)
6381
assert r.text == str(text, "utf-8")

0 commit comments

Comments
 (0)