Skip layer guidance now works on stable audio model.
This commit is contained in:
parent
898615122f
commit
22535d0589
|
@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
|
||||||
return_info = False,
|
return_info = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||||
batch, seq, device = *x.shape[:2], x.device
|
batch, seq, device = *x.shape[:2], x.device
|
||||||
|
context = kwargs["context"]
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"hidden_states": [],
|
"hidden_states": [],
|
||||||
|
@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
|
||||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||||
x = x + self.pos_emb(x)
|
x = x + self.pos_emb(x)
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
# Iterate over the transformer layers
|
# Iterate over the transformer layers
|
||||||
for layer in self.layers:
|
for i, layer in enumerate(self.layers):
|
||||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||||
|
|
||||||
if return_info:
|
if return_info:
|
||||||
|
@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
|
||||||
mask=None,
|
mask=None,
|
||||||
return_info=False,
|
return_info=False,
|
||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
return self._forward(
|
return self._forward(
|
||||||
x,
|
x,
|
||||||
|
|
Loading…
Reference in New Issue