diff --git a/comfy/hooks.py b/comfy/hooks.py index 3fc8c833..86ade0c4 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -19,6 +19,7 @@ class EnumHookMode(enum.Enum): class EnumHookType(enum.Enum): Weight = "weight" Patch = "patch" + AddModel = "addmodel" class EnumWeightTarget(enum.Enum): Model = "model" @@ -121,10 +122,22 @@ class PatchHook(Hook): def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) - c: PatchHook = super().clone(type(self)) + c: PatchHook = super().clone(subtype) c.patches = self.patches return c +class AddModelHook(Hook): + def __init__(self, model: 'ModelPatcher'): + super().__init__(hook_type=EnumHookType.AddModel) + self.model = model + + def clone(self, subtype: Callable=None): + if subtype is None: + subtype = type(self) + c: AddModelHook = super().clone(subtype) + c.model = self.model + return c + class HookGroup: def __init__(self): self.hooks: List[Hook] = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 640da31f..7563d8cd 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -108,6 +108,8 @@ class CallbacksMP: ON_PREPARE_STATE = "on_prepare_state" ON_APPLY_HOOKS = "on_apply_hooks" ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches" + ON_INJECT_MODEL = "on_inject_model" + ON_EJECT_MODEL = "on_eject_model" @classmethod def init_callbacks(cls): @@ -119,8 +121,37 @@ class CallbacksMP: cls.ON_PREPARE_STATE: [], cls.ON_APPLY_HOOKS: [], cls.ON_REGISTER_ALL_HOOK_PATCHES: [], + cls.ON_INJECT_MODEL: [], + cls.ON_EJECT_MODEL: [], } +class AutoPatcherEjector: + def __init__(self, model: 'ModelPatcher', skip_until_exit=False): + self.model = model + self.was_injected = False + self.prev_skip_injection = False + self.skip_until_exit = skip_until_exit + + def __enter__(self): + self.was_injected = False + self.prev_skip_injection = self.model.skip_injection + if self.skip_until_exit: + self.model.skip_injection = True + if self.model.is_injected: + self.model.eject_model() + self.was_injected = True + + def __exit__(self, *args): + if self.was_injected: + if self.skip_until_exit or not self.model.skip_injection: + self.model.inject_model() + self.model.skip_injection = self.prev_skip_injection + +class PatcherInjection: + def __init__(self, inject: Callable, eject: Callable): + self.inject = inject + self.eject = eject + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -143,9 +174,13 @@ class ModelPatcher: self.patches_uuid = uuid.uuid4() self.attachments: Dict[str] = {} - self.additional_models: list[ModelPatcher] = [] + self.additional_models: Dict[str, List[ModelPatcher]] = {} self.callbacks: Dict[str, List[Callable]] = CallbacksMP.init_callbacks() + self.is_injected = False + self.skip_injection = False + self.injections: Dict[str, List[PatcherInjection]] = {} + self.hook_patches: Dict[comfy.hooks._HookRef] = {} self.hook_patches_backup: Dict[comfy.hooks._HookRef] = {} self.hook_backup: Dict[str, Tuple[torch.Tensor, torch.device]] = {} @@ -196,11 +231,16 @@ class ModelPatcher: else: n.attachments[k] = self.attachments[k] # additional models - for m in self.additional_models: - n.additional_models.append(m.clone()) + for k, c in self.additional_models.items(): + n.additional_models[k] = [x.clone() for x in c] # callbacks for k, c in self.callbacks.items(): n.callbacks[k] = c.copy() + # injection + n.is_injected = self.is_injected + n.skip_injection = self.skip_injection + for k, i in self.injections.items(): + n.injections[k] = i.copy() # hooks n.hook_patches = create_hook_patches_clone(self.hook_patches) n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) @@ -342,27 +382,28 @@ class ModelPatcher: return self.model.get_dtype() def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] + with self.use_ejected(): + p = set() + model_sd = self.model.state_dict() + for k in patches: + offset = None + function = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] - if key in model_sd: - p.add(k) - current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - self.patches[key] = current_patches + if key in model_sd: + p.add(k) + current_patches = self.patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + self.patches[key] = current_patches - self.patches_uuid = uuid.uuid4() - return list(p) + self.patches_uuid = uuid.uuid4() + return list(p) def get_key_patches(self, filter_prefix=None): model_sd = self.model_state_dict() @@ -386,13 +427,14 @@ class ModelPatcher: return p def model_state_dict(self, filter_prefix=None): - sd = self.model.state_dict() - keys = list(sd.keys()) - if filter_prefix is not None: - for k in keys: - if not k.startswith(filter_prefix): - sd.pop(k) - return sd + with self.use_ejected(): + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd def patch_weight_to_device(self, key, device_to=None, inplace_update=False): if key not in self.patches: @@ -417,108 +459,116 @@ class ModelPatcher: comfy.utils.set_attr_param(self.model, key, out_weight) def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): - self.unpatch_hooks() - mem_counter = 0 - patch_counter = 0 - lowvram_counter = 0 - loading = [] - for n, m in self.model.named_modules(): - if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"): - loading.append((comfy.model_management.module_size(m), n, m)) + with self.use_ejected(): + self.unpatch_hooks() + mem_counter = 0 + patch_counter = 0 + lowvram_counter = 0 + loading = [] + for n, m in self.model.named_modules(): + if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"): + loading.append((comfy.model_management.module_size(m), n, m)) - load_completely = [] - loading.sort(reverse=True) - for x in loading: - n = x[1] - m = x[2] - module_mem = x[0] + load_completely = [] + loading.sort(reverse=True) + for x in loading: + n = x[1] + m = x[2] + module_mem = x[0] - lowvram_weight = False + lowvram_weight = False - if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: - lowvram_weight = True - lowvram_counter += 1 - if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed + if not full_load and hasattr(m, "comfy_cast_weights"): + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + lowvram_counter += 1 + if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed + continue + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + if lowvram_weight: + if weight_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "comfy_cast_weights"): + if m.comfy_cast_weights: + wipe_lowvram_weight(m) + + if hasattr(m, "weight"): + mem_counter += module_mem + load_completely.append((module_mem, n, m)) + + load_completely.sort(reverse=True) + for x in load_completely: + n = x[1] + m = x[2] + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if hasattr(m, "comfy_patched_weights"): + if m.comfy_patched_weights == True: continue - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + self.patch_weight_to_device(weight_key, device_to=device_to) + self.patch_weight_to_device(bias_key, device_to=device_to) + logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + m.comfy_patched_weights = True - if lowvram_weight: - if weight_key in self.patches: - if force_patch_weights: - self.patch_weight_to_device(weight_key) - else: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - if force_patch_weights: - self.patch_weight_to_device(bias_key) - else: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 + for x in load_completely: + x[2].to(device_to) - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True + if lowvram_counter > 0: + logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + self.model.model_lowvram = True else: - if hasattr(m, "comfy_cast_weights"): - if m.comfy_cast_weights: - wipe_lowvram_weight(m) + logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + self.model.model_lowvram = False + if full_load: + self.model.to(device_to) + mem_counter = self.model_size() - if hasattr(m, "weight"): - mem_counter += module_mem - load_completely.append((module_mem, n, m)) + self.model.lowvram_patch_counter += patch_counter + self.model.device = device_to + self.model.model_loaded_weight_memory = mem_counter - load_completely.sort(reverse=True) - for x in load_completely: - n = x[1] - m = x[2] - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if hasattr(m, "comfy_patched_weights"): - if m.comfy_patched_weights == True: - continue + for callback in self.callbacks[CallbacksMP.ON_LOAD]: + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) - self.patch_weight_to_device(weight_key, device_to=device_to) - self.patch_weight_to_device(bias_key, device_to=device_to) - logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) - m.comfy_patched_weights = True - - for x in load_completely: - x[2].to(device_to) - - if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) - self.model.model_lowvram = True - else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) - self.model.model_lowvram = False - if full_load: - self.model.to(device_to) - mem_counter = self.model_size() - - self.model.lowvram_patch_counter += patch_counter - self.model.device = device_to - self.model.model_loaded_weight_memory = mem_counter - self.apply_hooks(self.forced_hooks) + self.apply_hooks(self.forced_hooks) def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): - for k in self.object_patches: - old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) - if k not in self.object_patches_backup: - self.object_patches_backup[k] = old + with self.use_ejected(): + for k in self.object_patches: + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old - if lowvram_model_memory == 0: - full_load = True - else: - full_load = False + if lowvram_model_memory == 0: + full_load = True + else: + full_load = False - if load_weights: - self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) + if load_weights: + self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) + self.inject_model() return self.model def unpatch_model(self, device_to=None, unpatch_weights=True): + self.eject_model() if unpatch_weights: self.unpatch_hooks() if self.model.model_lowvram: @@ -555,66 +605,68 @@ class ModelPatcher: self.object_patches_backup.clear() def partially_unload(self, device_to, memory_to_free=0): - memory_freed = 0 - patch_counter = 0 - unload_list = [] + with self.use_ejected(): + memory_freed = 0 + patch_counter = 0 + unload_list = [] - for n, m in self.model.named_modules(): - shift_lowvram = False - if hasattr(m, "comfy_cast_weights"): - module_mem = comfy.model_management.module_size(m) - unload_list.append((module_mem, n, m)) + for n, m in self.model.named_modules(): + shift_lowvram = False + if hasattr(m, "comfy_cast_weights"): + module_mem = comfy.model_management.module_size(m) + unload_list.append((module_mem, n, m)) - unload_list.sort() - for unload in unload_list: - if memory_to_free < memory_freed: - break - module_mem = unload[0] - n = unload[1] - m = unload[2] - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + unload_list.sort() + for unload in unload_list: + if memory_to_free < memory_freed: + break + module_mem = unload[0] + n = unload[1] + m = unload[2] + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) - if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: - for key in [weight_key, bias_key]: - bk = self.backup.get(key, None) - if bk is not None: - if bk.inplace_update: - comfy.utils.copy_to_param(self.model, key, bk.weight) - else: - comfy.utils.set_attr_param(self.model, key, bk.weight) - self.backup.pop(key) + if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: + for key in [weight_key, bias_key]: + bk = self.backup.get(key, None) + if bk is not None: + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, key, bk.weight) + else: + comfy.utils.set_attr_param(self.model, key, bk.weight) + self.backup.pop(key) - m.to(device_to) - if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 + m.to(device_to) + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - m.comfy_patched_weights = False - memory_freed += module_mem - logging.debug("freed {}".format(n)) + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + m.comfy_patched_weights = False + memory_freed += module_mem + logging.debug("freed {}".format(n)) - self.model.model_lowvram = True - self.model.lowvram_patch_counter += patch_counter - self.model.model_loaded_weight_memory -= memory_freed - return memory_freed + self.model.model_lowvram = True + self.model.lowvram_patch_counter += patch_counter + self.model.model_loaded_weight_memory -= memory_freed + return memory_freed def partially_load(self, device_to, extra_memory=0): - self.unpatch_model(unpatch_weights=False) - self.patch_model(load_weights=False) - full_load = False - if self.model.model_lowvram == False: - return 0 - if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): - full_load = True - current_used = self.model.model_loaded_weight_memory - self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) - return self.model.model_loaded_weight_memory - current_used + with self.use_ejected(skip_injection=True): + self.unpatch_model(unpatch_weights=False) + self.patch_model(load_weights=False) + full_load = False + if self.model.model_lowvram == False: + return 0 + if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): + full_load = True + current_used = self.model.model_loaded_weight_memory + self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) + return self.model.model_loaded_weight_memory - current_used def current_loaded_device(self): return self.model.device @@ -629,13 +681,49 @@ class ModelPatcher: for callback in self.callbacks[CallbacksMP.ON_CLEANUP]: callback(self) - def add_callback(self, key, callback: Callable): + def get_all_additional_models(self): + all_models = [] + for models in self.additional_models.values(): + all_models.extend(models) + return all_models + + def add_callback(self, key: str, callback: Callable): if key not in self.callbacks: raise Exception(f"Callback '{key}' is not recognized.") self.callbacks[key].append(callback) - def add_attachment(self, attachment): - self.attachments.append(attachment) + def set_attachments(self, key: str, attachment): + self.attachments[key] = attachment + + def set_injections(self, key: str, injections: List[PatcherInjection]): + self.injections[key] = injections + + def set_additional_models(self, key: str, models: List['ModelPatcher']): + self.additional_models[key] = models + + def use_ejected(self, skip_injection=False): + return AutoPatcherEjector(self, skip_until_exit=skip_injection) + + def inject_model(self): + if self.is_injected or self.skip_injection: + return + for injections in self.injections.values(): + for inj in injections: + inj.inject(self) + self.is_injected = True + if self.is_injected: + for callback in self.callbacks[CallbacksMP.ON_INJECT_MODEL]: + callback(self) + + def eject_model(self): + if not self.is_injected: + return + for injections in self.injections.values(): + for inj in injections: + inj.eject(self) + self.is_injected = False + for callback in self.callbacks[CallbacksMP.ON_EJECT_MODEL]: + callback(self) def pre_run(self): for callback in self.callbacks[CallbacksMP.ON_PRE_RUN]: @@ -685,58 +773,60 @@ class ModelPatcher: callback(self, hooks_dict, target) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0, is_diff=False): - # NOTE: this mirrors behavior of add_patches func - if is_diff: - comfy.model_management.unload_model_clones(self) - current_hook_patches: Dict[str,List] = self.hook_patches.get(hook.hook_ref, {}) - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] - - if key in model_sd: - p.add(k) - current_patches: List[Tuple] = current_hook_patches.get(key, []) - if is_diff: - # take difference between desired weight and existing weight to get diff - # TODO: try to implement diff via strength_path/strength_model diff - model_dtype = comfy.utils.get_attr(self.model, key).dtype - if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: - diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, key).to(torch.float32)).to(model_dtype) - else: - diff_weight = patches[k]-comfy.utils.get_attr(self.model, key) - current_patches.append((strength_patch, (diff_weight,), strength_model, offset, function)) + with self.use_ejected(): + # NOTE: this mirrors behavior of add_patches func + if is_diff: + comfy.model_management.unload_model_clones(self) + current_hook_patches: Dict[str,List] = self.hook_patches.get(hook.hook_ref, {}) + p = set() + model_sd = self.model.state_dict() + for k in patches: + offset = None + function = None + if isinstance(k, str): + key = k else: - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - current_hook_patches[key] = current_patches - self.hook_patches[hook.hook_ref] = current_hook_patches - # since should care about these patches too to determine if same model, reroll patches_uuid - self.patches_uuid = uuid.uuid4() - return list(p) + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] + + if key in model_sd: + p.add(k) + current_patches: List[Tuple] = current_hook_patches.get(key, []) + if is_diff: + # take difference between desired weight and existing weight to get diff + # TODO: try to implement diff via strength_path/strength_model diff + model_dtype = comfy.utils.get_attr(self.model, key).dtype + if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, key).to(torch.float32)).to(model_dtype) + else: + diff_weight = patches[k]-comfy.utils.get_attr(self.model, key) + current_patches.append((strength_patch, (diff_weight,), strength_model, offset, function)) + else: + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + current_hook_patches[key] = current_patches + self.hook_patches[hook.hook_ref] = current_hook_patches + # since should care about these patches too to determine if same model, reroll patches_uuid + self.patches_uuid = uuid.uuid4() + return list(p) def get_weight_diffs(self, patches): - comfy.model_management.unload_model_clones(self) - weights: Dict[str, Tuple] = {} - p = set() - model_sd = self.model.state_dict() - for k in patches: - if k in model_sd: - p.add(k) - model_dtype = comfy.utils.get_attr(self.model, k).dtype - if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: - diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, k).to(torch.float32)).to(model_dtype) - else: - diff_weight = patches[k]-comfy.utils.get_attr(self.model, k) - weights[k] = (diff_weight,) - return weights, p + with self.use_ejected(): + comfy.model_management.unload_model_clones(self) + weights: Dict[str, Tuple] = {} + p = set() + model_sd = self.model.state_dict() + for k in patches: + if k in model_sd: + p.add(k) + model_dtype = comfy.utils.get_attr(self.model, k).dtype + if model_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + diff_weight = (patches[k].to(torch.float32)-comfy.utils.get_attr(self.model, k).to(torch.float32)).to(model_dtype) + else: + diff_weight = patches[k]-comfy.utils.get_attr(self.model, k) + weights[k] = (diff_weight,) + return weights, p def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup): # combined_patches will contain weights of all relevant hooks, per key @@ -765,27 +855,28 @@ class ModelPatcher: callback(self, hooks) def patch_hooks(self, hooks: comfy.hooks.HookGroup): - self.unpatch_hooks() - model_sd = self.model_state_dict() - # if have cached weights for hooks, use it - cached_weights = self.cached_hook_patches.get(hooks, None) - if cached_weights is not None: - for key in cached_weights: - if key not in model_sd: - print(f"WARNING cached hook could not patch. key does not exist in model: {key}") - continue - self.patch_cached_hook_weights(cached_weights=cached_weights, key=key) - else: - relevant_patches = self.get_combined_hook_patches(hooks=hooks) - original_weights = None - if len(relevant_patches) > 0: - original_weights = self.get_key_patches() - for key in relevant_patches: - if key not in model_sd: - print(f"WARNING cached hook would not patch. key does not exist in model: {key}") - continue - self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights) - self.current_hooks = hooks + with self.use_ejected(): + self.unpatch_hooks() + model_sd = self.model_state_dict() + # if have cached weights for hooks, use it + cached_weights = self.cached_hook_patches.get(hooks, None) + if cached_weights is not None: + for key in cached_weights: + if key not in model_sd: + print(f"WARNING cached hook could not patch. key does not exist in model: {key}") + continue + self.patch_cached_hook_weights(cached_weights=cached_weights, key=key) + else: + relevant_patches = self.get_combined_hook_patches(hooks=hooks) + original_weights = None + if len(relevant_patches) > 0: + original_weights = self.get_key_patches() + for key in relevant_patches: + if key not in model_sd: + print(f"WARNING cached hook would not patch. key does not exist in model: {key}") + continue + self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights) + self.current_hooks = hooks def patch_cached_hook_weights(self, cached_weights: Dict, key: str): if key not in self.hook_backup: @@ -825,25 +916,26 @@ class ModelPatcher: comfy.utils.set_attr_param(self.model, key, out_weight) def unpatch_hooks(self) -> None: - if len(self.hook_backup) == 0: + with self.use_ejected(): + if len(self.hook_backup) == 0: + self.current_hooks = None + return + keys = list(self.hook_backup.keys()) + if self.weight_inplace_update: + for k in keys: + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # does not need to be cast; device already matches + comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0]) + else: + comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) + else: + for k in keys: + if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: + comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0]) + else: + comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) + + self.hook_backup.clear() self.current_hooks = None - return - keys = list(self.hook_backup.keys()) - if self.weight_inplace_update: - for k in keys: - if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # does not need to be cast; device already matches - comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0]) - else: - comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) - else: - for k in keys: - if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: - comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0]) - else: - comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) - - self.hook_backup.clear() - self.current_hooks = None def clean_hooks(self): self.unpatch_hooks() diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 49214847..af89a715 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -2,6 +2,11 @@ import torch import comfy.model_management import comfy.conds import comfy.hooks +from typing import TYPE_CHECKING, Dict, List +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.model_base import BaseModel + from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" @@ -15,9 +20,22 @@ def get_models_from_cond(cond, model_type): models = [] for c in cond: if model_type in c: - models += [c[model_type]] + if isinstance(c[model_type], list): + models += c[model_type] + else: + models += [c[model_type]] return models +def get_hooks_from_cond(cond, filter_types: List[comfy.hooks.EnumHookType]=None): + hooks: Dict[comfy.hooks.Hook, None] = {} + for c in cond: + if 'hooks' in c: + for hook in c['hooks'].hooks: + hook: comfy.hooks.Hook + if not filter_types or hook.hook_type in filter_types: + hooks[hook] = None + return hooks + def convert_cond(cond): out = [] for c in cond: @@ -32,12 +50,16 @@ def convert_cond(cond): def get_additional_models(conds, dtype): """loads additional models in conditioning""" - cnets = [] + cnets: List[ControlBase] = [] gligen = [] + add_models = [] + hooks: Dict[comfy.hooks.AddModelHook, None] = {} for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") + add_models += get_models_from_cond(conds[k], "additional_models") + hooks.update(get_hooks_from_cond(conds[k], [comfy.hooks.EnumHookType.AddModel])) control_nets = set(cnets) @@ -48,7 +70,9 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - models = control_models + gligen + hook_models = [x.model for x in hooks] + models = control_models + gligen + add_models + hook_models + return models, inference_memory def cleanup_additional_models(models): @@ -58,10 +82,11 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model, noise_shape, conds): +def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): device = model.load_device - real_model = None + real_model: 'BaseModel' = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += model.get_all_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) @@ -79,12 +104,9 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) -def prepare_model_patcher(model, conds): +def prepare_model_patcher(model: 'ModelPatcher', conds): # check for hooks in conds - if not registered, see if can be applied hooks = {} for k in conds: - for cond in conds[k]: - if 'hooks' in cond: - for hook in cond['hooks'].hooks: - hooks[hook] = None + hooks.update(get_hooks_from_cond(conds[k])) model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model)