import torchaudio import torch import comfy.model_management import folder_paths import os import io import json import struct from comfy.cli_args import args class EmptyLatentAudio: def __init__(self): self.device = comfy.model_management.intermediate_device() @classmethod def INPUT_TYPES(s): return {"required": {}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" CATEGORY = "_for_testing/audio" def generate(self): batch_size = 1 latent = torch.zeros([batch_size, 64, 1024], device=self.device) return ({"samples":latent, "type": "audio"}, ) class VAEEncodeAudio: @classmethod def INPUT_TYPES(s): return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "_for_testing/audio" def encode(self, vae, audio): sample_rate = audio["sample_rate"] if 44100 != sample_rate: waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) else: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) return ({"samples":t}, ) class VAEDecodeAudio: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} RETURN_TYPES = ("AUDIO",) FUNCTION = "decode" CATEGORY = "_for_testing/audio" def decode(self, vae, samples): audio = vae.decode(samples["samples"]).movedim(-1, 1) return ({"waveform": audio, "sample_rate": 44100}, ) def create_vorbis_comment_block(comment_dict, last_block): vendor_string = b'ComfyUI' vendor_length = len(vendor_string) comments = [] for key, value in comment_dict.items(): comment = f"{key}={value}".encode('utf-8') comments.append(struct.pack('I', len(comment_data))[1:] + comment_data return comment_block def insert_or_replace_vorbis_comment(flac_io, comment_dict): if len(comment_dict) == 0: return flac_io flac_io.seek(4) blocks = [] last_block = False while not last_block: header = flac_io.read(4) last_block = (header[0] & 0x80) != 0 block_type = header[0] & 0x7F block_length = struct.unpack('>I', b'\x00' + header[1:])[0] block_data = flac_io.read(block_length) if block_type == 4 or block_type == 1: pass else: header = bytes([(header[0] & (~0x80))]) + header[1:] blocks.append(header + block_data) blocks.append(create_vorbis_comment_block(comment_dict, last_block=True)) new_flac_io = io.BytesIO() new_flac_io.write(b'fLaC') for block in blocks: new_flac_io.write(block) new_flac_io.write(flac_io.read()) return new_flac_io class SaveAudio: def __init__(self): self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" self.compress_level = 4 @classmethod def INPUT_TYPES(s): return {"required": { "audio": ("AUDIO", ), "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } RETURN_TYPES = () FUNCTION = "save_audio" OUTPUT_NODE = True CATEGORY = "_for_testing/audio" def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) results = list() metadata = {} if not args.disable_metadata: if prompt is not None: metadata["prompt"] = json.dumps(prompt) if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) for (batch_number, waveform) in enumerate(audio["waveform"]): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.flac" buff = io.BytesIO() torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC") buff = insert_or_replace_vorbis_comment(buff, metadata) with open(os.path.join(full_output_folder, file), 'wb') as f: f.write(buff.getbuffer()) results.append({ "filename": file, "subfolder": subfolder, "type": self.type }) counter += 1 return { "ui": { "audio": results } } class LoadAudio: SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif') @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() files = [ f for f in os.listdir(input_dir) if (os.path.isfile(os.path.join(input_dir, f)) and f.endswith(LoadAudio.SUPPORTED_FORMATS) ) ] return {"required": {"audio": (sorted(files), {"audio_upload": True})}} CATEGORY = "_for_testing/audio" RETURN_TYPES = ("AUDIO", ) FUNCTION = "load" def load(self, audio): audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = torchaudio.load(audio_path) multiplier = 1.0 audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} return (audio, ) @classmethod def IS_CHANGED(s, audio): image_path = folder_paths.get_annotated_filepath(audio) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() @classmethod def VALIDATE_INPUTS(s, audio): if not folder_paths.exists_annotated_filepath(audio): return "Invalid audio file: {}".format(audio) return True NODE_CLASS_MAPPINGS = { "EmptyLatentAudio": EmptyLatentAudio, "VAEEncodeAudio": VAEEncodeAudio, "VAEDecodeAudio": VAEDecodeAudio, "SaveAudio": SaveAudio, "LoadAudio": LoadAudio, }