Fix RescaleCFG for batch size > 1.
This commit is contained in:
parent
58d5d71a93
commit
ca2812bae0
|
@ -140,6 +140,7 @@ class RescaleCFG:
|
|||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
sigma = args["sigma"]
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
|
||||
x_orig = args["input"]
|
||||
|
||||
#rescale cfg has to be done on v-pred model output
|
||||
|
|
Loading…
Reference in New Issue