Skip to content

Commit effd872

Browse files
committed
133: Support for LHS functions and no comparison while tests
1 parent e704c21 commit effd872

File tree

5 files changed

+76
-9
lines changed

5 files changed

+76
-9
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
def foo():
2-
return 6
2+
return True
33

4-
while x < foo():
4+
while foo():
55
print(x)
66
x += 1
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def foo():
2+
return 6
3+
4+
while foo() > x:
5+
print(x)
6+
x += 1
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def foo():
2+
return 6
3+
4+
while x < foo():
5+
print(x)
6+
x += 1

pyt/cfg/stmt_visitor.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -563,19 +563,29 @@ def visit_For(self, node):
563563

564564
def visit_While(self, node):
565565
label_visitor = LabelVisitor()
566-
label_visitor.visit(node.test)
566+
test = node.test # the test condition of the while loop
567+
label_visitor.visit(test)
567568

568569
while_node = self.append_node(Node(
569570
'while ' + label_visitor.result + ':',
570571
node,
571572
path=self.filenames[-1]
572573
))
573574

574-
for comp in node.test.comparators:
575-
if isinstance(comp, ast.Call) and get_call_names_as_string(comp.func) in self.function_names:
576-
last_node = self.visit(comp)
575+
def process_comparator(comp_n):
576+
if isinstance(comp_n, ast.Call) and get_call_names_as_string(comp_n.func) in self.function_names:
577+
last_node = self.visit(comp_n)
577578
last_node.connect(while_node)
578579

580+
if isinstance(test, ast.Compare):
581+
comparators = test.comparators
582+
comparators.append(test.left) # quirk. See https://greentreesnakes.readthedocs.io/en/latest/nodes.html#Compare
583+
584+
for comp in comparators:
585+
process_comparator(comp)
586+
else: # while foo():
587+
process_comparator(test)
588+
579589
return self.loop_node_skeleton(while_node, node)
580590

581591
def add_blackbox_or_builtin_call(self, node, blackbox): # noqa: C901

tests/cfg/cfg_test.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def test_while_line_numbers(self):
684684
self.assertLineNumber(else_body_2, 6)
685685
self.assertLineNumber(next_stmt, 7)
686686

687-
def test_while_func_iterator(self):
687+
def test_while_func_comparator(self):
688688
self.cfg_create_from_file('examples/example_inputs/while_func_comparator.py')
689689

690690
self.assert_length(self.cfg.nodes, expected_length=9)
@@ -699,6 +699,23 @@ def test_while_func_iterator(self):
699699
body_1 = 7
700700
_exit = 8
701701

702+
self.assertEqual(self.cfg.nodes[test].label, 'while foo():')
703+
704+
def test_while_func_comparator_rhs(self):
705+
self.cfg_create_from_file('examples/example_inputs/while_func_comparator_rhs.py')
706+
707+
self.assert_length(self.cfg.nodes, expected_length=9)
708+
709+
entry = 0
710+
test = 1
711+
entry_foo = 2
712+
ret_foo = 3
713+
exit_foo = 4
714+
call_foo = 5
715+
_print = 6
716+
body_1 = 7
717+
_exit = 8
718+
702719
self.assertEqual(self.cfg.nodes[test].label, 'while x < foo():')
703720

704721
self.assertInCfg([
@@ -707,13 +724,41 @@ def test_while_func_iterator(self):
707724
(_print, test),
708725
(_exit, test),
709726
(body_1, _print),
710-
711727
(test, body_1),
712728
(test, call_foo),
713729
(ret_foo, entry_foo),
714730
(exit_foo, ret_foo),
715-
(call_foo, exit_foo),
731+
(call_foo, exit_foo)
732+
])
733+
734+
def test_while_func_comparator_lhs(self):
735+
self.cfg_create_from_file('examples/example_inputs/while_func_comparator_lhs.py')
716736

737+
self.assert_length(self.cfg.nodes, expected_length=9)
738+
739+
entry = 0
740+
test = 1
741+
entry_foo = 2
742+
ret_foo = 3
743+
exit_foo = 4
744+
call_foo = 5
745+
_print = 6
746+
body_1 = 7
747+
_exit = 8
748+
749+
self.assertEqual(self.cfg.nodes[test].label, 'while foo() > x:')
750+
751+
self.assertInCfg([
752+
(test, entry),
753+
(entry_foo, test),
754+
(_print, test),
755+
(_exit, test),
756+
(body_1, _print),
757+
(test, body_1),
758+
(test, call_foo),
759+
(ret_foo, entry_foo),
760+
(exit_foo, ret_foo),
761+
(call_foo, exit_foo)
717762
])
718763

719764

0 commit comments

Comments
 (0)