Add a way to set output directory with --output-directory

This commit is contained in:
comfyanonymous 2023-04-05 14:01:01 -04:00
parent 30f274bf48
commit f816964847
4 changed files with 60 additions and 18 deletions

View File

@ -27,6 +27,40 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
if not os.path.exists(input_directory):
os.makedirs(input_directory)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def get_output_directory():
global output_directory
return output_directory
def get_temp_directory():
global temp_directory
return temp_directory
def get_input_directory():
global input_directory
return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
if type_name == "output":
return get_output_directory()
if type_name == "temp":
return get_temp_directory()
if type_name == "input":
return get_input_directory()
return None
def add_model_folder_path(folder_name, full_folder_path): def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths global folder_names_and_paths

View File

@ -17,6 +17,7 @@ if __name__ == "__main__":
print("\t--port 8188\t\t\tSet the listen port.") print("\t--port 8188\t\t\tSet the listen port.")
print() print()
print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.")
print("\t--output-directory path/to/output\tSet the ComfyUI output directory.")
print() print()
print() print()
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
@ -134,6 +135,14 @@ if __name__ == "__main__":
for i in indices: for i in indices:
load_extra_path_config(sys.argv[i]) load_extra_path_config(sys.argv[i])
try:
output_dir = sys.argv[sys.argv.index('--output-directory') + 1]
output_dir = os.path.abspath(output_dir)
print("setting output directory to:", output_dir)
folder_paths.set_output_directory(output_dir)
except:
pass
port = 8188 port = 8188
try: try:
p_index = sys.argv.index('--port') p_index = sys.argv.index('--port')

View File

@ -777,7 +777,7 @@ class KSamplerAdvanced:
class SaveImage: class SaveImage:
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") self.output_dir = folder_paths.get_output_directory()
self.type = "output" self.type = "output"
@classmethod @classmethod
@ -829,9 +829,6 @@ class SaveImage:
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
counter = 1 counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
results = list() results = list()
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
@ -856,7 +853,7 @@ class SaveImage:
class PreviewImage(SaveImage): class PreviewImage(SaveImage):
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") self.output_dir = folder_paths.get_temp_directory()
self.type = "temp" self.type = "temp"
@classmethod @classmethod
@ -867,13 +864,11 @@ class PreviewImage(SaveImage):
} }
class LoadImage: class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
if not os.path.exists(s.input_dir): input_dir = folder_paths.get_input_directory()
os.makedirs(s.input_dir)
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), )}, {"image": (sorted(os.listdir(input_dir)), )},
} }
CATEGORY = "image" CATEGORY = "image"
@ -881,7 +876,8 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK") RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
@ -895,18 +891,19 @@ class LoadImage:
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()
class LoadImageMask: class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), ), {"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),} "channel": (["alpha", "red", "green", "blue"], ),}
} }
@ -915,7 +912,8 @@ class LoadImageMask:
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image, channel): def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
mask = None mask = None
c = channel[0].upper() c = channel[0].upper()
@ -930,7 +928,8 @@ class LoadImageMask:
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())

View File

@ -89,7 +89,7 @@ class PromptServer():
@routes.post("/upload/image") @routes.post("/upload/image")
async def upload_image(request): async def upload_image(request):
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") upload_dir = folder_paths.get_input_directory()
if not os.path.exists(upload_dir): if not os.path.exists(upload_dir):
os.makedirs(upload_dir) os.makedirs(upload_dir)
@ -122,10 +122,10 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output") type = request.rel_url.query.get("type", "output")
if type not in ["output", "input", "temp"]: output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400) return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
if "subfolder" in request.rel_url.query: if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: