Only append zero to noise schedule if last sigma isn't zero.
This commit is contained in:
parent
11b74147ee
commit
95fa9545f1
|
@ -313,13 +313,18 @@ def simple_scheduler(model_sampling, steps):
|
||||||
def ddim_scheduler(model_sampling, steps):
|
def ddim_scheduler(model_sampling, steps):
|
||||||
s = model_sampling
|
s = model_sampling
|
||||||
sigs = []
|
sigs = []
|
||||||
ss = max(len(s.sigmas) // steps, 1)
|
|
||||||
x = 1
|
x = 1
|
||||||
|
if math.isclose(float(s.sigmas[x]), 0, abs_tol=0.00001):
|
||||||
|
steps += 1
|
||||||
|
sigs = []
|
||||||
|
else:
|
||||||
|
sigs = [0.0]
|
||||||
|
|
||||||
|
ss = max(len(s.sigmas) // steps, 1)
|
||||||
while x < len(s.sigmas):
|
while x < len(s.sigmas):
|
||||||
sigs += [float(s.sigmas[x])]
|
sigs += [float(s.sigmas[x])]
|
||||||
x += ss
|
x += ss
|
||||||
sigs = sigs[::-1]
|
sigs = sigs[::-1]
|
||||||
sigs += [0.0]
|
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||||
|
@ -327,16 +332,23 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||||
start = s.timestep(s.sigma_max)
|
start = s.timestep(s.sigma_max)
|
||||||
end = s.timestep(s.sigma_min)
|
end = s.timestep(s.sigma_min)
|
||||||
|
|
||||||
|
append_zero = True
|
||||||
if sgm:
|
if sgm:
|
||||||
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||||
else:
|
else:
|
||||||
|
if math.isclose(float(s.sigma(end)), 0, abs_tol=0.00001):
|
||||||
|
steps += 1
|
||||||
|
append_zero = False
|
||||||
timesteps = torch.linspace(start, end, steps)
|
timesteps = torch.linspace(start, end, steps)
|
||||||
|
|
||||||
sigs = []
|
sigs = []
|
||||||
for x in range(len(timesteps)):
|
for x in range(len(timesteps)):
|
||||||
ts = timesteps[x]
|
ts = timesteps[x]
|
||||||
sigs.append(s.sigma(ts))
|
sigs.append(float(s.sigma(ts)))
|
||||||
sigs += [0.0]
|
|
||||||
|
if append_zero:
|
||||||
|
sigs += [0.0]
|
||||||
|
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
# Implemented based on: https://arxiv.org/abs/2407.12173
|
# Implemented based on: https://arxiv.org/abs/2407.12173
|
||||||
|
|
Loading…
Reference in New Issue