Controlnet refactor.
This commit is contained in:
parent
97b409cd48
commit
66aaa14001
|
@ -289,7 +289,8 @@ class ControlNet(nn.Module):
|
|||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
|
||||
outs = []
|
||||
out_output = []
|
||||
out_middle = []
|
||||
|
||||
hs = []
|
||||
if self.num_classes is not None:
|
||||
|
@ -304,10 +305,10 @@ class ControlNet(nn.Module):
|
|||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
outs.append(zero_conv(h, emb, context))
|
||||
out_output.append(zero_conv(h, emb, context))
|
||||
|
||||
h = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
out_middle.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
return outs
|
||||
return {"middle": out_middle, "output": out_output}
|
||||
|
||||
|
|
|
@ -89,27 +89,12 @@ class ControlBase:
|
|||
return self.previous_controlnet.inference_memory_requirements(dtype)
|
||||
return 0
|
||||
|
||||
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||
def control_merge(self, control, control_prev, output_dtype):
|
||||
out = {'input':[], 'middle':[], 'output': []}
|
||||
|
||||
if control_input is not None:
|
||||
for i in range(len(control_input)):
|
||||
key = 'input'
|
||||
x = control_input[i]
|
||||
if x is not None:
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
out[key].insert(0, x)
|
||||
|
||||
if control_output is not None:
|
||||
for key in control:
|
||||
control_output = control[key]
|
||||
for i in range(len(control_output)):
|
||||
if i == (len(control_output) - 1):
|
||||
key = 'middle'
|
||||
index = 0
|
||||
else:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control_output[i]
|
||||
if x is not None:
|
||||
if self.global_average_pooling:
|
||||
|
@ -120,6 +105,7 @@ class ControlBase:
|
|||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
|
||||
if control_prev is not None:
|
||||
for x in ['input', 'middle', 'output']:
|
||||
o = out[x]
|
||||
|
@ -182,7 +168,7 @@ class ControlNet(ControlBase):
|
|||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
|
@ -490,12 +476,11 @@ class T2IAdapter(ControlBase):
|
|||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||
self.t2i_model.cpu()
|
||||
|
||||
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
||||
mid = None
|
||||
if self.t2i_model.xl == True:
|
||||
mid = control_input[-1:]
|
||||
control_input = control_input[:-1]
|
||||
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
|
||||
control_input = {}
|
||||
for k in self.control_input:
|
||||
control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
|
||||
|
||||
return self.control_merge(control_input, control_prev, x_noisy.dtype)
|
||||
|
||||
def copy(self):
|
||||
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
|
||||
|
|
|
@ -90,4 +90,4 @@ class ControlNet(nn.Module):
|
|||
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
|
||||
for i, idx in enumerate(self.proj_blocks):
|
||||
proj_outputs[idx] = self.projections[i](x)
|
||||
return proj_outputs
|
||||
return {"input": proj_outputs[::-1]}
|
||||
|
|
|
@ -153,7 +153,13 @@ class Adapter(nn.Module):
|
|||
features.append(None)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
features = features[::-1]
|
||||
|
||||
if self.xl:
|
||||
return {"input": features[1:], "middle": features[:1]}
|
||||
else:
|
||||
return {"input": features}
|
||||
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
|
@ -290,4 +296,4 @@ class Adapter_light(nn.Module):
|
|||
features.append(None)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
return {"input": features[::-1]}
|
||||
|
|
Loading…
Reference in New Issue