diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 433381df..6dd99afd 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -5,6 +5,7 @@ import comfy.utils import comfy.model_management import comfy.model_detection import comfy.model_patcher +import comfy.ops import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -248,6 +249,15 @@ class ControlLoraOps: else: raise ValueError(f"unsupported dimensions: {dims}") + class Conv3d(comfy.ops.Conv3d): + pass + + class GroupNorm(comfy.ops.GroupNorm): + pass + + class LayerNorm(comfy.ops.LayerNorm): + pass + class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, device=None): diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f6845238..c2b85a69 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -83,16 +83,6 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) @@ -414,10 +404,10 @@ class BasicTransformerBlock(nn.Module): self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head @@ -559,7 +549,7 @@ class SpatialTransformer(nn.Module): context_dim = [context_dim] * depth self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) if not use_linear: self.proj_in = operations.Conv2d(in_channels, inner_dim, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 48264892..855c3d1f 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -177,7 +177,7 @@ class ResBlock(TimestepBlock): padding = kernel_size // 2 self.in_layers = nn.Sequential( - nn.GroupNorm(32, channels, dtype=dtype, device=device), + operations.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) @@ -206,12 +206,11 @@ class ResBlock(TimestepBlock): ), ) self.out_layers = nn.Sequential( - nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), + operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) - ), + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) + , ) if self.out_channels == channels: @@ -810,13 +809,13 @@ class UNetModel(nn.Module): self._feature_size += ch self.out = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) diff --git a/comfy/ops.py b/comfy/ops.py index 0bfb698a..deb849d6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -13,6 +13,14 @@ class Conv3d(torch.nn.Conv3d): def reset_parameters(self): return None +class GroupNorm(torch.nn.GroupNorm): + def reset_parameters(self): + return None + +class LayerNorm(torch.nn.LayerNorm): + def reset_parameters(self): + return None + def conv_nd(dims, *args, **kwargs): if dims == 2: return Conv2d(*args, **kwargs)