Better Mac support on flux model.

This commit is contained in:
comfyanonymous 2024-08-01 12:55:28 -04:00
parent 1aa9cf3292
commit f2b80f95d2
2 changed files with 8 additions and 6 deletions

View File

@ -13,9 +13,9 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=pos.device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.einsum("...n,d->...nd", pos.float(), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()

View File

@ -126,11 +126,13 @@ class Flux(nn.Module):
bs, c, h, w = x.shape
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
img_ids = torch.zeros((h // 2, w // 2, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=x.device, dtype=x.dtype)[None, :]
h_len = (h // 2)
w_len = (w // 2)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h // 2, w=w // 2, ph=2, pw=2)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)