Controlnet refactor.

This commit is contained in:
comfyanonymous 2024-06-25 17:02:05 -04:00
parent 97b409cd48
commit 66aaa14001
4 changed files with 24 additions and 32 deletions

View File

@ -289,7 +289,8 @@ class ControlNet(nn.Module):
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
out_output = []
out_middle = []
hs = []
if self.num_classes is not None:
@ -304,10 +305,10 @@ class ControlNet(nn.Module):
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
out_output.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
out_middle.append(self.middle_block_out(h, emb, context))
return outs
return {"middle": out_middle, "output": out_output}

View File

@ -89,27 +89,12 @@ class ControlBase:
return self.previous_controlnet.inference_memory_requirements(dtype)
return 0
def control_merge(self, control_input, control_output, control_prev, output_dtype):
def control_merge(self, control, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []}
if control_input is not None:
for i in range(len(control_input)):
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)
if control_output is not None:
for key in control:
control_output = control[key]
for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i]
if x is not None:
if self.global_average_pooling:
@ -120,6 +105,7 @@ class ControlBase:
x = x.to(output_dtype)
out[key].append(x)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]
@ -182,7 +168,7 @@ class ControlNet(ControlBase):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
return self.control_merge(control, control_prev, output_dtype)
def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@ -490,12 +476,11 @@ class T2IAdapter(ControlBase):
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
mid = None
if self.t2i_model.xl == True:
mid = control_input[-1:]
control_input = control_input[:-1]
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
control_input = {}
for k in self.control_input:
control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
return self.control_merge(control_input, control_prev, x_noisy.dtype)
def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)

View File

@ -90,4 +90,4 @@ class ControlNet(nn.Module):
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
for i, idx in enumerate(self.proj_blocks):
proj_outputs[idx] = self.projections[i](x)
return proj_outputs
return {"input": proj_outputs[::-1]}

View File

@ -153,7 +153,13 @@ class Adapter(nn.Module):
features.append(None)
features.append(x)
return features
features = features[::-1]
if self.xl:
return {"input": features[1:], "middle": features[:1]}
else:
return {"input": features}
class LayerNorm(nn.LayerNorm):
@ -290,4 +296,4 @@ class Adapter_light(nn.Module):
features.append(None)
features.append(x)
return features
return {"input": features[::-1]}