fix dpm2 in img2img as well

parent 180fdf78
......@@ -454,6 +454,9 @@ class KDiffusionSampler:
else:
sigmas = self.model_wrap.get_sigmas(steps)
if self.funcname in ['sample_dpm_2_ancestral', 'sample_dpm_2']:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment