|
14 | 14 | #
|
15 | 15 | # pylint: disable=protected-access,bad-continuation,missing-function-docstring
|
16 | 16 |
|
17 |
| -import os |
18 | 17 |
|
19 | 18 | from tests.unit.vertexai.genai.replays import pytest_helper
|
20 | 19 | from vertexai._genai import types
|
21 |
| -import pytest |
| 20 | +import pandas as pd |
22 | 21 |
|
23 | 22 |
|
24 |
| -IS_KOKORO = os.getenv("KOKORO_BUILD_NUMBER") is not None |
| 23 | +def test_bleu_metric(client): |
| 24 | + test_bleu_input = types.BleuInput( |
| 25 | + instances=[ |
| 26 | + types.BleuInstance( |
| 27 | + reference="The quick brown fox jumps over the lazy dog.", |
| 28 | + prediction="A fast brown fox leaps over a lazy dog.", |
| 29 | + ) |
| 30 | + ], |
| 31 | + metric_spec=types.BleuSpec(), |
| 32 | + ) |
| 33 | + response = client.evals._evaluate_instances(bleu_input=test_bleu_input) |
| 34 | + assert len(response.bleu_results.bleu_metric_values) == 1 |
25 | 35 |
|
26 | 36 |
|
27 |
| -@pytest.mark.skipif(IS_KOKORO, reason="This test is only run in google3 env.") |
28 |
| -class TestEvaluateInstances: |
29 |
| - """Tests for evaluate instances.""" |
| 37 | +def test_run_inference_with_string_model(client): |
| 38 | + test_df = pd.DataFrame({"prompt": ["test prompt"]}) |
30 | 39 |
|
31 |
| - def test_bleu_metric(self, client): |
32 |
| - test_bleu_input = types.BleuInput( |
33 |
| - instances=[ |
34 |
| - types.BleuInstance( |
35 |
| - reference="The quick brown fox jumps over the lazy dog.", |
36 |
| - prediction="A fast brown fox leaps over a lazy dog.", |
37 |
| - ) |
38 |
| - ], |
39 |
| - metric_spec=types.BleuSpec(), |
40 |
| - ) |
41 |
| - response = client.evals._evaluate_instances(bleu_input=test_bleu_input) |
42 |
| - assert len(response.bleu_results.bleu_metric_values) == 1 |
| 40 | + inference_result = client.evals.run_inference( |
| 41 | + model="gemini-pro", |
| 42 | + src=test_df, |
| 43 | + ) |
| 44 | + assert inference_result.candidate_name == "gemini-pro" |
| 45 | + assert inference_result.gcs_source is None |
| 46 | + |
| 47 | + |
| 48 | +def test_run_inference_with_callable_model_sets_candidate_name(client): |
| 49 | + test_df = pd.DataFrame({"prompt": ["test prompt"]}) |
| 50 | + |
| 51 | + def my_model_fn(contents): |
| 52 | + return "callable response" |
| 53 | + |
| 54 | + inference_result = client.evals.run_inference( |
| 55 | + model=my_model_fn, |
| 56 | + src=test_df, |
| 57 | + ) |
| 58 | + assert inference_result.candidate_name == "my_model_fn" |
| 59 | + assert inference_result.gcs_source is None |
| 60 | + |
| 61 | + |
| 62 | +def test_inference_with_prompt_template(client): |
| 63 | + test_df = pd.DataFrame({"text_input": ["world"]}) |
| 64 | + config = types.EvalRunInferenceConfig(prompt_template="Hello {text_input}") |
| 65 | + inference_result = client.evals.run_inference( |
| 66 | + model="gemini-2.0-flash-exp", src=test_df, config=config |
| 67 | + ) |
| 68 | + assert inference_result.candidate_name == "gemini-2.0-flash-exp" |
| 69 | + assert inference_result.gcs_source is None |
43 | 70 |
|
44 | 71 |
|
45 | 72 | pytestmark = pytest_helper.setup(
|
|
0 commit comments