Reduce floating point rounding errors in loras.
This commit is contained in:
parent
91ed2815d5
commit
6fb084f39d
|
@ -342,7 +342,9 @@ class ModelPatcher:
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.clone()
|
self.backup[key] = weight.clone()
|
||||||
|
|
||||||
weight[:] = self.calculate_weight(self.patches[key], weight.clone(), key)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
del temp_weight
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key):
|
||||||
|
|
Loading…
Reference in New Issue