diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py new file mode 100644 index 00000000..c465619b --- /dev/null +++ b/comfy/ldm/aura/mmdit.py @@ -0,0 +1,479 @@ +#AuraFlow MMDiT +#Originally written by the AuraFlow Authors + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None: + super().__init__() + if hidden_dim is None: + hidden_dim = 4 * dim + + n_hidden = int(2 * hidden_dim / 3) + n_hidden = find_multiple(n_hidden, 256) + + self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device) + self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device) + self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.c_fc1(x)) * self.c_fc2(x) + x = self.c_proj(x) + return x + + +class MultiHeadLayerNorm(nn.Module): + def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None): + # Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78 + + super().__init__() + self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt( + variance + self.variance_epsilon + ) + hidden_states = self.weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + +class SingleAttention(nn.Module): + def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None): + super().__init__() + + self.n_heads = n_heads + self.head_dim = dim // n_heads + + # this is for cond + self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + self.q_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + self.k_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + + #@torch.compile() + def forward(self, c): + + bsz, seqlen1, _ = c.shape + + q, k, v = self.w1q(c), self.w1k(c), self.w1v(c) + q = q.view(bsz, seqlen1, self.n_heads, self.head_dim) + k = k.view(bsz, seqlen1, self.n_heads, self.head_dim) + v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) + q, k = self.q_norm1(q), self.k_norm1(k) + + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + c = self.w1o(output) + return c + + + +class DoubleAttention(nn.Module): + def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None): + super().__init__() + + self.n_heads = n_heads + self.head_dim = dim // n_heads + + # this is for cond + self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + # this is for x + self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + self.q_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + self.k_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + + self.q_norm2 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + self.k_norm2 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + + + #@torch.compile() + def forward(self, c, x): + + bsz, seqlen1, _ = c.shape + bsz, seqlen2, _ = x.shape + seqlen = seqlen1 + seqlen2 + + cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c) + cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim) + ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim) + cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim) + cq, ck = self.q_norm1(cq), self.k_norm1(ck) + + xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x) + xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim) + xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim) + xq, xk = self.q_norm2(xq), self.k_norm2(xk) + + # concat all + q, k, v = ( + torch.cat([cq, xq], dim=1), + torch.cat([ck, xk], dim=1), + torch.cat([cv, xv], dim=1), + ) + + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + + c, x = output.split([seqlen1, seqlen2], dim=1) + c = self.w1o(c) + x = self.w2o(x) + + return c, x + + +class MMDiTBlock(nn.Module): + def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None): + super().__init__() + + self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + if not is_last: + self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) + self.modC = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device), + ) + else: + self.modC = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device), + ) + + self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) + self.modX = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device), + ) + + self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations) + self.is_last = is_last + + #@torch.compile() + def forward(self, c, x, global_cond, **kwargs): + + cres, xres = c, x + + cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = ( + self.modC(global_cond).chunk(6, dim=1) + ) + + c = modulate(self.normC1(c), cshift_msa, cscale_msa) + + # xpath + xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = ( + self.modX(global_cond).chunk(6, dim=1) + ) + + x = modulate(self.normX1(x), xshift_msa, xscale_msa) + + # attention + c, x = self.attn(c, x) + + + c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) + c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp)) + c = cres + c + + x = self.normX2(xres + xgate_msa.unsqueeze(1) * x) + x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp)) + x = xres + x + + return c, x + +class DiTBlock(nn.Module): + # like MMDiTBlock, but it only has X + def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None): + super().__init__() + + self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + + self.modCX = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device), + ) + + self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) + + #@torch.compile() + def forward(self, cx, global_cond, **kwargs): + cxres = cx + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( + global_cond + ).chunk(6, dim=1) + cx = modulate(self.norm1(cx), shift_msa, scale_msa) + cx = self.attn(cx) + cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) + mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) + cx = gate_mlp.unsqueeze(1) * mlpout + + cx = cxres + cx + + return cx + + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = 1000 * torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half) / half + ).to(t.device) + args = t[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + #@torch.compile() + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class MMDiT(nn.Module): + def __init__( + self, + in_channels=4, + out_channels=4, + patch_size=2, + dim=3072, + n_layers=36, + n_double_layers=4, + n_heads=12, + global_conddim=3072, + cond_seq_dim=2048, + max_seq=32 * 32, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations) + + self.cond_seq_linear = operations.Linear( + cond_seq_dim, dim, bias=False, dtype=dtype, device=device + ) # linear for something like text sequence. + self.init_x_linear = operations.Linear( + patch_size * patch_size * in_channels, dim, dtype=dtype, device=device + ) # init linear for patchified image. + + self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device)) + self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device)) + + self.double_layers = nn.ModuleList([]) + self.single_layers = nn.ModuleList([]) + + + for idx in range(n_double_layers): + self.double_layers.append( + MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations) + ) + + for idx in range(n_double_layers, n_layers): + self.single_layers.append( + DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations) + ) + + + self.final_linear = operations.Linear( + dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device + ) + + self.modF = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device), + ) + + self.out_channels = out_channels + self.patch_size = patch_size + self.n_double_layers = n_double_layers + self.n_layers = n_layers + + self.h_max = round(max_seq**0.5) + self.w_max = round(max_seq**0.5) + + @torch.no_grad() + def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)): + # extend pe + pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]] + + pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1) + + # now we need to extend this to target_dim. for this we will use interpolation. + # we will use torch.nn.functional.interpolate + pe_as_2d = F.interpolate( + pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear" + ) + pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1) + self.positional_encoding.data = pe_new.unsqueeze(0).contiguous() + self.h_max, self.w_max = target_dim + print("PE extended to", target_dim) + + def pe_selection_index_based_on_dim(self, h, w): + h_p, w_p = h // self.patch_size, w // self.patch_size + original_pe_indexes = torch.arange(self.positional_encoding.shape[1]) + original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max) + starth = self.h_max // 2 - h_p // 2 + endh =starth + h_p + startw = self.w_max // 2 - w_p // 2 + endw = startw + w_p + original_pe_indexes = original_pe_indexes[ + starth:endh, startw:endw + ] + return original_pe_indexes.flatten() + + def unpatchify(self, x, h, w): + c = self.out_channels + p = self.patch_size + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def patchify(self, x): + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + x = x.view( + B, + C, + (H + 1) // self.patch_size, + self.patch_size, + (W + 1) // self.patch_size, + self.patch_size, + ) + x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + return x + + def apply_pos_embeds(self, x, h, w): + h = (h + 1) // self.patch_size + w = (w + 1) // self.patch_size + max_dim = max(h, w) + + cur_dim = self.h_max + pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) + + if max_dim > cur_dim: + pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1) + cur_dim = max_dim + + from_h = (cur_dim - h) // 2 + from_w = (cur_dim - w) // 2 + pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w] + return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) + + def forward(self, x, timestep, context, **kwargs): + # patchify x, add PE + b, c, h, w = x.shape + + # pe_indexes = self.pe_selection_index_based_on_dim(h, w) + # print(pe_indexes, pe_indexes.shape) + + x = self.init_x_linear(self.patchify(x)) # B, T_x, D + x = self.apply_pos_embeds(x, h, w) + # x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype) + # x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype) + + # process conditions for MMDiT Blocks + c_seq = context # B, T_c, D_c + t = timestep + + c = self.cond_seq_linear(c_seq) # B, T_c, D + c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1) + + global_cond = self.t_embedder(t, x.dtype) # B, D + + if len(self.double_layers) > 0: + for layer in self.double_layers: + c, x = layer(c, x, global_cond, **kwargs) + + if len(self.single_layers) > 0: + c_len = c.size(1) + cx = torch.cat([c, x], dim=1) + for layer in self.single_layers: + cx = layer(cx, global_cond, **kwargs) + + x = cx[:, c_len:] + + fshift, fscale = self.modF(global_cond).chunk(2, dim=1) + + x = modulate(x, fshift, fscale) + x = self.final_linear(x) + x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w] + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 80f6667e..0e0e69d3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -6,6 +6,7 @@ from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper +import comfy.ldm.aura.mmdit import comfy.ldm.audio.dit import comfy.ldm.audio.embedders import comfy.model_management @@ -598,6 +599,17 @@ class SD3(BaseModel): area = input_shape[0] * input_shape[2] * input_shape[3] return (area * 0.3) * (1024 * 1024) +class AuraFlow(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class StableAudio1(BaseModel): def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 783a03eb..dd5bff82 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -105,6 +105,12 @@ def detect_unet_config(state_dict, key_prefix): unet_config["audio_model"] = "dit1.0" return unet_config + if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit + unet_config = {} + unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1] + unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1] + return unet_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -253,6 +259,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal def unet_prefix_from_state_dict(state_dict): if "model.model.postprocess_conv.weight" in state_dict: #audio models unet_key_prefix = "model.model." + elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow + unet_key_prefix = "model." else: unet_key_prefix = "model.diffusion_model." return unet_key_prefix diff --git a/comfy/sd.py b/comfy/sd.py index eeffa423..6028029d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -21,6 +21,7 @@ from . import sd2_clip from . import sdxl_clip from . import sd3_clip from . import sa_t5 +import comfy.text_encoders.aura_t5 import comfy.model_patcher import comfy.lora @@ -415,6 +416,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI if weight.shape[-1] == 4096: clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) clip_target.tokenizer = sd3_clip.SD3Tokenizer + elif weight.shape[-1] == 2048: + clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model + clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]: clip_target.clip = sa_t5.SAT5Model clip_target.tokenizer = sa_t5.SAT5Tokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 21fdb7ec..ccf8c333 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -7,6 +7,7 @@ from . import sd2_clip from . import sdxl_clip from . import sd3_clip from . import sa_t5 +import comfy.text_encoders.aura_t5 from . import supported_models_base from . import latent_formats @@ -556,7 +557,28 @@ class StableAudio(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model) +class AuraFlow(supported_models_base.BASE): + unet_config = { + "cond_seq_dim": 2048, + } -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio] + sampling_settings = { + "multiplier": 1.0, + } + + unet_extra_config = {} + latent_format = latent_formats.SDXL + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.AuraFlow(self, device=device) + return out + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow] models += [SVD_img2vid] diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py new file mode 100644 index 00000000..0e84189a --- /dev/null +++ b/comfy/text_encoders/aura_t5.py @@ -0,0 +1,22 @@ +from comfy import sd1_clip +from transformers import LlamaTokenizerFast +import comfy.t5 +import os + +class PT5XlModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True) + +class PT5XlTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LlamaTokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1) + +class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer) + +class AuraT5Model(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs) diff --git a/comfy/text_encoders/t5_pile_config_xl.json b/comfy/text_encoders/t5_pile_config_xl.json new file mode 100644 index 00000000..ee4e03f9 --- /dev/null +++ b/comfy/text_encoders/t5_pile_config_xl.json @@ -0,0 +1,22 @@ +{ + "d_ff": 5120, + "d_kv": 64, + "d_model": 2048, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 2, + "dense_act_fn": "gelu_pytorch_tanh", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "umt5", + "num_decoder_layers": 24, + "num_heads": 32, + "num_layers": 24, + "output_past": true, + "pad_token_id": 1, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "vocab_size": 32128 +} diff --git a/comfy/text_encoders/t5_pile_tokenizer/added_tokens.json b/comfy/text_encoders/t5_pile_tokenizer/added_tokens.json new file mode 100644 index 00000000..3f513200 --- /dev/null +++ b/comfy/text_encoders/t5_pile_tokenizer/added_tokens.json @@ -0,0 +1,102 @@ +{ + "": 32099, + "": 32089, + "": 32088, + "": 32087, + "": 32086, + "": 32085, + "": 32084, + "": 32083, + "": 32082, + "": 32081, + "": 32080, + "": 32098, + "": 32079, + "": 32078, + "": 32077, + "": 32076, + "": 32075, + "": 32074, + "": 32073, + "": 32072, + "": 32071, + "": 32070, + "": 32097, + "": 32069, + "": 32068, + "": 32067, + "": 32066, + "": 32065, + "": 32064, + "": 32063, + "": 32062, + "": 32061, + "": 32060, + "": 32096, + "": 32059, + "": 32058, + "": 32057, + "": 32056, + "": 32055, + "": 32054, + "": 32053, + "": 32052, + "": 32051, + "": 32050, + "": 32095, + "": 32049, + "": 32048, + "": 32047, + "": 32046, + "": 32045, + "": 32044, + "": 32043, + "": 32042, + "": 32041, + "": 32040, + "": 32094, + "": 32039, + "": 32038, + "": 32037, + "": 32036, + "": 32035, + "": 32034, + "": 32033, + "": 32032, + "": 32031, + "": 32030, + "": 32093, + "": 32029, + "": 32028, + "": 32027, + "": 32026, + "": 32025, + "": 32024, + "": 32023, + "": 32022, + "": 32021, + "": 32020, + "": 32092, + "": 32019, + "": 32018, + "": 32017, + "": 32016, + "": 32015, + "": 32014, + "": 32013, + "": 32012, + "": 32011, + "": 32010, + "": 32091, + "": 32009, + "": 32008, + "": 32007, + "": 32006, + "": 32005, + "": 32004, + "": 32003, + "": 32002, + "": 32001, + "": 32000, + "": 32090 +} diff --git a/comfy/text_encoders/t5_pile_tokenizer/special_tokens_map.json b/comfy/text_encoders/t5_pile_tokenizer/special_tokens_map.json new file mode 100644 index 00000000..19fb1d5f --- /dev/null +++ b/comfy/text_encoders/t5_pile_tokenizer/special_tokens_map.json @@ -0,0 +1,125 @@ +{ + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/comfy/text_encoders/t5_pile_tokenizer/tokenizer.model b/comfy/text_encoders/t5_pile_tokenizer/tokenizer.model new file mode 100644 index 00000000..22bccbcb Binary files /dev/null and b/comfy/text_encoders/t5_pile_tokenizer/tokenizer.model differ diff --git a/comfy/text_encoders/t5_pile_tokenizer/tokenizer_config.json b/comfy/text_encoders/t5_pile_tokenizer/tokenizer_config.json new file mode 100644 index 00000000..81f8e11e --- /dev/null +++ b/comfy/text_encoders/t5_pile_tokenizer/tokenizer_config.json @@ -0,0 +1,945 @@ +{ + "add_bos_token": false, + "add_eos_token": true, + "add_prefix_space": true, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32000": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32001": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32002": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32003": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32004": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32005": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32006": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32007": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32008": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32009": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32010": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32011": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32012": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32013": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32014": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32015": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32016": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32017": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32018": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32019": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32020": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32021": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32022": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32023": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32024": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32025": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32026": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32027": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32028": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32029": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32030": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32031": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32032": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32033": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32034": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32035": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32036": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32037": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32038": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32039": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32040": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32041": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32042": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32043": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32044": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32045": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32046": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32047": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32048": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32049": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32050": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32051": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32052": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32053": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32054": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32055": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32056": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32057": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32058": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32059": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32060": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32061": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32062": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32063": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32064": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32065": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32066": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32067": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32068": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32069": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32070": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32071": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32072": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32073": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32074": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32075": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32076": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32077": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32078": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32079": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32080": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32081": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32082": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32083": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32084": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32085": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32086": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32087": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32088": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32089": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32090": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32091": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32092": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32093": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32094": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32095": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32096": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32097": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32098": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32099": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "legacy": false, + "model_max_length": 512, + "pad_token": null, + "padding_side": "right", + "sp_model_kwargs": {}, + "spaces_between_special_tokens": false, + "tokenizer_class": "LlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": false +} diff --git a/requirements.txt b/requirements.txt index 108958d2..6febc3d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,8 @@ torchsde torchvision torchaudio einops -transformers>=4.25.1 +transformers>=4.28.1 +tokenizers>=0.13.3 safetensors>=0.4.2 aiohttp pyyaml