All the unet ops with weights are now handled by comfy.ops

This commit is contained in:
comfyanonymous 2023-12-04 03:12:18 -05:00
parent 6efe561c2a
commit af365e4dd1
4 changed files with 28 additions and 21 deletions

View File

@ -5,6 +5,7 @@ import comfy.utils
import comfy.model_management
import comfy.model_detection
import comfy.model_patcher
import comfy.ops
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
@ -248,6 +249,15 @@ class ControlLoraOps:
else:
raise ValueError(f"unsupported dimensions: {dims}")
class Conv3d(comfy.ops.Conv3d):
pass
class GroupNorm(comfy.ops.GroupNorm):
pass
class LayerNorm(comfy.ops.LayerNorm):
pass
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):

View File

@ -83,16 +83,6 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
@ -414,10 +404,10 @@ class BasicTransformerBlock(nn.Module):
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
@ -559,7 +549,7 @@ class SpatialTransformer(nn.Module):
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device)
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if not use_linear:
self.proj_in = operations.Conv2d(in_channels,
inner_dim,

View File

@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
padding = kernel_size // 2
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
operations.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
),
,
)
if self.out_channels == channels:
@ -810,13 +809,13 @@ class UNetModel(nn.Module):
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)

View File

@ -13,6 +13,14 @@ class Conv3d(torch.nn.Conv3d):
def reset_parameters(self):
return None
class GroupNorm(torch.nn.GroupNorm):
def reset_parameters(self):
return None
class LayerNorm(torch.nn.LayerNorm):
def reset_parameters(self):
return None
def conv_nd(dims, *args, **kwargs):
if dims == 2:
return Conv2d(*args, **kwargs)