diff --git a/README.md b/README.md index 9a1eb194..3d8e339b 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. - [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/) - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. -- [ControlNet](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) +- [ControlNet](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) and T2I-Adapter - Starts up very fast. - Works fully offline: will never download anything. diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index c75830ae..8d14a690 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -1,7 +1,6 @@ #taken from: https://github.com/lllyasviel/ControlNet #and modified -import einops import torch import torch as th import torch.nn as nn @@ -13,8 +12,6 @@ from ldm.modules.diffusionmodules.util import ( timestep_embedding, ) -from einops import rearrange, repeat -from torchvision.utils import make_grid from ldm.modules.attention import SpatialTransformer from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock from ldm.models.diffusion.ddpm import LatentDiffusion diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 9054a1c2..9a652c29 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -774,17 +774,23 @@ class UNetModel(nn.Module): emb = emb + self.label_emb(y) h = x.type(self.dtype) - for module in self.input_blocks: + for id, module in enumerate(self.input_blocks): h = module(h, emb, context) + if control is not None and 'input' in control and len(control['input']) > 0: + ctrl = control['input'].pop() + if ctrl is not None: + h += ctrl hs.append(h) h = self.middle_block(h, emb, context) - if control is not None: - h += control.pop() + if control is not None and 'middle' in control and len(control['middle']) > 0: + h += control['middle'].pop() for module in self.output_blocks: hsp = hs.pop() - if control is not None: - hsp += control.pop() + if control is not None and 'output' in control and len(control['output']) > 0: + ctrl = control['output'].pop() + if ctrl is not None: + hsp += ctrl h = th.cat([h, hsp], dim=1) del hsp h = module(h, emb, context) diff --git a/comfy/sd.py b/comfy/sd.py index fe60205d..42860710 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -8,6 +8,7 @@ from ldm.util import instantiate_from_config from ldm.models.autoencoder import AutoencoderKL from omegaconf import OmegaConf from .cldm import cldm +from .t2i_adapter import adapter from . import utils @@ -318,6 +319,37 @@ class VAE: pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples + def decode_tiled(self, samples): + tile_x = tile_y = 64 + overlap = 8 + model_management.unload_model() + output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu") + self.first_stage_model = self.first_stage_model.to(self.device) + for b in range(samples.shape[0]): + s = samples[b:b+1] + out = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu") + out_div = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu") + for y in range(0, s.shape[2], tile_y - overlap): + for x in range(0, s.shape[3], tile_x - overlap): + s_in = s[:,:,y:y+tile_y,x:x+tile_x] + + pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * s_in.to(self.device)) + pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) + ps = pixel_samples.cpu() + mask = torch.ones_like(ps) + feather = overlap * 8 + for t in range(feather): + mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) + mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1)) + mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1)) + mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) + out[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += ps * mask + out_div[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += mask + + output[b:b+1] = out/out_div + self.first_stage_model = self.first_stage_model.cpu() + return output.movedim(1,-1) + def encode(self, pixel_samples): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) @@ -357,18 +389,28 @@ class ControlNet: self.control_model = model_management.load_if_low_vram(self.control_model) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) self.control_model = model_management.unload_if_low_vram(self.control_model) - out = [] + out = {'middle':[], 'output': []} autocast_enabled = torch.is_autocast_enabled() for i in range(len(control)): + if i == (len(control) - 1): + key = 'middle' + index = 0 + else: + key = 'output' + index = i x = control[i] x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) - if control_prev is not None: - x += control_prev[i] - out.append(x) + if control_prev is not None and key in control_prev: + prev = control_prev[key][index] + if prev is not None: + x += prev + out[key].append(x) + if control_prev is not None and 'input' in control_prev: + out['input'] = control_prev['input'] return out def set_cond_hint(self, cond_hint, strength=1.0): @@ -463,6 +505,95 @@ def load_controlnet(ckpt_path, model=None): control = ControlNet(control_model) return control +class T2IAdapter: + def __init__(self, t2i_model, channels_in, device="cuda"): + self.t2i_model = t2i_model + self.channels_in = channels_in + self.strength = 1.0 + self.device = device + self.previous_controlnet = None + self.control_input = None + self.cond_hint_original = None + self.cond_hint = None + + def get_control(self, x_noisy, t, cond_txt): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) + + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) + if self.channels_in == 1 and self.cond_hint.shape[1] > 1: + self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + self.t2i_model.to(self.device) + self.control_input = self.t2i_model(self.cond_hint) + self.t2i_model.cpu() + + output_dtype = x_noisy.dtype + out = {'input':[]} + + for i in range(len(self.control_input)): + key = 'input' + x = self.control_input[i] * self.strength + if x.dtype != output_dtype and not autocast_enabled: + x = x.to(output_dtype) + + if control_prev is not None and key in control_prev: + index = len(control_prev[key]) - i * 3 - 3 + prev = control_prev[key][index] + if prev is not None: + x += prev + out[key].insert(0, None) + out[key].insert(0, None) + out[key].insert(0, x) + + if control_prev is not None and 'input' in control_prev: + for i in range(len(out['input'])): + if out['input'][i] is None: + out['input'][i] = control_prev['input'][i] + if control_prev is not None and 'middle' in control_prev: + out['middle'] = control_prev['middle'] + if control_prev is not None and 'output' in control_prev: + out['output'] = control_prev['output'] + return out + + def set_cond_hint(self, cond_hint, strength=1.0): + self.cond_hint_original = cond_hint + self.strength = strength + return self + + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + + def copy(self): + c = T2IAdapter(self.t2i_model, self.channels_in) + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + return c + + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + + def get_control_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_control_models() + return out + +def load_t2i_adapter(ckpt_path, model=None): + t2i_data = load_torch_file(ckpt_path) + cin = t2i_data['conv_in.weight'].shape[1] + model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False) + model_ad.load_state_dict(t2i_data) + return T2IAdapter(model_ad, cin // 64) def load_clip(ckpt_path, embedding_directory=None): clip_data = load_torch_file(ckpt_path) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py new file mode 100644 index 00000000..d059ba91 --- /dev/null +++ b/comfy/t2i_adapter/adapter.py @@ -0,0 +1,125 @@ +#taken from https://github.com/TencentARC/T2I-Adapter + +import torch +import torch.nn as nn +import torch.nn.functional as F +from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResnetBlock(nn.Module): + def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): + super().__init__() + ps = ksize//2 + if in_c != out_c or sk==False: + self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) + else: + # print('n_in') + self.in_conv = None + self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) + if sk==False: + self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) + else: + self.skep = None + + self.down = down + if self.down == True: + self.down_opt = Downsample(in_c, use_conv=use_conv) + + def forward(self, x): + if self.down == True: + x = self.down_opt(x) + if self.in_conv is not None: # edit + x = self.in_conv(x) + + h = self.block1(x) + h = self.act(h) + h = self.block2(h) + if self.skep is not None: + return h + self.skep(x) + else: + return h + x + + +class Adapter(nn.Module): + def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): + super(Adapter, self).__init__() + self.unshuffle = nn.PixelUnshuffle(8) + self.channels = channels + self.nums_rb = nums_rb + self.body = [] + for i in range(len(channels)): + for j in range(nums_rb): + if (i!=0) and (j==0): + self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) + else: + self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) + self.body = nn.ModuleList(self.body) + self.conv_in = nn.Conv2d(cin,channels[0], 3, 1, 1) + + def forward(self, x): + # unshuffle + x = self.unshuffle(x) + # extract features + features = [] + x = self.conv_in(x) + for i in range(len(self.channels)): + for j in range(self.nums_rb): + idx = i*self.nums_rb +j + x = self.body[idx](x) + features.append(x) + + return features diff --git a/models/t2i_adapter/put_t2i_adapter_models_here b/models/t2i_adapter/put_t2i_adapter_models_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index 48549745..1a567e01 100644 --- a/nodes.py +++ b/nodes.py @@ -106,6 +106,21 @@ class VAEDecode: def decode(self, vae, samples): return (vae.decode(samples["samples"]), ) +class VAEDecodeTiled: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "decode" + + CATEGORY = "_for_testing" + + def decode(self, vae, samples): + return (vae.decode_tiled(samples["samples"]), ) + class VAEEncode: def __init__(self, device="cpu"): self.device = device @@ -277,6 +292,22 @@ class ControlNetApply: c.append(n) return (c, ) +class T2IAdapterLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter") + @classmethod + def INPUT_TYPES(s): + return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}} + + RETURN_TYPES = ("CONTROL_NET",) + FUNCTION = "load_t2i_adapter" + + CATEGORY = "loaders" + + def load_t2i_adapter(self, t2i_adapter_name): + t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name) + t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path) + return (t2i_adapter,) class CLIPLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") @@ -794,6 +825,8 @@ NODE_CLASS_MAPPINGS = { "ControlNetApply": ControlNetApply, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, + "T2IAdapterLoader": T2IAdapterLoader, + "VAEDecodeTiled": VAEDecodeTiled, } CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")