pop clip vision keys after loading them.
This commit is contained in:
parent
c9e4a8c9e5
commit
cd930d4e7f
|
@ -21,7 +21,7 @@ class ClipVisionModel():
|
|||
size=224)
|
||||
|
||||
def load_sd(self, sd):
|
||||
self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def encode_image(self, image):
|
||||
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
||||
|
@ -59,7 +59,13 @@ def load_clipvision_from_sd(sd):
|
|||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
clip = ClipVisionModel(json_config)
|
||||
clip.load_sd(sd)
|
||||
m, u = clip.load_sd(sd)
|
||||
u = set(u)
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
t = sd.pop(k)
|
||||
del t
|
||||
return clip
|
||||
|
||||
def load(ckpt_path):
|
||||
|
|
Loading…
Reference in New Issue