Skip to content

Commit 0798797

Browse files
authored
Fix an import error with PreTrainModel (#41571)
1 parent 0566b6f commit 0798797

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/transformers/quantizers/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232

3333

3434
def _assign_original_dtype(module, original_dtype):
35+
# not very nice in a recursive function but it avoids a circular import
36+
from ..modeling_utils import PreTrainedModel
37+
3538
for child in module.children():
3639
if isinstance(child, PreTrainedModel):
3740
child.config._pre_quantization_dtype = original_dtype

0 commit comments

Comments
 (0)