2024-08-13 01:22:22 +00:00
|
|
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
|
|
|
|
|
|
import torch
|
2024-08-28 22:56:33 +00:00
|
|
|
import math
|
2024-08-13 01:22:22 +00:00
|
|
|
from torch import Tensor, nn
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
|
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
|
|
MLPEmbedder, SingleStreamBlock,
|
|
|
|
timestep_embedding)
|
|
|
|
|
|
|
|
from .model import Flux
|
|
|
|
import comfy.ldm.common_dit
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNetFlux(Flux):
|
2024-08-28 22:56:33 +00:00
|
|
|
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
2024-08-13 01:22:22 +00:00
|
|
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
|
|
|
2024-08-28 22:56:33 +00:00
|
|
|
self.main_model_double = 19
|
|
|
|
self.main_model_single = 38
|
2024-08-13 01:22:22 +00:00
|
|
|
# add ControlNet blocks
|
|
|
|
self.controlnet_blocks = nn.ModuleList([])
|
|
|
|
for _ in range(self.params.depth):
|
|
|
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
|
|
self.controlnet_blocks.append(controlnet_block)
|
2024-08-29 06:14:19 +00:00
|
|
|
|
|
|
|
self.controlnet_single_blocks = nn.ModuleList([])
|
|
|
|
for _ in range(self.params.depth_single_blocks):
|
|
|
|
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
|
|
|
|
2024-08-13 01:22:22 +00:00
|
|
|
self.gradient_checkpointing = False
|
2024-08-28 22:56:33 +00:00
|
|
|
self.latent_input = latent_input
|
|
|
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
|
|
if not self.latent_input:
|
|
|
|
self.input_hint_block = nn.Sequential(
|
|
|
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
|
|
nn.SiLU(),
|
|
|
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
|
|
|
)
|
2024-08-13 01:22:22 +00:00
|
|
|
|
|
|
|
def forward_orig(
|
|
|
|
self,
|
|
|
|
img: Tensor,
|
|
|
|
img_ids: Tensor,
|
|
|
|
controlnet_cond: Tensor,
|
|
|
|
txt: Tensor,
|
|
|
|
txt_ids: Tensor,
|
|
|
|
timesteps: Tensor,
|
|
|
|
y: Tensor,
|
|
|
|
guidance: Tensor = None,
|
|
|
|
) -> Tensor:
|
|
|
|
if img.ndim != 3 or txt.ndim != 3:
|
|
|
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
|
|
|
|
|
|
# running on sequences img
|
|
|
|
img = self.img_in(img)
|
2024-08-28 22:56:33 +00:00
|
|
|
if not self.latent_input:
|
|
|
|
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
|
|
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
|
|
|
2024-08-13 01:22:22 +00:00
|
|
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
|
|
|
img = img + controlnet_cond
|
|
|
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
|
|
if self.params.guidance_embed:
|
2024-08-14 05:05:17 +00:00
|
|
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
|
|
vec = vec + self.vector_in(y)
|
2024-08-13 01:22:22 +00:00
|
|
|
txt = self.txt_in(txt)
|
|
|
|
|
|
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
|
|
pe = self.pe_embedder(ids)
|
|
|
|
|
2024-08-29 06:14:19 +00:00
|
|
|
controlnet_double = ()
|
|
|
|
|
|
|
|
for i in range(len(self.double_blocks)):
|
|
|
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
|
|
|
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
2024-08-13 01:22:22 +00:00
|
|
|
|
2024-08-29 06:14:19 +00:00
|
|
|
img = torch.cat((txt, img), 1)
|
2024-08-13 01:22:22 +00:00
|
|
|
|
2024-08-29 06:14:19 +00:00
|
|
|
controlnet_single = ()
|
2024-08-13 01:22:22 +00:00
|
|
|
|
2024-08-29 06:14:19 +00:00
|
|
|
for i in range(len(self.single_blocks)):
|
|
|
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
|
|
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
2024-08-28 22:56:33 +00:00
|
|
|
|
2024-08-29 06:14:19 +00:00
|
|
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
2024-08-28 22:56:33 +00:00
|
|
|
if self.latent_input:
|
|
|
|
out_input = ()
|
2024-08-29 06:14:19 +00:00
|
|
|
for x in controlnet_double:
|
2024-08-28 22:56:33 +00:00
|
|
|
out_input += (x,) * repeat
|
|
|
|
else:
|
2024-08-29 06:14:19 +00:00
|
|
|
out_input = (controlnet_double * repeat)
|
|
|
|
|
|
|
|
out = {"input": out_input[:self.main_model_double]}
|
|
|
|
if len(controlnet_single) > 0:
|
|
|
|
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
|
|
|
out_output = ()
|
|
|
|
if self.latent_input:
|
|
|
|
for x in controlnet_single:
|
|
|
|
out_output += (x,) * repeat
|
|
|
|
else:
|
|
|
|
out_output = (controlnet_single * repeat)
|
|
|
|
out["output"] = out_output[:self.main_model_single]
|
|
|
|
return out
|
2024-08-13 01:22:22 +00:00
|
|
|
|
|
|
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
2024-08-28 22:56:33 +00:00
|
|
|
patch_size = 2
|
|
|
|
if self.latent_input:
|
|
|
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
|
|
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
|
|
else:
|
|
|
|
hint = hint * 2.0 - 1.0
|
2024-08-13 01:22:22 +00:00
|
|
|
|
|
|
|
bs, c, h, w = x.shape
|
|
|
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
|
|
|
|
|
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
|
|
|
|
|
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
|
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
|
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
|
|
|
|
|
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
|
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|