Skip to content

Commit 00c7868

Browse files
Improved the performance of the regular languages compiler: generate fewer and better regexes.
1 parent cdaa3e1 commit 00c7868

File tree

1 file changed

+110
-26
lines changed

1 file changed

+110
-26
lines changed

prompt_toolkit/contrib/regular_languages/compiler.py

Lines changed: 110 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
m.variables().get('operator2') # Returns "add"
3939
4040
"""
41+
import itertools
4142
import re
4243
from typing import Callable, Dict, Iterable, Iterator, List
4344
from typing import Match as RegexMatch
@@ -175,10 +176,20 @@ def transform(node: Node) -> str:
175176

176177
# `Repeat`.
177178
elif isinstance(node, Repeat):
178-
return "(?:%s){%i,%s}%s" % (
179+
if node.max_repeat is None:
180+
if node.min_repeat == 0:
181+
repeat_sign = "*"
182+
elif node.min_repeat == 1:
183+
repeat_sign = "+"
184+
else:
185+
repeat_sign = "{%i,%s}" % (
186+
node.min_repeat,
187+
("" if node.max_repeat is None else str(node.max_repeat)),
188+
)
189+
190+
return "(?:%s)%s%s" % (
179191
transform(node.childnode),
180-
node.min_repeat,
181-
("" if node.max_repeat is None else str(node.max_repeat)),
192+
repeat_sign,
182193
("" if node.greedy else "?"),
183194
)
184195
else:
@@ -207,26 +218,96 @@ def _transform_prefix(
207218
:param create_group_func: A callable which takes a `Node` and returns the next
208219
free name for this node.
209220
"""
221+
def contains_variable(node: Node) -> bool:
222+
if isinstance(node, Regex):
223+
return False
224+
elif isinstance(node, Variable):
225+
return True
226+
elif isinstance(node, (Lookahead, Repeat)):
227+
return contains_variable(node.childnode)
228+
elif isinstance(node, (NodeSequence, AnyNode)):
229+
return any(contains_variable(child) for child in node.children)
230+
231+
return False
210232

211233
def transform(node: Node) -> Iterable[str]:
212234
# Generate regexes for all permutations of this OR. Each node
213235
# should be in front once.
214236
if isinstance(node, AnyNode):
237+
# If we have a definition like:
238+
# (?P<name> .*) | (?P<city> .*)
239+
# Then we want to be able to generate completions for both the
240+
# name as well as the city. We do this by yielding two
241+
# different regular expressions, because the engine won't
242+
# follow multiple paths, if multiple are possible.
243+
children_with_variable = []
244+
children_without_variable = []
215245
for c in node.children:
216-
for r in transform(c):
217-
yield "(?:%s)?" % r
246+
if contains_variable(c):
247+
children_with_variable.append(c)
248+
else:
249+
children_without_variable.append(c)
250+
251+
for c in children_with_variable:
252+
yield from transform(c)
218253

219-
# For a sequence. We can either have a match for the sequence
220-
# of all the children, or for an exact match of the first X
221-
# children, followed by a partial match of the next children.
254+
# Merge options without variable together.
255+
if children_without_variable:
256+
yield "|".join(r for c in children_without_variable for r in transform(c))
257+
258+
# For a sequence. We can either have a match for the sequence of
259+
# all the children, or for an exact match of the first X children,
260+
# followed by a partial match of the next children.
261+
# It is important here to yield a separate regex for each child
262+
# that contains a variable. (With the variable at the end.)
222263
elif isinstance(node, NodeSequence):
264+
# For all components in the sequence, compute prefix patterns,
265+
# as well as full patterns.
266+
complete = [
267+
cls._transform(c, create_group_func) for c in node.children
268+
]
269+
prefixes = [list(transform(c)) for c in node.children]
270+
variable_nodes = [contains_variable(c) for c in node.children]
271+
272+
# However, if any child is contains a variable, we should
273+
# yield an expression up to that point, so that we are sure
274+
# this will be matched.
275+
# (Otherwise, 'set (\s+ (?P<var> ..)) \s' won't complete "var".)
223276
for i in range(len(node.children)):
224-
a = [
225-
cls._transform(c, create_group_func) for c in node.children[:i]
226-
]
227-
228-
for c_str in transform(node.children[i]):
229-
yield "(?:%s)" % ("".join(a) + c_str)
277+
if variable_nodes[i]:
278+
for c_str in prefixes[i]:
279+
yield "".join(complete[:i]) + c_str
280+
281+
# If there are non-variable nodes:
282+
if not all(variable_nodes):
283+
# If the input is: "[part1] [part2] [part3]", then this
284+
# gets compiled into:
285+
# (complete1 + (complete2 + (complete3 | partial3) | partial2) | partial1 )
286+
# The partial matches in here can possibly contain many
287+
# pattern (if it contained an `AnyNode`. We take the
288+
# product of all combinations.)
289+
# For nodes that contain a variable, we skip the "|partial" part here.
290+
result = []
291+
292+
# Start with complete patterns.
293+
for i, child in enumerate(node.children):
294+
result.append(["(?:"])
295+
result.append([complete[i]])
296+
297+
# Add prefix patterns.
298+
for i, child in reversed(list(enumerate(node.children))):
299+
if variable_nodes[i]:
300+
# No need to yield a prefix for this one, we did
301+
# the variable prefixes earlier.
302+
result.append(')')
303+
else:
304+
result.append(["|(?:"])
305+
# If this yields multiple, we should yield all combinations.
306+
result.append(prefixes[i])
307+
result.append(["))"])
308+
309+
for comb in itertools.product(*result):
310+
yield "".join(comb)
230311

231312
elif isinstance(node, Regex):
232313
yield "(?:%s)?" % node.regex
@@ -251,17 +332,20 @@ def transform(node: Node) -> Iterable[str]:
251332
# match, followed by a partial match.
252333
prefix = cls._transform(node.childnode, create_group_func)
253334

254-
for c_str in transform(node.childnode):
255-
if node.max_repeat:
256-
repeat_sign = "{,%i}" % (node.max_repeat - 1)
257-
else:
258-
repeat_sign = "*"
259-
yield "(?:%s)%s%s(?:%s)?" % (
260-
prefix,
261-
repeat_sign,
262-
("" if node.greedy else "?"),
263-
c_str,
264-
)
335+
if node.max_repeat == 1:
336+
yield from transform(node.childnode)
337+
else:
338+
for c_str in transform(node.childnode):
339+
if node.max_repeat:
340+
repeat_sign = "{,%i}" % (node.max_repeat - 1)
341+
else:
342+
repeat_sign = "*"
343+
yield "(?:%s)%s%s%s" % (
344+
prefix,
345+
repeat_sign,
346+
("" if node.greedy else "?"),
347+
c_str,
348+
)
265349

266350
else:
267351
raise TypeError("Got %r" % node)
@@ -343,7 +427,7 @@ def get_tuples() -> Iterable[Tuple[str, Tuple[int, int]]]:
343427

344428
def _nodes_to_values(self) -> List[Tuple[str, str, Tuple[int, int]]]:
345429
"""
346-
Returns list of list of (Node, string_value) tuples.
430+
Returns list of (Node, string_value) tuples.
347431
"""
348432

349433
def is_none(sl: Tuple[int, int]) -> bool:

0 commit comments

Comments
 (0)