There was an error while loading. Please reload this page.
1 parent 0566b6f commit 0798797Copy full SHA for 0798797
src/transformers/quantizers/base.py
@@ -32,6 +32,9 @@
32
33
34
def _assign_original_dtype(module, original_dtype):
35
+ # not very nice in a recursive function but it avoids a circular import
36
+ from ..modeling_utils import PreTrainedModel
37
+
38
for child in module.children():
39
if isinstance(child, PreTrainedModel):
40
child.config._pre_quantization_dtype = original_dtype
0 commit comments