Changed line endings to LF

This commit is contained in:
pythongosssss 2023-02-25 20:57:40 +00:00 committed by GitHub
parent 23507882d2
commit e053184a54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 201 additions and 201 deletions

402
server.py
View File

@ -1,202 +1,202 @@
import os import os
import sys import sys
import asyncio import asyncio
import nodes import nodes
import main import main
import uuid import uuid
import json import json
try: try:
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
except ImportError: except ImportError:
print("Module 'aiohttp' not installed. Please install it via:") print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp") print("pip install aiohttp")
print("or") print("or")
print("pip install -r requirements.txt") print("pip install -r requirements.txt")
sys.exit() sys.exit()
class PromptServer(): class PromptServer():
def __init__(self, loop): def __init__(self, loop):
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.number = 0 self.number = 0
self.app = web.Application() self.app = web.Application()
self.sockets = dict() self.sockets = dict()
self.web_root = os.path.join(os.path.dirname( self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "webshit") os.path.realpath(__file__)), "webshit")
routes = web.RouteTableDef() routes = web.RouteTableDef()
@routes.get('/ws') @routes.get('/ws')
async def websocket_handler(request): async def websocket_handler(request):
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
sid = uuid.uuid4().hex sid = uuid.uuid4().hex
self.sockets[sid] = ws self.sockets[sid] = ws
try: try:
# Send initial state to the new client # Send initial state to the new client
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid) await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
async for msg in ws: async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR: if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception()) print('ws connection closed with exception %s' % ws.exception())
finally: finally:
self.sockets.pop(sid) self.sockets.pop(sid)
return ws return ws
@routes.get("/") @routes.get("/")
async def get_root(request): async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html")) return web.FileResponse(os.path.join(self.web_root, "index.html"))
@routes.get("/view/{file}") @routes.get("/view/{file}")
async def view_image(request): async def view_image(request):
if "file" in request.match_info: if "file" in request.match_info:
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
file = request.match_info["file"] file = request.match_info["file"]
file = os.path.splitext(os.path.basename(file))[0] + ".png" file = os.path.splitext(os.path.basename(file))[0] + ".png"
file = os.path.join(output_dir, file) file = os.path.join(output_dir, file)
if os.path.isfile(file): if os.path.isfile(file):
return web.FileResponse(file) return web.FileResponse(file)
return web.Response(status=404) return web.Response(status=404)
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
out = {} out = {}
for x in nodes.NODE_CLASS_MAPPINGS: for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x] obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {} info = {}
info['input'] = obj_class.INPUT_TYPES() info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES info['output'] = obj_class.RETURN_TYPES
info['name'] = x #TODO info['name'] = x #TODO
info['description'] = '' info['description'] = ''
info['category'] = 'sd' info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'): if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY info['category'] = obj_class.CATEGORY
out[x] = info out[x] = info
return web.json_response(out) return web.json_response(out)
@routes.get("/history") @routes.get("/history")
async def get_history(request): async def get_history(request):
return web.json_response(self.prompt_queue.history) return web.json_response(self.prompt_queue.history)
@routes.get("/queue") @routes.get("/queue")
async def get_queue(request): async def get_queue(request):
queue_info = {} queue_info = {}
current_queue = self.prompt_queue.get_current_queue() current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0] queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1] queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info) return web.json_response(queue_info)
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
print("got prompt") print("got prompt")
resp_code = 200 resp_code = 200
out_string = "" out_string = ""
json_data = await request.json() json_data = await request.json()
if "number" in json_data: if "number" in json_data:
number = float(json_data['number']) number = float(json_data['number'])
else: else:
number = self.number number = self.number
if "front" in json_data: if "front" in json_data:
if json_data['front']: if json_data['front']:
number = -number number = -number
self.number += 1 self.number += 1
if "prompt" in json_data: if "prompt" in json_data:
prompt = json_data["prompt"] prompt = json_data["prompt"]
valid = main.validate_prompt(prompt) valid = main.validate_prompt(prompt)
extra_data = {} extra_data = {}
if "extra_data" in json_data: if "extra_data" in json_data:
extra_data = json_data["extra_data"] extra_data = json_data["extra_data"]
if "client_id" in json_data: if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"] extra_data["client_id"] = json_data["client_id"]
if valid[0]: if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data)) self.prompt_queue.put((number, id(prompt), prompt, extra_data))
else: else:
resp_code = 400 resp_code = 400
out_string = valid[1] out_string = valid[1]
print("invalid prompt:", valid[1]) print("invalid prompt:", valid[1])
return web.Response(body=out_string, status=resp_code) return web.Response(body=out_string, status=resp_code)
@routes.post("/queue") @routes.post("/queue")
async def post_queue(request): async def post_queue(request):
json_data = await request.json() json_data = await request.json()
if "clear" in json_data: if "clear" in json_data:
if json_data["clear"]: if json_data["clear"]:
self.prompt_queue.wipe_queue() self.prompt_queue.wipe_queue()
if "delete" in json_data: if "delete" in json_data:
to_delete = json_data['delete'] to_delete = json_data['delete']
for id_to_delete in to_delete: for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete) delete_func = lambda a: a[1] == int(id_to_delete)
self.prompt_queue.delete_queue_item(delete_func) self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200) return web.Response(status=200)
@routes.post("/history") @routes.post("/history")
async def post_history(request): async def post_history(request):
json_data = await request.json() json_data = await request.json()
if "clear" in json_data: if "clear" in json_data:
if json_data["clear"]: if json_data["clear"]:
self.prompt_queue.history = {} self.prompt_queue.history = {}
if "delete" in json_data: if "delete" in json_data:
to_delete = json_data['delete'] to_delete = json_data['delete']
for id_to_delete in to_delete: for id_to_delete in to_delete:
self.prompt_queue.history.pop(id_to_delete, None) self.prompt_queue.history.pop(id_to_delete, None)
return web.Response(status=200) return web.Response(status=200)
self.app.add_routes(routes) self.app.add_routes(routes)
self.app.add_routes([ self.app.add_routes([
web.static('/', self.web_root), web.static('/', self.web_root),
]) ])
def get_queue_info(self): def get_queue_info(self):
prompt_info = {} prompt_info = {}
exec_info = {} exec_info = {}
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info prompt_info['exec_info'] = exec_info
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
message = {"type": event, "data": data} message = {"type": event, "data": data}
if isinstance(message, str) == False: if isinstance(message, str) == False:
message = json.dumps(message) message = json.dumps(message)
if sid is None: if sid is None:
for ws in self.sockets.values(): for ws in self.sockets.values():
await ws.send_str(message) await ws.send_str(message)
elif sid in self.sockets: elif sid in self.sockets:
await self.sockets[sid].send_str(message) await self.sockets[sid].send_str(message)
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid)) self.messages.put_nowait, (event, data, sid))
def queue_updated(self): def queue_updated(self):
self.send_sync("status", { "status": self.get_queue_info() }) self.send_sync("status", { "status": self.get_queue_info() })
async def publish_loop(self): async def publish_loop(self):
while True: while True:
msg = await self.messages.get() msg = await self.messages.get()
await self.send(*msg) await self.send(*msg)
async def start(self, address, port): async def start(self, address, port):
runner = web.AppRunner(self.app) runner = web.AppRunner(self.app)
await runner.setup() await runner.setup()
site = web.TCPSite(runner, address, port) site = web.TCPSite(runner, address, port)
await site.start() await site.start()
if address == '': if address == '':
address = '0.0.0.0' address = '0.0.0.0'
print("Starting server\n") print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port)) print("To see the GUI go to: http://{}:{}".format(address, port))