Skip to content

Commit 4bfb65c

Browse files
Merge pull request #131 from lambda-feedback/tr102-rtol-slow
Tr102 rtol slow
2 parents f9e5915 + 8c47f23 commit 4bfb65c

File tree

5 files changed

+48
-29
lines changed

5 files changed

+48
-29
lines changed

app/evaluation_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ def test_eval_function_can_handle_latex_input(self):
4242
assert result["is_correct"] is True
4343

4444
if __name__ == "__main__":
45-
pytest.main(['-xsk not slow', "--tb=line", os.path.abspath(__file__)])
45+
pytest.main(['-xsk not slow', '--tb=line', '--durations=10', os.path.abspath(__file__)])

app/expression_utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def parse_expression(expr, parsing_params):
578578
substitutions.sort(key=lambda x: -len(x[0]))
579579
expr = substitute(expr, substitutions)
580580
can_split = lambda x: False if x in unsplittable_symbols else _token_splittable(x)
581-
if strict_syntax:
581+
if strict_syntax is True:
582582
transformations = parser_transformations[0:4]+extra_transformations
583583
else:
584584
transformations = parser_transformations[0:5, 6]+extra_transformations+(split_symbols_custom(can_split),)+parser_transformations[8]

app/feedback/symbolic_comparison.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
"EXPRESSION_NOT_EQUALITY": "The response was an expression but was expected to be an equality.",
3737
"EQUALITY_NOT_EXPRESSION": "The response was an equality but was expected to be an expression.",
3838
"WITHIN_TOLERANCE": "", # "The difference between the response the answer is within specified error tolerance.",
39-
"SYMBOLICALLY_EQUAL": "The response and answer are symbolically equal.",
39+
"NOT_NUMERICAL": "The expression cannot be evaluated numerically.",
40+
# "SYMBOLICALLY_EQUAL": "The response and answer are symbolically equal.",
4041
}

app/symbolic_comparison_evaluation.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,20 @@ def symbolic_comparison(response, answer, params, eval_response) -> dict:
255255
eval_response.is_correct = ((res.args[0]-res.args[1])/(ans.args[0]-ans.args[1])).simplify().is_constant()
256256
return eval_response
257257

258-
error_below_atol = False
259-
error_below_rtol = False
258+
is_correct = True
259+
parameters_dict = {
260+
"parsing_params": parsing_params,
261+
"reserved_expressions": reserved_expressions,
262+
"reference_criteria_strings": reference_criteria_strings,
263+
"symbolic_comparison_criteria": symbolic_comparison_criteria,
264+
"eval_response": eval_response,
265+
}
266+
for criterion in criteria_parsed:
267+
is_correct = is_correct and check_criterion(criterion, parameters_dict)
268+
eval_response.is_correct = is_correct
269+
270+
error_below_atol = None
271+
error_below_rtol = None
260272

261273
if params.get("numerical", False) or params.get("rtol", False) or params.get("atol", False):
262274
# REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
@@ -271,32 +283,28 @@ def replace_pi(expr):
271283
ans = replace_pi(ans)
272284
res = replace_pi(res)
273285
if "atol" in params.keys():
274-
absolute_error = abs(ans-res)
275-
if isinstance(absolute_error, float) or absolute_error.is_constant():
276-
error_below_atol = bool(float(absolute_error) < float(params["atol"]))
286+
try:
287+
absolute_error = abs(float(ans-res))
288+
error_below_atol = bool(absolute_error < float(params["atol"]))
289+
except TypeError:
290+
error_below_atol = None
277291
else:
278292
error_below_atol = True
279293
if "rtol" in params.keys():
280-
relative_error = abs(((ans-res)/ans).simplify())
281-
if isinstance(relative_error, float) or relative_error.is_constant():
282-
error_below_rtol = bool(float(relative_error) < float(params["rtol"]))
294+
try:
295+
relative_error = abs(float((ans-res)/ans)) # TODO: capture error here and see if you can rewrite this in a faster way
296+
error_below_rtol = bool(relative_error < float(params["rtol"]))
297+
except TypeError:
298+
error_below_rtol = None
283299
else:
284300
error_below_rtol = True
285-
if error_below_atol and error_below_rtol:
301+
if error_below_atol is None or error_below_rtol is None:
302+
eval_response.is_correct = False
303+
tag = "NOT_NUMERICAL"
304+
eval_response.add_feedback((tag, symbolic_comparison_internal_messages[tag]))
305+
elif error_below_atol is True and error_below_rtol is True:
286306
eval_response.is_correct = True
287307
tag = "WITHIN_TOLERANCE"
288308
eval_response.add_feedback((tag, symbolic_comparison_internal_messages[tag]))
289-
return eval_response
290309

291-
is_correct = True
292-
parameters_dict = {
293-
"parsing_params": parsing_params,
294-
"reserved_expressions": reserved_expressions,
295-
"reference_criteria_strings": reference_criteria_strings,
296-
"symbolic_comparison_criteria": symbolic_comparison_criteria,
297-
"eval_response": eval_response,
298-
}
299-
for criterion in criteria_parsed:
300-
is_correct = is_correct and check_criterion(criterion, parameters_dict)
301-
eval_response.is_correct = is_correct
302310
return eval_response

app/symbolic_comparison_evaluation_tests.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,71 +548,81 @@ def test_empty_input_symbols_codes_and_alternatives(self):
548548
assert result["is_correct"] is True
549549

550550
@pytest.mark.parametrize(
551-
"description,response,answer,tolerance,outcome",
551+
"description,response,answer,tolerance,tags,outcome",
552552
[
553553
(
554554
"Correct response, tolerance specified with atol",
555555
"6.73",
556556
"sqrt(3)+5",
557557
{"atol": 0.005},
558+
["WITHIN_TOLERANCE"],
558559
True
559560
),
560561
(
561562
"Incorrect response, tolerance specified with atol",
562563
"6.7",
563564
"sqrt(3)+5",
564565
{"atol": 0.005},
566+
[],
565567
False
566568
),
567569
(
568570
"Correct response, tolerance specified with rtol",
569571
"6.73",
570572
"sqrt(3)+5",
571573
{"rtol": 0.0005},
574+
["WITHIN_TOLERANCE"],
572575
True
573576
),
574577
(
575578
"Incorrect response, tolerance specified with rtol",
576579
"6.7",
577580
"sqrt(3)+5",
578581
{"rtol": 0.0005},
582+
[],
579583
False
580584
),
581585
(
582586
"Response is not constant, tolerance specified with atol",
583587
"6.7+x",
584588
"sqrt(3)+5",
585589
{"atol": 0.005},
590+
["NOT_NUMERICAL"],
586591
False
587592
),
588593
(
589594
"Answer is not constant, tolerance specified with atol",
590595
"6.73",
591596
"sqrt(3)+x",
592597
{"atol": 0.005},
598+
["NOT_NUMERICAL"],
593599
False
594600
),
595601
(
596602
"Response is not constant, tolerance specified with rtol",
597603
"6.7+x",
598604
"sqrt(3)+5",
599605
{"rtol": 0.0005},
606+
["NOT_NUMERICAL"],
600607
False
601608
),
602609
(
603610
"Answer is not constant, tolerance specified with rtol",
604611
"6.73",
605612
"sqrt(3)+x",
606613
{"rtol": 0.0005},
614+
["NOT_NUMERICAL"],
607615
False
608616
),
609617
]
610618
)
611-
def test_numerical_comparison_problem(self, description, response, answer, tolerance, outcome):
619+
def test_numerical_comparison_problem(self, description, response, answer, tolerance, tags, outcome):
612620
params = {"numerical": True}
613621
params.update(tolerance)
614-
result = evaluation_function(response, answer, params)
622+
result = evaluation_function(response, answer, params, include_test_data=True)
615623
assert result["is_correct"] is outcome
624+
for tag in tags:
625+
tag in result["tags"]
616626

617627
@pytest.mark.parametrize(
618628
"description,response,answer,tolerance,outcome",
@@ -633,7 +643,7 @@ def test_numerical_comparison_problem(self, description, response, answer, toler
633643
),
634644
]
635645
)
636-
def test_numerical_comparison(self, description, response, answer, tolerance, outcome):
646+
def test_numerical_comparison_AERO4007(self, description, response, answer, tolerance, outcome):
637647
params = {
638648
"strict_syntax": False,
639649
"elementary_functions": True,
@@ -1052,4 +1062,4 @@ def test_exclamation_mark_for_factorial(self):
10521062
assert result["is_correct"] is True
10531063

10541064
if __name__ == "__main__":
1055-
pytest.main(['-xsk not slow', "--tb=line", os.path.abspath(__file__)])
1065+
pytest.main(['-xsk not slow', "--tb=line", '--durations=10', os.path.abspath(__file__)])

0 commit comments

Comments
 (0)