Tweak memory inference calculations a bit.
This commit is contained in:
parent
96c2deeefb
commit
be71bb5e13
|
@ -164,12 +164,13 @@ class BaseModel(torch.nn.Module):
|
|||
self.inpaint_model = True
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
#TODO: this needs to be tweaked
|
||||
return (area / (comfy.model_management.dtype_size(self.get_dtype()) * 10)) * (1024 * 1024)
|
||||
area = max(input_shape[0], 3) * input_shape[2] * input_shape[3]
|
||||
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 60) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue