SV3D support.
This commit is contained in:
parent
0b78213bda
commit
40e124c6be
|
@ -380,6 +380,36 @@ class SVD_img2vid(BaseModel):
|
|||
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
||||
return out
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
def encode_adm(self, **kwargs):
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
|
||||
class SV3D_p(SVD_img2vid):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder_512 = Timestep(512)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
|
||||
azimuth = kwargs.get("azimuth", 0)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
||||
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
|
||||
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))
|
||||
|
||||
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
|
||||
return torch.cat(out, dim=1)
|
||||
|
||||
|
||||
class Stable_Zero123(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
|
|
@ -284,6 +284,41 @@ class SVD_img2vid(supported_models_base.BASE):
|
|||
def clip_target(self):
|
||||
return None
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 256,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
vae_key_prefix = ["conditioner.embedders.1.encoder."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_u(self, device=device)
|
||||
return out
|
||||
|
||||
class SV3D_p(SV3D_u):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 1280,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_p(self, device=device)
|
||||
return out
|
||||
|
||||
class Stable_Zero123(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
|
@ -405,5 +440,5 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
|||
return out
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
||||
models += [SVD_img2vid]
|
||||
|
|
|
@ -29,8 +29,8 @@ class StableZero123_Conditioning:
|
|||
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
@ -62,10 +62,10 @@ class StableZero123_Conditioning_Batched:
|
|||
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
@ -95,8 +95,49 @@ class StableZero123_Conditioning_Batched:
|
|||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
|
||||
|
||||
class SV3D_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"init_image": ("IMAGE",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/3d_models"
|
||||
|
||||
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
|
||||
output = clip_vision.encode_image(init_image)
|
||||
pooled = output.image_embeds.unsqueeze(0)
|
||||
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||
encode_pixels = pixels[:,:,:,:3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
azimuth = 0
|
||||
azimuth_increment = 360 / (max(video_frames, 2) - 1)
|
||||
|
||||
elevations = []
|
||||
azimuths = []
|
||||
for i in range(video_frames):
|
||||
elevations.append(elevation)
|
||||
azimuths.append(azimuth)
|
||||
azimuth += azimuth_increment
|
||||
|
||||
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
|
||||
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
|
||||
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableZero123_Conditioning": StableZero123_Conditioning,
|
||||
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
|
||||
"SV3D_Conditioning": SV3D_Conditioning,
|
||||
}
|
||||
|
|
|
@ -79,6 +79,33 @@ class VideoLinearCFGGuidance:
|
|||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
|
||||
class VideoTriangleCFGGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "sampling/video_models"
|
||||
|
||||
def patch(self, model, min_cfg):
|
||||
def linear_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
period = 1.0
|
||||
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
|
||||
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
|
||||
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
|
||||
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
|
@ -98,6 +125,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
||||
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue