Skip to content

Commit 1c0ba67

Browse files
committed
openai: Support toolChoice
1 parent 8ac453c commit 1c0ba67

File tree

1 file changed

+21
-7
lines changed
  • services/openai-dialog/src

1 file changed

+21
-7
lines changed

services/openai-dialog/src/lib.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ use openai_api_rs::realtime::{
1515
api::RealtimeClient,
1616
client_event::{self, ClientEvent},
1717
server_event::{self, ServerEvent},
18-
types::{self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus},
18+
types::{
19+
self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus,
20+
ToolChoice,
21+
},
1922
};
2023
use serde::{Deserialize, Serialize};
2124
use tokio::{net::TcpStream, select};
@@ -42,6 +45,7 @@ pub struct Params {
4245
pub temperature: Option<f32>,
4346
#[serde(default)]
4447
pub tools: Vec<types::ToolDefinition>,
48+
tool_choice: Option<ToolChoice>,
4549
}
4650

4751
impl Params {
@@ -113,6 +117,7 @@ pub enum ServiceInputEvent {
113117
Prompt {
114118
text: String,
115119
},
120+
#[serde(rename_all = "camelCase")]
116121
SessionUpdate {
117122
#[serde(skip_serializing_if = "Option::is_none")]
118123
instructions: Option<String>,
@@ -122,6 +127,8 @@ pub enum ServiceInputEvent {
122127
temperature: Option<f32>,
123128
#[serde(skip_serializing_if = "Option::is_none")]
124129
tools: Option<Vec<types::ToolDefinition>>,
130+
#[serde(skip_serializing_if = "Option::is_none")]
131+
tool_choice: Option<ToolChoice>,
125132
},
126133
}
127134

@@ -257,11 +264,6 @@ impl Client {
257264
send_update = true;
258265
};
259266

260-
if !params.tools.is_empty() {
261-
session.tools = Some(params.tools);
262-
send_update = true;
263-
}
264-
265267
if let Some(voice) = params.voice {
266268
session.voice = Some(voice);
267269
send_update = true;
@@ -272,6 +274,16 @@ impl Client {
272274
send_update = true;
273275
}
274276

277+
if !params.tools.is_empty() {
278+
session.tools = Some(params.tools);
279+
send_update = true;
280+
}
281+
282+
if let Some(tool_choice) = params.tool_choice {
283+
session.tool_choice = Some(tool_choice);
284+
send_update = true;
285+
}
286+
275287
if send_update {
276288
self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate {
277289
event_id: None,
@@ -426,13 +438,15 @@ impl Client {
426438
voice,
427439
temperature,
428440
tools,
441+
tool_choice,
429442
} => {
430443
let event = ClientEvent::SessionUpdate(client_event::SessionUpdate {
431444
session: types::Session {
432-
tools,
433445
instructions,
434446
voice,
435447
temperature,
448+
tools,
449+
tool_choice,
436450
..Default::default()
437451
},
438452
..Default::default()

0 commit comments

Comments
 (0)