Skip to content

Commit b62f780

Browse files
committed
Refactor auto_regressive_inference to reduce memory allocations and cpu-gpy syncs.
1 parent 6456913 commit b62f780

File tree

1 file changed

+43
-15
lines changed

1 file changed

+43
-15
lines changed

model/kronos.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,6 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l
388388

389389
def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False):
390390
with torch.no_grad():
391-
batch_size = x.size(0)
392-
initial_seq_len = x.size(1)
393391
x = torch.clip(x, -clip, clip)
394392

395393
device = x.device
@@ -398,28 +396,42 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
398396
y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)
399397

400398
x_token = tokenizer.encode(x, half=True)
399+
400+
initial_seq_len = x.size(1)
401+
batch_size = x_token[0].size(0)
402+
total_seq_len = initial_seq_len + pred_len
403+
full_stamp = torch.cat([x_stamp, y_stamp], dim=1)
401404

402-
def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
405+
generated_pre = x_token[0].new_empty(batch_size, pred_len)
406+
generated_post = x_token[1].new_empty(batch_size, pred_len)
403407

404-
if current_seq_len <= max_context - pred_step:
405-
return torch.cat([x_stamp, y_stamp[:, :pred_step, :]], dim=1)
406-
else:
407-
start_idx = max_context - pred_step
408-
return torch.cat([x_stamp[:, -start_idx:, :], y_stamp[:, :pred_step, :]], dim=1)
408+
pre_buffer = x_token[0].new_zeros(batch_size, max_context)
409+
post_buffer = x_token[1].new_zeros(batch_size, max_context)
410+
buffer_len = min(initial_seq_len, max_context)
411+
if buffer_len > 0:
412+
start_idx = max(0, initial_seq_len - max_context)
413+
pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len]
414+
post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len]
409415

410416
if verbose:
411417
ran = trange
412418
else:
413419
ran = range
414420
for i in ran(pred_len):
415421
current_seq_len = initial_seq_len + i
422+
window_len = min(current_seq_len, max_context)
416423

417424
if current_seq_len <= max_context:
418-
input_tokens = x_token
425+
input_tokens = [
426+
pre_buffer[:, :window_len],
427+
post_buffer[:, :window_len]
428+
]
419429
else:
420-
input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
430+
input_tokens = [pre_buffer, post_buffer]
421431

422-
current_stamp = get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, i)
432+
context_end = current_seq_len
433+
context_start = max(0, context_end - max_context)
434+
current_stamp = full_stamp[:, context_start:context_end, :].contiguous()
423435

424436
s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp)
425437
s1_logits = s1_logits[:, -1, :]
@@ -429,12 +441,28 @@ def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
429441
s2_logits = s2_logits[:, -1, :]
430442
sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
431443

432-
x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
433-
x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
444+
generated_pre[:, i] = sample_pre.squeeze(-1)
445+
generated_post[:, i] = sample_post.squeeze(-1)
434446

435-
input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
447+
if current_seq_len < max_context:
448+
pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1)
449+
post_buffer[:, current_seq_len] = sample_post.squeeze(-1)
450+
else:
451+
pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1))
452+
post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1))
453+
pre_buffer[:, -1] = sample_pre.squeeze(-1)
454+
post_buffer[:, -1] = sample_post.squeeze(-1)
455+
456+
full_pre = torch.cat([x_token[0], generated_pre], dim=1)
457+
full_post = torch.cat([x_token[1], generated_post], dim=1)
458+
459+
context_start = max(0, total_seq_len - max_context)
460+
input_tokens = [
461+
full_pre[:, context_start:total_seq_len].contiguous(),
462+
full_post[:, context_start:total_seq_len].contiguous()
463+
]
436464
z = tokenizer.decode(input_tokens, half=True)
437-
z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
465+
z = z.reshape(-1, sample_count, z.size(1), z.size(2))
438466
preds = z.cpu().numpy()
439467
preds = np.mean(preds, axis=1)
440468

0 commit comments

Comments
 (0)