ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/OSA.py

578 lines
15 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: OSA.py
# Created Date: Tuesday April 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
from torch import einsum, nn
from .layernorm import LayerNorm2d
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length=1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x)) + x
class Conv_PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm2d(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x)) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult=2, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Conv_FeedForward(nn.Module):
def __init__(self, dim, mult=2, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1, 1, 0),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1, 1, 0),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Gated_Conv_FeedForward(nn.Module):
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
super().__init__()
hidden_features = int(dim * mult)
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(
hidden_features * 2,
hidden_features * 2,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features * 2,
bias=bias,
)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
# MBConv
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate=0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
Reduce("b c h w -> b c", "mean"),
nn.Linear(dim, hidden_dim, bias=False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias=False),
nn.Sigmoid(),
Rearrange("b c -> b c 1 1"),
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(nn.Module):
def __init__(self, fn, dropout=0.0):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(nn.Module):
def __init__(self, prob=0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0.0 or (not self.training):
return x
keep_mask = (
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
> self.prob
)
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
# nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
),
# nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
# nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout=dropout)
return net
# attention related classes
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=32,
dropout=0.0,
window_size=7,
with_pe=True,
):
super().__init__()
assert (
dim % dim_head
) == 0, "dimension should be divisible by dimension per head"
self.heads = dim // dim_head
self.scale = dim_head**-0.5
self.with_pe = with_pe
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
)
# relative positional bias
if self.with_pe:
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos))
grid = rearrange(grid, "c i j -> (i j) c")
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
grid, "j ... -> 1 j ..."
)
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
dim=-1
)
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = (
*x.shape,
x.device,
self.heads,
)
# flatten
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# split heads
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
# scale
q = q * self.scale
# sim
sim = einsum("b h i d, b h j d -> b h i j", q, k)
# add positional bias
if self.with_pe:
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, "i j h -> h i j")
# attention
attn = self.attend(sim)
# aggregate
out = einsum("b h i j, b h j d -> b h i d", attn, v)
# merge heads
out = rearrange(
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
)
# combine heads out
out = self.to_out(out)
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
class Block_Attention(nn.Module):
def __init__(
self,
dim,
dim_head=32,
bias=False,
dropout=0.0,
window_size=7,
with_pe=True,
):
super().__init__()
assert (
dim % dim_head
) == 0, "dimension should be divisible by dimension per head"
self.heads = dim // dim_head
self.ps = window_size
self.scale = dim_head**-0.5
self.with_pe = with_pe
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3,
dim * 3,
kernel_size=3,
stride=1,
padding=1,
groups=dim * 3,
bias=bias,
)
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
# project for queries, keys, values
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
# split heads
q, k, v = map(
lambda t: rearrange(
t,
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
h=self.heads,
w1=self.ps,
w2=self.ps,
),
(q, k, v),
)
# scale
q = q * self.scale
# sim
sim = einsum("b h i d, b h j d -> b h i j", q, k)
# attention
attn = self.attend(sim)
# aggregate
out = einsum("b h i j, b h j d -> b h i d", attn, v)
# merge heads
out = rearrange(
out,
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
x=h // self.ps,
y=w // self.ps,
head=self.heads,
w1=self.ps,
w2=self.ps,
)
out = self.to_out(out)
return out
class Channel_Attention(nn.Module):
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
super(Channel_Attention, self).__init__()
self.heads = heads
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.ps = window_size
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3,
dim * 3,
kernel_size=3,
stride=1,
padding=1,
groups=dim * 3,
bias=bias,
)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
qkv = qkv.chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(
t,
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
ph=self.ps,
pw=self.ps,
head=self.heads,
),
qkv,
)
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = attn @ v
out = rearrange(
out,
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
h=h // self.ps,
w=w // self.ps,
ph=self.ps,
pw=self.ps,
head=self.heads,
)
out = self.project_out(out)
return out
class Channel_Attention_grid(nn.Module):
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
super(Channel_Attention_grid, self).__init__()
self.heads = heads
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.ps = window_size
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3,
dim * 3,
kernel_size=3,
stride=1,
padding=1,
groups=dim * 3,
bias=bias,
)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
qkv = qkv.chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(
t,
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
ph=self.ps,
pw=self.ps,
head=self.heads,
),
qkv,
)
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = attn @ v
out = rearrange(
out,
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
h=h // self.ps,
w=w // self.ps,
ph=self.ps,
pw=self.ps,
head=self.heads,
)
out = self.project_out(out)
return out
class OSA_Block(nn.Module):
def __init__(
self,
channel_num=64,
bias=True,
ffn_bias=True,
window_size=8,
with_pe=False,
dropout=0.0,
):
super(OSA_Block, self).__init__()
w = window_size
self.layer = nn.Sequential(
MBConv(
channel_num,
channel_num,
downsample=False,
expansion_rate=1,
shrinkage_rate=0.25,
),
Rearrange(
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
), # block-like attention
PreNormResidual(
channel_num,
Attention(
dim=channel_num,
dim_head=channel_num // 4,
dropout=dropout,
window_size=window_size,
with_pe=with_pe,
),
),
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
Conv_PreNormResidual(
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
),
# channel-like attention
Conv_PreNormResidual(
channel_num,
Channel_Attention(
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
),
),
Conv_PreNormResidual(
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
),
Rearrange(
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
), # grid-like attention
PreNormResidual(
channel_num,
Attention(
dim=channel_num,
dim_head=channel_num // 4,
dropout=dropout,
window_size=window_size,
with_pe=with_pe,
),
),
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
Conv_PreNormResidual(
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
),
# channel-like attention
Conv_PreNormResidual(
channel_num,
Channel_Attention_grid(
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
),
),
Conv_PreNormResidual(
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
),
)
def forward(self, x):
out = self.layer(x)
return out