Add support for GLIGEN textbox model.
This commit is contained in:
parent
472b1cc0d8
commit
3696d1699a
|
@ -0,0 +1,343 @@
|
|||
import torch
|
||||
from torch import nn, einsum
|
||||
from ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * torch.nn.functional.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class GatedCrossAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
super().__init__()
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||
# for example, when it is set to 0, then the entire model is same as
|
||||
# original one
|
||||
self.scale = 1
|
||||
|
||||
def forward(self, x, objs):
|
||||
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GatedSelfAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
super().__init__()
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim,
|
||||
context_dim=query_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||
# for example, when it is set to 0, then the entire model is same as
|
||||
# original one
|
||||
self.scale = 1
|
||||
|
||||
def forward(self, x, objs):
|
||||
|
||||
N_visual = x.shape[1]
|
||||
objs = self.linear(objs)
|
||||
|
||||
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
|
||||
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GatedSelfAttentionDense2(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
super().__init__()
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||
# for example, when it is set to 0, then the entire model is same as
|
||||
# original one
|
||||
self.scale = 1
|
||||
|
||||
def forward(self, x, objs):
|
||||
|
||||
B, N_visual, _ = x.shape
|
||||
B, N_ground, _ = objs.shape
|
||||
|
||||
objs = self.linear(objs)
|
||||
|
||||
# sanity check
|
||||
size_v = math.sqrt(N_visual)
|
||||
size_g = math.sqrt(N_ground)
|
||||
assert int(size_v) == size_v, "Visual tokens must be square rootable"
|
||||
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
|
||||
size_v = int(size_v)
|
||||
size_g = int(size_g)
|
||||
|
||||
# select grounding token and resize it to visual token size as residual
|
||||
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
|
||||
:, N_visual:, :]
|
||||
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
|
||||
out = torch.nn.functional.interpolate(
|
||||
out, (size_v, size_v), mode='bicubic')
|
||||
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
|
||||
|
||||
# add residual to visual feature
|
||||
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FourierEmbedder():
|
||||
def __init__(self, num_freqs=64, temperature=100):
|
||||
|
||||
self.num_freqs = num_freqs
|
||||
self.temperature = temperature
|
||||
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, x, cat_dim=-1):
|
||||
"x: arbitrary shape of tensor. dim: cat dim"
|
||||
out = []
|
||||
for freq in self.freq_bands:
|
||||
out.append(torch.sin(freq * x))
|
||||
out.append(torch.cos(freq * x))
|
||||
return torch.cat(out, cat_dim)
|
||||
|
||||
|
||||
class PositionNet(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, fourier_freqs=8):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
||||
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
||||
|
||||
self.linears = nn.Sequential(
|
||||
nn.Linear(self.in_dim + self.position_dim, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, out_dim),
|
||||
)
|
||||
|
||||
self.null_positive_feature = torch.nn.Parameter(
|
||||
torch.zeros([self.in_dim]))
|
||||
self.null_position_feature = torch.nn.Parameter(
|
||||
torch.zeros([self.position_dim]))
|
||||
|
||||
def forward(self, boxes, masks, positive_embeddings):
|
||||
B, N, _ = boxes.shape
|
||||
masks = masks.unsqueeze(-1)
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||
|
||||
# replace padding with learnable null embedding
|
||||
positive_embeddings = positive_embeddings * \
|
||||
masks + (1 - masks) * positive_null
|
||||
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
||||
|
||||
objs = self.linears(
|
||||
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
||||
assert objs.shape == torch.Size([B, N, self.out_dim])
|
||||
return objs
|
||||
|
||||
|
||||
class Gligen(nn.Module):
|
||||
def __init__(self, modules, position_net, key_dim):
|
||||
super().__init__()
|
||||
self.module_list = nn.ModuleList(modules)
|
||||
self.position_net = position_net
|
||||
self.key_dim = key_dim
|
||||
self.max_objs = 30
|
||||
|
||||
def _set_position(self, boxes, masks, positive_embeddings):
|
||||
objs = self.position_net(boxes, masks, positive_embeddings)
|
||||
|
||||
def func(key, x):
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
|
||||
def set_position(self, latent_image_shape, position_params, device):
|
||||
batch, c, h, w = latent_image_shape
|
||||
masks = torch.zeros([self.max_objs], device="cpu")
|
||||
boxes = []
|
||||
positive_embeddings = []
|
||||
for p in position_params:
|
||||
x1 = (p[4]) / w
|
||||
y1 = (p[3]) / h
|
||||
x2 = (p[4] + p[2]) / w
|
||||
y2 = (p[3] + p[1]) / h
|
||||
masks[len(boxes)] = 1.0
|
||||
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
|
||||
positive_embeddings += [p[0]]
|
||||
append_boxes = []
|
||||
append_conds = []
|
||||
if len(boxes) < self.max_objs:
|
||||
append_boxes = [torch.zeros(
|
||||
[self.max_objs - len(boxes), 4], device="cpu")]
|
||||
append_conds = [torch.zeros(
|
||||
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
|
||||
|
||||
box_out = torch.cat(
|
||||
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
|
||||
masks = masks.unsqueeze(0).repeat(batch, 1)
|
||||
conds = torch.cat(positive_embeddings +
|
||||
append_conds).unsqueeze(0).repeat(batch, 1, 1)
|
||||
return self._set_position(
|
||||
box_out.to(device),
|
||||
masks.to(device),
|
||||
conds.to(device))
|
||||
|
||||
def set_empty(self, latent_image_shape, device):
|
||||
batch, c, h, w = latent_image_shape
|
||||
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
|
||||
box_out = torch.zeros([self.max_objs, 4],
|
||||
device="cpu").repeat(batch, 1, 1)
|
||||
conds = torch.zeros([self.max_objs, self.key_dim],
|
||||
device="cpu").repeat(batch, 1, 1)
|
||||
return self._set_position(
|
||||
box_out.to(device),
|
||||
masks.to(device),
|
||||
conds.to(device))
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
def get_models(self):
|
||||
return [self]
|
||||
|
||||
def load_gligen(sd):
|
||||
sd_k = sd.keys()
|
||||
output_list = []
|
||||
key_dim = 768
|
||||
for a in ["input_blocks", "middle_block", "output_blocks"]:
|
||||
for b in range(20):
|
||||
k_temp = filter(lambda k: "{}.{}.".format(a, b)
|
||||
in k and ".fuser." in k, sd_k)
|
||||
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
|
||||
|
||||
n_sd = {}
|
||||
for k in k_temp:
|
||||
n_sd[k[1]] = sd[k[0]]
|
||||
if len(n_sd) > 0:
|
||||
query_dim = n_sd["linear.weight"].shape[0]
|
||||
key_dim = n_sd["linear.weight"].shape[1]
|
||||
|
||||
if key_dim == 768: # SD1.x
|
||||
n_heads = 8
|
||||
d_head = query_dim // n_heads
|
||||
else:
|
||||
d_head = 64
|
||||
n_heads = query_dim // d_head
|
||||
|
||||
gated = GatedSelfAttentionDense(
|
||||
query_dim, key_dim, n_heads, d_head)
|
||||
gated.load_state_dict(n_sd, strict=False)
|
||||
output_list.append(gated)
|
||||
|
||||
if "position_net.null_positive_feature" in sd_k:
|
||||
in_dim = sd["position_net.null_positive_feature"].shape[0]
|
||||
out_dim = sd["position_net.linears.4.weight"].shape[0]
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.position_net = PositionNet(in_dim, out_dim)
|
||||
w.load_state_dict(sd, strict=False)
|
||||
|
||||
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||
return gligen
|
|
@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module):
|
|||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
current_index = None
|
||||
if "current_index" in transformer_options:
|
||||
current_index = transformer_options["current_index"]
|
||||
if "patches" in transformer_options:
|
||||
transformer_patches = transformer_options["patches"]
|
||||
else:
|
||||
transformer_patches = {}
|
||||
|
||||
n = self.norm1(x)
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
|
@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module):
|
|||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||
|
||||
x += n
|
||||
if "middle_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_patch"]
|
||||
for p in patch:
|
||||
x = p(current_index, x)
|
||||
|
||||
n = self.norm2(x)
|
||||
n = self.attn2(n, context=context)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
||||
if current_index is not None:
|
||||
transformer_options["current_index"] += 1
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -782,6 +782,8 @@ class UNetModel(nn.Module):
|
|||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["current_index"] = 0
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
|
|
|
@ -176,7 +176,7 @@ def load_model_gpu(model):
|
|||
model_accelerated = True
|
||||
return current_loaded_model
|
||||
|
||||
def load_controlnet_gpu(models):
|
||||
def load_controlnet_gpu(control_models):
|
||||
global current_gpu_controlnets
|
||||
global vram_state
|
||||
if vram_state == VRAMState.CPU:
|
||||
|
@ -186,6 +186,10 @@ def load_controlnet_gpu(models):
|
|||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||
return
|
||||
|
||||
models = []
|
||||
for m in control_models:
|
||||
models += m.get_models()
|
||||
|
||||
for m in current_gpu_controlnets:
|
||||
if m not in models:
|
||||
m.cpu()
|
||||
|
|
|
@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
control = None
|
||||
if 'control' in cond[1]:
|
||||
control = cond[1]['control']
|
||||
return (input_x, mult, conditionning, area, control)
|
||||
|
||||
patches = None
|
||||
if 'gligen' in cond[1]:
|
||||
gligen = cond[1]['gligen']
|
||||
patches = {}
|
||||
gligen_type = gligen[0]
|
||||
gligen_model = gligen[1]
|
||||
if gligen_type == "position":
|
||||
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||
else:
|
||||
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
return (input_x, mult, conditionning, area, control, patches)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
|
@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
def can_concat_cond(c1, c2):
|
||||
if c1[0].shape != c2[0].shape:
|
||||
return False
|
||||
|
||||
#control
|
||||
if (c1[4] is None) != (c2[4] is None):
|
||||
return False
|
||||
if c1[4] is not None:
|
||||
if c1[4] is not c2[4]:
|
||||
return False
|
||||
|
||||
#patches
|
||||
if (c1[5] is None) != (c2[5] is None):
|
||||
return False
|
||||
if (c1[5] is not None):
|
||||
if c1[5] is not c2[5]:
|
||||
return False
|
||||
|
||||
return cond_equal_size(c1[2], c2[2])
|
||||
|
||||
def cond_cat(c_list):
|
||||
|
@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
|
@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
area += [p[3]]
|
||||
cond_or_uncond += [o[1]]
|
||||
control = p[4]
|
||||
patches = p[5]
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
|
@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
c['transformer_options'] = model_options['transformer_options']
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c):
|
|||
n = c[1].copy()
|
||||
conds += [[smallest[0], n]]
|
||||
|
||||
|
||||
def apply_control_net_to_equal_area(conds, uncond):
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
cond_other = []
|
||||
uncond_cnets = []
|
||||
|
@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond):
|
|||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
if 'area' not in x[1]:
|
||||
if 'control' in x[1] and x[1]['control'] is not None:
|
||||
cond_cnets.append(x[1]['control'])
|
||||
if name in x[1] and x[1][name] is not None:
|
||||
cond_cnets.append(x[1][name])
|
||||
else:
|
||||
cond_other.append((x, t))
|
||||
for t in range(len(uncond)):
|
||||
x = uncond[t]
|
||||
if 'area' not in x[1]:
|
||||
if 'control' in x[1] and x[1]['control'] is not None:
|
||||
uncond_cnets.append(x[1]['control'])
|
||||
if name in x[1] and x[1][name] is not None:
|
||||
uncond_cnets.append(x[1][name])
|
||||
else:
|
||||
uncond_other.append((x, t))
|
||||
|
||||
|
@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond):
|
|||
for x in range(len(cond_cnets)):
|
||||
temp = uncond_other[x % len(uncond_other)]
|
||||
o = temp[0]
|
||||
if 'control' in o[1] and o[1]['control'] is not None:
|
||||
if name in o[1] and o[1][name] is not None:
|
||||
n = o[1].copy()
|
||||
n['control'] = cond_cnets[x]
|
||||
n[name] = uncond_fill_func(cond_cnets, x)
|
||||
uncond += [[o[0], n]]
|
||||
else:
|
||||
n = o[1].copy()
|
||||
n['control'] = cond_cnets[x]
|
||||
n[name] = uncond_fill_func(cond_cnets, x)
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
|
||||
|
||||
def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
|||
|
||||
return conds
|
||||
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
|
@ -466,7 +498,8 @@ class KSampler:
|
|||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
apply_control_net_to_equal_area(positive, negative)
|
||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
|
|
22
comfy/sd.py
22
comfy/sd.py
|
@ -13,6 +13,7 @@ from .t2i_adapter import adapter
|
|||
|
||||
from . import utils
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
|
||||
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
@ -378,7 +379,7 @@ class CLIP:
|
|||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def encode_from_tokens(self, tokens):
|
||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||
try:
|
||||
|
@ -388,6 +389,10 @@ class CLIP:
|
|||
except Exception as e:
|
||||
self.patcher.unpatch_model()
|
||||
raise e
|
||||
if return_pooled:
|
||||
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__)
|
||||
pooled = cond[:, eos_token_index]
|
||||
return cond, pooled
|
||||
return cond
|
||||
|
||||
def encode(self, text):
|
||||
|
@ -564,10 +569,10 @@ class ControlNet:
|
|||
c.strength = self.strength
|
||||
return c
|
||||
|
||||
def get_control_models(self):
|
||||
def get_models(self):
|
||||
out = []
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_control_models()
|
||||
out += self.previous_controlnet.get_models()
|
||||
out.append(self.control_model)
|
||||
return out
|
||||
|
||||
|
@ -737,10 +742,10 @@ class T2IAdapter:
|
|||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
def get_control_models(self):
|
||||
def get_models(self):
|
||||
out = []
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_control_models()
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
|
@ -787,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None):
|
|||
clip.load_from_state_dict(clip_data)
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
data = utils.load_torch_file(ckpt_path)
|
||||
model = gligen.load_gligen(data)
|
||||
if model_management.should_use_fp16():
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
|
|
|
@ -26,6 +26,8 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
|
|||
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
||||
|
||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
||||
|
|
71
nodes.py
71
nodes.py
|
@ -490,6 +490,51 @@ class unCLIPConditioning:
|
|||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class GLIGENLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
|
||||
|
||||
RETURN_TYPES = ("GLIGEN",)
|
||||
FUNCTION = "load_gligen"
|
||||
|
||||
CATEGORY = "_for_testing/gligen"
|
||||
|
||||
def load_gligen(self, gligen_name):
|
||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
||||
gligen = comfy.sd.load_gligen(gligen_path)
|
||||
return (gligen,)
|
||||
|
||||
class GLIGENTextBoxApply:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning_to": ("CONDITIONING", ),
|
||||
"clip": ("CLIP", ),
|
||||
"gligen_textbox_model": ("GLIGEN", ),
|
||||
"text": ("STRING", {"multiline": True}),
|
||||
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "_for_testing/gligen"
|
||||
|
||||
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
|
||||
c = []
|
||||
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
|
||||
for t in conditioning_to:
|
||||
n = [t[0], t[1].copy()]
|
||||
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
|
||||
prev = []
|
||||
if "gligen" in n[1]:
|
||||
prev = n[1]['gligen'][2]
|
||||
|
||||
n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
|
@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||
negative_copy = []
|
||||
|
||||
control_nets = []
|
||||
def get_models(cond):
|
||||
models = []
|
||||
for c in cond:
|
||||
if 'control' in c[1]:
|
||||
models += [c[1]['control']]
|
||||
if 'gligen' in c[1]:
|
||||
models += [c[1]['gligen'][1]]
|
||||
return models
|
||||
|
||||
for p in positive:
|
||||
t = p[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in p[1]:
|
||||
control_nets += [p[1]['control']]
|
||||
positive_copy += [[t] + p[1:]]
|
||||
for n in negative:
|
||||
t = n[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in n[1]:
|
||||
control_nets += [n[1]['control']]
|
||||
negative_copy += [[t] + n[1:]]
|
||||
|
||||
control_net_models = []
|
||||
for x in control_nets:
|
||||
control_net_models += x.get_control_models()
|
||||
comfy.model_management.load_controlnet_gpu(control_net_models)
|
||||
models = get_models(positive) + get_models(negative)
|
||||
comfy.model_management.load_controlnet_gpu(models)
|
||||
|
||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
||||
samples = samples.cpu()
|
||||
for c in control_nets:
|
||||
c.cleanup()
|
||||
for m in models:
|
||||
m.cleanup()
|
||||
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
|
@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = {
|
|||
"VAEEncodeTiled": VAEEncodeTiled,
|
||||
"TomePatchModel": TomePatchModel,
|
||||
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||
"GLIGENLoader": GLIGENLoader,
|
||||
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue