diff --git a/README.md b/README.md index 2e1cbb72..d26f0fe9 100644 --- a/README.md +++ b/README.md @@ -75,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git | Keybind | Explanation | |------------------------------------|--------------------------------------------------------------------------------------------------------------------| -| Ctrl + Enter | Queue up current graph for generation | -| Ctrl + Shift + Enter | Queue up current graph as first for generation | -| Ctrl + Alt + Enter | Cancel current generation | -| Ctrl + Z/Ctrl + Y | Undo/Redo | -| Ctrl + S | Save workflow | -| Ctrl + O | Load workflow | -| Ctrl + A | Select all nodes | -| Alt + C | Collapse/uncollapse selected nodes | -| Ctrl + M | Mute/unmute selected nodes | -| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | -| Delete/Backspace | Delete selected nodes | -| Ctrl + Backspace | Delete the current graph | -| Space | Move the canvas around when held and moving the cursor | -| Ctrl/Shift + Click | Add clicked node to selection | -| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | -| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | -| Shift + Drag | Move multiple selected nodes at the same time | -| Ctrl + D | Load default graph | -| Alt + `+` | Canvas Zoom in | -| Alt + `-` | Canvas Zoom out | -| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out | -| P | Pin/Unpin selected nodes | -| Ctrl + G | Group selected nodes | -| Q | Toggle visibility of the queue | -| H | Toggle visibility of history | -| R | Refresh graph | +| `Ctrl` + `Enter` | Queue up current graph for generation | +| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation | +| `Ctrl` + `Alt` + `Enter` | Cancel current generation | +| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo | +| `Ctrl` + `S` | Save workflow | +| `Ctrl` + `O` | Load workflow | +| `Ctrl` + `A` | Select all nodes | +| `Alt `+ `C` | Collapse/uncollapse selected nodes | +| `Ctrl` + `M` | Mute/unmute selected nodes | +| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | +| `Delete`/`Backspace` | Delete selected nodes | +| `Ctrl` + `Backspace` | Delete the current graph | +| `Space` | Move the canvas around when held and moving the cursor | +| `Ctrl`/`Shift` + `Click` | Add clicked node to selection | +| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | +| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | +| `Shift` + `Drag` | Move multiple selected nodes at the same time | +| `Ctrl` + `D` | Load default graph | +| `Alt` + `+` | Canvas Zoom in | +| `Alt` + `-` | Canvas Zoom out | +| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out | +| `P` | Pin/Unpin selected nodes | +| `Ctrl` + `G` | Group selected nodes | +| `Q` | Toggle visibility of the queue | +| `H` | Toggle visibility of history | +| `R` | Refresh graph | | Double-Click LMB | Open node quick search palette | -| Shift + Drag | Move multiple wires at once | -| Ctrl + Alt + LMB | Disconnect all wires from clicked slot | +| `Shift` + Drag | Move multiple wires at once | +| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot | -Ctrl can also be replaced with Cmd instead for macOS users +`Ctrl` can also be replaced with `Cmd` instead for macOS users # Installing diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 44e806cb..88459457 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module): style=None, return_dict=False, control=None, - transformer_options=None, + transformer_options={}, ): """ Forward pass of the encoder. @@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module): return_dict: bool Whether to return a dictionary. """ - #import pdb - #pdb.set_trace() + patches_replace = transformer_options.get("patches_replace", {}) encoder_hidden_states = context text_states = encoder_hidden_states # 2,77,1024 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 @@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module): # Concatenate all extra vectors c = t + self.extra_embedder(extra_vec) # [B, D] + blocks_replace = patches_replace.get("dit", {}) + controls = None if control: controls = control.get("output", None) @@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module): skip = skips.pop() + controls.pop().to(dtype=x.dtype) else: skip = skips.pop() - x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) else: - x = block(x, c, text_states, freqs_cis_img) # (N, L, D) + skip = None + + if ("double_block", layer) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"]) + return out + + out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) + if layer < (self.depth // 2 - 1): skips.append(x) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 2792384d..f49cef95 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa @@ -415,13 +415,15 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) - def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs): + def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) + indices_grid = self.patchifier.get_grid( orig_num_frames=x.shape[2], orig_height=x.shape[3], orig_width=x.shape[4], batch_size=x.shape[0], - scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate + scale_grid=((1 / frame_rate) * 8, 32, 32), device=x.device, ) @@ -468,18 +470,28 @@ class LTXVModel(torch.nn.Module): batch_size, -1, x.shape[-1] ) - for block in self.transformer_blocks: - x = block( - x, - context=context, - attention_mask=attention_mask, - timestep=timestep, - pe=pe - ) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + return out + + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block( + x, + context=context, + attention_mask=attention_mask, + timestep=timestep, + pe=pe + ) # 3. Output scale_shift_values = ( - self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 146dea19..c572e7e8 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -2,6 +2,8 @@ from typing import Tuple, Union import torch import torch.nn as nn +import comfy.ops +ops = comfy.ops.disable_weight_init class CausalConv3d(nn.Module): @@ -29,7 +31,7 @@ class CausalConv3d(nn.Module): width_pad = kernel_size[2] // 2 padding = (0, height_pad, width_pad) - self.conv = nn.Conv3d( + self.conv = ops.Conv3d( in_channels, out_channels, kernel_size, diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 4138fdf3..33b2c2d4 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -628,10 +628,10 @@ class processor(nn.Module): self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): - return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1) + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) def normalize(self, x): - return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1) + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): def __init__(self): diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py index 389f8165..c5f067bf 100644 --- a/comfy/ldm/lightricks/vae/conv_nd_factory.py +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -4,7 +4,8 @@ import torch from .dual_conv3d import DualConv3d from .causal_conv3d import CausalConv3d - +import comfy.ops +ops = comfy.ops.disable_weight_init def make_conv_nd( dims: Union[int, Tuple[int, int]], @@ -19,7 +20,7 @@ def make_conv_nd( causal=False, ): if dims == 2: - return torch.nn.Conv2d( + return ops.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -41,7 +42,7 @@ def make_conv_nd( groups=groups, bias=bias, ) - return torch.nn.Conv3d( + return ops.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -71,11 +72,11 @@ def make_linear_nd( bias=True, ): if dims == 2: - return torch.nn.Conv2d( + return ops.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) elif dims == 3 or dims == (2, 1): - return torch.nn.Conv3d( + return ops.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) else: diff --git a/comfy/sd.py b/comfy/sd.py index 35ac754b..8db83ce8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -342,7 +342,7 @@ class VAE: self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) - self.upscale_ratio = 8 + self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") @@ -443,7 +443,9 @@ class VAE: elif dims == 2: pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: - pixel_samples = self.decode_tiled_3d(samples_in) + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -507,6 +509,12 @@ class VAE: def get_sd(self): return self.first_stage_model.state_dict() + def spacial_compression_decode(self): + try: + return self.upscale_ratio[-1] + except: + return self.upscale_ratio + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 9d063937..e6a48fc4 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -10,7 +10,7 @@ class EmptyLTXVLatentVideo: def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), + "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" diff --git a/nodes.py b/nodes.py index d3a15df2..26f64e30 100644 --- a/nodes.py +++ b/nodes.py @@ -304,7 +304,8 @@ class VAEDecodeTiled: def decode(self, vae, samples, tile_size, overlap=64): if tile_size < overlap * 4: overlap = tile_size // 4 - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8) + compression = vae.spacial_compression_decode() + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, )