Load flux t5 in fp8 if weights are in fp8.

This commit is contained in:
comfyanonymous 2024-08-01 11:05:56 -04:00
parent 8d34211a7a
commit 5f98de7697
4 changed files with 29 additions and 12 deletions

View File

@ -661,6 +661,17 @@ def supports_cast(device, dtype): #TODO
return True
return False
def pick_weight_dtype(dtype, fallback_dtype, device=None):
if dtype is None:
dtype = fallback_dtype
elif dtype_size(dtype) > dtype_size(fallback_dtype):
dtype = fallback_dtype
if not supports_cast(device, dtype):
dtype = fallback_dtype
return dtype
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking

View File

@ -441,7 +441,13 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = comfy.text_encoders.hydit.HyditModel
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
elif clip_type == CLIPType.FLUX:
clip_target.clip = comfy.text_encoders.flux.FluxClipModel
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
dtype_t5 = None
if weight is not None:
dtype_t5 = weight.dtype
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel

View File

@ -1,5 +1,6 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.model_management
from transformers import T5TokenizerFast
import torch
import os
@ -34,11 +35,12 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
self.dtypes = set([dtype])
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
@ -62,3 +64,8 @@ class FluxClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
return FluxClipModel_

View File

@ -54,14 +54,7 @@ class SD3ClipModel(torch.nn.Module):
self.clip_g = None
if t5:
if dtype_t5 is None:
dtype_t5 = dtype
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
dtype_t5 = dtype
if not comfy.model_management.supports_cast(device, dtype_t5):
dtype_t5 = dtype
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes.add(dtype_t5)
else: