Fix hunyuan dit text encoder weights always being in fp32.
This commit is contained in:
parent
2c038ccef0
commit
a5991a7aa6
|
@ -52,8 +52,8 @@ class HyditTokenizer:
|
|||
class HyditModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.hydit_clip = HyditBertModel()
|
||||
self.mt5xl = MT5XLModel()
|
||||
self.hydit_clip = HyditBertModel(dtype=dtype)
|
||||
self.mt5xl = MT5XLModel(dtype=dtype)
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
|
|
Loading…
Reference in New Issue