@@ -388,38 +388,51 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l
388388
389389def auto_regressive_inference (tokenizer , model , x , x_stamp , y_stamp , max_context , pred_len , clip = 5 , T = 1.0 , top_k = 0 , top_p = 0.99 , sample_count = 5 , verbose = False ):
390390 with torch .no_grad ():
391- batch_size = x .size (0 )
392- initial_seq_len = x .size (1 )
393391 x = torch .clip (x , - clip , clip )
394392
395393 device = x .device
396394 x = x .unsqueeze (1 ).repeat (1 , sample_count , 1 , 1 ).reshape (- 1 , x .size (1 ), x .size (2 )).to (device )
397395 x_stamp = x_stamp .unsqueeze (1 ).repeat (1 , sample_count , 1 , 1 ).reshape (- 1 , x_stamp .size (1 ), x_stamp .size (2 )).to (device )
398396 y_stamp = y_stamp .unsqueeze (1 ).repeat (1 , sample_count , 1 , 1 ).reshape (- 1 , y_stamp .size (1 ), y_stamp .size (2 )).to (device )
399397
398+ batch_size = x .size (0 )
399+ initial_seq_len = x .size (1 )
400+
400401 x_token = tokenizer .encode (x , half = True )
402+ batch_size = x_token [0 ].size (0 )
403+ total_seq_len = initial_seq_len + pred_len
404+ full_stamp = torch .cat ([x_stamp , y_stamp ], dim = 1 )
401405
402- def get_dynamic_stamp (x_stamp , y_stamp , current_seq_len , pred_step ):
406+ generated_pre = x_token [0 ].new_empty (batch_size , pred_len )
407+ generated_post = x_token [1 ].new_empty (batch_size , pred_len )
403408
404- if current_seq_len <= max_context - pred_step :
405- return torch .cat ([x_stamp , y_stamp [:, :pred_step , :]], dim = 1 )
406- else :
407- start_idx = max_context - pred_step
408- return torch .cat ([x_stamp [:, - start_idx :, :], y_stamp [:, :pred_step , :]], dim = 1 )
409+ pre_buffer = x_token [0 ].new_zeros (batch_size , max_context )
410+ post_buffer = x_token [1 ].new_zeros (batch_size , max_context )
411+ buffer_len = min (initial_seq_len , max_context )
412+ if buffer_len > 0 :
413+ start_idx = max (0 , initial_seq_len - max_context )
414+ pre_buffer [:, :buffer_len ] = x_token [0 ][:, start_idx :start_idx + buffer_len ]
415+ post_buffer [:, :buffer_len ] = x_token [1 ][:, start_idx :start_idx + buffer_len ]
409416
410417 if verbose :
411418 ran = trange
412419 else :
413420 ran = range
414421 for i in ran (pred_len ):
415422 current_seq_len = initial_seq_len + i
423+ window_len = min (current_seq_len , max_context )
416424
417425 if current_seq_len <= max_context :
418- input_tokens = x_token
426+ input_tokens = [
427+ pre_buffer [:, :window_len ],
428+ post_buffer [:, :window_len ]
429+ ]
419430 else :
420- input_tokens = [t [:, - max_context :]. contiguous () for t in x_token ]
431+ input_tokens = [pre_buffer , post_buffer ]
421432
422- current_stamp = get_dynamic_stamp (x_stamp , y_stamp , current_seq_len , i )
433+ context_end = current_seq_len
434+ context_start = max (0 , context_end - max_context )
435+ current_stamp = full_stamp [:, context_start :context_end , :].contiguous ()
423436
424437 s1_logits , context = model .decode_s1 (input_tokens [0 ], input_tokens [1 ], current_stamp )
425438 s1_logits = s1_logits [:, - 1 , :]
@@ -429,12 +442,28 @@ def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
429442 s2_logits = s2_logits [:, - 1 , :]
430443 sample_post = sample_from_logits (s2_logits , temperature = T , top_k = top_k , top_p = top_p , sample_logits = True )
431444
432- x_token [ 0 ] = torch . cat ([ x_token [ 0 ], sample_pre ], dim = 1 )
433- x_token [ 1 ] = torch . cat ([ x_token [ 1 ], sample_post ], dim = 1 )
445+ generated_pre [:, i ] = sample_pre . squeeze ( - 1 )
446+ generated_post [:, i ] = sample_post . squeeze ( - 1 )
434447
435- input_tokens = [t [:, - max_context :].contiguous () for t in x_token ]
448+ if current_seq_len < max_context :
449+ pre_buffer [:, current_seq_len ] = sample_pre .squeeze (- 1 )
450+ post_buffer [:, current_seq_len ] = sample_post .squeeze (- 1 )
451+ else :
452+ pre_buffer .copy_ (torch .roll (pre_buffer , shifts = - 1 , dims = 1 ))
453+ post_buffer .copy_ (torch .roll (post_buffer , shifts = - 1 , dims = 1 ))
454+ pre_buffer [:, - 1 ] = sample_pre .squeeze (- 1 )
455+ post_buffer [:, - 1 ] = sample_post .squeeze (- 1 )
456+
457+ full_pre = torch .cat ([x_token [0 ], generated_pre ], dim = 1 )
458+ full_post = torch .cat ([x_token [1 ], generated_post ], dim = 1 )
459+
460+ context_start = max (0 , total_seq_len - max_context )
461+ input_tokens = [
462+ full_pre [:, context_start :total_seq_len ].contiguous (),
463+ full_post [:, context_start :total_seq_len ].contiguous ()
464+ ]
436465 z = tokenizer .decode (input_tokens , half = True )
437- z = z .reshape (batch_size , sample_count , z .size (1 ), z .size (2 ))
466+ z = z .reshape (- 1 , sample_count , z .size (1 ), z .size (2 ))
438467 preds = z .cpu ().numpy ()
439468 preds = np .mean (preds , axis = 1 )
440469
0 commit comments