diff --git a/comfy/model_management.py b/comfy/model_management.py index bcd86a03..07c13727 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -661,6 +661,17 @@ def supports_cast(device, dtype): #TODO return True return False +def pick_weight_dtype(dtype, fallback_dtype, device=None): + if dtype is None: + dtype = fallback_dtype + elif dtype_size(dtype) > dtype_size(fallback_dtype): + dtype = fallback_dtype + + if not supports_cast(device, dtype): + dtype = fallback_dtype + + return dtype + def device_supports_non_blocking(device): if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking diff --git a/comfy/sd.py b/comfy/sd.py index c9bc1639..2be8edef 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -441,7 +441,13 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_target.clip = comfy.text_encoders.hydit.HyditModel clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer elif clip_type == CLIPType.FLUX: - clip_target.clip = comfy.text_encoders.flux.FluxClipModel + weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" + weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None)) + dtype_t5 = None + if weight is not None: + dtype_t5 = weight.dtype + + clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5) clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 2759a38a..849214ce 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,5 +1,6 @@ from comfy import sd1_clip import comfy.text_encoders.t5 +import comfy.model_management from transformers import T5TokenizerFast import torch import os @@ -34,11 +35,12 @@ class FluxTokenizer: class FluxClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None): + def __init__(self, dtype_t5=None, device="cpu", dtype=None): super().__init__() + dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False) - self.t5xxl = T5XXLModel(device=device, dtype=dtype) - self.dtypes = set([dtype]) + self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) + self.dtypes = set([dtype, dtype_t5]) def set_clip_options(self, options): self.clip_l.set_clip_options(options) @@ -62,3 +64,8 @@ class FluxClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) +def flux_clip(dtype_t5=None): + class FluxClipModel_(FluxClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype) + return FluxClipModel_ diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index b01fad22..143d884c 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -54,14 +54,7 @@ class SD3ClipModel(torch.nn.Module): self.clip_g = None if t5: - if dtype_t5 is None: - dtype_t5 = dtype - elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype): - dtype_t5 = dtype - - if not comfy.model_management.supports_cast(device, dtype_t5): - dtype_t5 = dtype - + dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) self.dtypes.add(dtype_t5) else: