diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index bb5e02b6..99f49810 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -8,9 +8,8 @@ from torch import Tensor, nn from .math import attention, rope import comfy.ops - class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: list[int]): + def __init__(self, dim: int, theta: int, axes_dim: list): super().__init__() self.dim = dim self.theta = theta @@ -79,7 +78,7 @@ class QKNorm(torch.nn.Module): self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: q = self.query_norm(q) k = self.key_norm(k) return q.to(v), k.to(v) @@ -118,7 +117,7 @@ class Modulation(nn.Module): self.multiplier = 6 if double else 3 self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) - def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + def forward(self, vec: Tensor) -> tuple: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) return ( @@ -156,7 +155,7 @@ class DoubleStreamBlock(nn.Module): operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -203,7 +202,7 @@ class SingleStreamBlock(nn.Module): hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, - qk_scale: float | None = None, + qk_scale: float = None, dtype=None, device=None, operations=None diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index e4ef624e..d9bb568a 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -21,7 +21,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: return out.float() -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index f77834c1..ae34052d 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -26,7 +26,7 @@ class FluxParams: num_heads: int depth: int depth_single_blocks: int - axes_dim: list[int] + axes_dim: list theta: int qkv_bias: bool guidance_embed: bool @@ -92,7 +92,7 @@ class Flux(nn.Module): txt_ids: Tensor, timesteps: Tensor, y: Tensor, - guidance: Tensor | None = None, + guidance: Tensor = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.")