Add route to get safetensors metadata:

/view_metadata/loras?filename=lora.safetensors
This commit is contained in:
comfyanonymous 2023-05-29 02:48:50 -04:00
parent 23ffafeb5d
commit b9818eb910
3 changed files with 35 additions and 1 deletions

View File

@ -1,5 +1,6 @@
import torch import torch
import math import math
import struct
def load_torch_file(ckpt, safe_load=False): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
@ -50,6 +51,14 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd return sd
def safetensors_header(safetensors_path, max_size=100*1024*1024):
with open(safetensors_path, "rb") as f:
header = f.read(8)
length_of_header = struct.unpack('<Q', header)[0]
if length_of_header > max_size:
return None
return f.read(length_of_header)
def bislerp(samples, width, height): def bislerp(samples, width, height):
def slerp(b1, b2, r): def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''

View File

@ -126,11 +126,13 @@ def filter_files_extensions(files, extensions):
def get_full_path(folder_name, filename): def get_full_path(folder_name, filename):
global folder_names_and_paths global folder_names_and_paths
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
filename = os.path.relpath(os.path.join("/", filename), "/")
for x in folders[0]: for x in folders[0]:
full_path = os.path.join(x, filename) full_path = os.path.join(x, filename)
if os.path.isfile(full_path): if os.path.isfile(full_path):
return full_path return full_path
return None
def get_filename_list(folder_name): def get_filename_list(folder_name):
global folder_names_and_paths global folder_names_and_paths

View File

@ -22,7 +22,7 @@ except ImportError:
import mimetypes import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
@ -257,6 +257,29 @@ class PromptServer():
return web.Response(status=404) return web.Response(status=404)
@routes.get("/view_metadata/{folder_name}")
async def view_metadata(request):
folder_name = request.match_info.get("folder_name", None)
if folder_name is None:
return web.Response(status=404)
if not "filename" in request.rel_url.query:
return web.Response(status=404)
filename = request.rel_url.query["filename"]
if not filename.endswith(".safetensors"):
return web.Response(status=404)
safetensors_path = folder_paths.get_full_path(folder_name, filename)
if safetensors_path is None:
return web.Response(status=404)
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
if out is None:
return web.Response(status=404)
dt = json.loads(out)
if not "__metadata__" in dt:
return web.Response(status=404)
return web.json_response(dt["__metadata__"])
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())