Fix lowvram edge case.

This commit is contained in:
comfyanonymous 2024-10-22 16:34:50 -04:00
parent 5a8a48931a
commit 915fdb5745
1 changed files with 4 additions and 0 deletions

View File

@ -264,10 +264,14 @@ def fp8_linear(self, input):
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
if bias is not None: