@@ -56,15 +56,22 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
5656 )
5757
5858 expected_lora_output = [
59- "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])" , # noqa: E501
60- "give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])" , # noqa: E501
61- "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])" , # noqa: E501
59+ [
60+ "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])" # noqa: E501
61+ ],
62+ [
63+ "give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])" , # noqa: E501
64+ "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])" , # noqa: E501
65+ ],
66+ [
67+ "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])" # noqa: E501
68+ ],
6269 ]
63- assert (
64- do_sample ( llm , mixtral_lora_files , lora_id = 1 , prompts = prompts )
65- == expected_lora_output
66- )
67- assert (
68- do_sample ( llm , mixtral_lora_files , lora_id = 2 , prompts = prompts )
69- == expected_lora_output
70- )
70+
71+ def check_outputs ( generated : list [ str ]):
72+ assert len ( generated ) == len ( expected_lora_output )
73+ for gen , gt_choices in zip ( generated , expected_lora_output ):
74+ assert gen in gt_choices
75+
76+ check_outputs ( do_sample ( llm , mixtral_lora_files , lora_id = 1 , prompts = prompts ))
77+ check_outputs ( do_sample ( llm , mixtral_lora_files , lora_id = 2 , prompts = prompts ) )
0 commit comments