Clean up the extra_options dict for the transformer patches.

Now everything in transformer_options gets put in extra_options.
This commit is contained in:
comfyanonymous 2023-11-26 03:13:56 -05:00
parent 5b37270d3a
commit 50dc39d6ec
2 changed files with 16 additions and 26 deletions

View File

@ -430,31 +430,20 @@ class BasicTransformerBlock(nn.Module):
extra_options = {}
block = None
block_index = 0
if "current_index" in transformer_options:
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
block_index = transformer_options["block_index"]
extra_options["block_index"] = block_index
if "original_shape" in transformer_options:
extra_options["original_shape"] = transformer_options["original_shape"]
if "block" in transformer_options:
block = transformer_options["block"]
extra_options["block"] = block
if "cond_or_uncond" in transformer_options:
extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
transformer_patches = {}
transformer_patches_replace = {}
for k in transformer_options:
if k == "patches":
transformer_patches = transformer_options[k]
elif k == "patches_replace":
transformer_patches_replace = transformer_options[k]
else:
extra_options[k] = transformer_options[k]
extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
if "patches_replace" in transformer_options:
transformer_patches_replace = transformer_options["patches_replace"]
else:
transformer_patches_replace = {}
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))

View File

@ -31,7 +31,7 @@ class TimestepBlock(nn.Module):
Apply the module to `x` given `emb` timestep embeddings.
"""
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, VideoResBlock):
@ -40,11 +40,12 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
x = layer(x, emb)
elif isinstance(layer, SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
transformer_options["current_index"] += 1
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
if "current_index" in transformer_options:
transformer_options["current_index"] += 1
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
@ -830,7 +831,7 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)