Skip to content

Commit a97ac3d

Browse files
feat: [vertexai] add AutomaticFunctionCallingResponder class (#10896)
PiperOrigin-RevId: 638319753 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent ca905ed commit a97ac3d

File tree

2 files changed

+289
-0
lines changed

2 files changed

+289
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.cloud.vertexai.generativeai;
17+
18+
import com.google.common.collect.ImmutableList;
19+
import java.lang.reflect.Method;
20+
import java.lang.reflect.Modifier;
21+
import java.lang.reflect.Parameter;
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
25+
/** A responder that automatically calls functions when requested by the GenAI model. */
26+
public final class AutomaticFunctionCallingResponder {
27+
private int maxFunctionCalls = 1;
28+
private int remainingFunctionCalls;
29+
private final Map<String, CallableFunction> callableFunctions = new HashMap<>();
30+
31+
/** Constructs an AutomaticFunctionCallingResponder instance. */
32+
public AutomaticFunctionCallingResponder() {
33+
this.remainingFunctionCalls = this.maxFunctionCalls;
34+
}
35+
36+
/**
37+
* Constructs an AutomaticFunctionCallingResponder instance.
38+
*
39+
* @param maxFunctionCalls the maximum number of function calls to make in a row
40+
*/
41+
public AutomaticFunctionCallingResponder(int maxFunctionCalls) {
42+
this.maxFunctionCalls = maxFunctionCalls;
43+
this.remainingFunctionCalls = maxFunctionCalls;
44+
}
45+
46+
/** Sets the maximum number of function calls to make in a row. */
47+
public void setMaxFunctionCalls(int maxFunctionCalls) {
48+
this.maxFunctionCalls = maxFunctionCalls;
49+
this.remainingFunctionCalls = this.maxFunctionCalls;
50+
}
51+
52+
/** Gets the maximum number of function calls to make in a row. */
53+
public int getMaxFunctionCalls() {
54+
return maxFunctionCalls;
55+
}
56+
57+
/** Resets the remaining function calls to the maximum number of function calls. */
58+
void resetRemainingFunctionCalls() {
59+
this.remainingFunctionCalls = this.maxFunctionCalls;
60+
}
61+
62+
/**
63+
* Adds a callable function to the AutomaticFunctionCallingResponder.
64+
*
65+
* <p><b>Note:</b>: If you don't want to manually provide parameter names, you can ignore
66+
* `orderedParameterNames` and compile your code with the "-parameters" flag. In this case, the
67+
* parameter names can be auto retrieved from reflection.
68+
*
69+
* @param functionName the name of the function
70+
* @param callableFunction the method to call when the functionName is requested
71+
* @param orderedParameterNames the names of the parameters in the order they are passed to the
72+
* function
73+
* @throws IllegalArgumentException if the functionName is already in the responder
74+
* @throws IllegalStateException if the parameter names are not provided and cannot be retrieved
75+
* from reflection
76+
*/
77+
public void addCallableFunction(
78+
String functionName, Method callableFunction, String... orderedParameterNames) {
79+
if (callableFunctions.containsKey(functionName)) {
80+
throw new IllegalArgumentException("Duplicate function name: " + functionName);
81+
} else {
82+
callableFunctions.put(
83+
functionName, new CallableFunction(callableFunction, orderedParameterNames));
84+
}
85+
}
86+
87+
/** A class that represents a function that can be called automatically. */
88+
static class CallableFunction {
89+
private final Method callableFunction;
90+
private final ImmutableList<String> orderedParameterNames;
91+
92+
/**
93+
* Constructs a CallableFunction instance.
94+
*
95+
* <p><b>Note:</b>: If you don't want to manually provide parameter names, you can ignore
96+
* `orderedParameterNames` and compile your code with the "-parameters" flag. In this case, the
97+
* parameter names can be auto retrieved from reflection.
98+
*
99+
* @param callableFunction the method to call
100+
* @param orderedParameterNames the names of the parameters in the order they are passed to the
101+
* function
102+
* @throws IllegalArgumentException if the given method is not a static method or the number of
103+
* provided parameter names doesn't match the number of parameters in the callable function
104+
* @throws IllegalStateException if the parameter names are not provided and cannot be retrieved
105+
* from reflection
106+
*/
107+
CallableFunction(Method callableFunction, String... orderedParameterNames) {
108+
validateFunction(callableFunction);
109+
this.callableFunction = callableFunction;
110+
111+
if (orderedParameterNames.length == 0) {
112+
ImmutableList.Builder<String> builder = ImmutableList.builder();
113+
for (Parameter parameter : callableFunction.getParameters()) {
114+
if (parameter.isNamePresent()) {
115+
builder.add(parameter.getName());
116+
} else {
117+
throw new IllegalStateException(
118+
"Failed to retrieve the parameter name from reflection. Please compile your code"
119+
+ " with \"-parameters\" flag or use `addCallableFunction(String, Method,"
120+
+ " String...)` to manually enter parameter names");
121+
}
122+
}
123+
this.orderedParameterNames = builder.build();
124+
} else if (orderedParameterNames.length == callableFunction.getParameters().length) {
125+
this.orderedParameterNames = ImmutableList.copyOf(orderedParameterNames);
126+
} else {
127+
throw new IllegalArgumentException(
128+
"The number of provided parameter names doesn't match the number of parameters in the"
129+
+ " callable function.");
130+
}
131+
}
132+
133+
/** Validates that the given method is a static method. */
134+
private void validateFunction(Method method) {
135+
if (!Modifier.isStatic(method.getModifiers())) {
136+
throw new IllegalArgumentException("Function calling only supports static methods.");
137+
}
138+
}
139+
}
140+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.vertexai.generativeai;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static org.junit.Assert.assertThrows;
21+
22+
import java.lang.reflect.Method;
23+
import org.junit.Test;
24+
import org.junit.runner.RunWith;
25+
import org.junit.runners.JUnit4;
26+
27+
@RunWith(JUnit4.class)
28+
public final class AutomaticFunctionCallingResponderTest {
29+
private static final int MAX_FUNCTION_CALLS = 5;
30+
private static final int DEFAULT_MAX_FUNCTION_CALLS = 1;
31+
private static final String FUNCTION_NAME_1 = "getCurrentWeather";
32+
private static final String FUNCTION_NAME_2 = "getCurrentTemperature";
33+
private static final String PARAMETER_NAME = "location";
34+
35+
public static String getCurrentWeather(String location) {
36+
if (location.equals("Boston")) {
37+
return "snowing";
38+
} else if (location.equals("Vancouver")) {
39+
return "raining";
40+
} else {
41+
return "sunny";
42+
}
43+
}
44+
45+
public static int getCurrentTemperature(String location) {
46+
if (location.equals("Boston")) {
47+
return 32;
48+
} else if (location.equals("Vancouver")) {
49+
return 45;
50+
} else {
51+
return 75;
52+
}
53+
}
54+
55+
public boolean nonStaticMethod() {
56+
return true;
57+
}
58+
59+
@Test
60+
public void testInitAutomaticFunctionCallingResponder_containsRightFields() {
61+
AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder();
62+
63+
assertThat(responder.getMaxFunctionCalls()).isEqualTo(DEFAULT_MAX_FUNCTION_CALLS);
64+
}
65+
66+
@Test
67+
public void testInitAutomaticFunctionCallingResponderWithMaxFunctionCalls_containsRightFields() {
68+
AutomaticFunctionCallingResponder responder =
69+
new AutomaticFunctionCallingResponder(MAX_FUNCTION_CALLS);
70+
71+
assertThat(responder.getMaxFunctionCalls()).isEqualTo(MAX_FUNCTION_CALLS);
72+
}
73+
74+
@Test
75+
public void testSetMaxFunctionCalls_containsRightFields() {
76+
AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder();
77+
responder.setMaxFunctionCalls(MAX_FUNCTION_CALLS);
78+
79+
assertThat(responder.getMaxFunctionCalls()).isEqualTo(MAX_FUNCTION_CALLS);
80+
}
81+
82+
@Test
83+
public void testAddCallableFunctionWithoutOrderedParameterNames_throwsIllegalArgumentException()
84+
throws NoSuchMethodException {
85+
AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder();
86+
Method callableFunction =
87+
AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class);
88+
89+
IllegalStateException thrown =
90+
assertThrows(
91+
IllegalStateException.class,
92+
() -> responder.addCallableFunction(FUNCTION_NAME_1, callableFunction));
93+
assertThat(thrown)
94+
.hasMessageThat()
95+
.isEqualTo(
96+
"Failed to retrieve the parameter name from reflection. Please compile your code with"
97+
+ " \"-parameters\" flag or use `addCallableFunction(String, Method, String...)`"
98+
+ " to manually enter parameter names");
99+
}
100+
101+
@Test
102+
public void testAddNonStaticCallableFunction_throwsIllegalArgumentException()
103+
throws NoSuchMethodException {
104+
AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder();
105+
Method nonStaticMethod =
106+
AutomaticFunctionCallingResponderTest.class.getMethod("nonStaticMethod");
107+
108+
IllegalArgumentException thrown =
109+
assertThrows(
110+
IllegalArgumentException.class,
111+
() -> responder.addCallableFunction(FUNCTION_NAME_1, nonStaticMethod, PARAMETER_NAME));
112+
assertThat(thrown).hasMessageThat().isEqualTo("Function calling only supports static methods.");
113+
}
114+
115+
@Test
116+
public void testAddRepeatedCallableFunction_throwsIllegalArgumentException()
117+
throws NoSuchMethodException {
118+
AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder();
119+
Method callableFunction =
120+
AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class);
121+
responder.addCallableFunction(FUNCTION_NAME_1, callableFunction, PARAMETER_NAME);
122+
123+
IllegalArgumentException thrown =
124+
assertThrows(
125+
IllegalArgumentException.class,
126+
() -> responder.addCallableFunction(FUNCTION_NAME_1, callableFunction, PARAMETER_NAME));
127+
assertThat(thrown).hasMessageThat().isEqualTo("Duplicate function name: " + FUNCTION_NAME_1);
128+
}
129+
130+
@Test
131+
public void testAddCallableFunctionWithWrongParameterNames_throwsIllegalArgumentException()
132+
throws NoSuchMethodException {
133+
AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder();
134+
Method callableFunction =
135+
AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class);
136+
137+
IllegalArgumentException thrown =
138+
assertThrows(
139+
IllegalArgumentException.class,
140+
() ->
141+
responder.addCallableFunction(
142+
FUNCTION_NAME_1, callableFunction, PARAMETER_NAME, "anotherParameter"));
143+
assertThat(thrown)
144+
.hasMessageThat()
145+
.isEqualTo(
146+
"The number of provided parameter names doesn't match the number of parameters in the"
147+
+ " callable function.");
148+
}
149+
}

0 commit comments

Comments
 (0)