Fix bug.
This commit is contained in:
parent
2d880fec3a
commit
fcef47f06e
|
@ -156,10 +156,10 @@ class SDXLRefiner(BaseModel):
|
||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||||
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
||||||
flat = torch.flatten(torch.cat(out))[None, ]
|
flat = torch.flatten(torch.cat(out))[None, ]
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
@ -180,11 +180,11 @@ class SDXL(BaseModel):
|
||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||||
out.append(self.embedder(torch.Tensor([target_width])))
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
out.append(self.embedder(torch.Tensor([target_height])))
|
out.append(self.embedder(torch.Tensor([target_height])))
|
||||||
|
out.append(self.embedder(torch.Tensor([target_width])))
|
||||||
flat = torch.flatten(torch.cat(out))[None, ]
|
flat = torch.flatten(torch.cat(out))[None, ]
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
Loading…
Reference in New Issue