From bc6be6c11e48114889a368e8c3597df8aac64ae3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 16:40:04 -0500 Subject: [PATCH 1/8] Some fixes to the lowvram system. --- comfy/model_patcher.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 22de7eea..fc232954 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -374,9 +374,14 @@ class ModelPatcher: loading = [] for n, m in self.model.named_modules(): params = [] + skip = False for name, param in m.named_parameters(recurse=False): params.append(name) - if hasattr(m, "comfy_cast_weights") or len(params) > 0: + for name, param in m.named_parameters(recurse=True): + if name not in params: + skip = True # skip random weights in non leaf modules + break + if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): loading.append((comfy.model_management.module_size(m), n, m, params)) load_completely = [] @@ -420,8 +425,9 @@ class ModelPatcher: if m.comfy_cast_weights: wipe_lowvram_weight(m) - mem_counter += module_mem - load_completely.append((module_mem, n, m, params)) + if full_load or mem_counter + module_mem < lowvram_model_memory: + mem_counter += module_mem + load_completely.append((module_mem, n, m, params)) load_completely.sort(reverse=True) for x in load_completely: From e5c3f4b87febd790b316f82813ba8d89d275fee4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 17:17:11 -0500 Subject: [PATCH 2/8] LTXV lowvram fixes. --- comfy/ldm/lightricks/model.py | 4 ++-- comfy/ldm/lightricks/vae/causal_conv3d.py | 4 +++- comfy/ldm/lightricks/vae/causal_video_autoencoder.py | 4 ++-- comfy/ldm/lightricks/vae/conv_nd_factory.py | 11 ++++++----- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 2792384d..87ed0995 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa @@ -479,7 +479,7 @@ class LTXVModel(torch.nn.Module): # 3. Output scale_shift_values = ( - self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 146dea19..c572e7e8 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -2,6 +2,8 @@ from typing import Tuple, Union import torch import torch.nn as nn +import comfy.ops +ops = comfy.ops.disable_weight_init class CausalConv3d(nn.Module): @@ -29,7 +31,7 @@ class CausalConv3d(nn.Module): width_pad = kernel_size[2] // 2 padding = (0, height_pad, width_pad) - self.conv = nn.Conv3d( + self.conv = ops.Conv3d( in_channels, out_channels, kernel_size, diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 4138fdf3..33b2c2d4 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -628,10 +628,10 @@ class processor(nn.Module): self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): - return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1) + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) def normalize(self, x): - return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1) + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): def __init__(self): diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py index 389f8165..c5f067bf 100644 --- a/comfy/ldm/lightricks/vae/conv_nd_factory.py +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -4,7 +4,8 @@ import torch from .dual_conv3d import DualConv3d from .causal_conv3d import CausalConv3d - +import comfy.ops +ops = comfy.ops.disable_weight_init def make_conv_nd( dims: Union[int, Tuple[int, int]], @@ -19,7 +20,7 @@ def make_conv_nd( causal=False, ): if dims == 2: - return torch.nn.Conv2d( + return ops.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -41,7 +42,7 @@ def make_conv_nd( groups=groups, bias=bias, ) - return torch.nn.Conv3d( + return ops.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -71,11 +72,11 @@ def make_linear_nd( bias=True, ): if dims == 2: - return torch.nn.Conv2d( + return ops.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) elif dims == 3 or dims == (2, 1): - return torch.nn.Conv3d( + return ops.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) else: From 6e8cdcd3cb542ba9eb5a5e5a420eff06f59dd268 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 18:00:34 -0500 Subject: [PATCH 3/8] Fix some tiled VAE decoding issues with LTX-Video. --- comfy/sd.py | 12 ++++++++++-- nodes.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index b07b5fe3..e2af7078 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -269,7 +269,7 @@ class VAE: self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) - self.upscale_ratio = 8 + self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") @@ -370,7 +370,9 @@ class VAE: elif dims == 2: pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: - pixel_samples = self.decode_tiled_3d(samples_in) + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -434,6 +436,12 @@ class VAE: def get_sd(self): return self.first_stage_model.state_dict() + def spacial_compression_decode(self): + try: + return self.upscale_ratio[-1] + except: + return self.upscale_ratio + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/nodes.py b/nodes.py index 01af6c68..3a68d43c 100644 --- a/nodes.py +++ b/nodes.py @@ -301,7 +301,8 @@ class VAEDecodeTiled: def decode(self, vae, samples, tile_size, overlap=64): if tile_size < overlap * 4: overlap = tile_size // 4 - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8) + compression = vae.spacial_compression_decode() + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) From 839ed3368efd0f61a2b986f57fe9e0698fd08e9f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 20:59:15 -0500 Subject: [PATCH 4/8] Some improvements to the lowvram unloading. --- comfy/model_patcher.py | 61 +++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index fc232954..f53f1074 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -367,10 +367,7 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) - def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): - mem_counter = 0 - patch_counter = 0 - lowvram_counter = 0 + def _load_list(self): loading = [] for n, m in self.model.named_modules(): params = [] @@ -383,6 +380,13 @@ class ModelPatcher: break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): loading.append((comfy.model_management.module_size(m), n, m, params)) + return loading + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): + mem_counter = 0 + patch_counter = 0 + lowvram_counter = 0 + loading = self._load_list() load_completely = [] loading.sort(reverse=True) @@ -514,14 +518,7 @@ class ModelPatcher: def partially_unload(self, device_to, memory_to_free=0): memory_freed = 0 patch_counter = 0 - unload_list = [] - - for n, m in self.model.named_modules(): - shift_lowvram = False - if hasattr(m, "comfy_cast_weights"): - module_mem = comfy.model_management.module_size(m) - unload_list.append((module_mem, n, m)) - + unload_list = self._load_list() unload_list.sort() for unload in unload_list: if memory_to_free < memory_freed: @@ -529,32 +526,42 @@ class ModelPatcher: module_mem = unload[0] n = unload[1] m = unload[2] - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + params = unload[3] + lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: - for key in [weight_key, bias_key]: + move_weight = True + for param in params: + key = "{}.{}".format(n, param) bk = self.backup.get(key, None) if bk is not None: + if not lowvram_possible: + move_weight = False + break + if bk.inplace_update: comfy.utils.copy_to_param(self.model, key, bk.weight) else: comfy.utils.set_attr_param(self.model, key, bk.weight) self.backup.pop(key) - m.to(device_to) - if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if move_weight: + m.to(device_to) + if lowvram_possible: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - m.comfy_patched_weights = False - memory_freed += module_mem - logging.debug("freed {}".format(n)) + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + m.comfy_patched_weights = False + memory_freed += module_mem + logging.debug("freed {}".format(n)) self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter From ab885b33ba78509d5e2a3f2cba5fc62dab907b37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 23 Nov 2024 10:33:05 -0500 Subject: [PATCH 5/8] Skip layer guidance node now works on LTX-Video. --- comfy/ldm/lightricks/model.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 87ed0995..f49cef95 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -415,13 +415,15 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) - def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs): + def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) + indices_grid = self.patchifier.get_grid( orig_num_frames=x.shape[2], orig_height=x.shape[3], orig_width=x.shape[4], batch_size=x.shape[0], - scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate + scale_grid=((1 / frame_rate) * 8, 32, 32), device=x.device, ) @@ -468,14 +470,24 @@ class LTXVModel(torch.nn.Module): batch_size, -1, x.shape[-1] ) - for block in self.transformer_blocks: - x = block( - x, - context=context, - attention_mask=attention_mask, - timestep=timestep, - pe=pe - ) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + return out + + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block( + x, + context=context, + attention_mask=attention_mask, + timestep=timestep, + pe=pe + ) # 3. Output scale_shift_values = ( From 7126ecffde66036070c6feef5e4c92d1ec2de025 Mon Sep 17 00:00:00 2001 From: spacepxl <143970342+spacepxl@users.noreply.github.com> Date: Sat, 23 Nov 2024 21:33:08 -0500 Subject: [PATCH 6/8] set LTX min length to 1 for t2i (#5750) At length=1, the LTX model can do txt2img and img2img with no other changes required. --- comfy_extras/nodes_lt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 9d063937..e6a48fc4 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -10,7 +10,7 @@ class EmptyLTXVLatentVideo: def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), + "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" From 3d802710e7b3b86221eec5d4a0aababdd9b81ba1 Mon Sep 17 00:00:00 2001 From: 40476 <63472353+40476@users.noreply.github.com> Date: Sun, 24 Nov 2024 04:12:07 -0500 Subject: [PATCH 7/8] Update README.md (#5707) --- README.md | 58 +++++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 2e1cbb72..d26f0fe9 100644 --- a/README.md +++ b/README.md @@ -75,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git | Keybind | Explanation | |------------------------------------|--------------------------------------------------------------------------------------------------------------------| -| Ctrl + Enter | Queue up current graph for generation | -| Ctrl + Shift + Enter | Queue up current graph as first for generation | -| Ctrl + Alt + Enter | Cancel current generation | -| Ctrl + Z/Ctrl + Y | Undo/Redo | -| Ctrl + S | Save workflow | -| Ctrl + O | Load workflow | -| Ctrl + A | Select all nodes | -| Alt + C | Collapse/uncollapse selected nodes | -| Ctrl + M | Mute/unmute selected nodes | -| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | -| Delete/Backspace | Delete selected nodes | -| Ctrl + Backspace | Delete the current graph | -| Space | Move the canvas around when held and moving the cursor | -| Ctrl/Shift + Click | Add clicked node to selection | -| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | -| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | -| Shift + Drag | Move multiple selected nodes at the same time | -| Ctrl + D | Load default graph | -| Alt + `+` | Canvas Zoom in | -| Alt + `-` | Canvas Zoom out | -| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out | -| P | Pin/Unpin selected nodes | -| Ctrl + G | Group selected nodes | -| Q | Toggle visibility of the queue | -| H | Toggle visibility of history | -| R | Refresh graph | +| `Ctrl` + `Enter` | Queue up current graph for generation | +| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation | +| `Ctrl` + `Alt` + `Enter` | Cancel current generation | +| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo | +| `Ctrl` + `S` | Save workflow | +| `Ctrl` + `O` | Load workflow | +| `Ctrl` + `A` | Select all nodes | +| `Alt `+ `C` | Collapse/uncollapse selected nodes | +| `Ctrl` + `M` | Mute/unmute selected nodes | +| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | +| `Delete`/`Backspace` | Delete selected nodes | +| `Ctrl` + `Backspace` | Delete the current graph | +| `Space` | Move the canvas around when held and moving the cursor | +| `Ctrl`/`Shift` + `Click` | Add clicked node to selection | +| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | +| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | +| `Shift` + `Drag` | Move multiple selected nodes at the same time | +| `Ctrl` + `D` | Load default graph | +| `Alt` + `+` | Canvas Zoom in | +| `Alt` + `-` | Canvas Zoom out | +| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out | +| `P` | Pin/Unpin selected nodes | +| `Ctrl` + `G` | Group selected nodes | +| `Q` | Toggle visibility of the queue | +| `H` | Toggle visibility of history | +| `R` | Refresh graph | | Double-Click LMB | Open node quick search palette | -| Shift + Drag | Move multiple wires at once | -| Ctrl + Alt + LMB | Disconnect all wires from clicked slot | +| `Shift` + Drag | Move multiple wires at once | +| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot | -Ctrl can also be replaced with Cmd instead for macOS users +`Ctrl` can also be replaced with `Cmd` instead for macOS users # Installing From b4526d3fc3c4b6d42ea27d16e38cfecb6c54ef7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 24 Nov 2024 05:54:30 -0500 Subject: [PATCH 8/8] Skip layer guidance now works on hydit model. --- comfy/ldm/hydit/models.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 44e806cb..88459457 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module): style=None, return_dict=False, control=None, - transformer_options=None, + transformer_options={}, ): """ Forward pass of the encoder. @@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module): return_dict: bool Whether to return a dictionary. """ - #import pdb - #pdb.set_trace() + patches_replace = transformer_options.get("patches_replace", {}) encoder_hidden_states = context text_states = encoder_hidden_states # 2,77,1024 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 @@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module): # Concatenate all extra vectors c = t + self.extra_embedder(extra_vec) # [B, D] + blocks_replace = patches_replace.get("dit", {}) + controls = None if control: controls = control.get("output", None) @@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module): skip = skips.pop() + controls.pop().to(dtype=x.dtype) else: skip = skips.pop() - x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) else: - x = block(x, c, text_states, freqs_cis_img) # (N, L, D) + skip = None + + if ("double_block", layer) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"]) + return out + + out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) + if layer < (self.depth // 2 - 1): skips.append(x)