AuraFlow model implementation.

This commit is contained in:
comfyanonymous 2024-07-11 16:51:06 -04:00
parent f45157e3ac
commit 9f291d75b3
12 changed files with 1744 additions and 2 deletions

479
comfy/ldm/aura/mmdit.py Normal file
View File

@ -0,0 +1,479 @@
#AuraFlow MMDiT
#Originally written by the AuraFlow Authors
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
n_hidden = int(2 * hidden_dim / 3)
n_hidden = find_multiple(n_hidden, 256)
self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
class MultiHeadLayerNorm(nn.Module):
def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
super().__init__()
self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(
variance + self.variance_epsilon
)
hidden_states = self.weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
class SingleAttention(nn.Module):
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
# this is for cond
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
#@torch.compile()
def forward(self, c):
bsz, seqlen1, _ = c.shape
q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c = self.w1o(output)
return c
class DoubleAttention(nn.Module):
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
# this is for cond
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
# this is for x
self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.q_norm2 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm2 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
#@torch.compile()
def forward(self, c, x):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
seqlen = seqlen1 + seqlen2
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
cq, ck = self.q_norm1(cq), self.k_norm1(ck)
xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
xq, xk = self.q_norm2(xq), self.k_norm2(xk)
# concat all
q, k, v = (
torch.cat([cq, xq], dim=1),
torch.cat([ck, xk], dim=1),
torch.cat([cv, xv], dim=1),
)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
x = self.w2o(x)
return c, x
class MMDiTBlock(nn.Module):
def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
super().__init__()
self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
if not is_last:
self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
self.modC = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
else:
self.modC = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
)
self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
self.modX = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
self.is_last = is_last
#@torch.compile()
def forward(self, c, x, global_cond, **kwargs):
cres, xres = c, x
cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
self.modC(global_cond).chunk(6, dim=1)
)
c = modulate(self.normC1(c), cshift_msa, cscale_msa)
# xpath
xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
self.modX(global_cond).chunk(6, dim=1)
)
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
c, x = self.attn(c, x)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
c = cres + c
x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
x = xres + x
return c, x
class DiTBlock(nn.Module):
# like MMDiTBlock, but it only has X
def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.modCX = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
def forward(self, cx, global_cond, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
cx = cxres + cx
return cx
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = 1000 * torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
#@torch.compile()
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class MMDiT(nn.Module):
def __init__(
self,
in_channels=4,
out_channels=4,
patch_size=2,
dim=3072,
n_layers=36,
n_double_layers=4,
n_heads=12,
global_conddim=3072,
cond_seq_dim=2048,
max_seq=32 * 32,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
self.cond_seq_linear = operations.Linear(
cond_seq_dim, dim, bias=False, dtype=dtype, device=device
) # linear for something like text sequence.
self.init_x_linear = operations.Linear(
patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
) # init linear for patchified image.
self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
self.double_layers = nn.ModuleList([])
self.single_layers = nn.ModuleList([])
for idx in range(n_double_layers):
self.double_layers.append(
MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
)
for idx in range(n_double_layers, n_layers):
self.single_layers.append(
DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
)
self.final_linear = operations.Linear(
dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
)
self.modF = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
)
self.out_channels = out_channels
self.patch_size = patch_size
self.n_double_layers = n_double_layers
self.n_layers = n_layers
self.h_max = round(max_seq**0.5)
self.w_max = round(max_seq**0.5)
@torch.no_grad()
def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
# extend pe
pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
# now we need to extend this to target_dim. for this we will use interpolation.
# we will use torch.nn.functional.interpolate
pe_as_2d = F.interpolate(
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
)
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
self.h_max, self.w_max = target_dim
print("PE extended to", target_dim)
def pe_selection_index_based_on_dim(self, h, w):
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
starth = self.h_max // 2 - h_p // 2
endh =starth + h_p
startw = self.w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[
starth:endh, startw:endw
]
return original_pe_indexes.flatten()
def unpatchify(self, x, h, w):
c = self.out_channels
p = self.patch_size
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def patchify(self, x):
B, C, H, W = x.size()
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
x = x.view(
B,
C,
(H + 1) // self.patch_size,
self.patch_size,
(W + 1) // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
def apply_pos_embeds(self, x, h, w):
h = (h + 1) // self.patch_size
w = (w + 1) // self.patch_size
max_dim = max(h, w)
cur_dim = self.h_max
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
if max_dim > cur_dim:
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
cur_dim = max_dim
from_h = (cur_dim - h) // 2
from_w = (cur_dim - w) // 2
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, **kwargs):
# patchify x, add PE
b, c, h, w = x.shape
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
# print(pe_indexes, pe_indexes.shape)
x = self.init_x_linear(self.patchify(x)) # B, T_x, D
x = self.apply_pos_embeds(x, h, w)
# x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
# x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
# process conditions for MMDiT Blocks
c_seq = context # B, T_c, D_c
t = timestep
c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
global_cond = self.t_embedder(t, x.dtype) # B, D
if len(self.double_layers) > 0:
for layer in self.double_layers:
c, x = layer(c, x, global_cond, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
cx = torch.cat([c, x], dim=1)
for layer in self.single_layers:
cx = layer(cx, global_cond, **kwargs)
x = cx[:, c_len:]
fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
x = modulate(x, fshift, fscale)
x = self.final_linear(x)
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
return x

View File

@ -6,6 +6,7 @@ from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.aura.mmdit
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.model_management
@ -598,6 +599,17 @@ class SD3(BaseModel):
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)
class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class StableAudio1(BaseModel):
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):

View File

@ -105,6 +105,12 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["audio_model"] = "dit1.0"
return unet_config
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -253,6 +259,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models
unet_key_prefix = "model.model."
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
unet_key_prefix = "model."
else:
unet_key_prefix = "model.diffusion_model."
return unet_key_prefix

View File

@ -21,6 +21,7 @@ from . import sd2_clip
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
import comfy.text_encoders.aura_t5
import comfy.model_patcher
import comfy.lora
@ -415,6 +416,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
if weight.shape[-1] == 4096:
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif weight.shape[-1] == 2048:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = sa_t5.SAT5Model
clip_target.tokenizer = sa_t5.SAT5Tokenizer

View File

@ -7,6 +7,7 @@ from . import sd2_clip
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
import comfy.text_encoders.aura_t5
from . import supported_models_base
from . import latent_formats
@ -556,7 +557,28 @@ class StableAudio(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
class AuraFlow(supported_models_base.BASE):
unet_config = {
"cond_seq_dim": 2048,
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
sampling_settings = {
"multiplier": 1.0,
}
unet_extra_config = {}
latent_format = latent_formats.SDXL
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.AuraFlow(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
models += [SVD_img2vid]

View File

@ -0,0 +1,22 @@
from comfy import sd1_clip
from transformers import LlamaTokenizerFast
import comfy.t5
import os
class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LlamaTokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)

View File

@ -0,0 +1,22 @@
{
"d_ff": 5120,
"d_kv": 64,
"d_model": 2048,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 2,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 24,
"num_heads": 32,
"num_layers": 24,
"output_past": true,
"pad_token_id": 1,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 32128
}

View File

@ -0,0 +1,102 @@
{
"<extra_id_0>": 32099,
"<extra_id_10>": 32089,
"<extra_id_11>": 32088,
"<extra_id_12>": 32087,
"<extra_id_13>": 32086,
"<extra_id_14>": 32085,
"<extra_id_15>": 32084,
"<extra_id_16>": 32083,
"<extra_id_17>": 32082,
"<extra_id_18>": 32081,
"<extra_id_19>": 32080,
"<extra_id_1>": 32098,
"<extra_id_20>": 32079,
"<extra_id_21>": 32078,
"<extra_id_22>": 32077,
"<extra_id_23>": 32076,
"<extra_id_24>": 32075,
"<extra_id_25>": 32074,
"<extra_id_26>": 32073,
"<extra_id_27>": 32072,
"<extra_id_28>": 32071,
"<extra_id_29>": 32070,
"<extra_id_2>": 32097,
"<extra_id_30>": 32069,
"<extra_id_31>": 32068,
"<extra_id_32>": 32067,
"<extra_id_33>": 32066,
"<extra_id_34>": 32065,
"<extra_id_35>": 32064,
"<extra_id_36>": 32063,
"<extra_id_37>": 32062,
"<extra_id_38>": 32061,
"<extra_id_39>": 32060,
"<extra_id_3>": 32096,
"<extra_id_40>": 32059,
"<extra_id_41>": 32058,
"<extra_id_42>": 32057,
"<extra_id_43>": 32056,
"<extra_id_44>": 32055,
"<extra_id_45>": 32054,
"<extra_id_46>": 32053,
"<extra_id_47>": 32052,
"<extra_id_48>": 32051,
"<extra_id_49>": 32050,
"<extra_id_4>": 32095,
"<extra_id_50>": 32049,
"<extra_id_51>": 32048,
"<extra_id_52>": 32047,
"<extra_id_53>": 32046,
"<extra_id_54>": 32045,
"<extra_id_55>": 32044,
"<extra_id_56>": 32043,
"<extra_id_57>": 32042,
"<extra_id_58>": 32041,
"<extra_id_59>": 32040,
"<extra_id_5>": 32094,
"<extra_id_60>": 32039,
"<extra_id_61>": 32038,
"<extra_id_62>": 32037,
"<extra_id_63>": 32036,
"<extra_id_64>": 32035,
"<extra_id_65>": 32034,
"<extra_id_66>": 32033,
"<extra_id_67>": 32032,
"<extra_id_68>": 32031,
"<extra_id_69>": 32030,
"<extra_id_6>": 32093,
"<extra_id_70>": 32029,
"<extra_id_71>": 32028,
"<extra_id_72>": 32027,
"<extra_id_73>": 32026,
"<extra_id_74>": 32025,
"<extra_id_75>": 32024,
"<extra_id_76>": 32023,
"<extra_id_77>": 32022,
"<extra_id_78>": 32021,
"<extra_id_79>": 32020,
"<extra_id_7>": 32092,
"<extra_id_80>": 32019,
"<extra_id_81>": 32018,
"<extra_id_82>": 32017,
"<extra_id_83>": 32016,
"<extra_id_84>": 32015,
"<extra_id_85>": 32014,
"<extra_id_86>": 32013,
"<extra_id_87>": 32012,
"<extra_id_88>": 32011,
"<extra_id_89>": 32010,
"<extra_id_8>": 32091,
"<extra_id_90>": 32009,
"<extra_id_91>": 32008,
"<extra_id_92>": 32007,
"<extra_id_93>": 32006,
"<extra_id_94>": 32005,
"<extra_id_95>": 32004,
"<extra_id_96>": 32003,
"<extra_id_97>": 32002,
"<extra_id_98>": 32001,
"<extra_id_99>": 32000,
"<extra_id_9>": 32090
}

View File

@ -0,0 +1,125 @@
{
"additional_special_tokens": [
"<extra_id_99>",
"<extra_id_98>",
"<extra_id_97>",
"<extra_id_96>",
"<extra_id_95>",
"<extra_id_94>",
"<extra_id_93>",
"<extra_id_92>",
"<extra_id_91>",
"<extra_id_90>",
"<extra_id_89>",
"<extra_id_88>",
"<extra_id_87>",
"<extra_id_86>",
"<extra_id_85>",
"<extra_id_84>",
"<extra_id_83>",
"<extra_id_82>",
"<extra_id_81>",
"<extra_id_80>",
"<extra_id_79>",
"<extra_id_78>",
"<extra_id_77>",
"<extra_id_76>",
"<extra_id_75>",
"<extra_id_74>",
"<extra_id_73>",
"<extra_id_72>",
"<extra_id_71>",
"<extra_id_70>",
"<extra_id_69>",
"<extra_id_68>",
"<extra_id_67>",
"<extra_id_66>",
"<extra_id_65>",
"<extra_id_64>",
"<extra_id_63>",
"<extra_id_62>",
"<extra_id_61>",
"<extra_id_60>",
"<extra_id_59>",
"<extra_id_58>",
"<extra_id_57>",
"<extra_id_56>",
"<extra_id_55>",
"<extra_id_54>",
"<extra_id_53>",
"<extra_id_52>",
"<extra_id_51>",
"<extra_id_50>",
"<extra_id_49>",
"<extra_id_48>",
"<extra_id_47>",
"<extra_id_46>",
"<extra_id_45>",
"<extra_id_44>",
"<extra_id_43>",
"<extra_id_42>",
"<extra_id_41>",
"<extra_id_40>",
"<extra_id_39>",
"<extra_id_38>",
"<extra_id_37>",
"<extra_id_36>",
"<extra_id_35>",
"<extra_id_34>",
"<extra_id_33>",
"<extra_id_32>",
"<extra_id_31>",
"<extra_id_30>",
"<extra_id_29>",
"<extra_id_28>",
"<extra_id_27>",
"<extra_id_26>",
"<extra_id_25>",
"<extra_id_24>",
"<extra_id_23>",
"<extra_id_22>",
"<extra_id_21>",
"<extra_id_20>",
"<extra_id_19>",
"<extra_id_18>",
"<extra_id_17>",
"<extra_id_16>",
"<extra_id_15>",
"<extra_id_14>",
"<extra_id_13>",
"<extra_id_12>",
"<extra_id_11>",
"<extra_id_10>",
"<extra_id_9>",
"<extra_id_8>",
"<extra_id_7>",
"<extra_id_6>",
"<extra_id_5>",
"<extra_id_4>",
"<extra_id_3>",
"<extra_id_2>",
"<extra_id_1>",
"<extra_id_0>"
],
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

Binary file not shown.

View File

@ -0,0 +1,945 @@
{
"add_bos_token": false,
"add_eos_token": true,
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<extra_id_99>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32001": {
"content": "<extra_id_98>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32002": {
"content": "<extra_id_97>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32003": {
"content": "<extra_id_96>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32004": {
"content": "<extra_id_95>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32005": {
"content": "<extra_id_94>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32006": {
"content": "<extra_id_93>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32007": {
"content": "<extra_id_92>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32008": {
"content": "<extra_id_91>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32009": {
"content": "<extra_id_90>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32010": {
"content": "<extra_id_89>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32011": {
"content": "<extra_id_88>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32012": {
"content": "<extra_id_87>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32013": {
"content": "<extra_id_86>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32014": {
"content": "<extra_id_85>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32015": {
"content": "<extra_id_84>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32016": {
"content": "<extra_id_83>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32017": {
"content": "<extra_id_82>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32018": {
"content": "<extra_id_81>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32019": {
"content": "<extra_id_80>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32020": {
"content": "<extra_id_79>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32021": {
"content": "<extra_id_78>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32022": {
"content": "<extra_id_77>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32023": {
"content": "<extra_id_76>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32024": {
"content": "<extra_id_75>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32025": {
"content": "<extra_id_74>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32026": {
"content": "<extra_id_73>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32027": {
"content": "<extra_id_72>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32028": {
"content": "<extra_id_71>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32029": {
"content": "<extra_id_70>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32030": {
"content": "<extra_id_69>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32031": {
"content": "<extra_id_68>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32032": {
"content": "<extra_id_67>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32033": {
"content": "<extra_id_66>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32034": {
"content": "<extra_id_65>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32035": {
"content": "<extra_id_64>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32036": {
"content": "<extra_id_63>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32037": {
"content": "<extra_id_62>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32038": {
"content": "<extra_id_61>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32039": {
"content": "<extra_id_60>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32040": {
"content": "<extra_id_59>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32041": {
"content": "<extra_id_58>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32042": {
"content": "<extra_id_57>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32043": {
"content": "<extra_id_56>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32044": {
"content": "<extra_id_55>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32045": {
"content": "<extra_id_54>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32046": {
"content": "<extra_id_53>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32047": {
"content": "<extra_id_52>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32048": {
"content": "<extra_id_51>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32049": {
"content": "<extra_id_50>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32050": {
"content": "<extra_id_49>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32051": {
"content": "<extra_id_48>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32052": {
"content": "<extra_id_47>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32053": {
"content": "<extra_id_46>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32054": {
"content": "<extra_id_45>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32055": {
"content": "<extra_id_44>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32056": {
"content": "<extra_id_43>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32057": {
"content": "<extra_id_42>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32058": {
"content": "<extra_id_41>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32059": {
"content": "<extra_id_40>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32060": {
"content": "<extra_id_39>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32061": {
"content": "<extra_id_38>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32062": {
"content": "<extra_id_37>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32063": {
"content": "<extra_id_36>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32064": {
"content": "<extra_id_35>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32065": {
"content": "<extra_id_34>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32066": {
"content": "<extra_id_33>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32067": {
"content": "<extra_id_32>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32068": {
"content": "<extra_id_31>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32069": {
"content": "<extra_id_30>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32070": {
"content": "<extra_id_29>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32071": {
"content": "<extra_id_28>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32072": {
"content": "<extra_id_27>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32073": {
"content": "<extra_id_26>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32074": {
"content": "<extra_id_25>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32075": {
"content": "<extra_id_24>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32076": {
"content": "<extra_id_23>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32077": {
"content": "<extra_id_22>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32078": {
"content": "<extra_id_21>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32079": {
"content": "<extra_id_20>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32080": {
"content": "<extra_id_19>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32081": {
"content": "<extra_id_18>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32082": {
"content": "<extra_id_17>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32083": {
"content": "<extra_id_16>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32084": {
"content": "<extra_id_15>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32085": {
"content": "<extra_id_14>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32086": {
"content": "<extra_id_13>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32087": {
"content": "<extra_id_12>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32088": {
"content": "<extra_id_11>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32089": {
"content": "<extra_id_10>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32090": {
"content": "<extra_id_9>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32091": {
"content": "<extra_id_8>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32092": {
"content": "<extra_id_7>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32093": {
"content": "<extra_id_6>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32094": {
"content": "<extra_id_5>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32095": {
"content": "<extra_id_4>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32096": {
"content": "<extra_id_3>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32097": {
"content": "<extra_id_2>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32098": {
"content": "<extra_id_1>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32099": {
"content": "<extra_id_0>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<extra_id_99>",
"<extra_id_98>",
"<extra_id_97>",
"<extra_id_96>",
"<extra_id_95>",
"<extra_id_94>",
"<extra_id_93>",
"<extra_id_92>",
"<extra_id_91>",
"<extra_id_90>",
"<extra_id_89>",
"<extra_id_88>",
"<extra_id_87>",
"<extra_id_86>",
"<extra_id_85>",
"<extra_id_84>",
"<extra_id_83>",
"<extra_id_82>",
"<extra_id_81>",
"<extra_id_80>",
"<extra_id_79>",
"<extra_id_78>",
"<extra_id_77>",
"<extra_id_76>",
"<extra_id_75>",
"<extra_id_74>",
"<extra_id_73>",
"<extra_id_72>",
"<extra_id_71>",
"<extra_id_70>",
"<extra_id_69>",
"<extra_id_68>",
"<extra_id_67>",
"<extra_id_66>",
"<extra_id_65>",
"<extra_id_64>",
"<extra_id_63>",
"<extra_id_62>",
"<extra_id_61>",
"<extra_id_60>",
"<extra_id_59>",
"<extra_id_58>",
"<extra_id_57>",
"<extra_id_56>",
"<extra_id_55>",
"<extra_id_54>",
"<extra_id_53>",
"<extra_id_52>",
"<extra_id_51>",
"<extra_id_50>",
"<extra_id_49>",
"<extra_id_48>",
"<extra_id_47>",
"<extra_id_46>",
"<extra_id_45>",
"<extra_id_44>",
"<extra_id_43>",
"<extra_id_42>",
"<extra_id_41>",
"<extra_id_40>",
"<extra_id_39>",
"<extra_id_38>",
"<extra_id_37>",
"<extra_id_36>",
"<extra_id_35>",
"<extra_id_34>",
"<extra_id_33>",
"<extra_id_32>",
"<extra_id_31>",
"<extra_id_30>",
"<extra_id_29>",
"<extra_id_28>",
"<extra_id_27>",
"<extra_id_26>",
"<extra_id_25>",
"<extra_id_24>",
"<extra_id_23>",
"<extra_id_22>",
"<extra_id_21>",
"<extra_id_20>",
"<extra_id_19>",
"<extra_id_18>",
"<extra_id_17>",
"<extra_id_16>",
"<extra_id_15>",
"<extra_id_14>",
"<extra_id_13>",
"<extra_id_12>",
"<extra_id_11>",
"<extra_id_10>",
"<extra_id_9>",
"<extra_id_8>",
"<extra_id_7>",
"<extra_id_6>",
"<extra_id_5>",
"<extra_id_4>",
"<extra_id_3>",
"<extra_id_2>",
"<extra_id_1>",
"<extra_id_0>"
],
"bos_token": "<s>",
"clean_up_tokenization_spaces": false,
"eos_token": "</s>",
"legacy": false,
"model_max_length": 512,
"pad_token": null,
"padding_side": "right",
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "LlamaTokenizer",
"unk_token": "<unk>",
"use_default_system_prompt": false
}

View File

@ -3,7 +3,8 @@ torchsde
torchvision
torchaudio
einops
transformers>=4.25.1
transformers>=4.28.1
tokenizers>=0.13.3
safetensors>=0.4.2
aiohttp
pyyaml