Skip to content

Commit ca78b68

Browse files
authored
Merge pull request #163 from bcaller/curry
Fix VarsVisitor RuntimeError on code like f(g(a)(b)(c))
2 parents b1da929 + f258639 commit ca78b68

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

pyt/cfg/stmt_visitor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,8 @@ def add_blackbox_or_builtin_call(self, node, blackbox):
619619
rhs_vars = list()
620620
last_return_value_of_nested_call = None
621621

622-
for arg in itertools.chain(node.args, node.keywords):
622+
for arg_node in itertools.chain(node.args, node.keywords):
623+
arg = arg_node.value if isinstance(arg_node, ast.keyword) else arg_node
623624
if isinstance(arg, ast.Call):
624625
return_value_of_nested_call = self.visit(arg)
625626

@@ -638,15 +639,18 @@ def add_blackbox_or_builtin_call(self, node, blackbox):
638639
call_node.inner_most_call = return_value_of_nested_call
639640
last_return_value_of_nested_call = return_value_of_nested_call
640641

641-
visual_args.append(return_value_of_nested_call.left_hand_side)
642+
if isinstance(arg_node, ast.keyword) and arg_node.arg is not None:
643+
visual_args.append(arg_node.arg + '=' + return_value_of_nested_call.left_hand_side)
644+
else:
645+
visual_args.append(return_value_of_nested_call.left_hand_side)
642646
rhs_vars.append(return_value_of_nested_call.left_hand_side)
643647
else:
644648
label = LabelVisitor()
645-
label.visit(arg)
649+
label.visit(arg_node)
646650
visual_args.append(label.result)
647651

648652
vv = VarsVisitor()
649-
vv.visit(arg)
653+
vv.visit(arg_node)
650654
rhs_vars.extend(vv.result)
651655
if last_return_value_of_nested_call:
652656
# connect other_inner to outer in e.g.

pyt/helper_visitors/vars_visitor.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def visit_Call(self, node):
8484
# This will not visit Flask in Flask(__name__) but it will visit request in `request.args.get()
8585
if not isinstance(node.func, ast.Name):
8686
self.visit(node.func)
87-
for arg in itertools.chain(node.args, node.keywords):
87+
for arg_node in itertools.chain(node.args, node.keywords):
88+
arg = arg_node.value if isinstance(arg_node, ast.keyword) else arg_node
8889
if isinstance(arg, ast.Call):
8990
if isinstance(arg.func, ast.Name):
9091
# We can't just visit because we need to add 'ret_'
@@ -95,12 +96,32 @@ def visit_Call(self, node):
9596
# func.value.id is html
9697
# We want replace
9798
self.result.append('ret_' + arg.func.attr)
99+
elif isinstance(arg.func, ast.Call):
100+
self.visit_curried_call_inside_call_args(arg)
98101
else:
99-
# Deal with it when we have code that triggers it.
100-
raise
102+
raise Exception('Cannot visit vars of ' + ast.dump(arg))
101103
else:
102104
self.visit(arg)
103105

106+
def visit_curried_call_inside_call_args(self, inner_call):
107+
# Curried functions aren't supported really, but we now at least have a defined behaviour.
108+
# In f(g(a)(b)(c)), inner_call is the Call node with argument c
109+
# Try to get the name of curried function g
110+
curried_func = inner_call.func.func
111+
while isinstance(curried_func, ast.Call):
112+
curried_func = curried_func.func
113+
if isinstance(curried_func, ast.Name):
114+
self.result.append('ret_' + curried_func.id)
115+
elif isinstance(curried_func, ast.Attribute):
116+
self.result.append('ret_' + curried_func.attr)
117+
118+
# Visit all arguments except a (ignore the curried function g)
119+
not_curried = inner_call
120+
while not_curried.func is not curried_func:
121+
for arg in itertools.chain(not_curried.args, not_curried.keywords):
122+
self.visit(arg.value if isinstance(arg, ast.keyword) else arg)
123+
not_curried = not_curried.func
124+
104125
def visit_Attribute(self, node):
105126
if not isinstance(node.value, ast.Name):
106127
self.visit(node.value)

tests/helper_visitors/vars_visitor_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,20 @@ def test_call5(self):
4545
self.assertEqual(vars.result, ['resp', 'ret_replace'])
4646

4747
def test_call6(self):
48+
vars = self.perform_vars_on_expression("resp = f(kw=g(a, b))")
49+
self.assertEqual(vars.result, ['resp', 'ret_g'])
50+
51+
def test_call7(self):
4852
vars = self.perform_vars_on_expression("resp = make_response(html.replace.bar('{{ param }}', param))")
4953
self.assertEqual(vars.result, ['resp', 'ret_bar'])
5054

55+
def test_curried_function(self):
56+
# Curried functions aren't supported really, but we now at least have a defined behaviour.
57+
vars = self.perform_vars_on_expression('f(g.h(a)(b))')
58+
self.assertCountEqual(vars.result, ['ret_h', 'b'])
59+
vars = self.perform_vars_on_expression('f(g(a)(b)(c)(d, e=f))')
60+
self.assertCountEqual(vars.result, ['ret_g', 'b', 'c', 'd', 'f'])
61+
5162
def test_keyword_vararg(self):
5263
vars = self.perform_vars_on_expression('print(arg = x)')
5364
self.assertEqual(vars.result, ['x'])

0 commit comments

Comments
 (0)