Skip to content

Commit 5a485b0

Browse files
ankursharmascopybara-github
authored andcommitted
feat: Adds Rubric based final response evaluator
The evaluator uses a set of rubrics to assess the quality of the agent's final response. PiperOrigin-RevId: 811154498
1 parent 01923a9 commit 5a485b0

16 files changed

+1969
-24
lines changed

src/google/adk/evaluation/app_details.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,17 @@ class AppDetails(EvalBaseModel):
4747
default_factory=dict,
4848
)
4949
"""A mapping from the agent name to the details of that agent."""
50+
51+
def get_developer_instructions(self, agent_name: str) -> str:
52+
"""Returns a string containing the developer instructions."""
53+
if agent_name not in self.agent_details:
54+
raise ValueError(f"`{agent_name}` not found in the agentic system.")
55+
56+
return self.agent_details[agent_name].instructions
57+
58+
def get_tools_by_agent_name(self) -> dict[str, genai_types.ToolListUnion]:
59+
"""Returns a dictionary of tools available to an agent in the App, keyed to the name of the Agent."""
60+
return {
61+
name: details.tool_declarations
62+
for name, details in self.agent_details.items()
63+
}

src/google/adk/evaluation/eval_case.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,54 @@ def get_all_tool_calls(
168168
)
169169

170170
return tool_calls
171+
172+
173+
def get_all_tool_responses(
174+
intermediate_data: Optional[IntermediateDataType],
175+
) -> list[genai_types.FunctionResponse]:
176+
"""A utility method to retrieve tools responses from intermediate data."""
177+
if not intermediate_data:
178+
return []
179+
180+
tool_responses = []
181+
if isinstance(intermediate_data, IntermediateData):
182+
tool_responses = intermediate_data.tool_responses
183+
elif isinstance(intermediate_data, InvocationEvents):
184+
# Go over each event in the list of events
185+
for invocation_event in intermediate_data.invocation_events:
186+
# Check if the event has content and some parts.
187+
if invocation_event.content and invocation_event.content.parts:
188+
for p in invocation_event.content.parts:
189+
# For each part, we check if any of those part is a function response.
190+
if p.function_response:
191+
tool_responses.append(p.function_response)
192+
else:
193+
raise ValueError(
194+
f"Unsupported type for intermediate_data `{intermediate_data}`"
195+
)
196+
197+
return tool_responses
198+
199+
200+
ToolCallAndResponse: TypeAlias = tuple[
201+
genai_types.FunctionCall, Optional[genai_types.FunctionResponse]
202+
]
203+
"""A Tuple representing a Function call and corresponding optional function response."""
204+
205+
206+
def get_all_tool_calls_with_responses(
207+
intermediate_data: Optional[IntermediateDataType],
208+
) -> list[ToolCallAndResponse]:
209+
"""Returns tool calls with the corresponding responses, if available."""
210+
tool_responses_by_call_id: dict[str, genai_types.FunctionResponse] = {
211+
tool_response.id: tool_response
212+
for tool_response in get_all_tool_responses(intermediate_data)
213+
}
214+
215+
tool_call_and_responses: list[ToolCallAndResponse] = []
216+
217+
for tool_call in get_all_tool_calls(intermediate_data):
218+
response = tool_responses_by_call_id.get(tool_call.id, None)
219+
tool_call_and_responses.append((tool_call, response))
220+
221+
return tool_call_and_responses

src/google/adk/evaluation/eval_metrics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ class PrebuiltMetrics(Enum):
4848

4949
FINAL_RESPONSE_MATCH_V2 = "final_response_match_v2"
5050

51+
RUBRIC_BASED_FINAL_RESPONSE_QUALITY_V1 = (
52+
"rubric_based_final_response_quality_v1"
53+
)
54+
5155

5256
MetricName: TypeAlias = Union[str, PrebuiltMetrics]
5357
Threshold: TypeAlias = float

src/google/adk/evaluation/evaluator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .eval_case import Invocation
2424
from .eval_metrics import BaseCriterion
2525
from .eval_metrics import EvalStatus
26+
from .eval_rubrics import RubricScore
2627

2728
# Redefining the type here for backward compatibility.
2829
EvalStatus: TypeAlias = EvalStatus
@@ -35,6 +36,7 @@ class PerInvocationResult(BaseModel):
3536
expected_invocation: Invocation
3637
score: Optional[float] = None
3738
eval_status: EvalStatus = EvalStatus.NOT_EVALUATED
39+
rubric_scores: Optional[list[RubricScore]] = None
3840

3941

4042
class EvaluationResult(BaseModel):
@@ -45,6 +47,10 @@ class EvaluationResult(BaseModel):
4547
"""Overall status, based on each invocation."""
4648

4749
per_invocation_results: list[PerInvocationResult] = []
50+
"""Detailed results per invocation."""
51+
52+
overall_rubric_scores: Optional[list[RubricScore]] = None
53+
"""Overall rubric, based on each invocation."""
4854

4955

5056
class Evaluator(ABC):

src/google/adk/evaluation/final_response_match_v2.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .eval_metrics import PrebuiltMetrics
3434
from .evaluator import EvaluationResult
3535
from .evaluator import PerInvocationResult
36+
from .llm_as_judge import AutoRaterScore
3637
from .llm_as_judge import LlmAsJudge
3738
from .llm_as_judge_utils import get_eval_status
3839
from .llm_as_judge_utils import get_text_from_content
@@ -179,17 +180,17 @@ def format_auto_rater_prompt(
179180
@override
180181
def convert_auto_rater_response_to_score(
181182
self, llm_response: LlmResponse
182-
) -> Optional[float]:
183+
) -> AutoRaterScore:
183184
response_text = get_text_from_content(llm_response.content)
184185
if response_text is None:
185-
return None
186+
return AutoRaterScore()
186187
label = _parse_critique(response_text)
187188
if label == Label.VALID:
188-
return 1.0
189+
return AutoRaterScore(score=1.0)
189190
elif label == Label.INVALID:
190-
return 0.0
191+
return AutoRaterScore(score=0.0)
191192
else:
192-
return None
193+
return AutoRaterScore()
193194

194195
@override
195196
def aggregate_per_invocation_samples(

src/google/adk/evaluation/llm_as_judge.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,22 @@
2626
from ..models.llm_response import LlmResponse
2727
from ..models.registry import LLMRegistry
2828
from ..utils.context_utils import Aclosing
29+
from .common import EvalBaseModel
2930
from .eval_case import Invocation
3031
from .eval_metrics import BaseCriterion
3132
from .eval_metrics import EvalMetric
33+
from .eval_metrics import RubricScore
3234
from .evaluator import EvaluationResult
3335
from .evaluator import Evaluator
3436
from .evaluator import PerInvocationResult
3537
from .llm_as_judge_utils import get_eval_status
3638

3739

40+
class AutoRaterScore(EvalBaseModel):
41+
score: Optional[float] = None
42+
rubric_scores: Optional[list[RubricScore]] = None
43+
44+
3845
class LlmAsJudge(Evaluator):
3946
"""Evaluator based on a LLM.
4047
@@ -82,7 +89,7 @@ def format_auto_rater_prompt(
8289
@abstractmethod
8390
def convert_auto_rater_response_to_score(
8491
self, auto_rater_response: LlmResponse
85-
) -> Optional[float]:
92+
) -> AutoRaterScore:
8693
"""Parses auto_rater_response and returns the corresponding score, or None if the score cannot be determined."""
8794

8895
@abstractmethod
@@ -126,15 +133,18 @@ async def evaluate_invocations(
126133
) as agen:
127134
async for llm_response in agen:
128135
# Non-streaming call, so there is only one response content.
129-
score = self.convert_auto_rater_response_to_score(llm_response)
136+
auto_rater_score = self.convert_auto_rater_response_to_score(
137+
llm_response
138+
)
130139
invocation_result_samples.append(
131140
PerInvocationResult(
132141
actual_invocation=actual,
133142
expected_invocation=expected,
134-
score=score,
143+
score=auto_rater_score.score,
135144
eval_status=get_eval_status(
136-
score, self._criterion.threshold
145+
auto_rater_score.score, self._eval_metric.threshold
137146
),
147+
rubric_scores=auto_rater_score.rubric_scores,
138148
)
139149
)
140150
if not invocation_result_samples:

src/google/adk/evaluation/llm_as_judge_utils.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,17 @@
1515
from __future__ import annotations
1616

1717
import enum
18+
import statistics
1819
from typing import Optional
20+
from typing import Union
1921

2022
from google.genai import types as genai_types
2123

24+
from .app_details import AppDetails
25+
from .common import EvalBaseModel
26+
from .eval_case import get_all_tool_calls_with_responses
27+
from .eval_case import IntermediateDataType
28+
from .eval_metrics import RubricScore
2229
from .evaluator import EvalStatus
2330

2431

@@ -46,3 +53,97 @@ def get_eval_status(score: Optional[float], threshold: float) -> EvalStatus:
4653
if score is None:
4754
return EvalStatus.NOT_EVALUATED
4855
return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
56+
57+
58+
def get_average_rubric_score(
59+
rubric_scores: list[RubricScore],
60+
) -> Optional[float]:
61+
"""Returns a single score value from the given list of rubric scores.
62+
63+
It is possible that none of the rubric score actually contain a score value,
64+
if that happens then None is returned.
65+
66+
If non-zero score values are present, then a mean value is returned as the
67+
aggregated value.
68+
"""
69+
rubric_scores = [
70+
rubric_score.score
71+
for rubric_score in rubric_scores
72+
if rubric_score.score is not None
73+
]
74+
75+
return statistics.mean(rubric_scores) if rubric_scores else None
76+
77+
78+
class _ToolDeclarations(EvalBaseModel):
79+
"""Internal data model used for serializing Tool declarations."""
80+
81+
tool_declarations: dict[str, genai_types.ToolListUnion]
82+
83+
84+
def get_tool_declarations_as_json_str(
85+
app_details: AppDetails,
86+
) -> str:
87+
"""Returns a JSON string representation of Tool declarations.
88+
89+
The output of this method is usually intended to be sent to the LLM.
90+
"""
91+
tool_declarations = _ToolDeclarations(
92+
tool_declarations=app_details.get_tools_by_agent_name()
93+
)
94+
return tool_declarations.model_dump_json(
95+
indent=2,
96+
exclude_unset=True,
97+
exclude_defaults=True,
98+
exclude_none=True,
99+
)
100+
101+
102+
class _ToolCallAndResponse(EvalBaseModel):
103+
"""Internal data model to capture one single tool call and response."""
104+
105+
step: int
106+
tool_call: genai_types.FunctionCall
107+
tool_response: Union[genai_types.FunctionResponse, str]
108+
109+
110+
class _ToolCallsAndResponses(EvalBaseModel):
111+
"""Internal data model used for serializing Tool call and responses."""
112+
113+
tool_calls_and_response: list[_ToolCallAndResponse]
114+
115+
116+
def get_tool_calls_and_responses_as_json_str(
117+
intermediate_data: Optional[IntermediateDataType],
118+
) -> str:
119+
"""Returns a JSON string representation of tool calls and corresponding responses.
120+
121+
The output of this method is usually intended to be sent to the LLM.
122+
"""
123+
raw_tool_calls_and_response = get_all_tool_calls_with_responses(
124+
intermediate_data
125+
)
126+
127+
if not raw_tool_calls_and_response:
128+
return "No intermediate steps were taken."
129+
130+
tool_calls_and_responses = []
131+
for idx, (tool_call, tool_response) in enumerate(raw_tool_calls_and_response):
132+
tool_calls_and_responses.append(
133+
_ToolCallAndResponse(
134+
step=idx,
135+
tool_call=tool_call,
136+
tool_response=tool_response if tool_response else "None",
137+
)
138+
)
139+
140+
internal_tool_calls_and_responses = _ToolCallsAndResponses(
141+
tool_calls_and_response=tool_calls_and_responses
142+
)
143+
144+
return internal_tool_calls_and_responses.model_dump_json(
145+
indent=2,
146+
exclude_unset=True,
147+
exclude_defaults=True,
148+
exclude_none=True,
149+
)

src/google/adk/evaluation/local_eval_service.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .eval_case import Invocation
4141
from .eval_metrics import EvalMetric
4242
from .eval_metrics import EvalMetricResult
43+
from .eval_metrics import EvalMetricResultDetails
4344
from .eval_metrics import EvalMetricResultPerInvocation
4445
from .eval_result import EvalCaseResult
4546
from .eval_set import EvalCase
@@ -239,12 +240,15 @@ async def _evaluate_single_inference_result(
239240
)
240241

241242
# Track overall scrore across all invocations.
243+
eval_metric_result_details = EvalMetricResultDetails(
244+
rubric_scores=evaluation_result.overall_rubric_scores
245+
)
242246
overall_eval_metric_results.append(
243247
EvalMetricResult(
244-
metric_name=eval_metric.metric_name,
245-
threshold=eval_metric.threshold,
246248
score=evaluation_result.overall_score,
247249
eval_status=evaluation_result.overall_eval_status,
250+
details=eval_metric_result_details,
251+
**eval_metric.model_dump(),
248252
)
249253
)
250254

@@ -262,12 +266,15 @@ async def _evaluate_single_inference_result(
262266
evaluation_result.per_invocation_results,
263267
eval_metric_result_per_invocation,
264268
):
269+
eval_metric_result_details = EvalMetricResultDetails(
270+
rubric_scores=invocation_result.rubric_scores
271+
)
265272
invocation.eval_metric_results.append(
266273
EvalMetricResult(
267-
metric_name=eval_metric.metric_name,
268-
threshold=eval_metric.threshold,
269274
score=invocation_result.score,
270275
eval_status=invocation_result.eval_status,
276+
details=eval_metric_result_details,
277+
**eval_metric.model_dump(),
271278
)
272279
)
273280

src/google/adk/evaluation/metric_evaluator_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .evaluator import Evaluator
2525
from .final_response_match_v2 import FinalResponseMatchV2Evaluator
2626
from .response_evaluator import ResponseEvaluator
27+
from .rubric_based_final_response_quality_v1 import RubricBasedFinalResponseQualityV1Evaluator
2728
from .safety_evaluator import SafetyEvaluatorV1
2829
from .trajectory_evaluator import TrajectoryEvaluator
2930

@@ -111,6 +112,10 @@ def _get_default_metric_evaluator_registry() -> MetricEvaluatorRegistry:
111112
metric_info=FinalResponseMatchV2Evaluator.get_metric_info(),
112113
evaluator=FinalResponseMatchV2Evaluator,
113114
)
115+
metric_evaluator_registry.register_evaluator(
116+
metric_info=RubricBasedFinalResponseQualityV1Evaluator.get_metric_info(),
117+
evaluator=RubricBasedFinalResponseQualityV1Evaluator,
118+
)
114119

115120
return metric_evaluator_registry
116121

0 commit comments

Comments
 (0)