Print the torch device that is used on startup.
This commit is contained in:
parent
b0505eb7ab
commit
3a1f47764d
|
@ -127,6 +127,32 @@ if args.cpu:
|
||||||
|
|
||||||
print(f"Set vram state to: {vram_state.name}")
|
print(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
|
def get_torch_device():
|
||||||
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
|
if directml_enabled:
|
||||||
|
global directml_device
|
||||||
|
return directml_device
|
||||||
|
if vram_state == VRAMState.MPS:
|
||||||
|
return torch.device("mps")
|
||||||
|
if vram_state == VRAMState.CPU:
|
||||||
|
return torch.device("cpu")
|
||||||
|
else:
|
||||||
|
if xpu_available:
|
||||||
|
return torch.device("xpu")
|
||||||
|
else:
|
||||||
|
return torch.cuda.current_device()
|
||||||
|
|
||||||
|
def get_torch_device_name(device):
|
||||||
|
if hasattr(device, 'type'):
|
||||||
|
return "{}".format(device.type)
|
||||||
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Using device:", get_torch_device_name(get_torch_device()))
|
||||||
|
except:
|
||||||
|
print("Could not pick default device.")
|
||||||
|
|
||||||
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
current_gpu_controlnets = []
|
current_gpu_controlnets = []
|
||||||
|
@ -233,22 +259,6 @@ def unload_if_low_vram(model):
|
||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_torch_device():
|
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
|
||||||
if directml_enabled:
|
|
||||||
global directml_device
|
|
||||||
return directml_device
|
|
||||||
if vram_state == VRAMState.MPS:
|
|
||||||
return torch.device("mps")
|
|
||||||
if vram_state == VRAMState.CPU:
|
|
||||||
return torch.device("cpu")
|
|
||||||
else:
|
|
||||||
if xpu_available:
|
|
||||||
return torch.device("xpu")
|
|
||||||
else:
|
|
||||||
return torch.cuda.current_device()
|
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
|
|
Loading…
Reference in New Issue