Cleanup: Remove a bunch of useless files.
This commit is contained in:
parent
74297f5f9d
commit
f0a2b81cd0
|
@ -14,8 +14,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
||||||
|
|
||||||
from ..ldm.modules.attention import SpatialTransformer
|
from ..ldm.modules.attention import SpatialTransformer
|
||||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||||
from ..ldm.models.diffusion.ddpm import LatentDiffusion
|
from ..ldm.util import exists
|
||||||
from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
|
|
||||||
|
|
||||||
|
|
||||||
class ControlledUnetModel(UNetModel):
|
class ControlledUnetModel(UNetModel):
|
||||||
|
|
|
@ -3,7 +3,6 @@ import os
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.ldm.util import instantiate_from_config
|
|
||||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
|
|
|
@ -1,24 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from ldm.modules.midas.api import load_midas_transform
|
|
||||||
|
|
||||||
|
|
||||||
class AddMiDaS(object):
|
|
||||||
def __init__(self, model_type):
|
|
||||||
super().__init__()
|
|
||||||
self.transform = load_midas_transform(model_type)
|
|
||||||
|
|
||||||
def pt2np(self, x):
|
|
||||||
x = ((x + 1.0) * .5).detach().cpu().numpy()
|
|
||||||
return x
|
|
||||||
|
|
||||||
def np2pt(self, x):
|
|
||||||
x = torch.from_numpy(x) * 2 - 1.
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
|
||||||
x = self.pt2np(sample['jpg'])
|
|
||||||
x = self.transform({"image": x})["image"]
|
|
||||||
sample['midas_in'] = x
|
|
||||||
return sample
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,59 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
from typing import List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
|
|
||||||
|
|
||||||
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
|
||||||
r"""Normalize an image/video tensor with mean and standard deviation.
|
|
||||||
.. math::
|
|
||||||
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
|
|
||||||
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
|
|
||||||
Args:
|
|
||||||
data: Image tensor of size :math:`(B, C, *)`.
|
|
||||||
mean: Mean for each channel.
|
|
||||||
std: Standard deviations for each channel.
|
|
||||||
Return:
|
|
||||||
Normalised tensor with same size as input :math:`(B, C, *)`.
|
|
||||||
Examples:
|
|
||||||
>>> x = torch.rand(1, 4, 3, 3)
|
|
||||||
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
|
|
||||||
>>> out.shape
|
|
||||||
torch.Size([1, 4, 3, 3])
|
|
||||||
>>> x = torch.rand(1, 4, 3, 3)
|
|
||||||
>>> mean = torch.zeros(4)
|
|
||||||
>>> std = 255. * torch.ones(4)
|
|
||||||
>>> out = normalize(x, mean, std)
|
|
||||||
>>> out.shape
|
|
||||||
torch.Size([1, 4, 3, 3])
|
|
||||||
"""
|
|
||||||
shape = data.shape
|
|
||||||
if len(mean.shape) == 0 or mean.shape[0] == 1:
|
|
||||||
mean = mean.expand(shape[1])
|
|
||||||
if len(std.shape) == 0 or std.shape[0] == 1:
|
|
||||||
std = std.expand(shape[1])
|
|
||||||
|
|
||||||
# Allow broadcast on channel dimension
|
|
||||||
if mean.shape and mean.shape[0] != 1:
|
|
||||||
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
|
|
||||||
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
|
|
||||||
|
|
||||||
# Allow broadcast on channel dimension
|
|
||||||
if std.shape and std.shape[0] != 1:
|
|
||||||
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
|
|
||||||
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
|
|
||||||
|
|
||||||
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
|
|
||||||
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
|
|
||||||
|
|
||||||
if mean.shape:
|
|
||||||
mean = mean[..., :, None]
|
|
||||||
if std.shape:
|
|
||||||
std = std[..., :, None]
|
|
||||||
|
|
||||||
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
|
|
||||||
|
|
||||||
return out.view(shape)
|
|
|
@ -1,314 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from . import kornia_functions
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
|
||||||
|
|
||||||
import open_clip
|
|
||||||
from ldm.util import default, count_params
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractEncoder(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def encode(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityEncoder(AbstractEncoder):
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ClassEmbedder(nn.Module):
|
|
||||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
|
||||||
super().__init__()
|
|
||||||
self.key = key
|
|
||||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
|
||||||
self.n_classes = n_classes
|
|
||||||
self.ucg_rate = ucg_rate
|
|
||||||
|
|
||||||
def forward(self, batch, key=None, disable_dropout=False):
|
|
||||||
if key is None:
|
|
||||||
key = self.key
|
|
||||||
# this is for use in crossattn
|
|
||||||
c = batch[key][:, None]
|
|
||||||
if self.ucg_rate > 0. and not disable_dropout:
|
|
||||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
|
||||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
|
||||||
c = c.long()
|
|
||||||
c = self.embedding(c)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
|
||||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
|
||||||
uc = torch.ones((bs,), device=device) * uc_class
|
|
||||||
uc = {self.key: uc}
|
|
||||||
return uc
|
|
||||||
|
|
||||||
|
|
||||||
def disabled_train(self, mode=True):
|
|
||||||
"""Overwrite model.train with this function to make sure train/eval mode
|
|
||||||
does not change anymore."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenT5Embedder(AbstractEncoder):
|
|
||||||
"""Uses the T5 transformer encoder for text"""
|
|
||||||
|
|
||||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
|
|
||||||
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
|
||||||
super().__init__()
|
|
||||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
|
||||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length # TODO: typical value?
|
|
||||||
if freeze:
|
|
||||||
self.freeze()
|
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
self.transformer = self.transformer.eval()
|
|
||||||
# self.train = disabled_train
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
|
||||||
outputs = self.transformer(input_ids=tokens)
|
|
||||||
|
|
||||||
z = outputs.last_hidden_state
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
|
||||||
LAYERS = [
|
|
||||||
"last",
|
|
||||||
"pooled",
|
|
||||||
"hidden"
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
|
||||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
|
||||||
super().__init__()
|
|
||||||
assert layer in self.LAYERS
|
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
|
||||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
|
||||||
if freeze:
|
|
||||||
self.freeze()
|
|
||||||
self.layer = layer
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
if layer == "hidden":
|
|
||||||
assert layer_idx is not None
|
|
||||||
assert 0 <= abs(layer_idx) <= 12
|
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
self.transformer = self.transformer.eval()
|
|
||||||
# self.train = disabled_train
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
|
|
||||||
if self.layer == "last":
|
|
||||||
z = outputs.last_hidden_state
|
|
||||||
elif self.layer == "pooled":
|
|
||||||
z = outputs.pooler_output[:, None, :]
|
|
||||||
else:
|
|
||||||
z = outputs.hidden_states[self.layer_idx]
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipImageEmbedder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
jit=False,
|
|
||||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
||||||
antialias=True,
|
|
||||||
ucg_rate=0.
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
from clip import load as load_clip
|
|
||||||
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
|
||||||
|
|
||||||
self.antialias = antialias
|
|
||||||
|
|
||||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
|
||||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
|
||||||
self.ucg_rate = ucg_rate
|
|
||||||
|
|
||||||
def preprocess(self, x):
|
|
||||||
# normalize to [0,1]
|
|
||||||
# x = kornia_functions.geometry_resize(x, (224, 224),
|
|
||||||
# interpolation='bicubic', align_corners=True,
|
|
||||||
# antialias=self.antialias)
|
|
||||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
|
||||||
x = (x + 1.) / 2.
|
|
||||||
# re-normalize according to clip
|
|
||||||
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x, no_dropout=False):
|
|
||||||
# x is assumed to be in range [-1,1]
|
|
||||||
out = self.model.encode_image(self.preprocess(x))
|
|
||||||
out = out.to(x.dtype)
|
|
||||||
if self.ucg_rate > 0. and not no_dropout:
|
|
||||||
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|
||||||
"""
|
|
||||||
Uses the OpenCLIP transformer encoder for text
|
|
||||||
"""
|
|
||||||
LAYERS = [
|
|
||||||
# "pooled",
|
|
||||||
"last",
|
|
||||||
"penultimate"
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
|
||||||
freeze=True, layer="last"):
|
|
||||||
super().__init__()
|
|
||||||
assert layer in self.LAYERS
|
|
||||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
|
||||||
del model.visual
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
|
||||||
if freeze:
|
|
||||||
self.freeze()
|
|
||||||
self.layer = layer
|
|
||||||
if self.layer == "last":
|
|
||||||
self.layer_idx = 0
|
|
||||||
elif self.layer == "penultimate":
|
|
||||||
self.layer_idx = 1
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
self.model = self.model.eval()
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
tokens = open_clip.tokenize(text)
|
|
||||||
z = self.encode_with_transformer(tokens.to(self.device))
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode_with_transformer(self, text):
|
|
||||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
|
||||||
x = x + self.model.positional_embedding
|
|
||||||
x = x.permute(1, 0, 2) # NLD -> LND
|
|
||||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
|
||||||
x = x.permute(1, 0, 2) # LND -> NLD
|
|
||||||
x = self.model.ln_final(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
|
||||||
for i, r in enumerate(self.model.transformer.resblocks):
|
|
||||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
|
||||||
break
|
|
||||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
x = checkpoint(r, x, attn_mask)
|
|
||||||
else:
|
|
||||||
x = r(x, attn_mask=attn_mask)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
|
||||||
"""
|
|
||||||
Uses the OpenCLIP vision transformer encoder for images
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
|
||||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
|
||||||
super().__init__()
|
|
||||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
|
||||||
pretrained=version, )
|
|
||||||
del model.transformer
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
|
||||||
if freeze:
|
|
||||||
self.freeze()
|
|
||||||
self.layer = layer
|
|
||||||
if self.layer == "penultimate":
|
|
||||||
raise NotImplementedError()
|
|
||||||
self.layer_idx = 1
|
|
||||||
|
|
||||||
self.antialias = antialias
|
|
||||||
|
|
||||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
|
||||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
|
||||||
self.ucg_rate = ucg_rate
|
|
||||||
|
|
||||||
def preprocess(self, x):
|
|
||||||
# normalize to [0,1]
|
|
||||||
# x = kornia.geometry.resize(x, (224, 224),
|
|
||||||
# interpolation='bicubic', align_corners=True,
|
|
||||||
# antialias=self.antialias)
|
|
||||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
|
||||||
x = (x + 1.) / 2.
|
|
||||||
# renormalize according to clip
|
|
||||||
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
self.model = self.model.eval()
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, image, no_dropout=False):
|
|
||||||
z = self.encode_with_vision_transformer(image)
|
|
||||||
if self.ucg_rate > 0. and not no_dropout:
|
|
||||||
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode_with_vision_transformer(self, img):
|
|
||||||
img = self.preprocess(img)
|
|
||||||
x = self.model.visual(img)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
|
||||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
|
||||||
clip_max_length=77, t5_max_length=77):
|
|
||||||
super().__init__()
|
|
||||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
|
||||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
|
||||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
|
||||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
clip_z = self.clip_encoder.encode(text)
|
|
||||||
t5_z = self.t5_encoder.encode(text)
|
|
||||||
return [clip_z, t5_z]
|
|
|
@ -1,170 +0,0 @@
|
||||||
# based on https://github.com/isl-org/MiDaS
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torchvision.transforms import Compose
|
|
||||||
|
|
||||||
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
|
|
||||||
from ldm.modules.midas.midas.midas_net import MidasNet
|
|
||||||
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
|
|
||||||
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
|
||||||
|
|
||||||
|
|
||||||
ISL_PATHS = {
|
|
||||||
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
|
|
||||||
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
|
|
||||||
"midas_v21": "",
|
|
||||||
"midas_v21_small": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def disabled_train(self, mode=True):
|
|
||||||
"""Overwrite model.train with this function to make sure train/eval mode
|
|
||||||
does not change anymore."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def load_midas_transform(model_type):
|
|
||||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
|
||||||
# load transform only
|
|
||||||
if model_type == "dpt_large": # DPT-Large
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "minimal"
|
|
||||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
|
|
||||||
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "minimal"
|
|
||||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
|
|
||||||
elif model_type == "midas_v21":
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "upper_bound"
|
|
||||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
|
|
||||||
elif model_type == "midas_v21_small":
|
|
||||||
net_w, net_h = 256, 256
|
|
||||||
resize_mode = "upper_bound"
|
|
||||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
|
||||||
|
|
||||||
transform = Compose(
|
|
||||||
[
|
|
||||||
Resize(
|
|
||||||
net_w,
|
|
||||||
net_h,
|
|
||||||
resize_target=None,
|
|
||||||
keep_aspect_ratio=True,
|
|
||||||
ensure_multiple_of=32,
|
|
||||||
resize_method=resize_mode,
|
|
||||||
image_interpolation_method=cv2.INTER_CUBIC,
|
|
||||||
),
|
|
||||||
normalization,
|
|
||||||
PrepareForNet(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return transform
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_type):
|
|
||||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
|
||||||
# load network
|
|
||||||
model_path = ISL_PATHS[model_type]
|
|
||||||
if model_type == "dpt_large": # DPT-Large
|
|
||||||
model = DPTDepthModel(
|
|
||||||
path=model_path,
|
|
||||||
backbone="vitl16_384",
|
|
||||||
non_negative=True,
|
|
||||||
)
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "minimal"
|
|
||||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
|
|
||||||
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
|
||||||
model = DPTDepthModel(
|
|
||||||
path=model_path,
|
|
||||||
backbone="vitb_rn50_384",
|
|
||||||
non_negative=True,
|
|
||||||
)
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "minimal"
|
|
||||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
|
|
||||||
elif model_type == "midas_v21":
|
|
||||||
model = MidasNet(model_path, non_negative=True)
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "upper_bound"
|
|
||||||
normalization = NormalizeImage(
|
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == "midas_v21_small":
|
|
||||||
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
|
||||||
non_negative=True, blocks={'expand': True})
|
|
||||||
net_w, net_h = 256, 256
|
|
||||||
resize_mode = "upper_bound"
|
|
||||||
normalization = NormalizeImage(
|
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
|
||||||
assert False
|
|
||||||
|
|
||||||
transform = Compose(
|
|
||||||
[
|
|
||||||
Resize(
|
|
||||||
net_w,
|
|
||||||
net_h,
|
|
||||||
resize_target=None,
|
|
||||||
keep_aspect_ratio=True,
|
|
||||||
ensure_multiple_of=32,
|
|
||||||
resize_method=resize_mode,
|
|
||||||
image_interpolation_method=cv2.INTER_CUBIC,
|
|
||||||
),
|
|
||||||
normalization,
|
|
||||||
PrepareForNet(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return model.eval(), transform
|
|
||||||
|
|
||||||
|
|
||||||
class MiDaSInference(nn.Module):
|
|
||||||
MODEL_TYPES_TORCH_HUB = [
|
|
||||||
"DPT_Large",
|
|
||||||
"DPT_Hybrid",
|
|
||||||
"MiDaS_small"
|
|
||||||
]
|
|
||||||
MODEL_TYPES_ISL = [
|
|
||||||
"dpt_large",
|
|
||||||
"dpt_hybrid",
|
|
||||||
"midas_v21",
|
|
||||||
"midas_v21_small",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, model_type):
|
|
||||||
super().__init__()
|
|
||||||
assert (model_type in self.MODEL_TYPES_ISL)
|
|
||||||
model, _ = load_model(model_type)
|
|
||||||
self.model = model
|
|
||||||
self.model.train = disabled_train
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
|
|
||||||
# NOTE: we expect that the correct transform has been called during dataloading.
|
|
||||||
with torch.no_grad():
|
|
||||||
prediction = self.model(x)
|
|
||||||
prediction = torch.nn.functional.interpolate(
|
|
||||||
prediction.unsqueeze(1),
|
|
||||||
size=x.shape[2:],
|
|
||||||
mode="bicubic",
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
|
|
||||||
return prediction
|
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
|
||||||
def load(self, path):
|
|
||||||
"""Load model from file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): file path
|
|
||||||
"""
|
|
||||||
parameters = torch.load(path, map_location=torch.device('cpu'))
|
|
||||||
|
|
||||||
if "optimizer" in parameters:
|
|
||||||
parameters = parameters["model"]
|
|
||||||
|
|
||||||
self.load_state_dict(parameters)
|
|
|
@ -1,342 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .vit import (
|
|
||||||
_make_pretrained_vitb_rn50_384,
|
|
||||||
_make_pretrained_vitl16_384,
|
|
||||||
_make_pretrained_vitb16_384,
|
|
||||||
forward_vit,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
|
||||||
if backbone == "vitl16_384":
|
|
||||||
pretrained = _make_pretrained_vitl16_384(
|
|
||||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
|
||||||
)
|
|
||||||
scratch = _make_scratch(
|
|
||||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
|
||||||
) # ViT-L/16 - 85.0% Top1 (backbone)
|
|
||||||
elif backbone == "vitb_rn50_384":
|
|
||||||
pretrained = _make_pretrained_vitb_rn50_384(
|
|
||||||
use_pretrained,
|
|
||||||
hooks=hooks,
|
|
||||||
use_vit_only=use_vit_only,
|
|
||||||
use_readout=use_readout,
|
|
||||||
)
|
|
||||||
scratch = _make_scratch(
|
|
||||||
[256, 512, 768, 768], features, groups=groups, expand=expand
|
|
||||||
) # ViT-H/16 - 85.0% Top1 (backbone)
|
|
||||||
elif backbone == "vitb16_384":
|
|
||||||
pretrained = _make_pretrained_vitb16_384(
|
|
||||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
|
||||||
)
|
|
||||||
scratch = _make_scratch(
|
|
||||||
[96, 192, 384, 768], features, groups=groups, expand=expand
|
|
||||||
) # ViT-B/16 - 84.6% Top1 (backbone)
|
|
||||||
elif backbone == "resnext101_wsl":
|
|
||||||
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
|
||||||
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
|
||||||
elif backbone == "efficientnet_lite3":
|
|
||||||
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
|
||||||
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
|
||||||
else:
|
|
||||||
print(f"Backbone '{backbone}' not implemented")
|
|
||||||
assert False
|
|
||||||
|
|
||||||
return pretrained, scratch
|
|
||||||
|
|
||||||
|
|
||||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
|
||||||
scratch = nn.Module()
|
|
||||||
|
|
||||||
out_shape1 = out_shape
|
|
||||||
out_shape2 = out_shape
|
|
||||||
out_shape3 = out_shape
|
|
||||||
out_shape4 = out_shape
|
|
||||||
if expand==True:
|
|
||||||
out_shape1 = out_shape
|
|
||||||
out_shape2 = out_shape*2
|
|
||||||
out_shape3 = out_shape*4
|
|
||||||
out_shape4 = out_shape*8
|
|
||||||
|
|
||||||
scratch.layer1_rn = nn.Conv2d(
|
|
||||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
scratch.layer2_rn = nn.Conv2d(
|
|
||||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
scratch.layer3_rn = nn.Conv2d(
|
|
||||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
scratch.layer4_rn = nn.Conv2d(
|
|
||||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
|
|
||||||
return scratch
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
|
||||||
efficientnet = torch.hub.load(
|
|
||||||
"rwightman/gen-efficientnet-pytorch",
|
|
||||||
"tf_efficientnet_lite3",
|
|
||||||
pretrained=use_pretrained,
|
|
||||||
exportable=exportable
|
|
||||||
)
|
|
||||||
return _make_efficientnet_backbone(efficientnet)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_efficientnet_backbone(effnet):
|
|
||||||
pretrained = nn.Module()
|
|
||||||
|
|
||||||
pretrained.layer1 = nn.Sequential(
|
|
||||||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
|
||||||
)
|
|
||||||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
|
||||||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
|
||||||
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
|
||||||
|
|
||||||
return pretrained
|
|
||||||
|
|
||||||
|
|
||||||
def _make_resnet_backbone(resnet):
|
|
||||||
pretrained = nn.Module()
|
|
||||||
pretrained.layer1 = nn.Sequential(
|
|
||||||
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.layer2 = resnet.layer2
|
|
||||||
pretrained.layer3 = resnet.layer3
|
|
||||||
pretrained.layer4 = resnet.layer4
|
|
||||||
|
|
||||||
return pretrained
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_resnext101_wsl(use_pretrained):
|
|
||||||
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
|
||||||
return _make_resnet_backbone(resnet)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Interpolate(nn.Module):
|
|
||||||
"""Interpolation module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, scale_factor, mode, align_corners=False):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scale_factor (float): scaling
|
|
||||||
mode (str): interpolation mode
|
|
||||||
"""
|
|
||||||
super(Interpolate, self).__init__()
|
|
||||||
|
|
||||||
self.interp = nn.functional.interpolate
|
|
||||||
self.scale_factor = scale_factor
|
|
||||||
self.mode = mode
|
|
||||||
self.align_corners = align_corners
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: interpolated data
|
|
||||||
"""
|
|
||||||
|
|
||||||
x = self.interp(
|
|
||||||
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
|
||||||
)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualConvUnit(nn.Module):
|
|
||||||
"""Residual convolution module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, features):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
out = self.relu(x)
|
|
||||||
out = self.conv1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
out = self.conv2(out)
|
|
||||||
|
|
||||||
return out + x
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureFusionBlock(nn.Module):
|
|
||||||
"""Feature fusion block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, features):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super(FeatureFusionBlock, self).__init__()
|
|
||||||
|
|
||||||
self.resConfUnit1 = ResidualConvUnit(features)
|
|
||||||
self.resConfUnit2 = ResidualConvUnit(features)
|
|
||||||
|
|
||||||
def forward(self, *xs):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
output = xs[0]
|
|
||||||
|
|
||||||
if len(xs) == 2:
|
|
||||||
output += self.resConfUnit1(xs[1])
|
|
||||||
|
|
||||||
output = self.resConfUnit2(output)
|
|
||||||
|
|
||||||
output = nn.functional.interpolate(
|
|
||||||
output, scale_factor=2, mode="bilinear", align_corners=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualConvUnit_custom(nn.Module):
|
|
||||||
"""Residual convolution module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, features, activation, bn):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.bn = bn
|
|
||||||
|
|
||||||
self.groups=1
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.bn==True:
|
|
||||||
self.bn1 = nn.BatchNorm2d(features)
|
|
||||||
self.bn2 = nn.BatchNorm2d(features)
|
|
||||||
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
self.skip_add = nn.quantized.FloatFunctional()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
|
|
||||||
out = self.activation(x)
|
|
||||||
out = self.conv1(out)
|
|
||||||
if self.bn==True:
|
|
||||||
out = self.bn1(out)
|
|
||||||
|
|
||||||
out = self.activation(out)
|
|
||||||
out = self.conv2(out)
|
|
||||||
if self.bn==True:
|
|
||||||
out = self.bn2(out)
|
|
||||||
|
|
||||||
if self.groups > 1:
|
|
||||||
out = self.conv_merge(out)
|
|
||||||
|
|
||||||
return self.skip_add.add(out, x)
|
|
||||||
|
|
||||||
# return out + x
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureFusionBlock_custom(nn.Module):
|
|
||||||
"""Feature fusion block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super(FeatureFusionBlock_custom, self).__init__()
|
|
||||||
|
|
||||||
self.deconv = deconv
|
|
||||||
self.align_corners = align_corners
|
|
||||||
|
|
||||||
self.groups=1
|
|
||||||
|
|
||||||
self.expand = expand
|
|
||||||
out_features = features
|
|
||||||
if self.expand==True:
|
|
||||||
out_features = features//2
|
|
||||||
|
|
||||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
|
||||||
|
|
||||||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
|
||||||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
|
||||||
|
|
||||||
self.skip_add = nn.quantized.FloatFunctional()
|
|
||||||
|
|
||||||
def forward(self, *xs):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
output = xs[0]
|
|
||||||
|
|
||||||
if len(xs) == 2:
|
|
||||||
res = self.resConfUnit1(xs[1])
|
|
||||||
output = self.skip_add.add(output, res)
|
|
||||||
# output += res
|
|
||||||
|
|
||||||
output = self.resConfUnit2(output)
|
|
||||||
|
|
||||||
output = nn.functional.interpolate(
|
|
||||||
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
|
||||||
)
|
|
||||||
|
|
||||||
output = self.out_conv(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
|
@ -1,109 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .base_model import BaseModel
|
|
||||||
from .blocks import (
|
|
||||||
FeatureFusionBlock,
|
|
||||||
FeatureFusionBlock_custom,
|
|
||||||
Interpolate,
|
|
||||||
_make_encoder,
|
|
||||||
forward_vit,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_fusion_block(features, use_bn):
|
|
||||||
return FeatureFusionBlock_custom(
|
|
||||||
features,
|
|
||||||
nn.ReLU(False),
|
|
||||||
deconv=False,
|
|
||||||
bn=use_bn,
|
|
||||||
expand=False,
|
|
||||||
align_corners=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DPT(BaseModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
head,
|
|
||||||
features=256,
|
|
||||||
backbone="vitb_rn50_384",
|
|
||||||
readout="project",
|
|
||||||
channels_last=False,
|
|
||||||
use_bn=False,
|
|
||||||
):
|
|
||||||
|
|
||||||
super(DPT, self).__init__()
|
|
||||||
|
|
||||||
self.channels_last = channels_last
|
|
||||||
|
|
||||||
hooks = {
|
|
||||||
"vitb_rn50_384": [0, 1, 8, 11],
|
|
||||||
"vitb16_384": [2, 5, 8, 11],
|
|
||||||
"vitl16_384": [5, 11, 17, 23],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Instantiate backbone and reassemble blocks
|
|
||||||
self.pretrained, self.scratch = _make_encoder(
|
|
||||||
backbone,
|
|
||||||
features,
|
|
||||||
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
|
||||||
groups=1,
|
|
||||||
expand=False,
|
|
||||||
exportable=False,
|
|
||||||
hooks=hooks[backbone],
|
|
||||||
use_readout=readout,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
|
||||||
|
|
||||||
self.scratch.output_conv = head
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.channels_last == True:
|
|
||||||
x.contiguous(memory_format=torch.channels_last)
|
|
||||||
|
|
||||||
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
|
||||||
|
|
||||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
|
||||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
|
||||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
|
||||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
|
||||||
|
|
||||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
|
||||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
|
||||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
|
||||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
|
||||||
|
|
||||||
out = self.scratch.output_conv(path_1)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DPTDepthModel(DPT):
|
|
||||||
def __init__(self, path=None, non_negative=True, **kwargs):
|
|
||||||
features = kwargs["features"] if "features" in kwargs else 256
|
|
||||||
|
|
||||||
head = nn.Sequential(
|
|
||||||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
|
||||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
|
||||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
|
||||||
nn.Identity(),
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(head, **kwargs)
|
|
||||||
|
|
||||||
if path is not None:
|
|
||||||
self.load(path)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return super().forward(x).squeeze(dim=1)
|
|
||||||
|
|
|
@ -1,76 +0,0 @@
|
||||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
|
||||||
This file contains code that is adapted from
|
|
||||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .base_model import BaseModel
|
|
||||||
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
|
||||||
|
|
||||||
|
|
||||||
class MidasNet(BaseModel):
|
|
||||||
"""Network for monocular depth estimation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path=None, features=256, non_negative=True):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str, optional): Path to saved model. Defaults to None.
|
|
||||||
features (int, optional): Number of features. Defaults to 256.
|
|
||||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
|
||||||
"""
|
|
||||||
print("Loading weights: ", path)
|
|
||||||
|
|
||||||
super(MidasNet, self).__init__()
|
|
||||||
|
|
||||||
use_pretrained = False if path is None else True
|
|
||||||
|
|
||||||
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
|
||||||
|
|
||||||
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
|
||||||
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
|
||||||
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
|
||||||
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
|
||||||
|
|
||||||
self.scratch.output_conv = nn.Sequential(
|
|
||||||
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
|
||||||
Interpolate(scale_factor=2, mode="bilinear"),
|
|
||||||
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if path:
|
|
||||||
self.load(path)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input data (image)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: depth
|
|
||||||
"""
|
|
||||||
|
|
||||||
layer_1 = self.pretrained.layer1(x)
|
|
||||||
layer_2 = self.pretrained.layer2(layer_1)
|
|
||||||
layer_3 = self.pretrained.layer3(layer_2)
|
|
||||||
layer_4 = self.pretrained.layer4(layer_3)
|
|
||||||
|
|
||||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
|
||||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
|
||||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
|
||||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
|
||||||
|
|
||||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
|
||||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
|
||||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
|
||||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
|
||||||
|
|
||||||
out = self.scratch.output_conv(path_1)
|
|
||||||
|
|
||||||
return torch.squeeze(out, dim=1)
|
|
|
@ -1,128 +0,0 @@
|
||||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
|
||||||
This file contains code that is adapted from
|
|
||||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .base_model import BaseModel
|
|
||||||
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
|
||||||
|
|
||||||
|
|
||||||
class MidasNet_small(BaseModel):
|
|
||||||
"""Network for monocular depth estimation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
|
||||||
blocks={'expand': True}):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str, optional): Path to saved model. Defaults to None.
|
|
||||||
features (int, optional): Number of features. Defaults to 256.
|
|
||||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
|
||||||
"""
|
|
||||||
print("Loading weights: ", path)
|
|
||||||
|
|
||||||
super(MidasNet_small, self).__init__()
|
|
||||||
|
|
||||||
use_pretrained = False if path else True
|
|
||||||
|
|
||||||
self.channels_last = channels_last
|
|
||||||
self.blocks = blocks
|
|
||||||
self.backbone = backbone
|
|
||||||
|
|
||||||
self.groups = 1
|
|
||||||
|
|
||||||
features1=features
|
|
||||||
features2=features
|
|
||||||
features3=features
|
|
||||||
features4=features
|
|
||||||
self.expand = False
|
|
||||||
if "expand" in self.blocks and self.blocks['expand'] == True:
|
|
||||||
self.expand = True
|
|
||||||
features1=features
|
|
||||||
features2=features*2
|
|
||||||
features3=features*4
|
|
||||||
features4=features*8
|
|
||||||
|
|
||||||
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
|
||||||
|
|
||||||
self.scratch.activation = nn.ReLU(False)
|
|
||||||
|
|
||||||
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
|
||||||
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
|
||||||
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
|
||||||
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
|
||||||
|
|
||||||
|
|
||||||
self.scratch.output_conv = nn.Sequential(
|
|
||||||
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
|
||||||
Interpolate(scale_factor=2, mode="bilinear"),
|
|
||||||
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
|
||||||
self.scratch.activation,
|
|
||||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
|
||||||
nn.Identity(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if path:
|
|
||||||
self.load(path)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input data (image)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: depth
|
|
||||||
"""
|
|
||||||
if self.channels_last==True:
|
|
||||||
print("self.channels_last = ", self.channels_last)
|
|
||||||
x.contiguous(memory_format=torch.channels_last)
|
|
||||||
|
|
||||||
|
|
||||||
layer_1 = self.pretrained.layer1(x)
|
|
||||||
layer_2 = self.pretrained.layer2(layer_1)
|
|
||||||
layer_3 = self.pretrained.layer3(layer_2)
|
|
||||||
layer_4 = self.pretrained.layer4(layer_3)
|
|
||||||
|
|
||||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
|
||||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
|
||||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
|
||||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
|
||||||
|
|
||||||
|
|
||||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
|
||||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
|
||||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
|
||||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
|
||||||
|
|
||||||
out = self.scratch.output_conv(path_1)
|
|
||||||
|
|
||||||
return torch.squeeze(out, dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def fuse_model(m):
|
|
||||||
prev_previous_type = nn.Identity()
|
|
||||||
prev_previous_name = ''
|
|
||||||
previous_type = nn.Identity()
|
|
||||||
previous_name = ''
|
|
||||||
for name, module in m.named_modules():
|
|
||||||
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
|
||||||
# print("FUSED ", prev_previous_name, previous_name, name)
|
|
||||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
|
||||||
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
|
||||||
# print("FUSED ", prev_previous_name, previous_name)
|
|
||||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
|
||||||
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
|
||||||
# print("FUSED ", previous_name, name)
|
|
||||||
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
|
||||||
|
|
||||||
prev_previous_type = previous_type
|
|
||||||
prev_previous_name = previous_name
|
|
||||||
previous_type = type(module)
|
|
||||||
previous_name = name
|
|
|
@ -1,234 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
|
||||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample (dict): sample
|
|
||||||
size (tuple): image size
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: new size
|
|
||||||
"""
|
|
||||||
shape = list(sample["disparity"].shape)
|
|
||||||
|
|
||||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
|
||||||
return sample
|
|
||||||
|
|
||||||
scale = [0, 0]
|
|
||||||
scale[0] = size[0] / shape[0]
|
|
||||||
scale[1] = size[1] / shape[1]
|
|
||||||
|
|
||||||
scale = max(scale)
|
|
||||||
|
|
||||||
shape[0] = math.ceil(scale * shape[0])
|
|
||||||
shape[1] = math.ceil(scale * shape[1])
|
|
||||||
|
|
||||||
# resize
|
|
||||||
sample["image"] = cv2.resize(
|
|
||||||
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
|
||||||
)
|
|
||||||
|
|
||||||
sample["disparity"] = cv2.resize(
|
|
||||||
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
|
||||||
)
|
|
||||||
sample["mask"] = cv2.resize(
|
|
||||||
sample["mask"].astype(np.float32),
|
|
||||||
tuple(shape[::-1]),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
sample["mask"] = sample["mask"].astype(bool)
|
|
||||||
|
|
||||||
return tuple(shape)
|
|
||||||
|
|
||||||
|
|
||||||
class Resize(object):
|
|
||||||
"""Resize sample to given size (width, height).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
resize_target=True,
|
|
||||||
keep_aspect_ratio=False,
|
|
||||||
ensure_multiple_of=1,
|
|
||||||
resize_method="lower_bound",
|
|
||||||
image_interpolation_method=cv2.INTER_AREA,
|
|
||||||
):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
width (int): desired output width
|
|
||||||
height (int): desired output height
|
|
||||||
resize_target (bool, optional):
|
|
||||||
True: Resize the full sample (image, mask, target).
|
|
||||||
False: Resize image only.
|
|
||||||
Defaults to True.
|
|
||||||
keep_aspect_ratio (bool, optional):
|
|
||||||
True: Keep the aspect ratio of the input sample.
|
|
||||||
Output sample might not have the given width and height, and
|
|
||||||
resize behaviour depends on the parameter 'resize_method'.
|
|
||||||
Defaults to False.
|
|
||||||
ensure_multiple_of (int, optional):
|
|
||||||
Output width and height is constrained to be multiple of this parameter.
|
|
||||||
Defaults to 1.
|
|
||||||
resize_method (str, optional):
|
|
||||||
"lower_bound": Output will be at least as large as the given size.
|
|
||||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
|
||||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
|
||||||
Defaults to "lower_bound".
|
|
||||||
"""
|
|
||||||
self.__width = width
|
|
||||||
self.__height = height
|
|
||||||
|
|
||||||
self.__resize_target = resize_target
|
|
||||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
|
||||||
self.__multiple_of = ensure_multiple_of
|
|
||||||
self.__resize_method = resize_method
|
|
||||||
self.__image_interpolation_method = image_interpolation_method
|
|
||||||
|
|
||||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
|
||||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
if max_val is not None and y > max_val:
|
|
||||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
if y < min_val:
|
|
||||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
def get_size(self, width, height):
|
|
||||||
# determine new height and width
|
|
||||||
scale_height = self.__height / height
|
|
||||||
scale_width = self.__width / width
|
|
||||||
|
|
||||||
if self.__keep_aspect_ratio:
|
|
||||||
if self.__resize_method == "lower_bound":
|
|
||||||
# scale such that output size is lower bound
|
|
||||||
if scale_width > scale_height:
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
elif self.__resize_method == "upper_bound":
|
|
||||||
# scale such that output size is upper bound
|
|
||||||
if scale_width < scale_height:
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
elif self.__resize_method == "minimal":
|
|
||||||
# scale as least as possbile
|
|
||||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"resize_method {self.__resize_method} not implemented"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.__resize_method == "lower_bound":
|
|
||||||
new_height = self.constrain_to_multiple_of(
|
|
||||||
scale_height * height, min_val=self.__height
|
|
||||||
)
|
|
||||||
new_width = self.constrain_to_multiple_of(
|
|
||||||
scale_width * width, min_val=self.__width
|
|
||||||
)
|
|
||||||
elif self.__resize_method == "upper_bound":
|
|
||||||
new_height = self.constrain_to_multiple_of(
|
|
||||||
scale_height * height, max_val=self.__height
|
|
||||||
)
|
|
||||||
new_width = self.constrain_to_multiple_of(
|
|
||||||
scale_width * width, max_val=self.__width
|
|
||||||
)
|
|
||||||
elif self.__resize_method == "minimal":
|
|
||||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
|
||||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
|
||||||
|
|
||||||
return (new_width, new_height)
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
width, height = self.get_size(
|
|
||||||
sample["image"].shape[1], sample["image"].shape[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# resize sample
|
|
||||||
sample["image"] = cv2.resize(
|
|
||||||
sample["image"],
|
|
||||||
(width, height),
|
|
||||||
interpolation=self.__image_interpolation_method,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.__resize_target:
|
|
||||||
if "disparity" in sample:
|
|
||||||
sample["disparity"] = cv2.resize(
|
|
||||||
sample["disparity"],
|
|
||||||
(width, height),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "depth" in sample:
|
|
||||||
sample["depth"] = cv2.resize(
|
|
||||||
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
|
||||||
)
|
|
||||||
|
|
||||||
sample["mask"] = cv2.resize(
|
|
||||||
sample["mask"].astype(np.float32),
|
|
||||||
(width, height),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
sample["mask"] = sample["mask"].astype(bool)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizeImage(object):
|
|
||||||
"""Normlize image by given mean and std.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, mean, std):
|
|
||||||
self.__mean = mean
|
|
||||||
self.__std = std
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
class PrepareForNet(object):
|
|
||||||
"""Prepare sample for usage as network input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
image = np.transpose(sample["image"], (2, 0, 1))
|
|
||||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
|
||||||
|
|
||||||
if "mask" in sample:
|
|
||||||
sample["mask"] = sample["mask"].astype(np.float32)
|
|
||||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
|
||||||
|
|
||||||
if "disparity" in sample:
|
|
||||||
disparity = sample["disparity"].astype(np.float32)
|
|
||||||
sample["disparity"] = np.ascontiguousarray(disparity)
|
|
||||||
|
|
||||||
if "depth" in sample:
|
|
||||||
depth = sample["depth"].astype(np.float32)
|
|
||||||
sample["depth"] = np.ascontiguousarray(depth)
|
|
||||||
|
|
||||||
return sample
|
|
|
@ -1,491 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import timm
|
|
||||||
import types
|
|
||||||
import math
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Slice(nn.Module):
|
|
||||||
def __init__(self, start_index=1):
|
|
||||||
super(Slice, self).__init__()
|
|
||||||
self.start_index = start_index
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x[:, self.start_index :]
|
|
||||||
|
|
||||||
|
|
||||||
class AddReadout(nn.Module):
|
|
||||||
def __init__(self, start_index=1):
|
|
||||||
super(AddReadout, self).__init__()
|
|
||||||
self.start_index = start_index
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.start_index == 2:
|
|
||||||
readout = (x[:, 0] + x[:, 1]) / 2
|
|
||||||
else:
|
|
||||||
readout = x[:, 0]
|
|
||||||
return x[:, self.start_index :] + readout.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectReadout(nn.Module):
|
|
||||||
def __init__(self, in_features, start_index=1):
|
|
||||||
super(ProjectReadout, self).__init__()
|
|
||||||
self.start_index = start_index
|
|
||||||
|
|
||||||
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
|
||||||
features = torch.cat((x[:, self.start_index :], readout), -1)
|
|
||||||
|
|
||||||
return self.project(features)
|
|
||||||
|
|
||||||
|
|
||||||
class Transpose(nn.Module):
|
|
||||||
def __init__(self, dim0, dim1):
|
|
||||||
super(Transpose, self).__init__()
|
|
||||||
self.dim0 = dim0
|
|
||||||
self.dim1 = dim1
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.transpose(self.dim0, self.dim1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def forward_vit(pretrained, x):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
|
|
||||||
glob = pretrained.model.forward_flex(x)
|
|
||||||
|
|
||||||
layer_1 = pretrained.activations["1"]
|
|
||||||
layer_2 = pretrained.activations["2"]
|
|
||||||
layer_3 = pretrained.activations["3"]
|
|
||||||
layer_4 = pretrained.activations["4"]
|
|
||||||
|
|
||||||
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
|
||||||
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
|
||||||
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
|
||||||
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
|
||||||
|
|
||||||
unflatten = nn.Sequential(
|
|
||||||
nn.Unflatten(
|
|
||||||
2,
|
|
||||||
torch.Size(
|
|
||||||
[
|
|
||||||
h // pretrained.model.patch_size[1],
|
|
||||||
w // pretrained.model.patch_size[0],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_1.ndim == 3:
|
|
||||||
layer_1 = unflatten(layer_1)
|
|
||||||
if layer_2.ndim == 3:
|
|
||||||
layer_2 = unflatten(layer_2)
|
|
||||||
if layer_3.ndim == 3:
|
|
||||||
layer_3 = unflatten(layer_3)
|
|
||||||
if layer_4.ndim == 3:
|
|
||||||
layer_4 = unflatten(layer_4)
|
|
||||||
|
|
||||||
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
|
||||||
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
|
||||||
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
|
||||||
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
|
||||||
|
|
||||||
return layer_1, layer_2, layer_3, layer_4
|
|
||||||
|
|
||||||
|
|
||||||
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
|
||||||
posemb_tok, posemb_grid = (
|
|
||||||
posemb[:, : self.start_index],
|
|
||||||
posemb[0, self.start_index :],
|
|
||||||
)
|
|
||||||
|
|
||||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
||||||
|
|
||||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
||||||
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
|
||||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
|
||||||
|
|
||||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
||||||
|
|
||||||
return posemb
|
|
||||||
|
|
||||||
|
|
||||||
def forward_flex(self, x):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
|
|
||||||
pos_embed = self._resize_pos_embed(
|
|
||||||
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
B = x.shape[0]
|
|
||||||
|
|
||||||
if hasattr(self.patch_embed, "backbone"):
|
|
||||||
x = self.patch_embed.backbone(x)
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
|
|
||||||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
if getattr(self, "dist_token", None) is not None:
|
|
||||||
cls_tokens = self.cls_token.expand(
|
|
||||||
B, -1, -1
|
|
||||||
) # stole cls_tokens impl from Phil Wang, thanks
|
|
||||||
dist_token = self.dist_token.expand(B, -1, -1)
|
|
||||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
|
||||||
else:
|
|
||||||
cls_tokens = self.cls_token.expand(
|
|
||||||
B, -1, -1
|
|
||||||
) # stole cls_tokens impl from Phil Wang, thanks
|
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
|
||||||
|
|
||||||
x = x + pos_embed
|
|
||||||
x = self.pos_drop(x)
|
|
||||||
|
|
||||||
for blk in self.blocks:
|
|
||||||
x = blk(x)
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
activations = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_activation(name):
|
|
||||||
def hook(model, input, output):
|
|
||||||
activations[name] = output
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
|
||||||
if use_readout == "ignore":
|
|
||||||
readout_oper = [Slice(start_index)] * len(features)
|
|
||||||
elif use_readout == "add":
|
|
||||||
readout_oper = [AddReadout(start_index)] * len(features)
|
|
||||||
elif use_readout == "project":
|
|
||||||
readout_oper = [
|
|
||||||
ProjectReadout(vit_features, start_index) for out_feat in features
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
False
|
|
||||||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
|
||||||
|
|
||||||
return readout_oper
|
|
||||||
|
|
||||||
|
|
||||||
def _make_vit_b16_backbone(
|
|
||||||
model,
|
|
||||||
features=[96, 192, 384, 768],
|
|
||||||
size=[384, 384],
|
|
||||||
hooks=[2, 5, 8, 11],
|
|
||||||
vit_features=768,
|
|
||||||
use_readout="ignore",
|
|
||||||
start_index=1,
|
|
||||||
):
|
|
||||||
pretrained = nn.Module()
|
|
||||||
|
|
||||||
pretrained.model = model
|
|
||||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
|
||||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
|
||||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
|
||||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
|
||||||
|
|
||||||
pretrained.activations = activations
|
|
||||||
|
|
||||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
|
||||||
|
|
||||||
# 32, 48, 136, 384
|
|
||||||
pretrained.act_postprocess1 = nn.Sequential(
|
|
||||||
readout_oper[0],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[0],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=features[0],
|
|
||||||
out_channels=features[0],
|
|
||||||
kernel_size=4,
|
|
||||||
stride=4,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess2 = nn.Sequential(
|
|
||||||
readout_oper[1],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[1],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=features[1],
|
|
||||||
out_channels=features[1],
|
|
||||||
kernel_size=2,
|
|
||||||
stride=2,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess3 = nn.Sequential(
|
|
||||||
readout_oper[2],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[2],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess4 = nn.Sequential(
|
|
||||||
readout_oper[3],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[3],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=features[3],
|
|
||||||
out_channels=features[3],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.model.start_index = start_index
|
|
||||||
pretrained.model.patch_size = [16, 16]
|
|
||||||
|
|
||||||
# We inject this function into the VisionTransformer instances so that
|
|
||||||
# we can use it with interpolated position embeddings without modifying the library source.
|
|
||||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
|
||||||
pretrained.model._resize_pos_embed = types.MethodType(
|
|
||||||
_resize_pos_embed, pretrained.model
|
|
||||||
)
|
|
||||||
|
|
||||||
return pretrained
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
|
||||||
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
|
||||||
|
|
||||||
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
|
||||||
return _make_vit_b16_backbone(
|
|
||||||
model,
|
|
||||||
features=[256, 512, 1024, 1024],
|
|
||||||
hooks=hooks,
|
|
||||||
vit_features=1024,
|
|
||||||
use_readout=use_readout,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
|
||||||
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
|
||||||
|
|
||||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
|
||||||
return _make_vit_b16_backbone(
|
|
||||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
|
||||||
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
|
||||||
|
|
||||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
|
||||||
return _make_vit_b16_backbone(
|
|
||||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
|
||||||
model = timm.create_model(
|
|
||||||
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
|
||||||
)
|
|
||||||
|
|
||||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
|
||||||
return _make_vit_b16_backbone(
|
|
||||||
model,
|
|
||||||
features=[96, 192, 384, 768],
|
|
||||||
hooks=hooks,
|
|
||||||
use_readout=use_readout,
|
|
||||||
start_index=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_vit_b_rn50_backbone(
|
|
||||||
model,
|
|
||||||
features=[256, 512, 768, 768],
|
|
||||||
size=[384, 384],
|
|
||||||
hooks=[0, 1, 8, 11],
|
|
||||||
vit_features=768,
|
|
||||||
use_vit_only=False,
|
|
||||||
use_readout="ignore",
|
|
||||||
start_index=1,
|
|
||||||
):
|
|
||||||
pretrained = nn.Module()
|
|
||||||
|
|
||||||
pretrained.model = model
|
|
||||||
|
|
||||||
if use_vit_only == True:
|
|
||||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
|
||||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
|
||||||
else:
|
|
||||||
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
|
||||||
get_activation("1")
|
|
||||||
)
|
|
||||||
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
|
||||||
get_activation("2")
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
|
||||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
|
||||||
|
|
||||||
pretrained.activations = activations
|
|
||||||
|
|
||||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
|
||||||
|
|
||||||
if use_vit_only == True:
|
|
||||||
pretrained.act_postprocess1 = nn.Sequential(
|
|
||||||
readout_oper[0],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[0],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=features[0],
|
|
||||||
out_channels=features[0],
|
|
||||||
kernel_size=4,
|
|
||||||
stride=4,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess2 = nn.Sequential(
|
|
||||||
readout_oper[1],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[1],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=features[1],
|
|
||||||
out_channels=features[1],
|
|
||||||
kernel_size=2,
|
|
||||||
stride=2,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pretrained.act_postprocess1 = nn.Sequential(
|
|
||||||
nn.Identity(), nn.Identity(), nn.Identity()
|
|
||||||
)
|
|
||||||
pretrained.act_postprocess2 = nn.Sequential(
|
|
||||||
nn.Identity(), nn.Identity(), nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess3 = nn.Sequential(
|
|
||||||
readout_oper[2],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[2],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess4 = nn.Sequential(
|
|
||||||
readout_oper[3],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[3],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=features[3],
|
|
||||||
out_channels=features[3],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.model.start_index = start_index
|
|
||||||
pretrained.model.patch_size = [16, 16]
|
|
||||||
|
|
||||||
# We inject this function into the VisionTransformer instances so that
|
|
||||||
# we can use it with interpolated position embeddings without modifying the library source.
|
|
||||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
|
||||||
|
|
||||||
# We inject this function into the VisionTransformer instances so that
|
|
||||||
# we can use it with interpolated position embeddings without modifying the library source.
|
|
||||||
pretrained.model._resize_pos_embed = types.MethodType(
|
|
||||||
_resize_pos_embed, pretrained.model
|
|
||||||
)
|
|
||||||
|
|
||||||
return pretrained
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_vitb_rn50_384(
|
|
||||||
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
|
||||||
):
|
|
||||||
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
|
||||||
|
|
||||||
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
|
||||||
return _make_vit_b_rn50_backbone(
|
|
||||||
model,
|
|
||||||
features=[256, 512, 768, 768],
|
|
||||||
size=[384, 384],
|
|
||||||
hooks=hooks,
|
|
||||||
use_vit_only=use_vit_only,
|
|
||||||
use_readout=use_readout,
|
|
||||||
)
|
|
|
@ -1,189 +0,0 @@
|
||||||
"""Utils for monoDepth."""
|
|
||||||
import sys
|
|
||||||
import re
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def read_pfm(path):
|
|
||||||
"""Read pfm file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): path to file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (data, scale)
|
|
||||||
"""
|
|
||||||
with open(path, "rb") as file:
|
|
||||||
|
|
||||||
color = None
|
|
||||||
width = None
|
|
||||||
height = None
|
|
||||||
scale = None
|
|
||||||
endian = None
|
|
||||||
|
|
||||||
header = file.readline().rstrip()
|
|
||||||
if header.decode("ascii") == "PF":
|
|
||||||
color = True
|
|
||||||
elif header.decode("ascii") == "Pf":
|
|
||||||
color = False
|
|
||||||
else:
|
|
||||||
raise Exception("Not a PFM file: " + path)
|
|
||||||
|
|
||||||
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
|
||||||
if dim_match:
|
|
||||||
width, height = list(map(int, dim_match.groups()))
|
|
||||||
else:
|
|
||||||
raise Exception("Malformed PFM header.")
|
|
||||||
|
|
||||||
scale = float(file.readline().decode("ascii").rstrip())
|
|
||||||
if scale < 0:
|
|
||||||
# little-endian
|
|
||||||
endian = "<"
|
|
||||||
scale = -scale
|
|
||||||
else:
|
|
||||||
# big-endian
|
|
||||||
endian = ">"
|
|
||||||
|
|
||||||
data = np.fromfile(file, endian + "f")
|
|
||||||
shape = (height, width, 3) if color else (height, width)
|
|
||||||
|
|
||||||
data = np.reshape(data, shape)
|
|
||||||
data = np.flipud(data)
|
|
||||||
|
|
||||||
return data, scale
|
|
||||||
|
|
||||||
|
|
||||||
def write_pfm(path, image, scale=1):
|
|
||||||
"""Write pfm file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): pathto file
|
|
||||||
image (array): data
|
|
||||||
scale (int, optional): Scale. Defaults to 1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
with open(path, "wb") as file:
|
|
||||||
color = None
|
|
||||||
|
|
||||||
if image.dtype.name != "float32":
|
|
||||||
raise Exception("Image dtype must be float32.")
|
|
||||||
|
|
||||||
image = np.flipud(image)
|
|
||||||
|
|
||||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
|
||||||
color = True
|
|
||||||
elif (
|
|
||||||
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
|
||||||
): # greyscale
|
|
||||||
color = False
|
|
||||||
else:
|
|
||||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
|
||||||
|
|
||||||
file.write("PF\n" if color else "Pf\n".encode())
|
|
||||||
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
|
||||||
|
|
||||||
endian = image.dtype.byteorder
|
|
||||||
|
|
||||||
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
|
||||||
scale = -scale
|
|
||||||
|
|
||||||
file.write("%f\n".encode() % scale)
|
|
||||||
|
|
||||||
image.tofile(file)
|
|
||||||
|
|
||||||
|
|
||||||
def read_image(path):
|
|
||||||
"""Read image and output RGB image (0-1).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): path to file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
array: RGB image (0-1)
|
|
||||||
"""
|
|
||||||
img = cv2.imread(path)
|
|
||||||
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
||||||
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def resize_image(img):
|
|
||||||
"""Resize image and make it fit for network.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img (array): image
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: data ready for network
|
|
||||||
"""
|
|
||||||
height_orig = img.shape[0]
|
|
||||||
width_orig = img.shape[1]
|
|
||||||
|
|
||||||
if width_orig > height_orig:
|
|
||||||
scale = width_orig / 384
|
|
||||||
else:
|
|
||||||
scale = height_orig / 384
|
|
||||||
|
|
||||||
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
|
||||||
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
|
||||||
|
|
||||||
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
|
||||||
|
|
||||||
img_resized = (
|
|
||||||
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
|
||||||
)
|
|
||||||
img_resized = img_resized.unsqueeze(0)
|
|
||||||
|
|
||||||
return img_resized
|
|
||||||
|
|
||||||
|
|
||||||
def resize_depth(depth, width, height):
|
|
||||||
"""Resize depth map and bring to CPU (numpy).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
depth (tensor): depth
|
|
||||||
width (int): image width
|
|
||||||
height (int): image height
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
array: processed depth
|
|
||||||
"""
|
|
||||||
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
|
||||||
|
|
||||||
depth_resized = cv2.resize(
|
|
||||||
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
|
||||||
)
|
|
||||||
|
|
||||||
return depth_resized
|
|
||||||
|
|
||||||
def write_depth(path, depth, bits=1):
|
|
||||||
"""Write depth map to pfm and png file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): filepath without extension
|
|
||||||
depth (array): depth
|
|
||||||
"""
|
|
||||||
write_pfm(path + ".pfm", depth.astype(np.float32))
|
|
||||||
|
|
||||||
depth_min = depth.min()
|
|
||||||
depth_max = depth.max()
|
|
||||||
|
|
||||||
max_val = (2**(8*bits))-1
|
|
||||||
|
|
||||||
if depth_max - depth_min > np.finfo("float").eps:
|
|
||||||
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
|
||||||
else:
|
|
||||||
out = np.zeros(depth.shape, dtype=depth.type)
|
|
||||||
|
|
||||||
if bits == 1:
|
|
||||||
cv2.imwrite(path + ".png", out.astype("uint8"))
|
|
||||||
elif bits == 2:
|
|
||||||
cv2.imwrite(path + ".png", out.astype("uint16"))
|
|
||||||
|
|
||||||
return
|
|
|
@ -1111,7 +1111,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||||
|
|
||||||
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||||
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
|
||||||
|
|
||||||
unclip_model = False
|
unclip_model = False
|
||||||
inpaint_model = False
|
inpaint_model = False
|
||||||
|
@ -1121,11 +1120,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||||
sd_config["embedding_dropout"] = 0.25
|
sd_config["embedding_dropout"] = 0.25
|
||||||
sd_config["conditioning_key"] = 'crossattn-adm'
|
sd_config["conditioning_key"] = 'crossattn-adm'
|
||||||
unclip_model = True
|
unclip_model = True
|
||||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
|
||||||
elif unet_config["in_channels"] > 4: #inpainting model
|
elif unet_config["in_channels"] > 4: #inpainting model
|
||||||
sd_config["conditioning_key"] = "hybrid"
|
sd_config["conditioning_key"] = "hybrid"
|
||||||
sd_config["finetune_keys"] = None
|
sd_config["finetune_keys"] = None
|
||||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
|
||||||
inpaint_model = True
|
inpaint_model = True
|
||||||
else:
|
else:
|
||||||
sd_config["conditioning_key"] = "crossattn"
|
sd_config["conditioning_key"] = "crossattn"
|
||||||
|
|
|
@ -2,7 +2,6 @@ torch
|
||||||
torchdiffeq
|
torchdiffeq
|
||||||
torchsde
|
torchsde
|
||||||
einops
|
einops
|
||||||
open-clip-torch
|
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
safetensors>=0.3.0
|
safetensors>=0.3.0
|
||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
|
|
Loading…
Reference in New Issue