@@ -377,26 +377,28 @@ def __init__(self, socket_pool, ssl_context=None):
377377 self ._last_response = None
378378
379379 def _free_socket (self , socket ):
380-
381380 if socket not in self ._open_sockets .values ():
382381 raise RuntimeError ("Socket not from session" )
383382 self ._socket_free [socket ] = True
384383
384+ def _close_socket (self , sock ):
385+ sock .close ()
386+ del self ._socket_free [sock ]
387+ key = None
388+ for k in self ._open_sockets :
389+ if self ._open_sockets [k ] == sock :
390+ key = k
391+ break
392+ if key :
393+ del self ._open_sockets [key ]
394+
385395 def _free_sockets (self ):
386396 free_sockets = []
387397 for sock in self ._socket_free :
388398 if self ._socket_free [sock ]:
389- sock .close ()
390399 free_sockets .append (sock )
391400 for sock in free_sockets :
392- del self ._socket_free [sock ]
393- key = None
394- for k in self ._open_sockets :
395- if self ._open_sockets [k ] == sock :
396- key = k
397- break
398- if key :
399- del self ._open_sockets [key ]
401+ self ._close_socket (sock )
400402
401403 def _get_socket (self , host , port , proto , * , timeout = 1 ):
402404 key = (host , port , proto )
@@ -440,6 +442,61 @@ def _get_socket(self, host, port, proto, *, timeout=1):
440442 self ._socket_free [sock ] = False
441443 return sock
442444
445+ @staticmethod
446+ def _send (socket , data ):
447+ total_sent = 0
448+ while total_sent < len (data ):
449+ sent = socket .send (data [total_sent :])
450+ if sent is None :
451+ sent = len (data )
452+ if sent == 0 :
453+ raise RuntimeError ("Connection closed" )
454+ total_sent += sent
455+
456+ def _send_request (self , socket , host , method , path , headers , data , json ):
457+ # pylint: disable=too-many-arguments
458+ self ._send (socket , bytes (method , "utf-8" ))
459+ self ._send (socket , b" /" )
460+ self ._send (socket , bytes (path , "utf-8" ))
461+ self ._send (socket , b" HTTP/1.1\r \n " )
462+ if "Host" not in headers :
463+ self ._send (socket , b"Host: " )
464+ self ._send (socket , bytes (host , "utf-8" ))
465+ self ._send (socket , b"\r \n " )
466+ if "User-Agent" not in headers :
467+ self ._send (socket , b"User-Agent: Adafruit CircuitPython\r \n " )
468+ # Iterate over keys to avoid tuple alloc
469+ for k in headers :
470+ self ._send (socket , k .encode ())
471+ self ._send (socket , b": " )
472+ self ._send (socket , headers [k ].encode ())
473+ self ._send (socket , b"\r \n " )
474+ if json is not None :
475+ assert data is None
476+ # pylint: disable=import-outside-toplevel
477+ try :
478+ import json as json_module
479+ except ImportError :
480+ import ujson as json_module
481+ data = json_module .dumps (json )
482+ self ._send (socket , b"Content-Type: application/json\r \n " )
483+ if data :
484+ if isinstance (data , dict ):
485+ self ._send (
486+ socket , b"Content-Type: application/x-www-form-urlencoded\r \n "
487+ )
488+ _post_data = ""
489+ for k in data :
490+ _post_data = "{}&{}={}" .format (_post_data , k , data [k ])
491+ data = _post_data [1 :]
492+ self ._send (socket , b"Content-Length: %d\r \n " % len (data ))
493+ self ._send (socket , b"\r \n " )
494+ if data :
495+ if isinstance (data , bytearray ):
496+ self ._send (socket , bytes (data ))
497+ else :
498+ self ._send (socket , bytes (data , "utf-8" ))
499+
443500 # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
444501 def request (
445502 self , method , url , data = None , json = None , headers = None , stream = False , timeout = 60
@@ -476,42 +533,11 @@ def request(
476533 self ._last_response = None
477534
478535 socket = self ._get_socket (host , port , proto , timeout = timeout )
479- socket .send (
480- b"%s /%s HTTP/1.1\r \n " % (bytes (method , "utf-8" ), bytes (path , "utf-8" ))
481- )
482- if "Host" not in headers :
483- socket .send (b"Host: %s\r \n " % bytes (host , "utf-8" ))
484- if "User-Agent" not in headers :
485- socket .send (b"User-Agent: Adafruit CircuitPython\r \n " )
486- # Iterate over keys to avoid tuple alloc
487- for k in headers :
488- socket .send (k .encode ())
489- socket .send (b": " )
490- socket .send (headers [k ].encode ())
491- socket .send (b"\r \n " )
492- if json is not None :
493- assert data is None
494- # pylint: disable=import-outside-toplevel
495- try :
496- import json as json_module
497- except ImportError :
498- import ujson as json_module
499- data = json_module .dumps (json )
500- socket .send (b"Content-Type: application/json\r \n " )
501- if data :
502- if isinstance (data , dict ):
503- socket .send (b"Content-Type: application/x-www-form-urlencoded\r \n " )
504- _post_data = ""
505- for k in data :
506- _post_data = "{}&{}={}" .format (_post_data , k , data [k ])
507- data = _post_data [1 :]
508- socket .send (b"Content-Length: %d\r \n " % len (data ))
509- socket .send (b"\r \n " )
510- if data :
511- if isinstance (data , bytearray ):
512- socket .send (bytes (data ))
513- else :
514- socket .send (bytes (data , "utf-8" ))
536+ try :
537+ self ._send_request (socket , host , method , path , headers , data , json )
538+ except :
539+ self ._close_socket (socket )
540+ raise
515541
516542 resp = Response (socket , self ) # our response
517543 if "location" in resp .headers and 300 <= resp .status_code <= 399 :
@@ -557,6 +583,7 @@ def __init__(self, socket, tls_mode):
557583 self .settimeout = socket .settimeout
558584 self .send = socket .send
559585 self .recv = socket .recv
586+ self .close = socket .close
560587
561588 def connect (self , address ):
562589 """connect wrapper to add non-standard mode parameter"""
0 commit comments