Greatly improve lowvram sampling speed by getting rid of accelerate.
Let me know if this breaks anything.
This commit is contained in:
parent
261bcbb0d9
commit
36a7953142
|
@ -283,7 +283,7 @@ class ControlLora(ControlNet):
|
||||||
cm = self.control_model.state_dict()
|
cm = self.control_model.state_dict()
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
|
weight = sd[k]
|
||||||
try:
|
try:
|
||||||
comfy.utils.set_attr(self.control_model, k, weight)
|
comfy.utils.set_attr(self.control_model, k, weight)
|
||||||
except:
|
except:
|
||||||
|
|
|
@ -162,11 +162,7 @@ class BaseModel(torch.nn.Module):
|
||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||||
unet_sd = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
unet_state_dict = {}
|
|
||||||
for k in unet_sd:
|
|
||||||
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
|
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||||
if self.get_dtype() == torch.float16:
|
if self.get_dtype() == torch.float16:
|
||||||
|
|
|
@ -218,15 +218,8 @@ if args.force_fp16:
|
||||||
FORCE_FP16 = True
|
FORCE_FP16 = True
|
||||||
|
|
||||||
if lowvram_available:
|
if lowvram_available:
|
||||||
try:
|
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||||
import accelerate
|
vram_state = set_vram_to
|
||||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
|
||||||
vram_state = set_vram_to
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
print(traceback.format_exc())
|
|
||||||
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
|
|
||||||
lowvram_available = False
|
|
||||||
|
|
||||||
|
|
||||||
if cpu_state != CPUState.GPU:
|
if cpu_state != CPUState.GPU:
|
||||||
|
@ -298,8 +291,20 @@ class LoadedModel:
|
||||||
|
|
||||||
if lowvram_model_memory > 0:
|
if lowvram_model_memory > 0:
|
||||||
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
||||||
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
mem_counter = 0
|
||||||
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
for m in self.real_model.modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
module_mem = 0
|
||||||
|
sd = m.state_dict()
|
||||||
|
for k in sd:
|
||||||
|
t = sd[k]
|
||||||
|
module_mem += t.nelement() * t.element_size()
|
||||||
|
if mem_counter + module_mem < lowvram_model_memory:
|
||||||
|
m.to(self.device)
|
||||||
|
mem_counter += module_mem
|
||||||
|
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
|
@ -309,7 +314,11 @@ class LoadedModel:
|
||||||
|
|
||||||
def model_unload(self):
|
def model_unload(self):
|
||||||
if self.model_accelerated:
|
if self.model_accelerated:
|
||||||
accelerate.hooks.remove_hook_from_submodules(self.real_model)
|
for m in self.real_model.modules():
|
||||||
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
|
del m.prev_comfy_cast_weights
|
||||||
|
|
||||||
self.model_accelerated = False
|
self.model_accelerated = False
|
||||||
|
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
|
@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0):
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
else:
|
else:
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 256 * 1024 * 1024
|
lowvram_model_memory = 64 * 1024 * 1024
|
||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
|
@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def device_supports_non_blocking(device):
|
||||||
|
if is_device_mps(device):
|
||||||
|
return False #pytorch bug? mps doesn't support non blocking
|
||||||
|
return True
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
device_supports_cast = False
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||||
|
@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
device_supports_cast = True
|
device_supports_cast = True
|
||||||
|
|
||||||
non_blocking = True
|
non_blocking = device_supports_non_blocking(device)
|
||||||
if is_device_mps(device):
|
|
||||||
non_blocking = False #pytorch bug? mps doesn't support non blocking
|
|
||||||
|
|
||||||
if device_supports_cast:
|
if device_supports_cast:
|
||||||
if copy:
|
if copy:
|
||||||
|
@ -742,11 +754,7 @@ def soft_empty_cache(force=False):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key):
|
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
|
||||||
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
|
||||||
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
|
|
||||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
|
|
92
comfy/ops.py
92
comfy/ops.py
|
@ -1,27 +1,93 @@
|
||||||
import torch
|
import torch
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input):
|
||||||
|
bias = None
|
||||||
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||||
|
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear):
|
class Linear(torch.nn.Linear):
|
||||||
|
comfy_cast_weights = False
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d):
|
class Conv2d(torch.nn.Conv2d):
|
||||||
|
comfy_cast_weights = False
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Conv3d(torch.nn.Conv3d):
|
class Conv3d(torch.nn.Conv3d):
|
||||||
|
comfy_cast_weights = False
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class GroupNorm(torch.nn.GroupNorm):
|
class GroupNorm(torch.nn.GroupNorm):
|
||||||
|
comfy_cast_weights = False
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm):
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
|
comfy_cast_weights = False
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def conv_nd(s, dims, *args, **kwargs):
|
def conv_nd(s, dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
|
@ -31,35 +97,19 @@ class disable_weight_init:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
def cast_bias_weight(s, input):
|
|
||||||
bias = None
|
|
||||||
if s.bias is not None:
|
|
||||||
bias = s.bias.to(device=input.device, dtype=input.dtype)
|
|
||||||
weight = s.weight.to(device=input.device, dtype=input.dtype)
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
class manual_cast(disable_weight_init):
|
class manual_cast(disable_weight_init):
|
||||||
class Linear(disable_weight_init.Linear):
|
class Linear(disable_weight_init.Linear):
|
||||||
def forward(self, input):
|
comfy_cast_weights = True
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
|
||||||
|
|
||||||
class Conv2d(disable_weight_init.Conv2d):
|
class Conv2d(disable_weight_init.Conv2d):
|
||||||
def forward(self, input):
|
comfy_cast_weights = True
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return self._conv_forward(input, weight, bias)
|
|
||||||
|
|
||||||
class Conv3d(disable_weight_init.Conv3d):
|
class Conv3d(disable_weight_init.Conv3d):
|
||||||
def forward(self, input):
|
comfy_cast_weights = True
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return self._conv_forward(input, weight, bias)
|
|
||||||
|
|
||||||
class GroupNorm(disable_weight_init.GroupNorm):
|
class GroupNorm(disable_weight_init.GroupNorm):
|
||||||
def forward(self, input):
|
comfy_cast_weights = True
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
|
||||||
|
|
||||||
class LayerNorm(disable_weight_init.LayerNorm):
|
class LayerNorm(disable_weight_init.LayerNorm):
|
||||||
def forward(self, input):
|
comfy_cast_weights = True
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ einops
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
safetensors>=0.3.0
|
safetensors>=0.3.0
|
||||||
aiohttp
|
aiohttp
|
||||||
accelerate
|
|
||||||
pyyaml
|
pyyaml
|
||||||
Pillow
|
Pillow
|
||||||
scipy
|
scipy
|
||||||
|
|
Loading…
Reference in New Issue