Skip to content

Commit bc8624b

Browse files
committed
Add ipv6 support to should_bypass_proxies
Add support to should_bypass_proxies to support IPv6 ipaddresses and CIDRs in no_proxy. Includes adding IPv6 support to various other helper functions.
1 parent 9a40d12 commit bc8624b

File tree

2 files changed

+132
-17
lines changed

2 files changed

+132
-17
lines changed

src/requests/utils.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -679,18 +679,46 @@ def requote_uri(uri):
679679
return quote(uri, safe=safe_without_percent)
680680

681681

682+
def _get_mask_bits(mask, totalbits=32):
683+
"""Converts a mask from /xx format to a int
684+
to be used as a mask for IP's in int format
685+
686+
Example: if mask is 24 function returns 0xFFFFFF00
687+
if mask is 24 and totalbits=128 function
688+
returns 0xFFFFFF00000000000000000000000000
689+
690+
:rtype: int
691+
"""
692+
bits = ((1 << mask) - 1) << (totalbits - mask)
693+
return bits
694+
695+
682696
def address_in_network(ip, net):
683697
"""This function allows you to check if an IP belongs to a network subnet
684698
685699
Example: returns True if ip = 192.168.1.1 and net = 192.168.1.0/24
686700
returns False if ip = 192.168.1.1 and net = 192.168.100.0/24
701+
returns True if ip = 1:2:3:4::1 and net = 1:2:3:4::/64
687702
688703
:rtype: bool
689704
"""
690-
ipaddr = struct.unpack("=L", socket.inet_aton(ip))[0]
691705
netaddr, bits = net.split("/")
692-
netmask = struct.unpack("=L", socket.inet_aton(dotted_netmask(int(bits))))[0]
693-
network = struct.unpack("=L", socket.inet_aton(netaddr))[0] & netmask
706+
if is_ipv4_address(ip) and is_ipv4_address(netaddr):
707+
ipaddr = struct.unpack(">L", socket.inet_aton(ip))[0]
708+
netmask = _get_mask_bits(int(bits))
709+
network = struct.unpack(">L", socket.inet_aton(netaddr))[0]
710+
elif is_ipv6_address(ip) and is_ipv6_address(netaddr):
711+
ipaddr_msb, ipaddr_lsb = struct.unpack(
712+
">QQ", socket.inet_pton(socket.AF_INET6, ip)
713+
)
714+
ipaddr = (ipaddr_msb << 64) ^ ipaddr_lsb
715+
netmask = _get_mask_bits(int(bits), 128)
716+
network_msb, network_lsb = struct.unpack(
717+
">QQ", socket.inet_pton(socket.AF_INET6, netaddr)
718+
)
719+
network = (network_msb << 64) ^ network_lsb
720+
else:
721+
return False
694722
return (ipaddr & netmask) == (network & netmask)
695723

696724

@@ -710,30 +738,59 @@ def is_ipv4_address(string_ip):
710738
:rtype: bool
711739
"""
712740
try:
713-
socket.inet_aton(string_ip)
741+
socket.inet_pton(socket.AF_INET, string_ip)
742+
except OSError:
743+
return False
744+
return True
745+
746+
747+
def is_ipv6_address(string_ip):
748+
"""
749+
:rtype: bool
750+
"""
751+
try:
752+
socket.inet_pton(socket.AF_INET6, string_ip)
714753
except OSError:
715754
return False
716755
return True
717756

718757

758+
def compare_ips(a, b):
759+
"""
760+
Compare 2 IP's, uses socket.inet_pton to normalize IPv6 IPs
761+
762+
:rtype: bool
763+
"""
764+
if a == b:
765+
return True
766+
try:
767+
return socket.inet_pton(socket.AF_INET6, a) == socket.inet_pton(
768+
socket.AF_INET6, b
769+
)
770+
except OSError:
771+
return False
772+
773+
719774
def is_valid_cidr(string_network):
720775
"""
721776
Very simple check of the cidr format in no_proxy variable.
722777
723778
:rtype: bool
724779
"""
725780
if string_network.count("/") == 1:
781+
address, mask = string_network.split("/")
726782
try:
727-
mask = int(string_network.split("/")[1])
783+
mask = int(mask)
728784
except ValueError:
729785
return False
730786

731-
if mask < 1 or mask > 32:
732-
return False
733-
734-
try:
735-
socket.inet_aton(string_network.split("/")[0])
736-
except OSError:
787+
if is_ipv4_address(address):
788+
if mask < 1 or mask > 32:
789+
return False
790+
elif is_ipv6_address(address):
791+
if mask < 1 or mask > 128:
792+
return False
793+
else:
737794
return False
738795
else:
739796
return False
@@ -790,12 +847,12 @@ def get_proxy(key):
790847
# the end of the hostname, both with and without the port.
791848
no_proxy = (host for host in no_proxy.replace(" ", "").split(",") if host)
792849

793-
if is_ipv4_address(parsed.hostname):
850+
if is_ipv4_address(parsed.hostname) or is_ipv6_address(parsed.hostname):
794851
for proxy_ip in no_proxy:
795852
if is_valid_cidr(proxy_ip):
796853
if address_in_network(parsed.hostname, proxy_ip):
797854
return True
798-
elif parsed.hostname == proxy_ip:
855+
elif compare_ips(parsed.hostname, proxy_ip):
799856
# If no_proxy ip was defined in plain IP notation instead of cidr notation &
800857
# matches the IP of the index
801858
return True

tests/test_utils.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from requests.cookies import RequestsCookieJar
1515
from requests.structures import CaseInsensitiveDict
1616
from requests.utils import (
17+
_get_mask_bits,
1718
_parse_content_type_header,
1819
add_dict_to_cookiejar,
1920
address_in_network,
21+
compare_ips,
2022
dotted_netmask,
2123
extract_zipped_paths,
2224
get_auth_from_url,
@@ -263,8 +265,15 @@ def test_invalid(self, value):
263265

264266

265267
class TestIsValidCIDR:
266-
def test_valid(self):
267-
assert is_valid_cidr("192.168.1.0/24")
268+
@pytest.mark.parametrize(
269+
"value",
270+
(
271+
"192.168.1.0/24",
272+
"1:2:3:4::/64",
273+
),
274+
)
275+
def test_valid(self, value):
276+
assert is_valid_cidr(value)
268277

269278
@pytest.mark.parametrize(
270279
"value",
@@ -274,6 +283,11 @@ def test_valid(self):
274283
"192.168.1.0/128",
275284
"192.168.1.0/-1",
276285
"192.168.1.999/24",
286+
"1:2:3:4::1",
287+
"1:2:3:4::/a",
288+
"1:2:3:4::0/321",
289+
"1:2:3:4::/-1",
290+
"1:2:3:4::12211/64",
277291
),
278292
)
279293
def test_invalid(self, value):
@@ -287,6 +301,12 @@ def test_valid(self):
287301
def test_invalid(self):
288302
assert not address_in_network("172.16.0.1", "192.168.1.0/24")
289303

304+
def test_valid_v6(self):
305+
assert address_in_network("1:2:3:4::1111", "1:2:3:4::/64")
306+
307+
def test_invalid_v6(self):
308+
assert not address_in_network("1:2:3:4:1111", "1:2:3:4::/124")
309+
290310

291311
class TestGuessFilename:
292312
@pytest.mark.parametrize(
@@ -722,6 +742,11 @@ def test_urldefragauth(url, expected):
722742
("http://172.16.1.12:5000/", False),
723743
("http://google.com:5000/v1.0/", False),
724744
("file:///some/path/on/disk", True),
745+
("http://[1:2:3:4:5:6:7:8]:5000/", True),
746+
("http://[1:2:3:4::1]/", True),
747+
("http://[1:2:3:9::1]/", True),
748+
("http://[1:2:3:9:0:0:0:1]/", True),
749+
("http://[1:2:3:9::2]/", False),
725750
),
726751
)
727752
def test_should_bypass_proxies(url, expected, monkeypatch):
@@ -730,11 +755,11 @@ def test_should_bypass_proxies(url, expected, monkeypatch):
730755
"""
731756
monkeypatch.setenv(
732757
"no_proxy",
733-
"192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000",
758+
"192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000",
734759
)
735760
monkeypatch.setenv(
736761
"NO_PROXY",
737-
"192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000",
762+
"192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000",
738763
)
739764
assert should_bypass_proxies(url, no_proxy=None) == expected
740765

@@ -956,3 +981,36 @@ def QueryValueEx(key, value_name):
956981
monkeypatch.setattr(winreg, "OpenKey", OpenKey)
957982
monkeypatch.setattr(winreg, "QueryValueEx", QueryValueEx)
958983
assert should_bypass_proxies("http://example.com/", None) is False
984+
985+
986+
@pytest.mark.parametrize(
987+
"mask, totalbits, maskbits",
988+
(
989+
(24, None, 0xFFFFFF00),
990+
(31, None, 0xFFFFFFFE),
991+
(0, None, 0x0),
992+
(4, 4, 0xF),
993+
(24, 128, 0xFFFFFF00000000000000000000000000),
994+
),
995+
)
996+
def test__get_mask_bits(mask, totalbits, maskbits):
997+
args = {"mask": mask}
998+
if totalbits:
999+
args["totalbits"] = totalbits
1000+
assert _get_mask_bits(**args) == maskbits
1001+
1002+
1003+
@pytest.mark.parametrize(
1004+
"a, b, expected",
1005+
(
1006+
("1.2.3.4", "1.2.3.4", True),
1007+
("1.2.3.4", "2.2.3.4", False),
1008+
("1::4", "1.2.3.4", False),
1009+
("1::4", "1::4", True),
1010+
("1::4", "1:0:0:0:0:0:0:4", True),
1011+
("1::4", "1:0:0:0:0:0::4", True),
1012+
("1::4", "1:0:0:0:0:0:1:4", False),
1013+
),
1014+
)
1015+
def test_compare_ips(a, b, expected):
1016+
assert compare_ips(a, b) == expected

0 commit comments

Comments
 (0)