From 6c6a39251fe313c56a88c90d820009073b623dfe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 Apr 2024 11:46:34 -0400 Subject: [PATCH] Fix saving text encoder in fp8. --- comfy/diffusers_convert.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 08018c54..ed2a45fe 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys())) # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp code2idx = {"q": 0, "k": 1, "v": 2} +# This function exists because at the time of writing torch.cat can't do fp8 with cuda +def cat_tensors(tensors): + x = 0 + for t in tensors: + x += t.shape[0] + + shape = [x] + list(tensors[0].shape)[1:] + out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) + + x = 0 + for t in tensors: + out[x:x + t.shape[0]] = t + x += t.shape[0] + + return out def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): new_state_dict = {} @@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): if None in tensors: raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) - new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors) for k_pre, tensors in capture_qkv_bias.items(): if None in tensors: raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) - new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors) return new_state_dict