CLIPVisionEncode can now encode multiple images.

This commit is contained in:
comfyanonymous 2023-08-14 16:54:05 -04:00
parent 0cb6dac943
commit 9cc12c833d
3 changed files with 12 additions and 12 deletions

View File

@ -24,8 +24,8 @@ class ClipVisionModel():
return self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
img = torch.clip((255. * image[0]), 0, 255).round().int()
inputs = self.processor(images=[img], return_tensors="pt")
img = torch.clip((255. * image), 0, 255).round().int()
inputs = self.processor(images=img, return_tensors="pt")
outputs = self.model(**inputs)
return outputs

View File

@ -120,15 +120,15 @@ class SD21UNCLIP(BaseModel):
weights = []
noise_aug = []
for unclip_cond in unclip_conditioning:
adm_cond = unclip_cond["clip_vision_output"].image_embeds
weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)

View File

@ -771,7 +771,7 @@ class StyleModelApply:
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
cond = style_model.get_cond(clip_vision_output)
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
c = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]