Do FreeU fft on CPU if the device doesn't support fft functions.

This commit is contained in:
comfyanonymous 2023-09-24 18:09:44 -04:00
parent 77c124c5a1
commit f00471cdc8
1 changed files with 12 additions and 1 deletions

View File

@ -39,11 +39,22 @@ class FreeU:
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
def output_block_patch(h, hsp, transformer_options):
scale = scale_dict.get(h.shape[1], None)
if scale is not None:
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
if hsp.device not in on_cpu_devices:
try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else:
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
return h, hsp
m = model.clone()