Split the batch in VAEEncode if there's not enough memory.

This commit is contained in:
comfyanonymous 2023-06-12 00:21:50 -04:00
parent c069fc0730
commit f8c5931053
1 changed files with 4 additions and 1 deletions

View File

@ -589,11 +589,14 @@ class VAE:
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
batch_number = 1
free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)