Xformers is now properly disabled when --cpu used.

Added --windows-standalone-build option, currently it only opens
makes the code open up comfyui in the browser.
This commit is contained in:
comfyanonymous 2023-03-12 15:44:16 -04:00
parent 6d6758e9e4
commit 0f3ba7482f
5 changed files with 34 additions and 12 deletions

View File

@ -14,9 +14,8 @@ import model_management
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILBLE = True
except: except:
XFORMERS_IS_AVAILBLE = False pass
# CrossAttn precision handling # CrossAttn precision handling
import os import os
@ -481,7 +480,7 @@ class CrossAttentionPytorch(nn.Module):
return self.to_out(out) return self.to_out(out)
import sys import sys
if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: if model_management.xformers_enabled() == False:
if "--use-split-cross-attention" in sys.argv: if "--use-split-cross-attention" in sys.argv:
print("Using split optimization for cross attention") print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx CrossAttention = CrossAttentionDoggettx

View File

@ -12,10 +12,8 @@ import model_management
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILBLE = True
except: except:
XFORMERS_IS_AVAILBLE = False pass
print("No module 'xformers'. Proceeding without it.")
try: try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError OOM_EXCEPTION = torch.cuda.OutOfMemoryError
@ -315,7 +313,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": if model_management.xformers_enabled() and attn_type == "vanilla":
attn_type = "vanilla-xformers" attn_type = "vanilla-xformers"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels") print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla": if attn_type == "vanilla":

View File

@ -31,6 +31,16 @@ try:
except: except:
pass pass
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
if "--cpu" in sys.argv: if "--cpu" in sys.argv:
vram_state = CPU vram_state = CPU
if "--lowvram" in sys.argv: if "--lowvram" in sys.argv:
@ -159,6 +169,11 @@ def get_autocast_device(dev):
return dev.type return dev.type
return "cuda" return "cuda"
def xformers_enabled():
if vram_state == CPU:
return False
return XFORMERS_IS_AVAILBLE
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()

15
main.py
View File

@ -38,8 +38,8 @@ def prompt_worker(q, server):
e.execute(item[-2], item[-1]) e.execute(item[-2], item[-1])
q.task_done(item_id, e.outputs) q.task_done(item_id, e.outputs)
async def run(server, address='', port=8188, verbose=True): async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -76,11 +76,18 @@ if __name__ == "__main__":
except: except:
pass pass
call_on_start = None
if "--windows-standalone-build" in sys.argv:
def startup_server(address, port):
import webbrowser
webbrowser.open("http://{}:{}".format(address, port))
call_on_start = startup_server
if os.name == "nt": if os.name == "nt":
try: try:
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print)) loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
else: else:
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print)) loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))

View File

@ -260,7 +260,7 @@ class PromptServer():
msg = await self.messages.get() msg = await self.messages.get()
await self.send(*msg) await self.send(*msg)
async def start(self, address, port, verbose=True): async def start(self, address, port, verbose=True, call_on_start=None):
runner = web.AppRunner(self.app) runner = web.AppRunner(self.app)
await runner.setup() await runner.setup()
site = web.TCPSite(runner, address, port) site = web.TCPSite(runner, address, port)
@ -271,3 +271,6 @@ class PromptServer():
if verbose: if verbose:
print("Starting server\n") print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port)) print("To see the GUI go to: http://{}:{}".format(address, port))
if call_on_start is not None:
call_on_start(address, port)