Support stable zero 123 model.
To use it use the ImageOnlyCheckpointLoader to load the checkpoint and the new Stable_Zero123 node.
This commit is contained in:
parent
2f9d6a97ec
commit
2258f85159
|
@ -328,3 +328,33 @@ class SVD_img2vid(BaseModel):
|
||||||
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
|
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
|
||||||
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Stable_Zero123(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
|
||||||
|
self.cc_projection.weight.copy_(cc_projection_weight)
|
||||||
|
self.cc_projection.bias.copy_(cc_projection_bias)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
latent_image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
|
||||||
|
if latent_image is None:
|
||||||
|
latent_image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
if latent_image.shape[1:] != noise.shape[1:]:
|
||||||
|
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
|
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
|
||||||
|
|
||||||
|
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
if cross_attn.shape[-1] != 768:
|
||||||
|
cross_attn = self.cc_projection(cross_attn)
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||||
|
return out
|
||||||
|
|
|
@ -47,7 +47,8 @@ def convert_cond(cond):
|
||||||
temp = c[1].copy()
|
temp = c[1].copy()
|
||||||
model_conds = temp.get("model_conds", {})
|
model_conds = temp.get("model_conds", {})
|
||||||
if c[0] is not None:
|
if c[0] is not None:
|
||||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0])
|
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||||
|
temp["cross_attn"] = c[0]
|
||||||
temp["model_conds"] = model_conds
|
temp["model_conds"] = model_conds
|
||||||
out.append(temp)
|
out.append(temp)
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -252,5 +252,32 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega]
|
class Stable_Zero123(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"context_dim": 768,
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": False,
|
||||||
|
"adm_in_channels": None,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
"in_channels": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {
|
||||||
|
"num_heads": 8,
|
||||||
|
"num_head_channels": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||||
|
|
||||||
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega]
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
import torch
|
||||||
|
import nodes
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
def camera_embeddings(elevation, azimuth):
|
||||||
|
elevation = torch.as_tensor([elevation])
|
||||||
|
azimuth = torch.as_tensor([azimuth])
|
||||||
|
embeddings = torch.stack(
|
||||||
|
[
|
||||||
|
torch.deg2rad(
|
||||||
|
(90 - elevation) - (90)
|
||||||
|
), # Zero123 polar is 90-elevation
|
||||||
|
torch.sin(torch.deg2rad(azimuth)),
|
||||||
|
torch.cos(torch.deg2rad(azimuth)),
|
||||||
|
torch.deg2rad(
|
||||||
|
90 - torch.full_like(elevation, 0)
|
||||||
|
),
|
||||||
|
], dim=-1).unsqueeze(1)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class Zero123_Conditioning:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||||
|
"init_image": ("IMAGE",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||||
|
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/3d_models"
|
||||||
|
|
||||||
|
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
|
||||||
|
output = clip_vision.encode_image(init_image)
|
||||||
|
pooled = output.image_embeds.unsqueeze(0)
|
||||||
|
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||||
|
encode_pixels = pixels[:,:,:,:3]
|
||||||
|
t = vae.encode(encode_pixels)
|
||||||
|
cam_embeds = camera_embeddings(elevation, azimuth)
|
||||||
|
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
|
||||||
|
|
||||||
|
positive = [[cond, {"concat_latent_image": t}]]
|
||||||
|
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
|
||||||
|
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||||
|
return (positive, negative, {"samples":latent})
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"Zero123_Conditioning": Zero123_Conditioning,
|
||||||
|
}
|
Loading…
Reference in New Issue