@@ -388,8 +388,6 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l
388388
389389def auto_regressive_inference (tokenizer , model , x , x_stamp , y_stamp , max_context , pred_len , clip = 5 , T = 1.0 , top_k = 0 , top_p = 0.99 , sample_count = 5 , verbose = False ):
390390 with torch .no_grad ():
391- batch_size = x .size (0 )
392- initial_seq_len = x .size (1 )
393391 x = torch .clip (x , - clip , clip )
394392
395393 device = x .device
@@ -398,28 +396,42 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
398396 y_stamp = y_stamp .unsqueeze (1 ).repeat (1 , sample_count , 1 , 1 ).reshape (- 1 , y_stamp .size (1 ), y_stamp .size (2 )).to (device )
399397
400398 x_token = tokenizer .encode (x , half = True )
399+
400+ initial_seq_len = x .size (1 )
401+ batch_size = x_token [0 ].size (0 )
402+ total_seq_len = initial_seq_len + pred_len
403+ full_stamp = torch .cat ([x_stamp , y_stamp ], dim = 1 )
401404
402- def get_dynamic_stamp (x_stamp , y_stamp , current_seq_len , pred_step ):
405+ generated_pre = x_token [0 ].new_empty (batch_size , pred_len )
406+ generated_post = x_token [1 ].new_empty (batch_size , pred_len )
403407
404- if current_seq_len <= max_context - pred_step :
405- return torch .cat ([x_stamp , y_stamp [:, :pred_step , :]], dim = 1 )
406- else :
407- start_idx = max_context - pred_step
408- return torch .cat ([x_stamp [:, - start_idx :, :], y_stamp [:, :pred_step , :]], dim = 1 )
408+ pre_buffer = x_token [0 ].new_zeros (batch_size , max_context )
409+ post_buffer = x_token [1 ].new_zeros (batch_size , max_context )
410+ buffer_len = min (initial_seq_len , max_context )
411+ if buffer_len > 0 :
412+ start_idx = max (0 , initial_seq_len - max_context )
413+ pre_buffer [:, :buffer_len ] = x_token [0 ][:, start_idx :start_idx + buffer_len ]
414+ post_buffer [:, :buffer_len ] = x_token [1 ][:, start_idx :start_idx + buffer_len ]
409415
410416 if verbose :
411417 ran = trange
412418 else :
413419 ran = range
414420 for i in ran (pred_len ):
415421 current_seq_len = initial_seq_len + i
422+ window_len = min (current_seq_len , max_context )
416423
417424 if current_seq_len <= max_context :
418- input_tokens = x_token
425+ input_tokens = [
426+ pre_buffer [:, :window_len ],
427+ post_buffer [:, :window_len ]
428+ ]
419429 else :
420- input_tokens = [t [:, - max_context :]. contiguous () for t in x_token ]
430+ input_tokens = [pre_buffer , post_buffer ]
421431
422- current_stamp = get_dynamic_stamp (x_stamp , y_stamp , current_seq_len , i )
432+ context_end = current_seq_len
433+ context_start = max (0 , context_end - max_context )
434+ current_stamp = full_stamp [:, context_start :context_end , :].contiguous ()
423435
424436 s1_logits , context = model .decode_s1 (input_tokens [0 ], input_tokens [1 ], current_stamp )
425437 s1_logits = s1_logits [:, - 1 , :]
@@ -429,12 +441,28 @@ def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
429441 s2_logits = s2_logits [:, - 1 , :]
430442 sample_post = sample_from_logits (s2_logits , temperature = T , top_k = top_k , top_p = top_p , sample_logits = True )
431443
432- x_token [ 0 ] = torch . cat ([ x_token [ 0 ], sample_pre ], dim = 1 )
433- x_token [ 1 ] = torch . cat ([ x_token [ 1 ], sample_post ], dim = 1 )
444+ generated_pre [:, i ] = sample_pre . squeeze ( - 1 )
445+ generated_post [:, i ] = sample_post . squeeze ( - 1 )
434446
435- input_tokens = [t [:, - max_context :].contiguous () for t in x_token ]
447+ if current_seq_len < max_context :
448+ pre_buffer [:, current_seq_len ] = sample_pre .squeeze (- 1 )
449+ post_buffer [:, current_seq_len ] = sample_post .squeeze (- 1 )
450+ else :
451+ pre_buffer .copy_ (torch .roll (pre_buffer , shifts = - 1 , dims = 1 ))
452+ post_buffer .copy_ (torch .roll (post_buffer , shifts = - 1 , dims = 1 ))
453+ pre_buffer [:, - 1 ] = sample_pre .squeeze (- 1 )
454+ post_buffer [:, - 1 ] = sample_post .squeeze (- 1 )
455+
456+ full_pre = torch .cat ([x_token [0 ], generated_pre ], dim = 1 )
457+ full_post = torch .cat ([x_token [1 ], generated_post ], dim = 1 )
458+
459+ context_start = max (0 , total_seq_len - max_context )
460+ input_tokens = [
461+ full_pre [:, context_start :total_seq_len ].contiguous (),
462+ full_post [:, context_start :total_seq_len ].contiguous ()
463+ ]
436464 z = tokenizer .decode (input_tokens , half = True )
437- z = z .reshape (batch_size , sample_count , z .size (1 ), z .size (2 ))
465+ z = z .reshape (- 1 , sample_count , z .size (1 ), z .size (2 ))
438466 preds = z .cpu ().numpy ()
439467 preds = np .mean (preds , axis = 1 )
440468
0 commit comments