very simple strong-cache on model list (#4969)

* very simple strong-cache on model list

* store the cache after validation too

* only cache object_info for now

* use a 'with' context
This commit is contained in:
Alex "mcmonkey" Goodwin 2024-09-19 17:40:14 +09:00 committed by GitHub
parent 0bfc7cc998
commit a1e71cfad1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 8 deletions

View File

@ -46,6 +46,36 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper()
extension_mimetypes_cache = {
"webp" : "image",
}
@ -257,6 +287,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
@ -285,6 +319,7 @@ def get_filename_list(folder_name: str) -> list[str]:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0])
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:

View File

@ -552,14 +552,15 @@ class PromptServer():
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
with folder_paths.cache_helper:
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):