Add support for GLIGEN textbox model.
This commit is contained in:
parent
472b1cc0d8
commit
3696d1699a
|
@ -0,0 +1,343 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn, einsum
|
||||||
|
from ldm.modules.attention import CrossAttention
|
||||||
|
from inspect import isfunction
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def uniq(arr):
|
||||||
|
return{el: True for el in arr}.keys()
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
|
|
||||||
|
# feedforward
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
|
return x * torch.nn.functional.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = default(dim_out, dim)
|
||||||
|
project_in = nn.Sequential(
|
||||||
|
nn.Linear(dim, inner_dim),
|
||||||
|
nn.GELU()
|
||||||
|
) if not glu else GEGLU(dim, inner_dim)
|
||||||
|
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
project_in,
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(inner_dim, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GatedCrossAttentionDense(nn.Module):
|
||||||
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = CrossAttention(
|
||||||
|
query_dim=query_dim,
|
||||||
|
context_dim=context_dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head)
|
||||||
|
self.ff = FeedForward(query_dim, glu=True)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(query_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(query_dim)
|
||||||
|
|
||||||
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||||
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||||
|
|
||||||
|
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||||
|
# for example, when it is set to 0, then the entire model is same as
|
||||||
|
# original one
|
||||||
|
self.scale = 1
|
||||||
|
|
||||||
|
def forward(self, x, objs):
|
||||||
|
|
||||||
|
x = x + self.scale * \
|
||||||
|
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
|
||||||
|
x = x + self.scale * \
|
||||||
|
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GatedSelfAttentionDense(nn.Module):
|
||||||
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# we need a linear projection since we need cat visual feature and obj
|
||||||
|
# feature
|
||||||
|
self.linear = nn.Linear(context_dim, query_dim)
|
||||||
|
|
||||||
|
self.attn = CrossAttention(
|
||||||
|
query_dim=query_dim,
|
||||||
|
context_dim=query_dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head)
|
||||||
|
self.ff = FeedForward(query_dim, glu=True)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(query_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(query_dim)
|
||||||
|
|
||||||
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||||
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||||
|
|
||||||
|
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||||
|
# for example, when it is set to 0, then the entire model is same as
|
||||||
|
# original one
|
||||||
|
self.scale = 1
|
||||||
|
|
||||||
|
def forward(self, x, objs):
|
||||||
|
|
||||||
|
N_visual = x.shape[1]
|
||||||
|
objs = self.linear(objs)
|
||||||
|
|
||||||
|
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
|
||||||
|
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
|
||||||
|
x = x + self.scale * \
|
||||||
|
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GatedSelfAttentionDense2(nn.Module):
|
||||||
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# we need a linear projection since we need cat visual feature and obj
|
||||||
|
# feature
|
||||||
|
self.linear = nn.Linear(context_dim, query_dim)
|
||||||
|
|
||||||
|
self.attn = CrossAttention(
|
||||||
|
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
|
||||||
|
self.ff = FeedForward(query_dim, glu=True)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(query_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(query_dim)
|
||||||
|
|
||||||
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||||
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||||
|
|
||||||
|
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||||
|
# for example, when it is set to 0, then the entire model is same as
|
||||||
|
# original one
|
||||||
|
self.scale = 1
|
||||||
|
|
||||||
|
def forward(self, x, objs):
|
||||||
|
|
||||||
|
B, N_visual, _ = x.shape
|
||||||
|
B, N_ground, _ = objs.shape
|
||||||
|
|
||||||
|
objs = self.linear(objs)
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
size_v = math.sqrt(N_visual)
|
||||||
|
size_g = math.sqrt(N_ground)
|
||||||
|
assert int(size_v) == size_v, "Visual tokens must be square rootable"
|
||||||
|
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
|
||||||
|
size_v = int(size_v)
|
||||||
|
size_g = int(size_g)
|
||||||
|
|
||||||
|
# select grounding token and resize it to visual token size as residual
|
||||||
|
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
|
||||||
|
:, N_visual:, :]
|
||||||
|
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
|
||||||
|
out = torch.nn.functional.interpolate(
|
||||||
|
out, (size_v, size_v), mode='bicubic')
|
||||||
|
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
|
||||||
|
|
||||||
|
# add residual to visual feature
|
||||||
|
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
|
||||||
|
x = x + self.scale * \
|
||||||
|
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FourierEmbedder():
|
||||||
|
def __init__(self, num_freqs=64, temperature=100):
|
||||||
|
|
||||||
|
self.num_freqs = num_freqs
|
||||||
|
self.temperature = temperature
|
||||||
|
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, x, cat_dim=-1):
|
||||||
|
"x: arbitrary shape of tensor. dim: cat dim"
|
||||||
|
out = []
|
||||||
|
for freq in self.freq_bands:
|
||||||
|
out.append(torch.sin(freq * x))
|
||||||
|
out.append(torch.cos(freq * x))
|
||||||
|
return torch.cat(out, cat_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionNet(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, fourier_freqs=8):
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
||||||
|
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
||||||
|
|
||||||
|
self.linears = nn.Sequential(
|
||||||
|
nn.Linear(self.in_dim + self.position_dim, 512),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(512, 512),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(512, out_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.null_positive_feature = torch.nn.Parameter(
|
||||||
|
torch.zeros([self.in_dim]))
|
||||||
|
self.null_position_feature = torch.nn.Parameter(
|
||||||
|
torch.zeros([self.position_dim]))
|
||||||
|
|
||||||
|
def forward(self, boxes, masks, positive_embeddings):
|
||||||
|
B, N, _ = boxes.shape
|
||||||
|
masks = masks.unsqueeze(-1)
|
||||||
|
|
||||||
|
# embedding position (it may includes padding as placeholder)
|
||||||
|
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||||
|
|
||||||
|
# learnable null embedding
|
||||||
|
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||||
|
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||||
|
|
||||||
|
# replace padding with learnable null embedding
|
||||||
|
positive_embeddings = positive_embeddings * \
|
||||||
|
masks + (1 - masks) * positive_null
|
||||||
|
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
||||||
|
|
||||||
|
objs = self.linears(
|
||||||
|
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
||||||
|
assert objs.shape == torch.Size([B, N, self.out_dim])
|
||||||
|
return objs
|
||||||
|
|
||||||
|
|
||||||
|
class Gligen(nn.Module):
|
||||||
|
def __init__(self, modules, position_net, key_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.module_list = nn.ModuleList(modules)
|
||||||
|
self.position_net = position_net
|
||||||
|
self.key_dim = key_dim
|
||||||
|
self.max_objs = 30
|
||||||
|
|
||||||
|
def _set_position(self, boxes, masks, positive_embeddings):
|
||||||
|
objs = self.position_net(boxes, masks, positive_embeddings)
|
||||||
|
|
||||||
|
def func(key, x):
|
||||||
|
module = self.module_list[key]
|
||||||
|
return module(x, objs)
|
||||||
|
return func
|
||||||
|
|
||||||
|
def set_position(self, latent_image_shape, position_params, device):
|
||||||
|
batch, c, h, w = latent_image_shape
|
||||||
|
masks = torch.zeros([self.max_objs], device="cpu")
|
||||||
|
boxes = []
|
||||||
|
positive_embeddings = []
|
||||||
|
for p in position_params:
|
||||||
|
x1 = (p[4]) / w
|
||||||
|
y1 = (p[3]) / h
|
||||||
|
x2 = (p[4] + p[2]) / w
|
||||||
|
y2 = (p[3] + p[1]) / h
|
||||||
|
masks[len(boxes)] = 1.0
|
||||||
|
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
|
||||||
|
positive_embeddings += [p[0]]
|
||||||
|
append_boxes = []
|
||||||
|
append_conds = []
|
||||||
|
if len(boxes) < self.max_objs:
|
||||||
|
append_boxes = [torch.zeros(
|
||||||
|
[self.max_objs - len(boxes), 4], device="cpu")]
|
||||||
|
append_conds = [torch.zeros(
|
||||||
|
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
|
||||||
|
|
||||||
|
box_out = torch.cat(
|
||||||
|
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
|
||||||
|
masks = masks.unsqueeze(0).repeat(batch, 1)
|
||||||
|
conds = torch.cat(positive_embeddings +
|
||||||
|
append_conds).unsqueeze(0).repeat(batch, 1, 1)
|
||||||
|
return self._set_position(
|
||||||
|
box_out.to(device),
|
||||||
|
masks.to(device),
|
||||||
|
conds.to(device))
|
||||||
|
|
||||||
|
def set_empty(self, latent_image_shape, device):
|
||||||
|
batch, c, h, w = latent_image_shape
|
||||||
|
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
|
||||||
|
box_out = torch.zeros([self.max_objs, 4],
|
||||||
|
device="cpu").repeat(batch, 1, 1)
|
||||||
|
conds = torch.zeros([self.max_objs, self.key_dim],
|
||||||
|
device="cpu").repeat(batch, 1, 1)
|
||||||
|
return self._set_position(
|
||||||
|
box_out.to(device),
|
||||||
|
masks.to(device),
|
||||||
|
conds.to(device))
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
return [self]
|
||||||
|
|
||||||
|
def load_gligen(sd):
|
||||||
|
sd_k = sd.keys()
|
||||||
|
output_list = []
|
||||||
|
key_dim = 768
|
||||||
|
for a in ["input_blocks", "middle_block", "output_blocks"]:
|
||||||
|
for b in range(20):
|
||||||
|
k_temp = filter(lambda k: "{}.{}.".format(a, b)
|
||||||
|
in k and ".fuser." in k, sd_k)
|
||||||
|
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
|
||||||
|
|
||||||
|
n_sd = {}
|
||||||
|
for k in k_temp:
|
||||||
|
n_sd[k[1]] = sd[k[0]]
|
||||||
|
if len(n_sd) > 0:
|
||||||
|
query_dim = n_sd["linear.weight"].shape[0]
|
||||||
|
key_dim = n_sd["linear.weight"].shape[1]
|
||||||
|
|
||||||
|
if key_dim == 768: # SD1.x
|
||||||
|
n_heads = 8
|
||||||
|
d_head = query_dim // n_heads
|
||||||
|
else:
|
||||||
|
d_head = 64
|
||||||
|
n_heads = query_dim // d_head
|
||||||
|
|
||||||
|
gated = GatedSelfAttentionDense(
|
||||||
|
query_dim, key_dim, n_heads, d_head)
|
||||||
|
gated.load_state_dict(n_sd, strict=False)
|
||||||
|
output_list.append(gated)
|
||||||
|
|
||||||
|
if "position_net.null_positive_feature" in sd_k:
|
||||||
|
in_dim = sd["position_net.null_positive_feature"].shape[0]
|
||||||
|
out_dim = sd["position_net.linears.4.weight"].shape[0]
|
||||||
|
|
||||||
|
class WeightsLoader(torch.nn.Module):
|
||||||
|
pass
|
||||||
|
w = WeightsLoader()
|
||||||
|
w.position_net = PositionNet(in_dim, out_dim)
|
||||||
|
w.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||||
|
return gligen
|
|
@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
|
current_index = None
|
||||||
|
if "current_index" in transformer_options:
|
||||||
|
current_index = transformer_options["current_index"]
|
||||||
|
if "patches" in transformer_options:
|
||||||
|
transformer_patches = transformer_options["patches"]
|
||||||
|
else:
|
||||||
|
transformer_patches = {}
|
||||||
|
|
||||||
n = self.norm1(x)
|
n = self.norm1(x)
|
||||||
if "tomesd" in transformer_options:
|
if "tomesd" in transformer_options:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
|
@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module):
|
||||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
|
if "middle_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["middle_patch"]
|
||||||
|
for p in patch:
|
||||||
|
x = p(current_index, x)
|
||||||
|
|
||||||
n = self.norm2(x)
|
n = self.norm2(x)
|
||||||
n = self.attn2(n, context=context)
|
n = self.attn2(n, context=context)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
|
|
||||||
|
if current_index is not None:
|
||||||
|
transformer_options["current_index"] += 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -782,6 +782,8 @@ class UNetModel(nn.Module):
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
transformer_options["original_shape"] = list(x.shape)
|
transformer_options["original_shape"] = list(x.shape)
|
||||||
|
transformer_options["current_index"] = 0
|
||||||
|
|
||||||
assert (y is not None) == (
|
assert (y is not None) == (
|
||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
|
|
|
@ -176,7 +176,7 @@ def load_model_gpu(model):
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
def load_controlnet_gpu(models):
|
def load_controlnet_gpu(control_models):
|
||||||
global current_gpu_controlnets
|
global current_gpu_controlnets
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
|
@ -186,6 +186,10 @@ def load_controlnet_gpu(models):
|
||||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||||
return
|
return
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for m in control_models:
|
||||||
|
models += m.get_models()
|
||||||
|
|
||||||
for m in current_gpu_controlnets:
|
for m in current_gpu_controlnets:
|
||||||
if m not in models:
|
if m not in models:
|
||||||
m.cpu()
|
m.cpu()
|
||||||
|
|
|
@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||||
control = None
|
control = None
|
||||||
if 'control' in cond[1]:
|
if 'control' in cond[1]:
|
||||||
control = cond[1]['control']
|
control = cond[1]['control']
|
||||||
return (input_x, mult, conditionning, area, control)
|
|
||||||
|
patches = None
|
||||||
|
if 'gligen' in cond[1]:
|
||||||
|
gligen = cond[1]['gligen']
|
||||||
|
patches = {}
|
||||||
|
gligen_type = gligen[0]
|
||||||
|
gligen_model = gligen[1]
|
||||||
|
if gligen_type == "position":
|
||||||
|
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||||
|
else:
|
||||||
|
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
|
||||||
|
|
||||||
|
patches['middle_patch'] = [gligen_patch]
|
||||||
|
|
||||||
|
return (input_x, mult, conditionning, area, control, patches)
|
||||||
|
|
||||||
def cond_equal_size(c1, c2):
|
def cond_equal_size(c1, c2):
|
||||||
if c1 is c2:
|
if c1 is c2:
|
||||||
|
@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||||
def can_concat_cond(c1, c2):
|
def can_concat_cond(c1, c2):
|
||||||
if c1[0].shape != c2[0].shape:
|
if c1[0].shape != c2[0].shape:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
#control
|
||||||
if (c1[4] is None) != (c2[4] is None):
|
if (c1[4] is None) != (c2[4] is None):
|
||||||
return False
|
return False
|
||||||
if c1[4] is not None:
|
if c1[4] is not None:
|
||||||
if c1[4] is not c2[4]:
|
if c1[4] is not c2[4]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
#patches
|
||||||
|
if (c1[5] is None) != (c2[5] is None):
|
||||||
|
return False
|
||||||
|
if (c1[5] is not None):
|
||||||
|
if c1[5] is not c2[5]:
|
||||||
|
return False
|
||||||
|
|
||||||
return cond_equal_size(c1[2], c2[2])
|
return cond_equal_size(c1[2], c2[2])
|
||||||
|
|
||||||
def cond_cat(c_list):
|
def cond_cat(c_list):
|
||||||
|
@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||||
cond_or_uncond = []
|
cond_or_uncond = []
|
||||||
area = []
|
area = []
|
||||||
control = None
|
control = None
|
||||||
|
patches = None
|
||||||
for x in to_batch:
|
for x in to_batch:
|
||||||
o = to_run.pop(x)
|
o = to_run.pop(x)
|
||||||
p = o[0]
|
p = o[0]
|
||||||
|
@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||||
area += [p[3]]
|
area += [p[3]]
|
||||||
cond_or_uncond += [o[1]]
|
cond_or_uncond += [o[1]]
|
||||||
control = p[4]
|
control = p[4]
|
||||||
|
patches = p[5]
|
||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
batch_chunks = len(cond_or_uncond)
|
||||||
input_x = torch.cat(input_x)
|
input_x = torch.cat(input_x)
|
||||||
|
@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||||
if control is not None:
|
if control is not None:
|
||||||
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
|
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
|
||||||
|
|
||||||
|
transformer_options = {}
|
||||||
if 'transformer_options' in model_options:
|
if 'transformer_options' in model_options:
|
||||||
c['transformer_options'] = model_options['transformer_options']
|
transformer_options = model_options['transformer_options'].copy()
|
||||||
|
|
||||||
|
if patches is not None:
|
||||||
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c):
|
||||||
n = c[1].copy()
|
n = c[1].copy()
|
||||||
conds += [[smallest[0], n]]
|
conds += [[smallest[0], n]]
|
||||||
|
|
||||||
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
def apply_control_net_to_equal_area(conds, uncond):
|
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
cond_other = []
|
cond_other = []
|
||||||
uncond_cnets = []
|
uncond_cnets = []
|
||||||
|
@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond):
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
if 'area' not in x[1]:
|
if 'area' not in x[1]:
|
||||||
if 'control' in x[1] and x[1]['control'] is not None:
|
if name in x[1] and x[1][name] is not None:
|
||||||
cond_cnets.append(x[1]['control'])
|
cond_cnets.append(x[1][name])
|
||||||
else:
|
else:
|
||||||
cond_other.append((x, t))
|
cond_other.append((x, t))
|
||||||
for t in range(len(uncond)):
|
for t in range(len(uncond)):
|
||||||
x = uncond[t]
|
x = uncond[t]
|
||||||
if 'area' not in x[1]:
|
if 'area' not in x[1]:
|
||||||
if 'control' in x[1] and x[1]['control'] is not None:
|
if name in x[1] and x[1][name] is not None:
|
||||||
uncond_cnets.append(x[1]['control'])
|
uncond_cnets.append(x[1][name])
|
||||||
else:
|
else:
|
||||||
uncond_other.append((x, t))
|
uncond_other.append((x, t))
|
||||||
|
|
||||||
|
@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond):
|
||||||
for x in range(len(cond_cnets)):
|
for x in range(len(cond_cnets)):
|
||||||
temp = uncond_other[x % len(uncond_other)]
|
temp = uncond_other[x % len(uncond_other)]
|
||||||
o = temp[0]
|
o = temp[0]
|
||||||
if 'control' in o[1] and o[1]['control'] is not None:
|
if name in o[1] and o[1][name] is not None:
|
||||||
n = o[1].copy()
|
n = o[1].copy()
|
||||||
n['control'] = cond_cnets[x]
|
n[name] = uncond_fill_func(cond_cnets, x)
|
||||||
uncond += [[o[0], n]]
|
uncond += [[o[0], n]]
|
||||||
else:
|
else:
|
||||||
n = o[1].copy()
|
n = o[1].copy()
|
||||||
n['control'] = cond_cnets[x]
|
n[name] = uncond_fill_func(cond_cnets, x)
|
||||||
uncond[temp[1]] = [o[0], n]
|
uncond[temp[1]] = [o[0], n]
|
||||||
|
|
||||||
|
|
||||||
def encode_adm(noise_augmentor, conds, batch_size, device):
|
def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||||
|
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
|
@ -466,7 +498,8 @@ class KSampler:
|
||||||
for c in negative:
|
for c in negative:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
apply_control_net_to_equal_area(positive, negative)
|
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
if self.model.model.diffusion_model.dtype == torch.float16:
|
||||||
precision_scope = torch.autocast
|
precision_scope = torch.autocast
|
||||||
|
|
22
comfy/sd.py
22
comfy/sd.py
|
@ -13,6 +13,7 @@ from .t2i_adapter import adapter
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import clip_vision
|
from . import clip_vision
|
||||||
|
from . import gligen
|
||||||
|
|
||||||
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
@ -378,7 +379,7 @@ class CLIP:
|
||||||
def tokenize(self, text, return_word_ids=False):
|
def tokenize(self, text, return_word_ids=False):
|
||||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|
||||||
def encode_from_tokens(self, tokens):
|
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||||
if self.layer_idx is not None:
|
if self.layer_idx is not None:
|
||||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||||
try:
|
try:
|
||||||
|
@ -388,6 +389,10 @@ class CLIP:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.patcher.unpatch_model()
|
self.patcher.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
if return_pooled:
|
||||||
|
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__)
|
||||||
|
pooled = cond[:, eos_token_index]
|
||||||
|
return cond, pooled
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
|
@ -564,10 +569,10 @@ class ControlNet:
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def get_control_models(self):
|
def get_models(self):
|
||||||
out = []
|
out = []
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
out += self.previous_controlnet.get_control_models()
|
out += self.previous_controlnet.get_models()
|
||||||
out.append(self.control_model)
|
out.append(self.control_model)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -737,10 +742,10 @@ class T2IAdapter:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
|
|
||||||
def get_control_models(self):
|
def get_models(self):
|
||||||
out = []
|
out = []
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
out += self.previous_controlnet.get_control_models()
|
out += self.previous_controlnet.get_models()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data):
|
||||||
|
@ -787,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None):
|
||||||
clip.load_from_state_dict(clip_data)
|
clip.load_from_state_dict(clip_data)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
|
def load_gligen(ckpt_path):
|
||||||
|
data = utils.load_torch_file(ckpt_path)
|
||||||
|
model = gligen.load_gligen(data)
|
||||||
|
if model_management.should_use_fp16():
|
||||||
|
model = model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||||
with open(config_path, 'r') as stream:
|
with open(config_path, 'r') as stream:
|
||||||
config = yaml.safe_load(stream)
|
config = yaml.safe_load(stream)
|
||||||
|
|
|
@ -26,6 +26,8 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
|
||||||
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
||||||
|
|
||||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||||
|
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
|
||||||
|
|
||||||
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
||||||
|
|
||||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
||||||
|
|
71
nodes.py
71
nodes.py
|
@ -490,6 +490,51 @@ class unCLIPConditioning:
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class GLIGENLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("GLIGEN",)
|
||||||
|
FUNCTION = "load_gligen"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/gligen"
|
||||||
|
|
||||||
|
def load_gligen(self, gligen_name):
|
||||||
|
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
||||||
|
gligen = comfy.sd.load_gligen(gligen_path)
|
||||||
|
return (gligen,)
|
||||||
|
|
||||||
|
class GLIGENTextBoxApply:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning_to": ("CONDITIONING", ),
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"gligen_textbox_model": ("GLIGEN", ),
|
||||||
|
"text": ("STRING", {"multiline": True}),
|
||||||
|
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
|
||||||
|
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
|
||||||
|
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||||
|
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/gligen"
|
||||||
|
|
||||||
|
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
|
||||||
|
c = []
|
||||||
|
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
|
||||||
|
for t in conditioning_to:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
|
||||||
|
prev = []
|
||||||
|
if "gligen" in n[1]:
|
||||||
|
prev = n[1]['gligen'][2]
|
||||||
|
|
||||||
|
n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
class EmptyLatentImage:
|
class EmptyLatentImage:
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu"):
|
||||||
|
@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||||
negative_copy = []
|
negative_copy = []
|
||||||
|
|
||||||
control_nets = []
|
control_nets = []
|
||||||
|
def get_models(cond):
|
||||||
|
models = []
|
||||||
|
for c in cond:
|
||||||
|
if 'control' in c[1]:
|
||||||
|
models += [c[1]['control']]
|
||||||
|
if 'gligen' in c[1]:
|
||||||
|
models += [c[1]['gligen'][1]]
|
||||||
|
return models
|
||||||
|
|
||||||
for p in positive:
|
for p in positive:
|
||||||
t = p[0]
|
t = p[0]
|
||||||
if t.shape[0] < noise.shape[0]:
|
if t.shape[0] < noise.shape[0]:
|
||||||
t = torch.cat([t] * noise.shape[0])
|
t = torch.cat([t] * noise.shape[0])
|
||||||
t = t.to(device)
|
t = t.to(device)
|
||||||
if 'control' in p[1]:
|
|
||||||
control_nets += [p[1]['control']]
|
|
||||||
positive_copy += [[t] + p[1:]]
|
positive_copy += [[t] + p[1:]]
|
||||||
for n in negative:
|
for n in negative:
|
||||||
t = n[0]
|
t = n[0]
|
||||||
if t.shape[0] < noise.shape[0]:
|
if t.shape[0] < noise.shape[0]:
|
||||||
t = torch.cat([t] * noise.shape[0])
|
t = torch.cat([t] * noise.shape[0])
|
||||||
t = t.to(device)
|
t = t.to(device)
|
||||||
if 'control' in n[1]:
|
|
||||||
control_nets += [n[1]['control']]
|
|
||||||
negative_copy += [[t] + n[1:]]
|
negative_copy += [[t] + n[1:]]
|
||||||
|
|
||||||
control_net_models = []
|
models = get_models(positive) + get_models(negative)
|
||||||
for x in control_nets:
|
comfy.model_management.load_controlnet_gpu(models)
|
||||||
control_net_models += x.get_control_models()
|
|
||||||
comfy.model_management.load_controlnet_gpu(control_net_models)
|
|
||||||
|
|
||||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
for c in control_nets:
|
for m in models:
|
||||||
c.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
|
@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = {
|
||||||
"VAEEncodeTiled": VAEEncodeTiled,
|
"VAEEncodeTiled": VAEEncodeTiled,
|
||||||
"TomePatchModel": TomePatchModel,
|
"TomePatchModel": TomePatchModel,
|
||||||
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||||
|
"GLIGENLoader": GLIGENLoader,
|
||||||
|
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
||||||
|
|
||||||
"CheckpointLoader": CheckpointLoader,
|
"CheckpointLoader": CheckpointLoader,
|
||||||
"DiffusersLoader": DiffusersLoader,
|
"DiffusersLoader": DiffusersLoader,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue