2023-01-03 06:53:32 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
import sd1_clip
|
|
|
|
import sd2_clip
|
|
|
|
from ldm.util import instantiate_from_config
|
|
|
|
from ldm.models.autoencoder import AutoencoderKL
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
|
|
|
print(f"Loading model from {ckpt}")
|
|
|
|
|
|
|
|
if ckpt.lower().endswith(".safetensors"):
|
|
|
|
import safetensors.torch
|
|
|
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
|
|
|
else:
|
|
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
|
if "global_step" in pl_sd:
|
|
|
|
print(f"Global Step: {pl_sd['global_step']}")
|
2023-01-25 20:20:55 +00:00
|
|
|
if "state_dict" in pl_sd:
|
|
|
|
sd = pl_sd["state_dict"]
|
|
|
|
else:
|
|
|
|
sd = pl_sd
|
2023-01-03 06:53:32 +00:00
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
|
|
|
|
m, u = model.load_state_dict(sd, strict=False)
|
|
|
|
|
|
|
|
k = list(sd.keys())
|
|
|
|
for x in k:
|
|
|
|
# print(x)
|
|
|
|
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
|
|
|
|
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
|
|
|
sd[y] = sd.pop(x)
|
|
|
|
|
2023-01-28 07:14:22 +00:00
|
|
|
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
|
|
|
|
ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids']
|
|
|
|
if ids.dtype == torch.float32:
|
|
|
|
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
2023-01-28 05:19:33 +00:00
|
|
|
|
2023-01-03 06:53:32 +00:00
|
|
|
for x in load_state_dict_to:
|
|
|
|
x.load_state_dict(sd, strict=False)
|
|
|
|
|
|
|
|
if len(m) > 0 and verbose:
|
|
|
|
print("missing keys:")
|
|
|
|
print(m)
|
|
|
|
if len(u) > 0 and verbose:
|
|
|
|
print("unexpected keys:")
|
|
|
|
print(u)
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIP:
|
2023-01-29 23:46:44 +00:00
|
|
|
def __init__(self, config, embedding_directory=None):
|
2023-01-03 06:53:32 +00:00
|
|
|
self.target_clip = config["target"]
|
2023-01-29 23:46:44 +00:00
|
|
|
if "params" in config:
|
|
|
|
params = config["params"]
|
|
|
|
else:
|
|
|
|
params = {}
|
|
|
|
|
|
|
|
tokenizer_params = {}
|
|
|
|
|
2023-01-03 06:53:32 +00:00
|
|
|
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
|
|
|
|
clip = sd2_clip.SD2ClipModel
|
|
|
|
tokenizer = sd2_clip.SD2Tokenizer
|
|
|
|
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
|
|
|
|
clip = sd1_clip.SD1ClipModel
|
|
|
|
tokenizer = sd1_clip.SD1Tokenizer
|
2023-01-29 23:46:44 +00:00
|
|
|
tokenizer_params['embedding_directory'] = embedding_directory
|
|
|
|
|
|
|
|
self.cond_stage_model = clip(**(params))
|
|
|
|
self.tokenizer = tokenizer(**(tokenizer_params))
|
2023-01-03 06:53:32 +00:00
|
|
|
|
|
|
|
def encode(self, text):
|
|
|
|
tokens = self.tokenizer.tokenize_with_weights(text)
|
|
|
|
cond = self.cond_stage_model.encode_token_weights(tokens)
|
|
|
|
return cond
|
|
|
|
|
|
|
|
|
|
|
|
class VAE:
|
|
|
|
def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None):
|
|
|
|
if config is None:
|
|
|
|
#default SD1.x/SD2.x VAE parameters
|
|
|
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
|
|
|
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
|
|
|
|
else:
|
|
|
|
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
|
|
|
|
self.first_stage_model = self.first_stage_model.eval()
|
|
|
|
self.scale_factor = scale_factor
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
def decode(self, samples):
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
|
|
|
samples = samples.to(self.device)
|
|
|
|
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
|
|
|
|
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu()
|
|
|
|
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
|
|
|
return pixel_samples
|
|
|
|
|
|
|
|
def encode(self, pixel_samples):
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
|
|
|
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
|
|
|
|
samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu()
|
|
|
|
samples = samples.cpu()
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
2023-01-29 23:46:44 +00:00
|
|
|
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
2023-01-03 06:53:32 +00:00
|
|
|
config = OmegaConf.load(config_path)
|
|
|
|
model_config_params = config['model']['params']
|
|
|
|
clip_config = model_config_params['cond_stage_config']
|
|
|
|
scale_factor = model_config_params['scale_factor']
|
|
|
|
vae_config = model_config_params['first_stage_config']
|
|
|
|
|
|
|
|
clip = None
|
|
|
|
vae = None
|
|
|
|
|
|
|
|
class WeightsLoader(torch.nn.Module):
|
|
|
|
pass
|
|
|
|
|
|
|
|
w = WeightsLoader()
|
|
|
|
load_state_dict_to = []
|
|
|
|
if output_vae:
|
|
|
|
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
|
|
|
w.first_stage_model = vae.first_stage_model
|
|
|
|
load_state_dict_to = [w]
|
|
|
|
|
|
|
|
if output_clip:
|
2023-01-29 23:46:44 +00:00
|
|
|
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
2023-01-03 06:53:32 +00:00
|
|
|
w.cond_stage_model = clip.cond_stage_model
|
|
|
|
load_state_dict_to = [w]
|
|
|
|
|
|
|
|
model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
|
|
|
|
return (model, clip, vae)
|