From 89a0767abf019817a73ad9c7a693a2efcff75b12 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Aug 2023 01:06:34 -0400 Subject: [PATCH] Smarter memory management. Try to keep models on the vram when possible. Better lowvram mode for controlnets. --- comfy/gligen.py | 35 +---- comfy/model_management.py | 280 ++++++++++++++++++++++---------------- comfy/sample.py | 19 ++- comfy/samplers.py | 4 +- comfy/sd.py | 59 +++++--- execution.py | 1 + 6 files changed, 230 insertions(+), 168 deletions(-) diff --git a/comfy/gligen.py b/comfy/gligen.py index 90558785..8d182839 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -244,30 +244,15 @@ class Gligen(nn.Module): self.position_net = position_net self.key_dim = key_dim self.max_objs = 30 - self.lowvram = False + self.current_device = torch.device("cpu") def _set_position(self, boxes, masks, positive_embeddings): - if self.lowvram == True: - self.position_net.to(boxes.device) - objs = self.position_net(boxes, masks, positive_embeddings) - - if self.lowvram == True: - self.position_net.cpu() - def func_lowvram(x, extra_options): - key = extra_options["transformer_index"] - module = self.module_list[key] - module.to(x.device) - r = module(x, objs) - module.cpu() - return r - return func_lowvram - else: - def func(x, extra_options): - key = extra_options["transformer_index"] - module = self.module_list[key] - return module(x, objs) - return func + def func(x, extra_options): + key = extra_options["transformer_index"] + module = self.module_list[key] + return module(x, objs) + return func def set_position(self, latent_image_shape, position_params, device): batch, c, h, w = latent_image_shape @@ -312,14 +297,6 @@ class Gligen(nn.Module): masks.to(device), conds.to(device)) - def set_lowvram(self, value=True): - self.lowvram = value - - def cleanup(self): - self.lowvram = False - - def get_models(self): - return [self] def load_gligen(sd): sd_k = sd.keys() diff --git a/comfy/model_management.py b/comfy/model_management.py index 4dd15b41..3736b57a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -2,6 +2,7 @@ import psutil from enum import Enum from comfy.cli_args import args import torch +import sys class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -221,132 +222,161 @@ except: print("Could not pick default device.") -current_loaded_model = None -current_gpu_controlnets = [] +current_loaded_models = [] -model_accelerated = False +class LoadedModel: + def __init__(self, model): + self.model = model + self.model_accelerated = False + self.device = model.load_device + def model_memory(self): + return self.model.model_size() -def unload_model(): - global current_loaded_model - global model_accelerated - global current_gpu_controlnets - global vram_state + def model_memory_required(self, device): + if device == self.model.current_device: + return 0 + else: + return self.model_memory() - if current_loaded_model is not None: - if model_accelerated: - accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) - model_accelerated = False + def model_load(self, lowvram_model_memory=0): + patch_model_to = None + if lowvram_model_memory == 0: + patch_model_to = self.device - current_loaded_model.unpatch_model() - current_loaded_model.model.to(current_loaded_model.offload_device) - current_loaded_model.model_patches_to(current_loaded_model.offload_device) - current_loaded_model = None - if vram_state != VRAMState.HIGH_VRAM: - soft_empty_cache() + self.model.model_patches_to(self.device) + self.model.model_patches_to(self.model.model_dtype()) - if vram_state != VRAMState.HIGH_VRAM: - if len(current_gpu_controlnets) > 0: - for n in current_gpu_controlnets: - n.cpu() - current_gpu_controlnets = [] + try: + self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU + except Exception as e: + self.model.unpatch_model(self.model.offload_device) + self.model_unload() + raise e + + if lowvram_model_memory > 0: + print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) + device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) + accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) + self.model_accelerated = True + + return self.real_model + + def model_unload(self): + if self.model_accelerated: + accelerate.hooks.remove_hook_from_submodules(self.real_model) + self.model_accelerated = False + + self.model.unpatch_model(self.model.offload_device) + self.model.model_patches_to(self.model.offload_device) + + def __eq__(self, other): + return self.model is other.model def minimum_inference_memory(): - return (768 * 1024 * 1024) + return (1024 * 1024 * 1024) + +def unload_model_clones(model): + to_unload = [] + for i in range(len(current_loaded_models)): + if model.is_clone(current_loaded_models[i].model): + to_unload = [i] + to_unload + + for i in to_unload: + print("unload clone", i) + current_loaded_models.pop(i).model_unload() + +def free_memory(memory_required, device, keep_loaded=[]): + unloaded_model = False + for i in range(len(current_loaded_models) -1, -1, -1): + current_free_mem = get_free_memory(device) + if current_free_mem > memory_required: + break + shift_model = current_loaded_models[i] + if shift_model.device == device: + if shift_model not in keep_loaded: + current_loaded_models.pop(i).model_unload() + unloaded_model = True + + if unloaded_model: + soft_empty_cache() + + +def load_models_gpu(models, memory_required=0): + global vram_state + + inference_memory = minimum_inference_memory() + extra_mem = max(inference_memory, memory_required) + + models_to_load = [] + models_already_loaded = [] + for x in models: + loaded_model = LoadedModel(x) + + if loaded_model in current_loaded_models: + index = current_loaded_models.index(loaded_model) + current_loaded_models.insert(0, current_loaded_models.pop(index)) + models_already_loaded.append(loaded_model) + else: + models_to_load.append(loaded_model) + + if len(models_to_load) == 0: + devs = set(map(lambda a: a.device, models_already_loaded)) + for d in devs: + if d != torch.device("cpu"): + free_memory(extra_mem, d, models_already_loaded) + return + + print("loading new") + + total_memory_required = {} + for loaded_model in models_to_load: + unload_model_clones(loaded_model.model) + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + + for device in total_memory_required: + if device != torch.device("cpu"): + free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + + for loaded_model in models_to_load: + model = loaded_model.model + torch_dev = model.load_device + if is_device_cpu(torch_dev): + vram_set_state = VRAMState.DISABLED + else: + vram_set_state = vram_state + lowvram_model_memory = 0 + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = loaded_model.model_memory_required(torch_dev) + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + else: + lowvram_model_memory = 0 + + if vram_set_state == VRAMState.NO_VRAM: + lowvram_model_memory = 256 * 1024 * 1024 + + cur_loaded_model = loaded_model.model_load(lowvram_model_memory) + current_loaded_models.insert(0, loaded_model) + return + def load_model_gpu(model): - global current_loaded_model - global vram_state - global model_accelerated + return load_models_gpu([model]) - if model is current_loaded_model: - return - unload_model() +def cleanup_models(): + to_delete = [] + for i in range(len(current_loaded_models)): + print(sys.getrefcount(current_loaded_models[i].model)) + if sys.getrefcount(current_loaded_models[i].model) <= 2: + to_delete = [i] + to_delete - torch_dev = model.load_device - model.model_patches_to(torch_dev) - model.model_patches_to(model.model_dtype()) - current_loaded_model = model - - if is_device_cpu(torch_dev): - vram_set_state = VRAMState.DISABLED - else: - vram_set_state = vram_state - - if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): - model_size = model.model_size() - current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) - if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary - vram_set_state = VRAMState.LOW_VRAM - - real_model = model.model - patch_model_to = None - if vram_set_state == VRAMState.DISABLED: - pass - elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: - model_accelerated = False - patch_model_to = torch_dev - - try: - real_model = model.patch_model(device_to=patch_model_to) - except Exception as e: - model.unpatch_model() - unload_model() - raise e - - if patch_model_to is not None: - real_model.to(torch_dev) - - if vram_set_state == VRAMState.NO_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) - model_accelerated = True - elif vram_set_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) - model_accelerated = True - - return current_loaded_model - -def load_controlnet_gpu(control_models): - global current_gpu_controlnets - global vram_state - if vram_state == VRAMState.DISABLED: - return - - if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - for m in control_models: - if hasattr(m, 'set_lowvram'): - m.set_lowvram(True) - #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after - return - - models = [] - for m in control_models: - models += m.get_models() - - for m in current_gpu_controlnets: - if m not in models: - m.cpu() - - device = get_torch_device() - current_gpu_controlnets = [] - for m in models: - current_gpu_controlnets.append(m.to(device)) - - -def load_if_low_vram(model): - global vram_state - if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - return model.to(get_torch_device()) - return model - -def unload_if_low_vram(model): - global vram_state - if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - return model.cpu() - return model + for i in to_delete: + x = current_loaded_models.pop(i) + x.model_unload() + del x def unet_offload_device(): if vram_state == VRAMState.HIGH_VRAM: @@ -354,6 +384,21 @@ def unet_offload_device(): else: return torch.device("cpu") +def unet_inital_load_device(parameters, dtype): + torch_dev = get_torch_device() + if vram_state == VRAMState.HIGH_VRAM: + return torch_dev + + cpu_dev = torch.device("cpu") + model_size = dtype.itemsize * parameters + + mem_dev = get_free_memory(torch_dev) + mem_cpu = get_free_memory(cpu_dev) + if mem_dev > mem_cpu and model_size < mem_dev: + return torch_dev + else: + return cpu_dev + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -456,6 +501,13 @@ def get_free_memory(dev=None, torch_free_too=False): else: return mem_free_total +def batch_area_memory(area): + if xformers_enabled() or pytorch_attention_flash_attention(): + #TODO: these formulas are copied from maximum_batch_area below + return (area / 20) * (1024 * 1024) + else: + return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) + def maximum_batch_area(): global vram_state if vram_state == VRAMState.NO_VRAM: diff --git a/comfy/sample.py b/comfy/sample.py index 48530f13..1dfca420 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -51,19 +51,24 @@ def get_models_from_cond(cond, model_type): models += [c[1][model_type]] return models -def load_additional_models(positive, negative, dtype): +def get_additional_models(positive, negative): """loads additional models in positive and negative conditioning""" control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") + + control_models = [] + for m in control_nets: + control_models += m.get_models() + gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") - gligen = [x[1].to(dtype) for x in gligen] - models = control_nets + gligen - comfy.model_management.load_controlnet_gpu(models) + gligen = [x[1] for x in gligen] + models = control_models + gligen return models def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: - m.cleanup() + if hasattr(m, 'cleanup'): + m.cleanup() def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): device = comfy.model_management.get_torch_device() @@ -72,7 +77,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative noise_mask = prepare_mask(noise_mask, noise.shape, device) real_model = None - comfy.model_management.load_model_gpu(model) + models = get_additional_models(positive, negative) + comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[2] * noise.shape[3])) real_model = model.model noise = noise.to(device) @@ -81,7 +87,6 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative positive_copy = broadcast_cond(positive, noise.shape[0], device) negative_copy = broadcast_cond(negative, noise.shape[0], device) - models = load_additional_models(positive, negative, model.model_dtype()) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) diff --git a/comfy/samplers.py b/comfy/samplers.py index 28cd4666..ee37913e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -88,9 +88,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con gligen_type = gligen[0] gligen_model = gligen[1] if gligen_type == "position": - gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) + gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) else: - gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) + gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) patches['middle_patch'] = [gligen_patch] diff --git a/comfy/sd.py b/comfy/sd.py index 06b64096..8d8c8ee3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -244,7 +244,7 @@ def set_attr(obj, attr, value): del prev class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0): + def __init__(self, model, load_device, offload_device, size=0, current_device=None): self.size = size self.model = model self.patches = {} @@ -253,6 +253,10 @@ class ModelPatcher: self.model_size() self.load_device = load_device self.offload_device = offload_device + if current_device is None: + self.current_device = self.offload_device + else: + self.current_device = current_device def model_size(self): if self.size > 0: @@ -267,7 +271,7 @@ class ModelPatcher: return size def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -276,6 +280,11 @@ class ModelPatcher: n.model_keys = self.model_keys return n + def is_clone(self, other): + if hasattr(other, 'model') and self.model is other.model: + return True + return False + def set_model_sampler_cfg_function(self, sampler_cfg_function): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way @@ -390,6 +399,11 @@ class ModelPatcher: out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) set_attr(self.model, key, out_weight) del temp_weight + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + return self.model def calculate_weight(self, patches, weight, key): @@ -482,7 +496,7 @@ class ModelPatcher: return weight - def unpatch_model(self): + def unpatch_model(self, device_to=None): keys = list(self.backup.keys()) for k in keys: @@ -490,6 +504,11 @@ class ModelPatcher: self.backup = {} + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + + def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = model_lora_keys_unet(model.model) key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) @@ -630,11 +649,12 @@ class VAE: return samples def decode(self, samples_in): - model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) try: + memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.4 + model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) - batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64)) + batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") @@ -650,19 +670,19 @@ class VAE: return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): - model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) self.first_stage_model = self.first_stage_model.to(self.offload_device) return output.movedim(1,-1) def encode(self, pixel_samples): - model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: + memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.4 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) - batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): @@ -677,7 +697,6 @@ class VAE: return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): - model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) @@ -757,6 +776,7 @@ class ControlNet(ControlBase): def __init__(self, control_model, global_average_pooling=False, device=None): super().__init__(device) self.control_model = control_model + self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond, batched_number): @@ -786,11 +806,9 @@ class ControlNet(ControlBase): precision_scope = contextlib.nullcontext with precision_scope(model_management.get_autocast_device(self.device)): - self.control_model = model_management.load_if_low_vram(self.control_model) context = torch.cat(cond['c_crossattn'], 1) y = cond.get('c_adm', None) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y) - self.control_model = model_management.unload_if_low_vram(self.control_model) out = {'middle':[], 'output': []} autocast_enabled = torch.is_autocast_enabled() @@ -825,7 +843,7 @@ class ControlNet(ControlBase): def get_models(self): out = super().get_models() - out.append(self.control_model) + out.append(self.control_model_wrapped) return out @@ -1004,7 +1022,6 @@ class T2IAdapter(ControlBase): self.copy_to(c) return c - def load_t2i_adapter(t2i_data): keys = t2i_data.keys() if 'adapter' in keys: @@ -1090,7 +1107,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return model + return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): #TODO: this function is a mess and should be removed eventually @@ -1202,8 +1219,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clipvision: clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) + dtype = torch.float32 + if fp16: + dtype = torch.float16 + + inital_load_device = model_management.unet_inital_load_device(parameters, dtype) offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device) + model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model.load_model_weights(sd, "model.diffusion_model.") if output_vae: @@ -1224,7 +1246,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) + model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + if inital_load_device != torch.device("cpu"): + print("loaded straight to GPU") + model_management.load_model_gpu(model_patcher) + + return (model_patcher, clip, vae, clipvision) def load_unet(unet_path): #load unet in diffusers format diff --git a/execution.py b/execution.py index a1a7c75c..e10fdbb6 100644 --- a/execution.py +++ b/execution.py @@ -354,6 +354,7 @@ class PromptExecutor: d = self.outputs_ui.pop(x) del d + comfy.model_management.cleanup_models() if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set()