Merge branch 'master' into patch_hooks_improved_memory
This commit is contained in:
commit
26ccd3b5f9
58
README.md
58
README.md
|
@ -75,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||||
|
|
||||||
| Keybind | Explanation |
|
| Keybind | Explanation |
|
||||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||||
| Ctrl + Enter | Queue up current graph for generation |
|
| `Ctrl` + `Enter` | Queue up current graph for generation |
|
||||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
|
||||||
| Ctrl + Alt + Enter | Cancel current generation |
|
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
|
||||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
|
||||||
| Ctrl + S | Save workflow |
|
| `Ctrl` + `S` | Save workflow |
|
||||||
| Ctrl + O | Load workflow |
|
| `Ctrl` + `O` | Load workflow |
|
||||||
| Ctrl + A | Select all nodes |
|
| `Ctrl` + `A` | Select all nodes |
|
||||||
| Alt + C | Collapse/uncollapse selected nodes |
|
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
|
||||||
| Ctrl + M | Mute/unmute selected nodes |
|
| `Ctrl` + `M` | Mute/unmute selected nodes |
|
||||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||||
| Delete/Backspace | Delete selected nodes |
|
| `Delete`/`Backspace` | Delete selected nodes |
|
||||||
| Ctrl + Backspace | Delete the current graph |
|
| `Ctrl` + `Backspace` | Delete the current graph |
|
||||||
| Space | Move the canvas around when held and moving the cursor |
|
| `Space` | Move the canvas around when held and moving the cursor |
|
||||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
|
||||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
|
||||||
| Ctrl + D | Load default graph |
|
| `Ctrl` + `D` | Load default graph |
|
||||||
| Alt + `+` | Canvas Zoom in |
|
| `Alt` + `+` | Canvas Zoom in |
|
||||||
| Alt + `-` | Canvas Zoom out |
|
| `Alt` + `-` | Canvas Zoom out |
|
||||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
|
||||||
| P | Pin/Unpin selected nodes |
|
| `P` | Pin/Unpin selected nodes |
|
||||||
| Ctrl + G | Group selected nodes |
|
| `Ctrl` + `G` | Group selected nodes |
|
||||||
| Q | Toggle visibility of the queue |
|
| `Q` | Toggle visibility of the queue |
|
||||||
| H | Toggle visibility of history |
|
| `H` | Toggle visibility of history |
|
||||||
| R | Refresh graph |
|
| `R` | Refresh graph |
|
||||||
| Double-Click LMB | Open node quick search palette |
|
| Double-Click LMB | Open node quick search palette |
|
||||||
| Shift + Drag | Move multiple wires at once |
|
| `Shift` + Drag | Move multiple wires at once |
|
||||||
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||||
|
|
||||||
Ctrl can also be replaced with Cmd instead for macOS users
|
`Ctrl` can also be replaced with `Cmd` instead for macOS users
|
||||||
|
|
||||||
# Installing
|
# Installing
|
||||||
|
|
||||||
|
|
|
@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
|
||||||
style=None,
|
style=None,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
control=None,
|
control=None,
|
||||||
transformer_options=None,
|
transformer_options={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Forward pass of the encoder.
|
Forward pass of the encoder.
|
||||||
|
@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
|
||||||
return_dict: bool
|
return_dict: bool
|
||||||
Whether to return a dictionary.
|
Whether to return a dictionary.
|
||||||
"""
|
"""
|
||||||
#import pdb
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
#pdb.set_trace()
|
|
||||||
encoder_hidden_states = context
|
encoder_hidden_states = context
|
||||||
text_states = encoder_hidden_states # 2,77,1024
|
text_states = encoder_hidden_states # 2,77,1024
|
||||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||||
|
@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
|
||||||
# Concatenate all extra vectors
|
# Concatenate all extra vectors
|
||||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
controls = None
|
controls = None
|
||||||
if control:
|
if control:
|
||||||
controls = control.get("output", None)
|
controls = control.get("output", None)
|
||||||
|
@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
|
||||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
|
||||||
else:
|
else:
|
||||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
skip = None
|
||||||
|
|
||||||
|
if ("double_block", layer) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
|
||||||
|
|
||||||
if layer < (self.depth // 2 - 1):
|
if layer < (self.depth // 2 - 1):
|
||||||
skips.append(x)
|
skips.append(x)
|
||||||
|
|
|
@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module):
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||||
|
|
||||||
|
@ -415,13 +415,15 @@ class LTXVModel(torch.nn.Module):
|
||||||
|
|
||||||
self.patchifier = SymmetricPatchifier(1)
|
self.patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs):
|
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
indices_grid = self.patchifier.get_grid(
|
indices_grid = self.patchifier.get_grid(
|
||||||
orig_num_frames=x.shape[2],
|
orig_num_frames=x.shape[2],
|
||||||
orig_height=x.shape[3],
|
orig_height=x.shape[3],
|
||||||
orig_width=x.shape[4],
|
orig_width=x.shape[4],
|
||||||
batch_size=x.shape[0],
|
batch_size=x.shape[0],
|
||||||
scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate
|
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -468,18 +470,28 @@ class LTXVModel(torch.nn.Module):
|
||||||
batch_size, -1, x.shape[-1]
|
batch_size, -1, x.shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
for block in self.transformer_blocks:
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
x = block(
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
x,
|
if ("double_block", i) in blocks_replace:
|
||||||
context=context,
|
def block_wrap(args):
|
||||||
attention_mask=attention_mask,
|
out = {}
|
||||||
timestep=timestep,
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||||
pe=pe
|
return out
|
||||||
)
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
context=context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
timestep=timestep,
|
||||||
|
pe=pe
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Output
|
# 3. Output
|
||||||
scale_shift_values = (
|
scale_shift_values = (
|
||||||
self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||||
)
|
)
|
||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
|
|
|
@ -2,6 +2,8 @@ from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Module):
|
class CausalConv3d(nn.Module):
|
||||||
|
@ -29,7 +31,7 @@ class CausalConv3d(nn.Module):
|
||||||
width_pad = kernel_size[2] // 2
|
width_pad = kernel_size[2] // 2
|
||||||
padding = (0, height_pad, width_pad)
|
padding = (0, height_pad, width_pad)
|
||||||
|
|
||||||
self.conv = nn.Conv3d(
|
self.conv = ops.Conv3d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
|
|
|
@ -628,10 +628,10 @@ class processor(nn.Module):
|
||||||
self.register_buffer("channel", torch.empty(128))
|
self.register_buffer("channel", torch.empty(128))
|
||||||
|
|
||||||
def un_normalize(self, x):
|
def un_normalize(self, x):
|
||||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)
|
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
def normalize(self, x):
|
def normalize(self, x):
|
||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -4,7 +4,8 @@ import torch
|
||||||
|
|
||||||
from .dual_conv3d import DualConv3d
|
from .dual_conv3d import DualConv3d
|
||||||
from .causal_conv3d import CausalConv3d
|
from .causal_conv3d import CausalConv3d
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
def make_conv_nd(
|
def make_conv_nd(
|
||||||
dims: Union[int, Tuple[int, int]],
|
dims: Union[int, Tuple[int, int]],
|
||||||
|
@ -19,7 +20,7 @@ def make_conv_nd(
|
||||||
causal=False,
|
causal=False,
|
||||||
):
|
):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return torch.nn.Conv2d(
|
return ops.Conv2d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
|
@ -41,7 +42,7 @@ def make_conv_nd(
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
return torch.nn.Conv3d(
|
return ops.Conv3d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
|
@ -71,11 +72,11 @@ def make_linear_nd(
|
||||||
bias=True,
|
bias=True,
|
||||||
):
|
):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return torch.nn.Conv2d(
|
return ops.Conv2d(
|
||||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||||
)
|
)
|
||||||
elif dims == 3 or dims == (2, 1):
|
elif dims == 3 or dims == (2, 1):
|
||||||
return torch.nn.Conv3d(
|
return ops.Conv3d(
|
||||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
12
comfy/sd.py
12
comfy/sd.py
|
@ -342,7 +342,7 @@ class VAE:
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
|
@ -443,7 +443,9 @@ class VAE:
|
||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
pixel_samples = self.decode_tiled_3d(samples_in)
|
tile = 256 // self.spacial_compression_decode()
|
||||||
|
overlap = tile // 4
|
||||||
|
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
@ -507,6 +509,12 @@ class VAE:
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.first_stage_model.state_dict()
|
return self.first_stage_model.state_dict()
|
||||||
|
|
||||||
|
def spacial_compression_decode(self):
|
||||||
|
try:
|
||||||
|
return self.upscale_ratio[-1]
|
||||||
|
except:
|
||||||
|
return self.upscale_ratio
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
def __init__(self, model, device="cpu"):
|
def __init__(self, model, device="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
|
@ -10,7 +10,7 @@ class EmptyLTXVLatentVideo:
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
3
nodes.py
3
nodes.py
|
@ -304,7 +304,8 @@ class VAEDecodeTiled:
|
||||||
def decode(self, vae, samples, tile_size, overlap=64):
|
def decode(self, vae, samples, tile_size, overlap=64):
|
||||||
if tile_size < overlap * 4:
|
if tile_size < overlap * 4:
|
||||||
overlap = tile_size // 4
|
overlap = tile_size // 4
|
||||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
|
compression = vae.spacial_compression_decode()
|
||||||
|
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
||||||
if len(images.shape) == 5: #Combine batches
|
if len(images.shape) == 5: #Combine batches
|
||||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
return (images, )
|
return (images, )
|
||||||
|
|
Loading…
Reference in New Issue