diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 692952f3..a6d40e89 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -14,9 +14,8 @@ import model_management try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True except: - XFORMERS_IS_AVAILBLE = False + pass # CrossAttn precision handling import os @@ -481,7 +480,7 @@ class CrossAttentionPytorch(nn.Module): return self.to_out(out) import sys -if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: +if model_management.xformers_enabled() == False: if "--use-split-cross-attention" in sys.argv: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 18f7a8b0..15f35b91 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -12,10 +12,8 @@ import model_management try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True except: - XFORMERS_IS_AVAILBLE = False - print("No module 'xformers'. Proceeding without it.") + pass try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError @@ -315,7 +313,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' - if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + if model_management.xformers_enabled() and attn_type == "vanilla": attn_type = "vanilla-xformers" print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": diff --git a/comfy/model_management.py b/comfy/model_management.py index 4b061c32..c1a8f5a2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,16 @@ try: except: pass +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +if "--disable-xformers" in sys.argv: + XFORMERS_IS_AVAILBLE = False + if "--cpu" in sys.argv: vram_state = CPU if "--lowvram" in sys.argv: @@ -159,6 +169,11 @@ def get_autocast_device(dev): return dev.type return "cuda" +def xformers_enabled(): + if vram_state == CPU: + return False + return XFORMERS_IS_AVAILBLE + def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() diff --git a/main.py b/main.py index ca8674b5..c3d96039 100644 --- a/main.py +++ b/main.py @@ -38,8 +38,8 @@ def prompt_worker(q, server): e.execute(item[-2], item[-1]) q.task_done(item_id, e.outputs) -async def run(server, address='', port=8188, verbose=True): - await asyncio.gather(server.start(address, port, verbose), server.publish_loop()) +async def run(server, address='', port=8188, verbose=True, call_on_start=None): + await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) def hijack_progress(server): from tqdm.auto import tqdm @@ -76,11 +76,18 @@ if __name__ == "__main__": except: pass + call_on_start = None + if "--windows-standalone-build" in sys.argv: + def startup_server(address, port): + import webbrowser + webbrowser.open("http://{}:{}".format(address, port)) + call_on_start = startup_server + if os.name == "nt": try: - loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print)) + loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) except KeyboardInterrupt: pass else: - loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print)) + loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) diff --git a/server.py b/server.py index 5aba5761..a29d8597 100644 --- a/server.py +++ b/server.py @@ -260,7 +260,7 @@ class PromptServer(): msg = await self.messages.get() await self.send(*msg) - async def start(self, address, port, verbose=True): + async def start(self, address, port, verbose=True, call_on_start=None): runner = web.AppRunner(self.app) await runner.setup() site = web.TCPSite(runner, address, port) @@ -271,3 +271,6 @@ class PromptServer(): if verbose: print("Starting server\n") print("To see the GUI go to: http://{}:{}".format(address, port)) + if call_on_start is not None: + call_on_start(address, port) +