Skip to content

Commit 8c84144

Browse files
authored
[Model] Add PaddleOCR-VL Model Support (#42178)
* init * refactor * update * update * fix unresolved problems * fix how position_ids work with flash_attn_2 * add tests and fix code * add model_doc * update model_doc * fix ci * update docstring * add tests * update * add **kwargs * update * update * update * reduce max_position_embeddings in tests * update
1 parent 78b2992 commit 8c84144

20 files changed

+5006
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,8 @@
11191119
title: OWL-ViT
11201120
- local: model_doc/owlv2
11211121
title: OWLv2
1122+
- local: model_doc/paddleocr_vl
1123+
title: PaddleOCRVL
11221124
- local: model_doc/paligemma
11231125
title: PaliGemma
11241126
- local: model_doc/perceiver
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
*This model was released on 2025.10.16 and added to Hugging Face Transformers on 2025.12.10*
17+
18+
# PaddleOCR-VL
19+
20+
<div class="flex flex-wrap space-x-1">
21+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
23+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
24+
</div>
25+
26+
## Overview
27+
28+
**Huggingface Hub**: [PaddleOCR-VL](https://huggingface.co/collections/PaddlePaddle/paddleocr-vl) | **Github Repo**: [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
29+
30+
**Official Website**: [Baidu AI Studio](https://aistudio.baidu.com/paddleocr) | **arXiv**: [Technical Report](https://arxiv.org/pdf/2510.14528)
31+
32+
**PaddleOCR-VL** is a SOTA and resource-efficient model tailored for document parsing. Its core component is PaddleOCR-VL-0.9B, a compact yet powerful vision-language model (VLM) that integrates a NaViT-style dynamic resolution visual encoder with the ERNIE-4.5-0.3B language model to enable accurate element recognition. This innovative model efficiently supports 109 languages and excels in recognizing complex elements (e.g., text, tables, formulas, and charts), while maintaining minimal resource consumption. Through comprehensive evaluations on widely used public benchmarks and in-house benchmarks, PaddleOCR-VL achieves SOTA performance in both page-level document parsing and element-level recognition. It significantly outperforms existing solutions, exhibits strong competitiveness against top-tier VLMs, and delivers fast inference speeds. These strengths make it highly suitable for practical deployment in real-world scenarios.
33+
34+
<div align="center">
35+
<img src="https://huggingface.co/datasets/PaddlePaddle/PaddleOCR-VL_demo/resolve/main/imgs/allmetric.png" width="800"/>
36+
</div>
37+
38+
### **Core Features**
39+
40+
1. **Compact yet Powerful VLM Architecture:** We present a novel vision-language model that is specifically designed for resource-efficient inference, achieving outstanding performance in element recognition. By integrating a NaViT-style dynamic high-resolution visual encoder with the lightweight ERNIE-4.5-0.3B language model, we significantly enhance the model’s recognition capabilities and decoding efficiency. This integration maintains high accuracy while reducing computational demands, making it well-suited for efficient and practical document processing applications.
41+
42+
2. **SOTA Performance on Document Parsing:** PaddleOCR-VL achieves state-of-the-art performance in both page-level document parsing and element-level recognition. It significantly outperforms existing pipeline-based solutions and exhibiting strong competitiveness against leading vision-language models (VLMs) in document parsing. Moreover, it excels in recognizing complex document elements, such as text, tables, formulas, and charts, making it suitable for a wide range of challenging content types, including handwritten text and historical documents. This makes it highly versatile and suitable for a wide range of document types and scenarios.
43+
44+
3. **Multilingual Support:** PaddleOCR-VL Supports 109 languages, covering major global languages, including but not limited to Chinese, English, Japanese, Latin, and Korean, as well as languages with different scripts and structures, such as Russian (Cyrillic script), Arabic, Hindi (Devanagari script), and Thai. This broad language coverage substantially enhances the applicability of our system to multilingual and globalized document processing scenarios.
45+
46+
### **Model Architecture**
47+
48+
<div align="center">
49+
<img src="https://huggingface.co/datasets/PaddlePaddle/PaddleOCR-VL_demo/resolve/main/imgs/paddleocrvl.png" width="800"/>
50+
</div>
51+
52+
## Usage
53+
54+
### Usage tips
55+
56+
> [!IMPORTANT]
57+
> We currently recommend using the [PaddleOCR official method for inference](https://www.paddleocr.ai/latest/en/version3.x/pipeline_usage/PaddleOCR-VL.html), as it is faster and supports page-level document parsing.
58+
> The example code below only supports element-level recognition.
59+
60+
We have four types of element-level recognition:
61+
62+
- Text recognition, indicated by the prompt `OCR:`.
63+
- Formula recognition, indicated by the prompt `Formula Recognition:`.
64+
- Table recognition, indicated by the prompt `Table Recognition:`.
65+
- Chart recognition, indicated by the prompt `Chart Recognition:`.
66+
67+
The following examples are all based on text recognition, with the prompt `OCR:`.
68+
69+
### Single input inference
70+
71+
The example below demonstrates how to generate text with PaddleOCRVL using [`Pipeline`] or the [`AutoModel`].
72+
73+
<hfoptions id="usage">
74+
<hfoption id="Pipeline">
75+
76+
```py
77+
from transformers import pipeline
78+
79+
pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
80+
messages = [
81+
{
82+
"role": "user",
83+
"content": [
84+
{"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
85+
{"type": "text", "text": "OCR:"},
86+
]
87+
}
88+
]
89+
result = pipe(text=messages)
90+
print(result[0]["generated_text"])
91+
```
92+
93+
</hfoption>
94+
95+
<hfoption id="AutoModel">
96+
97+
```py
98+
from transformers import AutoProcessor, AutoModelForImageTextToText
99+
100+
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
101+
processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
102+
messages = [
103+
{
104+
"role": "user",
105+
"content": [
106+
{"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
107+
{"type": "text", "text": "OCR:"},
108+
]
109+
}
110+
]
111+
inputs = processor.apply_chat_template(
112+
messages,
113+
add_generation_prompt=True,
114+
tokenize=True,
115+
return_dict=True,
116+
return_tensors="pt",
117+
).to(model.device)
118+
119+
outputs = model.generate(**inputs, max_new_tokens=100)
120+
result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1])
121+
print(result)
122+
```
123+
124+
</hfoption>
125+
</hfoptions>
126+
127+
### Batched inference
128+
129+
PaddleOCRVL also supports batched inference. We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Here is how you can do it with PaddleOCRVL using [`Pipeline`] or the [`AutoModel`]:
130+
131+
<hfoptions id="usage">
132+
<hfoption id="Pipeline">
133+
134+
```py
135+
from transformers import pipeline
136+
137+
pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
138+
messages = [
139+
{
140+
"role": "user",
141+
"content": [
142+
{"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
143+
{"type": "text", "text": "OCR:"},
144+
]
145+
}
146+
]
147+
result = pipe(text=[messages, messages])
148+
print(result[0][0]["generated_text"])
149+
print(result[1][0]["generated_text"])
150+
```
151+
152+
</hfoption>
153+
154+
<hfoption id="AutoModel">
155+
156+
```py
157+
from transformers import AutoProcessor, AutoModelForImageTextToText
158+
159+
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
160+
processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
161+
messages = [
162+
{
163+
"role": "user",
164+
"content": [
165+
{"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
166+
{"type": "text", "text": "OCR:"},
167+
]
168+
}
169+
]
170+
batch_messages = [messages, messages]
171+
inputs = processor.apply_chat_template(
172+
batch_messages,
173+
add_generation_prompt=True,
174+
tokenize=True,
175+
return_dict=True,
176+
return_tensors="pt",
177+
padding=True,
178+
padding_side='left',
179+
).to(model.device)
180+
181+
generated_ids = model.generate(**inputs, max_new_tokens=100)
182+
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
183+
result = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
184+
print(result)
185+
```
186+
187+
</hfoption>
188+
</hfoptions>
189+
190+
### Using Flash Attention 2
191+
192+
Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [FlashAttention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention).
193+
194+
For example:
195+
196+
```shell
197+
pip install flash-attn --no-build-isolation
198+
```
199+
200+
```python
201+
from transformers import AutoModelForImageTextToText
202+
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16", attn_implementation="flash_attention_2")
203+
```
204+
205+
## PaddleOCRVLForConditionalGeneration
206+
207+
[[autodoc]] PaddleOCRVLForConditionalGeneration
208+
- forward
209+
210+
## PaddleOCRVLConfig
211+
212+
[[autodoc]] PaddleOCRVLConfig
213+
214+
## PaddleOCRVisionConfig
215+
216+
[[autodoc]] PaddleOCRVisionConfig
217+
218+
## PaddleOCRTextConfig
219+
220+
[[autodoc]] PaddleOCRTextConfig
221+
222+
## PaddleOCRTextModel
223+
224+
[[autodoc]] PaddleOCRTextModel
225+
226+
## PaddleOCRVisionModel
227+
228+
[[autodoc]] PaddleOCRVisionModel
229+
230+
## PaddleOCRVLImageProcessor
231+
232+
[[autodoc]] PaddleOCRVLImageProcessor
233+
234+
## PaddleOCRVLImageProcessorFast
235+
236+
[[autodoc]] PaddleOCRVLImageProcessorFast
237+
238+
## PaddleOCRVLModel
239+
240+
[[autodoc]] PaddleOCRVLModel
241+
242+
## PaddleOCRVLProcessor
243+
244+
[[autodoc]] PaddleOCRVLProcessor
245+
246+
## PaddleOCRVisionTransformer
247+
248+
[[autodoc]] PaddleOCRVisionTransformer

src/transformers/conversion_mapping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def register_checkpoint_conversion_mapping(
225225
"sam3",
226226
"sam3_tracker",
227227
"sam3_tracker_video",
228+
"paddleocrvl",
228229
]
229230

230231

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@
265265
from .ovis2 import *
266266
from .owlv2 import *
267267
from .owlvit import *
268+
from .paddleocr_vl import *
268269
from .paligemma import *
269270
from .parakeet import *
270271
from .patchtsmixer import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@
300300
("ovis2", "Ovis2Config"),
301301
("owlv2", "Owlv2Config"),
302302
("owlvit", "OwlViTConfig"),
303+
("paddleocr_vl", "PaddleOCRVLConfig"),
303304
("paligemma", "PaliGemmaConfig"),
304305
("parakeet_ctc", "ParakeetCTCConfig"),
305306
("parakeet_encoder", "ParakeetEncoderConfig"),
@@ -754,6 +755,7 @@
754755
("ovis2", "Ovis2"),
755756
("owlv2", "OWLv2"),
756757
("owlvit", "OWL-ViT"),
758+
("paddleocr_vl", "PaddleOCRVL"),
757759
("paligemma", "PaliGemma"),
758760
("parakeet", "Parakeet"),
759761
("parakeet_ctc", "Parakeet"),

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
154154
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
155155
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
156+
("paddleocr_vl", ("PaddleOCRVLImageProcessor", "PaddleOCRVLImageProcessorFast")),
156157
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
157158
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
158159
("perception_lm", (None, "PerceptionLMImageProcessorFast")),

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
10261026
("mistral3", "Mistral3ForConditionalGeneration"),
10271027
("mllama", "MllamaForConditionalGeneration"),
10281028
("ovis2", "Ovis2ForConditionalGeneration"),
1029+
("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"),
10291030
("paligemma", "PaliGemmaForConditionalGeneration"),
10301031
("perception_lm", "PerceptionLMForConditionalGeneration"),
10311032
("pix2struct", "Pix2StructForConditionalGeneration"),

src/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
("ovis2", "Ovis2Processor"),
115115
("owlv2", "Owlv2Processor"),
116116
("owlvit", "OwlViTProcessor"),
117+
("paddleocr_vl", "PaddleOCRVLProcessor"),
117118
("paligemma", "PaliGemmaProcessor"),
118119
("perception_lm", "PerceptionLMProcessor"),
119120
("phi4_multimodal", "Phi4MultimodalProcessor"),

src/transformers/models/auto/tokenization_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@
273273
("ovis2", "Qwen2TokenizerFast" if is_tokenizers_available() else None),
274274
("owlv2", "CLIPTokenizerFast" if is_tokenizers_available() else None),
275275
("owlvit", "CLIPTokenizerFast" if is_tokenizers_available() else None),
276+
("paddleocr_vl", "TokenizersBackend" if is_tokenizers_available() else None),
276277
("paligemma", "LlamaTokenizer" if is_tokenizers_available() else None),
277278
("pegasus", "PegasusTokenizer" if is_tokenizers_available() else None),
278279
("pegasus_x", "PegasusTokenizer" if is_tokenizers_available() else None),
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# coding=utf-8
2+
# Copyright 2025 the HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import TYPE_CHECKING
17+
18+
from ...utils import _LazyModule
19+
from ...utils.import_utils import define_import_structure
20+
21+
22+
if TYPE_CHECKING:
23+
from .configuration_paddleocr_vl import *
24+
from .image_processing_paddleocr_vl import *
25+
from .image_processing_paddleocr_vl_fast import *
26+
from .modeling_paddleocr_vl import *
27+
from .processing_paddleocr_vl import *
28+
else:
29+
import sys
30+
31+
_file = globals()["__file__"]
32+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

0 commit comments

Comments
 (0)