5757from websockets .asyncio .client import ClientConnection
5858
5959from agents .handoffs import Handoff
60+ from agents .realtime ._default_tracker import ModelAudioTracker
6061from agents .tool import FunctionTool , Tool
6162from agents .util ._types import MaybeAwaitable
6263
7273 RealtimeModel ,
7374 RealtimeModelConfig ,
7475 RealtimeModelListener ,
76+ RealtimePlaybackState ,
77+ RealtimePlaybackTracker ,
7578)
7679from .model_events import (
7780 RealtimeModelAudioDoneEvent ,
@@ -133,11 +136,10 @@ def __init__(self) -> None:
133136 self ._websocket_task : asyncio .Task [None ] | None = None
134137 self ._listeners : list [RealtimeModelListener ] = []
135138 self ._current_item_id : str | None = None
136- self ._audio_start_time : datetime | None = None
137- self ._audio_length_ms : float = 0.0
139+ self ._audio_state_tracker : ModelAudioTracker = ModelAudioTracker ()
138140 self ._ongoing_response : bool = False
139- self ._current_audio_content_index : int | None = None
140141 self ._tracing_config : RealtimeModelTracingConfig | Literal ["auto" ] | None = None
142+ self ._playback_tracker : RealtimePlaybackTracker | None = None
141143
142144 async def connect (self , options : RealtimeModelConfig ) -> None :
143145 """Establish a connection to the model and keep it alive."""
@@ -146,6 +148,8 @@ async def connect(self, options: RealtimeModelConfig) -> None:
146148
147149 model_settings : RealtimeSessionModelSettings = options .get ("initial_model_settings" , {})
148150
151+ self ._playback_tracker = options .get ("playback_tracker" , RealtimePlaybackTracker ())
152+
149153 self .model = model_settings .get ("model_name" , self .model )
150154 api_key = await get_api_key (options .get ("api_key" ))
151155
@@ -294,47 +298,75 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
294298 if event .start_response :
295299 await self ._send_raw_message (OpenAIResponseCreateEvent (type = "response.create" ))
296300
301+ def _get_playback_state (self ) -> RealtimePlaybackState :
302+ if self ._playback_tracker :
303+ return self ._playback_tracker .get_state ()
304+
305+ if last_audio_item_id := self ._audio_state_tracker .get_last_audio_item ():
306+ item_id , item_content_index = last_audio_item_id
307+ audio_state = self ._audio_state_tracker .get_state (item_id , item_content_index )
308+ if audio_state :
309+ elapsed_ms = (
310+ datetime .now () - audio_state .initial_received_time
311+ ).total_seconds () * 1000
312+ return {
313+ "current_item_id" : item_id ,
314+ "current_item_content_index" : item_content_index ,
315+ "elapsed_ms" : elapsed_ms ,
316+ }
317+
318+ return {
319+ "current_item_id" : None ,
320+ "current_item_content_index" : None ,
321+ "elapsed_ms" : None ,
322+ }
323+
297324 async def _send_interrupt (self , event : RealtimeModelSendInterrupt ) -> None :
298- if not self ._current_item_id or not self ._audio_start_time :
325+ playback_state = self ._get_playback_state ()
326+ current_item_id = playback_state .get ("current_item_id" )
327+ current_item_content_index = playback_state .get ("current_item_content_index" )
328+ elapsed_ms = playback_state .get ("elapsed_ms" )
329+ if current_item_id is None or elapsed_ms is None :
330+ logger .info (
331+ "Skipping interrupt. "
332+ f"Item id: { current_item_id } , "
333+ f"elapsed ms: { elapsed_ms } , "
334+ f"content index: { current_item_content_index } "
335+ )
299336 return
300337
301- await self ._cancel_response ()
302-
303- elapsed_time_ms = (datetime .now () - self ._audio_start_time ).total_seconds () * 1000
304- if elapsed_time_ms > 0 and elapsed_time_ms < self ._audio_length_ms :
338+ current_item_content_index = current_item_content_index or 0
339+ if elapsed_ms > 0 :
305340 await self ._emit_event (
306341 RealtimeModelAudioInterruptedEvent (
307- item_id = self . _current_item_id ,
308- content_index = self . _current_audio_content_index or 0 ,
342+ item_id = current_item_id ,
343+ content_index = current_item_content_index ,
309344 )
310345 )
311346 converted = _ConversionHelper .convert_interrupt (
312- self . _current_item_id ,
313- self . _current_audio_content_index or 0 ,
314- int (elapsed_time_ms ),
347+ current_item_id ,
348+ current_item_content_index ,
349+ int (elapsed_ms ),
315350 )
316351 await self ._send_raw_message (converted )
352+ await self ._cancel_response ()
317353
318- self ._current_item_id = None
319- self ._audio_start_time = None
320- self ._audio_length_ms = 0.0
321- self ._current_audio_content_index = None
354+ self ._audio_state_tracker .on_interrupted ()
355+ if self ._playback_tracker :
356+ self ._playback_tracker .on_interrupted ()
322357
323358 async def _send_session_update (self , event : RealtimeModelSendSessionUpdate ) -> None :
324359 """Send a session update to the model."""
325360 await self ._update_session_config (event .session_settings )
326361
327362 async def _handle_audio_delta (self , parsed : ResponseAudioDeltaEvent ) -> None :
328363 """Handle audio delta events and update audio tracking state."""
329- self ._current_audio_content_index = parsed .content_index
330364 self ._current_item_id = parsed .item_id
331- if self ._audio_start_time is None :
332- self ._audio_start_time = datetime .now ()
333- self ._audio_length_ms = 0.0
334365
335366 audio_bytes = base64 .b64decode (parsed .delta )
336- # Calculate audio length in ms using 24KHz pcm16le
337- self ._audio_length_ms += self ._calculate_audio_length_ms (audio_bytes )
367+
368+ self ._audio_state_tracker .on_audio_delta (parsed .item_id , parsed .content_index , audio_bytes )
369+
338370 await self ._emit_event (
339371 RealtimeModelAudioEvent (
340372 data = audio_bytes ,
@@ -344,10 +376,6 @@ async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
344376 )
345377 )
346378
347- def _calculate_audio_length_ms (self , audio_bytes : bytes ) -> float :
348- """Calculate audio length in milliseconds for 24KHz PCM16LE format."""
349- return len (audio_bytes ) / 24 / 2
350-
351379 async def _handle_output_item (self , item : ConversationItem ) -> None :
352380 """Handle response output item events (function calls and messages)."""
353381 if item .type == "function_call" and item .status == "completed" :
0 commit comments