From 1cdfb3dba4e7af11e2e05dc6a6276ba84eb1adf2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 17:52:41 -0400 Subject: [PATCH] Only do the cast on the device if the device supports it. --- comfy/model_management.py | 17 ++++++++++++++++ comfy/model_patcher.py | 43 ++++++++++++++++++++++++++------------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d8bc3bfe..1050c13a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -481,6 +481,23 @@ def get_autocast_device(dev): return dev.type return "cuda" +def cast_to_device(tensor, device, dtype, copy=False): + device_supports_cast = False + if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: + device_supports_cast = True + elif tensor.dtype == torch.bfloat16: + if hasattr(device, 'type') and device.type.startswith("cuda"): + device_supports_cast = True + + if device_supports_cast: + if copy: + if tensor.device == device: + return tensor.to(dtype, copy=copy) + return tensor.to(device, copy=copy).to(dtype) + else: + return tensor.to(device).to(dtype) + else: + return tensor.to(dtype).to(device, copy=copy) def xformers_enabled(): global directml_enabled diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 85bf5bd2..10551656 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -3,6 +3,7 @@ import copy import inspect import comfy.utils +import comfy.model_management class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None): @@ -154,7 +155,7 @@ class ModelPatcher: self.backup[key] = weight.to(self.offload_device) if device_to is not None: - temp_weight = weight.float().to(device_to, copy=True) + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) @@ -185,15 +186,15 @@ class ModelPatcher: if w1.shape != weight.shape: print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: - weight += alpha * w1.type(weight.dtype).to(weight.device) + weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) elif len(v) == 4: #lora/locon - mat1 = v[0].to(weight.device).float() - mat2 = v[1].to(weight.device).float() + mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) + mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].to(weight.device).float() + mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: @@ -212,18 +213,23 @@ class ModelPatcher: if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(w1_a.to(weight.device).float(), w1_b.to(weight.device).float()) + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32)) else: - w1 = w1.to(weight.device).float() + w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(w2_a.to(weight.device).float(), w2_b.to(weight.device).float()) + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32)) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2_b.to(weight.device).float(), w2_a.to(weight.device).float()) + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32)) else: - w2 = w2.to(weight.device).float() + w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -244,11 +250,20 @@ class ModelPatcher: if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.to(weight.device).float(), w1b.to(weight.device).float(), w1a.to(weight.device).float()) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2b.to(weight.device).float(), w2a.to(weight.device).float()) + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1a, weight.device, torch.float32)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2a, weight.device, torch.float32)) else: - m1 = torch.mm(w1a.to(weight.device).float(), w1b.to(weight.device).float()) - m2 = torch.mm(w2a.to(weight.device).float(), w2b.to(weight.device).float()) + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1b, weight.device, torch.float32)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2b, weight.device, torch.float32)) try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)