@@ -176,7 +176,7 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm
176176 bool AcceptStopToken ();
177177
178178 friend IntTuple FindNextRejectedTokens (GrammarStateMatcher matcher, bool verbose);
179- friend NDArray FindNextTokenBitmaskAsNDArray (GrammarStateMatcher matcher);
179+ friend NDArray FindNextTokenBitmaskAsNDArray (GrammarStateMatcher matcher, int full_vocab_size );
180180
181181 std::shared_ptr<GrammarStateInitContext> init_ctx_;
182182 int max_rollback_steps_;
@@ -362,6 +362,16 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
362362 // Finally update the rejected_ids bitset
363363 bool can_reach_end = CanReachEnd ();
364364 SetTokenBitmask (next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end);
365+
366+ // Up till now, we use vocab_size from `GetVocabSize()`, while `next_token_bitmask` is of
367+ // vocab_size read from `config.json`. For models like QWen2 and Phi3, the latter can be larger.
368+ // So we further mask out the dummy padded tokens.
369+ CHECK (next_token_bitmask->ndim == 1 );
370+ DynamicBitset next_token_bitset (next_token_bitmask->shape [0 ] * 32 ,
371+ reinterpret_cast <uint32_t *>(next_token_bitmask->data ));
372+ for (int i = init_ctx_->vocab_size ; i < next_token_bitmask->shape [0 ] * 32 ; i++) {
373+ next_token_bitset.Set (i, false );
374+ }
365375}
366376
367377std::string GrammarStateMatcherNodeImpl::FindJumpForwardString () {
@@ -719,12 +729,12 @@ TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFindNextRejectedTokens")
719729
720730/* !
721731 * \brief Find the bitmask for the next token as an NDArray.
732+ * \param full_vocab_size Different from `tokenizer->GetVocabSize()` or `init_ctx_->vocab_size`,
733+ * this is the vocab_size read from `config.json` that can be potentially larger.
722734 * \returns An NDArray of the bitmask for the next token of shape (bitmask_size,).
723735 */
724- NDArray FindNextTokenBitmaskAsNDArray (GrammarStateMatcher matcher) {
725- auto init_ctx = matcher.as <GrammarStateMatcherNodeImpl>()->init_ctx_ ;
726- auto vocab_size = init_ctx->vocab_size ;
727- auto bitset_size = DynamicBitset::CalculateBufferSize (vocab_size);
736+ NDArray FindNextTokenBitmaskAsNDArray (GrammarStateMatcher matcher, int full_vocab_size) {
737+ auto bitset_size = DynamicBitset::CalculateBufferSize (full_vocab_size);
728738 auto bitmask = NDArray::Empty (ShapeTuple{static_cast <long >(bitset_size)},
729739 DLDataType{kDLUInt , 32 , 1 }, DLDevice{kDLCPU , 0 });
730740 auto dltensor = const_cast <DLTensor*>(bitmask.operator ->());
0 commit comments