Try to keep text encoders loaded and patched to increase speed.

load_model_gpu() is now used with the text encoder models instead of just
the unet.
This commit is contained in:
comfyanonymous 2023-07-01 13:22:51 -04:00
parent 97ee230682
commit b6a60fa696
4 changed files with 48 additions and 40 deletions

View File

@ -3,7 +3,7 @@ import os
import yaml import yaml
import folder_paths import folder_paths
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint from comfy.sd import load_checkpoint
import os.path as osp import os.path as osp
import re import re
import torch import torch

View File

@ -216,11 +216,6 @@ current_gpu_controlnets = []
model_accelerated = False model_accelerated = False
def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return get_torch_device()
else:
return torch.device("cpu")
def unload_model(): def unload_model():
global current_loaded_model global current_loaded_model
@ -234,8 +229,8 @@ def unload_model():
model_accelerated = False model_accelerated = False
current_loaded_model.model.to(unet_offload_device()) current_loaded_model.model.to(current_loaded_model.offload_device)
current_loaded_model.model_patches_to(unet_offload_device()) current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model.unpatch_model() current_loaded_model.unpatch_model()
current_loaded_model = None current_loaded_model = None
@ -260,10 +255,14 @@ def load_model_gpu(model):
model.unpatch_model() model.unpatch_model()
raise e raise e
torch_dev = get_torch_device() torch_dev = model.load_device
model.model_patches_to(torch_dev) model.model_patches_to(torch_dev)
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state vram_set_state = vram_state
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = model.model_size() model_size = model.model_size()
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
@ -277,14 +276,14 @@ def load_model_gpu(model):
pass pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False model_accelerated = False
real_model.to(get_torch_device()) real_model.to(torch_dev)
else: else:
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_set_state == VRAMState.LOW_VRAM: elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
@ -327,6 +326,12 @@ def unload_if_low_vram(model):
return model.cpu() return model.cpu()
return model return model
def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return get_torch_device()
else:
return torch.device("cpu")
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
@ -428,13 +433,18 @@ def mps_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.MPS return cpu_state == CPUState.MPS
def is_device_cpu(device):
if hasattr(device, 'type'):
if (device.type == 'cpu' or device.type == 'mps'):
return True
return False
def should_use_fp16(device=None): def should_use_fp16(device=None):
global xpu_available global xpu_available
global directml_enabled global directml_enabled
if device is not None: #TODO if device is not None: #TODO
if hasattr(device, 'type'): if is_device_cpu(device):
if (device.type == 'cpu' or device.type == 'mps'):
return False return False
if FORCE_FP32: if FORCE_FP32:

View File

@ -308,13 +308,15 @@ def model_lora_keys(model, key_map={}):
class ModelPatcher: class ModelPatcher:
def __init__(self, model, size=0): def __init__(self, model, load_device, offload_device, size=0):
self.size = size self.size = size
self.model = model self.model = model
self.patches = [] self.patches = []
self.backup = {} self.backup = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options":{}}
self.model_size() self.model_size()
self.load_device = load_device
self.offload_device = offload_device
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
@ -329,7 +331,7 @@ class ModelPatcher:
return size return size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.size) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
n.patches = self.patches[:] n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options) n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys n.model_keys = self.model_keys
@ -341,6 +343,9 @@ class ModelPatcher:
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_patch(self, patch, name): def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" not in to: if "patches" not in to:
@ -525,14 +530,16 @@ class CLIP:
clip = target.clip clip = target.clip
tokenizer = target.tokenizer tokenizer = target.tokenizer
self.device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
if model_management.should_use_fp16(self.device): if model_management.should_use_fp16(load_device):
self.cond_stage_model.half() self.cond_stage_model.half()
self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device())
self.cond_stage_model = self.cond_stage_model.to()
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model) self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None self.layer_idx = None
def clone(self): def clone(self):
@ -541,7 +548,6 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx n.layer_idx = self.layer_idx
n.device = self.device
return n return n
def load_from_state_dict(self, sd): def load_from_state_dict(self, sd):
@ -559,21 +565,12 @@ class CLIP:
def encode_from_tokens(self, tokens, return_pooled=False): def encode_from_tokens(self, tokens, return_pooled=False):
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
try:
self.cond_stage_model.to(self.device)
self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
except Exception as e:
self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
raise e
cond_out = cond model_management.load_model_gpu(self.patcher)
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
if return_pooled: if return_pooled:
return cond_out, pooled return cond, pooled
return cond_out return cond
def encode(self, text): def encode(self, text):
tokens = self.tokenize(text) tokens = self.tokenize(text)
@ -1097,6 +1094,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if fp16: if fp16:
model = model.half() model = model.half()
offload_device = model_management.unet_offload_device()
model = model.to(offload_device)
model.load_model_weights(state_dict, "model.diffusion_model.") model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae: if output_vae:
@ -1119,7 +1118,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
w.cond_stage_model = clip.cond_stage_model w.cond_stage_model = clip.cond_stage_model
load_clip_weights(w, state_dict) load_clip_weights(w, state_dict)
return (ModelPatcher(model), clip, vae) return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
@ -1144,8 +1143,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision: if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd) model = model_config.get_model(sd)
model = model.to(model_management.unet_offload_device()) model = model.to(offload_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
@ -1166,7 +1166,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) print("left over keys:", left_over)
return (ModelPatcher(model), clip, vae, clipvision) return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
try: try:

View File

@ -112,11 +112,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
if backup_embeds.weight.dtype != torch.float32: if backup_embeds.weight.dtype != torch.float32:
print("autocast clip")
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
print("no autocast clip")
with precision_scope(model_management.get_autocast_device(device)): with precision_scope(model_management.get_autocast_device(device)):
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")