Skip to content

Commit 41c99e0

Browse files
Some bug fixes/improvements.
1 parent 4e8b9d1 commit 41c99e0

File tree

1 file changed

+80
-43
lines changed

1 file changed

+80
-43
lines changed

prompt_toolkit/contrib/regular_languages/compiler.py

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -218,16 +218,16 @@ def _transform_prefix(
218218
:param create_group_func: A callable which takes a `Node` and returns the next
219219
free name for this node.
220220
"""
221-
222221
def contains_variable(node: Node) -> bool:
223222
if isinstance(node, Regex):
224223
return False
225-
if isinstance(node, Variable):
224+
elif isinstance(node, Variable):
226225
return True
227-
if isinstance(node, (Lookahead, Repeat)):
226+
elif isinstance(node, (Lookahead, Repeat)):
228227
return contains_variable(node.childnode)
229-
if isinstance(node, (NodeSequence, AnyNode)):
228+
elif isinstance(node, (NodeSequence, AnyNode)):
230229
return any(contains_variable(child) for child in node.children)
230+
231231
return False
232232

233233
def transform(node: Node) -> Iterable[str]:
@@ -240,40 +240,74 @@ def transform(node: Node) -> Iterable[str]:
240240
# name as well as the city. We do this by yielding two
241241
# different regular expressions, because the engine won't
242242
# follow multiple paths, if multiple are possible.
243+
children_with_variable = []
244+
children_without_variable = []
245+
for c in node.children:
246+
if contains_variable(c):
247+
children_with_variable.append(c)
248+
else:
249+
children_without_variable.append(c)
243250

244-
# If it doesn't contain a variable, then we merge them together.
245-
if contains_variable(node):
246-
for c in node.children:
247-
for r in transform(c):
248-
yield r
249-
else:
250-
yield "|".join(r for c in node.children for r in transform(c))
251+
for c in children_with_variable:
252+
yield from transform(c)
253+
254+
# Merge options without variable together.
255+
if children_without_variable:
256+
yield "|".join(r for c in children_without_variable for r in transform(c))
251257

252258
# For a sequence. We can either have a match for the sequence of
253259
# all the children, or for an exact match of the first X children,
254260
# followed by a partial match of the next children.
261+
# It is important here to yield a separate regex for each child
262+
# that contains a variable. (With the variable at the end.)
255263
elif isinstance(node, NodeSequence):
256-
# If the input is: "[part1] [part2] [part3]", then this gets compiled into:
257-
# (complete1 + (complete2 + (complete3 | partial3) | partial2) | partial1 )
258-
# The partial matches in here can possibly contain many pattern
259-
# (if it contained an `AnyNode`. We take the product of all
260-
# combinations.)
261-
result = []
262-
263-
# Start with complete patterns.
264-
for child in node.children:
265-
result.append(["(?:"])
266-
result.append([cls._transform(child, create_group_func)])
267-
268-
# Add partial patterns.
269-
for child in reversed(node.children):
270-
result.append(["|(?:"])
271-
# If this yields multiple, we should yield all combinations.
272-
result.append(list(transform(child)))
273-
result.append(["))"])
274-
275-
for comb in itertools.product(*result):
276-
yield "".join(comb)
264+
# For all components in the sequence, compute prefix patterns,
265+
# as well as full patterns.
266+
complete = [
267+
cls._transform(c, create_group_func) for c in node.children
268+
]
269+
prefixes = [list(transform(c)) for c in node.children]
270+
variable_nodes = [contains_variable(c) for c in node.children]
271+
272+
# However, if any child is contains a variable, we should
273+
# yield an expression up to that point, so that we are sure
274+
# this will be matched.
275+
# (Otherwise, 'set (\s+ (?P<var> ..)) \s' won't complete "var".)
276+
for i in range(len(node.children)):
277+
if variable_nodes[i]:
278+
for c_str in prefixes[i]:
279+
yield "".join(complete[:i]) + c_str
280+
281+
# If there are non-variable nodes:
282+
if not all(variable_nodes):
283+
# If the input is: "[part1] [part2] [part3]", then this
284+
# gets compiled into:
285+
# (complete1 + (complete2 + (complete3 | partial3) | partial2) | partial1 )
286+
# The partial matches in here can possibly contain many
287+
# pattern (if it contained an `AnyNode`. We take the
288+
# product of all combinations.)
289+
# For nodes that contain a variable, we skip the "|partial" part here.
290+
result = []
291+
292+
# Start with complete patterns.
293+
for i, child in enumerate(node.children):
294+
result.append(["(?:"])
295+
result.append([complete[i]])
296+
297+
# Add prefix patterns.
298+
for i, child in reversed(list(enumerate(node.children))):
299+
if variable_nodes[i]:
300+
# No need to yield a prefix for this one, we did
301+
# the variable prefixes earlier.
302+
result.append(')')
303+
else:
304+
result.append(["|(?:"])
305+
# If this yields multiple, we should yield all combinations.
306+
result.append(prefixes[i])
307+
result.append(["))"])
308+
309+
for comb in itertools.product(*result):
310+
yield "".join(comb)
277311

278312
elif isinstance(node, Regex):
279313
yield "(?:%s)?" % node.regex
@@ -298,17 +332,20 @@ def transform(node: Node) -> Iterable[str]:
298332
# match, followed by a partial match.
299333
prefix = cls._transform(node.childnode, create_group_func)
300334

301-
for c_str in transform(node.childnode):
302-
if node.max_repeat:
303-
repeat_sign = "{,%i}" % (node.max_repeat - 1)
304-
else:
305-
repeat_sign = "*"
306-
yield "(?:%s)%s%s%s" % (
307-
prefix,
308-
repeat_sign,
309-
("" if node.greedy else "?"),
310-
c_str,
311-
)
335+
if node.max_repeat == 1:
336+
yield from transform(node.childnode)
337+
else:
338+
for c_str in transform(node.childnode):
339+
if node.max_repeat:
340+
repeat_sign = "{,%i}" % (node.max_repeat - 1)
341+
else:
342+
repeat_sign = "*"
343+
yield "(?:%s)%s%s%s" % (
344+
prefix,
345+
repeat_sign,
346+
("" if node.greedy else "?"),
347+
c_str,
348+
)
312349

313350
else:
314351
raise TypeError("Got %r" % node)

0 commit comments

Comments
 (0)