Add support for the stable diffusion x4 upscaling model.
This is an old model. Load the checkpoint like a regular one and use the new SD_4XUpscale_Conditioning node.
This commit is contained in:
parent
2c4e92a98b
commit
a7874d1a8b
|
@ -33,3 +33,7 @@ class SDXL(LatentFormat):
|
|||
[-0.3112, -0.2359, -0.2076]
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
class SD_X4(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.08333
|
||||
|
|
|
@ -364,3 +364,24 @@ class Stable_Zero123(BaseModel):
|
|||
cross_attn = self.cc_projection(cross_attn)
|
||||
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
class SD_X4Upscaler(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
image = kwargs.get("concat_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)[:,:3]
|
||||
|
||||
if image.shape[1:] != noise.shape[1:]:
|
||||
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||
return out
|
||||
|
|
|
@ -174,6 +174,11 @@ class VAE:
|
|||
else:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
|
|
|
@ -278,6 +278,32 @@ class Stable_Zero123(supported_models_base.BASE):
|
|||
def clip_target(self):
|
||||
return None
|
||||
|
||||
class SD_X4Upscaler(SD20):
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
"model_channels": 256,
|
||||
'in_channels': 7,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega]
|
||||
unet_extra_config = {
|
||||
"disable_self_attentions": [True, True, True, False],
|
||||
"num_heads": 8,
|
||||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD_X4
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start": 0.0001,
|
||||
"linear_end": 0.02,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SD_X4Upscaler(self, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
||||
models += [SVD_img2vid]
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
import nodes
|
||||
import comfy.utils
|
||||
|
||||
class SD_4XUpscale_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "images": ("IMAGE",),
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
# "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), #TODO
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/upscale_diffusion"
|
||||
|
||||
def encode(self, images, positive, negative, scale_ratio):
|
||||
width = max(1, round(images.shape[-2] * scale_ratio))
|
||||
height = max(1, round(images.shape[-3] * scale_ratio))
|
||||
|
||||
pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
|
||||
|
||||
out_cp = []
|
||||
out_cn = []
|
||||
|
||||
for t in positive:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['concat_image'] = pixels
|
||||
out_cp.append(n)
|
||||
|
||||
for t in negative:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['concat_image'] = pixels
|
||||
out_cn.append(n)
|
||||
|
||||
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
|
||||
return (out_cp, out_cn, {"samples":latent})
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
|
||||
}
|
Loading…
Reference in New Issue