From f081017c1a20a5d9cfae9005fd0898502e3356be Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 23 Aug 2023 01:07:57 -0400 Subject: [PATCH] Save memory by storing text encoder weights in fp16 in most situations. Do inference in fp32 to make sure quality stays the exact same. --- comfy/model_management.py | 2 +- comfy/sd.py | 7 ++----- comfy/sd1_clip.py | 4 ++-- web/extensions/core/uploadImage.js | 1 - 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index fc0cb901..9c100144 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -433,7 +433,7 @@ def text_encoder_device(): return get_torch_device() elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: #NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU - if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough. + if should_use_fp16() or torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough. return get_torch_device() else: return torch.device("cpu") diff --git a/comfy/sd.py b/comfy/sd.py index 5920ddde..7de72d37 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -546,11 +546,8 @@ class CLIP: offload_device = model_management.text_encoder_offload_device() params['device'] = load_device self.cond_stage_model = clip(**(params)) - #TODO: make sure this doesn't have a quality loss before enabling. - # if model_management.should_use_fp16(load_device): - # self.cond_stage_model.half() - - self.cond_stage_model = self.cond_stage_model.to() + if model_management.should_use_fp16(load_device): + self.cond_stage_model.half() self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index feca4188..c699214a 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -137,9 +137,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if backup_embeds.weight.dtype != torch.float32: precision_scope = torch.autocast else: - precision_scope = contextlib.nullcontext + precision_scope = lambda a, b: contextlib.nullcontext(a) - with precision_scope(model_management.get_autocast_device(device)): + with precision_scope(model_management.get_autocast_device(device), torch.float32): outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") self.transformer.set_input_embeddings(backup_embeds) diff --git a/web/extensions/core/uploadImage.js b/web/extensions/core/uploadImage.js index fda83f8c..530c4599 100644 --- a/web/extensions/core/uploadImage.js +++ b/web/extensions/core/uploadImage.js @@ -5,7 +5,6 @@ import { app } from "../../scripts/app.js"; app.registerExtension({ name: "Comfy.UploadImage", async beforeRegisterNodeDef(nodeType, nodeData, app) { - console.log(nodeData); if (nodeData?.input?.required?.image?.[1]?.image_upload === true) { nodeData.input.required.upload = ["IMAGEUPLOAD"]; }