2626
2727from transformers import AutoModelForCausalLM , AutoTokenizer
2828from transformers .generation import GenerationConfig
29+ from transformers .generation .continuous_batching .requests import logger
2930
3031
3132# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
3233SLIDING_WINDOW = 0
33- MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "Qwen/Qwen3-4B-Instruct-2507 "
34+ MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B "
3435FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
36+ SKIP_SPECIAL_TOKENS = False
3537
3638
3739def generate_simple (
3840 attn_impl : str , simple_batch_inputs : list [int ], generation_config : GenerationConfig
3941) -> dict [str , str ]:
4042 attn_impl = {
41- "sdpa_paged " : "sdpa" ,
42- "eager_paged " : "eager" ,
43+ "sdpa " : "sdpa" ,
44+ "eager " : "eager" ,
4345 "paged_attention" : "eager" , # TODO: this does not work on AMD docker
4446 "flash_paged" : "flash_attention_2" , # TODO: this does not work on AMD docker
47+ "kernels-community/flash-attn" : "eager" ,
4548 }[attn_impl ]
4649
4750 model = AutoModelForCausalLM .from_pretrained (MODEL_ID , dtype = torch .bfloat16 , attn_implementation = attn_impl )
@@ -56,7 +59,7 @@ def generate_simple(
5659 # attention_mask = torch.ones_like(input_ids)
5760 outputs = model .generate (input_ids , generation_config = generation_config , use_model_defaults = False )
5861 generated_tokens = outputs [0 ][input_ids .shape [1 ] :]
59- decoded_output = tokenizer .decode (generated_tokens , skip_special_tokens = True )
62+ decoded_output = tokenizer .decode (generated_tokens , skip_special_tokens = SKIP_SPECIAL_TOKENS )
6063 decoded_outputs [key ] = decoded_output
6164 return decoded_outputs
6265
@@ -99,7 +102,6 @@ def batch_generate(
99102 displayed_samples : int = 0 , # -1: no display, 0: display stats, >0: display inputs and some outputs
100103 output_file : Optional [str ] = None ,
101104 expected_outputs : Optional [list [str ]] = None ,
102- slice_inputs : bool = True ,
103105) -> tuple [float , float ]:
104106 # Actual batch generation
105107 if displayed_samples >= 0 :
@@ -108,7 +110,6 @@ def batch_generate(
108110 batch_outputs = model .generate_batch (
109111 inputs = simple_batch_inputs ,
110112 generation_config = generation_config ,
111- slice_inputs = slice_inputs , # TODO: move this to the generation config
112113 )
113114 end_time_simple = time .time ()
114115 if displayed_samples >= 0 :
@@ -118,19 +119,21 @@ def batch_generate(
118119 token_count = 0
119120 data = []
120121 for i , request in enumerate (batch_outputs ):
121- input_text = tokenizer .decode (batch_outputs [request ].prompt_ids , skip_special_tokens = True )
122+ input_text = tokenizer .decode (batch_outputs [request ].prompt_ids , skip_special_tokens = SKIP_SPECIAL_TOKENS )
122123 # The key is used to tie back to the output of unbatched generation
123124 key = " " .join (map (str , batch_outputs [request ].prompt_ids ))
124125 data .append ({"input" : input_text , "key" : key })
125126
126127 # Try to decode the output
127128 try :
128- output_text = tokenizer .decode (batch_outputs [request ].generated_tokens , skip_special_tokens = True )
129+ output_text = tokenizer .decode (
130+ batch_outputs [request ].generated_tokens , skip_special_tokens = SKIP_SPECIAL_TOKENS
131+ )
129132 token_count += len (batch_outputs [request ].generated_tokens [1 :])
130- data [- 1 ]["output " ] = output_text
133+ data [- 1 ]["cb_outputs " ] = output_text
131134 except Exception as e :
132135 print (f"Decoding failed for request { request } : { e } " )
133- data [- 1 ]["output " ] = "__ERROR__"
136+ data [- 1 ]["cb_outputs " ] = "__ERROR__"
134137 continue
135138
136139 # Display sample if asked
@@ -148,7 +151,7 @@ def batch_generate(
148151 if expected_outputs is not None :
149152 expected_output = expected_outputs .pop (key )
150153 matches = output_text == expected_output # TODO: rework this for a better distance metric
151- data [- 1 ]["ref " ] = expected_output
154+ data [- 1 ]["without_cb " ] = expected_output
152155 data [- 1 ]["matches" ] = matches
153156 data [- 1 ].pop ("key" )
154157 print (f"Request { i } matches" if matches else f"Request { i } does NOT match!" )
@@ -186,19 +189,20 @@ def batch_generate(
186189
187190 parser .add_argument ("--attn" , type = str , default = "kernels-community/flash-attn" , help = "Attention implementation" )
188191 parser .add_argument ("--matmul-precision" , "-mp" , type = str , default = "high" ) # set to "none" to disable
189- parser .add_argument ("--no-slice-inputs" , action = "store_true" ) # slicing is enabled by default because much faster
190- parser .add_argument ("--use-cuda-graph" , "-cg" , action = "store_true" )
191- parser .add_argument ("--compile" , action = "store_true" )
192+ parser .add_argument ("--cuda-graph" , "-cg" , help = "Use cuda graphs" , type = str , default = None )
193+ parser .add_argument ("--compile" , action = "store_true" , help = "Compile the model using torch.compile" )
192194
193- parser .add_argument ("--samples" , type = int , default = 500 )
195+ parser .add_argument ("--samples" , type = int , default = 500 , help = "Number of samples to generate" )
194196 parser .add_argument ("--displayed" , type = int , default = 0 , help = "Number of samples to display" )
197+ parser .add_argument ("--log-level" , type = str , default = "INFO" )
195198 parser .add_argument ("--output-file" , type = str , default = None )
196199 parser .add_argument ("--compare" , action = "store_true" )
197200 parser .add_argument ("--metrics" , action = "store_true" )
198201 parser .add_argument ("--profile" , type = str , default = None )
199202 args = parser .parse_args ()
200203
201- args .slice_inputs = not args .no_slice_inputs
204+ # Set log level
205+ logger .setLevel (args .log_level .upper ())
202206
203207 # If turned on, we setup metrics
204208 if args .metrics :
@@ -207,6 +211,15 @@ def batch_generate(
207211 # Set matmul precision if not none
208212 if args .matmul_precision != "none" :
209213 torch .set_float32_matmul_precision (args .matmul_precision )
214+ # Parse cuda graph argument
215+ if args .cuda_graph is not None :
216+ use_cuda_graph = {
217+ "none" : None ,
218+ "yes" : True , "y" : True , "true" : True , "t" : True , "1" : True ,
219+ "no" : False , "n" : False , "false" : False , "f" : False , "0" : False ,
220+ }[args .cuda_graph .lower ()] # fmt: skip
221+ else :
222+ use_cuda_graph = None
210223
211224 # Prepare model
212225 model = AutoModelForCausalLM .from_pretrained (
@@ -222,9 +235,6 @@ def batch_generate(
222235 # If turned on, we compile the model
223236 if args .compile :
224237 model .forward = torch .compile (model .forward , mode = "max-autotune-no-cudagraphs" )
225- if args .slice_inputs :
226- assert not args .compile , "Slicing inputs requires is not the model to be compiled"
227- assert not args .use_cuda_graph , "Slicing inputs is not compatible with cuda graphs"
228238
229239 # Prepare tokenizer and dataset
230240 tokenizer = AutoTokenizer .from_pretrained (MODEL_ID , padding_side = "left" )
@@ -237,10 +247,10 @@ def batch_generate(
237247 # Prepare generation config
238248 generation_config = GenerationConfig (
239249 max_new_tokens = 512 ,
240- use_cuda_graph = args . use_cuda_graph ,
250+ use_cuda_graph = use_cuda_graph ,
241251 eos_token_id = tokenizer .pad_token_id if FORCE_MAX_LENGTH else tokenizer .eos_token_id ,
242252 pad_token_id = tokenizer .pad_token_id ,
243- do_sample = True ,
253+ do_sample = not args . compare ,
244254 temperature = 0.8 ,
245255 top_p = 0.9 ,
246256 num_blocks = args .num_blocks ,
@@ -265,7 +275,6 @@ def batch_generate(
265275 generation_config ,
266276 tokenizer ,
267277 displayed_samples = - 1 ,
268- slice_inputs = args .slice_inputs ,
269278 )
270279
271280 if args .profile is not None :
@@ -282,12 +291,11 @@ def batch_generate(
282291 displayed_samples = args .displayed ,
283292 output_file = args .output_file ,
284293 expected_outputs = expected_outputs ,
285- slice_inputs = args .slice_inputs ,
286294 )
287295 if args .profile is not None :
288296 filename = args .profile if args .profile .endswith (".json" ) else args .profile + ".json"
289297 prof .export_chrome_trace (filename )
290298
291299# Example usage:
292- # python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --slice-inputs -- samples 3 --compare
300+ # python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
293301# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json
0 commit comments