Skip to content

Can't torch.compile transformer models that load GGUF via from_single_file #10795

@AstraliteHeart

Description

@AstraliteHeart

Describe the bug

transformer model loaded via GGUF can't be torch.compile(d) and raises torch._dynamo.exc.Unsupported: call_method SetVariable() __setitem__ (UserDefinedObjectVariable(GGUFParameter), ConstantVariable(NoneType: None)) {}

'normal' model loaded from HF for the same pipeline can be torch.compile(d) just fine.

Reproduction

If I load the pipeline from HF model, i.e.

import torch from diffusers import AuraFlowPipeline torch.set_float32_matmul_precision("high") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True pipeline = AuraFlowPipeline.from_pretrained( "fal/AuraFlow-v0.3", torch_dtype=torch.bfloat16, ).to("cuda") pipeline.transformer.to(memory_format=torch.channels_last) pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) pipeline("A cute pony", width=512, height=512, num_inference_steps=5)

I can torch.compile it (and observer better performance).

If I try to load the transformer part from GGUF

import torch from diffusers import ( AuraFlowPipeline, GGUFQuantizationConfig, AuraFlowTransformer2DModel, ) torch.set_float32_matmul_precision("high") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True transformer = AuraFlowTransformer2DModel.from_single_file( "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf", quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), torch_dtype=torch.bfloat16, ) pipeline = AuraFlowPipeline.from_pretrained( "fal/AuraFlow-v0.3", torch_dtype=torch.bfloat16, transformer=transformer, ).to("cuda") pipeline.transformer.to(memory_format=torch.channels_last) pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) pipeline("A cute pony", width=512, height=512, num_inference_steps=5)

it raises an exception (see log).

I am still learning on how torch.compile/dynamo function so its unclear to me if this is just some basic confusion of GGUFParameter wrapping torch.nn.Parameter or if diffusers need to do anything special (or if this is something torch must do better?). I've only tested on AuraFlow but this should be the same for any code using GGUF loading. Happy to continue debugging/raise issue with the torch devs but would appreciate if someone more knowledgable have a look at this.

Logs

Traceback (most recent call last): File "test_diff_gguf.py", line 27, in <module> pipeline("A cute pony", width=512, height=512, num_inference_steps=5) File "/env/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/env/diffusers/pipelines/aura_flow/pipeline_aura_flow.py", line 555, in __call__ noise_pred = self.transformer( File "/env/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/env/torch/nn/modules/module.py", line 1750, in _call_impl return forward_call(*args, **kwargs) File "/env/torch/_dynamo/eval_frame.py", line 574, in _fn return fn(*args, **kwargs) File "/env/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/env/torch/nn/modules/module.py", line 1750, in _call_impl return forward_call(*args, **kwargs) File "/env/torch/_dynamo/convert_frame.py", line 1380, in __call__ return self._torchdynamo_orig_callable( File "/env/torch/_dynamo/convert_frame.py", line 547, in __call__ return _compile( File "/env/torch/_dynamo/convert_frame.py", line 986, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/env/torch/_dynamo/convert_frame.py", line 715, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/env/torch/_utils_internal.py", line 95, in wrapper_function return function(*args, **kwargs) File "/env/torch/_dynamo/convert_frame.py", line 750, in _compile_inner out_code = transform_code_object(code, transform) File "/env/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object transformations(instructions, code_options) File "/env/torch/_dynamo/convert_frame.py", line 231, in _fn return fn(*args, **kwargs) File "/env/torch/_dynamo/convert_frame.py", line 662, in transform tracer.run() File "/env/torch/_dynamo/symbolic_convert.py", line 2868, in run super().run() File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run while self.step(): File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step self.dispatch_table[inst.opcode](self, inst) File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper return inner_fn(self, inst) File "/env/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION self.call_function(fn, args, {}) File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function return super().call_function(tx, args, kwargs) File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function return super().call_function(tx, args, kwargs) File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_ tracer.run() File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run while self.step(): File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step self.dispatch_table[inst.opcode](self, inst) File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper return inner_fn(self, inst) File "/env/torch/_dynamo/symbolic_convert.py", line 1748, in CALL_FUNCTION_KW self.call_function(fn, args, kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function return super().call_function(tx, args, kwargs) File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function return super().call_function(tx, args, kwargs) File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_ tracer.run() File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run while self.step(): File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step self.dispatch_table[inst.opcode](self, inst) File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper return inner_fn(self, inst) File "/env/torch/_dynamo/symbolic_convert.py", line 1748, in CALL_FUNCTION_KW self.call_function(fn, args, kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function return super().call_function(tx, args, kwargs) File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function return super().call_function(tx, args, kwargs) File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_ tracer.run() File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run while self.step(): File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step self.dispatch_table[inst.opcode](self, inst) File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper return inner_fn(self, inst) File "/env/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION self.call_function(fn, args, {}) File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/env/torch/_dynamo/variables/misc.py", line 1022, in call_function return self.obj.call_method(tx, self.name, args, kwargs) File "/env/torch/_dynamo/variables/dicts.py", line 566, in call_method return super().call_method(tx, name, args, kwargs) File "/env/torch/_dynamo/variables/dicts.py", line 396, in call_method return super().call_method(tx, name, args, kwargs) File "/env/torch/_dynamo/variables/base.py", line 414, in call_method unimplemented(f"call_method {self} {name} {args} {kwargs}") File "/env/torch/_dynamo/exc.py", line 317, in unimplemented raise Unsupported(msg, case_name=case_name) torch._dynamo.exc.Unsupported: call_method SetVariable() __setitem__ (UserDefinedObjectVariable(GGUFParameter), ConstantVariable(NoneType: None)) {} from user code: File "/env/diffusers/models/transformers/auraflow_transformer_2d.py", line 458, in forward temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) File "/env/torch/nn/modules/module.py", line 2630, in parameters for _name, param in self.named_parameters(recurse=recurse): File "/env/torch/nn/modules/module.py", line 2657, in named_parameters gen = self._named_members( File "/env/torch/nn/modules/module.py", line 2604, in _named_members memo.add(v) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.10.16
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.2
  • Accelerate version: 1.3.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.2
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@DN6 @hlky @stevhliu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions