From a84cd0d1ad7a641e56ab899b206f3e40e841b2a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 8 Feb 2023 03:17:54 -0500 Subject: [PATCH] Don't unload/reload model from CPU uselessly. --- comfy/model_management.py | 26 +++++++++++++++++ comfy/sd.py | 3 ++ nodes.py | 61 +++++++++++++++++++-------------------- 3 files changed, 58 insertions(+), 32 deletions(-) create mode 100644 comfy/model_management.py diff --git a/comfy/model_management.py b/comfy/model_management.py new file mode 100644 index 00000000..3e098124 --- /dev/null +++ b/comfy/model_management.py @@ -0,0 +1,26 @@ + + +current_loaded_model = None + + +def unload_model(): + global current_loaded_model + if current_loaded_model is not None: + current_loaded_model.model.cpu() + current_loaded_model.unpatch_model() + current_loaded_model = None + + +def load_model_gpu(model): + global current_loaded_model + if model is current_loaded_model: + return + unload_model() + try: + real_model = model.patch_model() + except Exception as e: + model.unpatch_model() + raise e + current_loaded_model = model + real_model.cuda() + return current_loaded_model diff --git a/comfy/sd.py b/comfy/sd.py index 300ceaef..a3c0066d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -2,6 +2,7 @@ import torch import sd1_clip import sd2_clip +import model_management from ldm.util import instantiate_from_config from ldm.models.autoencoder import AutoencoderKL from omegaconf import OmegaConf @@ -304,6 +305,7 @@ class VAE: self.device = device def decode(self, samples): + model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) samples = samples.to(self.device) pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples) @@ -313,6 +315,7 @@ class VAE: return pixel_samples def encode(self, pixel_samples): + model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor diff --git a/nodes.py b/nodes.py index eea66b45..698c06ce 100644 --- a/nodes.py +++ b/nodes.py @@ -15,6 +15,7 @@ sys.path.append(os.path.join(sys.path[0], "comfy")) import comfy.samplers import comfy.sd +import model_management supported_ckpt_extensions = ['.ckpt'] supported_pt_extensions = ['.ckpt', '.pt', '.bin'] @@ -353,43 +354,39 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") real_model = None - try: + if device != "cpu": + model_management.load_model_gpu(model) + real_model = model.model + else: + #TODO: cpu support real_model = model.patch_model() - real_model.to(device) - noise = noise.to(device) - latent_image = latent_image.to(device) + noise = noise.to(device) + latent_image = latent_image.to(device) - positive_copy = [] - negative_copy = [] + positive_copy = [] + negative_copy = [] - for p in positive: - t = p[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - positive_copy += [[t] + p[1:]] - for n in negative: - t = n[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - negative_copy += [[t] + n[1:]] + for p in positive: + t = p[0] + if t.shape[0] < noise.shape[0]: + t = torch.cat([t] * noise.shape[0]) + t = t.to(device) + positive_copy += [[t] + p[1:]] + for n in negative: + t = n[0] + if t.shape[0] < noise.shape[0]: + t = torch.cat([t] * noise.shape[0]) + t = t.to(device) + negative_copy += [[t] + n[1:]] - if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) - else: - #other samplers - pass + if sampler_name in comfy.samplers.KSampler.SAMPLERS: + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) + else: + #other samplers + pass - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) - samples = samples.cpu() - real_model.cpu() - model.unpatch_model() - except Exception as e: - if real_model is not None: - real_model.cpu() - model.unpatch_model() - raise e + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) + samples = samples.cpu() return (samples, )