Add manual cast to controlnet.
This commit is contained in:
parent
3152023fbc
commit
32b7e7e769
|
@ -141,24 +141,24 @@ class ControlNet(nn.Module):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
|
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
||||||
|
|
||||||
self.input_hint_block = TimestepEmbedSequential(
|
self.input_hint_block = TimestepEmbedSequential(
|
||||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
|
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 16, 16, 3, padding=1),
|
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 32, 32, 3, padding=1),
|
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 96, 96, 3, padding=1),
|
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
|
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._feature_size = model_channels
|
self._feature_size = model_channels
|
||||||
|
@ -206,7 +206,7 @@ class ControlNet(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
input_block_chans.append(ch)
|
input_block_chans.append(ch)
|
||||||
if level != len(channel_mult) - 1:
|
if level != len(channel_mult) - 1:
|
||||||
|
@ -234,7 +234,7 @@ class ControlNet(nn.Module):
|
||||||
)
|
)
|
||||||
ch = out_ch
|
ch = out_ch
|
||||||
input_block_chans.append(ch)
|
input_block_chans.append(ch)
|
||||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||||
ds *= 2
|
ds *= 2
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
|
@ -276,11 +276,11 @@ class ControlNet(nn.Module):
|
||||||
operations=operations
|
operations=operations
|
||||||
)]
|
)]
|
||||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
def make_zero_conv(self, channels, operations=None):
|
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
||||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
|
|
|
@ -36,13 +36,13 @@ class ControlBase:
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
self.timestep_percent_range = (0.0, 1.0)
|
self.timestep_percent_range = (0.0, 1.0)
|
||||||
|
self.global_average_pooling = False
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.global_average_pooling = False
|
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
|
@ -77,6 +77,7 @@ class ControlBase:
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
|
c.global_average_pooling = self.global_average_pooling
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
|
@ -129,12 +130,14 @@ class ControlBase:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
self.load_device = load_device
|
||||||
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
|
@ -149,11 +152,8 @@ class ControlNet(ControlBase):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype = self.control_model.dtype
|
dtype = self.control_model.dtype
|
||||||
if comfy.model_management.supports_dtype(self.device, dtype):
|
if self.manual_cast_dtype is not None:
|
||||||
precision_scope = lambda a: contextlib.nullcontext(a)
|
dtype = self.manual_cast_dtype
|
||||||
else:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
output_dtype = x_noisy.dtype
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
|
@ -171,12 +171,11 @@ class ControlNet(ControlBase):
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
with precision_scope(comfy.model_management.get_autocast_device(self.device)):
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||||
return self.control_merge(None, control, control_prev, output_dtype)
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
@ -207,10 +206,11 @@ class ControlLoraOps:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module):
|
class Conv2d(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -246,10 +246,11 @@ class ControlLoraOps:
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
|
@ -263,12 +264,19 @@ class ControlLora(ControlNet):
|
||||||
controlnet_config = model.model_config.unet_config.copy()
|
controlnet_config = model.model_config.unet_config.copy()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||||
|
self.manual_cast_dtype = model.manual_cast_dtype
|
||||||
|
dtype = model.get_dtype()
|
||||||
|
if self.manual_cast_dtype is None:
|
||||||
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
|
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
|
||||||
pass
|
pass
|
||||||
|
else:
|
||||||
|
class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
|
||||||
|
pass
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
controlnet_config["operations"] = control_lora_ops
|
controlnet_config["operations"] = control_lora_ops
|
||||||
|
controlnet_config["dtype"] = dtype
|
||||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
dtype = model.get_dtype()
|
|
||||||
self.control_model.to(dtype)
|
|
||||||
self.control_model.to(comfy.model_management.get_torch_device())
|
self.control_model.to(comfy.model_management.get_torch_device())
|
||||||
diffusion_model = model.diffusion_model
|
diffusion_model = model.diffusion_model
|
||||||
sd = diffusion_model.state_dict()
|
sd = diffusion_model.state_dict()
|
||||||
|
@ -372,6 +380,10 @@ def load_controlnet(ckpt_path, model=None):
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
unet_dtype = comfy.model_management.unet_dtype()
|
||||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
if manual_cast_dtype is not None:
|
||||||
|
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
|
@ -400,14 +412,12 @@ def load_controlnet(ckpt_path, model=None):
|
||||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||||
print(missing, unexpected)
|
print(missing, unexpected)
|
||||||
|
|
||||||
control_model = control_model.to(unet_dtype)
|
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
global_average_pooling = True
|
global_average_pooling = True
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
|
|
Loading…
Reference in New Issue