diff --git a/execution.py b/execution.py index 79c9a3ac..9d9ca5f6 100644 --- a/execution.py +++ b/execution.py @@ -10,6 +10,8 @@ import gc import torch import nodes +from model_management import xpu_available + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -206,6 +208,8 @@ class PromptExecutor: if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache() torch.cuda.ipc_collect() + elif xpu_available: + torch.xpu.empty_cache() def validate_inputs(prompt, item):