Skip to content

Commit c408129

Browse files
committed
Refactor auto_regressive_inference to reduce memory allocations and cpu-gpy syncs.
1 parent 6456913 commit c408129

File tree

1 file changed

+44
-15
lines changed

1 file changed

+44
-15
lines changed

model/kronos.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -388,38 +388,51 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l
388388

389389
def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False):
390390
with torch.no_grad():
391-
batch_size = x.size(0)
392-
initial_seq_len = x.size(1)
393391
x = torch.clip(x, -clip, clip)
394392

395393
device = x.device
396394
x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device)
397395
x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device)
398396
y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)
399397

398+
batch_size = x.size(0)
399+
initial_seq_len = x.size(1)
400+
400401
x_token = tokenizer.encode(x, half=True)
402+
batch_size = x_token[0].size(0)
403+
total_seq_len = initial_seq_len + pred_len
404+
full_stamp = torch.cat([x_stamp, y_stamp], dim=1)
401405

402-
def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
406+
generated_pre = x_token[0].new_empty(batch_size, pred_len)
407+
generated_post = x_token[1].new_empty(batch_size, pred_len)
403408

404-
if current_seq_len <= max_context - pred_step:
405-
return torch.cat([x_stamp, y_stamp[:, :pred_step, :]], dim=1)
406-
else:
407-
start_idx = max_context - pred_step
408-
return torch.cat([x_stamp[:, -start_idx:, :], y_stamp[:, :pred_step, :]], dim=1)
409+
pre_buffer = x_token[0].new_zeros(batch_size, max_context)
410+
post_buffer = x_token[1].new_zeros(batch_size, max_context)
411+
buffer_len = min(initial_seq_len, max_context)
412+
if buffer_len > 0:
413+
start_idx = max(0, initial_seq_len - max_context)
414+
pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len]
415+
post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len]
409416

410417
if verbose:
411418
ran = trange
412419
else:
413420
ran = range
414421
for i in ran(pred_len):
415422
current_seq_len = initial_seq_len + i
423+
window_len = min(current_seq_len, max_context)
416424

417425
if current_seq_len <= max_context:
418-
input_tokens = x_token
426+
input_tokens = [
427+
pre_buffer[:, :window_len],
428+
post_buffer[:, :window_len]
429+
]
419430
else:
420-
input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
431+
input_tokens = [pre_buffer, post_buffer]
421432

422-
current_stamp = get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, i)
433+
context_end = current_seq_len
434+
context_start = max(0, context_end - max_context)
435+
current_stamp = full_stamp[:, context_start:context_end, :].contiguous()
423436

424437
s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp)
425438
s1_logits = s1_logits[:, -1, :]
@@ -429,12 +442,28 @@ def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
429442
s2_logits = s2_logits[:, -1, :]
430443
sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
431444

432-
x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
433-
x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
445+
generated_pre[:, i] = sample_pre.squeeze(-1)
446+
generated_post[:, i] = sample_post.squeeze(-1)
434447

435-
input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
448+
if current_seq_len < max_context:
449+
pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1)
450+
post_buffer[:, current_seq_len] = sample_post.squeeze(-1)
451+
else:
452+
pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1))
453+
post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1))
454+
pre_buffer[:, -1] = sample_pre.squeeze(-1)
455+
post_buffer[:, -1] = sample_post.squeeze(-1)
456+
457+
full_pre = torch.cat([x_token[0], generated_pre], dim=1)
458+
full_post = torch.cat([x_token[1], generated_post], dim=1)
459+
460+
context_start = max(0, total_seq_len - max_context)
461+
input_tokens = [
462+
full_pre[:, context_start:total_seq_len].contiguous(),
463+
full_post[:, context_start:total_seq_len].contiguous()
464+
]
436465
z = tokenizer.decode(input_tokens, half=True)
437-
z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
466+
z = z.reshape(-1, sample_count, z.size(1), z.size(2))
438467
preds = z.cpu().numpy()
439468
preds = np.mean(preds, axis=1)
440469

0 commit comments

Comments
 (0)