Basic tiled decoding for audio VAE.
This commit is contained in:
parent
379ff92e9e
commit
a45df69570
16
comfy/sd.py
16
comfy/sd.py
|
@ -298,6 +298,17 @@ class VAE:
|
||||||
/ 3.0)
|
/ 3.0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def decode_tiled_1d(self, samples, tile_x=128, overlap=64):
|
||||||
|
output = torch.empty((samples.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples.shape[2:])), device=self.output_device)
|
||||||
|
|
||||||
|
for j in range(samples.shape[0]):
|
||||||
|
for i in range(0, samples.shape[-1], tile_x - overlap):
|
||||||
|
f = i
|
||||||
|
t = i + tile_x
|
||||||
|
output[j:j+1,:,f * self.upscale_ratio:t * self.upscale_ratio] = self.first_stage_model.decode(samples[j:j+1,:,f:t].to(self.vae_dtype).to(self.device)).float()
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
|
@ -325,7 +336,10 @@ class VAE:
|
||||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
if len(samples_in.shape) == 3:
|
||||||
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
|
else:
|
||||||
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
Loading…
Reference in New Issue