FreeU now works with the refiner.
This commit is contained in:
parent
ae87543653
commit
05e661e5ef
|
@ -37,13 +37,13 @@ class FreeU:
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def patch(self, model, b1, b2, s1, s2):
|
def patch(self, model, b1, b2, s1, s2):
|
||||||
|
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||||
|
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||||
def output_block_patch(h, hsp, transformer_options):
|
def output_block_patch(h, hsp, transformer_options):
|
||||||
if h.shape[1] == 1280:
|
scale = scale_dict.get(h.shape[1], None)
|
||||||
h[:,:640] = h[:,:640] * b1
|
if scale is not None:
|
||||||
hsp = Fourier_filter(hsp, threshold=1, scale=s1)
|
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
|
||||||
if h.shape[1] == 640:
|
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
||||||
h[:,:320] = h[:,:320] * b2
|
|
||||||
hsp = Fourier_filter(hsp, threshold=1, scale=s2)
|
|
||||||
return h, hsp
|
return h, hsp
|
||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
Loading…
Reference in New Issue