Remove pytorch_lightning dependency.

This commit is contained in:
comfyanonymous 2023-06-13 10:11:33 -04:00
parent cb180b9998
commit 735ac4cf81
3 changed files with 15 additions and 2 deletions

View File

@ -0,0 +1,13 @@
import pickle
load = pickle.load
class Empty:
pass
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
#TODO: safe unpickle
if module.startswith("pytorch_lightning"):
return Empty
return super().find_class(module, name)

View File

@ -1,6 +1,7 @@
import torch
import math
import struct
import comfy.checkpoint_pickle
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
@ -14,7 +15,7 @@ def load_torch_file(ckpt, safe_load=False):
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location="cpu")
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:

View File

@ -4,7 +4,6 @@ torchsde
einops
transformers>=4.25.1
safetensors>=0.3.0
pytorch_lightning
aiohttp
accelerate
pyyaml