Remove pytorch_lightning dependency.
This commit is contained in:
parent
cb180b9998
commit
735ac4cf81
|
@ -0,0 +1,13 @@
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
load = pickle.load
|
||||||
|
|
||||||
|
class Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Unpickler(pickle.Unpickler):
|
||||||
|
def find_class(self, module, name):
|
||||||
|
#TODO: safe unpickle
|
||||||
|
if module.startswith("pytorch_lightning"):
|
||||||
|
return Empty
|
||||||
|
return super().find_class(module, name)
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
|
import comfy.checkpoint_pickle
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
|
@ -14,7 +15,7 @@ def load_torch_file(ckpt, safe_load=False):
|
||||||
if safe_load:
|
if safe_load:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
|
|
|
@ -4,7 +4,6 @@ torchsde
|
||||||
einops
|
einops
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
safetensors>=0.3.0
|
safetensors>=0.3.0
|
||||||
pytorch_lightning
|
|
||||||
aiohttp
|
aiohttp
|
||||||
accelerate
|
accelerate
|
||||||
pyyaml
|
pyyaml
|
||||||
|
|
Loading…
Reference in New Issue