FreeU now works with the refiner.
This commit is contained in:
parent
ae87543653
commit
05e661e5ef
|
@ -37,13 +37,13 @@ class FreeU:
|
|||
CATEGORY = "_for_testing"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
if h.shape[1] == 1280:
|
||||
h[:,:640] = h[:,:640] * b1
|
||||
hsp = Fourier_filter(hsp, threshold=1, scale=s1)
|
||||
if h.shape[1] == 640:
|
||||
h[:,:320] = h[:,:320] * b2
|
||||
hsp = Fourier_filter(hsp, threshold=1, scale=s2)
|
||||
scale = scale_dict.get(h.shape[1], None)
|
||||
if scale is not None:
|
||||
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
|
||||
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
|
|
Loading…
Reference in New Issue