Skip to content

Commit 012374c

Browse files
Merge pull request #132 from lambda-feedback/tr102-rtol-slow
Tr102 rtol slow
2 parents 4bfb65c + c73908f commit 012374c

File tree

1 file changed

+37
-36
lines changed

1 file changed

+37
-36
lines changed

app/symbolic_comparison_evaluation.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -270,41 +270,42 @@ def symbolic_comparison(response, answer, params, eval_response) -> dict:
270270
error_below_atol = None
271271
error_below_rtol = None
272272

273-
if params.get("numerical", False) or params.get("rtol", False) or params.get("atol", False):
274-
# REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
275-
# The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
276-
# are other reserved symbols.
277-
def replace_pi(expr):
278-
pi_symbol = pi
279-
for s in expr.free_symbols:
280-
if str(s) == 'pi':
281-
pi_symbol = s
282-
return expr.subs(pi_symbol, float(pi))
283-
ans = replace_pi(ans)
284-
res = replace_pi(res)
285-
if "atol" in params.keys():
286-
try:
287-
absolute_error = abs(float(ans-res))
288-
error_below_atol = bool(absolute_error < float(params["atol"]))
289-
except TypeError:
290-
error_below_atol = None
291-
else:
292-
error_below_atol = True
293-
if "rtol" in params.keys():
294-
try:
295-
relative_error = abs(float((ans-res)/ans)) # TODO: capture error here and see if you can rewrite this in a faster way
296-
error_below_rtol = bool(relative_error < float(params["rtol"]))
297-
except TypeError:
298-
error_below_rtol = None
299-
else:
300-
error_below_rtol = True
301-
if error_below_atol is None or error_below_rtol is None:
302-
eval_response.is_correct = False
303-
tag = "NOT_NUMERICAL"
304-
eval_response.add_feedback((tag, symbolic_comparison_internal_messages[tag]))
305-
elif error_below_atol is True and error_below_rtol is True:
306-
eval_response.is_correct = True
307-
tag = "WITHIN_TOLERANCE"
308-
eval_response.add_feedback((tag, symbolic_comparison_internal_messages[tag]))
273+
if eval_response.is_correct is False:
274+
if params.get("numerical", False) or params.get("rtol", False) or params.get("atol", False):
275+
# REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
276+
# The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
277+
# are other reserved symbols.
278+
def replace_pi(expr):
279+
pi_symbol = pi
280+
for s in expr.free_symbols:
281+
if str(s) == 'pi':
282+
pi_symbol = s
283+
return expr.subs(pi_symbol, float(pi))
284+
ans = replace_pi(ans)
285+
res = replace_pi(res)
286+
if "atol" in params.keys():
287+
try:
288+
absolute_error = abs(float(ans-res))
289+
error_below_atol = bool(absolute_error < float(params["atol"]))
290+
except TypeError:
291+
error_below_atol = None
292+
else:
293+
error_below_atol = True
294+
if "rtol" in params.keys():
295+
try:
296+
relative_error = abs(float((ans-res)/ans)) # TODO: capture error here and see if you can rewrite this in a faster way
297+
error_below_rtol = bool(relative_error < float(params["rtol"]))
298+
except TypeError:
299+
error_below_rtol = None
300+
else:
301+
error_below_rtol = True
302+
if error_below_atol is None or error_below_rtol is None:
303+
eval_response.is_correct = False
304+
tag = "NOT_NUMERICAL"
305+
eval_response.add_feedback((tag, symbolic_comparison_internal_messages[tag]))
306+
elif error_below_atol is True and error_below_rtol is True:
307+
eval_response.is_correct = True
308+
tag = "WITHIN_TOLERANCE"
309+
eval_response.add_feedback((tag, symbolic_comparison_internal_messages[tag]))
309310

310311
return eval_response

0 commit comments

Comments
 (0)