-
Notifications
You must be signed in to change notification settings - Fork 584
feat: add er_sde sampler #1403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rmatif
wants to merge
1
commit into
leejet:master
Choose a base branch
from
rmatif:er-sde
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
feat: add er_sde sampler #1403
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -818,6 +818,33 @@ static std::tuple<float, float, float> get_ancestral_step_flow(float sigma_from, | |
| return {sigma_down, sigma_up, alpha_scale}; | ||
| } | ||
|
|
||
| static float er_sde_flow_sigma(float sigma) { | ||
| sigma = std::max(sigma, 1e-6f); | ||
| sigma = std::min(sigma, 1.0f - 1e-4f); | ||
| return sigma; | ||
| } | ||
|
|
||
| static float sigma_to_er_sde_lambda(float sigma, bool is_flow_denoiser) { | ||
| if (is_flow_denoiser) { | ||
| sigma = er_sde_flow_sigma(sigma); | ||
| return sigma / std::max(1.0f - sigma, 1e-6f); | ||
| } | ||
| return std::max(sigma, 1e-6f); | ||
| } | ||
|
|
||
| static float sigma_to_er_sde_alpha(float sigma, bool is_flow_denoiser) { | ||
| if (is_flow_denoiser) { | ||
| sigma = er_sde_flow_sigma(sigma); | ||
| return 1.0f - sigma; | ||
| } | ||
| return 1.0f; | ||
| } | ||
|
|
||
| static float er_sde_noise_scaler(float x) { | ||
| x = std::max(x, 0.0f); | ||
| return x * (std::exp(std::pow(x, 0.3f)) + 10.0f); | ||
| } | ||
|
|
||
| static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model, | ||
| sd::Tensor<float> x, | ||
| const std::vector<float>& sigmas, | ||
|
|
@@ -1295,6 +1322,112 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model, | |
| return x; | ||
| } | ||
|
|
||
| static sd::Tensor<float> sample_er_sde(denoise_cb_t model, | ||
| sd::Tensor<float> x, | ||
| std::vector<float> sigmas, | ||
| std::shared_ptr<RNG> rng, | ||
| bool is_flow_denoiser) { | ||
| constexpr int max_stage = 3; | ||
| constexpr int num_integration_points = 200; | ||
| constexpr float num_integration_points_f = 200.0f; | ||
| constexpr float s_noise = 1.0f; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we don't have a separate |
||
|
|
||
| if (is_flow_denoiser) { | ||
| for (size_t i = 0; i + 1 < sigmas.size(); ++i) { | ||
| if (sigmas[i] > 1.0f) { | ||
| sigmas[i] = er_sde_flow_sigma(sigmas[i]); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| std::vector<float> er_lambdas(sigmas.size(), 0.0f); | ||
| for (size_t i = 0; i < sigmas.size(); ++i) { | ||
| er_lambdas[i] = sigma_to_er_sde_lambda(sigmas[i], is_flow_denoiser); | ||
| } | ||
|
|
||
| sd::Tensor<float> old_denoised = x; | ||
| sd::Tensor<float> old_denoised_d = x; | ||
| bool have_old_denoised = false; | ||
| bool have_old_denoised_d = false; | ||
|
|
||
| int steps = static_cast<int>(sigmas.size()) - 1; | ||
| for (int i = 0; i < steps; i++) { | ||
| sd::Tensor<float> denoised = model(x, sigmas[i], i + 1); | ||
| if (denoised.empty()) { | ||
| return {}; | ||
| } | ||
|
|
||
| int stage_used = std::min(max_stage, i + 1); | ||
|
|
||
| if (sigmas[i + 1] == 0.0f) { | ||
| x = denoised; | ||
| } else { | ||
| float er_lambda_s = er_lambdas[i]; | ||
| float er_lambda_t = er_lambdas[i + 1]; | ||
| float alpha_s = sigma_to_er_sde_alpha(sigmas[i], is_flow_denoiser); | ||
| float alpha_t = sigma_to_er_sde_alpha(sigmas[i + 1], is_flow_denoiser); | ||
| float scaled_s = er_sde_noise_scaler(er_lambda_s); | ||
| float scaled_t = er_sde_noise_scaler(er_lambda_t); | ||
| float r_alpha = alpha_s > 0.0f ? alpha_t / alpha_s : 0.0f; | ||
| float r = scaled_s > 0.0f ? scaled_t / scaled_s : 0.0f; | ||
|
|
||
| x = r_alpha * r * x + alpha_t * (1.0f - r) * denoised; | ||
|
|
||
| if (stage_used >= 2 && have_old_denoised) { | ||
| float dt = er_lambda_t - er_lambda_s; | ||
| float lambda_step_size = -dt / num_integration_points_f; | ||
| float s = 0.0f; | ||
| float s_u = 0.0f; | ||
|
|
||
| for (int p = 0; p < num_integration_points; ++p) { | ||
| float lambda_pos = er_lambda_t + p * lambda_step_size; | ||
| float scaled_pos = er_sde_noise_scaler(lambda_pos); | ||
| if (scaled_pos <= 0.0f) { | ||
| continue; | ||
| } | ||
|
|
||
| s += 1.0f / scaled_pos; | ||
| if (stage_used >= 3 && have_old_denoised_d) { | ||
| s_u += (lambda_pos - er_lambda_s) / scaled_pos; | ||
| } | ||
| } | ||
|
|
||
| s *= lambda_step_size; | ||
|
|
||
| float denom_d = er_lambda_s - er_lambdas[i - 1]; | ||
| if (std::fabs(denom_d) > 1e-12f) { | ||
| float coeff_d = alpha_t * (dt + s * scaled_t); | ||
| sd::Tensor<float> denoised_d = (denoised - old_denoised) / denom_d; | ||
| x += coeff_d * denoised_d; | ||
|
|
||
| if (stage_used >= 3 && have_old_denoised_d) { | ||
| float denom_u = (er_lambda_s - er_lambdas[i - 2]) * 0.5f; | ||
| if (std::fabs(denom_u) > 1e-12f) { | ||
| s_u *= lambda_step_size; | ||
| float coeff_u = alpha_t * (0.5f * dt * dt + s_u * scaled_t); | ||
| sd::Tensor<float> denoised_u = (denoised_d - old_denoised_d) / denom_u; | ||
| x += coeff_u * denoised_u; | ||
| } | ||
| } | ||
|
|
||
| old_denoised_d = denoised_d; | ||
| have_old_denoised_d = true; | ||
| } | ||
| } | ||
|
|
||
| float noise_scale_sq = er_lambda_t * er_lambda_t - er_lambda_s * er_lambda_s * r * r; | ||
| if (s_noise > 0.0f && noise_scale_sq > 0.0f) { | ||
| float noise_scale = alpha_t * std::sqrt(std::max(noise_scale_sq, 0.0f)); | ||
| x += sd::Tensor<float>::randn_like(x, rng) * noise_scale; | ||
| } | ||
| } | ||
|
|
||
| old_denoised = denoised; | ||
| have_old_denoised = true; | ||
| } | ||
| return x; | ||
| } | ||
|
|
||
| static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model, | ||
| sd::Tensor<float> x, | ||
| const std::vector<float>& sigmas, | ||
|
|
@@ -1456,6 +1589,8 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method, | |
| return sample_res_multistep(model, std::move(x), sigmas, rng, eta); | ||
| case RES_2S_SAMPLE_METHOD: | ||
| return sample_res_2s(model, std::move(x), sigmas, rng, eta); | ||
| case ER_SDE_SAMPLE_METHOD: | ||
| return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser); | ||
| case DDIM_TRAILING_SAMPLE_METHOD: | ||
| return sample_ddim_trailing(model, std::move(x), sigmas, rng, eta); | ||
| case TCD_SAMPLE_METHOD: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these won't be reused for other samplers, I'd suggest keeping them as lambdas inside
sample_er_sde, or at least close to that function.