minor changes for tiled sampler
This commit is contained in:
parent
8ea165dd1e
commit
d9e088ddfd
|
@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||||
"""
|
"""
|
||||||
B, N, _ = metric.shape
|
B, N, _ = metric.shape
|
||||||
|
|
||||||
if r <= 0:
|
if r <= 0 or w == 1 or h == 1:
|
||||||
return do_nothing, do_nothing
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|
15
comfy/sd.py
15
comfy/sd.py
|
@ -581,10 +581,7 @@ class VAE:
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def resize_image_to(tensor, target_latent_tensor, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
|
|
||||||
target_batch_size = target_latent_tensor.shape[0]
|
|
||||||
|
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
print(current_batch_size, target_batch_size)
|
print(current_batch_size, target_batch_size)
|
||||||
if current_batch_size == 1:
|
if current_batch_size == 1:
|
||||||
|
@ -623,7 +620,9 @@ class ControlNet:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||||
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
if self.control_model.dtype == torch.float16:
|
if self.control_model.dtype == torch.float16:
|
||||||
precision_scope = torch.autocast
|
precision_scope = torch.autocast
|
||||||
|
@ -794,10 +793,14 @@ class T2IAdapter:
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
|
self.control_input = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
|
||||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||||
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
if self.control_input is None:
|
||||||
self.t2i_model.to(self.device)
|
self.t2i_model.to(self.device)
|
||||||
self.control_input = self.t2i_model(self.cond_hint)
|
self.control_input = self.t2i_model(self.cond_hint)
|
||||||
self.t2i_model.cpu()
|
self.t2i_model.cpu()
|
||||||
|
|
Loading…
Reference in New Issue