From 22535d05896cf78d84924c492c8cfc17b8786c05 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Nov 2024 07:33:06 -0500 Subject: [PATCH] Skip layer guidance now works on stable audio model. --- comfy/ldm/audio/dit.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 4d2185be..5b3f498f 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module): return_info = False, **kwargs ): + patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) batch, seq, device = *x.shape[:2], x.device + context = kwargs["context"] info = { "hidden_states": [], @@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module): if self.use_sinusoidal_emb or self.use_abs_pos_emb: x = x + self.pos_emb(x) + blocks_replace = patches_replace.get("dit", {}) # Iterate over the transformer layers - for layer in self.layers: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + for i, layer in enumerate(self.layers): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) + return out + + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) + x = out["img"] + else: + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: @@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module): mask=None, return_info=False, control=None, - transformer_options={}, **kwargs): return self._forward( x,