Remove pytorch_lightning dependency.
This commit is contained in:
parent
cb180b9998
commit
735ac4cf81
|
@ -0,0 +1,13 @@
|
|||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import math
|
||||
import struct
|
||||
import comfy.checkpoint_pickle
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
|
@ -14,7 +15,7 @@ def load_torch_file(ckpt, safe_load=False):
|
|||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
|
|
|
@ -4,7 +4,6 @@ torchsde
|
|||
einops
|
||||
transformers>=4.25.1
|
||||
safetensors>=0.3.0
|
||||
pytorch_lightning
|
||||
aiohttp
|
||||
accelerate
|
||||
pyyaml
|
||||
|
|
Loading…
Reference in New Issue