Add a LoadImageMask node to load one colour channel in an image as a mask.
This commit is contained in:
parent
d75003001a
commit
e87a8669b6
48
nodes.py
48
nodes.py
|
@ -410,11 +410,8 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
|||
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent['noise_mask']
|
||||
print(noise_mask.shape, noise.shape)
|
||||
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.floor()
|
||||
noise_mask = torch.ones_like(noise_mask) - noise_mask
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
|
@ -581,10 +578,11 @@ class LoadImage:
|
|||
FUNCTION = "load_image"
|
||||
def load_image(self, image):
|
||||
image_path = os.path.join(self.input_dir, image)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
i = Image.open(image_path)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image[None])[None,]
|
||||
return image
|
||||
image = torch.from_numpy(image)[None,]
|
||||
return (image,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
|
@ -594,6 +592,41 @@ class LoadImage:
|
|||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
class LoadImageMask:
|
||||
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"image": (os.listdir(s.input_dir), ),
|
||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
image_path = os.path.join(self.input_dir, image)
|
||||
i = Image.open(image_path)
|
||||
mask = None
|
||||
c = channel[0].upper()
|
||||
if c in i.getbands():
|
||||
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(mask)
|
||||
if c == 'A':
|
||||
mask = 1. - mask
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
return (mask,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
image_path = os.path.join(s.input_dir, image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
class ImageScale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
@ -626,6 +659,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"LatentUpscale": LatentUpscale,
|
||||
"SaveImage": SaveImage,
|
||||
"LoadImage": LoadImage,
|
||||
"LoadImageMask": LoadImageMask,
|
||||
"ImageScale": ImageScale,
|
||||
"ConditioningCombine": ConditioningCombine,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
|
|
Loading…
Reference in New Issue