FreeU now works with the refiner.

This commit is contained in:
comfyanonymous 2023-09-23 12:19:08 -04:00
parent ae87543653
commit 05e661e5ef
1 changed files with 6 additions and 6 deletions

View File

@ -37,13 +37,13 @@ class FreeU:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def patch(self, model, b1, b2, s1, s2): def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
def output_block_patch(h, hsp, transformer_options): def output_block_patch(h, hsp, transformer_options):
if h.shape[1] == 1280: scale = scale_dict.get(h.shape[1], None)
h[:,:640] = h[:,:640] * b1 if scale is not None:
hsp = Fourier_filter(hsp, threshold=1, scale=s1) h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
if h.shape[1] == 640: hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
h[:,:320] = h[:,:320] * b2
hsp = Fourier_filter(hsp, threshold=1, scale=s2)
return h, hsp return h, hsp
m = model.clone() m = model.clone()