Skip to content

Commit cf1e983

Browse files
authored
Restore cuda graphs to continuous batching (#41421)
* Type hints and small fixes * Remove unusued params * Made slice inputs the default * ruffed * Updated some var name and moved index slicing * Logging arg in example * Added some padding debug var and reformat out cg * First working CG, fixe size * Working flexible CG * CG are compatible with all implementations * Fixed CG API * Update example * Documentation * Fix padding tokens in FA * Review compliance * Better doc around weird bug * Style * Fix for sliding with CG
1 parent 6c901bd commit cf1e983

File tree

8 files changed

+380
-233
lines changed

8 files changed

+380
-233
lines changed

examples/pytorch/continuous_batching.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,25 @@
2626

2727
from transformers import AutoModelForCausalLM, AutoTokenizer
2828
from transformers.generation import GenerationConfig
29+
from transformers.generation.continuous_batching.requests import logger
2930

3031

3132
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
3233
SLIDING_WINDOW = 0
33-
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "Qwen/Qwen3-4B-Instruct-2507"
34+
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
3435
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
36+
SKIP_SPECIAL_TOKENS = False
3537

3638

3739
def generate_simple(
3840
attn_impl: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
3941
) -> dict[str, str]:
4042
attn_impl = {
41-
"sdpa_paged": "sdpa",
42-
"eager_paged": "eager",
43+
"sdpa": "sdpa",
44+
"eager": "eager",
4345
"paged_attention": "eager", # TODO: this does not work on AMD docker
4446
"flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
47+
"kernels-community/flash-attn": "eager",
4548
}[attn_impl]
4649

4750
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
@@ -56,7 +59,7 @@ def generate_simple(
5659
# attention_mask = torch.ones_like(input_ids)
5760
outputs = model.generate(input_ids, generation_config=generation_config, use_model_defaults=False)
5861
generated_tokens = outputs[0][input_ids.shape[1] :]
59-
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
62+
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS)
6063
decoded_outputs[key] = decoded_output
6164
return decoded_outputs
6265

@@ -99,7 +102,6 @@ def batch_generate(
99102
displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs
100103
output_file: Optional[str] = None,
101104
expected_outputs: Optional[list[str]] = None,
102-
slice_inputs: bool = True,
103105
) -> tuple[float, float]:
104106
# Actual batch generation
105107
if displayed_samples >= 0:
@@ -108,7 +110,6 @@ def batch_generate(
108110
batch_outputs = model.generate_batch(
109111
inputs=simple_batch_inputs,
110112
generation_config=generation_config,
111-
slice_inputs=slice_inputs, # TODO: move this to the generation config
112113
)
113114
end_time_simple = time.time()
114115
if displayed_samples >= 0:
@@ -118,19 +119,21 @@ def batch_generate(
118119
token_count = 0
119120
data = []
120121
for i, request in enumerate(batch_outputs):
121-
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True)
122+
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=SKIP_SPECIAL_TOKENS)
122123
# The key is used to tie back to the output of unbatched generation
123124
key = " ".join(map(str, batch_outputs[request].prompt_ids))
124125
data.append({"input": input_text, "key": key})
125126

126127
# Try to decode the output
127128
try:
128-
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=True)
129+
output_text = tokenizer.decode(
130+
batch_outputs[request].generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS
131+
)
129132
token_count += len(batch_outputs[request].generated_tokens[1:])
130-
data[-1]["output"] = output_text
133+
data[-1]["cb_outputs"] = output_text
131134
except Exception as e:
132135
print(f"Decoding failed for request {request}: {e}")
133-
data[-1]["output"] = "__ERROR__"
136+
data[-1]["cb_outputs"] = "__ERROR__"
134137
continue
135138

136139
# Display sample if asked
@@ -148,7 +151,7 @@ def batch_generate(
148151
if expected_outputs is not None:
149152
expected_output = expected_outputs.pop(key)
150153
matches = output_text == expected_output # TODO: rework this for a better distance metric
151-
data[-1]["ref"] = expected_output
154+
data[-1]["without_cb"] = expected_output
152155
data[-1]["matches"] = matches
153156
data[-1].pop("key")
154157
print(f"Request {i} matches" if matches else f"Request {i} does NOT match!")
@@ -186,19 +189,20 @@ def batch_generate(
186189

187190
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
188191
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
189-
parser.add_argument("--no-slice-inputs", action="store_true") # slicing is enabled by default because much faster
190-
parser.add_argument("--use-cuda-graph", "-cg", action="store_true")
191-
parser.add_argument("--compile", action="store_true")
192+
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
193+
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
192194

193-
parser.add_argument("--samples", type=int, default=500)
195+
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
194196
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
197+
parser.add_argument("--log-level", type=str, default="INFO")
195198
parser.add_argument("--output-file", type=str, default=None)
196199
parser.add_argument("--compare", action="store_true")
197200
parser.add_argument("--metrics", action="store_true")
198201
parser.add_argument("--profile", type=str, default=None)
199202
args = parser.parse_args()
200203

201-
args.slice_inputs = not args.no_slice_inputs
204+
# Set log level
205+
logger.setLevel(args.log_level.upper())
202206

203207
# If turned on, we setup metrics
204208
if args.metrics:
@@ -207,6 +211,15 @@ def batch_generate(
207211
# Set matmul precision if not none
208212
if args.matmul_precision != "none":
209213
torch.set_float32_matmul_precision(args.matmul_precision)
214+
# Parse cuda graph argument
215+
if args.cuda_graph is not None:
216+
use_cuda_graph = {
217+
"none": None,
218+
"yes": True, "y": True, "true": True, "t": True, "1": True,
219+
"no": False, "n": False, "false": False, "f": False, "0": False,
220+
}[args.cuda_graph.lower()] # fmt: skip
221+
else:
222+
use_cuda_graph = None
210223

211224
# Prepare model
212225
model = AutoModelForCausalLM.from_pretrained(
@@ -222,9 +235,6 @@ def batch_generate(
222235
# If turned on, we compile the model
223236
if args.compile:
224237
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
225-
if args.slice_inputs:
226-
assert not args.compile, "Slicing inputs requires is not the model to be compiled"
227-
assert not args.use_cuda_graph, "Slicing inputs is not compatible with cuda graphs"
228238

229239
# Prepare tokenizer and dataset
230240
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
@@ -237,10 +247,10 @@ def batch_generate(
237247
# Prepare generation config
238248
generation_config = GenerationConfig(
239249
max_new_tokens=512,
240-
use_cuda_graph=args.use_cuda_graph,
250+
use_cuda_graph=use_cuda_graph,
241251
eos_token_id=tokenizer.pad_token_id if FORCE_MAX_LENGTH else tokenizer.eos_token_id,
242252
pad_token_id=tokenizer.pad_token_id,
243-
do_sample=True,
253+
do_sample=not args.compare,
244254
temperature=0.8,
245255
top_p=0.9,
246256
num_blocks=args.num_blocks,
@@ -265,7 +275,6 @@ def batch_generate(
265275
generation_config,
266276
tokenizer,
267277
displayed_samples=-1,
268-
slice_inputs=args.slice_inputs,
269278
)
270279

271280
if args.profile is not None:
@@ -282,12 +291,11 @@ def batch_generate(
282291
displayed_samples=args.displayed,
283292
output_file=args.output_file,
284293
expected_outputs=expected_outputs,
285-
slice_inputs=args.slice_inputs,
286294
)
287295
if args.profile is not None:
288296
filename = args.profile if args.profile.endswith(".json") else args.profile + ".json"
289297
prof.export_chrome_trace(filename)
290298

291299
# Example usage:
292-
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --slice-inputs --samples 3 --compare
300+
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
293301
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json

examples/pytorch/continuous_batching_simple.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
_ = model.generate_batch(
6969
inputs=simple_batch_inputs[: min(5, args.samples)],
7070
generation_config=generation_config,
71-
slice_inputs=True,
7271
)
7372

7473
# Actual batch generation
@@ -77,7 +76,6 @@
7776
batch_outputs = model.generate_batch(
7877
inputs=simple_batch_inputs,
7978
generation_config=generation_config,
80-
slice_inputs=True,
8179
)
8280
end_time = time.time()
8381
print("Done with batch generation.")

src/transformers/generation/continuous_batching/cache.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def __init__(
204204
# Initialize the cache
205205
self.key_cache: list[torch.Tensor] = []
206206
self.value_cache: list[torch.Tensor] = []
207-
# We add one extra token to the cache to handle padding and generally discard unwanted tokens
208-
self.cache_shape = (num_blocks * self.block_size + 1, self.num_key_value_heads, self.head_dim)
207+
# We add two extra tokens to the cache to handle padding and generally discard unwanted tokens
208+
self.cache_shape = (num_blocks * self.block_size + 2, self.num_key_value_heads, self.head_dim)
209209
for _ in range(group_size):
210210
new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
211211
new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
@@ -290,7 +290,6 @@ def update(
290290
layer_idx: int,
291291
read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length]
292292
write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q]
293-
**kwargs,
294293
) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim]
295294
"""Update the cache with new key-value states for a specific layer. This method writes new KV states to the
296295
appropriate cache locations. The behavior differs based on the layer's attention type:
@@ -324,11 +323,11 @@ def update(
324323
# the only case where you may write over cache you need to use
325324
else:
326325
# Add the cache to the key and value states
327-
mask = layer_read_index == -1 # TODO: can this can be efficiently precomputed?
326+
mask = (layer_read_index == -1).unsqueeze(-1).unsqueeze(-1) # TODO: should this be precomputed?
328327
key_states_with_cache = k_cache[layer_read_index, :, :]
329-
key_states_with_cache[mask] = key_states
328+
key_states_with_cache.masked_scatter_(mask, key_states)
330329
value_states_with_cache = v_cache[layer_read_index, :, :]
331-
value_states_with_cache[mask] = value_states
330+
value_states_with_cache.masked_scatter_(mask, value_states)
332331
# Write new KV values to the cache
333332
k_cache[layer_write_index, :, :] = key_states
334333
v_cache[layer_write_index, :, :] = value_states

0 commit comments

Comments
 (0)