Skip to content

Commit e704c21

Browse files
committed
133: Visit functions in while test
1 parent ce56a20 commit e704c21

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def foo():
2+
return 6
3+
4+
while x < foo():
5+
print(x)
6+
x += 1

pyt/cfg/stmt_visitor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,13 +565,18 @@ def visit_While(self, node):
565565
label_visitor = LabelVisitor()
566566
label_visitor.visit(node.test)
567567

568-
test = self.append_node(Node(
568+
while_node = self.append_node(Node(
569569
'while ' + label_visitor.result + ':',
570570
node,
571571
path=self.filenames[-1]
572572
))
573573

574-
return self.loop_node_skeleton(test, node)
574+
for comp in node.test.comparators:
575+
if isinstance(comp, ast.Call) and get_call_names_as_string(comp.func) in self.function_names:
576+
last_node = self.visit(comp)
577+
last_node.connect(while_node)
578+
579+
return self.loop_node_skeleton(while_node, node)
575580

576581
def add_blackbox_or_builtin_call(self, node, blackbox): # noqa: C901
577582
"""Processes a blackbox or builtin function when it is called.

tests/cfg/cfg_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,38 @@ def test_while_line_numbers(self):
684684
self.assertLineNumber(else_body_2, 6)
685685
self.assertLineNumber(next_stmt, 7)
686686

687+
def test_while_func_iterator(self):
688+
self.cfg_create_from_file('examples/example_inputs/while_func_comparator.py')
689+
690+
self.assert_length(self.cfg.nodes, expected_length=9)
691+
692+
entry = 0
693+
test = 1
694+
entry_foo = 2
695+
ret_foo = 3
696+
exit_foo = 4
697+
call_foo = 5
698+
_print = 6
699+
body_1 = 7
700+
_exit = 8
701+
702+
self.assertEqual(self.cfg.nodes[test].label, 'while x < foo():')
703+
704+
self.assertInCfg([
705+
(test, entry),
706+
(entry_foo, test),
707+
(_print, test),
708+
(_exit, test),
709+
(body_1, _print),
710+
711+
(test, body_1),
712+
(test, call_foo),
713+
(ret_foo, entry_foo),
714+
(exit_foo, ret_foo),
715+
(call_foo, exit_foo),
716+
717+
])
718+
687719

688720
class CFGAssignmentMultiTest(CFGBaseTestCase):
689721
def test_assignment_multi_target(self):

0 commit comments

Comments
 (0)