Only do the cast on the device if the device supports it.
This commit is contained in:
parent
b92a86d737
commit
1cdfb3dba4
|
@ -481,6 +481,23 @@ def get_autocast_device(dev):
|
||||||
return dev.type
|
return dev.type
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
|
device_supports_cast = False
|
||||||
|
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||||
|
device_supports_cast = True
|
||||||
|
elif tensor.dtype == torch.bfloat16:
|
||||||
|
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
||||||
|
device_supports_cast = True
|
||||||
|
|
||||||
|
if device_supports_cast:
|
||||||
|
if copy:
|
||||||
|
if tensor.device == device:
|
||||||
|
return tensor.to(dtype, copy=copy)
|
||||||
|
return tensor.to(device, copy=copy).to(dtype)
|
||||||
|
else:
|
||||||
|
return tensor.to(device).to(dtype)
|
||||||
|
else:
|
||||||
|
return tensor.to(dtype).to(device, copy=copy)
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
|
|
@ -3,6 +3,7 @@ import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||||
|
@ -154,7 +155,7 @@ class ModelPatcher:
|
||||||
self.backup[key] = weight.to(self.offload_device)
|
self.backup[key] = weight.to(self.offload_device)
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = weight.float().to(device_to, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
@ -185,15 +186,15 @@ class ModelPatcher:
|
||||||
if w1.shape != weight.shape:
|
if w1.shape != weight.shape:
|
||||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
elif len(v) == 4: #lora/locon
|
elif len(v) == 4: #lora/locon
|
||||||
mat1 = v[0].to(weight.device).float()
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||||
mat2 = v[1].to(weight.device).float()
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
mat3 = v[3].to(weight.device).float()
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
try:
|
try:
|
||||||
|
@ -212,18 +213,23 @@ class ModelPatcher:
|
||||||
|
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(w1_a.to(weight.device).float(), w1_b.to(weight.device).float())
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
||||||
else:
|
else:
|
||||||
w1 = w1.to(weight.device).float()
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
|
||||||
|
|
||||||
if w2 is None:
|
if w2 is None:
|
||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(w2_a.to(weight.device).float(), w2_b.to(weight.device).float())
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2_b.to(weight.device).float(), w2_a.to(weight.device).float())
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
||||||
else:
|
else:
|
||||||
w2 = w2.to(weight.device).float()
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
@ -244,11 +250,20 @@ class ModelPatcher:
|
||||||
if v[5] is not None: #cp decomposition
|
if v[5] is not None: #cp decomposition
|
||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.to(weight.device).float(), w1b.to(weight.device).float(), w1a.to(weight.device).float())
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2b.to(weight.device).float(), w2a.to(weight.device).float())
|
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(w1a.to(weight.device).float(), w1b.to(weight.device).float())
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
||||||
m2 = torch.mm(w2a.to(weight.device).float(), w2b.to(weight.device).float())
|
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
|
Loading…
Reference in New Issue