Skip to content

Commit cbf6ae0

Browse files
authored
[Fix][Bitmask] Mask dummy padded tokens for grammar (#2651)
1 parent 64d8dc6 commit cbf6ae0

File tree

4 files changed

+53
-14
lines changed

4 files changed

+53
-14
lines changed

cpp/grammar/grammar_state_matcher.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm
176176
bool AcceptStopToken();
177177

178178
friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose);
179-
friend NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher);
179+
friend NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher, int full_vocab_size);
180180

181181
std::shared_ptr<GrammarStateInitContext> init_ctx_;
182182
int max_rollback_steps_;
@@ -362,6 +362,16 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
362362
// Finally update the rejected_ids bitset
363363
bool can_reach_end = CanReachEnd();
364364
SetTokenBitmask(next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end);
365+
366+
// Up till now, we use vocab_size from `GetVocabSize()`, while `next_token_bitmask` is of
367+
// vocab_size read from `config.json`. For models like QWen2 and Phi3, the latter can be larger.
368+
// So we further mask out the dummy padded tokens.
369+
CHECK(next_token_bitmask->ndim == 1);
370+
DynamicBitset next_token_bitset(next_token_bitmask->shape[0] * 32,
371+
reinterpret_cast<uint32_t*>(next_token_bitmask->data));
372+
for (int i = init_ctx_->vocab_size; i < next_token_bitmask->shape[0] * 32; i++) {
373+
next_token_bitset.Set(i, false);
374+
}
365375
}
366376

367377
std::string GrammarStateMatcherNodeImpl::FindJumpForwardString() {
@@ -719,12 +729,12 @@ TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFindNextRejectedTokens")
719729

720730
/*!
721731
* \brief Find the bitmask for the next token as an NDArray.
732+
* \param full_vocab_size Different from `tokenizer->GetVocabSize()` or `init_ctx_->vocab_size`,
733+
* this is the vocab_size read from `config.json` that can be potentially larger.
722734
* \returns An NDArray of the bitmask for the next token of shape (bitmask_size,).
723735
*/
724-
NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher) {
725-
auto init_ctx = matcher.as<GrammarStateMatcherNodeImpl>()->init_ctx_;
726-
auto vocab_size = init_ctx->vocab_size;
727-
auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size);
736+
NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher, int full_vocab_size) {
737+
auto bitset_size = DynamicBitset::CalculateBufferSize(full_vocab_size);
728738
auto bitmask = NDArray::Empty(ShapeTuple{static_cast<long>(bitset_size)},
729739
DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0});
730740
auto dltensor = const_cast<DLTensor*>(bitmask.operator->());

cpp/tokenizers/tokenizers.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ class TokenizerObj : public Object {
8787
const DynamicBitset& GetPrefixTokenMask();
8888

8989
/*!
90-
* \brief Returns the vocabulary size. Special tokens are considered.
90+
* \brief Returns the vocabulary size. Special tokens are considered. This may be smaller than the
91+
* `vocab_size` in config.json (length of logits), see https://github.com/QwenLM/Qwen2/issues/147
92+
* and https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/discussions/47.
9193
*/
9294
size_t GetVocabSize() const;
9395

python/mlc_llm/grammar/grammar.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,16 +310,22 @@ def find_next_rejected_tokens(self, verbose: bool = False) -> List[int]:
310310

311311
return _ffi_api.GrammarStateMatcherFindNextRejectedTokens(self, verbose) # type: ignore # pylint: disable=no-member
312312

313-
def find_next_token_bitmask_as_ndarray(self) -> tvm.nd.array:
314-
"""Find the ids of the rejected tokens for the next step.
313+
def find_next_token_bitmask_as_ndarray(self, full_vocab_size: int) -> tvm.nd.array:
314+
"""Find the bitmask for the next step.
315+
316+
Parameters
317+
----------
318+
full_vocab_size: int
319+
Different from `tokenizer->GetVocabSize()` or `init_ctx_->vocab_size`, this is the
320+
vocab_size read from `config.json` that can be potentially larger.
315321
316322
Returns
317323
-------
318-
rejected_token_ids : List[int]
319-
A list of rejected token ids.
324+
bitmask_ndarray : tvm.nd.array
325+
Bitmask for the next step.
320326
"""
321327

322-
return _ffi_api.GrammarStateMatcherFindNextTokenBitmaskAsNDArray(self) # type: ignore # pylint: disable=no-member
328+
return _ffi_api.GrammarStateMatcherFindNextTokenBitmaskAsNDArray(self, full_vocab_size) # type: ignore # pylint: disable=no-member
323329

324330
def find_jump_forward_string(self) -> str:
325331
"""Find the jump-forward string for jump-forward decoding. This is the longest string that

python/mlc_llm/interface/gen_config.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,11 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
186186
fast_tokenizer = AutoTokenizer.from_pretrained(str(config.parent), use_fast=True)
187187
fast_tokenizer.backend_tokenizer.save(str(tokenizer_json_save_dest))
188188
mlc_chat_config.tokenizer_files.append("tokenizer.json")
189-
logger.info("Succesfully converted `tokenizer.model` to: %s", tokenizer_json_save_dest)
189+
logger.info("Successfully converted `tokenizer.model` to: %s", tokenizer_json_save_dest)
190190
except Exception: # pylint: disable=broad-exception-caught
191191
logger.warning(
192-
"Convertion to `tokenizer.json` %s with the exception below. "
193-
"Skipping the conversion. Tokenizer will only use `tokenizer.model`",
192+
"Converting to `tokenizer.json` %s with the exception below. "
193+
"Skipping the conversion.",
194194
FAILED,
195195
exc_info=True,
196196
)
@@ -216,6 +216,27 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
216216
mlc_chat_config.tokenizer_info = asdict(Tokenizer.detect_tokenizer_info(str(output)))
217217
logger.info("Detected tokenizer info: %s", mlc_chat_config.tokenizer_info)
218218

219+
# 3.5. Ensure added_tokens do not have duplicated added_tokens, a mistake from model releaser
220+
# that affects correctness of huggingface tokenizer.
221+
# See https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/discussions/15.
222+
if tokenizer_json_file.exists():
223+
with open(tokenizer_json_file, "r") as f:
224+
tokenizer_json = json.load(f)
225+
if "added_tokens" in tokenizer_json:
226+
appeared_content = set()
227+
for added_token in tokenizer_json["added_tokens"]:
228+
content = added_token["content"]
229+
if content in appeared_content:
230+
logger.exception(
231+
"%s with incorrect tokenizer.json which has duplicated token %s. "
232+
"This affects correctness of huggingface tokenizer during runtime, "
233+
"please check your tokenizer.json to remove duplication manually.",
234+
FAILED,
235+
content,
236+
)
237+
raise ValueError("Duplicated vocab in tokenizer.json")
238+
appeared_content.add(content)
239+
219240
# Step 4. Load system default value
220241
apply_system_defaults_for_missing_fields(mlc_chat_config)
221242
# Step 5. Dump the configuration file to output directory

0 commit comments

Comments
 (0)