Skip to content

Commit e53d73a

Browse files
committed
feat: support async response
- Re-align CompressionMiddleware code with GzipMiddleware from Django 5.1. This adds support for async responses, and implements Heal The Breach (HTB), a guard against HTTPS BREACH attack - Apply some checks to ensure backward compatibility with older Django versions
1 parent 7b53068 commit e53d73a

File tree

2 files changed

+78
-51
lines changed

2 files changed

+78
-51
lines changed

compression_middleware/middleware.py

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,17 @@
1515
__all__ = ["CompressionMiddleware"]
1616

1717

18+
from django import VERSION as django_version
19+
from django.middleware.gzip import compress_sequence as gzip_compress_stream
20+
from django.middleware.gzip import compress_string as gzip_compress
21+
from django.utils.cache import patch_vary_headers
22+
1823
from .br import brotli_compress, brotli_compress_stream
1924
from .zstd import zstd_compress, zstd_compress_stream
2025

21-
from django.utils.text import (
22-
compress_string as gzip_compress,
23-
compress_sequence as gzip_compress_stream,
24-
)
25-
from django.utils.cache import patch_vary_headers
26-
2726
try:
2827
from django.utils.deprecation import MiddlewareMixin
29-
except ImportError: # pragma: no cover
28+
except ImportError: # pragma: no cover
3029
MiddlewareMixin = object
3130

3231

@@ -52,9 +51,9 @@
5251
# supported encodings in order of preference
5352
# (encoding, bulk_compressor, stream_compressor)
5453
compressors = (
55-
("zstd", zstd_compress, zstd_compress_stream),
56-
("br", brotli_compress, brotli_compress_stream),
57-
("gzip", gzip_compress, gzip_compress_stream),
54+
("zstd", zstd_compress, zstd_compress_stream),
55+
("br", brotli_compress, brotli_compress_stream),
56+
("gzip", gzip_compress, gzip_compress_stream),
5857
)
5958

6059

@@ -76,65 +75,91 @@ def encoding_name(s):
7675
return s.strip()
7776

7877

79-
def compressor(accept_encoding):
78+
def select_compressor(accept_encoding):
8079
# We don't want to process extremely long headers. It might be an attack:
8180
accept_encoding = accept_encoding[:200]
8281
client_encodings = set(encoding_name(e) for e in accept_encoding.split(","))
8382
if "*" in client_encodings:
8483
# Our first choice:
8584
return compressors[0]
86-
for encoding, compress_func, stream_func in compressors:
87-
if encoding in client_encodings:
88-
return (encoding, compress_func, stream_func)
85+
for compressor in compressors:
86+
if compressor[0] in client_encodings:
87+
return compressor
8988
return (None, None, None)
9089

9190

9291
class CompressionMiddleware(MiddlewareMixin):
9392
"""
94-
This middleware compresses content based on the Accept-Encoding header.
95-
96-
The Vary header is set for the sake of downstream caches.
93+
Compress content based on the Accept-Encoding header, and
94+
set the Vary header accordingly.
9795
"""
9896

97+
max_random_bytes = 100
98+
9999
def process_response(self, request, response):
100-
# Test a few things before we even try:
101-
# - content is already encoded
102-
# - really short responses are not worth it
103-
if response.has_header("Content-Encoding") or (
104-
not response.streaming and len(response.content) < MIN_LEN
105-
):
100+
# It's not worth attempting to compress really short responses.
101+
if not response.streaming and len(response.content) < MIN_LEN:
102+
return response
103+
104+
# Avoid compression if we've already got a content-encoding.
105+
if response.has_header("Content-Encoding"):
106106
return response
107107

108108
patch_vary_headers(response, ("Accept-Encoding",))
109+
109110
ae = request.META.get("HTTP_ACCEPT_ENCODING", "")
110-
encoding, compress_func, stream_func = compressor(ae)
111-
if not encoding:
111+
encoding, compress_string, compress_sequence = select_compressor(ae)
112+
if encoding is None:
112113
# No compression in common with client (the client probably didn't
113114
# indicate support for anything).
114115
return response
115116

117+
compress_kwargs = {}
118+
if encoding == "gzip" and django_version >= (4, 2):
119+
compress_kwargs["max_random_bytes"] = self.max_random_bytes
120+
116121
if response.streaming:
122+
if getattr(response, "is_async", False):
123+
124+
# forward args explicitly to capture fixed references in
125+
# case they are set again later.
126+
async def compress_wrapper(streaming_content, **compress_kwargs):
127+
async for chunk in streaming_content:
128+
yield compress_string(
129+
chunk,
130+
**compress_kwargs,
131+
)
132+
133+
response.streaming_content = compress_wrapper(
134+
response.streaming_content,
135+
**compress_kwargs,
136+
)
137+
else:
138+
response.streaming_content = compress_sequence(
139+
response.streaming_content,
140+
**compress_kwargs,
141+
)
142+
117143
# Delete the `Content-Length` header for streaming content, because
118144
# we won't know the compressed size until we stream it.
119-
response.streaming_content = stream_func(response.streaming_content)
120-
del response["Content-Length"]
145+
del response.headers["Content-Length"]
121146
else:
122-
#TODO: protect against excessive response size
123-
compressed_content = compress_func(response.content)
124-
# Return the compressed content only if compression is worth it
125-
if len(compressed_content) >= len(response.content) - MIN_IMPROVEMENT:
147+
# Return the compressed content only if it's actually shorter.
148+
compressed_content = compress_string(
149+
response.content,
150+
**compress_kwargs,
151+
)
152+
if len(response.content) - len(compressed_content) < MIN_IMPROVEMENT:
126153
return response
127-
128154
response.content = compressed_content
129-
response["Content-Length"] = str(len(response.content))
155+
response.headers["Content-Length"] = str(len(response.content))
130156

131157
# If there is a strong ETag, make it weak to fulfill the requirements
132-
# of RFC 7232 section-2.1 while also allowing conditional request
158+
# of RFC 9110 Section 8.8.1 while also allowing conditional request
133159
# matches on ETags.
134-
# Django's ConditionalGetMiddleware relies upon this etag behaviour.
135-
etag = response.get("ETag")
160+
etag = response.headers.get("ETag")
136161
if etag and etag.startswith('"'):
137-
response["ETag"] = "W/" + etag
138-
response["Content-Encoding"] = encoding
162+
response.headers["ETag"] = "W/" + etag
163+
response.headers["Content-Encoding"] = encoding
139164

140165
return response

tests/test_middleware.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import struct
2727
int2byte = struct.Struct(">B").pack
2828

29-
from compression_middleware.middleware import CompressionMiddleware, compressor
29+
from compression_middleware.middleware import CompressionMiddleware, select_compressor
3030
from .utils import UTF8_LOREM_IPSUM_IN_CZECH
3131

3232

@@ -231,18 +231,20 @@ def test_middleware_wont_compress_if_response_is_already_compressed(self):
231231

232232

233233
def test_content_encoding_parsing(self):
234-
self.assertEqual(compressor("")[0], None)
235-
self.assertEqual(compressor("gzip")[0], "gzip")
236-
self.assertEqual(compressor("br")[0], "br")
237-
self.assertEqual(compressor("gzip, br")[0], "br")
238-
self.assertEqual(compressor("br;q=1.0, gzip;q=0.8")[0], "br")
239-
self.assertEqual(compressor("br;q=0, gzip;q=0.8")[0], "gzip")
240-
self.assertEqual(compressor("bla;bla;gzip")[0], None)
241-
self.assertEqual(compressor("text/plain,*/*; charset=utf-8")[0], None) # PR #12
242-
self.assertEqual(compressor("gzip;q==1")[0], "gzip") # questionable
243-
self.assertEqual(compressor("br;gzip")[0], "br") # questionable
244-
# self.assertEqual(compressor("br;q=0, gzip;q=0.8, *;q=0.1")[0], "gzip")
245-
self.assertEqual(compressor("*")[0], "zstd")
234+
self.assertEqual(select_compressor("")[0], None)
235+
self.assertEqual(select_compressor("gzip")[0], "gzip")
236+
self.assertEqual(select_compressor("br")[0], "br")
237+
self.assertEqual(select_compressor("gzip, br")[0], "br")
238+
self.assertEqual(select_compressor("br;q=1.0, gzip;q=0.8")[0], "br")
239+
self.assertEqual(select_compressor("br;q=0, gzip;q=0.8")[0], "gzip")
240+
self.assertEqual(select_compressor("bla;bla;gzip")[0], None)
241+
self.assertEqual(
242+
select_compressor("text/plain,*/*; charset=utf-8")[0], None
243+
) # PR #12
244+
self.assertEqual(select_compressor("gzip;q==1")[0], "gzip") # questionable
245+
self.assertEqual(select_compressor("br;gzip")[0], "br") # questionable
246+
# self.assertEqual(select_compressor("br;q=0, gzip;q=0.8, *;q=0.1")[0], "gzip")
247+
self.assertEqual(select_compressor("*")[0], "zstd")
246248

247249

248250
class StreamingTest(SimpleTestCase):

0 commit comments

Comments
 (0)