Merge branch 'master' into patch_hooks_improved_memory
This commit is contained in:
commit
76b9ed1f2d
|
@ -28,7 +28,7 @@
|
|||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||
![ComfyUI Screenshot](comfyui_screenshot.png)
|
||||
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
|
@ -140,7 +140,7 @@ Put your VAE in: models/vae
|
|||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from app.logger import on_flush
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class TerminalService:
|
||||
|
@ -10,15 +11,27 @@ class TerminalService:
|
|||
self.subscriptions = set()
|
||||
on_flush(self.send_messages)
|
||||
|
||||
def get_terminal_size(self):
|
||||
try:
|
||||
size = os.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
try:
|
||||
size = shutil.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
return (80, 24) # fallback to 80x24
|
||||
|
||||
def update_size(self):
|
||||
sz = os.get_terminal_size()
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
if sz.columns != self.cols:
|
||||
self.cols = sz.columns
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
|
||||
if sz.lines != self.rows:
|
||||
self.rows = sz.lines
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
|
|
|
@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
|
|||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
|
@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
|
|||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
if model_type == "siglip_vision_model":
|
||||
self.class_embedding = None
|
||||
patch_bias = True
|
||||
else:
|
||||
num_patches = num_patches + 1
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
patch_bias = False
|
||||
|
||||
self.patch_embedding = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
bias=patch_bias,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
if self.class_embedding is not None:
|
||||
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
||||
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
|
@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
|
|||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.output_layernorm = False
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
||||
|
@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
|
|||
x = self.pre_layrnorm(x)
|
||||
#TODO: attention_mask?
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
if self.output_layernorm:
|
||||
x = self.post_layernorm(x)
|
||||
pooled_output = x
|
||||
else:
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
if "projection_dim" in config_dict:
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
|
|
|
@ -16,9 +16,9 @@ class Output:
|
|||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]):
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
|
@ -35,6 +35,8 @@ class ClipVisionModel():
|
|||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
|
@ -51,7 +53,7 @@ class ClipVisionModel():
|
|||
|
||||
def encode_image(self, image):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
|
@ -94,7 +96,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 384,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
|
@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
|
|||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
|
@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
|
|||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
|
@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
|
|||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
|
|
|
@ -20,6 +20,7 @@ import comfy.ldm.common_dit
|
|||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
|
@ -29,6 +30,7 @@ class FluxParams:
|
|||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
@ -43,8 +45,9 @@ class Flux(nn.Module):
|
|||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
|
@ -165,7 +168,7 @@ class Flux(nn.Module):
|
|||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
import torch
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
|
@ -372,10 +372,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
|
|
|
@ -49,6 +49,15 @@ def load_lora(lora, to_load, log_missing=True):
|
|||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
|
@ -82,7 +91,7 @@ def load_lora(lora, to_load, log_missing=True):
|
|||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
|
@ -193,6 +202,12 @@ def load_lora(lora, to_load, log_missing=True):
|
|||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
if log_missing:
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
|
@ -283,11 +298,14 @@ def model_lora_keys_unet(model, key_map={}):
|
|||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
|
@ -441,6 +459,8 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
elif patch_type == "model_as_lora":
|
||||
target_weight: torch.Tensor = v[0]
|
||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||
|
@ -450,6 +470,11 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
import torch
|
||||
|
||||
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
return sd
|
|
@ -165,8 +165,7 @@ class BaseModel(torch.nn.Module):
|
|||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
def concat_cond(self, **kwargs):
|
||||
if len(self.concat_keys) > 0:
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
|
@ -205,7 +204,14 @@ class BaseModel(torch.nn.Module):
|
|||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
return data
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
if concat_cond is not None:
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
|
||||
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
|
@ -535,9 +541,7 @@ class SD_X4Upscaler(BaseModel):
|
|||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
@ -549,18 +553,15 @@ class IP2P:
|
|||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
@ -721,6 +722,38 @@ class Flux(BaseModel):
|
|||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
|
||||
out_channels = self.model_config.unet_config["out_channels"]
|
||||
|
||||
if num_channels <= out_channels:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
image = self.process_latent_in(image)
|
||||
if num_channels <= out_channels * 2:
|
||||
return image
|
||||
|
||||
#inpaint model
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.ones_like(noise)[:, :1]
|
||||
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
print(mask.shape)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
|
|
|
@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
|
|||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
patch_size = 2
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
|
|
|
@ -531,14 +531,18 @@ class ModelPatcher:
|
|||
lowvram_counter = 0
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||
loading.append((comfy.model_management.module_size(m), n, m))
|
||||
params = []
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
params.append(name)
|
||||
if hasattr(m, "comfy_cast_weights") or len(params) > 0:
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
|
||||
load_completely = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
module_mem = x[0]
|
||||
|
||||
lowvram_weight = False
|
||||
|
@ -574,22 +578,21 @@ class ModelPatcher:
|
|||
if m.comfy_cast_weights:
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if hasattr(m, "weight"):
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m))
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
params = x[3]
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
if m.comfy_patched_weights == True:
|
||||
continue
|
||||
|
||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||
for param in params:
|
||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
|
||||
|
|
|
@ -32,10 +32,13 @@ import comfy.text_encoders.genmo
|
|||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.taesd.taesd
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
|
@ -43,6 +46,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
lora = comfy.lora_convert.convert_lora(lora)
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
|
@ -506,6 +510,8 @@ def load_style_model(ckpt_path):
|
|||
keys = model_data.keys()
|
||||
if "style_embedding" in keys:
|
||||
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||
elif "redux_down.weight" in keys:
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
model.load_state_dict(model_data)
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 116 KiB |
6
nodes.py
6
nodes.py
|
@ -385,6 +385,7 @@ class InpaintModelConditioning:
|
|||
"vae": ("VAE", ),
|
||||
"pixels": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
|
||||
|
@ -393,7 +394,7 @@ class InpaintModelConditioning:
|
|||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, positive, negative, pixels, vae, mask):
|
||||
def encode(self, positive, negative, pixels, vae, mask, noise_mask):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
|
@ -417,7 +418,8 @@ class InpaintModelConditioning:
|
|||
out_latent = {}
|
||||
|
||||
out_latent["samples"] = orig_latent
|
||||
out_latent["noise_mask"] = mask
|
||||
if noise_mask:
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
|
|
Loading…
Reference in New Issue