From 36a7953142ccf3f9debf9305e3cbeb3bfe956ee3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Dec 2023 14:24:04 -0500 Subject: [PATCH] Greatly improve lowvram sampling speed by getting rid of accelerate. Let me know if this breaks anything. --- comfy/controlnet.py | 2 +- comfy/model_base.py | 6 +-- comfy/model_management.py | 52 ++++++++++++---------- comfy/ops.py | 92 ++++++++++++++++++++++++++++++--------- requirements.txt | 1 - 5 files changed, 103 insertions(+), 50 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 110b5c7c..8404054f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -283,7 +283,7 @@ class ControlLora(ControlNet): cm = self.control_model.state_dict() for k in sd: - weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) + weight = sd[k] try: comfy.utils.set_attr(self.control_model, k, weight) except: diff --git a/comfy/model_base.py b/comfy/model_base.py index f2a6f984..b3a1fcd5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -162,11 +162,7 @@ class BaseModel(torch.nn.Module): def state_dict_for_saving(self, clip_state_dict, vae_state_dict): clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) - unet_sd = self.diffusion_model.state_dict() - unet_state_dict = {} - for k in unet_sd: - unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) - + unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) if self.get_dtype() == torch.float16: diff --git a/comfy/model_management.py b/comfy/model_management.py index 23f39c98..61c967f6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -218,15 +218,8 @@ if args.force_fp16: FORCE_FP16 = True if lowvram_available: - try: - import accelerate - if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): - vram_state = set_vram_to - except Exception as e: - import traceback - print(traceback.format_exc()) - print("ERROR: LOW VRAM MODE NEEDS accelerate.") - lowvram_available = False + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to if cpu_state != CPUState.GPU: @@ -298,8 +291,20 @@ class LoadedModel: if lowvram_model_memory > 0: print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) - device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) + mem_counter = 0 + for m in self.real_model.modules(): + if hasattr(m, "comfy_cast_weights"): + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + module_mem = 0 + sd = m.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + if mem_counter + module_mem < lowvram_model_memory: + m.to(self.device) + mem_counter += module_mem + self.model_accelerated = True if is_intel_xpu() and not args.disable_ipex_optimize: @@ -309,7 +314,11 @@ class LoadedModel: def model_unload(self): if self.model_accelerated: - accelerate.hooks.remove_hook_from_submodules(self.real_model) + for m in self.real_model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + self.model_accelerated = False self.model.unpatch_model(self.model.offload_device) @@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM else: lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 256 * 1024 * 1024 + lowvram_model_memory = 64 * 1024 * 1024 cur_loaded_model = loaded_model.model_load(lowvram_model_memory) current_loaded_models.insert(0, loaded_model) @@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO return True return False +def device_supports_non_blocking(device): + if is_device_mps(device): + return False #pytorch bug? mps doesn't support non blocking + return True + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: @@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False): elif is_intel_xpu(): device_supports_cast = True - non_blocking = True - if is_device_mps(device): - non_blocking = False #pytorch bug? mps doesn't support non blocking + non_blocking = device_supports_non_blocking(device) if device_supports_cast: if copy: @@ -742,11 +754,7 @@ def soft_empty_cache(force=False): torch.cuda.empty_cache() torch.cuda.ipc_collect() -def resolve_lowvram_weight(weight, model, key): - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = comfy.utils.get_attr(model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] +def resolve_lowvram_weight(weight, model, key): #TODO: remove return weight #TODO: might be cleaner to put this somewhere else diff --git a/comfy/ops.py b/comfy/ops.py index 08c63384..f6f85de6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,27 +1,93 @@ import torch from contextlib import contextmanager +import comfy.model_management + +def cast_bias_weight(s, input): + bias = None + non_blocking = comfy.model_management.device_supports_non_blocking(input.device) + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + return weight, bias + class disable_weight_init: class Linear(torch.nn.Linear): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class Conv2d(torch.nn.Conv2d): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class Conv3d(torch.nn.Conv3d): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class GroupNorm(torch.nn.GroupNorm): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class LayerNorm(torch.nn.LayerNorm): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: @@ -31,35 +97,19 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") -def cast_bias_weight(s, input): - bias = None - if s.bias is not None: - bias = s.bias.to(device=input.device, dtype=input.dtype) - weight = s.weight.to(device=input.device, dtype=input.dtype) - return weight, bias class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + comfy_cast_weights = True class Conv2d(disable_weight_init.Conv2d): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + comfy_cast_weights = True class Conv3d(disable_weight_init.Conv3d): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + comfy_cast_weights = True class GroupNorm(disable_weight_init.GroupNorm): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + comfy_cast_weights = True class LayerNorm(disable_weight_init.LayerNorm): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + comfy_cast_weights = True diff --git a/requirements.txt b/requirements.txt index 14524485..da1fbb27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ einops transformers>=4.25.1 safetensors>=0.3.0 aiohttp -accelerate pyyaml Pillow scipy