All the unet ops with weights are now handled by comfy.ops
This commit is contained in:
parent
6efe561c2a
commit
af365e4dd1
|
@ -5,6 +5,7 @@ import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
import comfy.cldm.cldm
|
import comfy.cldm.cldm
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
|
@ -248,6 +249,15 @@ class ControlLoraOps:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
class Conv3d(comfy.ops.Conv3d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class GroupNorm(comfy.ops.GroupNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class LayerNorm(comfy.ops.LayerNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
||||||
|
|
|
@ -83,16 +83,6 @@ class FeedForward(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
def zero_module(module):
|
|
||||||
"""
|
|
||||||
Zero out the parameters of a module and return it.
|
|
||||||
"""
|
|
||||||
for p in module.parameters():
|
|
||||||
p.detach().zero_()
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
@ -414,10 +404,10 @@ class BasicTransformerBlock(nn.Module):
|
||||||
|
|
||||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||||
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
|
@ -559,7 +549,7 @@ class SpatialTransformer(nn.Module):
|
||||||
context_dim = [context_dim] * depth
|
context_dim = [context_dim] * depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = operations.Conv2d(in_channels,
|
self.proj_in = operations.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
|
|
|
@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
|
||||||
padding = kernel_size // 2
|
padding = kernel_size // 2
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
operations.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.out_channels == channels:
|
if self.out_channels == channels:
|
||||||
|
@ -810,13 +809,13 @@ class UNetModel(nn.Module):
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,6 +13,14 @@ class Conv3d(torch.nn.Conv3d):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
class GroupNorm(torch.nn.GroupNorm):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def conv_nd(dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return Conv2d(*args, **kwargs)
|
return Conv2d(*args, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue