@@ -270,41 +270,42 @@ def symbolic_comparison(response, answer, params, eval_response) -> dict:
270270 error_below_atol = None
271271 error_below_rtol = None
272272
273- if params .get ("numerical" , False ) or params .get ("rtol" , False ) or params .get ("atol" , False ):
274- # REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
275- # The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
276- # are other reserved symbols.
277- def replace_pi (expr ):
278- pi_symbol = pi
279- for s in expr .free_symbols :
280- if str (s ) == 'pi' :
281- pi_symbol = s
282- return expr .subs (pi_symbol , float (pi ))
283- ans = replace_pi (ans )
284- res = replace_pi (res )
285- if "atol" in params .keys ():
286- try :
287- absolute_error = abs (float (ans - res ))
288- error_below_atol = bool (absolute_error < float (params ["atol" ]))
289- except TypeError :
290- error_below_atol = None
291- else :
292- error_below_atol = True
293- if "rtol" in params .keys ():
294- try :
295- relative_error = abs (float ((ans - res )/ ans )) # TODO: capture error here and see if you can rewrite this in a faster way
296- error_below_rtol = bool (relative_error < float (params ["rtol" ]))
297- except TypeError :
298- error_below_rtol = None
299- else :
300- error_below_rtol = True
301- if error_below_atol is None or error_below_rtol is None :
302- eval_response .is_correct = False
303- tag = "NOT_NUMERICAL"
304- eval_response .add_feedback ((tag , symbolic_comparison_internal_messages [tag ]))
305- elif error_below_atol is True and error_below_rtol is True :
306- eval_response .is_correct = True
307- tag = "WITHIN_TOLERANCE"
308- eval_response .add_feedback ((tag , symbolic_comparison_internal_messages [tag ]))
273+ if eval_response .is_correct is False :
274+ if params .get ("numerical" , False ) or params .get ("rtol" , False ) or params .get ("atol" , False ):
275+ # REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
276+ # The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
277+ # are other reserved symbols.
278+ def replace_pi (expr ):
279+ pi_symbol = pi
280+ for s in expr .free_symbols :
281+ if str (s ) == 'pi' :
282+ pi_symbol = s
283+ return expr .subs (pi_symbol , float (pi ))
284+ ans = replace_pi (ans )
285+ res = replace_pi (res )
286+ if "atol" in params .keys ():
287+ try :
288+ absolute_error = abs (float (ans - res ))
289+ error_below_atol = bool (absolute_error < float (params ["atol" ]))
290+ except TypeError :
291+ error_below_atol = None
292+ else :
293+ error_below_atol = True
294+ if "rtol" in params .keys ():
295+ try :
296+ relative_error = abs (float ((ans - res )/ ans )) # TODO: capture error here and see if you can rewrite this in a faster way
297+ error_below_rtol = bool (relative_error < float (params ["rtol" ]))
298+ except TypeError :
299+ error_below_rtol = None
300+ else :
301+ error_below_rtol = True
302+ if error_below_atol is None or error_below_rtol is None :
303+ eval_response .is_correct = False
304+ tag = "NOT_NUMERICAL"
305+ eval_response .add_feedback ((tag , symbolic_comparison_internal_messages [tag ]))
306+ elif error_below_atol is True and error_below_rtol is True :
307+ eval_response .is_correct = True
308+ tag = "WITHIN_TOLERANCE"
309+ eval_response .add_feedback ((tag , symbolic_comparison_internal_messages [tag ]))
309310
310311 return eval_response
0 commit comments