66# Don't subclass these. Stuff will break.
77
88import re
9+ from abc import ABC
10+ from dataclasses import dataclass , field
11+ from typing import Any , cast , Dict , List , Tuple , Union
912
10- from . import _headers
1113from ._abnf import request_target
14+ from ._headers import Headers , normalize_and_validate
1215from ._util import bytesify , LocalProtocolError , validate
1316
1417# Everything in __all__ gets re-exported as part of the h11 public API.
1518__all__ = [
19+ "Event" ,
1620 "Request" ,
1721 "InformationalResponse" ,
1822 "Response" ,
2428request_target_re = re .compile (request_target .encode ("ascii" ))
2529
2630
27- class _EventBundle :
28- _fields = []
29- _defaults = {}
30-
31- def __init__ (self , ** kwargs ):
32- _parsed = kwargs .pop ("_parsed" , False )
33- allowed = set (self ._fields )
34- for kwarg in kwargs :
35- if kwarg not in allowed :
36- raise TypeError (
37- "unrecognized kwarg {} for {}" .format (
38- kwarg , self .__class__ .__name__
39- )
40- )
41- required = allowed .difference (self ._defaults )
42- for field in required :
43- if field not in kwargs :
44- raise TypeError (
45- "missing required kwarg {} for {}" .format (
46- field , self .__class__ .__name__
47- )
48- )
49- self .__dict__ .update (self ._defaults )
50- self .__dict__ .update (kwargs )
51-
52- # Special handling for some fields
53-
54- if "headers" in self .__dict__ :
55- self .headers = _headers .normalize_and_validate (
56- self .headers , _parsed = _parsed
57- )
58-
59- if not _parsed :
60- for field in ["method" , "target" , "http_version" , "reason" ]:
61- if field in self .__dict__ :
62- self .__dict__ [field ] = bytesify (self .__dict__ [field ])
63-
64- if "status_code" in self .__dict__ :
65- if not isinstance (self .status_code , int ):
66- raise LocalProtocolError ("status code must be integer" )
67- # Because IntEnum objects are instances of int, but aren't
68- # duck-compatible (sigh), see gh-72.
69- self .status_code = int (self .status_code )
70-
71- self ._validate ()
72-
73- def _validate (self ):
74- pass
75-
76- def __repr__ (self ):
77- name = self .__class__ .__name__
78- kwarg_strs = [
79- "{}={}" .format (field , self .__dict__ [field ]) for field in self ._fields
80- ]
81- kwarg_str = ", " .join (kwarg_strs )
82- return "{}({})" .format (name , kwarg_str )
83-
84- # Useful for tests
85- def __eq__ (self , other ):
86- return self .__class__ == other .__class__ and self .__dict__ == other .__dict__
31+ class Event (ABC ):
32+ """
33+ Base class for h11 events.
34+ """
8735
88- # This is an unhashable type.
89- __hash__ = None
36+ __slots__ = ()
9037
9138
92- class Request (_EventBundle ):
39+ @dataclass (init = False , frozen = True )
40+ class Request (Event ):
9341 """The beginning of an HTTP request.
9442
9543 Fields:
@@ -123,10 +71,38 @@ class Request(_EventBundle):
12371
12472 """
12573
126- _fields = ["method" , "target" , "headers" , "http_version" ]
127- _defaults = {"http_version" : b"1.1" }
74+ __slots__ = ("method" , "headers" , "target" , "http_version" )
75+
76+ method : bytes
77+ headers : Headers
78+ target : bytes
79+ http_version : bytes
80+
81+ def __init__ (
82+ self ,
83+ * ,
84+ method : Union [bytes , str ],
85+ headers : Union [Headers , List [Tuple [bytes , bytes ]], List [Tuple [str , str ]]],
86+ target : Union [bytes , str ],
87+ http_version : Union [bytes , str ] = b"1.1" ,
88+ _parsed : bool = False ,
89+ ) -> None :
90+ super ().__init__ ()
91+ if isinstance (headers , Headers ):
92+ object .__setattr__ (self , "headers" , headers )
93+ else :
94+ object .__setattr__ (
95+ self , "headers" , normalize_and_validate (headers , _parsed = _parsed )
96+ )
97+ if not _parsed :
98+ object .__setattr__ (self , "method" , bytesify (method ))
99+ object .__setattr__ (self , "target" , bytesify (target ))
100+ object .__setattr__ (self , "http_version" , bytesify (http_version ))
101+ else :
102+ object .__setattr__ (self , "method" , method )
103+ object .__setattr__ (self , "target" , target )
104+ object .__setattr__ (self , "http_version" , http_version )
128105
129- def _validate (self ):
130106 # "A server MUST respond with a 400 (Bad Request) status code to any
131107 # HTTP/1.1 request message that lacks a Host header field and to any
132108 # request message that contains more than one Host header field or a
@@ -143,12 +119,58 @@ def _validate(self):
143119
144120 validate (request_target_re , self .target , "Illegal target characters" )
145121
122+ # This is an unhashable type.
123+ __hash__ = None # type: ignore
124+
125+
126+ @dataclass (init = False , frozen = True )
127+ class _ResponseBase (Event ):
128+ __slots__ = ("headers" , "http_version" , "reason" , "status_code" )
129+
130+ headers : Headers
131+ http_version : bytes
132+ reason : bytes
133+ status_code : int
134+
135+ def __init__ (
136+ self ,
137+ * ,
138+ headers : Union [Headers , List [Tuple [bytes , bytes ]], List [Tuple [str , str ]]],
139+ status_code : int ,
140+ http_version : Union [bytes , str ] = b"1.1" ,
141+ reason : Union [bytes , str ] = b"" ,
142+ _parsed : bool = False ,
143+ ) -> None :
144+ super ().__init__ ()
145+ if isinstance (headers , Headers ):
146+ object .__setattr__ (self , "headers" , headers )
147+ else :
148+ object .__setattr__ (
149+ self , "headers" , normalize_and_validate (headers , _parsed = _parsed )
150+ )
151+ if not _parsed :
152+ object .__setattr__ (self , "reason" , bytesify (reason ))
153+ object .__setattr__ (self , "http_version" , bytesify (http_version ))
154+ if not isinstance (status_code , int ):
155+ raise LocalProtocolError ("status code must be integer" )
156+ # Because IntEnum objects are instances of int, but aren't
157+ # duck-compatible (sigh), see gh-72.
158+ object .__setattr__ (self , "status_code" , int (status_code ))
159+ else :
160+ object .__setattr__ (self , "reason" , reason )
161+ object .__setattr__ (self , "http_version" , http_version )
162+ object .__setattr__ (self , "status_code" , status_code )
163+
164+ self .__post_init__ ()
165+
166+ def __post_init__ (self ) -> None :
167+ pass
146168
147- class _ResponseBase (_EventBundle ):
148- _fields = ["status_code" , "headers" , "http_version" , "reason" ]
149- _defaults = {"http_version" : b"1.1" , "reason" : b"" }
169+ # This is an unhashable type.
170+ __hash__ = None # type: ignore
150171
151172
173+ @dataclass (init = False , frozen = True )
152174class InformationalResponse (_ResponseBase ):
153175 """An HTTP informational response.
154176
@@ -179,14 +201,18 @@ class InformationalResponse(_ResponseBase):
179201
180202 """
181203
182- def _validate (self ):
204+ def __post_init__ (self ) -> None :
183205 if not (100 <= self .status_code < 200 ):
184206 raise LocalProtocolError (
185207 "InformationalResponse status_code should be in range "
186208 "[100, 200), not {}" .format (self .status_code )
187209 )
188210
211+ # This is an unhashable type.
212+ __hash__ = None # type: ignore
213+
189214
215+ @dataclass (init = False , frozen = True )
190216class Response (_ResponseBase ):
191217 """The beginning of an HTTP response.
192218
@@ -216,16 +242,20 @@ class Response(_ResponseBase):
216242
217243 """
218244
219- def _validate (self ):
245+ def __post_init__ (self ) -> None :
220246 if not (200 <= self .status_code < 600 ):
221247 raise LocalProtocolError (
222248 "Response status_code should be in range [200, 600), not {}" .format (
223249 self .status_code
224250 )
225251 )
226252
253+ # This is an unhashable type.
254+ __hash__ = None # type: ignore
255+
227256
228- class Data (_EventBundle ):
257+ @dataclass (init = False , frozen = True )
258+ class Data (Event ):
229259 """Part of an HTTP message body.
230260
231261 Fields:
@@ -258,16 +288,30 @@ class Data(_EventBundle):
258288
259289 """
260290
261- _fields = ["data" , "chunk_start" , "chunk_end" ]
262- _defaults = {"chunk_start" : False , "chunk_end" : False }
291+ __slots__ = ("data" , "chunk_start" , "chunk_end" )
292+
293+ data : bytes
294+ chunk_start : bool
295+ chunk_end : bool
296+
297+ def __init__ (
298+ self , data : bytes , chunk_start : bool = False , chunk_end : bool = False
299+ ) -> None :
300+ object .__setattr__ (self , "data" , data )
301+ object .__setattr__ (self , "chunk_start" , chunk_start )
302+ object .__setattr__ (self , "chunk_end" , chunk_end )
303+
304+ # This is an unhashable type.
305+ __hash__ = None # type: ignore
263306
264307
265308# XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that
266309# are forbidden to be sent in a trailer, since processing them as if they were
267310# present in the header section might bypass external security filters."
268311# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part
269312# Unfortunately, the list of forbidden fields is long and vague :-/
270- class EndOfMessage (_EventBundle ):
313+ @dataclass (init = False , frozen = True )
314+ class EndOfMessage (Event ):
271315 """The end of an HTTP message.
272316
273317 Fields:
@@ -284,11 +328,32 @@ class EndOfMessage(_EventBundle):
284328
285329 """
286330
287- _fields = ["headers" ]
288- _defaults = {"headers" : []}
331+ __slots__ = ("headers" ,)
332+
333+ headers : Headers
334+
335+ def __init__ (
336+ self ,
337+ * ,
338+ headers : Union [
339+ Headers , List [Tuple [bytes , bytes ]], List [Tuple [str , str ]], None
340+ ] = None ,
341+ _parsed : bool = False ,
342+ ) -> None :
343+ super ().__init__ ()
344+ if headers is None :
345+ headers = Headers ([])
346+ elif not isinstance (headers , Headers ):
347+ headers = normalize_and_validate (headers , _parsed = _parsed )
348+
349+ object .__setattr__ (self , "headers" , headers )
350+
351+ # This is an unhashable type.
352+ __hash__ = None # type: ignore
289353
290354
291- class ConnectionClosed (_EventBundle ):
355+ @dataclass (frozen = True )
356+ class ConnectionClosed (Event ):
292357 """This event indicates that the sender has closed their outgoing
293358 connection.
294359
0 commit comments