Fix control loras breaking.
This commit is contained in:
parent
db8b59ecff
commit
448d9263a2
|
@ -201,7 +201,7 @@ class ControlNet(ControlBase):
|
|||
super().cleanup()
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module):
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
|
@ -220,7 +220,7 @@ class ControlLoraOps:
|
|||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
|
|
40
comfy/ops.py
40
comfy/ops.py
|
@ -31,13 +31,13 @@ def cast_bias_weight(s, input):
|
|||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear):
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
|
@ -51,11 +51,7 @@ class disable_weight_init:
|
|||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
|
@ -69,11 +65,7 @@ class disable_weight_init:
|
|||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
|
@ -87,11 +79,7 @@ class disable_weight_init:
|
|||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
|
@ -106,11 +94,7 @@ class disable_weight_init:
|
|||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
|
@ -128,11 +112,7 @@ class disable_weight_init:
|
|||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
|
|
Loading…
Reference in New Issue