diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index d6074c7d..ba04b981 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -3,7 +3,7 @@ import os import yaml import folder_paths -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint +from comfy.sd import load_checkpoint import os.path as osp import re import torch diff --git a/comfy/model_management.py b/comfy/model_management.py index f10d1ca8..0babdc13 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -216,11 +216,6 @@ current_gpu_controlnets = [] model_accelerated = False -def unet_offload_device(): - if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: - return get_torch_device() - else: - return torch.device("cpu") def unload_model(): global current_loaded_model @@ -234,8 +229,8 @@ def unload_model(): model_accelerated = False - current_loaded_model.model.to(unet_offload_device()) - current_loaded_model.model_patches_to(unet_offload_device()) + current_loaded_model.model.to(current_loaded_model.offload_device) + current_loaded_model.model_patches_to(current_loaded_model.offload_device) current_loaded_model.unpatch_model() current_loaded_model = None @@ -260,10 +255,14 @@ def load_model_gpu(model): model.unpatch_model() raise e - torch_dev = get_torch_device() + torch_dev = model.load_device model.model_patches_to(torch_dev) - vram_set_state = vram_state + if is_device_cpu(torch_dev): + vram_set_state = VRAMState.DISABLED + else: + vram_set_state = vram_state + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = model.model_size() current_free_mem = get_free_memory(torch_dev) @@ -277,14 +276,14 @@ def load_model_gpu(model): pass elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False - real_model.to(get_torch_device()) + real_model.to(torch_dev) else: if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) elif vram_set_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) + accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) model_accelerated = True return current_loaded_model @@ -327,6 +326,12 @@ def unload_if_low_vram(model): return model.cpu() return model +def unet_offload_device(): + if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: + return get_torch_device() + else: + return torch.device("cpu") + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -428,14 +433,19 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS +def is_device_cpu(device): + if hasattr(device, 'type'): + if (device.type == 'cpu' or device.type == 'mps'): + return True + return False + def should_use_fp16(device=None): global xpu_available global directml_enabled if device is not None: #TODO - if hasattr(device, 'type'): - if (device.type == 'cpu' or device.type == 'mps'): - return False + if is_device_cpu(device): + return False if FORCE_FP32: return False diff --git a/comfy/sd.py b/comfy/sd.py index 320b0fb7..5eef51b3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -308,13 +308,15 @@ def model_lora_keys(model, key_map={}): class ModelPatcher: - def __init__(self, model, size=0): + def __init__(self, model, load_device, offload_device, size=0): self.size = size self.model = model self.patches = [] self.backup = {} self.model_options = {"transformer_options":{}} self.model_size() + self.load_device = load_device + self.offload_device = offload_device def model_size(self): if self.size > 0: @@ -329,7 +331,7 @@ class ModelPatcher: return size def clone(self): - n = ModelPatcher(self.model, self.size) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys @@ -341,6 +343,9 @@ class ModelPatcher: else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + def set_model_unet_function_wrapper(self, unet_wrapper_function): + self.model_options["model_function_wrapper"] = unet_wrapper_function + def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: @@ -525,14 +530,16 @@ class CLIP: clip = target.clip tokenizer = target.tokenizer - self.device = model_management.text_encoder_device() + load_device = model_management.text_encoder_device() + offload_device = model_management.text_encoder_offload_device() self.cond_stage_model = clip(**(params)) - if model_management.should_use_fp16(self.device): + if model_management.should_use_fp16(load_device): self.cond_stage_model.half() - self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device()) + + self.cond_stage_model = self.cond_stage_model.to() self.tokenizer = tokenizer(embedding_directory=embedding_directory) - self.patcher = ModelPatcher(self.cond_stage_model) + self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None def clone(self): @@ -541,7 +548,6 @@ class CLIP: n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx - n.device = self.device return n def load_from_state_dict(self, sd): @@ -559,21 +565,12 @@ class CLIP: def encode_from_tokens(self, tokens, return_pooled=False): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) - try: - self.cond_stage_model.to(self.device) - self.patch_model() - cond, pooled = self.cond_stage_model.encode_token_weights(tokens) - self.unpatch_model() - self.cond_stage_model.to(model_management.text_encoder_offload_device()) - except Exception as e: - self.unpatch_model() - self.cond_stage_model.to(model_management.text_encoder_offload_device()) - raise e - cond_out = cond + model_management.load_model_gpu(self.patcher) + cond, pooled = self.cond_stage_model.encode_token_weights(tokens) if return_pooled: - return cond_out, pooled - return cond_out + return cond, pooled + return cond def encode(self, text): tokens = self.tokenize(text) @@ -1097,6 +1094,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if fp16: model = model.half() + offload_device = model_management.unet_offload_device() + model = model.to(offload_device) model.load_model_weights(state_dict, "model.diffusion_model.") if output_vae: @@ -1119,7 +1118,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl w.cond_stage_model = clip.cond_stage_model load_clip_weights(w, state_dict) - return (ModelPatcher(model), clip, vae) + return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): @@ -1144,8 +1143,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clipvision: clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) + offload_device = model_management.unet_offload_device() model = model_config.get_model(sd) - model = model.to(model_management.unet_offload_device()) + model = model.to(offload_device) model.load_model_weights(sd, "model.diffusion_model.") if output_vae: @@ -1166,7 +1166,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - return (ModelPatcher(model), clip, vae, clipvision) + return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) def save_checkpoint(output_path, model, clip, vae, metadata=None): try: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 5c627cb8..ffcb849d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -112,11 +112,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = torch.LongTensor(tokens).to(device) if backup_embeds.weight.dtype != torch.float32: - print("autocast clip") precision_scope = torch.autocast else: precision_scope = contextlib.nullcontext - print("no autocast clip") with precision_scope(model_management.get_autocast_device(device)): outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")