Xformers is now properly disabled when --cpu used.
Added --windows-standalone-build option, currently it only opens makes the code open up comfyui in the browser.
This commit is contained in:
parent
6d6758e9e4
commit
0f3ba7482f
|
@ -14,9 +14,8 @@ import model_management
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
XFORMERS_IS_AVAILBLE = True
|
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
pass
|
||||||
|
|
||||||
# CrossAttn precision handling
|
# CrossAttn precision handling
|
||||||
import os
|
import os
|
||||||
|
@ -481,7 +480,7 @@ class CrossAttentionPytorch(nn.Module):
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv:
|
if model_management.xformers_enabled() == False:
|
||||||
if "--use-split-cross-attention" in sys.argv:
|
if "--use-split-cross-attention" in sys.argv:
|
||||||
print("Using split optimization for cross attention")
|
print("Using split optimization for cross attention")
|
||||||
CrossAttention = CrossAttentionDoggettx
|
CrossAttention = CrossAttentionDoggettx
|
||||||
|
|
|
@ -12,10 +12,8 @@ import model_management
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
XFORMERS_IS_AVAILBLE = True
|
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
pass
|
||||||
print("No module 'xformers'. Proceeding without it.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||||
|
@ -315,7 +313,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
if model_management.xformers_enabled() and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-xformers"
|
attn_type = "vanilla-xformers"
|
||||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
|
|
|
@ -31,6 +31,16 @@ try:
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers
|
||||||
|
import xformers.ops
|
||||||
|
XFORMERS_IS_AVAILBLE = True
|
||||||
|
except:
|
||||||
|
XFORMERS_IS_AVAILBLE = False
|
||||||
|
|
||||||
|
if "--disable-xformers" in sys.argv:
|
||||||
|
XFORMERS_IS_AVAILBLE = False
|
||||||
|
|
||||||
if "--cpu" in sys.argv:
|
if "--cpu" in sys.argv:
|
||||||
vram_state = CPU
|
vram_state = CPU
|
||||||
if "--lowvram" in sys.argv:
|
if "--lowvram" in sys.argv:
|
||||||
|
@ -159,6 +169,11 @@ def get_autocast_device(dev):
|
||||||
return dev.type
|
return dev.type
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
def xformers_enabled():
|
||||||
|
if vram_state == CPU:
|
||||||
|
return False
|
||||||
|
return XFORMERS_IS_AVAILBLE
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
15
main.py
15
main.py
|
@ -38,8 +38,8 @@ def prompt_worker(q, server):
|
||||||
e.execute(item[-2], item[-1])
|
e.execute(item[-2], item[-1])
|
||||||
q.task_done(item_id, e.outputs)
|
q.task_done(item_id, e.outputs)
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose), server.publish_loop())
|
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
@ -76,11 +76,18 @@ if __name__ == "__main__":
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
call_on_start = None
|
||||||
|
if "--windows-standalone-build" in sys.argv:
|
||||||
|
def startup_server(address, port):
|
||||||
|
import webbrowser
|
||||||
|
webbrowser.open("http://{}:{}".format(address, port))
|
||||||
|
call_on_start = startup_server
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
|
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
|
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
|
||||||
|
|
||||||
|
|
|
@ -260,7 +260,7 @@ class PromptServer():
|
||||||
msg = await self.messages.get()
|
msg = await self.messages.get()
|
||||||
await self.send(*msg)
|
await self.send(*msg)
|
||||||
|
|
||||||
async def start(self, address, port, verbose=True):
|
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||||
runner = web.AppRunner(self.app)
|
runner = web.AppRunner(self.app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(runner, address, port)
|
site = web.TCPSite(runner, address, port)
|
||||||
|
@ -271,3 +271,6 @@ class PromptServer():
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Starting server\n")
|
print("Starting server\n")
|
||||||
print("To see the GUI go to: http://{}:{}".format(address, port))
|
print("To see the GUI go to: http://{}:{}".format(address, port))
|
||||||
|
if call_on_start is not None:
|
||||||
|
call_on_start(address, port)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue