Support block replace patches in auraflow.
This commit is contained in:
parent
41886af138
commit
d9f90965c8
|
@ -437,7 +437,8 @@ class MMDiT(nn.Module):
|
||||||
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||||
|
|
||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
# patchify x, add PE
|
# patchify x, add PE
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
@ -458,15 +459,36 @@ class MMDiT(nn.Module):
|
||||||
|
|
||||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
if len(self.double_layers) > 0:
|
if len(self.double_layers) > 0:
|
||||||
for layer in self.double_layers:
|
for i, layer in enumerate(self.double_layers):
|
||||||
c, x = layer(c, x, global_cond, **kwargs)
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["txt"], out["img"] = layer(args["txt"],
|
||||||
|
args["img"],
|
||||||
|
args["vec"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||||
|
c = out["txt"]
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
c, x = layer(c, x, global_cond, **kwargs)
|
||||||
|
|
||||||
if len(self.single_layers) > 0:
|
if len(self.single_layers) > 0:
|
||||||
c_len = c.size(1)
|
c_len = c.size(1)
|
||||||
cx = torch.cat([c, x], dim=1)
|
cx = torch.cat([c, x], dim=1)
|
||||||
for layer in self.single_layers:
|
for i, layer in enumerate(self.single_layers):
|
||||||
cx = layer(cx, global_cond, **kwargs)
|
if ("single_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = layer(args["img"], args["vec"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||||
|
cx = out["img"]
|
||||||
|
else:
|
||||||
|
cx = layer(cx, global_cond, **kwargs)
|
||||||
|
|
||||||
x = cx[:, c_len:]
|
x = cx[:, c_len:]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue