Fix an issue when alphas_comprod are half floats.

This commit is contained in:
comfyanonymous 2023-06-16 17:16:51 -04:00
parent ae43f09ef7
commit e6e50ab2dd
1 changed files with 2 additions and 2 deletions

View File

@ -134,7 +134,7 @@ class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
@ -173,7 +173,7 @@ class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)