From d9f90965c8601b671faca9c8784b529e389c49aa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Nov 2024 08:19:59 -0500 Subject: [PATCH] Support block replace patches in auraflow. --- comfy/ldm/aura/mmdit.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index cd9a4218..77090372 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -437,7 +437,8 @@ class MMDiT(nn.Module): pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w] return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) - def forward(self, x, timestep, context, **kwargs): + def forward(self, x, timestep, context, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) # patchify x, add PE b, c, h, w = x.shape @@ -458,15 +459,36 @@ class MMDiT(nn.Module): global_cond = self.t_embedder(t, x.dtype) # B, D + blocks_replace = patches_replace.get("dit", {}) if len(self.double_layers) > 0: - for layer in self.double_layers: - c, x = layer(c, x, global_cond, **kwargs) + for i, layer in enumerate(self.double_layers): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = layer(args["txt"], + args["img"], + args["vec"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) + c = out["txt"] + x = out["img"] + else: + c, x = layer(c, x, global_cond, **kwargs) if len(self.single_layers) > 0: c_len = c.size(1) cx = torch.cat([c, x], dim=1) - for layer in self.single_layers: - cx = layer(cx, global_cond, **kwargs) + for i, layer in enumerate(self.single_layers): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = layer(args["img"], args["vec"]) + return out + + out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) + cx = out["img"] + else: + cx = layer(cx, global_cond, **kwargs) x = cx[:, c_len:]