1818
1919from hypothesis .configuration import mkdir_p , storage_directory
2020from hypothesis .errors import InvalidArgument
21+ from hypothesis .internal .intervalsets import IntervalSet
2122
2223intervals = Tuple [Tuple [int , int ], ...]
2324cache_type = Dict [Tuple [Tuple [str , ...], int , int , intervals ], intervals ]
@@ -146,126 +147,6 @@ def as_general_categories(cats, name="cats"):
146147 return tuple (c for c in cs if c in out )
147148
148149
149- def _union_intervals (x , y ):
150- """Merge two sequences of intervals into a single tuple of intervals.
151-
152- Any integer bounded by `x` or `y` is also bounded by the result.
153-
154- >>> _union_intervals([(3, 10)], [(1, 2), (5, 17)])
155- ((1, 17),)
156- """
157- if not x :
158- return tuple ((u , v ) for u , v in y )
159- if not y :
160- return tuple ((u , v ) for u , v in x )
161- intervals = sorted (x + y , reverse = True )
162- result = [intervals .pop ()]
163- while intervals :
164- # 1. intervals is in descending order
165- # 2. pop() takes from the RHS.
166- # 3. (a, b) was popped 1st, then (u, v) was popped 2nd
167- # 4. Therefore: a <= u
168- # 5. We assume that u <= v and a <= b
169- # 6. So we need to handle 2 cases of overlap, and one disjoint case
170- # | u--v | u----v | u--v |
171- # | a----b | a--b | a--b |
172- u , v = intervals .pop ()
173- a , b = result [- 1 ]
174- if u <= b + 1 :
175- # Overlap cases
176- result [- 1 ] = (a , max (v , b ))
177- else :
178- # Disjoint case
179- result .append ((u , v ))
180- return tuple (result )
181-
182-
183- def _subtract_intervals (x , y ):
184- """Set difference for lists of intervals. That is, returns a list of
185- intervals that bounds all values bounded by x that are not also bounded by
186- y. x and y are expected to be in sorted order.
187-
188- For example _subtract_intervals([(1, 10)], [(2, 3), (9, 15)]) would
189- return [(1, 1), (4, 8)], removing the values 2, 3, 9 and 10 from the
190- interval.
191- """
192- if not y :
193- return tuple (x )
194- x = list (map (list , x ))
195- i = 0
196- j = 0
197- result = []
198- while i < len (x ) and j < len (y ):
199- # Iterate in parallel over x and y. j stays pointing at the smallest
200- # interval in the left hand side that could still overlap with some
201- # element of x at index >= i.
202- # Similarly, i is not incremented until we know that it does not
203- # overlap with any element of y at index >= j.
204-
205- xl , xr = x [i ]
206- assert xl <= xr
207- yl , yr = y [j ]
208- assert yl <= yr
209-
210- if yr < xl :
211- # The interval at y[j] is strictly to the left of the interval at
212- # x[i], so will not overlap with it or any later interval of x.
213- j += 1
214- elif yl > xr :
215- # The interval at y[j] is strictly to the right of the interval at
216- # x[i], so all of x[i] goes into the result as no further intervals
217- # in y will intersect it.
218- result .append (x [i ])
219- i += 1
220- elif yl <= xl :
221- if yr >= xr :
222- # x[i] is contained entirely in y[j], so we just skip over it
223- # without adding it to the result.
224- i += 1
225- else :
226- # The beginning of x[i] is contained in y[j], so we update the
227- # left endpoint of x[i] to remove this, and increment j as we
228- # now have moved past it. Note that this is not added to the
229- # result as is, as more intervals from y may intersect it so it
230- # may need updating further.
231- x [i ][0 ] = yr + 1
232- j += 1
233- else :
234- # yl > xl, so the left hand part of x[i] is not contained in y[j],
235- # so there are some values we should add to the result.
236- result .append ((xl , yl - 1 ))
237-
238- if yr + 1 <= xr :
239- # If y[j] finishes before x[i] does, there may be some values
240- # in x[i] left that should go in the result (or they may be
241- # removed by a later interval in y), so we update x[i] to
242- # reflect that and increment j because it no longer overlaps
243- # with any remaining element of x.
244- x [i ][0 ] = yr + 1
245- j += 1
246- else :
247- # Every element of x[i] other than the initial part we have
248- # already added is contained in y[j], so we move to the next
249- # interval.
250- i += 1
251- # Any remaining intervals in x do not overlap with any of y, as if they did
252- # we would not have incremented j to the end, so can be added to the result
253- # as they are.
254- result .extend (x [i :])
255- return tuple (map (tuple , result ))
256-
257-
258- def _intervals (s ):
259- """Return a tuple of intervals, covering the codepoints of characters in
260- `s`.
261-
262- >>> _intervals('abcdef0123456789')
263- ((48, 57), (97, 102))
264- """
265- intervals = tuple ((ord (c ), ord (c )) for c in sorted (s ))
266- return _union_intervals (intervals , intervals )
267-
268-
269150category_index_cache = {(): ()}
270151
271152
@@ -306,11 +187,14 @@ def _query_for_key(key):
306187 pass
307188 assert key
308189 if set (key ) == set (categories ()):
309- result = ( (0 , sys .maxunicode ), )
190+ result = IntervalSet ([ (0 , sys .maxunicode )] )
310191 else :
311- result = _union_intervals (_query_for_key (key [:- 1 ]), charmap ()[key [- 1 ]])
312- category_index_cache [key ] = result
313- return result
192+ result = IntervalSet (_query_for_key (key [:- 1 ])).union (
193+ IntervalSet (charmap ()[key [- 1 ]])
194+ )
195+ assert isinstance (result , IntervalSet )
196+ category_index_cache [key ] = result .intervals
197+ return result .intervals
314198
315199
316200limited_category_index_cache : cache_type = {}
@@ -344,14 +228,14 @@ def query(
344228 if max_codepoint is None :
345229 max_codepoint = sys .maxunicode
346230 catkey = _category_key (exclude_categories , include_categories )
347- character_intervals = _intervals (include_characters or "" )
348- exclude_intervals = _intervals (exclude_characters or "" )
231+ character_intervals = IntervalSet . from_string (include_characters or "" )
232+ exclude_intervals = IntervalSet . from_string (exclude_characters or "" )
349233 qkey = (
350234 catkey ,
351235 min_codepoint ,
352236 max_codepoint ,
353- character_intervals ,
354- exclude_intervals ,
237+ character_intervals . intervals ,
238+ exclude_intervals . intervals ,
355239 )
356240 try :
357241 return limited_category_index_cache [qkey ]
@@ -362,8 +246,6 @@ def query(
362246 for u , v in base :
363247 if v >= min_codepoint and u <= max_codepoint :
364248 result .append ((max (u , min_codepoint ), min (v , max_codepoint )))
365- result = tuple (result )
366- result = _union_intervals (result , character_intervals )
367- result = _subtract_intervals (result , exclude_intervals )
249+ result = (IntervalSet (result ) | character_intervals ) - exclude_intervals
368250 limited_category_index_cache [qkey ] = result
369251 return result
0 commit comments