Smarter memory management.

Try to keep models on the vram when possible.

Better lowvram mode for controlnets.
This commit is contained in:
comfyanonymous 2023-08-17 01:06:34 -04:00
parent 2c97c30256
commit 89a0767abf
6 changed files with 230 additions and 168 deletions

View File

@ -244,30 +244,15 @@ class Gligen(nn.Module):
self.position_net = position_net self.position_net = position_net
self.key_dim = key_dim self.key_dim = key_dim
self.max_objs = 30 self.max_objs = 30
self.lowvram = False self.current_device = torch.device("cpu")
def _set_position(self, boxes, masks, positive_embeddings): def _set_position(self, boxes, masks, positive_embeddings):
if self.lowvram == True:
self.position_net.to(boxes.device)
objs = self.position_net(boxes, masks, positive_embeddings) objs = self.position_net(boxes, masks, positive_embeddings)
def func(x, extra_options):
if self.lowvram == True: key = extra_options["transformer_index"]
self.position_net.cpu() module = self.module_list[key]
def func_lowvram(x, extra_options): return module(x, objs)
key = extra_options["transformer_index"] return func
module = self.module_list[key]
module.to(x.device)
r = module(x, objs)
module.cpu()
return r
return func_lowvram
else:
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device): def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape batch, c, h, w = latent_image_shape
@ -312,14 +297,6 @@ class Gligen(nn.Module):
masks.to(device), masks.to(device),
conds.to(device)) conds.to(device))
def set_lowvram(self, value=True):
self.lowvram = value
def cleanup(self):
self.lowvram = False
def get_models(self):
return [self]
def load_gligen(sd): def load_gligen(sd):
sd_k = sd.keys() sd_k = sd.keys()

View File

@ -2,6 +2,7 @@ import psutil
from enum import Enum from enum import Enum
from comfy.cli_args import args from comfy.cli_args import args
import torch import torch
import sys
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
@ -221,132 +222,161 @@ except:
print("Could not pick default device.") print("Could not pick default device.")
current_loaded_model = None current_loaded_models = []
current_gpu_controlnets = []
model_accelerated = False class LoadedModel:
def __init__(self, model):
self.model = model
self.model_accelerated = False
self.device = model.load_device
def model_memory(self):
return self.model.model_size()
def unload_model(): def model_memory_required(self, device):
global current_loaded_model if device == self.model.current_device:
global model_accelerated return 0
global current_gpu_controlnets else:
global vram_state return self.model_memory()
if current_loaded_model is not None: def model_load(self, lowvram_model_memory=0):
if model_accelerated: patch_model_to = None
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) if lowvram_model_memory == 0:
model_accelerated = False patch_model_to = self.device
current_loaded_model.unpatch_model() self.model.model_patches_to(self.device)
current_loaded_model.model.to(current_loaded_model.offload_device) self.model.model_patches_to(self.model.model_dtype())
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model = None
if vram_state != VRAMState.HIGH_VRAM:
soft_empty_cache()
if vram_state != VRAMState.HIGH_VRAM: try:
if len(current_gpu_controlnets) > 0: self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
for n in current_gpu_controlnets: except Exception as e:
n.cpu() self.model.unpatch_model(self.model.offload_device)
current_gpu_controlnets = [] self.model_unload()
raise e
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True
return self.real_model
def model_unload(self):
if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model)
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other):
return self.model is other.model
def minimum_inference_memory(): def minimum_inference_memory():
return (768 * 1024 * 1024) return (1024 * 1024 * 1024)
def unload_model_clones(model):
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
for i in to_unload:
print("unload clone", i)
current_loaded_models.pop(i).model_unload()
def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False
for i in range(len(current_loaded_models) -1, -1, -1):
current_free_mem = get_free_memory(device)
if current_free_mem > memory_required:
break
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
current_loaded_models.pop(i).model_unload()
unloaded_model = True
if unloaded_model:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
models_to_load = []
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
if loaded_model in current_loaded_models:
index = current_loaded_models.index(loaded_model)
current_loaded_models.insert(0, current_loaded_models.pop(index))
models_already_loaded.append(loaded_model)
else:
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
return
print("loading new")
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model)
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
return
def load_model_gpu(model): def load_model_gpu(model):
global current_loaded_model return load_models_gpu([model])
global vram_state
global model_accelerated
if model is current_loaded_model: def cleanup_models():
return to_delete = []
unload_model() for i in range(len(current_loaded_models)):
print(sys.getrefcount(current_loaded_models[i].model))
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
torch_dev = model.load_device for i in to_delete:
model.model_patches_to(torch_dev) x = current_loaded_models.pop(i)
model.model_patches_to(model.model_dtype()) x.model_unload()
current_loaded_model = model del x
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = model.model_size()
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
real_model = model.model
patch_model_to = None
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
patch_model_to = torch_dev
try:
real_model = model.patch_model(device_to=patch_model_to)
except Exception as e:
model.unpatch_model()
unload_model()
raise e
if patch_model_to is not None:
real_model.to(torch_dev)
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu(control_models):
global current_gpu_controlnets
global vram_state
if vram_state == VRAMState.DISABLED:
return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
for m in control_models:
if hasattr(m, 'set_lowvram'):
m.set_lowvram(True)
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return
models = []
for m in control_models:
models += m.get_models()
for m in current_gpu_controlnets:
if m not in models:
m.cpu()
device = get_torch_device()
current_gpu_controlnets = []
for m in models:
current_gpu_controlnets.append(m.to(device))
def load_if_low_vram(model):
global vram_state
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.to(get_torch_device())
return model
def unload_if_low_vram(model):
global vram_state
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cpu()
return model
def unet_offload_device(): def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM: if vram_state == VRAMState.HIGH_VRAM:
@ -354,6 +384,21 @@ def unet_offload_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM:
return torch_dev
cpu_dev = torch.device("cpu")
model_size = dtype.itemsize * parameters
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev
else:
return cpu_dev
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
@ -456,6 +501,13 @@ def get_free_memory(dev=None, torch_free_too=False):
else: else:
return mem_free_total return mem_free_total
def batch_area_memory(area):
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: these formulas are copied from maximum_batch_area below
return (area / 20) * (1024 * 1024)
else:
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def maximum_batch_area(): def maximum_batch_area():
global vram_state global vram_state
if vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.NO_VRAM:

View File

@ -51,19 +51,24 @@ def get_models_from_cond(cond, model_type):
models += [c[1][model_type]] models += [c[1][model_type]]
return models return models
def load_additional_models(positive, negative, dtype): def get_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning""" """loads additional models in positive and negative conditioning"""
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
control_models = []
for m in control_nets:
control_models += m.get_models()
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1].to(dtype) for x in gligen] gligen = [x[1] for x in gligen]
models = control_nets + gligen models = control_models + gligen
comfy.model_management.load_controlnet_gpu(models)
return models return models
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" """cleanup additional models that were loaded"""
for m in models: for m in models:
m.cleanup() if hasattr(m, 'cleanup'):
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
@ -72,7 +77,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
noise_mask = prepare_mask(noise_mask, noise.shape, device) noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None real_model = None
comfy.model_management.load_model_gpu(model) models = get_additional_models(positive, negative)
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[2] * noise.shape[3]))
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)
@ -81,7 +87,6 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
positive_copy = broadcast_cond(positive, noise.shape[0], device) positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device) negative_copy = broadcast_cond(negative, noise.shape[0], device)
models = load_additional_models(positive, negative, model.model_dtype())
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)

View File

@ -88,9 +88,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
gligen_type = gligen[0] gligen_type = gligen[0]
gligen_model = gligen[1] gligen_model = gligen[1]
if gligen_type == "position": if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else: else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch] patches['middle_patch'] = [gligen_patch]

View File

@ -244,7 +244,7 @@ def set_attr(obj, attr, value):
del prev del prev
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0): def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size self.size = size
self.model = model self.model = model
self.patches = {} self.patches = {}
@ -253,6 +253,10 @@ class ModelPatcher:
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
@ -267,7 +271,7 @@ class ModelPatcher:
return size return size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -276,6 +280,11 @@ class ModelPatcher:
n.model_keys = self.model_keys n.model_keys = self.model_keys
return n return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
@ -390,6 +399,11 @@ class ModelPatcher:
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight) set_attr(self.model, key, out_weight)
del temp_weight del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
@ -482,7 +496,7 @@ class ModelPatcher:
return weight return weight
def unpatch_model(self): def unpatch_model(self, device_to=None):
keys = list(self.backup.keys()) keys = list(self.backup.keys())
for k in keys: for k in keys:
@ -490,6 +504,11 @@ class ModelPatcher:
self.backup = {} self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = model_lora_keys_unet(model.model) key_map = model_lora_keys_unet(model.model)
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
@ -630,11 +649,12 @@ class VAE:
return samples return samples
def decode(self, samples_in): def decode(self, samples_in):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
try: try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.4
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64)) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
@ -650,19 +670,19 @@ class VAE:
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device) self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1) return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.4 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
@ -677,7 +697,6 @@ class VAE:
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
@ -757,6 +776,7 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
@ -786,11 +806,9 @@ class ControlNet(ControlBase):
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
with precision_scope(model_management.get_autocast_device(self.device)): with precision_scope(model_management.get_autocast_device(self.device)):
self.control_model = model_management.load_if_low_vram(self.control_model)
context = torch.cat(cond['c_crossattn'], 1) context = torch.cat(cond['c_crossattn'], 1)
y = cond.get('c_adm', None) y = cond.get('c_adm', None)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = {'middle':[], 'output': []} out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled() autocast_enabled = torch.is_autocast_enabled()
@ -825,7 +843,7 @@ class ControlNet(ControlBase):
def get_models(self): def get_models(self):
out = super().get_models() out = super().get_models()
out.append(self.control_model) out.append(self.control_model_wrapped)
return out return out
@ -1004,7 +1022,6 @@ class T2IAdapter(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
keys = t2i_data.keys() keys = t2i_data.keys()
if 'adapter' in keys: if 'adapter' in keys:
@ -1090,7 +1107,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data) model = gligen.load_gligen(data)
if model_management.should_use_fp16(): if model_management.should_use_fp16():
model = model.half() model = model.half()
return model return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually #TODO: this function is a mess and should be removed eventually
@ -1202,8 +1219,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision: if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
dtype = torch.float32
if fp16:
dtype = torch.float16
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device) model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
@ -1224,7 +1246,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) print("left over keys:", left_over)
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format def load_unet(unet_path): #load unet in diffusers format

View File

@ -354,6 +354,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x) d = self.outputs_ui.pop(x)
del d del d
comfy.model_management.cleanup_models()
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set() executed = set()