Use .itemsize to get dtype size for fp8.
This commit is contained in:
parent
31b0f6f3d8
commit
ca82ade765
|
@ -430,6 +430,13 @@ def dtype_size(dtype):
|
|||
dtype_size = 4
|
||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
dtype_size = 2
|
||||
elif dtype == torch.float32:
|
||||
dtype_size = 4
|
||||
else:
|
||||
try:
|
||||
dtype_size = dtype.itemsize
|
||||
except: #Old pytorch doesn't have .itemsize
|
||||
pass
|
||||
return dtype_size
|
||||
|
||||
def unet_offload_device():
|
||||
|
|
Loading…
Reference in New Issue