Use .itemsize to get dtype size for fp8.
This commit is contained in:
parent
31b0f6f3d8
commit
ca82ade765
|
@ -430,6 +430,13 @@ def dtype_size(dtype):
|
||||||
dtype_size = 4
|
dtype_size = 4
|
||||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||||
dtype_size = 2
|
dtype_size = 2
|
||||||
|
elif dtype == torch.float32:
|
||||||
|
dtype_size = 4
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
dtype_size = dtype.itemsize
|
||||||
|
except: #Old pytorch doesn't have .itemsize
|
||||||
|
pass
|
||||||
return dtype_size
|
return dtype_size
|
||||||
|
|
||||||
def unet_offload_device():
|
def unet_offload_device():
|
||||||
|
|
Loading…
Reference in New Issue