diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 535eece3..c3542a7a 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -37,13 +37,13 @@ class FreeU: CATEGORY = "_for_testing" def patch(self, model, b1, b2, s1, s2): + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} def output_block_patch(h, hsp, transformer_options): - if h.shape[1] == 1280: - h[:,:640] = h[:,:640] * b1 - hsp = Fourier_filter(hsp, threshold=1, scale=s1) - if h.shape[1] == 640: - h[:,:320] = h[:,:320] * b2 - hsp = Fourier_filter(hsp, threshold=1, scale=s2) + scale = scale_dict.get(h.shape[1], None) + if scale is not None: + h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0] + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) return h, hsp m = model.clone()