Skip to content

Commit c8fa08d

Browse files
nv-guomingzkaiyux
andauthored
doc: update cuda_graph_config usage part in DS R1 docs (#5796)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
1 parent 5203a0f commit c8fa08d

File tree

3 files changed

+41
-32
lines changed

3 files changed

+41
-32
lines changed

docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,20 @@ We are seeing meaningful speedup using FP8 KV cache, thus refreshing the numbers
195195
#### Benchmark
196196
```bash
197197
cat >./extra-llm-api-config.yml <<EOF
198-
use_cuda_graph: true
199-
cuda_graph_padding_enabled: true
200-
cuda_graph_batch_sizes:
201-
- 896
202-
- 512
203-
- 256
204-
- 128
205-
- 64
206-
- 32
207-
- 16
208-
- 8
209-
- 4
210-
- 2
211-
- 1
198+
cuda_graph_config:
199+
padding_enabled: true
200+
batch_sizes:
201+
- 896
202+
- 512
203+
- 256
204+
- 128
205+
- 64
206+
- 32
207+
- 16
208+
- 8
209+
- 4
210+
- 2
211+
- 1
212212
print_iter_log: true
213213
kv_cache_dtype: fp8
214214
enable_attention_dp: true
@@ -262,19 +262,19 @@ python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py \
262262
YOUR_DATA_PATH=./dataset.txt
263263

264264
cat >./extra-llm-api-config.yml <<EOF
265-
use_cuda_graph: true
266-
cuda_graph_padding_enabled: true
267-
cuda_graph_batch_sizes:
268-
- 1
269-
- 2
270-
- 4
271-
- 8
272-
- 16
273-
- 32
274-
- 64
275-
- 128
276-
- 256
277-
- 384
265+
cuda_graph_config:
266+
padding_enabled: true
267+
batch_sizes:
268+
- 1
269+
- 2
270+
- 4
271+
- 8
272+
- 16
273+
- 32
274+
- 64
275+
- 128
276+
- 256
277+
- 384
278278
print_iter_log: ${PRINT_ITER_LOG}
279279
enable_attention_dp: true
280280
EOF

docs/source/blogs/tech_blog/blog3_Optimizing_DeepSeek_R1_Throughput_on_NVIDIA_Blackwell_GPUs.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,13 @@ These optimizations target the overall execution flow, scheduling, and resource
151151

152152
* CUDA Graph
153153

154-
This had a significant **22% E2E performance impact** for throughput scenarios. CUDA Graphs allow capturing a sequence of CUDA operations and launching them as a single unit, drastically reducing kernel launch overheads. This is particularly beneficial for models with many small kernels, and particularly on the PyTorch flow, because the python host code normally executes slower than C++. Since the CUDA Graph freezes the kernel launch parameters, which is normally associated with the tensor shapes, it can only be safely used with static shape, meaning that different CUDA graphs need to be captured for different batch sizes. Each graph will have some cost of memory usage, and capturing time, thus we cannot capture every possible CUDA graph for all possible batches. For the non-captured batch sizes, PyTorch eager mode code will be executed. There is a feature called CUDA Graph padding in TensorRT-LLM, which is a good trade-off between the number of CUDA Graphs and the CUDA Graph hit ratio; it tries to pad a batch to the nearest one with a captured CUDA Graph. Normally you should enable the CUDA Graph padding feature to increase the CUDA Graph hit rate, but the padding itself has some overhead due to wasted tokens computation. Users can opt-out the CUDA Graph padding feature to see the perf benefits, by setting the `cuda_graph_padding_enabled` to false, see API here [Pytorch backend config](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/config.py#L41)
154+
This had a significant **22% E2E performance impact** for throughput scenarios.
155+
156+
CUDA Graphs allow capturing a sequence of CUDA operations and launching them as a single unit, drastically reducing kernel launch overheads. This is particularly beneficial for models with many small kernels, and particularly on the PyTorch flow, because the python host code normally executes slower than C++. Since the CUDA Graph freezes the kernel launch parameters, which is normally associated with the tensor shapes, it can only be safely used with static shape, meaning that different CUDA graphs need to be captured for different batch sizes. Each graph will have some cost of memory usage, and capturing time, thus we cannot capture every possible CUDA graph for all possible batches. For the non-captured batch sizes, PyTorch eager mode code will be executed.
157+
158+
There is a feature called CUDA Graph padding in TensorRT-LLM, which is a good trade-off between the number of CUDA Graphs and the CUDA Graph hit ratio; it tries to pad a batch to the nearest one with a captured CUDA Graph. Normally you should enable the CUDA Graph padding feature to increase the CUDA Graph hit rate, but the padding itself has some overhead due to wasted tokens computation.
159+
160+
Users can opt-out the CUDA Graph padding feature to see the perf benefits, by setting the `cuda_graph_config:\n padding_enabled: False`, see API here [Pytorch backend config](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/config.py#L41)
155161

156162
* Overlap Scheduler:
157163

tests/integration/defs/perf/pytorch_model_config.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ def get_model_yaml_config(model_label: str,
6565
],
6666
'config': {
6767
'enable_attention_dp': True,
68-
'cuda_graph_padding_enabled': True,
69-
'cuda_graph_batch_sizes':
70-
[1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
68+
'cuda_graph_config': {
69+
'padding_enabled': True,
70+
'batch_sizes': [1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
71+
}
7172
}
7273
},
7374
# DeepSeek R1 model with specific batch size 128
@@ -76,7 +77,9 @@ def get_model_yaml_config(model_label: str,
7677
'deepseek_r1-bench-pytorch-float16-maxbs:128-maxnt:1127-input_output_len:1000,2000-quant:fp8-reqs:5120-con:1024-ep:8-gpus:8',
7778
'config': {
7879
'enable_attention_dp': True,
79-
'cuda_graph_batch_sizes': [128]
80+
'cuda_graph_config': {
81+
'batch_sizes': [128]
82+
}
8083
}
8184
},
8285
# Deepseek_v3_lite_cases

0 commit comments

Comments
 (0)