Skip to content

Commit cabbf02

Browse files
Trim to minimal files and add verbose option
1 parent d5634da commit cabbf02

File tree

11 files changed

+8
-1032
lines changed

11 files changed

+8
-1032
lines changed

chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def chat():
2929
else:
3030
prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1)
3131

32-
out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
32+
out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence', verbose=True, tokenizer=tokenize)
3333

3434
answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
3535
print(f"Bot's reply: {answer}")

evaluation/eval.sh

Lines changed: 0 additions & 18 deletions
This file was deleted.

evaluation/eval_llada.py

Lines changed: 0 additions & 250 deletions
This file was deleted.

generate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_num_transfer_tokens(mask_index, steps):
4242

4343
@ torch.no_grad()
4444
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
45-
cfg_scale=0., remasking='low_confidence', mask_id=126336):
45+
cfg_scale=0., remasking='low_confidence', mask_id=126336, verbose=False, tokenizer=None):
4646
'''
4747
Args:
4848
model: Mask predictor.
@@ -66,6 +66,9 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera
6666
assert steps % num_blocks == 0
6767
steps = steps // num_blocks
6868

69+
if verbose and tokenizer:
70+
print(tokenizer.decode(*x))
71+
6972
for num_block in range(num_blocks):
7073
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
7174
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
@@ -103,6 +106,9 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera
103106
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
104107
transfer_index[j, select_index] = True
105108
x[transfer_index] = x0[transfer_index]
109+
110+
if verbose and tokenizer:
111+
print(tokenizer.decode(*x))
106112

107113
return x
108114

0 commit comments

Comments
 (0)