Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions cpp/grammar/grammar_state_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm
bool AcceptStopToken();

friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose);
friend NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher);
friend NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher, int full_vocab_size);

std::shared_ptr<GrammarStateInitContext> init_ctx_;
int max_rollback_steps_;
Expand Down Expand Up @@ -362,6 +362,16 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
// Finally update the rejected_ids bitset
bool can_reach_end = CanReachEnd();
SetTokenBitmask(next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end);

// Up till now, we use vocab_size from `GetVocabSize()`, while `next_token_bitmask` is of
// vocab_size read from `config.json`. For models like QWen2 and Phi3, the latter can be larger.
// So we further mask out the dummy padded tokens.
CHECK(next_token_bitmask->ndim == 1);
DynamicBitset next_token_bitset(next_token_bitmask->shape[0] * 32,
reinterpret_cast<uint32_t*>(next_token_bitmask->data));
for (int i = init_ctx_->vocab_size; i < next_token_bitmask->shape[0] * 32; i++) {
next_token_bitset.Set(i, false);
}
}

std::string GrammarStateMatcherNodeImpl::FindJumpForwardString() {
Expand Down Expand Up @@ -719,12 +729,12 @@ TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFindNextRejectedTokens")

/*!
* \brief Find the bitmask for the next token as an NDArray.
* \param full_vocab_size Different from `tokenizer->GetVocabSize()` or `init_ctx_->vocab_size`,
* this is the vocab_size read from `config.json` that can be potentially larger.
* \returns An NDArray of the bitmask for the next token of shape (bitmask_size,).
*/
NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher) {
auto init_ctx = matcher.as<GrammarStateMatcherNodeImpl>()->init_ctx_;
auto vocab_size = init_ctx->vocab_size;
auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size);
NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher, int full_vocab_size) {
auto bitset_size = DynamicBitset::CalculateBufferSize(full_vocab_size);
auto bitmask = NDArray::Empty(ShapeTuple{static_cast<long>(bitset_size)},
DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0});
auto dltensor = const_cast<DLTensor*>(bitmask.operator->());
Expand Down
4 changes: 3 additions & 1 deletion cpp/tokenizers/tokenizers.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ class TokenizerObj : public Object {
const DynamicBitset& GetPrefixTokenMask();

/*!
* \brief Returns the vocabulary size. Special tokens are considered.
* \brief Returns the vocabulary size. Special tokens are considered. This may be smaller than the
* `vocab_size` in config.json (length of logits), see https://github.com/QwenLM/Qwen2/issues/147
* and https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/discussions/47.
*/
size_t GetVocabSize() const;

Expand Down
16 changes: 11 additions & 5 deletions python/mlc_llm/grammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,16 +310,22 @@ def find_next_rejected_tokens(self, verbose: bool = False) -> List[int]:

return _ffi_api.GrammarStateMatcherFindNextRejectedTokens(self, verbose) # type: ignore # pylint: disable=no-member

def find_next_token_bitmask_as_ndarray(self) -> tvm.nd.array:
"""Find the ids of the rejected tokens for the next step.
def find_next_token_bitmask_as_ndarray(self, full_vocab_size: int) -> tvm.nd.array:
"""Find the bitmask for the next step.
Parameters
----------
full_vocab_size: int
Different from `tokenizer->GetVocabSize()` or `init_ctx_->vocab_size`, this is the
vocab_size read from `config.json` that can be potentially larger.
Returns
-------
rejected_token_ids : List[int]
A list of rejected token ids.
bitmask_ndarray : tvm.nd.array
Bitmask for the next step.
"""

return _ffi_api.GrammarStateMatcherFindNextTokenBitmaskAsNDArray(self) # type: ignore # pylint: disable=no-member
return _ffi_api.GrammarStateMatcherFindNextTokenBitmaskAsNDArray(self, full_vocab_size) # type: ignore # pylint: disable=no-member

def find_jump_forward_string(self) -> str:
"""Find the jump-forward string for jump-forward decoding. This is the longest string that
Expand Down
27 changes: 24 additions & 3 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
fast_tokenizer = AutoTokenizer.from_pretrained(str(config.parent), use_fast=True)
fast_tokenizer.backend_tokenizer.save(str(tokenizer_json_save_dest))
mlc_chat_config.tokenizer_files.append("tokenizer.json")
logger.info("Succesfully converted `tokenizer.model` to: %s", tokenizer_json_save_dest)
logger.info("Successfully converted `tokenizer.model` to: %s", tokenizer_json_save_dest)
except Exception: # pylint: disable=broad-exception-caught
logger.warning(
"Convertion to `tokenizer.json` %s with the exception below. "
"Skipping the conversion. Tokenizer will only use `tokenizer.model`",
"Converting to `tokenizer.json` %s with the exception below. "
"Skipping the conversion.",
FAILED,
exc_info=True,
)
Expand All @@ -216,6 +216,27 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
mlc_chat_config.tokenizer_info = asdict(Tokenizer.detect_tokenizer_info(str(output)))
logger.info("Detected tokenizer info: %s", mlc_chat_config.tokenizer_info)

# 3.5. Ensure added_tokens do not have duplicated added_tokens, a mistake from model releaser
# that affects correctness of huggingface tokenizer.
# See https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/discussions/15.
if tokenizer_json_file.exists():
with open(tokenizer_json_file, "r") as f:
tokenizer_json = json.load(f)
if "added_tokens" in tokenizer_json:
appeared_content = set()
for added_token in tokenizer_json["added_tokens"]:
content = added_token["content"]
if content in appeared_content:
logger.exception(
"%s with incorrect tokenizer.json which has duplicated token %s. "
"This affects correctness of huggingface tokenizer during runtime, "
"please check your tokenizer.json to remove duplication manually.",
FAILED,
content,
)
raise ValueError("Duplicated vocab in tokenizer.json")
appeared_content.add(content)

# Step 4. Load system default value
apply_system_defaults_for_missing_fields(mlc_chat_config)
# Step 5. Dump the configuration file to output directory
Expand Down