Skip to content

Commit c414e8f

Browse files
committed
Add ipv6 support to should_bypass_proxies
Add support to should_bypass_proxies to support IPv6 ipaddresses and CIDRs in no_proxy. Includes adding IPv6 support to various other helper functions.
1 parent 94f9f73 commit c414e8f

File tree

2 files changed

+121
-19
lines changed

2 files changed

+121
-19
lines changed

requests/utils.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -648,18 +648,42 @@ def requote_uri(uri):
648648
return quote(uri, safe=safe_without_percent)
649649

650650

651+
def _get_mask_bits(mask, totalbits=32):
652+
"""Converts a mask from /xx format to a int
653+
to be used as a mask for IP's in int format
654+
655+
Example: if mask is 24 function returns 0xFFFFFF00
656+
if mask is 24 and totalbits=128 function
657+
returns 0xFFFFFF00000000000000000000000000
658+
659+
:rtype: int
660+
"""
661+
bits = ((1 << mask) - 1) << (totalbits - mask)
662+
return bits
663+
664+
651665
def address_in_network(ip, net):
652666
"""This function allows you to check if an IP belongs to a network subnet
653667
654668
Example: returns True if ip = 192.168.1.1 and net = 192.168.1.0/24
655669
returns False if ip = 192.168.1.1 and net = 192.168.100.0/24
670+
returns True if ip = 1:2:3:4::1 and net = 1:2:3:4::/64
656671
657672
:rtype: bool
658673
"""
659-
ipaddr = struct.unpack('=L', socket.inet_aton(ip))[0]
660674
netaddr, bits = net.split('/')
661-
netmask = struct.unpack('=L', socket.inet_aton(dotted_netmask(int(bits))))[0]
662-
network = struct.unpack('=L', socket.inet_aton(netaddr))[0] & netmask
675+
if is_ipv4_address(ip) and is_ipv4_address(netaddr):
676+
ipaddr = struct.unpack('>L', socket.inet_aton(ip))[0]
677+
netmask = _get_mask_bits(int(bits))
678+
network = struct.unpack('>L', socket.inet_aton(netaddr))[0]
679+
elif is_ipv6_address(ip) and is_ipv6_address(netaddr):
680+
ipaddr_msb, ipaddr_lsb = struct.unpack('>QQ', socket.inet_pton(socket.AF_INET6, ip))
681+
ipaddr = (ipaddr_msb << 64) ^ ipaddr_lsb
682+
netmask = _get_mask_bits(int(bits), 128)
683+
network_msb, network_lsb = struct.unpack('>QQ', socket.inet_pton(socket.AF_INET6, netaddr))
684+
network = (network_msb << 64) ^ network_lsb
685+
else:
686+
return False
663687
return (ipaddr & netmask) == (network & netmask)
664688

665689

@@ -679,31 +703,58 @@ def is_ipv4_address(string_ip):
679703
:rtype: bool
680704
"""
681705
try:
682-
socket.inet_aton(string_ip)
706+
socket.inet_pton(socket.AF_INET, string_ip)
707+
except socket.error:
708+
return False
709+
return True
710+
711+
712+
def is_ipv6_address(string_ip):
713+
"""
714+
:rtype: bool
715+
"""
716+
try:
717+
socket.inet_pton(socket.AF_INET6, string_ip)
683718
except socket.error:
684719
return False
685720
return True
686721

687722

723+
def compare_ips(a, b):
724+
"""
725+
Compare 2 IP's, uses socket.inet_pton to normalize IPv6 IPs
726+
727+
:rtype: bool
728+
"""
729+
if a == b:
730+
return True
731+
try:
732+
return socket.inet_pton(socket.AF_INET6, a) == socket.inet_pton(socket.AF_INET6, b)
733+
except socket.error:
734+
return False
735+
736+
688737
def is_valid_cidr(string_network):
689738
"""
690739
Very simple check of the cidr format in no_proxy variable.
691740
692741
:rtype: bool
693742
"""
694743
if string_network.count('/') == 1:
744+
address, mask = string_network.split('/')
695745
try:
696-
mask = int(string_network.split('/')[1])
746+
mask = int(mask)
697747
except ValueError:
698748
return False
699749

700-
if mask < 1 or mask > 32:
701-
return False
702-
703-
try:
704-
socket.inet_aton(string_network.split('/')[0])
705-
except socket.error:
706-
return False
750+
if is_ipv4_address(address):
751+
if mask < 1 or mask > 32:
752+
return False
753+
elif is_ipv6_address(address):
754+
if mask < 1 or mask > 128:
755+
return False
756+
else:
757+
return False
707758
else:
708759
return False
709760
return True
@@ -759,12 +810,12 @@ def should_bypass_proxies(url, no_proxy):
759810
host for host in no_proxy.replace(' ', '').split(',') if host
760811
)
761812

762-
if is_ipv4_address(parsed.hostname):
813+
if is_ipv4_address(parsed.hostname) or is_ipv6_address(parsed.hostname):
763814
for proxy_ip in no_proxy:
764815
if is_valid_cidr(proxy_ip):
765816
if address_in_network(parsed.hostname, proxy_ip):
766817
return True
767-
elif parsed.hostname == proxy_ip:
818+
elif compare_ips(parsed.hostname, proxy_ip):
768819
# If no_proxy ip was defined in plain IP notation instead of cidr notation &
769820
# matches the IP of the index
770821
return True

tests/test_utils.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
requote_uri, select_proxy, should_bypass_proxies, super_len,
2222
to_key_val_list, to_native_string,
2323
unquote_header_value, unquote_unreserved,
24-
urldefragauth, add_dict_to_cookiejar, set_environ)
24+
urldefragauth, add_dict_to_cookiejar, set_environ, _get_mask_bits,
25+
compare_ips)
2526
from requests._internal_utils import unicode_is_ascii
2627

2728
from .compat import StringIO, cStringIO
@@ -216,8 +217,13 @@ def test_invalid(self, value):
216217

217218
class TestIsValidCIDR:
218219

219-
def test_valid(self):
220-
assert is_valid_cidr('192.168.1.0/24')
220+
@pytest.mark.parametrize(
221+
'value', (
222+
'192.168.1.0/24',
223+
'1:2:3:4::/64',
224+
))
225+
def test_valid(self, value):
226+
assert is_valid_cidr(value)
221227

222228
@pytest.mark.parametrize(
223229
'value', (
@@ -226,6 +232,11 @@ def test_valid(self):
226232
'192.168.1.0/128',
227233
'192.168.1.0/-1',
228234
'192.168.1.999/24',
235+
'1:2:3:4::1',
236+
'1:2:3:4::/a',
237+
'1:2:3:4::0/321',
238+
'1:2:3:4::/-1',
239+
'1:2:3:4::12211/64',
229240
))
230241
def test_invalid(self, value):
231242
assert not is_valid_cidr(value)
@@ -239,6 +250,12 @@ def test_valid(self):
239250
def test_invalid(self):
240251
assert not address_in_network('172.16.0.1', '192.168.1.0/24')
241252

253+
def test_valid_v6(self):
254+
assert address_in_network('1:2:3:4::1111', '1:2:3:4::/64')
255+
256+
def test_invalid_v6(self):
257+
assert not address_in_network('1:2:3:4:1111', '1:2:3:4::/124')
258+
242259

243260
class TestGuessFilename:
244261

@@ -628,13 +645,18 @@ def test_urldefragauth(url, expected):
628645
('http://172.16.1.12:5000/', False),
629646
('http://google.com:5000/v1.0/', False),
630647
('file:///some/path/on/disk', True),
648+
('http://[1:2:3:4:5:6:7:8]:5000/', True),
649+
('http://[1:2:3:4::1]/', True),
650+
('http://[1:2:3:9::1]/', True),
651+
('http://[1:2:3:9:0:0:0:1]/', True),
652+
('http://[1:2:3:9::2]/', False),
631653
))
632654
def test_should_bypass_proxies(url, expected, monkeypatch):
633655
"""Tests for function should_bypass_proxies to check if proxy
634656
can be bypassed or not
635657
"""
636-
monkeypatch.setenv('no_proxy', '192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000')
637-
monkeypatch.setenv('NO_PROXY', '192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000')
658+
monkeypatch.setenv('no_proxy', '192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000')
659+
monkeypatch.setenv('NO_PROXY', '192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000')
638660
assert should_bypass_proxies(url, no_proxy=None) == expected
639661

640662

@@ -785,3 +807,32 @@ def test_set_environ_raises_exception():
785807
raise Exception('Expected exception')
786808

787809
assert 'Expected exception' in str(exception.value)
810+
811+
812+
@pytest.mark.parametrize(
813+
'mask, totalbits, maskbits', (
814+
(24, None, 0xFFFFFF00),
815+
(31, None, 0xFFFFFFFE),
816+
(0, None, 0x0),
817+
(4, 4, 0xF),
818+
(24, 128, 0xFFFFFF00000000000000000000000000),
819+
))
820+
def test__get_mask_bits(mask, totalbits, maskbits):
821+
args = {"mask": mask}
822+
if totalbits:
823+
args["totalbits"] = totalbits
824+
assert _get_mask_bits(**args) == maskbits
825+
826+
827+
@pytest.mark.parametrize(
828+
'a, b, expected', (
829+
('1.2.3.4', '1.2.3.4', True),
830+
('1.2.3.4', '2.2.3.4', False),
831+
('1::4', '1.2.3.4', False),
832+
('1::4', '1::4', True),
833+
('1::4', '1:0:0:0:0:0:0:4', True),
834+
('1::4', '1:0:0:0:0:0::4', True),
835+
('1::4', '1:0:0:0:0:0:1:4', False),
836+
))
837+
def test_compare_ips(a, b, expected):
838+
assert compare_ips(a, b) == expected

0 commit comments

Comments
 (0)