made sample functions more explicit
This commit is contained in:
parent
5818539743
commit
d9b1595f85
|
@ -2,30 +2,25 @@ import torch
|
|||
import comfy.model_management
|
||||
|
||||
|
||||
def prepare_noise(latent, seed):
|
||||
"""creates random noise given a LATENT and a seed"""
|
||||
latent_image = latent["samples"]
|
||||
batch_index = 0
|
||||
if "batch_index" in latent:
|
||||
batch_index = latent["batch_index"]
|
||||
|
||||
def prepare_noise(latent_image, seed, skip=0):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
for i in range(batch_index):
|
||||
for _ in range(skip):
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
return noise
|
||||
|
||||
def create_mask(latent, noise):
|
||||
"""creates a mask for a given LATENT and noise"""
|
||||
noise_mask = None
|
||||
def prepare_mask(noise_mask, noise):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
device = comfy.model_management.get_torch_device()
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent['noise_mask']
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
def broadcast_cond(cond, noise):
|
||||
|
@ -40,22 +35,20 @@ def broadcast_cond(cond, noise):
|
|||
copy += [[t] + p[1:]]
|
||||
return copy
|
||||
|
||||
def load_c_nets(positive, negative):
|
||||
"""loads control nets in positive and negative conditioning"""
|
||||
def get_models(cond):
|
||||
models = []
|
||||
for c in cond:
|
||||
if 'control' in c[1]:
|
||||
models += [c[1]['control']]
|
||||
if 'gligen' in c[1]:
|
||||
models += [c[1]['gligen'][1]]
|
||||
return models
|
||||
|
||||
return get_models(positive) + get_models(negative)
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c[1]:
|
||||
models += [c[1][model_type]]
|
||||
return models
|
||||
|
||||
def load_additional_models(positive, negative):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
models = load_c_nets(positive, negative)
|
||||
models = []
|
||||
models += get_models_from_cond(positive, "control")
|
||||
models += get_models_from_cond(negative, "control")
|
||||
models += get_models_from_cond(positive, "gligen")
|
||||
models += get_models_from_cond(negative, "gligen")
|
||||
comfy.model_management.load_controlnet_gpu(models)
|
||||
return models
|
||||
|
||||
|
|
7
nodes.py
7
nodes.py
|
@ -747,9 +747,12 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
else:
|
||||
noise = comfy.sample.prepare_noise(latent, seed)
|
||||
skip = latent["batch_index"] if "batch_index" in latent else 0
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
|
||||
|
||||
noise_mask = comfy.sample.create_mask(latent, noise)
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = comfy.sample.prepare_mask(latent["noise_mask"], noise)
|
||||
|
||||
real_model = None
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
|
|
Loading…
Reference in New Issue