-
- Notifications
You must be signed in to change notification settings - Fork 12.1k
[Feature] Generic Model Support via TrainableAttention and ModelRegistry parallelism constructor callback #28685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| Documentation preview: https://vllm--28685.org.readthedocs.build/en/28685/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant new functionality to support generic models in vLLM, including a trainable flash attention layer and a mechanism for registering models with a parallelism-aware constructor. The changes are well-structured and include comprehensive examples and tests. My review focuses on potential race conditions and robustness improvements in the new code and examples. I've identified a thread-safety issue in the TrainableFlashAttention layer and some hardcoded values in the new example that could cause issues in multi-user environments. Overall, this is a great addition to vLLM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
94fb433 to 4ba8efb Compare 257a8bd to 2421b47 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For an RL framework to use such vLLM-version of torchtitan models, where should such models eventually be put?
A related question would be, for batch invariance implementation we can use the "canonical" version of modules (Attention, RoPE, etc.) showcased here, then for other versions (e.g. FlexAttention with arbitrary masks), how should we configure them? Is the idea that we configure the RL framework to just not call into these compat functions including replace_with_trainable_attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where should such models eventually be put
that's up to the RL framework
batch invariance implementation
out of scope for this PR. every kernel needs to be audited in a batch invariant impl - user code has no guarantees about that. we'll need some type of test that can highlight issues and allow users to pick which tools are needed (overriding pytorch, torch.compile, replace_with_trainable_attn)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I read (mainly TrainableMLA, TrainableAttention, model_excutor/custom_models), this PR doesn't change vLLM's interface of supporting generic model that much (basically following instructions here: https://docs.vllm.ai/en/latest/contributing/model/registration/) , but provide an practical example that how user could replace a customized model's attention with vllm's attention (so it has kv cache capability). This PR also provide a model wrapper (DeepSeekV3TorchTitanForCausalLM) of Torchtitan model which implemented APIs needed by vLLM. Please correct me if it's wrong.
There are 2 questions to be discussed:
-
Where does the customize model wrapper should be placed?
I'd prefer vLLM could have an example of how to plug in customized model - I was struggling with the doc to run customized model. However, the model wrapper is closed related to torchtitan model definition, even the TrainableMLA. I would prefer to put it in torchtitan, and showcase "vLLM+custom model" with a simpler nn.Module model. -
How does it work with batch-invariant model in the future? (This might be a very long-term goal)
We want the "canonical" model definition between training and inference, because we want to run batch-invariant mode and achieve (almost) bit-wise identity between training and inference. Currently TrainableAttention has 2 separate path for training and inference. Do you have some idea how we should achieve bitwise-identity here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
5b43931 to 98c3640 Compare - Fix bug in TrainableFlashAttention where batch_size/seq_len were overwritten in training path, causing incorrect output shapes - Use original_batch_size/original_seq_len to preserve input dimensions - Remove tests/test_generic_model_support.py (tests moved to tests/models/test_generic_models.py) Signed-off-by: Bram Wasti <bwasti@meta.com>
Implements Multi-Head Latent Attention (MLA) for DeepSeek V3 using vLLM's native attention backend instead of manual implementations. Key features: - TrainableMLA layer that delegates to vLLM's MLAAttention for KV cache - Position-aware RoPE indexing for batched/chunked prefill - Proper KV cache spec registration (MLAAttentionSpec) - Converts TorchTitan's complex freqs_cis to real format before dtype conversion Architecture pattern: 1. Import TorchTitan model as-is 2. Replace attention layers with TrainableMLA (creates vLLM MLAAttention internally) 3. Implement vLLM model interface (get_input_embeddings, forward, compute_logits, load_weights, get_kv_cache_spec) 4. Register with ModelRegistry Files: - vllm/model_executor/layers/trainable_mla_attention.py: TrainableMLA layer - examples/offline_inference/deepseek_v3_torchtitan.py: DeepSeek V3 model wrapper Signed-off-by: Bram Wasti <bwasti@meta.com>
- Add qwen3_torchtitan.py: Example of Qwen3 + TorchTitan integration - Auto-format deepseek_v3_torchtitan.py and trainable_mla_attention.py (ruff) Signed-off-by: Bram Wasti <bwasti@meta.com>
Signed-off-by: Bram Wasti <bwasti@meta.com>
Major cleanup of TorchTitan integration to create reusable, generic components for integrating any custom model implementation with vLLM. Key changes: 1. Created generic base class VLLMModelForCausalLM (was TorchTitanModelWrapper) - vllm/model_executor/models/custom_model_wrapper.py - Enforces vLLM interface for any external model implementation 2. Created reusable attention replacement utilities - vllm/model_executor/layers/attention_replacement.py - replace_with_trainable_attention() unified API for both TrainableFlashAttention and TrainableMLA 3. Created custom model integration utilities - vllm/model_executor/custom_models/utils.py - load_external_weights() for generic weight loading - convert_freqs_cis_to_real() for RoPE conversion - create_mla_kv_cache_spec() for MLA KV cache - store_positions_in_context() for position management 4. Simplified examples (removed verbose prints, conditional logic) - deepseek_v3_torchtitan.py: 516 → 289 lines (44% reduction) - qwen3_torchtitan.py: 471 → 266 lines (44% reduction) - Both inherit from VLLMModelForCausalLM - Both use reusable vLLM utilities 5. Fixed standalone testing (added AssertionError handling) - trainable_mla_attention.py: Handle vLLM context gracefully - trainable_attention.py: Handle forward context gracefully All examples tested and working. Signed-off-by: Bram Wasti <bwasti@meta.com>
- Integrate TorchTitan's apply_non_moe_tp() for proper weight sharding - TP logic is explicitly inlined in build functions for clarity - Remove data parallelism (broadcast) from Qwen3 in favor of TP - TP happens after model creation but before dtype conversion - Add test_deepseek_v3_tp.py for TP validation This enables multi-GPU inference with actual tensor parallelism, not just data parallelism. Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit reorganizes the custom model integration code for better discoverability and cleaner API surface: **vLLM core changes:** - Move trainable attention files to vllm/model_executor/custom_models/ - trainable_attention.py (TrainableFlashAttention) - trainable_mla_attention.py (TrainableMLA) - attention_replacement.py (replace_with_trainable_attention) - custom_model_wrapper.py (VLLMModelForCausalLM base class) - Update vllm/model_executor/custom_models/__init__.py to export all API - Update all import paths throughout codebase: - Example files (deepseek_v3, qwen3, megatron example) - Test files (test_generic_models.py) - Documentation (basic.md) - Internal references (layers/__init__.py) **Example reorganization:** - Create examples/custom_models/ directory - Move deepseek_v3_torchtitan.py and qwen3_torchtitan.py to examples/custom_models/ - Add examples/custom_models/README.md with usage instructions **Test reorganization:** - Create tests/custom_models/ directory - Move test_deepseek_v3_tp.py to tests/custom_models/ - Convert to proper pytest test with fixtures and assertions **Documentation:** - Add docs/contributing/model/custom.md with comprehensive guide - Quick start example - Core components reference - Advanced topics (TP integration, multi-GPU) - Complete examples and troubleshooting **Benefits:** - Single unified import: from vllm.model_executor.custom_models import ... - Clear separation of custom model code from vLLM internals - Better discoverability for users - Proper test organization - Comprehensive documentation All imports updated, tests pass, documentation complete. Signed-off-by: Bram Wasti <bwasti@meta.com>
Add simple benchmark script to compare DeepSeek V3 TorchTitan custom model against vLLM's built-in implementation. Features: - 100 requests with configurable batch size (default 32) - Measures throughput, latency (P50/P90/P99), and requests/sec - Configurable TP size, prompt length, output tokens - Warmup phase for stable measurements - Side-by-side comparison when both models available Usage: python examples/custom_models/benchmark_deepseek_v3.py \ --model deepseek-ai/DeepSeek-V3-Base \ --tp 8 \ --num-requests 100 \ --max-batch-size 32 Update README with benchmarking instructions and metrics explanation. Signed-off-by: Bram Wasti <bwasti@meta.com>
Add sys.path manipulation to allow running the benchmark script directly from anywhere without installing vLLM as a package. The script now: 1. Finds the vLLM root directory (../../ from script location) 2. Adds it to sys.path 3. Imports the custom model registration module This allows users to run: python examples/custom_models/benchmark_deepseek_v3.py --help Without needing to install vLLM or set PYTHONPATH. Signed-off-by: Bram Wasti <bwasti@meta.com>
Add max_model_len parameter (default 8192) to limit KV cache allocation. Without this, DeepSeek V3 tries to allocate cache for 163K tokens which causes OOM on most GPUs. Changes: - Add --max-model-len argument (default 8192) - Pass max_model_len to LLM() to limit KV cache allocation - Display max_model_len in benchmark output This allows the benchmark to run on GPUs with limited memory by controlling the KV cache size. Signed-off-by: Bram Wasti <bwasti@meta.com>
Add --use-builtin and --run-both flags to enable comparing the custom TorchTitan model against vLLM's built-in DeepSeek V3 implementation. New flags: - --use-builtin: Run only built-in vLLM model (don't import custom) - --run-both: Run both custom and built-in for side-by-side comparison - --skip-custom: Skip custom model benchmark The custom model is now imported conditionally, allowing users to benchmark vLLM's native implementation without the TorchTitan overlay. Usage: # Custom only (default) python examples/custom_models/benchmark_deepseek_v3.py --num-requests 10 # Built-in only python examples/custom_models/benchmark_deepseek_v3.py --use-builtin --num-requests 10 # Both for comparison python examples/custom_models/benchmark_deepseek_v3.py --run-both --num-requests 10 Signed-off-by: Bram Wasti <bwasti@meta.com>
Add clear instructions for comparing custom TorchTitan model vs vLLM's built-in DeepSeek V3 implementation. Show examples of: - Running custom model only (default) - Running built-in only (--use-builtin) - Comparing both side-by-side (--run-both) - Expected comparison table output Include PYTHONPATH prefix in all examples since the benchmark needs to import vLLM from source. Signed-off-by: Bram Wasti <bwasti@meta.com>
vLLM's LLM API doesn't expose per-request timing metrics on output objects, causing P50/P90/P99 latencies to show as 0.00. Fix: - For small benchmarks (≤20 requests): Generate requests one at a time and manually track start/end times for accurate per-request latency - For large benchmarks (>20 requests): Use batched generation for speed, but P50/P90/P99 will show as N/A with helpful message This provides accurate latency percentiles for small tests while keeping large benchmarks fast. Example output: P50/P90/P99: N/A (use --num-requests ≤20 for per-request latency) Signed-off-by: Bram Wasti <bwasti@meta.com>
Previously, per-request latency percentiles (P50/P90/P99) were only available for ≤20 requests due to the overhead of one-at-a-time generation. This commit implements batched timing tracking: - Generate requests in small batches (batch_size=8) - Track timing for each batch - Estimate per-request latency as batch_time / batch_size This approach provides latency percentiles for benchmarks of any size (10, 100, or more requests) while maintaining reasonable performance. Signed-off-by: Bram Wasti <bwasti@meta.com>
Changes: 1. Add --track-latency flag to control fine-grained latency tracking - Default: Use full max_batch_size for maximum throughput - With flag: Use small batches (8) for P50/P90/P99 metrics - This fixes the throughput degradation caused by always using small batches 2. Improve naming throughout benchmark: - "CUSTOM MODEL" → "TorchTitan DeepSeek (Custom Implementation)" - "BUILT-IN vLLM MODEL" → "Built-in DeepSeek (vLLM Native Implementation)" - Column header "Custom" → "TorchTitan" - Mode descriptions updated for clarity 3. Update help text: - P50/P90/P99 message now says "use --track-latency for percentiles" This allows users to choose between maximum throughput (default) and detailed latency metrics (--track-latency). Signed-off-by: Bram Wasti <bwasti@meta.com>
- Fix line length violations (E501) in custom model files - Simplify nested if statements and use ternary operators (SIM102, SIM108) - Remove unused variable in trainable_mla_attention.py (F841) - Apply ruff formatting to test file Signed-off-by: Bram Wasti <bwasti@meta.com>
Signed-off-by: Bram Wasti <bwasti@meta.com>
- Add type annotations for **kwargs parameters (Any) - Add type annotations for weights_iter parameters (Iterator[tuple[str, torch.Tensor]]) - Add type annotations for sampling_metadata parameters (Any) - Add return type annotations for process group methods (dist.ProcessGroup | None) - Add return type annotation for create_mla_kv_cache_spec (Any) - Update imports to use collections.abc.Iterator (modern Python) Signed-off-by: Bram Wasti <bwasti@meta.com>
The code was incorrectly splitting freqs_cis at head_dim instead of at half of the last dimension. When using convert_freqs_cis_to_real, the tensor has shape [max_seq_len, head_dim] with cos and sin concatenated, so splitting at head_dim would make sin an empty tensor and cause broadcast failures. Changes: - Split at half of last dimension instead of head_dim - Add repeat_interleave to expand cos/sin to full head_dim - Update comments to reflect correct tensor shapes This fixes runtime errors when using real-format RoPE frequencies from TorchTitan integration. Signed-off-by: Bram Wasti <bwasti@meta.com>
- Remove multiple consecutive blank lines (MD012) - Add blank lines before lists (MD032) - Fix ordered list item indentation (MD029) - Fix unordered list indentation (MD007) - Add blank line before fenced code block (MD031) - Add language specifier to fenced code block (MD040) Signed-off-by: Bram Wasti <bwasti@meta.com>
Remove blank lines between ordered list items to maintain list continuity for markdownlint (MD029). Signed-off-by: Bram Wasti <bwasti@meta.com>
Indent code blocks within ordered list items to maintain list continuity. This prevents markdownlint from treating the code blocks as list terminators (MD029). Signed-off-by: Bram Wasti <bwasti@meta.com>
Skip all DeepSeek V3 tensor parallelism tests while the custom model examples are still under development. These can be re-enabled once the implementation is stable. Signed-off-by: Bram Wasti <bwasti@meta.com>
Make examples/ a proper Python package so that pytest can import from examples.custom_models during test collection. Signed-off-by: Bram Wasti <bwasti@meta.com>
This reverts commit fb5e6c2. Signed-off-by: Bram Wasti <bwasti@meta.com>
98c3640 to 95a402b Compare | print("Built-in DeepSeek (vLLM Native Implementation)") | ||
| print("=" * 70) | ||
| results["builtin"] = benchmark_model( | ||
| model_name=args.model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bwasti Thanks for the great PR, it helps me a lot of understanding vllm model registering... I have a question about running torchtitan model here:
-
The model_name= "deepseek-ai/DeepSeek-V3-Base", so its running the following code:
llm = LLM( model="deepseek-ai/DeepSeek-V3-Base", trust_remote_code=True, tensor_parallel_size=8, ), which will downloadconfig.jsonfrom HF, and the architecture isDeepseekV3ForCausalLM -
When registering model in
examples/custom_models/deepseek_torchtitan.py, the registered model architecture is "DeepSeekV3TorchTitan"
How does vllm find registered customized model, as these 2 model architectures are different? I run the following command python examples/custom_models/benchmark_deepseek_v3.py --run-both --num-requests 10 . Did I missed something here? Thank you!
| | ||
| # Get RoPE cache | ||
| seqlen = h.shape[1] if h.dim() == 3 else h.shape[0] | ||
| rope_cache = self.model.rope_cache[:seqlen] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: should this be self.model.rope_cache[positions]?


Purpose
Implement RFC #28326 to enable users to easily integrate custom models with vLLM for both training and inference, with support for external parallelism libraries (e.g., Megatron-LM, FSDP, DeepSpeed).
Enables use cases like:
Test Plan
Test Result
Both pass
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.