Xformers is now properly disabled when --cpu used.
Added --windows-standalone-build option, currently it only opens makes the code open up comfyui in the browser.
This commit is contained in:
parent
6d6758e9e4
commit
0f3ba7482f
|
@ -14,9 +14,8 @@ import model_management
|
|||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
pass
|
||||
|
||||
# CrossAttn precision handling
|
||||
import os
|
||||
|
@ -481,7 +480,7 @@ class CrossAttentionPytorch(nn.Module):
|
|||
return self.to_out(out)
|
||||
|
||||
import sys
|
||||
if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv:
|
||||
if model_management.xformers_enabled() == False:
|
||||
if "--use-split-cross-attention" in sys.argv:
|
||||
print("Using split optimization for cross attention")
|
||||
CrossAttention = CrossAttentionDoggettx
|
||||
|
|
|
@ -12,10 +12,8 @@ import model_management
|
|||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
print("No module 'xformers'. Proceeding without it.")
|
||||
pass
|
||||
|
||||
try:
|
||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||
|
@ -315,7 +313,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
||||
if model_management.xformers_enabled() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
|
|
|
@ -31,6 +31,16 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
if "--disable-xformers" in sys.argv:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
if "--cpu" in sys.argv:
|
||||
vram_state = CPU
|
||||
if "--lowvram" in sys.argv:
|
||||
|
@ -159,6 +169,11 @@ def get_autocast_device(dev):
|
|||
return dev.type
|
||||
return "cuda"
|
||||
|
||||
def xformers_enabled():
|
||||
if vram_state == CPU:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILBLE
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
|
15
main.py
15
main.py
|
@ -38,8 +38,8 @@ def prompt_worker(q, server):
|
|||
e.execute(item[-2], item[-1])
|
||||
q.task_done(item_id, e.outputs)
|
||||
|
||||
async def run(server, address='', port=8188, verbose=True):
|
||||
await asyncio.gather(server.start(address, port, verbose), server.publish_loop())
|
||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
||||
def hijack_progress(server):
|
||||
from tqdm.auto import tqdm
|
||||
|
@ -76,11 +76,18 @@ if __name__ == "__main__":
|
|||
except:
|
||||
pass
|
||||
|
||||
call_on_start = None
|
||||
if "--windows-standalone-build" in sys.argv:
|
||||
def startup_server(address, port):
|
||||
import webbrowser
|
||||
webbrowser.open("http://{}:{}".format(address, port))
|
||||
call_on_start = startup_server
|
||||
|
||||
if os.name == "nt":
|
||||
try:
|
||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
|
||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
|
||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
|
||||
|
||||
|
|
|
@ -260,7 +260,7 @@ class PromptServer():
|
|||
msg = await self.messages.get()
|
||||
await self.send(*msg)
|
||||
|
||||
async def start(self, address, port, verbose=True):
|
||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||
runner = web.AppRunner(self.app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, address, port)
|
||||
|
@ -271,3 +271,6 @@ class PromptServer():
|
|||
if verbose:
|
||||
print("Starting server\n")
|
||||
print("To see the GUI go to: http://{}:{}".format(address, port))
|
||||
if call_on_start is not None:
|
||||
call_on_start(address, port)
|
||||
|
||||
|
|
Loading…
Reference in New Issue