Skip to content

Fix SDXL Refiner with Higher Order Schedulers#13453

Open
Beinsezii wants to merge 1 commit intohuggingface:mainfrom
Beinsezii:beinsezii/fix_denoise_order
Open

Fix SDXL Refiner with Higher Order Schedulers#13453
Beinsezii wants to merge 1 commit intohuggingface:mainfrom
Beinsezii:beinsezii/fix_denoise_order

Conversation

@Beinsezii
Copy link
Copy Markdown
Contributor

@Beinsezii Beinsezii commented Apr 13, 2026

What does this PR do?

Fixes SDXL refiner w/ higher order schedulers by changing the strange hardcoded order==2 check with a simple tensor stride.

Standalone script using a scheduler with order=15

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "diffusers",
#     "skrample==0.6.*",
#     "torch>=2.11.0",
#     "transformers>=5.5.3",
# ]
#
# [tool.uv.sources]
# diffusers = { git = "https://github.com/Beinsezii/diffusers.git", rev = "beinsezii/fix_denoise_order" }
# ///

import torch
from skrample.diffusers import RKUltraWrapperScheduler

from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import (
    StableDiffusionXLImg2ImgPipeline,
)


with torch.inference_mode():
    DEVICE = "cuda"
    DTYPE = torch.bfloat16
    BASE = "stabilityai/stable-diffusion-xl-base-1.0"
    REFINER = "stabilityai/stable-diffusion-xl-refiner-1.0"

    PROMPT = "bright high resolution dslr photograph of a kitten in a field of lavender flowers"
    CFG: float = 8
    RATIO: float = 0.6
    ORDER: int = 15

    base = StableDiffusionXLPipeline.from_pretrained(BASE).to(device=DEVICE, dtype=DTYPE)
    refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        REFINER,
        vae=base.vae,
        text_encoder_2=base.text_encoder_2,
    ).to(device=DEVICE, dtype=DTYPE)
    base.scheduler = RKUltraWrapperScheduler.from_diffusers_config(base.scheduler.config, sampler_order=ORDER)
    refiner.scheduler = RKUltraWrapperScheduler.from_diffusers_config(refiner.scheduler.config, sampler_order=ORDER)

    for steps in 5, 10:
        partial = base(
            prompt=PROMPT,
            guidance_scale=CFG,
            num_inference_steps=steps,
            denoising_end=RATIO,
            output_type="latent",
        ).images[0]
        refiner(
            image=partial,
            prompt=PROMPT,
            guidance_scale=CFG,
            num_inference_steps=steps,
            denoising_start=RATIO,
        ).images[0].save(f"order_{refiner.scheduler.order}_{steps}_steps.png")

Before submitting

Who can review?

@yiyixuxu @sayakpaul

@github-actions github-actions bot added pipelines size/S PR with diff < 50 LOC labels Apr 13, 2026
Comment on lines -670 to -677
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this comment, it was previously hardcoded specifically for Heun's method, and anything else is 100% broken. Thing is, Heun appears to be the only higher order singlestep solver in Diffusers, so I guess we can't add tests for this yet?

Comment on lines -1178 to +1181
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
num_inference_steps = (
(torch.as_tensor(timesteps)[:: self.scheduler.order] >= discrete_timestep_cutoff).sum().item()
)
timesteps = timesteps[: num_inference_steps * self.scheduler.order]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically this might change the current results with Heun but it's necessary because otherwise it'll split wrong on a butcher tableaux with non-sequential coefficients like

                          RKZ.Butcher6
+0.0    | 
+0.2764 | +0.2764
+0.7236 | -0.2236 +0.9472
+0.2764 | +0.0326 +0.309  -0.0652
+0.7236 | +0.0461 +0.0    +0.1667 +0.5109
+0.2764 | +0.1206 +0.0    -0.1817 +0.1667 +0.1708
+1.0    | +0.1667 +0.0    +0.0751 -3.3877 +0.5279 +3.618 
-----------------------------------------------------------------
        | +0.0833 +0.0    +0.0    +0.0    +0.4167 +0.4167 +0.0833

Where it could split on stage 3, but the following stages contain lesser timestep values, and since the refiner is not trained on earlier timesteps this will lead to worse results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pipelines size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant