Support SVD img2vid model.

This commit is contained in:
comfyanonymous 2023-11-23 19:41:33 -05:00
parent 022033a0e7
commit 871cc20e13
11 changed files with 1030 additions and 100 deletions

View File

@ -54,6 +54,7 @@ class ControlNet(nn.Module):
transformer_depth_output=None,
device=None,
operations=comfy.ops,
**kwargs,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"

View File

@ -5,8 +5,10 @@ import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from functools import partial
from .diffusionmodules.util import checkpoint
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
@ -370,21 +372,45 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
context_dim_attn2 = None
if not switch_temporal_ca_to_sa:
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
@ -418,6 +444,12 @@ class BasicTransformerBlock(nn.Module):
else:
transformer_patches_replace = {}
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
n = self.norm1(x)
if self.disable_self_attn:
context_attn1 = context
@ -465,31 +497,34 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
x = p(x, extra_options)
n = self.norm2(x)
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
for p in patch:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
if self.attn2 is not None:
n = self.norm2(x)
if self.switch_temporal_ca_to_sa:
context_attn2 = n
else:
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
for p in patch:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
n = p(n, extra_options)
x += n
x = self.ff(self.norm3(x)) + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
return x
@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x)
return x + x_in
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
# timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)
]
)
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
)
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
if time_context is None:
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
x = block(
x,
context=spatial_context,
transformer_options=transformer_options,
)
x_mix = x
x_mix = x_mix + emb
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -5,6 +5,8 @@ import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from functools import partial
from .util import (
checkpoint,
@ -12,8 +14,9 @@ from .util import (
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.ops
@ -29,10 +32,15 @@ class TimestepBlock(nn.Module):
"""
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
if "current_index" in transformer_options:
@ -145,6 +153,9 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
dtype=None,
device=None,
operations=comfy.ops
@ -157,11 +168,17 @@ class ResBlock(TimestepBlock):
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
if isinstance(kernel_size, list):
padding = [k // 2 for k in kernel_size]
else:
padding = kernel_size // 2
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
self.updown = up or down
@ -175,19 +192,24 @@ class ResBlock(TimestepBlock):
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
operations.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
)
self.skip_t_emb = skip_t_emb
if self.skip_t_emb:
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential(
nn.SiLU(),
operations.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
),
)
@ -195,7 +217,7 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
)
else:
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
@ -221,19 +243,110 @@ class ResBlock(TimestepBlock):
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
emb_out = None
if not self.skip_t_emb:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_norm(h)
if emb_out is not None:
scale, shift = th.chunk(emb_out, 2, dim=1)
h *= (1 + scale)
h += shift
h = out_rest(h)
else:
h = h + emb_out
if emb_out is not None:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size=3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels=None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
dtype=dtype,
device=device,
operations=operations
)
self.time_stack = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
dtype=dtype,
device=device,
operations=operations
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_stack(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
@ -310,6 +423,16 @@ class UNetModel(nn.Module):
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
use_temporal_resblock=False,
use_temporal_attention=False,
time_context_dim=None,
extra_ff_mix_layer=False,
use_spatial_context=False,
merge_strategy=None,
merge_factor=0.0,
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None,
operations=comfy.ops,
):
@ -364,8 +487,12 @@ class UNetModel(nn.Module):
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_temporal_resblocks = use_temporal_resblock
self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
@ -402,13 +529,104 @@ class UNetModel(nn.Module):
input_block_chans = [model_channels]
ch = model_channels
ds = 1
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disable_self_attn=False,
):
if use_temporal_attention:
return SpatialVideoTransformer(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=merge_strategy,
merge_factor=merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_channels,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
dtype=None,
device=None,
operations=comfy.ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
else:
return ResBlock(
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
use_checkpoint=use_checkpoint,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
@ -435,11 +653,9 @@ class UNetModel(nn.Module):
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(SpatialTransformer(
layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
@ -448,10 +664,13 @@ class UNetModel(nn.Module):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
@ -481,10 +700,14 @@ class UNetModel(nn.Module):
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
mid_block = [
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
@ -493,15 +716,18 @@ class UNetModel(nn.Module):
operations=operations
)]
if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
@ -517,10 +743,13 @@ class UNetModel(nn.Module):
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
@ -548,19 +777,21 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
@ -602,6 +833,10 @@ class UNetModel(nn.Module):
transformer_options["current_index"] = 0
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
@ -616,7 +851,7 @@ class UNetModel(nn.Module):
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
@ -630,9 +865,10 @@ class UNetModel(nn.Module):
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
@ -649,7 +885,7 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)

View File

@ -13,11 +13,78 @@ import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from einops import repeat, rearrange
from comfy.ldm.util import instantiate_from_config
import comfy.ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
else:
raise NotImplementedError()
return alpha
def forward(
self,
x_spatial,
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (

View File

@ -0,0 +1,244 @@
import functools
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
import comfy.ops
from .diffusionmodules.model import (
AttnBlock,
Decoder,
ResnetBlock,
)
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
from .attention import BasicTransformerBlock
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
b, c, h, w = x.shape
if timesteps is None:
timesteps = b
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps=None, skip_video=False):
if timesteps is None:
timesteps = input.shape[0]
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class AttnVideoBlock(AttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = BasicTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
comfy.ops.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
comfy.ops.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps=None, skip_time_block=False):
if skip_time_block:
return super().forward(x)
if timesteps is None:
timesteps = x.shape[0]
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
)
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
if self.time_mode != "attn-only":
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
if self.time_mode not in ["conv-only", "only-last-conv"]:
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
if self.time_mode not in ["attn-only", "only-last-conv"]:
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return (
self.conv_out.time_mix_conv.weight
if not skip_time_mix
else self.conv_out.weight
)

View File

@ -10,17 +10,22 @@ from . import utils
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
V_PREDICTION_EDM = 3
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
s = ModelSamplingDiscrete
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
class ModelSampling(s, c):
pass
@ -262,3 +267,48 @@ class SDXL(BaseModel):
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
fps_id = kwargs.get("fps", 6) - 1
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.Tensor([fps_id])))
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
out.append(self.embedder(torch.Tensor([augmentation])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
def extra_conds(self, **kwargs):
out = {}
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out

View File

@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
return last_transformer_depth, context_dim, use_linear_in_transformer
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
def detect_unet_config(state_dict, key_prefix, dtype):
@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
context_dim = None
use_linear_in_transformer = False
video_model = False
current_res = 1
count = 0
@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
if context_dim is None:
context_dim = out[1]
use_linear_in_transformer = out[2]
video_model = out[3]
else:
transformer_depth.append(0)
@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
unet_config["context_dim"] = context_dim
if video_model:
unet_config["extra_ff_mix_layer"] = True
unet_config["use_spatial_context"] = True
unet_config["merge_strategy"] = "learned_with_images"
unet_config["merge_factor"] = 0.0
unet_config["video_kernel_size"] = [3, 1, 1]
unet_config["use_temporal_resblock"] = True
unet_config["use_temporal_attention"] = True
else:
unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False
return unet_config
def model_config_from_unet_config(unet_config):

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
class EPS:
def calculate_input(self, sigma, noise):
@ -83,3 +83,47 @@ class ModelSamplingDiscrete(torch.nn.Module):
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
def set_sigma_range(self, sigma_min, sigma_max):
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return 0.25 * sigma.log()
def sigma(self, timestep):
return (timestep / 0.25).exp()
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)

View File

@ -159,7 +159,15 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if config is None:
if "taesd_decoder.1.weight" in sd:
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
decoder_config = encoder_config.copy()
decoder_config["video_kernel_size"] = [3, 1, 1]
decoder_config["alpha"] = 0.0
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters

View File

@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
latent_format = latent_formats.SD15
@ -88,6 +90,7 @@ class SD21UnclipL(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 1536,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@ -100,6 +103,7 @@ class SD21UnclipH(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 2048,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim": 1280,
"adm_in_channels": 2560,
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@ -203,8 +209,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 4, 4],
"context_dim": 2048,
"adm_in_channels": 2816
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 768,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
models += [SVD_img2vid]

View File

@ -128,6 +128,36 @@ class ModelSamplingDiscrete:
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG:
@classmethod
def INPUT_TYPES(s):
@ -169,5 +199,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"RescaleCFG": RescaleCFG,
}