@@ -42,7 +42,7 @@ def get_num_transfer_tokens(mask_index, steps):
4242
4343@ torch .no_grad ()
4444def generate (model , prompt , steps = 128 , gen_length = 128 , block_length = 128 , temperature = 0. ,
45- cfg_scale = 0. , remasking = 'low_confidence' , mask_id = 126336 ):
45+ cfg_scale = 0. , remasking = 'low_confidence' , mask_id = 126336 , verbose = False , tokenizer = None ):
4646 '''
4747 Args:
4848 model: Mask predictor.
@@ -66,6 +66,9 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera
6666 assert steps % num_blocks == 0
6767 steps = steps // num_blocks
6868
69+ if verbose and tokenizer :
70+ print (tokenizer .decode (* x ))
71+
6972 for num_block in range (num_blocks ):
7073 block_mask_index = (x [:, prompt .shape [1 ] + num_block * block_length : prompt .shape [1 ] + (num_block + 1 ) * block_length :] == mask_id )
7174 num_transfer_tokens = get_num_transfer_tokens (block_mask_index , steps )
@@ -103,6 +106,9 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera
103106 _ , select_index = torch .topk (confidence [j ], k = num_transfer_tokens [j , i ])
104107 transfer_index [j , select_index ] = True
105108 x [transfer_index ] = x0 [transfer_index ]
109+
110+ if verbose and tokenizer :
111+ print (tokenizer .decode (* x ))
106112
107113 return x
108114
0 commit comments