Skip to content

Conversation

wwl2755
Copy link
Contributor

@wwl2755 wwl2755 commented Oct 7, 2025

Fix some of #23888

Enable audio in video in Qwen2.5-Omni in V1 engine.

Same purpose as #26156, but using a different and simpler method from @ywang96 . Basic idea is to create two placeholders for video and audio with the same start_idx, but use "is_embed" to differetiate them.

Basic flow

<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>Describe the content of the video<|im_end|> # no audio placeholder in the prompt -> "video": [ PlaceholderFeaturesInfo( start_idx=4, tokens=[151659, 151655, 151655, 151654, 151654, 151660], is_embed=[False, True, True, False, False, False] ) ] -> "audio": [ PlaceholderFeaturesInfo( start_idx=4, tokens=[151659, 151655, 151655, 151654, 151654, 151660], is_embed=[False, False, False, True, True, False] ) ] -> <|im_start|>user\n<|vision_bos|><|audio_bos|><|VIDEO|>*2<|AUDIO|>*2<|audio_eos|><|vision_eos|>Describe the content of the video<|im_end|> 

Known limitation

This PR assumes the number of video and audio would exactly match to enable use_audio_in_video as in the example.

Test

python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video INFO 10-09 04:02:38 [llm.py:340] Supported_tasks: ['generate'] Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00, 6.42s/it] Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.14s/it, est. speed input: 2370.76 toks/s, output: 80.69 toks/s] The video shows a baby sitting on a bed, wearing glasses, and holding a book. The baby seems to be looking at the book and turning the pages. I'm not sure what the baby says, but it could be something like "book" or "read". So, the text of what the baby says is "book" or "read". If you have any other questions about the video or anything else, feel free to let me know. 
@mergify mergify bot added documentation Improvements or additions to documentation qwen Related to Qwen models v1 labels Oct 7, 2025
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Copy link

mergify bot commented Oct 8, 2025

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Comment on lines 388 to 390
use_audio_in_video = all(
item["use_audio_in_video"].data for item in video_items
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This existing code seems to assume all video inputs should have a paired audio to enable use_audio_in_video.

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 853 to +859
second_per_grid_ts.append(t)
if (t := mm_input.get("audio_feature_lengths")) is not None:
audio_feature_lengths.append(t)
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
# Check for use_audio_in_video
use_audio_in_video_value = mm_input.get("use_audio_in_video")
if use_audio_in_video_value is not None:
use_audio_in_video = bool(use_audio_in_video_value.item())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve any use_audio_in_video flag across batch

The new loop in _init_mrope_positions overwrites use_audio_in_video on every multimodal item (use_audio_in_video = bool(use_audio_in_video_value.item())). When a batch mixes requests that require audio-in-video with ones that do not, the last item processed can reset the flag to False, so get_mrope_input_positions is called without audio-in-video handling even though earlier requests needed it. This yields incorrect rotary positions for those prompts. The flag should be accumulated (e.g., OR’ed) instead of overwritten so that any request enabling audio-in-video keeps the global flag true.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to handle use_audio_in_video and non_use_audio_in_video fixed in a request is a problem. This PR's scope is to assume all video items have the same attribute in this field.

@wwl2755
Copy link
Contributor Author

wwl2755 commented Oct 9, 2025

This should be ready to review. Please free feel to take a look when you are free~ @DarkLight1337 @ywang96 @Isotr0py

Comment on lines 409 to 412
(
prompt_ids,
mm_placeholders,
) = self._apply_prompt_updates(
Copy link
Member

@DarkLight1337 DarkLight1337 Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(
prompt_ids,
mm_placeholders,
) = self._apply_prompt_updates(
prompt_ids, mm_placeholders = self._apply_prompt_updates(

Nit: Avoid unnecessary lines. Same below, and can also do the same for self._validate_mm_placeholders

if num_audios != num_videos:
raise ValueError(
f"use_audio_in_video requires equal number of audio and video items, "
f"got audio={num_audios}, video={num_videos}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"got audio={num_audios}, video={num_videos}"
f"got {num_audios=}, {num_videos=}"
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 14, 2025
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation qwen Related to Qwen models v1

2 participants