OpenVINO Tokenizers adds text processing operations to OpenVINO.
- Perform tokenization and detokenization without third-party dependencies
- Convert a HuggingFace tokenizer into OpenVINO model tokenizer and detokenizer
- Combine OpenVINO models into a single model
- Add greedy decoding pipeline to text generation model
(Recommended) Create and activate virtual env:
python3 -m venv venv source venv/bin/activate # or conda create --name openvino_tokenizers conda activate openvino_tokenizersUse minimal installation when you have a converted OpenVINO tokenizer:
pip install openvino-tokenizers # or conda install -c conda-forge openvino openvino-tokenizersIf you want to convert HuggingFace tokenizers into OpenVINO tokenizers:
pip install openvino-tokenizers[transformers] # or conda install -c conda-forge openvino openvino-tokenizers && pip install transformers[sentencepiece] tiktokenUse openvino-tokenizers[transformers] to install tokenizers conversion dependencies.
pip install --pre -U openvino openvino-tokenizers --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightlyopenvino-tokenizers build depends on openvino package which will be automatically installed from PyPI during the build process. To install unreleased versions, you would need to install openvino package from the nightly distribution channel using --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers pip install . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightlyThis command is the equivalent of minimal installation. Install tokenizers conversion dependencies if needed:
pip install transformers[sentencepiece] tiktokenInstall OpenVINO archive distribution. Use --no-deps to avoid OpenVINO installation from PyPI into your current environment. --extra-index-url is needed to resolve build dependencies only.
source path/to/installed/openvino/setupvars.sh git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers pip install --no-deps . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightlyThis command is the equivalent of minimal installation. Install tokenizers conversion dependencies if needed:
pip install transformers[sentencepiece] tiktokengit clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers pip install -e .[all] --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly # verify installation by running tests cd tests/ pytest .Install OpenVINO archive distribution. Use --no-deps to avoid OpenVINO installation from PyPI into your current environment. --extra-index-url is needed to resolve build dependencies only.
source path/to/installed/openvino/setupvars.sh git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers pip install -e .[all] --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly # verify installation by running tests cd tests/ pytest .You can use converted tokenizers in C++ pipelines with prebuild binaries.
- Download OpenVINO archive distribution for your OS from here and extract the archive.
- Download OpenVINO Tokenizers prebuild libraries from here. To ensure compatibility first three numbers of OpenVINO Tokenizers version should match OpenVINO version and OS.
- Extract OpenVINO Tokenizers archive into OpenVINO installation directory. OpenVINO Tokenizers archive maintains the structure to be aligned with OpenVINO archive:
- Windows:
<openvino_dir>\runtime\bin\intel64\Release\ - MacOS_x86:
<openvino_dir>/runtime/lib/intel64/Release - MacOS_arm64:
<openvino_dir>/runtime/lib/arm64/Release/ - Linux_x86:
<openvino_dir>/runtime/lib/intel64/ - Linux_arm64:
<openvino_dir>/runtime/lib/aarch64/
- Windows:
After that you can add binary extension in the code with:
core.add_extension("openvino_tokenizers.dll")for Windowscore.add_extension("libopenvino_tokenizers.dylib")for MacOScore.add_extension("libopenvino_tokenizers.so")for Linux
and read/compile converted (de)tokenizers models. If you use version 2023.3.0.0, the binary extension file is called (lib)user_ov_extension.(dll/dylib/so).
To build OpenVINO Tokenizers binaries locally, use this command:
source path/to/installed/openvino/setupvars.sh git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers mkdir build && cd build cmake -DCMAKE_BUILD_TYPE=Release .. makeAfter that, you can transfer all binaries from build/src to <openvino_dir> as described in the C++ installation instruction above.
CPU device only.
OpenVINO Tokenizers ships with CLI tool that can convert tokenizers from Huggingface Hub or Huggingface tokenizers saved on disk:
convert_tokenizer codellama/CodeLlama-7b-hf --with-detokenizer -o output_dirThere is also convert_tokenizer function that can convert tokenizer python object.
import numpy as np from transformers import AutoTokenizer from openvino import compile_model, save_model from openvino_tokenizers import convert_tokenizer hf_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") ov_tokenizer = convert_tokenizer(hf_tokenizer) compiled_tokenzier = compile_model(ov_tokenizer) text_input = ["Test string"] hf_output = hf_tokenizer(text_input, return_tensors="np") ov_output = compiled_tokenzier(text_input) for output_name in hf_output: print(f"OpenVINO {output_name} = {ov_output[output_name]}") print(f"HuggingFace {output_name} = {hf_output[output_name]}") # OpenVINO input_ids = [[ 101 3231 5164 102]] # HuggingFace input_ids = [[ 101 3231 5164 102]] # OpenVINO token_type_ids = [[0 0 0 0]] # HuggingFace token_type_ids = [[0 0 0 0]] # OpenVINO attention_mask = [[1 1 1 1]] # HuggingFace attention_mask = [[1 1 1 1]] # save tokenizer for later use save_model(ov_tokenizer, "openvino_tokenizer.xml") loaded_tokenizer = compile_model("openvino_tokenizer.xml") loaded_ov_output = loaded_tokenizer(text_input) for output_name in hf_output: assert np.all(loaded_ov_output[output_name] == ov_output[output_name])To infer and convert the original model, install torch or torch-cpu to the virtual environment.
from transformers import AutoTokenizer, AutoModelForSequenceClassification from openvino import compile_model, convert_model from openvino_tokenizers import convert_tokenizer, connect_models checkpoint = "mrm8488/bert-tiny-finetuned-sms-spam-detection" hf_tokenizer = AutoTokenizer.from_pretrained(checkpoint) hf_model = AutoModelForSequenceClassification.from_pretrained(checkpoint) text_input = ["Free money!!!"] hf_input = hf_tokenizer(text_input, return_tensors="pt") hf_output = hf_model(**hf_input) ov_tokenizer = convert_tokenizer(hf_tokenizer) ov_model = convert_model(hf_model, example_input=hf_input.data) combined_model = connect_models(ov_tokenizer, ov_model) compiled_combined_model = compile_model(combined_model) openvino_output = compiled_combined_model(text_input) print(f"OpenVINO logits: {openvino_output['logits']}") # OpenVINO logits: [[ 1.2007061 -1.4698029]] print(f"HuggingFace logits {hf_output.logits}") # HuggingFace logits tensor([[ 1.2007, -1.4698]], grad_fn=<AddmmBackward0>)Import openvino_tokenizers will register tokenizer-related operations to OpenVINO, after which you can work with saved tokenizers and detokenizers.
import numpy as np import openvino_tokenizers from openvino import Core core = Core() # detokenizer from codellama sentencepiece model compiled_detokenizer = core.compile_model("detokenizer.xml") token_ids = np.random.randint(100, 1000, size=(3, 5)) openvino_output = compiled_detokenizer(token_ids) print(openvino_output["string_output"]) # ['sc�ouition�', 'intvenord hasient', 'g shouldwer M more']import numpy as np from openvino import compile_model, convert_model from openvino_tokenizers import add_greedy_decoding, convert_tokenizer from transformers import AutoModelForCausalLM, AutoTokenizer model_checkpoint = "JackFram/llama-68m" hf_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) hf_model = AutoModelForCausalLM.from_pretrained(model_checkpoint, use_cache=False) # convert hf tokenizer text_input = ["Quick brown fox jumped "] ov_tokenizer, ov_detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True) compiled_tokenizer = compile_model(ov_tokenizer) # transform input text into tokens ov_input = compiled_tokenizer(text_input) hf_input = hf_tokenizer(text_input, return_tensors="pt") # convert Pytorch model to OpenVINO IR and add greedy decoding pipeline to it ov_model = convert_model(hf_model, example_input=hf_input.data) ov_model_with_greedy_decoding = add_greedy_decoding(ov_model) compiled_model = compile_model(ov_model_with_greedy_decoding) # generate new tokens new_tokens_size = 10 prompt_size = ov_input["input_ids"].shape[-1] input_dict = { output.any_name: np.hstack([tensor, np.zeros(shape=(1, new_tokens_size), dtype=np.int_)]) for output, tensor in ov_input.items() } for idx in range(prompt_size, prompt_size + new_tokens_size): output = compiled_model(input_dict)["token_ids"] input_dict["input_ids"][:, idx] = output[:, idx - 1] input_dict["attention_mask"][:, idx] = 1 ov_token_ids = input_dict["input_ids"] hf_token_ids = hf_model.generate( **hf_input, min_new_tokens=new_tokens_size, max_new_tokens=new_tokens_size, temperature=0, # greedy decoding ) # decode model output compiled_detokenizer = compile_model(ov_detokenizer) ov_output = compiled_detokenizer(ov_token_ids)["string_output"] hf_output = hf_tokenizer.batch_decode(hf_token_ids, skip_special_tokens=True) print(f"OpenVINO output string: `{ov_output}`") # OpenVINO output string: `['Quick brown fox was walking through the forest. He was looking for something']` print(f"HuggingFace output string: `{hf_output}`") # HuggingFace output string: `['Quick brown fox was walking through the forest. He was looking for something']`OpenVINO Tokenizers include converters for certain TensorFlow Text operations. Currently, only the MUSE model is supported. Here is an example of model conversion and inference:
import numpy as np import tensorflow_hub as hub import tensorflow_text # register tf text ops from openvino import convert_model, compile_model import openvino_tokenizers # register ov tokenizer ops and translators sentences = ["dog", "I cuccioli sono carini.", "私は犬と一緒にビーチを散歩するのが好きです"] tf_embed = hub.load( "https://www.kaggle.com/models/google/universal-sentence-encoder/frameworks/" "TensorFlow2/variations/multilingual/versions/2" ) # convert model that uses Sentencepiece tokenizer op from TF Text ov_model = convert_model(tf_embed) ov_embed = compile_model(ov_model, "CPU") ov_result = ov_embed(sentences)[ov_embed.output()] tf_result = tf_embed(sentences) assert np.all(np.isclose(ov_result, tf_result, atol=1e-4))from urllib.request import urlopen from openvino import compile_model from openvino_tokenizers import build_rwkv_tokenizer rwkv_vocab_url = ( "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/tokenizer/rwkv_vocab_v20230424.txt" ) with urlopen(rwkv_vocab_url) as vocab_file: vocab = map(bytes.decode, vocab_file) tokenizer, detokenizer = build_rwkv_tokenizer(vocab) tokenizer, detokenizer = compile_model(tokenizer), compile_model(detokenizer) print(tokenized := tokenizer(["Test string"])["input_ids"]) # [[24235 47429]] print(detokenizer(tokenized)["string_output"]) # ['Test string']from transformers import AutoTokenizer import openvino as ov from openvino_tokenizers import convert_tokenizer model_id = "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF" filename = "DeepSeek-R1-Distill-Qwen-1.5B-Q2_K.gguf" hf_tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename) ov_tokenizer, ov_detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True) ov_tokenizer, ov_detokenizer = ov.compile_model(ov_tokenizer), ov.compile_model(ov_detokenizer) print(ov_res := ov_tokenizer(["Test string"])["input_ids"]) # [[2271 914]] print(ov_detokenizer(ov_res)["string_output"]) # ['Test string']This example shows how to run inference with C++ on a text-classification model from Hugging Face. It expects the path to a model directory as parameter, and prints the logits returned by the model inference.
Export an example model by running the following command after pip install optimum[openvino]:
optimum-cli export openvino microsoft/deberta-base-mnli deberta-base-mnli-ov#include <openvino/openvino.hpp> #include <iostream> #include <filesystem> int main(int argc, char* argv[]) { std::string dirname = argv[1]; std::filesystem::path dir_path(dirname); std::filesystem::path model_xml = dir_path / "openvino_model.xml"; std::filesystem::path tokenizer_xml = dir_path / "openvino_tokenizer.xml"; ov::Core core; // use "openvino_tokenizers.dll" on Windows, "libopenvino_tokenizers.dylib" on macOS core.add_extension("libopenvino_tokenizers.so"); ov::InferRequest tokenizer_request = core.compile_model(tokenizer_xml, "CPU").create_infer_request(); std::string prompt="Hello world!"; tokenizer_request.set_input_tensor(ov::Tensor{ov::element::string, {1}, &prompt}); tokenizer_request.infer(); ov::Tensor input_ids = tokenizer_request.get_tensor("input_ids"); ov::Tensor attention_mask = tokenizer_request.get_tensor("attention_mask"); ov::InferRequest infer_request = core.compile_model(model_xml, "CPU").create_infer_request(); infer_request.set_tensor("input_ids", input_ids); infer_request.set_tensor("attention_mask", attention_mask); infer_request.infer(); auto output = infer_request.get_tensor("logits"); const float *output_buffer = output.data<const float>(); size_t num_elements = output.get_size(); for (size_t i = 0; i < num_elements; i++) { std::cout << output_buffer[i] << " "; } std::cout << std::endl; return 0; }- OpenVINO Tokenizers support UTF-8 encoded inputs.
- Internal tokenizer vocabulary is stored in UTF-8 encoding:
- Providing a tokenizer model with non-UTF-8 input may lead to unexpected outputs or errors,
- Detokenizer output is UTF-8 encoded; if your terminal does not expect UTF-8, you might see garbage characters.
- By default, a detokenizer replaces invalid UTF-8 output with � character. You can change this behavior during conversion.
| Huggingface Tokenizer Type | Tokenizer Model Type | Tokenizer | Detokenizer |
|---|---|---|---|
| Fast | WordPiece | ✅ | ✅ |
| BPE | ✅ | ✅ | |
| Unigram | ✅ | ✅ | |
| WordLevel* | ✅ | ✅ | |
| Legacy | SentencePiece .model | ✅ | ✅ |
| Custom | tiktoken | ✅ | ✅ |
| RWKV | Trie | ✅ | ✅ |
This report is autogenerated and includes tokenizers and detokenizers tests. The Output Matched, % column shows the percent of test strings for which the results of OpenVINO and Huggingface Tokenizers are the same. To update the report run pytest --update_readme tokenizers_test.py in tests directory.
| Tokenizer Type | Output Matched, % | Number of Tests |
|---|---|---|
| BPE | 99.45 | 4397 |
| SentencePiece | 88.37 | 5279 |
| Tiktoken | 96.64 | 536 |
| Unigram | 95.35 | 1506 |
| WordLevel | 98.99 | 198 |
| WordPiece | 99.09 | 1319 |
| Tokenizer Type | Model | Output Matched, % | Number of Tests |
|---|---|---|---|
| BPE | LiquidAI/LFM2-350M | 100.00 | 253 |
| BPE | NousResearch/Llama-2-13b-hf | 100.00 | 251 |
| BPE | NousResearch/Meta-Llama-3-8B-Instruct | 100.00 | 253 |
| BPE | Qwen/Qwen3-Reranker-0.6B | 100.00 | 269 |
| BPE | Xenova/gpt-4o | 100.00 | 267 |
| BPE | answerdotai/ModernBERT-base | 100.00 | 267 |
| BPE | bigscience/bloom | 97.61 | 251 |
| BPE | deepseek-ai/DeepSeek-V3-0324 | 99.26 | 269 |
| BPE | deepseek-ai/deepseek-coder-6.7b-instruct | 100.00 | 269 |
| BPE | facebook/galactica-120b | 100.00 | 251 |
| BPE | koalajun/Gemma-2-9b-it-Ko-Crypto-Translate | 100.00 | 253 |
| BPE | llava-hf/LLaVA-NeXT-Video-7B-hf | 100.00 | 251 |
| BPE | microsoft/Phi-3-mini-128k-instruct | 100.00 | 253 |
| BPE | microsoft/deberta-base | 100.00 | 251 |
| BPE | mlx-community/quantized-gemma-7b-it | 97.63 | 253 |
| BPE | roberta-base | 100.00 | 267 |
| BPE | tiiuae/Falcon3-7B-Instruct | 96.28 | 269 |
| SentencePiece | BAAI/bge-reranker-v2-m3 | 96.81 | 251 |
| SentencePiece | BAAI/bge-reranker-v2-m3_legacy | 96.81 | 251 |
| SentencePiece | NousResearch/Llama-2-13b-hf | 96.02 | 251 |
| SentencePiece | NousResearch/Llama-2-13b-hf_legacy | 99.20 | 251 |
| SentencePiece | camembert-base | 56.18 | 251 |
| SentencePiece | camembert-base_legacy | 78.88 | 251 |
| SentencePiece | facebook/musicgen-small | 82.07 | 251 |
| SentencePiece | facebook/musicgen-small_legacy | 76.10 | 251 |
| SentencePiece | google/flan-t5-xxl | 75.70 | 251 |
| SentencePiece | google/flan-t5-xxl_legacy | 74.50 | 251 |
| SentencePiece | llava-hf/LLaVA-NeXT-Video-7B-hf | 95.22 | 251 |
| SentencePiece | llava-hf/LLaVA-NeXT-Video-7B-hf_legacy | 98.41 | 251 |
| SentencePiece | microsoft/Phi-3-mini-128k-instruct | 99.21 | 253 |
| SentencePiece | microsoft/Phi-3-mini-128k-instruct_legacy | 97.63 | 253 |
| SentencePiece | microsoft/deberta-v3-base | 95.22 | 251 |
| SentencePiece | microsoft/deberta-v3-base_legacy | 98.41 | 251 |
| SentencePiece | microsoft/speecht5_tts_legacy | 71.71 | 251 |
| SentencePiece | mlx-community/quantized-gemma-7b-it | 96.84 | 253 |
| SentencePiece | mlx-community/quantized-gemma-7b-it_legacy | 97.63 | 253 |
| SentencePiece | rinna/bilingual-gpt-neox-4b | 83.27 | 251 |
| SentencePiece | rinna/bilingual-gpt-neox-4b_legacy | 89.64 | 251 |
| Tiktoken | Qwen/Qwen-14B-Chat | 100.00 | 267 |
| Tiktoken | THUDM/glm-4-9b-chat | 93.31 | 269 |
| Unigram | BAAI/bge-reranker-v2-m3 | 98.41 | 251 |
| Unigram | camembert-base | 84.86 | 251 |
| Unigram | facebook/musicgen-small | 98.41 | 251 |
| Unigram | google/flan-t5-xxl | 92.03 | 251 |
| Unigram | microsoft/deberta-v3-base | 98.41 | 251 |
| Unigram | rinna/bilingual-gpt-neox-4b | 100.00 | 251 |
| WordLevel | cisco-ai/mini-bart-g2p | 98.99 | 198 |
| WordPiece | bert-base-multilingual-cased | 100.00 | 267 |
| WordPiece | cointegrated/rubert-tiny2 | 100.00 | 267 |
| WordPiece | google/mobilebert-uncased | 100.00 | 251 |
| WordPiece | rasa/LaBSE | 95.51 | 267 |
| WordPiece | sentence-transformers/all-MiniLM-L6-v2 | 100.00 | 267 |
In some tokenizers, you need to select certain settings so that their output is closer to the Huggingface tokenizers:
THUDM/chatglm3-6bdetokenizer don't skips special tokens. Useskip_special_tokens=Falseduring conversion- All tested tiktoken based detokenizers leave extra spaces. Use
clean_up_tokenization_spaces=Falseduring conversion