SD3 Support.
This commit is contained in:
parent
a5e6a632f9
commit
8c4a9befa7
|
@ -25,8 +25,9 @@ class SD15(LatentFormat):
|
||||||
self.taesd_decoder_name = "taesd_decoder"
|
self.taesd_decoder_name = "taesd_decoder"
|
||||||
|
|
||||||
class SDXL(LatentFormat):
|
class SDXL(LatentFormat):
|
||||||
|
scale_factor = 0.13025
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.13025
|
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.3920, 0.4054, 0.4549],
|
[ 0.3920, 0.4054, 0.4549],
|
||||||
|
@ -104,3 +105,33 @@ class SC_B(LatentFormat):
|
||||||
[-0.3087, -0.1535, 0.0366],
|
[-0.3087, -0.1535, 0.0366],
|
||||||
[ 0.0290, -0.1574, -0.4078]
|
[ 0.0290, -0.1574, -0.4078]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class SD3(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.5305
|
||||||
|
self.shift_factor = 0.0609
|
||||||
|
self.latent_rgb_factors = [
|
||||||
|
[-0.0645, 0.0177, 0.1052],
|
||||||
|
[ 0.0028, 0.0312, 0.0650],
|
||||||
|
[ 0.1848, 0.0762, 0.0360],
|
||||||
|
[ 0.0944, 0.0360, 0.0889],
|
||||||
|
[ 0.0897, 0.0506, -0.0364],
|
||||||
|
[-0.0020, 0.1203, 0.0284],
|
||||||
|
[ 0.0855, 0.0118, 0.0283],
|
||||||
|
[-0.0539, 0.0658, 0.1047],
|
||||||
|
[-0.0057, 0.0116, 0.0700],
|
||||||
|
[-0.0412, 0.0281, -0.0039],
|
||||||
|
[ 0.1106, 0.1171, 0.1220],
|
||||||
|
[-0.0248, 0.0682, -0.0481],
|
||||||
|
[ 0.0815, 0.0846, 0.1207],
|
||||||
|
[-0.0120, -0.0055, -0.0867],
|
||||||
|
[-0.0749, -0.0634, -0.0456],
|
||||||
|
[-0.1418, -0.1457, -0.1259]
|
||||||
|
]
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
return (latent - self.shift_factor) * self.scale_factor
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,11 +5,13 @@ from comfy.ldm.cascade.stage_c import StageC
|
||||||
from comfy.ldm.cascade.stage_b import StageB
|
from comfy.ldm.cascade.stage_b import StageB
|
||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
|
import comfy.latent_formats
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
|
@ -17,6 +19,7 @@ class ModelType(Enum):
|
||||||
V_PREDICTION_EDM = 3
|
V_PREDICTION_EDM = 3
|
||||||
STABLE_CASCADE = 4
|
STABLE_CASCADE = 4
|
||||||
EDM = 5
|
EDM = 5
|
||||||
|
FLOW = 6
|
||||||
|
|
||||||
|
|
||||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
||||||
|
@ -32,6 +35,9 @@ def model_sampling(model_config, model_type):
|
||||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||||
c = V_PREDICTION
|
c = V_PREDICTION
|
||||||
s = ModelSamplingContinuousEDM
|
s = ModelSamplingContinuousEDM
|
||||||
|
elif model_type == ModelType.FLOW:
|
||||||
|
c = comfy.model_sampling.CONST
|
||||||
|
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||||
elif model_type == ModelType.STABLE_CASCADE:
|
elif model_type == ModelType.STABLE_CASCADE:
|
||||||
c = EPS
|
c = EPS
|
||||||
s = StableCascadeSampling
|
s = StableCascadeSampling
|
||||||
|
@ -557,3 +563,23 @@ class StableCascade_B(BaseModel):
|
||||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
out["effnet"] = comfy.conds.CONDRegular(prior)
|
||||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SD3(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
return kwargs["pooled_output"]
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
adm = self.encode_adm(**kwargs)
|
||||||
|
if adm is not None:
|
||||||
|
out['y'] = comfy.conds.CONDRegular(adm)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import comfy.supported_models
|
import comfy.supported_models
|
||||||
import comfy.supported_models_base
|
import comfy.supported_models_base
|
||||||
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
|
@ -26,12 +27,47 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||||
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
|
||||||
|
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix):
|
def detect_unet_config(state_dict, key_prefix):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
|
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||||
|
unet_config = {}
|
||||||
|
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||||
|
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
||||||
|
unet_config["patch_size"] = patch_size
|
||||||
|
unet_config["out_channels"] = state_dict['{}final_layer.linear.weight'.format(key_prefix)].shape[0] // (patch_size * patch_size)
|
||||||
|
|
||||||
|
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
||||||
|
unet_config["input_size"] = None
|
||||||
|
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
|
||||||
|
if y_key in state_dict_keys:
|
||||||
|
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
|
||||||
|
|
||||||
|
context_key = '{}context_embedder.weight'.format(key_prefix)
|
||||||
|
if context_key in state_dict_keys:
|
||||||
|
in_features = state_dict[context_key].shape[1]
|
||||||
|
out_features = state_dict[context_key].shape[0]
|
||||||
|
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
|
||||||
|
num_patches_key = '{}pos_embed'.format(key_prefix)
|
||||||
|
if num_patches_key in state_dict_keys:
|
||||||
|
num_patches = state_dict[num_patches_key].shape[1]
|
||||||
|
unet_config["num_patches"] = num_patches
|
||||||
|
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
|
||||||
|
|
||||||
|
rms_qk = '{}joint_blocks.0.context_block.attn.ln_q.weight'.format(key_prefix)
|
||||||
|
if rms_qk in state_dict_keys:
|
||||||
|
unet_config["qk_norm"] = "rms"
|
||||||
|
|
||||||
|
unet_config["pos_embed_scaling_factor"] = None #unused for inference
|
||||||
|
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||||
|
if context_processor in state_dict_keys:
|
||||||
|
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||||
|
return unet_config
|
||||||
|
|
||||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||||
unet_config = {}
|
unet_config = {}
|
||||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||||
|
@ -58,7 +94,6 @@ def detect_unet_config(state_dict, key_prefix):
|
||||||
unet_config['nhead'] = [-1, 9, 18, 18]
|
unet_config['nhead'] = [-1, 9, 18, 18]
|
||||||
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
||||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||||
|
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
|
@ -93,6 +128,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
|
|
||||||
video_model = False
|
video_model = False
|
||||||
|
video_model_cross = False
|
||||||
|
|
||||||
current_res = 1
|
current_res = 1
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -136,6 +172,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||||
context_dim = out[1]
|
context_dim = out[1]
|
||||||
use_linear_in_transformer = out[2]
|
use_linear_in_transformer = out[2]
|
||||||
video_model = out[3]
|
video_model = out[3]
|
||||||
|
video_model_cross = out[4]
|
||||||
else:
|
else:
|
||||||
transformer_depth.append(0)
|
transformer_depth.append(0)
|
||||||
|
|
||||||
|
@ -176,6 +213,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||||
unet_config["video_kernel_size"] = [3, 1, 1]
|
unet_config["video_kernel_size"] = [3, 1, 1]
|
||||||
unet_config["use_temporal_resblock"] = True
|
unet_config["use_temporal_resblock"] = True
|
||||||
unet_config["use_temporal_attention"] = True
|
unet_config["use_temporal_attention"] = True
|
||||||
|
unet_config["disable_temporal_crossattention"] = not video_model_cross
|
||||||
else:
|
else:
|
||||||
unet_config["use_temporal_resblock"] = False
|
unet_config["use_temporal_resblock"] = False
|
||||||
unet_config["use_temporal_attention"] = False
|
unet_config["use_temporal_attention"] = False
|
||||||
|
|
|
@ -33,6 +33,19 @@ class EDM(V_PREDICTION):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
|
class CONST:
|
||||||
|
def calculate_input(self, sigma, noise):
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input - model_output * sigma
|
||||||
|
|
||||||
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
|
return sigma * noise + (1.0 - sigma) * latent_image
|
||||||
|
|
||||||
|
def inverse_noise_scaling(self, sigma, latent):
|
||||||
|
return latent / (1.0 - sigma)
|
||||||
|
|
||||||
class ModelSamplingDiscrete(torch.nn.Module):
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
|
@ -104,6 +117,12 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
percent = 1.0 - percent
|
percent = 1.0 - percent
|
||||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||||
|
|
||||||
|
class ModelSamplingDiscreteEDM(ModelSamplingDiscrete):
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return 0.25 * sigma.log()
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
return (timestep / 0.25).exp()
|
||||||
|
|
||||||
class ModelSamplingContinuousEDM(torch.nn.Module):
|
class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
|
@ -149,6 +168,48 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||||
log_sigma_min = math.log(self.sigma_min)
|
log_sigma_min = math.log(self.sigma_min)
|
||||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||||
|
|
||||||
|
|
||||||
|
def time_snr_shift(alpha, t):
|
||||||
|
if alpha == 1.0:
|
||||||
|
return t
|
||||||
|
return alpha * t / (1 + (alpha - 1) * t)
|
||||||
|
|
||||||
|
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||||
|
def __init__(self, model_config=None):
|
||||||
|
super().__init__()
|
||||||
|
if model_config is not None:
|
||||||
|
sampling_settings = model_config.sampling_settings
|
||||||
|
else:
|
||||||
|
sampling_settings = {}
|
||||||
|
|
||||||
|
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
|
||||||
|
|
||||||
|
def set_parameters(self, shift=1.0, timesteps=1000):
|
||||||
|
self.shift = shift
|
||||||
|
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||||
|
self.register_buffer('sigmas', ts)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return sigma * 1000
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
return time_snr_shift(self.shift, timestep / 1000)
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
if percent <= 0.0:
|
||||||
|
return 1.0
|
||||||
|
if percent >= 1.0:
|
||||||
|
return 0.0
|
||||||
|
return 1.0 - percent
|
||||||
|
|
||||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -19,6 +19,7 @@ from . import model_detection
|
||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
|
from . import sd3_clip
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
|
@ -395,9 +396,12 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
else:
|
elif len(clip_data) == 2:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
elif len(clip_data) == 3:
|
||||||
|
clip_target.clip = sd3_clip.SD3ClipModel
|
||||||
|
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
from comfy import sd1_clip
|
||||||
|
from comfy import sdxl_clip
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
import comfy.t5
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||||
|
|
||||||
|
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
class SDT5XXLModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SD3Tokenizer:
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||||
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
|
out = {}
|
||||||
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||||
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
class SD3ClipModel(torch.nn.Module):
|
||||||
|
def __init__(self, device="cpu", dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||||
|
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||||
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.clip_l.set_clip_options(options)
|
||||||
|
self.clip_g.set_clip_options(options)
|
||||||
|
self.t5xxl.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.clip_g.reset_clip_options()
|
||||||
|
self.clip_l.reset_clip_options()
|
||||||
|
self.t5xxl.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
|
token_weight_pairs_g = token_weight_pairs["g"]
|
||||||
|
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
||||||
|
lg_out = None
|
||||||
|
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||||
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||||
|
out = lg_out
|
||||||
|
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
else:
|
||||||
|
pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||||
|
if lg_out is not None:
|
||||||
|
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
|
else:
|
||||||
|
out = t5_out
|
||||||
|
|
||||||
|
return out, pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_g.load_sd(sd)
|
||||||
|
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.t5xxl.load_sd(sd)
|
|
@ -5,6 +5,7 @@ from . import utils
|
||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
|
from . import sd3_clip
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
|
@ -488,6 +489,28 @@ class SDXL_instructpix2pix(SDXL):
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
class SD3(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"in_channels": 16,
|
||||||
|
"pos_embed_scaling_factor": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.SD3
|
||||||
|
text_encoder_key_prefix = ["text_encoders."] #TODO?
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.SD3(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self):
|
||||||
|
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|
|
@ -0,0 +1,231 @@
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
class T5LayerNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
variance = x.pow(2).mean(-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||||
|
|
||||||
|
class T5DenseActDense(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.relu(self.wi(x))
|
||||||
|
# x = self.dropout(x)
|
||||||
|
x = self.wo(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class T5DenseGatedActDense(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||||
|
hidden_linear = self.wi_1(x)
|
||||||
|
x = hidden_gelu * hidden_linear
|
||||||
|
# x = self.dropout(x)
|
||||||
|
x = self.wo(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class T5LayerFF(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
if ff_activation == "gelu_pytorch_tanh":
|
||||||
|
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations)
|
||||||
|
elif ff_activation == "relu":
|
||||||
|
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations)
|
||||||
|
|
||||||
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
forwarded_states = self.layer_norm(x)
|
||||||
|
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||||
|
# x = x + self.dropout(forwarded_states)
|
||||||
|
x += forwarded_states
|
||||||
|
return x
|
||||||
|
|
||||||
|
class T5Attention(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||||
|
self.q = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.k = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.v = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.o = operations.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.relative_attention_bias = None
|
||||||
|
if relative_attention_bias:
|
||||||
|
self.relative_attention_num_buckets = 32
|
||||||
|
self.relative_attention_max_distance = 128
|
||||||
|
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||||
|
"""
|
||||||
|
Adapted from Mesh Tensorflow:
|
||||||
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||||
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||||
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||||
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||||
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relative_position: an int32 Tensor
|
||||||
|
bidirectional: a boolean - whether the attention is bidirectional
|
||||||
|
num_buckets: an integer
|
||||||
|
max_distance: an integer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||||
|
"""
|
||||||
|
relative_buckets = 0
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||||
|
relative_position = torch.abs(relative_position)
|
||||||
|
else:
|
||||||
|
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||||
|
# now relative_position is in the range [0, inf)
|
||||||
|
|
||||||
|
# half of the buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
relative_position_if_large = max_exact + (
|
||||||
|
torch.log(relative_position.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.long)
|
||||||
|
relative_position_if_large = torch.min(
|
||||||
|
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length, device):
|
||||||
|
"""Compute binned relative position bias"""
|
||||||
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||||
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||||
|
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||||
|
relative_position_bucket = self._relative_position_bucket(
|
||||||
|
relative_position, # shape (query_length, key_length)
|
||||||
|
bidirectional=True,
|
||||||
|
num_buckets=self.relative_attention_num_buckets,
|
||||||
|
max_distance=self.relative_attention_max_distance,
|
||||||
|
)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||||
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||||
|
q = self.q(x)
|
||||||
|
k = self.k(x)
|
||||||
|
v = self.v(x)
|
||||||
|
if self.relative_attention_bias is not None:
|
||||||
|
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||||
|
|
||||||
|
if past_bias is not None:
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask + past_bias
|
||||||
|
else:
|
||||||
|
mask = past_bias
|
||||||
|
|
||||||
|
out = optimized_attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
||||||
|
return self.o(out), past_bias
|
||||||
|
|
||||||
|
class T5LayerSelfAttention(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations)
|
||||||
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||||
|
normed_hidden_states = self.layer_norm(x)
|
||||||
|
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
|
||||||
|
# x = x + self.dropout(attention_output)
|
||||||
|
x += output
|
||||||
|
return x, past_bias
|
||||||
|
|
||||||
|
class T5Block(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = torch.nn.ModuleList()
|
||||||
|
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
|
||||||
|
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations))
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||||
|
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
|
||||||
|
x = self.layer[-1](x)
|
||||||
|
return x, past_bias
|
||||||
|
|
||||||
|
class T5Stack(torch.nn.Module):
|
||||||
|
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.block = torch.nn.ModuleList(
|
||||||
|
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
||||||
|
)
|
||||||
|
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||||
|
mask = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
|
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||||
|
|
||||||
|
intermediate = None
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
|
||||||
|
past_bias = None
|
||||||
|
for i, l in enumerate(self.block):
|
||||||
|
x, past_bias = l(x, mask, past_bias, optimized_attention)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if intermediate is not None and final_layer_norm_intermediate:
|
||||||
|
intermediate = self.final_layer_norm(intermediate)
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
class T5(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = config_dict["num_layers"]
|
||||||
|
model_dim = config_dict["d_model"]
|
||||||
|
|
||||||
|
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.shared
|
||||||
|
|
||||||
|
def set_input_embeddings(self, embeddings):
|
||||||
|
self.shared = embeddings
|
||||||
|
|
||||||
|
def forward(self, input_ids, *args, **kwargs):
|
||||||
|
x = self.shared(input_ids)
|
||||||
|
return self.encoder(x, *args, **kwargs)
|
|
@ -0,0 +1,21 @@
|
||||||
|
{
|
||||||
|
"d_ff": 3072,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 768,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"dense_act_fn": "relu",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 12,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 12,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
{
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 4096,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"dense_act_fn": "gelu_pytorch_tanh",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
|
@ -0,0 +1,125 @@
|
||||||
|
{
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<extra_id_0>",
|
||||||
|
"<extra_id_1>",
|
||||||
|
"<extra_id_2>",
|
||||||
|
"<extra_id_3>",
|
||||||
|
"<extra_id_4>",
|
||||||
|
"<extra_id_5>",
|
||||||
|
"<extra_id_6>",
|
||||||
|
"<extra_id_7>",
|
||||||
|
"<extra_id_8>",
|
||||||
|
"<extra_id_9>",
|
||||||
|
"<extra_id_10>",
|
||||||
|
"<extra_id_11>",
|
||||||
|
"<extra_id_12>",
|
||||||
|
"<extra_id_13>",
|
||||||
|
"<extra_id_14>",
|
||||||
|
"<extra_id_15>",
|
||||||
|
"<extra_id_16>",
|
||||||
|
"<extra_id_17>",
|
||||||
|
"<extra_id_18>",
|
||||||
|
"<extra_id_19>",
|
||||||
|
"<extra_id_20>",
|
||||||
|
"<extra_id_21>",
|
||||||
|
"<extra_id_22>",
|
||||||
|
"<extra_id_23>",
|
||||||
|
"<extra_id_24>",
|
||||||
|
"<extra_id_25>",
|
||||||
|
"<extra_id_26>",
|
||||||
|
"<extra_id_27>",
|
||||||
|
"<extra_id_28>",
|
||||||
|
"<extra_id_29>",
|
||||||
|
"<extra_id_30>",
|
||||||
|
"<extra_id_31>",
|
||||||
|
"<extra_id_32>",
|
||||||
|
"<extra_id_33>",
|
||||||
|
"<extra_id_34>",
|
||||||
|
"<extra_id_35>",
|
||||||
|
"<extra_id_36>",
|
||||||
|
"<extra_id_37>",
|
||||||
|
"<extra_id_38>",
|
||||||
|
"<extra_id_39>",
|
||||||
|
"<extra_id_40>",
|
||||||
|
"<extra_id_41>",
|
||||||
|
"<extra_id_42>",
|
||||||
|
"<extra_id_43>",
|
||||||
|
"<extra_id_44>",
|
||||||
|
"<extra_id_45>",
|
||||||
|
"<extra_id_46>",
|
||||||
|
"<extra_id_47>",
|
||||||
|
"<extra_id_48>",
|
||||||
|
"<extra_id_49>",
|
||||||
|
"<extra_id_50>",
|
||||||
|
"<extra_id_51>",
|
||||||
|
"<extra_id_52>",
|
||||||
|
"<extra_id_53>",
|
||||||
|
"<extra_id_54>",
|
||||||
|
"<extra_id_55>",
|
||||||
|
"<extra_id_56>",
|
||||||
|
"<extra_id_57>",
|
||||||
|
"<extra_id_58>",
|
||||||
|
"<extra_id_59>",
|
||||||
|
"<extra_id_60>",
|
||||||
|
"<extra_id_61>",
|
||||||
|
"<extra_id_62>",
|
||||||
|
"<extra_id_63>",
|
||||||
|
"<extra_id_64>",
|
||||||
|
"<extra_id_65>",
|
||||||
|
"<extra_id_66>",
|
||||||
|
"<extra_id_67>",
|
||||||
|
"<extra_id_68>",
|
||||||
|
"<extra_id_69>",
|
||||||
|
"<extra_id_70>",
|
||||||
|
"<extra_id_71>",
|
||||||
|
"<extra_id_72>",
|
||||||
|
"<extra_id_73>",
|
||||||
|
"<extra_id_74>",
|
||||||
|
"<extra_id_75>",
|
||||||
|
"<extra_id_76>",
|
||||||
|
"<extra_id_77>",
|
||||||
|
"<extra_id_78>",
|
||||||
|
"<extra_id_79>",
|
||||||
|
"<extra_id_80>",
|
||||||
|
"<extra_id_81>",
|
||||||
|
"<extra_id_82>",
|
||||||
|
"<extra_id_83>",
|
||||||
|
"<extra_id_84>",
|
||||||
|
"<extra_id_85>",
|
||||||
|
"<extra_id_86>",
|
||||||
|
"<extra_id_87>",
|
||||||
|
"<extra_id_88>",
|
||||||
|
"<extra_id_89>",
|
||||||
|
"<extra_id_90>",
|
||||||
|
"<extra_id_91>",
|
||||||
|
"<extra_id_92>",
|
||||||
|
"<extra_id_93>",
|
||||||
|
"<extra_id_94>",
|
||||||
|
"<extra_id_95>",
|
||||||
|
"<extra_id_96>",
|
||||||
|
"<extra_id_97>",
|
||||||
|
"<extra_id_98>",
|
||||||
|
"<extra_id_99>"
|
||||||
|
],
|
||||||
|
"eos_token": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,939 @@
|
||||||
|
{
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32000": {
|
||||||
|
"content": "<extra_id_99>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32001": {
|
||||||
|
"content": "<extra_id_98>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32002": {
|
||||||
|
"content": "<extra_id_97>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32003": {
|
||||||
|
"content": "<extra_id_96>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32004": {
|
||||||
|
"content": "<extra_id_95>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32005": {
|
||||||
|
"content": "<extra_id_94>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32006": {
|
||||||
|
"content": "<extra_id_93>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32007": {
|
||||||
|
"content": "<extra_id_92>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32008": {
|
||||||
|
"content": "<extra_id_91>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32009": {
|
||||||
|
"content": "<extra_id_90>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32010": {
|
||||||
|
"content": "<extra_id_89>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32011": {
|
||||||
|
"content": "<extra_id_88>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32012": {
|
||||||
|
"content": "<extra_id_87>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32013": {
|
||||||
|
"content": "<extra_id_86>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32014": {
|
||||||
|
"content": "<extra_id_85>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32015": {
|
||||||
|
"content": "<extra_id_84>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32016": {
|
||||||
|
"content": "<extra_id_83>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32017": {
|
||||||
|
"content": "<extra_id_82>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32018": {
|
||||||
|
"content": "<extra_id_81>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32019": {
|
||||||
|
"content": "<extra_id_80>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32020": {
|
||||||
|
"content": "<extra_id_79>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32021": {
|
||||||
|
"content": "<extra_id_78>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32022": {
|
||||||
|
"content": "<extra_id_77>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32023": {
|
||||||
|
"content": "<extra_id_76>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32024": {
|
||||||
|
"content": "<extra_id_75>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32025": {
|
||||||
|
"content": "<extra_id_74>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32026": {
|
||||||
|
"content": "<extra_id_73>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32027": {
|
||||||
|
"content": "<extra_id_72>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32028": {
|
||||||
|
"content": "<extra_id_71>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32029": {
|
||||||
|
"content": "<extra_id_70>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32030": {
|
||||||
|
"content": "<extra_id_69>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32031": {
|
||||||
|
"content": "<extra_id_68>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32032": {
|
||||||
|
"content": "<extra_id_67>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32033": {
|
||||||
|
"content": "<extra_id_66>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32034": {
|
||||||
|
"content": "<extra_id_65>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32035": {
|
||||||
|
"content": "<extra_id_64>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32036": {
|
||||||
|
"content": "<extra_id_63>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32037": {
|
||||||
|
"content": "<extra_id_62>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32038": {
|
||||||
|
"content": "<extra_id_61>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32039": {
|
||||||
|
"content": "<extra_id_60>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32040": {
|
||||||
|
"content": "<extra_id_59>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32041": {
|
||||||
|
"content": "<extra_id_58>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32042": {
|
||||||
|
"content": "<extra_id_57>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32043": {
|
||||||
|
"content": "<extra_id_56>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32044": {
|
||||||
|
"content": "<extra_id_55>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32045": {
|
||||||
|
"content": "<extra_id_54>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32046": {
|
||||||
|
"content": "<extra_id_53>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32047": {
|
||||||
|
"content": "<extra_id_52>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32048": {
|
||||||
|
"content": "<extra_id_51>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32049": {
|
||||||
|
"content": "<extra_id_50>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32050": {
|
||||||
|
"content": "<extra_id_49>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32051": {
|
||||||
|
"content": "<extra_id_48>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32052": {
|
||||||
|
"content": "<extra_id_47>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32053": {
|
||||||
|
"content": "<extra_id_46>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32054": {
|
||||||
|
"content": "<extra_id_45>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32055": {
|
||||||
|
"content": "<extra_id_44>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32056": {
|
||||||
|
"content": "<extra_id_43>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32057": {
|
||||||
|
"content": "<extra_id_42>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32058": {
|
||||||
|
"content": "<extra_id_41>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32059": {
|
||||||
|
"content": "<extra_id_40>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32060": {
|
||||||
|
"content": "<extra_id_39>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32061": {
|
||||||
|
"content": "<extra_id_38>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32062": {
|
||||||
|
"content": "<extra_id_37>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32063": {
|
||||||
|
"content": "<extra_id_36>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32064": {
|
||||||
|
"content": "<extra_id_35>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32065": {
|
||||||
|
"content": "<extra_id_34>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32066": {
|
||||||
|
"content": "<extra_id_33>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32067": {
|
||||||
|
"content": "<extra_id_32>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32068": {
|
||||||
|
"content": "<extra_id_31>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32069": {
|
||||||
|
"content": "<extra_id_30>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32070": {
|
||||||
|
"content": "<extra_id_29>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32071": {
|
||||||
|
"content": "<extra_id_28>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32072": {
|
||||||
|
"content": "<extra_id_27>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32073": {
|
||||||
|
"content": "<extra_id_26>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32074": {
|
||||||
|
"content": "<extra_id_25>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32075": {
|
||||||
|
"content": "<extra_id_24>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32076": {
|
||||||
|
"content": "<extra_id_23>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32077": {
|
||||||
|
"content": "<extra_id_22>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32078": {
|
||||||
|
"content": "<extra_id_21>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32079": {
|
||||||
|
"content": "<extra_id_20>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32080": {
|
||||||
|
"content": "<extra_id_19>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32081": {
|
||||||
|
"content": "<extra_id_18>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32082": {
|
||||||
|
"content": "<extra_id_17>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32083": {
|
||||||
|
"content": "<extra_id_16>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32084": {
|
||||||
|
"content": "<extra_id_15>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32085": {
|
||||||
|
"content": "<extra_id_14>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32086": {
|
||||||
|
"content": "<extra_id_13>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32087": {
|
||||||
|
"content": "<extra_id_12>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32088": {
|
||||||
|
"content": "<extra_id_11>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32089": {
|
||||||
|
"content": "<extra_id_10>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32090": {
|
||||||
|
"content": "<extra_id_9>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32091": {
|
||||||
|
"content": "<extra_id_8>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32092": {
|
||||||
|
"content": "<extra_id_7>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32093": {
|
||||||
|
"content": "<extra_id_6>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32094": {
|
||||||
|
"content": "<extra_id_5>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32095": {
|
||||||
|
"content": "<extra_id_4>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32096": {
|
||||||
|
"content": "<extra_id_3>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32097": {
|
||||||
|
"content": "<extra_id_2>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32098": {
|
||||||
|
"content": "<extra_id_1>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32099": {
|
||||||
|
"content": "<extra_id_0>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<extra_id_0>",
|
||||||
|
"<extra_id_1>",
|
||||||
|
"<extra_id_2>",
|
||||||
|
"<extra_id_3>",
|
||||||
|
"<extra_id_4>",
|
||||||
|
"<extra_id_5>",
|
||||||
|
"<extra_id_6>",
|
||||||
|
"<extra_id_7>",
|
||||||
|
"<extra_id_8>",
|
||||||
|
"<extra_id_9>",
|
||||||
|
"<extra_id_10>",
|
||||||
|
"<extra_id_11>",
|
||||||
|
"<extra_id_12>",
|
||||||
|
"<extra_id_13>",
|
||||||
|
"<extra_id_14>",
|
||||||
|
"<extra_id_15>",
|
||||||
|
"<extra_id_16>",
|
||||||
|
"<extra_id_17>",
|
||||||
|
"<extra_id_18>",
|
||||||
|
"<extra_id_19>",
|
||||||
|
"<extra_id_20>",
|
||||||
|
"<extra_id_21>",
|
||||||
|
"<extra_id_22>",
|
||||||
|
"<extra_id_23>",
|
||||||
|
"<extra_id_24>",
|
||||||
|
"<extra_id_25>",
|
||||||
|
"<extra_id_26>",
|
||||||
|
"<extra_id_27>",
|
||||||
|
"<extra_id_28>",
|
||||||
|
"<extra_id_29>",
|
||||||
|
"<extra_id_30>",
|
||||||
|
"<extra_id_31>",
|
||||||
|
"<extra_id_32>",
|
||||||
|
"<extra_id_33>",
|
||||||
|
"<extra_id_34>",
|
||||||
|
"<extra_id_35>",
|
||||||
|
"<extra_id_36>",
|
||||||
|
"<extra_id_37>",
|
||||||
|
"<extra_id_38>",
|
||||||
|
"<extra_id_39>",
|
||||||
|
"<extra_id_40>",
|
||||||
|
"<extra_id_41>",
|
||||||
|
"<extra_id_42>",
|
||||||
|
"<extra_id_43>",
|
||||||
|
"<extra_id_44>",
|
||||||
|
"<extra_id_45>",
|
||||||
|
"<extra_id_46>",
|
||||||
|
"<extra_id_47>",
|
||||||
|
"<extra_id_48>",
|
||||||
|
"<extra_id_49>",
|
||||||
|
"<extra_id_50>",
|
||||||
|
"<extra_id_51>",
|
||||||
|
"<extra_id_52>",
|
||||||
|
"<extra_id_53>",
|
||||||
|
"<extra_id_54>",
|
||||||
|
"<extra_id_55>",
|
||||||
|
"<extra_id_56>",
|
||||||
|
"<extra_id_57>",
|
||||||
|
"<extra_id_58>",
|
||||||
|
"<extra_id_59>",
|
||||||
|
"<extra_id_60>",
|
||||||
|
"<extra_id_61>",
|
||||||
|
"<extra_id_62>",
|
||||||
|
"<extra_id_63>",
|
||||||
|
"<extra_id_64>",
|
||||||
|
"<extra_id_65>",
|
||||||
|
"<extra_id_66>",
|
||||||
|
"<extra_id_67>",
|
||||||
|
"<extra_id_68>",
|
||||||
|
"<extra_id_69>",
|
||||||
|
"<extra_id_70>",
|
||||||
|
"<extra_id_71>",
|
||||||
|
"<extra_id_72>",
|
||||||
|
"<extra_id_73>",
|
||||||
|
"<extra_id_74>",
|
||||||
|
"<extra_id_75>",
|
||||||
|
"<extra_id_76>",
|
||||||
|
"<extra_id_77>",
|
||||||
|
"<extra_id_78>",
|
||||||
|
"<extra_id_79>",
|
||||||
|
"<extra_id_80>",
|
||||||
|
"<extra_id_81>",
|
||||||
|
"<extra_id_82>",
|
||||||
|
"<extra_id_83>",
|
||||||
|
"<extra_id_84>",
|
||||||
|
"<extra_id_85>",
|
||||||
|
"<extra_id_86>",
|
||||||
|
"<extra_id_87>",
|
||||||
|
"<extra_id_88>",
|
||||||
|
"<extra_id_89>",
|
||||||
|
"<extra_id_90>",
|
||||||
|
"<extra_id_91>",
|
||||||
|
"<extra_id_92>",
|
||||||
|
"<extra_id_93>",
|
||||||
|
"<extra_id_94>",
|
||||||
|
"<extra_id_95>",
|
||||||
|
"<extra_id_96>",
|
||||||
|
"<extra_id_97>",
|
||||||
|
"<extra_id_98>",
|
||||||
|
"<extra_id_99>"
|
||||||
|
],
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"extra_ids": 100,
|
||||||
|
"legacy": false,
|
||||||
|
"model_max_length": 512,
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"tokenizer_class": "T5Tokenizer",
|
||||||
|
"unk_token": "<unk>"
|
||||||
|
}
|
|
@ -132,6 +132,32 @@ class ModelSamplingStableCascade:
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class ModelSamplingSD3:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
|
def patch(self, model, shift):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||||
|
sampling_type = comfy.model_sampling.CONST
|
||||||
|
|
||||||
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
|
model_sampling.set_parameters(shift=shift)
|
||||||
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
class ModelSamplingContinuousEDM:
|
class ModelSamplingContinuousEDM:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
@ -213,5 +239,6 @@ NODE_CLASS_MAPPINGS = {
|
||||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||||
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
||||||
|
"ModelSamplingSD3": ModelSamplingSD3,
|
||||||
"RescaleCFG": RescaleCFG,
|
"RescaleCFG": RescaleCFG,
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
import folder_paths
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.model_management
|
||||||
|
import nodes
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class TripleCLIPLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||||
|
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
||||||
|
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
||||||
|
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
|
||||||
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
return (clip,)
|
||||||
|
|
||||||
|
class EmptySD3LatentImage:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"height": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent/sd3"
|
||||||
|
|
||||||
|
def generate(self, width, height, batch_size=1):
|
||||||
|
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
|
||||||
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
class CLIPTextEncodeSD3:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"empty_padding": (["none", "empty_prompt"], )
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
|
||||||
|
no_padding = empty_padding == "none"
|
||||||
|
|
||||||
|
tokens = clip.tokenize(clip_g)
|
||||||
|
if len(clip_g) == 0 and no_padding:
|
||||||
|
tokens["g"] = []
|
||||||
|
|
||||||
|
if len(clip_l) == 0 and no_padding:
|
||||||
|
tokens["l"] = []
|
||||||
|
else:
|
||||||
|
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||||
|
|
||||||
|
if len(t5xxl) == 0 and no_padding:
|
||||||
|
tokens["t5xxl"] = []
|
||||||
|
else:
|
||||||
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||||
|
if len(tokens["l"]) != len(tokens["g"]):
|
||||||
|
empty = clip.tokenize("")
|
||||||
|
while len(tokens["l"]) < len(tokens["g"]):
|
||||||
|
tokens["l"] += empty["l"]
|
||||||
|
while len(tokens["l"]) > len(tokens["g"]):
|
||||||
|
tokens["g"] += empty["g"]
|
||||||
|
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||||
|
return ([[cond, {"pooled_output": pooled}]], )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TripleCLIPLoader": TripleCLIPLoader,
|
||||||
|
"EmptySD3LatentImage": EmptySD3LatentImage,
|
||||||
|
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||||
|
}
|
Loading…
Reference in New Issue