Fix an issue when alphas_comprod are half floats.
This commit is contained in:
parent
ae43f09ef7
commit
e6e50ab2dd
|
@ -134,7 +134,7 @@ class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
|||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
@ -173,7 +173,7 @@ class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
|||
"""A wrapper for CompVis diffusion models that output v."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond)
|
||||
|
|
Loading…
Reference in New Issue