Merge branch 'master' into patch_hooks_improved_memory

This commit is contained in:
Jedrzej Kosinski 2024-11-24 15:46:25 -06:00
commit 26ccd3b5f9
9 changed files with 94 additions and 58 deletions

View File

@ -75,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Keybind | Explanation | | Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------| |------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation | | `Ctrl` + `Enter` | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation | | `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation | | `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo | | `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| Ctrl + S | Save workflow | | `Ctrl` + `S` | Save workflow |
| Ctrl + O | Load workflow | | `Ctrl` + `O` | Load workflow |
| Ctrl + A | Select all nodes | | `Ctrl` + `A` | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes | | `Alt `+ `C` | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes | | `Ctrl` + `M` | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | | `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes | | `Delete`/`Backspace` | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph | | `Ctrl` + `Backspace` | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor | | `Space` | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection | | `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | | `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | | `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time | | `Shift` + `Drag` | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph | | `Ctrl` + `D` | Load default graph |
| Alt + `+` | Canvas Zoom in | | `Alt` + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out | | `Alt` + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out | | `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes | | `P` | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes | | `Ctrl` + `G` | Group selected nodes |
| Q | Toggle visibility of the queue | | `Q` | Toggle visibility of the queue |
| H | Toggle visibility of history | | `H` | Toggle visibility of history |
| R | Refresh graph | | `R` | Refresh graph |
| Double-Click LMB | Open node quick search palette | | Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once | | `Shift` + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot | | `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users `Ctrl` can also be replaced with `Cmd` instead for macOS users
# Installing # Installing

View File

@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
style=None, style=None,
return_dict=False, return_dict=False,
control=None, control=None,
transformer_options=None, transformer_options={},
): ):
""" """
Forward pass of the encoder. Forward pass of the encoder.
@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool return_dict: bool
Whether to return a dictionary. Whether to return a dictionary.
""" """
#import pdb patches_replace = transformer_options.get("patches_replace", {})
#pdb.set_trace()
encoder_hidden_states = context encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024 text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors # Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D] c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None controls = None
if control: if control:
controls = control.get("output", None) controls = control.get("output", None)
@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
skip = skips.pop() + controls.pop().to(dtype=x.dtype) skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else: else:
skip = skips.pop() skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else: else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D) skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1): if layer < (self.depth // 2 - 1):
skips.append(x) skips.append(x)

View File

@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
@ -415,13 +415,15 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1) self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs): def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid( indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2], orig_num_frames=x.shape[2],
orig_height=x.shape[3], orig_height=x.shape[3],
orig_width=x.shape[4], orig_width=x.shape[4],
batch_size=x.shape[0], batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device, device=x.device,
) )
@ -468,7 +470,17 @@ class LTXVModel(torch.nn.Module):
batch_size, -1, x.shape[-1] batch_size, -1, x.shape[-1]
) )
for block in self.transformer_blocks: blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block( x = block(
x, x,
context=context, context=context,
@ -479,7 +491,7 @@ class LTXVModel(torch.nn.Module):
# 3. Output # 3. Output
scale_shift_values = ( scale_shift_values = (
self.scale_shift_table[None, None] + embedded_timestep[:, :, None] self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
) )
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x) x = self.norm_out(x)

View File

@ -2,6 +2,8 @@ from typing import Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module): class CausalConv3d(nn.Module):
@ -29,7 +31,7 @@ class CausalConv3d(nn.Module):
width_pad = kernel_size[2] // 2 width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad) padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d( self.conv = ops.Conv3d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,

View File

@ -628,10 +628,10 @@ class processor(nn.Module):
self.register_buffer("channel", torch.empty(128)) self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x): def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1) return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
def normalize(self, x): def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1) return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module): class VideoVAE(nn.Module):
def __init__(self): def __init__(self):

View File

@ -4,7 +4,8 @@ import torch
from .dual_conv3d import DualConv3d from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd( def make_conv_nd(
dims: Union[int, Tuple[int, int]], dims: Union[int, Tuple[int, int]],
@ -19,7 +20,7 @@ def make_conv_nd(
causal=False, causal=False,
): ):
if dims == 2: if dims == 2:
return torch.nn.Conv2d( return ops.Conv2d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
@ -41,7 +42,7 @@ def make_conv_nd(
groups=groups, groups=groups,
bias=bias, bias=bias,
) )
return torch.nn.Conv3d( return ops.Conv3d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
@ -71,11 +72,11 @@ def make_linear_nd(
bias=True, bias=True,
): ):
if dims == 2: if dims == 2:
return torch.nn.Conv2d( return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
) )
elif dims == 3 or dims == (2, 1): elif dims == 3 or dims == (2, 1):
return torch.nn.Conv3d( return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
) )
else: else:

View File

@ -342,7 +342,7 @@ class VAE:
self.latent_dim = 3 self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = 8 self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
@ -443,7 +443,9 @@ class VAE:
elif dims == 2: elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3: elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in) tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
@ -507,6 +509,12 @@ class VAE:
def get_sd(self): def get_sd(self):
return self.first_stage_model.state_dict() return self.first_stage_model.state_dict()
def spacial_compression_decode(self):
try:
return self.upscale_ratio[-1]
except:
return self.upscale_ratio
class StyleModel: class StyleModel:
def __init__(self, model, device="cpu"): def __init__(self, model, device="cpu"):
self.model = model self.model = model

View File

@ -10,7 +10,7 @@ class EmptyLTXVLatentVideo:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "generate" FUNCTION = "generate"

View File

@ -304,7 +304,8 @@ class VAEDecodeTiled:
def decode(self, vae, samples, tile_size, overlap=64): def decode(self, vae, samples, tile_size, overlap=64):
if tile_size < overlap * 4: if tile_size < overlap * 4:
overlap = tile_size // 4 overlap = tile_size // 4
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8) compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
if len(images.shape) == 5: #Combine batches if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, ) return (images, )