Use .itemsize to get dtype size for fp8.

This commit is contained in:
comfyanonymous 2023-12-04 11:52:06 -05:00
parent 31b0f6f3d8
commit ca82ade765
1 changed files with 7 additions and 0 deletions

View File

@ -430,6 +430,13 @@ def dtype_size(dtype):
dtype_size = 4 dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16: if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2 dtype_size = 2
elif dtype == torch.float32:
dtype_size = 4
else:
try:
dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize
pass
return dtype_size return dtype_size
def unet_offload_device(): def unet_offload_device():