Skip to content

Conversation

@samadwar
Copy link

What does this PR do?

Added support to load checkpoints from a single file where some modifications were required to convert_wan_transformer_to_diffusers method for it to work with WanAnimateTransformer3DModel

best regards,
Sam

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6
Copy link
Collaborator

DN6 commented Nov 21, 2025

Hi @samadwar do you have a single file version of Wan Animate we can use to test this PR?

@samadwar
Copy link
Author

samadwar commented Nov 21, 2025

@dg845
Copy link
Collaborator

dg845 commented Nov 22, 2025

Hi @samadwar, thanks for the PR! Would you be able to share an example of a code snippet which uses WanAnimateTransformer3DModel.from_single_file? I tried to test the PR using the following script:

import os import torch from diffusers import GGUFQuantizationConfig, WanAnimatePipeline, WanAnimateTransformer3DModel from diffusers.utils import export_to_video, load_image, load_video single_file_ckpt = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf" # single_file_ckpt = "https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/blob/main/Wan22Animate/Wan2_2-Animate-14B_fp8_scaled_e4m3fn_KJ_v2.safetensors" model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" device = "cuda:0" dtype = torch.bfloat16 seed = 42 transformer_kwargs = {} _, single_file_ext = os.path.splitext(single_file_ckpt) if single_file_ext == ".gguf": quantization_config = GGUFQuantizationConfig(compute_dtype=dtype) transformer_kwargs["quantization_config"] = quantization_config transformer = WanAnimateTransformer3DModel.from_single_file( single_file_ckpt, config=model_id, subfolder="transformer", **transformer_kwargs, ) pipe = WanAnimatePipeline.from_pretrained( model_id, transformer=transformer, torch_dtype=dtype, ) pipe.to(device) image = load_image("/path/to/reference_image.png") pose_video = load_video("/path/to/pose_video.mp4") face_video = load_video("/path/to/face_video.mp4") video = pipe( image=image, pose_video=pose_video, face_video=face_video, prompt="People in the video are doing actions.", height=720, width=1280, mode="animate", guidance_scale=1.0, num_inference_steps=20, generator=torch.Generator(device=device).manual_seed(seed), output_type="np", ).frames[0] export_to_video(video, "wan_animate_single_file.mp4", fps=30)

Using a checkpoint from QuantStack/Wan2.2-Animate-14B-GGUF doesn't get any errors, but the generated samples seem to be just noise:

wan_animate_single_file_gguf_20_step.mp4

If I instead try a checkpoint from Kijai/WanVideo_comfy_fp8_scaled, I get an OOM error on a A100 (80 GB VRAM) and a lot of keys in the model don't seem to be used (they mainly end in .scale_weight, so they might be the FP8 scaling parameters?).

@samadwar
Copy link
Author

Hi @samadwar, thanks for the PR! Would you be able to share an example of a code snippet which uses WanAnimateTransformer3DModel.from_single_file? I tried to test the PR using the following script:

import os import torch from diffusers import GGUFQuantizationConfig, WanAnimatePipeline, WanAnimateTransformer3DModel from diffusers.utils import export_to_video, load_image, load_video single_file_ckpt = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf" # single_file_ckpt = "https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/blob/main/Wan22Animate/Wan2_2-Animate-14B_fp8_scaled_e4m3fn_KJ_v2.safetensors" model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" device = "cuda:0" dtype = torch.bfloat16 seed = 42 transformer_kwargs = {} _, single_file_ext = os.path.splitext(single_file_ckpt) if single_file_ext == ".gguf": quantization_config = GGUFQuantizationConfig(compute_dtype=dtype) transformer_kwargs["quantization_config"] = quantization_config transformer = WanAnimateTransformer3DModel.from_single_file( single_file_ckpt, config=model_id, subfolder="transformer", **transformer_kwargs, ) pipe = WanAnimatePipeline.from_pretrained( model_id, transformer=transformer, torch_dtype=dtype, ) pipe.to(device) image = load_image("/path/to/reference_image.png") pose_video = load_video("/path/to/pose_video.mp4") face_video = load_video("/path/to/face_video.mp4") video = pipe( image=image, pose_video=pose_video, face_video=face_video, prompt="People in the video are doing actions.", height=720, width=1280, mode="animate", guidance_scale=1.0, num_inference_steps=20, generator=torch.Generator(device=device).manual_seed(seed), output_type="np", ).frames[0] export_to_video(video, "wan_animate_single_file.mp4", fps=30)

Using a checkpoint from QuantStack/Wan2.2-Animate-14B-GGUF doesn't get any errors, but the generated samples seem to be just noise:
wan_animate_single_file_gguf_20_step.mp4

If I instead try a checkpoint from Kijai/WanVideo_comfy_fp8_scaled, I get an OOM error on a A100 (80 GB VRAM) and a lot of keys in the model don't seem to be used (they mainly end in .scale_weight, so they might be the FP8 scaling parameters?).

Yeah, I am experiencing same issue today, I had it working before, I will check and get back to you.

For the GGUF I am using AWS ml.g6e.4xlarge that comes with 45 GB VRAM, I don't have access to more GPU VRAM to test fp8. but I guess one way to check is load the file in safetensor package and check the actual value of the weights if they match or not.

@samadwar
Copy link
Author

samadwar commented Nov 22, 2025

@dg845 I fixed the issue, can you try now?

@samadwar
Copy link
Author

samadwar commented Nov 22, 2025

Code I am using:

import torch import numpy as np from diffusers import AutoencoderKLWan, GGUFQuantizationConfig from diffusers import WanAnimatePipeline, WanAnimateTransformer3DModel from diffusers.utils import export_to_video, load_image, load_video import os from diffusers.utils import logging from safetensors.torch import load_file LoRA = True device_cpu = torch.device("cpu") device_gpu = torch.device("cuda") original_model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" lora_model_id = "Kijai/WanVideo_comfy" lora_model_path = "Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors" print("Loading transformer ....") transformer = WanAnimateTransformer3DModel.from_single_file( "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q8_0.gguf", quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), config=original_model_id, subfolder="transformer", torch_dtype=torch.bfloat16, offload_device="cpu", device=device_gpu ) print("Transformer loaded successfully ....") print("Loading pipeline ....") pipe = WanAnimatePipeline.from_pretrained( original_model_id, transformer=transformer, torch_dtype=torch.bfloat16, ) if LoRA: pipe.load_lora_weights( lora_model_id, weight_name=lora_model_path, adapter_name="lightning", offload_device="cpu", device=device_gpu ) pipe.enable_model_cpu_offload() print("Pipeline loaded successfully ....") # Load the character image image = load_image( "Wan2.2/examples/wan_animate/animate/image.jpeg" ) # Load pose and face videos (preprocessed from reference video) # Note: Videos should be preprocessed to extract pose keypoints and face features # Refer to the Wan-Animate preprocessing documentation for details pose_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_pose.mp4") face_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_face.mp4") # Calculate optimal dimensions based on VAE constraints max_area = 1280 * 720 aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value image = image.resize((width, height)) prompt = "People in the video are doing actions." # Animation mode: Animate the character with the motion from pose/face videos print("Generating animation ....") if LoRA: output = pipe( image=image, pose_video=pose_video, face_video=face_video, prompt=prompt, # negative_prompt=negative_prompt, height=height, width=width, segment_frame_length=77, guidance_scale=1.0, prev_segment_conditioning_frames=1, # refert_num in original code num_inference_steps=4, mode="animate", ).frames[0] else: output = pipe( image=image, pose_video=pose_video, face_video=face_video, prompt=prompt, # negative_prompt=negative_prompt, height=height, width=width, segment_frame_length=77, guidance_scale=1.0, prev_segment_conditioning_frames=1, # refert_num in original code num_inference_steps=20, mode="animate", ).frames[0] print("Exporting animation ....") export_to_video(output, "output_animation__.mp4", fps=30)
output_animation__.mp4
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [torch.int8, torch.uint8]:
timestep = timestep.to(time_embedder_dtype)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dg845 Do you know why this line exists? It seems to cause the white noise issue when time_embedder weights are in uint8, and line 811 would have an issue if timestep dtype does not match encoder_hidden_states. May be we need to remove lines 807 and 808?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this line is the way it is. @yiyixuxu, do you know why it was added?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we turn these checks into something like

 timestep = self.timesteps_proj(timestep) if timestep_seq_len is not None: timestep = timestep.unflatten(0, (-1, timestep_seq_len)) timestep = timestep.to(encoder_hidden_states.dtype) temb = self.time_embedder(timestep) timestep_proj = self.time_proj(self.act_fn(temb))

Since the last cast to encoder_hidden_states is the one that would get applied anyway? Casting based on layer params tends to give weird results as pointed out by @samadwar

@samadwar samadwar force-pushed the WanAnimate_from_single_file branch from 1043ed6 to 3c3755c Compare November 26, 2025 22:55
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [torch.int8, torch.uint8]:
timestep = timestep.to(time_embedder_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we turn these checks into something like

 timestep = self.timesteps_proj(timestep) if timestep_seq_len is not None: timestep = timestep.unflatten(0, (-1, timestep_seq_len)) timestep = timestep.to(encoder_hidden_states.dtype) temb = self.time_embedder(timestep) timestep_proj = self.time_proj(self.act_fn(temb))

Since the last cast to encoder_hidden_states is the one that would get applied anyway? Casting based on layer params tends to give weird results as pointed out by @samadwar

@samadwar samadwar force-pushed the WanAnimate_from_single_file branch from eaab598 to 028bc19 Compare December 5, 2025 08:15
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Dec 7, 2025

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 7, 2025

Style bot fixed some files and pushed the changes.

@samadwar samadwar force-pushed the WanAnimate_from_single_file branch from ed79a1a to bb21f2b Compare December 11, 2025 04:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

5 participants