All the unet ops with weights are now handled by comfy.ops
This commit is contained in:
parent
6efe561c2a
commit
af365e4dd1
|
@ -5,6 +5,7 @@ import comfy.utils
|
|||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
|
||||
import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
|
@ -248,6 +249,15 @@ class ControlLoraOps:
|
|||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
class Conv3d(comfy.ops.Conv3d):
|
||||
pass
|
||||
|
||||
class GroupNorm(comfy.ops.GroupNorm):
|
||||
pass
|
||||
|
||||
class LayerNorm(comfy.ops.LayerNorm):
|
||||
pass
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
||||
|
|
|
@ -83,16 +83,6 @@ class FeedForward(nn.Module):
|
|||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
|
@ -414,10 +404,10 @@ class BasicTransformerBlock(nn.Module):
|
|||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.checkpoint = checkpoint
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
|
@ -559,7 +549,7 @@ class SpatialTransformer(nn.Module):
|
|||
context_dim = [context_dim] * depth
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
if not use_linear:
|
||||
self.proj_in = operations.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
|
|
|
@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
|
|||
padding = kernel_size // 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
operations.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||
)
|
||||
|
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
|
|||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||
),
|
||||
,
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
|
@ -810,13 +809,13 @@ class UNetModel(nn.Module):
|
|||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
|
|
@ -13,6 +13,14 @@ class Conv3d(torch.nn.Conv3d):
|
|||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return Conv2d(*args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue