Reduce floating point rounding errors in loras.

This commit is contained in:
comfyanonymous 2023-07-15 00:45:38 -04:00
parent 91ed2815d5
commit 6fb084f39d
1 changed files with 3 additions and 1 deletions

View File

@ -342,7 +342,9 @@ class ModelPatcher:
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.clone() self.backup[key] = weight.clone()
weight[:] = self.calculate_weight(self.patches[key], weight.clone(), key) temp_weight = weight.to(torch.float32, copy=True)
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
del temp_weight
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):