Fix ControlLora on lowvram.
This commit is contained in:
parent
d08e53de2e
commit
199d73364a
23
comfy/sd.py
23
comfy/sd.py
|
@ -243,6 +243,13 @@ def set_attr(obj, attr, value):
|
|||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||
del prev
|
||||
|
||||
def get_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs:
|
||||
obj = getattr(obj, name)
|
||||
return obj
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||
self.size = size
|
||||
|
@ -856,9 +863,9 @@ class ControlLoraOps:
|
|||
|
||||
def forward(self, input):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias)
|
||||
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
||||
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(
|
||||
|
@ -895,9 +902,9 @@ class ControlLoraOps:
|
|||
|
||||
def forward(self, input):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
def conv_nd(self, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
|
@ -927,8 +934,14 @@ class ControlLora(ControlNet):
|
|||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
weight = sd[k]
|
||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||
op = get_attr(diffusion_model, '.'.join(key_split[:-1]))
|
||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||
|
||||
try:
|
||||
set_attr(self.control_model, k, sd[k])
|
||||
set_attr(self.control_model, k, weight)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue