Support loading the Stable Cascade effnet and previewer as a VAE.
The effnet can be used to encode images for img2img with Stage C.
This commit is contained in:
parent
2e4628ac8d
commit
3b2e579926
42
comfy/sd.py
42
comfy/sd.py
|
@ -4,6 +4,7 @@ from enum import Enum
|
|||
from comfy import model_management
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -158,6 +159,7 @@ class VAE:
|
|||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
@ -176,11 +178,31 @@ class VAE:
|
|||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
#TODO
|
||||
#self.memory_used_encode
|
||||
#self.memory_used_decode
|
||||
self.process_input = lambda image: image
|
||||
self.process_output = lambda image: image
|
||||
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["encoder.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["previewer.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
else:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
|
@ -188,6 +210,7 @@ class VAE:
|
|||
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
|
@ -213,6 +236,15 @@ class VAE:
|
|||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
|
||||
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
|
||||
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
|
@ -221,9 +253,9 @@ class VAE:
|
|||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
|
@ -248,7 +280,7 @@ class VAE:
|
|||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
|
@ -265,6 +297,7 @@ class VAE:
|
|||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
|
@ -284,6 +317,7 @@ class VAE:
|
|||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
|
|
20
nodes.py
20
nodes.py
|
@ -309,18 +309,7 @@ class VAEEncode:
|
|||
|
||||
CATEGORY = "latent"
|
||||
|
||||
@staticmethod
|
||||
def vae_encode_crop_pixels(pixels):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
pixels = self.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
return ({"samples":t}, )
|
||||
|
||||
|
@ -336,7 +325,6 @@ class VAEEncodeTiled:
|
|||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels, tile_size):
|
||||
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
|
||||
return ({"samples":t}, )
|
||||
|
||||
|
@ -350,14 +338,14 @@ class VAEEncodeForInpaint:
|
|||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def encode(self, vae, pixels, mask, grow_mask_by=6):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
|
||||
pixels = pixels.clone()
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
|
||||
y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
|
||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
|
||||
|
||||
|
|
Loading…
Reference in New Issue