Added an experimental VAEDecodeTiled.
This decodes the image with the VAE in tiles which should be faster and use less vram. It's in the _for_testing section so I might change/remove it or even add the functionality to the regular VAEDecode node depending on how well it performs which means don't depend too much on it.
This commit is contained in:
parent
5796705cc6
commit
87b00b37f6
31
comfy/sd.py
31
comfy/sd.py
|
@ -318,6 +318,37 @@ class VAE:
|
|||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples):
|
||||
tile_x = tile_y = 64
|
||||
overlap = 8
|
||||
model_management.unload_model()
|
||||
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu")
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
out = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
|
||||
out_div = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
|
||||
for y in range(0, s.shape[2], tile_y - overlap):
|
||||
for x in range(0, s.shape[3], tile_x - overlap):
|
||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
||||
|
||||
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * s_in.to(self.device))
|
||||
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
ps = pixel_samples.cpu()
|
||||
mask = torch.ones_like(ps)
|
||||
feather = overlap * 8
|
||||
for t in range(feather):
|
||||
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
|
||||
out[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += ps * mask
|
||||
out_div[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += mask
|
||||
|
||||
output[b:b+1] = out/out_div
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
|
|
16
nodes.py
16
nodes.py
|
@ -106,6 +106,21 @@ class VAEDecode:
|
|||
def decode(self, vae, samples):
|
||||
return (vae.decode(samples["samples"]), )
|
||||
|
||||
class VAEDecodeTiled:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def decode(self, vae, samples):
|
||||
return (vae.decode_tiled(samples["samples"]), )
|
||||
|
||||
class VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
@ -789,6 +804,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"ControlNetApply": ControlNetApply,
|
||||
"ControlNetLoader": ControlNetLoader,
|
||||
"DiffControlNetLoader": DiffControlNetLoader,
|
||||
"VAEDecodeTiled": VAEDecodeTiled,
|
||||
}
|
||||
|
||||
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
|
||||
|
|
Loading…
Reference in New Issue