578 lines
15 KiB
Python
578 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: OSA.py
|
|
# Created Date: Tuesday April 28th 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
from einops.layers.torch import Rearrange, Reduce
|
|
from torch import einsum, nn
|
|
|
|
from .layernorm import LayerNorm2d
|
|
|
|
# helpers
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
|
|
def cast_tuple(val, length=1):
|
|
return val if isinstance(val, tuple) else ((val,) * length)
|
|
|
|
|
|
# helper classes
|
|
|
|
|
|
class PreNormResidual(nn.Module):
|
|
def __init__(self, dim, fn):
|
|
super().__init__()
|
|
self.norm = nn.LayerNorm(dim)
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(self.norm(x)) + x
|
|
|
|
|
|
class Conv_PreNormResidual(nn.Module):
|
|
def __init__(self, dim, fn):
|
|
super().__init__()
|
|
self.norm = LayerNorm2d(dim)
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(self.norm(x)) + x
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, mult=2, dropout=0.0):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
self.net = nn.Sequential(
|
|
nn.Linear(dim, inner_dim),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(inner_dim, dim),
|
|
nn.Dropout(dropout),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class Conv_FeedForward(nn.Module):
|
|
def __init__(self, dim, mult=2, dropout=0.0):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
|
nn.Dropout(dropout),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class Gated_Conv_FeedForward(nn.Module):
|
|
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
|
super().__init__()
|
|
|
|
hidden_features = int(dim * mult)
|
|
|
|
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
|
|
|
self.dwconv = nn.Conv2d(
|
|
hidden_features * 2,
|
|
hidden_features * 2,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
groups=hidden_features * 2,
|
|
bias=bias,
|
|
)
|
|
|
|
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
|
|
|
def forward(self, x):
|
|
x = self.project_in(x)
|
|
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
|
x = F.gelu(x1) * x2
|
|
x = self.project_out(x)
|
|
return x
|
|
|
|
|
|
# MBConv
|
|
|
|
|
|
class SqueezeExcitation(nn.Module):
|
|
def __init__(self, dim, shrinkage_rate=0.25):
|
|
super().__init__()
|
|
hidden_dim = int(dim * shrinkage_rate)
|
|
|
|
self.gate = nn.Sequential(
|
|
Reduce("b c h w -> b c", "mean"),
|
|
nn.Linear(dim, hidden_dim, bias=False),
|
|
nn.SiLU(),
|
|
nn.Linear(hidden_dim, dim, bias=False),
|
|
nn.Sigmoid(),
|
|
Rearrange("b c -> b c 1 1"),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return x * self.gate(x)
|
|
|
|
|
|
class MBConvResidual(nn.Module):
|
|
def __init__(self, fn, dropout=0.0):
|
|
super().__init__()
|
|
self.fn = fn
|
|
self.dropsample = Dropsample(dropout)
|
|
|
|
def forward(self, x):
|
|
out = self.fn(x)
|
|
out = self.dropsample(out)
|
|
return out + x
|
|
|
|
|
|
class Dropsample(nn.Module):
|
|
def __init__(self, prob=0):
|
|
super().__init__()
|
|
self.prob = prob
|
|
|
|
def forward(self, x):
|
|
device = x.device
|
|
|
|
if self.prob == 0.0 or (not self.training):
|
|
return x
|
|
|
|
keep_mask = (
|
|
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
|
> self.prob
|
|
)
|
|
return x * keep_mask / (1 - self.prob)
|
|
|
|
|
|
def MBConv(
|
|
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
|
):
|
|
hidden_dim = int(expansion_rate * dim_out)
|
|
stride = 2 if downsample else 1
|
|
|
|
net = nn.Sequential(
|
|
nn.Conv2d(dim_in, hidden_dim, 1),
|
|
# nn.BatchNorm2d(hidden_dim),
|
|
nn.GELU(),
|
|
nn.Conv2d(
|
|
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
|
),
|
|
# nn.BatchNorm2d(hidden_dim),
|
|
nn.GELU(),
|
|
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
|
nn.Conv2d(hidden_dim, dim_out, 1),
|
|
# nn.BatchNorm2d(dim_out)
|
|
)
|
|
|
|
if dim_in == dim_out and not downsample:
|
|
net = MBConvResidual(net, dropout=dropout)
|
|
|
|
return net
|
|
|
|
|
|
# attention related classes
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_head=32,
|
|
dropout=0.0,
|
|
window_size=7,
|
|
with_pe=True,
|
|
):
|
|
super().__init__()
|
|
assert (
|
|
dim % dim_head
|
|
) == 0, "dimension should be divisible by dimension per head"
|
|
|
|
self.heads = dim // dim_head
|
|
self.scale = dim_head**-0.5
|
|
self.with_pe = with_pe
|
|
|
|
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
|
|
|
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
|
|
|
self.to_out = nn.Sequential(
|
|
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
|
)
|
|
|
|
# relative positional bias
|
|
if self.with_pe:
|
|
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
|
|
|
pos = torch.arange(window_size)
|
|
grid = torch.stack(torch.meshgrid(pos, pos))
|
|
grid = rearrange(grid, "c i j -> (i j) c")
|
|
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
|
grid, "j ... -> 1 j ..."
|
|
)
|
|
rel_pos += window_size - 1
|
|
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
|
dim=-1
|
|
)
|
|
|
|
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
|
|
|
def forward(self, x):
|
|
batch, height, width, window_height, window_width, _, device, h = (
|
|
*x.shape,
|
|
x.device,
|
|
self.heads,
|
|
)
|
|
|
|
# flatten
|
|
|
|
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
|
|
|
# project for queries, keys, values
|
|
|
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
|
|
|
# split heads
|
|
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
|
|
|
# scale
|
|
|
|
q = q * self.scale
|
|
|
|
# sim
|
|
|
|
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
|
|
|
# add positional bias
|
|
if self.with_pe:
|
|
bias = self.rel_pos_bias(self.rel_pos_indices)
|
|
sim = sim + rearrange(bias, "i j h -> h i j")
|
|
|
|
# attention
|
|
|
|
attn = self.attend(sim)
|
|
|
|
# aggregate
|
|
|
|
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
|
|
|
# merge heads
|
|
|
|
out = rearrange(
|
|
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
|
)
|
|
|
|
# combine heads out
|
|
|
|
out = self.to_out(out)
|
|
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
|
|
|
|
|
class Block_Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_head=32,
|
|
bias=False,
|
|
dropout=0.0,
|
|
window_size=7,
|
|
with_pe=True,
|
|
):
|
|
super().__init__()
|
|
assert (
|
|
dim % dim_head
|
|
) == 0, "dimension should be divisible by dimension per head"
|
|
|
|
self.heads = dim // dim_head
|
|
self.ps = window_size
|
|
self.scale = dim_head**-0.5
|
|
self.with_pe = with_pe
|
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
|
self.qkv_dwconv = nn.Conv2d(
|
|
dim * 3,
|
|
dim * 3,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
groups=dim * 3,
|
|
bias=bias,
|
|
)
|
|
|
|
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
|
|
|
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
|
|
|
def forward(self, x):
|
|
# project for queries, keys, values
|
|
b, c, h, w = x.shape
|
|
|
|
qkv = self.qkv_dwconv(self.qkv(x))
|
|
q, k, v = qkv.chunk(3, dim=1)
|
|
|
|
# split heads
|
|
|
|
q, k, v = map(
|
|
lambda t: rearrange(
|
|
t,
|
|
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
|
h=self.heads,
|
|
w1=self.ps,
|
|
w2=self.ps,
|
|
),
|
|
(q, k, v),
|
|
)
|
|
|
|
# scale
|
|
|
|
q = q * self.scale
|
|
|
|
# sim
|
|
|
|
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
|
|
|
# attention
|
|
attn = self.attend(sim)
|
|
|
|
# aggregate
|
|
|
|
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
|
|
|
# merge heads
|
|
out = rearrange(
|
|
out,
|
|
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
|
x=h // self.ps,
|
|
y=w // self.ps,
|
|
head=self.heads,
|
|
w1=self.ps,
|
|
w2=self.ps,
|
|
)
|
|
|
|
out = self.to_out(out)
|
|
return out
|
|
|
|
|
|
class Channel_Attention(nn.Module):
|
|
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
|
super(Channel_Attention, self).__init__()
|
|
self.heads = heads
|
|
|
|
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
|
|
|
self.ps = window_size
|
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
|
self.qkv_dwconv = nn.Conv2d(
|
|
dim * 3,
|
|
dim * 3,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
groups=dim * 3,
|
|
bias=bias,
|
|
)
|
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
|
|
qkv = self.qkv_dwconv(self.qkv(x))
|
|
qkv = qkv.chunk(3, dim=1)
|
|
|
|
q, k, v = map(
|
|
lambda t: rearrange(
|
|
t,
|
|
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
|
ph=self.ps,
|
|
pw=self.ps,
|
|
head=self.heads,
|
|
),
|
|
qkv,
|
|
)
|
|
|
|
q = F.normalize(q, dim=-1)
|
|
k = F.normalize(k, dim=-1)
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
|
attn = attn.softmax(dim=-1)
|
|
out = attn @ v
|
|
|
|
out = rearrange(
|
|
out,
|
|
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
|
h=h // self.ps,
|
|
w=w // self.ps,
|
|
ph=self.ps,
|
|
pw=self.ps,
|
|
head=self.heads,
|
|
)
|
|
|
|
out = self.project_out(out)
|
|
|
|
return out
|
|
|
|
|
|
class Channel_Attention_grid(nn.Module):
|
|
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
|
super(Channel_Attention_grid, self).__init__()
|
|
self.heads = heads
|
|
|
|
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
|
|
|
self.ps = window_size
|
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
|
self.qkv_dwconv = nn.Conv2d(
|
|
dim * 3,
|
|
dim * 3,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
groups=dim * 3,
|
|
bias=bias,
|
|
)
|
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
|
|
qkv = self.qkv_dwconv(self.qkv(x))
|
|
qkv = qkv.chunk(3, dim=1)
|
|
|
|
q, k, v = map(
|
|
lambda t: rearrange(
|
|
t,
|
|
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
|
ph=self.ps,
|
|
pw=self.ps,
|
|
head=self.heads,
|
|
),
|
|
qkv,
|
|
)
|
|
|
|
q = F.normalize(q, dim=-1)
|
|
k = F.normalize(k, dim=-1)
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
|
attn = attn.softmax(dim=-1)
|
|
out = attn @ v
|
|
|
|
out = rearrange(
|
|
out,
|
|
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
|
h=h // self.ps,
|
|
w=w // self.ps,
|
|
ph=self.ps,
|
|
pw=self.ps,
|
|
head=self.heads,
|
|
)
|
|
|
|
out = self.project_out(out)
|
|
|
|
return out
|
|
|
|
|
|
class OSA_Block(nn.Module):
|
|
def __init__(
|
|
self,
|
|
channel_num=64,
|
|
bias=True,
|
|
ffn_bias=True,
|
|
window_size=8,
|
|
with_pe=False,
|
|
dropout=0.0,
|
|
):
|
|
super(OSA_Block, self).__init__()
|
|
|
|
w = window_size
|
|
|
|
self.layer = nn.Sequential(
|
|
MBConv(
|
|
channel_num,
|
|
channel_num,
|
|
downsample=False,
|
|
expansion_rate=1,
|
|
shrinkage_rate=0.25,
|
|
),
|
|
Rearrange(
|
|
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
|
), # block-like attention
|
|
PreNormResidual(
|
|
channel_num,
|
|
Attention(
|
|
dim=channel_num,
|
|
dim_head=channel_num // 4,
|
|
dropout=dropout,
|
|
window_size=window_size,
|
|
with_pe=with_pe,
|
|
),
|
|
),
|
|
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
|
Conv_PreNormResidual(
|
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
|
),
|
|
# channel-like attention
|
|
Conv_PreNormResidual(
|
|
channel_num,
|
|
Channel_Attention(
|
|
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
|
),
|
|
),
|
|
Conv_PreNormResidual(
|
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
|
),
|
|
Rearrange(
|
|
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
|
), # grid-like attention
|
|
PreNormResidual(
|
|
channel_num,
|
|
Attention(
|
|
dim=channel_num,
|
|
dim_head=channel_num // 4,
|
|
dropout=dropout,
|
|
window_size=window_size,
|
|
with_pe=with_pe,
|
|
),
|
|
),
|
|
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
|
Conv_PreNormResidual(
|
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
|
),
|
|
# channel-like attention
|
|
Conv_PreNormResidual(
|
|
channel_num,
|
|
Channel_Attention_grid(
|
|
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
|
),
|
|
),
|
|
Conv_PreNormResidual(
|
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
|
),
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.layer(x)
|
|
return out
|