diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index d8c65c2b..a7d164bf 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -6,6 +6,7 @@ import comfy.utils from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): + source = source.to(destination.device) if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") @@ -20,7 +21,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou if mask is None: mask = torch.ones_like(source) else: - mask = mask.clone() + mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])