Support block replace patches in auraflow.

This commit is contained in:
comfyanonymous 2024-11-17 08:19:59 -05:00
parent 41886af138
commit d9f90965c8
1 changed files with 27 additions and 5 deletions

View File

@ -437,7 +437,8 @@ class MMDiT(nn.Module):
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w] pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, **kwargs): def forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE # patchify x, add PE
b, c, h, w = x.shape b, c, h, w = x.shape
@ -458,14 +459,35 @@ class MMDiT(nn.Module):
global_cond = self.t_embedder(t, x.dtype) # B, D global_cond = self.t_embedder(t, x.dtype) # B, D
blocks_replace = patches_replace.get("dit", {})
if len(self.double_layers) > 0: if len(self.double_layers) > 0:
for layer in self.double_layers: for i, layer in enumerate(self.double_layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs) c, x = layer(c, x, global_cond, **kwargs)
if len(self.single_layers) > 0: if len(self.single_layers) > 0:
c_len = c.size(1) c_len = c.size(1)
cx = torch.cat([c, x], dim=1) cx = torch.cat([c, x], dim=1)
for layer in self.single_layers: for i, layer in enumerate(self.single_layers):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs) cx = layer(cx, global_cond, **kwargs)
x = cx[:, c_len:] x = cx[:, c_len:]