Skip to content

fix: align sigmas device in UniPCMultistepScheduler.set_timesteps#13450

Open
conanna wants to merge 1 commit intohuggingface:mainfrom
conanna:fix/unipc-scheduler-device-mismatch
Open

fix: align sigmas device in UniPCMultistepScheduler.set_timesteps#13450
conanna wants to merge 1 commit intohuggingface:mainfrom
conanna:fix/unipc-scheduler-device-mismatch

Conversation

@conanna
Copy link
Copy Markdown

@conanna conanna commented Apr 12, 2026

What does this PR do?

UniPCMultistepScheduler.set_timesteps hardcodes self.sigmas = self.sigmas.to("cpu") at the end of the method. This causes device mismatch errors in multistep_uni_p_bh_update and multistep_uni_c_bh_update when used with flow-matching models (e.g. Wan 2.2 I2V-A14B).

Root cause: self.sigmas stays on CPU → all derived tensors (alpha_t, sigma_t, lambda_t, h, rk, B_h, b) remain on CPU → torch.stack(rks) and torch.linalg.solve(R, b) fail because sample and model_output are on GPU.

This only affects use_flow_sigmas=True + flow_prediction path. The older DDPM/epsilon path doesn't hit this because intermediate computations get broadcast to GPU earlier.

Fix: Replace self.sigmas.to("cpu") with self.sigmas.to(device=device) in set_timesteps, using the device parameter already passed by the caller.

Error traceback

RuntimeError: Expected all tensors to be on the same device, but got tensors is on cuda:0,
different from other tensors on cpu (when checking argument in method wrapper_CUDA_cat)

  File "scheduling_unipc_multistep.py", line 907, in multistep_uni_p_bh_update
    rks = torch.stack(rks)

Reproduction

import torch
from diffusers import WanImageToVideoPipeline
from diffusers.utils import load_image

pipe = WanImageToVideoPipeline.from_pretrained(
    "Wan-AI/Wan2.2-I2V-A14B-Diffusers", torch_dtype=torch.bfloat16
)
pipe.to("cuda")

image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG")

output = pipe(
    image=image,
    prompt="A cat sitting on a surfboard",
    num_frames=17,
    num_inference_steps=10,
)
# RuntimeError at scheduler.step() → multistep_uni_p_bh_update → torch.stack(rks)

Before submitting

Who can review?

@yiyixuxu (Schedulers)

@github-actions github-actions bot added schedulers size/S PR with diff < 50 LOC labels Apr 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

schedulers size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant