Print the torch device that is used on startup.

This commit is contained in:
comfyanonymous 2023-05-13 17:11:27 -04:00
parent b0505eb7ab
commit 3a1f47764d
1 changed files with 26 additions and 16 deletions

View File

@ -127,6 +127,32 @@ if args.cpu:
print(f"Set vram state to: {vram_state.name}") print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device):
if hasattr(device, 'type'):
return "{}".format(device.type)
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Using device:", get_torch_device_name(get_torch_device()))
except:
print("Could not pick default device.")
current_loaded_model = None current_loaded_model = None
current_gpu_controlnets = [] current_gpu_controlnets = []
@ -233,22 +259,6 @@ def unload_if_low_vram(model):
return model.cpu() return model.cpu()
return model return model
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type