Fix hunyuan dit text encoder weights always being in fp32.

This commit is contained in:
comfyanonymous 2024-07-31 01:34:57 -04:00
parent 2c038ccef0
commit a5991a7aa6
1 changed files with 2 additions and 2 deletions

View File

@ -52,8 +52,8 @@ class HyditTokenizer:
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.hydit_clip = HyditBertModel()
self.mt5xl = MT5XLModel()
self.hydit_clip = HyditBertModel(dtype=dtype)
self.mt5xl = MT5XLModel(dtype=dtype)
self.dtypes = set()
if dtype is not None: