diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 87ed0995..f49cef95 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -415,13 +415,15 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) - def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs): + def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) + indices_grid = self.patchifier.get_grid( orig_num_frames=x.shape[2], orig_height=x.shape[3], orig_width=x.shape[4], batch_size=x.shape[0], - scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate + scale_grid=((1 / frame_rate) * 8, 32, 32), device=x.device, ) @@ -468,14 +470,24 @@ class LTXVModel(torch.nn.Module): batch_size, -1, x.shape[-1] ) - for block in self.transformer_blocks: - x = block( - x, - context=context, - attention_mask=attention_mask, - timestep=timestep, - pe=pe - ) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + return out + + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block( + x, + context=context, + attention_mask=attention_mask, + timestep=timestep, + pe=pe + ) # 3. Output scale_shift_values = (