Switch some more prints to logging.

This commit is contained in:
comfyanonymous 2024-03-11 16:24:47 -04:00
parent 0ed72befe1
commit 2a813c3b09
10 changed files with 40 additions and 34 deletions

View File

@ -4,6 +4,7 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
import logging
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
@ -20,7 +21,7 @@ ops = comfy.ops.disable_weight_init
# CrossAttn precision handling # CrossAttn precision handling
if args.dont_upcast_attention: if args.dont_upcast_attention:
print("disabling upcasting of attention") logging.info("disabling upcasting of attention")
_ATTN_PRECISION = "fp16" _ATTN_PRECISION = "fp16"
else: else:
_ATTN_PRECISION = "fp32" _ATTN_PRECISION = "fp32"
@ -274,12 +275,12 @@ def attention_split(q, k, v, heads, mask=None):
model_management.soft_empty_cache(True) model_management.soft_empty_cache(True)
if cleared_cache == False: if cleared_cache == False:
cleared_cache = True cleared_cache = True
print("out of memory error, emptying cache and trying again") logging.warning("out of memory error, emptying cache and trying again")
continue continue
steps *= 2 steps *= 2
if steps > 64: if steps > 64:
raise e raise e
print("out of memory error, increasing steps and trying again", steps) logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
else: else:
raise e raise e
@ -351,17 +352,17 @@ def attention_pytorch(q, k, v, heads, mask=None):
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") logging.info("Using xformers cross attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention") logging.info("Using pytorch cross attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch
else: else:
if args.use_split_cross_attention: if args.use_split_cross_attention:
print("Using split optimization for cross attention") logging.info("Using split optimization for cross attention")
optimized_attention = attention_split optimized_attention = attention_split
else: else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention

View File

@ -5,6 +5,7 @@ import torch.nn as nn
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
import logging
from comfy import model_management from comfy import model_management
import comfy.ops import comfy.ops
@ -190,7 +191,7 @@ def slice_attention(q, k, v):
steps *= 2 steps *= 2
if steps > 128: if steps > 128:
raise e raise e
print("out of memory error, increasing steps and trying again", steps) logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
return r1 return r1
@ -235,7 +236,7 @@ def pytorch_attention(q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W) out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention") logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out return out
@ -268,13 +269,13 @@ class AttnBlock(nn.Module):
padding=0) padding=0)
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
print("Using xformers attention in VAE") logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
print("Using pytorch attention in VAE") logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention self.optimized_attention = pytorch_attention
else: else:
print("Using split attention in VAE") logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention self.optimized_attention = normal_attention
def forward(self, x): def forward(self, x):
@ -562,7 +563,7 @@ class Decoder(nn.Module):
block_in = ch*ch_mult[self.num_resolutions-1] block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1) curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res) self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format( logging.debug("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape))) self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in

View File

@ -4,6 +4,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
import logging
from .util import ( from .util import (
checkpoint, checkpoint,
@ -359,7 +360,7 @@ def apply_control(h, control, name):
try: try:
h += ctrl h += ctrl
except: except:
print("warning control could not be applied", h.shape, ctrl.shape) logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h return h
class UNetModel(nn.Module): class UNetModel(nn.Module):
@ -496,7 +497,7 @@ class UNetModel(nn.Module):
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer") logging.debug("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential": elif self.num_classes == "sequential":
assert adm_in_channels is not None assert adm_in_channels is not None

View File

@ -14,6 +14,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import math import math
import logging
try: try:
from typing import Optional, NamedTuple, List, Protocol from typing import Optional, NamedTuple, List, Protocol
@ -170,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True) summed = torch.sum(attn_scores, dim=-1, keepdim=True)

View File

@ -4,6 +4,7 @@ import torch
import collections import collections
from comfy import model_management from comfy import model_management
import math import math
import logging
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
@ -625,7 +626,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True) sigmas = normal_scheduler(model, steps, sgm=True)
else: else:
print("error invalid scheduler", scheduler_name) logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas return sigmas
def sampler_object(name): def sampler_object(name):

View File

@ -1,7 +1,7 @@
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
import torch import torch
import logging
def Fourier_filter(x, threshold, scale): def Fourier_filter(x, threshold, scale):
# FFT # FFT
@ -49,7 +49,7 @@ class FreeU:
try: try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except: except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else: else:
@ -95,7 +95,7 @@ class FreeU_V2:
try: try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except: except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else: else:

View File

@ -1,6 +1,7 @@
import comfy.utils import comfy.utils
import folder_paths import folder_paths
import torch import torch
import logging
def load_hypernetwork_patch(path, strength): def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True) sd = comfy.utils.load_torch_file(path, safe_load=True)
@ -23,7 +24,7 @@ def load_hypernetwork_patch(path, strength):
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
return None return None
out = {} out = {}

18
main.py
View File

@ -54,15 +54,15 @@ import threading
import gc import gc
from comfy.cli_args import args from comfy.cli_args import args
import logging
if os.name == "nt": if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__": if __name__ == "__main__":
if args.cuda_device is not None: if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.deterministic: if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
@ -88,7 +88,7 @@ def cuda_malloc_warning():
if b in device_name: if b in device_name:
cuda_malloc_warning = True cuda_malloc_warning = True
if cuda_malloc_warning: if cuda_malloc_warning:
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
@ -121,7 +121,7 @@ def prompt_worker(q, server):
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time)) logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags() flags = q.get_flags()
free_memory = flags.get("free_memory", False) free_memory = flags.get("free_memory", False)
@ -182,14 +182,14 @@ def load_extra_path_config(yaml_path):
full_path = y full_path = y
if base_path is not None: if base_path is not None:
full_path = os.path.join(base_path, full_path) full_path = os.path.join(base_path, full_path)
print("Adding extra search path", x, full_path) logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path) folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__": if __name__ == "__main__":
if args.temp_directory: if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
print(f"Setting temp directory to: {temp_dir}") logging.info(f"Setting temp directory to: {temp_dir}")
folder_paths.set_temp_directory(temp_dir) folder_paths.set_temp_directory(temp_dir)
cleanup_temp() cleanup_temp()
@ -224,7 +224,7 @@ if __name__ == "__main__":
if args.output_directory: if args.output_directory:
output_dir = os.path.abspath(args.output_directory) output_dir = os.path.abspath(args.output_directory)
print(f"Setting output directory to: {output_dir}") logging.info(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir) folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
@ -234,7 +234,7 @@ if __name__ == "__main__":
if args.input_directory: if args.input_directory:
input_dir = os.path.abspath(args.input_directory) input_dir = os.path.abspath(args.input_directory)
print(f"Setting input directory to: {input_dir}") logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir) folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci: if args.quick_test_for_ci:
@ -252,6 +252,6 @@ if __name__ == "__main__":
try: try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nStopped server") logging.info("\nStopped server")
cleanup_temp() cleanup_temp()

View File

@ -1904,7 +1904,7 @@ def load_custom_node(module_path, ignore=set()):
return False return False
except Exception as e: except Exception as e:
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
logging.warning(f"Cannot import {module_path} module for custom nodes:", e) logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
return False return False
def load_custom_nodes(): def load_custom_nodes():

View File

@ -413,8 +413,8 @@ class PromptServer():
try: try:
out[x] = node_info(x) out[x] = node_info(x)
except Exception as e: except Exception as e:
print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr) logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
traceback.print_exc() logging.error(traceback.format_exc())
return web.json_response(out) return web.json_response(out)
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
@ -641,6 +641,6 @@ class PromptServer():
json_data = handler(json_data) json_data = handler(json_data)
except Exception as e: except Exception as e:
logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing") logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
traceback.print_exc() logging.warning(traceback.format_exc())
return json_data return json_data