Skip to content

Commit 4b495ad

Browse files
authored
Merge pull request #186 from adrianbn/133_while_test_with_func
Visit functions in while test (#133)
2 parents ce56a20 + 9cb0b56 commit 4b495ad

File tree

5 files changed

+135
-6
lines changed

5 files changed

+135
-6
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def foo():
2+
return True
3+
4+
while foo():
5+
print(x)
6+
x += 1
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def foo():
2+
return 6
3+
4+
while foo() > x:
5+
print(x)
6+
x += 1
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def foo():
2+
return 6
3+
4+
while x < foo():
5+
print(x)
6+
x += 1

pyt/cfg/stmt_visitor.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -555,23 +555,44 @@ def visit_For(self, node):
555555
path=self.filenames[-1]
556556
))
557557

558-
if isinstance(node.iter, ast.Call) and get_call_names_as_string(node.iter.func) in self.function_names:
559-
last_node = self.visit(node.iter)
560-
last_node.connect(for_node)
558+
self.process_loop_funcs(node.iter, for_node)
561559

562560
return self.loop_node_skeleton(for_node, node)
563561

562+
def process_loop_funcs(self, comp_n, loop_node):
563+
"""
564+
If the loop test node contains function calls, it connects the loop node to the nodes of
565+
those function calls.
566+
567+
:param comp_n: The test node of a loop that may contain functions.
568+
:param loop_node: The loop node itself to connect to the new function nodes if any
569+
:return: None
570+
"""
571+
if isinstance(comp_n, ast.Call) and get_call_names_as_string(comp_n.func) in self.function_names:
572+
last_node = self.visit(comp_n)
573+
last_node.connect(loop_node)
574+
564575
def visit_While(self, node):
565576
label_visitor = LabelVisitor()
566-
label_visitor.visit(node.test)
577+
test = node.test # the test condition of the while loop
578+
label_visitor.visit(test)
567579

568-
test = self.append_node(Node(
580+
while_node = self.append_node(Node(
569581
'while ' + label_visitor.result + ':',
570582
node,
571583
path=self.filenames[-1]
572584
))
573585

574-
return self.loop_node_skeleton(test, node)
586+
if isinstance(test, ast.Compare):
587+
# quirk. See https://greentreesnakes.readthedocs.io/en/latest/nodes.html#Compare
588+
self.process_loop_funcs(test.left, while_node)
589+
590+
for comp in test.comparators:
591+
self.process_loop_funcs(comp, while_node)
592+
else: # while foo():
593+
self.process_loop_funcs(test, while_node)
594+
595+
return self.loop_node_skeleton(while_node, node)
575596

576597
def add_blackbox_or_builtin_call(self, node, blackbox): # noqa: C901
577598
"""Processes a blackbox or builtin function when it is called.

tests/cfg/cfg_test.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,96 @@ def test_while_line_numbers(self):
684684
self.assertLineNumber(else_body_2, 6)
685685
self.assertLineNumber(next_stmt, 7)
686686

687+
def test_while_func_comparator(self):
688+
self.cfg_create_from_file('examples/example_inputs/while_func_comparator.py')
689+
690+
self.assert_length(self.cfg.nodes, expected_length=9)
691+
692+
entry = 0
693+
test = 1
694+
entry_foo = 2
695+
ret_foo = 3
696+
exit_foo = 4
697+
call_foo = 5
698+
_print = 6
699+
body_1 = 7
700+
_exit = 8
701+
702+
self.assertEqual(self.cfg.nodes[test].label, 'while foo():')
703+
704+
self.assertInCfg([
705+
(test, entry),
706+
(entry_foo, test),
707+
(_print, test),
708+
(_exit, test),
709+
(body_1, _print),
710+
(test, body_1),
711+
(test, call_foo),
712+
(ret_foo, entry_foo),
713+
(exit_foo, ret_foo),
714+
(call_foo, exit_foo)
715+
])
716+
717+
def test_while_func_comparator_rhs(self):
718+
self.cfg_create_from_file('examples/example_inputs/while_func_comparator_rhs.py')
719+
720+
self.assert_length(self.cfg.nodes, expected_length=9)
721+
722+
entry = 0
723+
test = 1
724+
entry_foo = 2
725+
ret_foo = 3
726+
exit_foo = 4
727+
call_foo = 5
728+
_print = 6
729+
body_1 = 7
730+
_exit = 8
731+
732+
self.assertEqual(self.cfg.nodes[test].label, 'while x < foo():')
733+
734+
self.assertInCfg([
735+
(test, entry),
736+
(entry_foo, test),
737+
(_print, test),
738+
(_exit, test),
739+
(body_1, _print),
740+
(test, body_1),
741+
(test, call_foo),
742+
(ret_foo, entry_foo),
743+
(exit_foo, ret_foo),
744+
(call_foo, exit_foo)
745+
])
746+
747+
def test_while_func_comparator_lhs(self):
748+
self.cfg_create_from_file('examples/example_inputs/while_func_comparator_lhs.py')
749+
750+
self.assert_length(self.cfg.nodes, expected_length=9)
751+
752+
entry = 0
753+
test = 1
754+
entry_foo = 2
755+
ret_foo = 3
756+
exit_foo = 4
757+
call_foo = 5
758+
_print = 6
759+
body_1 = 7
760+
_exit = 8
761+
762+
self.assertEqual(self.cfg.nodes[test].label, 'while foo() > x:')
763+
764+
self.assertInCfg([
765+
(test, entry),
766+
(entry_foo, test),
767+
(_print, test),
768+
(_exit, test),
769+
(body_1, _print),
770+
(test, body_1),
771+
(test, call_foo),
772+
(ret_foo, entry_foo),
773+
(exit_foo, ret_foo),
774+
(call_foo, exit_foo)
775+
])
776+
687777

688778
class CFGAssignmentMultiTest(CFGBaseTestCase):
689779
def test_assignment_multi_target(self):

0 commit comments

Comments
 (0)