Changed line endings to LF
This commit is contained in:
parent
23507882d2
commit
e053184a54
402
server.py
402
server.py
|
@ -1,202 +1,202 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
import nodes
|
import nodes
|
||||||
import main
|
import main
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Module 'aiohttp' not installed. Please install it via:")
|
print("Module 'aiohttp' not installed. Please install it via:")
|
||||||
print("pip install aiohttp")
|
print("pip install aiohttp")
|
||||||
print("or")
|
print("or")
|
||||||
print("pip install -r requirements.txt")
|
print("pip install -r requirements.txt")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
self.number = 0
|
self.number = 0
|
||||||
self.app = web.Application()
|
self.app = web.Application()
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
self.web_root = os.path.join(os.path.dirname(
|
self.web_root = os.path.join(os.path.dirname(
|
||||||
os.path.realpath(__file__)), "webshit")
|
os.path.realpath(__file__)), "webshit")
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
|
|
||||||
@routes.get('/ws')
|
@routes.get('/ws')
|
||||||
async def websocket_handler(request):
|
async def websocket_handler(request):
|
||||||
ws = web.WebSocketResponse()
|
ws = web.WebSocketResponse()
|
||||||
await ws.prepare(request)
|
await ws.prepare(request)
|
||||||
sid = uuid.uuid4().hex
|
sid = uuid.uuid4().hex
|
||||||
self.sockets[sid] = ws
|
self.sockets[sid] = ws
|
||||||
try:
|
try:
|
||||||
# Send initial state to the new client
|
# Send initial state to the new client
|
||||||
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
|
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
print('ws connection closed with exception %s' % ws.exception())
|
print('ws connection closed with exception %s' % ws.exception())
|
||||||
finally:
|
finally:
|
||||||
self.sockets.pop(sid)
|
self.sockets.pop(sid)
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
async def get_root(request):
|
async def get_root(request):
|
||||||
return web.FileResponse(os.path.join(self.web_root, "index.html"))
|
return web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||||
|
|
||||||
@routes.get("/view/{file}")
|
@routes.get("/view/{file}")
|
||||||
async def view_image(request):
|
async def view_image(request):
|
||||||
if "file" in request.match_info:
|
if "file" in request.match_info:
|
||||||
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||||
file = request.match_info["file"]
|
file = request.match_info["file"]
|
||||||
file = os.path.splitext(os.path.basename(file))[0] + ".png"
|
file = os.path.splitext(os.path.basename(file))[0] + ".png"
|
||||||
file = os.path.join(output_dir, file)
|
file = os.path.join(output_dir, file)
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
return web.FileResponse(file)
|
return web.FileResponse(file)
|
||||||
|
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
@routes.get("/prompt")
|
@routes.get("/prompt")
|
||||||
async def get_prompt(request):
|
async def get_prompt(request):
|
||||||
return web.json_response(self.get_queue_info())
|
return web.json_response(self.get_queue_info())
|
||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
|
||||||
info = {}
|
info = {}
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
info['output'] = obj_class.RETURN_TYPES
|
info['output'] = obj_class.RETURN_TYPES
|
||||||
info['name'] = x #TODO
|
info['name'] = x #TODO
|
||||||
info['description'] = ''
|
info['description'] = ''
|
||||||
info['category'] = 'sd'
|
info['category'] = 'sd'
|
||||||
if hasattr(obj_class, 'CATEGORY'):
|
if hasattr(obj_class, 'CATEGORY'):
|
||||||
info['category'] = obj_class.CATEGORY
|
info['category'] = obj_class.CATEGORY
|
||||||
out[x] = info
|
out[x] = info
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/history")
|
@routes.get("/history")
|
||||||
async def get_history(request):
|
async def get_history(request):
|
||||||
return web.json_response(self.prompt_queue.history)
|
return web.json_response(self.prompt_queue.history)
|
||||||
|
|
||||||
@routes.get("/queue")
|
@routes.get("/queue")
|
||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
queue_info = {}
|
queue_info = {}
|
||||||
current_queue = self.prompt_queue.get_current_queue()
|
current_queue = self.prompt_queue.get_current_queue()
|
||||||
queue_info['queue_running'] = current_queue[0]
|
queue_info['queue_running'] = current_queue[0]
|
||||||
queue_info['queue_pending'] = current_queue[1]
|
queue_info['queue_pending'] = current_queue[1]
|
||||||
return web.json_response(queue_info)
|
return web.json_response(queue_info)
|
||||||
|
|
||||||
@routes.post("/prompt")
|
@routes.post("/prompt")
|
||||||
async def post_prompt(request):
|
async def post_prompt(request):
|
||||||
print("got prompt")
|
print("got prompt")
|
||||||
resp_code = 200
|
resp_code = 200
|
||||||
out_string = ""
|
out_string = ""
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
|
|
||||||
if "number" in json_data:
|
if "number" in json_data:
|
||||||
number = float(json_data['number'])
|
number = float(json_data['number'])
|
||||||
else:
|
else:
|
||||||
number = self.number
|
number = self.number
|
||||||
if "front" in json_data:
|
if "front" in json_data:
|
||||||
if json_data['front']:
|
if json_data['front']:
|
||||||
number = -number
|
number = -number
|
||||||
|
|
||||||
self.number += 1
|
self.number += 1
|
||||||
|
|
||||||
if "prompt" in json_data:
|
if "prompt" in json_data:
|
||||||
prompt = json_data["prompt"]
|
prompt = json_data["prompt"]
|
||||||
valid = main.validate_prompt(prompt)
|
valid = main.validate_prompt(prompt)
|
||||||
extra_data = {}
|
extra_data = {}
|
||||||
if "extra_data" in json_data:
|
if "extra_data" in json_data:
|
||||||
extra_data = json_data["extra_data"]
|
extra_data = json_data["extra_data"]
|
||||||
|
|
||||||
if "client_id" in json_data:
|
if "client_id" in json_data:
|
||||||
extra_data["client_id"] = json_data["client_id"]
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
if valid[0]:
|
if valid[0]:
|
||||||
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
||||||
else:
|
else:
|
||||||
resp_code = 400
|
resp_code = 400
|
||||||
out_string = valid[1]
|
out_string = valid[1]
|
||||||
print("invalid prompt:", valid[1])
|
print("invalid prompt:", valid[1])
|
||||||
|
|
||||||
return web.Response(body=out_string, status=resp_code)
|
return web.Response(body=out_string, status=resp_code)
|
||||||
|
|
||||||
@routes.post("/queue")
|
@routes.post("/queue")
|
||||||
async def post_queue(request):
|
async def post_queue(request):
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
if "clear" in json_data:
|
if "clear" in json_data:
|
||||||
if json_data["clear"]:
|
if json_data["clear"]:
|
||||||
self.prompt_queue.wipe_queue()
|
self.prompt_queue.wipe_queue()
|
||||||
if "delete" in json_data:
|
if "delete" in json_data:
|
||||||
to_delete = json_data['delete']
|
to_delete = json_data['delete']
|
||||||
for id_to_delete in to_delete:
|
for id_to_delete in to_delete:
|
||||||
delete_func = lambda a: a[1] == int(id_to_delete)
|
delete_func = lambda a: a[1] == int(id_to_delete)
|
||||||
self.prompt_queue.delete_queue_item(delete_func)
|
self.prompt_queue.delete_queue_item(delete_func)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
@routes.post("/history")
|
@routes.post("/history")
|
||||||
async def post_history(request):
|
async def post_history(request):
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
if "clear" in json_data:
|
if "clear" in json_data:
|
||||||
if json_data["clear"]:
|
if json_data["clear"]:
|
||||||
self.prompt_queue.history = {}
|
self.prompt_queue.history = {}
|
||||||
if "delete" in json_data:
|
if "delete" in json_data:
|
||||||
to_delete = json_data['delete']
|
to_delete = json_data['delete']
|
||||||
for id_to_delete in to_delete:
|
for id_to_delete in to_delete:
|
||||||
self.prompt_queue.history.pop(id_to_delete, None)
|
self.prompt_queue.history.pop(id_to_delete, None)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
self.app.add_routes(routes)
|
self.app.add_routes(routes)
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root),
|
||||||
])
|
])
|
||||||
|
|
||||||
def get_queue_info(self):
|
def get_queue_info(self):
|
||||||
prompt_info = {}
|
prompt_info = {}
|
||||||
exec_info = {}
|
exec_info = {}
|
||||||
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
|
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
|
||||||
prompt_info['exec_info'] = exec_info
|
prompt_info['exec_info'] = exec_info
|
||||||
return prompt_info
|
return prompt_info
|
||||||
|
|
||||||
async def send(self, event, data, sid=None):
|
async def send(self, event, data, sid=None):
|
||||||
message = {"type": event, "data": data}
|
message = {"type": event, "data": data}
|
||||||
|
|
||||||
if isinstance(message, str) == False:
|
if isinstance(message, str) == False:
|
||||||
message = json.dumps(message)
|
message = json.dumps(message)
|
||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
for ws in self.sockets.values():
|
for ws in self.sockets.values():
|
||||||
await ws.send_str(message)
|
await ws.send_str(message)
|
||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await self.sockets[sid].send_str(message)
|
await self.sockets[sid].send_str(message)
|
||||||
|
|
||||||
def send_sync(self, event, data, sid=None):
|
def send_sync(self, event, data, sid=None):
|
||||||
self.loop.call_soon_threadsafe(
|
self.loop.call_soon_threadsafe(
|
||||||
self.messages.put_nowait, (event, data, sid))
|
self.messages.put_nowait, (event, data, sid))
|
||||||
|
|
||||||
def queue_updated(self):
|
def queue_updated(self):
|
||||||
self.send_sync("status", { "status": self.get_queue_info() })
|
self.send_sync("status", { "status": self.get_queue_info() })
|
||||||
|
|
||||||
async def publish_loop(self):
|
async def publish_loop(self):
|
||||||
while True:
|
while True:
|
||||||
msg = await self.messages.get()
|
msg = await self.messages.get()
|
||||||
await self.send(*msg)
|
await self.send(*msg)
|
||||||
|
|
||||||
async def start(self, address, port):
|
async def start(self, address, port):
|
||||||
runner = web.AppRunner(self.app)
|
runner = web.AppRunner(self.app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(runner, address, port)
|
site = web.TCPSite(runner, address, port)
|
||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
if address == '':
|
if address == '':
|
||||||
address = '0.0.0.0'
|
address = '0.0.0.0'
|
||||||
print("Starting server\n")
|
print("Starting server\n")
|
||||||
print("To see the GUI go to: http://{}:{}".format(address, port))
|
print("To see the GUI go to: http://{}:{}".format(address, port))
|
Loading…
Reference in New Issue