""" This file is part of ComfyUI. Copyright (C) 2024 Comfy This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ import torch from . import model_base from . import utils from . import latent_formats class ClipTarget: def __init__(self, tokenizer, clip): self.clip = clip self.tokenizer = tokenizer self.params = {} class BASE: unet_config = {} unet_extra_config = { "num_heads": -1, "num_head_channels": 64, } required_keys = {} clip_prefix = [] clip_vision_prefix = None noise_aug_config = None sampling_settings = {} latent_format = latent_formats.LatentFormat vae_key_prefix = ["first_stage_model."] text_encoder_key_prefix = ["cond_stage_model."] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] memory_usage_factor = 2.0 manual_cast_dtype = None custom_operations = None scaled_fp8 = None optimizations = {"fp8": False} @classmethod def matches(s, unet_config, state_dict=None): for k in s.unet_config: if k not in unet_config or s.unet_config[k] != unet_config[k]: return False if state_dict is not None: for k in s.required_keys: if k not in state_dict: return False return True def model_type(self, state_dict, prefix=""): return model_base.ModelType.EPS def inpaint_model(self): return self.unet_config["in_channels"] > 4 def __init__(self, unet_config): self.unet_config = unet_config.copy() self.sampling_settings = self.sampling_settings.copy() self.latent_format = self.latent_format() self.optimizations = self.optimizations.copy() for x in self.unet_extra_config: self.unet_config[x] = self.unet_extra_config[x] def get_model(self, state_dict, prefix="", device=None): if self.noise_aug_config is not None: out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device) else: out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device) if self.inpaint_model(): out.set_inpaint() return out def process_clip_state_dict(self, state_dict): state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) return state_dict def process_unet_state_dict(self, state_dict): return state_dict def process_vae_state_dict(self, state_dict): return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": self.text_encoder_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_clip_vision_state_dict_for_saving(self, state_dict): replace_prefix = {} if self.clip_vision_prefix is not None: replace_prefix[""] = self.clip_vision_prefix return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model.diffusion_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_vae_state_dict_for_saving(self, state_dict): replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype