Pull in latest upscale model code from chainner.
This commit is contained in:
parent
c00bb1a0b7
commit
7310290f17
|
@ -0,0 +1,110 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class CA_layer(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16):
|
||||||
|
super(CA_layer, self).__init__()
|
||||||
|
# global average pooling
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
|
||||||
|
# nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.fc(self.gap(x))
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Simple_CA_layer(nn.Module):
|
||||||
|
def __init__(self, channel):
|
||||||
|
super(Simple_CA_layer, self).__init__()
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Conv2d(
|
||||||
|
in_channels=channel,
|
||||||
|
out_channels=channel,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.fc(self.gap(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ECA_layer(nn.Module):
|
||||||
|
"""Constructs a ECA module.
|
||||||
|
Args:
|
||||||
|
channel: Number of channels of the input feature map
|
||||||
|
k_size: Adaptive selection of kernel size
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
super(ECA_layer, self).__init__()
|
||||||
|
|
||||||
|
b = 1
|
||||||
|
gamma = 2
|
||||||
|
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||||
|
k_size = k_size if k_size % 2 else k_size + 1
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||||
|
)
|
||||||
|
# self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: input features with shape [b, c, h, w]
|
||||||
|
# b, c, h, w = x.size()
|
||||||
|
|
||||||
|
# feature descriptor on the global spatial information
|
||||||
|
y = self.avg_pool(x)
|
||||||
|
|
||||||
|
# Two different branches of ECA module
|
||||||
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Multi-scale information fusion
|
||||||
|
# y = self.sigmoid(y)
|
||||||
|
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ECA_MaxPool_layer(nn.Module):
|
||||||
|
"""Constructs a ECA module.
|
||||||
|
Args:
|
||||||
|
channel: Number of channels of the input feature map
|
||||||
|
k_size: Adaptive selection of kernel size
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
super(ECA_MaxPool_layer, self).__init__()
|
||||||
|
|
||||||
|
b = 1
|
||||||
|
gamma = 2
|
||||||
|
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||||
|
k_size = k_size if k_size % 2 else k_size + 1
|
||||||
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||||
|
)
|
||||||
|
# self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: input features with shape [b, c, h, w]
|
||||||
|
# b, c, h, w = x.size()
|
||||||
|
|
||||||
|
# feature descriptor on the global spatial information
|
||||||
|
y = self.max_pool(x)
|
||||||
|
|
||||||
|
# Two different branches of ECA module
|
||||||
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Multi-scale information fusion
|
||||||
|
# y = self.sigmoid(y)
|
||||||
|
|
||||||
|
return x * y.expand_as(x)
|
|
@ -0,0 +1,201 @@
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
|
@ -0,0 +1,577 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: OSA.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
from torch import einsum, nn
|
||||||
|
|
||||||
|
from .layernorm import LayerNorm2d
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(val, length=1):
|
||||||
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
|
|
||||||
|
# helper classes
|
||||||
|
|
||||||
|
|
||||||
|
class PreNormResidual(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fn(self.norm(x)) + x
|
||||||
|
|
||||||
|
|
||||||
|
class Conv_PreNormResidual(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = LayerNorm2d(dim)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fn(self.norm(x)) + x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, mult=2, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, inner_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(inner_dim, dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv_FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, mult=2, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Gated_Conv_FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hidden_features = int(dim * mult)
|
||||||
|
|
||||||
|
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
self.dwconv = nn.Conv2d(
|
||||||
|
hidden_features * 2,
|
||||||
|
hidden_features * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=hidden_features * 2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.project_in(x)
|
||||||
|
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||||
|
x = F.gelu(x1) * x2
|
||||||
|
x = self.project_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# MBConv
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeExcitation(nn.Module):
|
||||||
|
def __init__(self, dim, shrinkage_rate=0.25):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(dim * shrinkage_rate)
|
||||||
|
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
Reduce("b c h w -> b c", "mean"),
|
||||||
|
nn.Linear(dim, hidden_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_dim, dim, bias=False),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
Rearrange("b c -> b c 1 1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.gate(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MBConvResidual(nn.Module):
|
||||||
|
def __init__(self, fn, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
self.dropsample = Dropsample(dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.fn(x)
|
||||||
|
out = self.dropsample(out)
|
||||||
|
return out + x
|
||||||
|
|
||||||
|
|
||||||
|
class Dropsample(nn.Module):
|
||||||
|
def __init__(self, prob=0):
|
||||||
|
super().__init__()
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
if self.prob == 0.0 or (not self.training):
|
||||||
|
return x
|
||||||
|
|
||||||
|
keep_mask = (
|
||||||
|
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
||||||
|
> self.prob
|
||||||
|
)
|
||||||
|
return x * keep_mask / (1 - self.prob)
|
||||||
|
|
||||||
|
|
||||||
|
def MBConv(
|
||||||
|
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
||||||
|
):
|
||||||
|
hidden_dim = int(expansion_rate * dim_out)
|
||||||
|
stride = 2 if downsample else 1
|
||||||
|
|
||||||
|
net = nn.Sequential(
|
||||||
|
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||||
|
# nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
||||||
|
),
|
||||||
|
# nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
||||||
|
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||||
|
# nn.BatchNorm2d(dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
if dim_in == dim_out and not downsample:
|
||||||
|
net = MBConvResidual(net, dropout=dropout)
|
||||||
|
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
# attention related classes
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_head=32,
|
||||||
|
dropout=0.0,
|
||||||
|
window_size=7,
|
||||||
|
with_pe=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
dim % dim_head
|
||||||
|
) == 0, "dimension should be divisible by dimension per head"
|
||||||
|
|
||||||
|
self.heads = dim // dim_head
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.with_pe = with_pe
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||||
|
|
||||||
|
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
# relative positional bias
|
||||||
|
if self.with_pe:
|
||||||
|
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||||
|
|
||||||
|
pos = torch.arange(window_size)
|
||||||
|
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||||
|
grid = rearrange(grid, "c i j -> (i j) c")
|
||||||
|
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
||||||
|
grid, "j ... -> 1 j ..."
|
||||||
|
)
|
||||||
|
rel_pos += window_size - 1
|
||||||
|
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch, height, width, window_height, window_width, _, device, h = (
|
||||||
|
*x.shape,
|
||||||
|
x.device,
|
||||||
|
self.heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
|
||||||
|
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
||||||
|
|
||||||
|
# project for queries, keys, values
|
||||||
|
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# sim
|
||||||
|
|
||||||
|
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||||
|
|
||||||
|
# add positional bias
|
||||||
|
if self.with_pe:
|
||||||
|
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||||
|
sim = sim + rearrange(bias, "i j h -> h i j")
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
attn = self.attend(sim)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
|
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
|
||||||
|
out = rearrange(
|
||||||
|
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# combine heads out
|
||||||
|
|
||||||
|
out = self.to_out(out)
|
||||||
|
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
||||||
|
|
||||||
|
|
||||||
|
class Block_Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_head=32,
|
||||||
|
bias=False,
|
||||||
|
dropout=0.0,
|
||||||
|
window_size=7,
|
||||||
|
with_pe=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
dim % dim_head
|
||||||
|
) == 0, "dimension should be divisible by dimension per head"
|
||||||
|
|
||||||
|
self.heads = dim // dim_head
|
||||||
|
self.ps = window_size
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.with_pe = with_pe
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||||
|
self.qkv_dwconv = nn.Conv2d(
|
||||||
|
dim * 3,
|
||||||
|
dim * 3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=dim * 3,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# project for queries, keys, values
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_dwconv(self.qkv(x))
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(
|
||||||
|
t,
|
||||||
|
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
||||||
|
h=self.heads,
|
||||||
|
w1=self.ps,
|
||||||
|
w2=self.ps,
|
||||||
|
),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# sim
|
||||||
|
|
||||||
|
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
attn = self.attend(sim)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
|
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
out = rearrange(
|
||||||
|
out,
|
||||||
|
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
||||||
|
x=h // self.ps,
|
||||||
|
y=w // self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
w1=self.ps,
|
||||||
|
w2=self.ps,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Channel_Attention(nn.Module):
|
||||||
|
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||||
|
super(Channel_Attention, self).__init__()
|
||||||
|
self.heads = heads
|
||||||
|
|
||||||
|
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||||
|
|
||||||
|
self.ps = window_size
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||||
|
self.qkv_dwconv = nn.Conv2d(
|
||||||
|
dim * 3,
|
||||||
|
dim * 3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=dim * 3,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_dwconv(self.qkv(x))
|
||||||
|
qkv = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(
|
||||||
|
t,
|
||||||
|
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
||||||
|
ph=self.ps,
|
||||||
|
pw=self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
),
|
||||||
|
qkv,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = F.normalize(q, dim=-1)
|
||||||
|
k = F.normalize(k, dim=-1)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
out = attn @ v
|
||||||
|
|
||||||
|
out = rearrange(
|
||||||
|
out,
|
||||||
|
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
||||||
|
h=h // self.ps,
|
||||||
|
w=w // self.ps,
|
||||||
|
ph=self.ps,
|
||||||
|
pw=self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.project_out(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Channel_Attention_grid(nn.Module):
|
||||||
|
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||||
|
super(Channel_Attention_grid, self).__init__()
|
||||||
|
self.heads = heads
|
||||||
|
|
||||||
|
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||||
|
|
||||||
|
self.ps = window_size
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||||
|
self.qkv_dwconv = nn.Conv2d(
|
||||||
|
dim * 3,
|
||||||
|
dim * 3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=dim * 3,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_dwconv(self.qkv(x))
|
||||||
|
qkv = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(
|
||||||
|
t,
|
||||||
|
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
||||||
|
ph=self.ps,
|
||||||
|
pw=self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
),
|
||||||
|
qkv,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = F.normalize(q, dim=-1)
|
||||||
|
k = F.normalize(k, dim=-1)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
out = attn @ v
|
||||||
|
|
||||||
|
out = rearrange(
|
||||||
|
out,
|
||||||
|
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
||||||
|
h=h // self.ps,
|
||||||
|
w=w // self.ps,
|
||||||
|
ph=self.ps,
|
||||||
|
pw=self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.project_out(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class OSA_Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channel_num=64,
|
||||||
|
bias=True,
|
||||||
|
ffn_bias=True,
|
||||||
|
window_size=8,
|
||||||
|
with_pe=False,
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super(OSA_Block, self).__init__()
|
||||||
|
|
||||||
|
w = window_size
|
||||||
|
|
||||||
|
self.layer = nn.Sequential(
|
||||||
|
MBConv(
|
||||||
|
channel_num,
|
||||||
|
channel_num,
|
||||||
|
downsample=False,
|
||||||
|
expansion_rate=1,
|
||||||
|
shrinkage_rate=0.25,
|
||||||
|
),
|
||||||
|
Rearrange(
|
||||||
|
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
||||||
|
), # block-like attention
|
||||||
|
PreNormResidual(
|
||||||
|
channel_num,
|
||||||
|
Attention(
|
||||||
|
dim=channel_num,
|
||||||
|
dim_head=channel_num // 4,
|
||||||
|
dropout=dropout,
|
||||||
|
window_size=window_size,
|
||||||
|
with_pe=with_pe,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||||
|
),
|
||||||
|
# channel-like attention
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num,
|
||||||
|
Channel_Attention(
|
||||||
|
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||||
|
),
|
||||||
|
Rearrange(
|
||||||
|
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
||||||
|
), # grid-like attention
|
||||||
|
PreNormResidual(
|
||||||
|
channel_num,
|
||||||
|
Attention(
|
||||||
|
dim=channel_num,
|
||||||
|
dim_head=channel_num // 4,
|
||||||
|
dropout=dropout,
|
||||||
|
window_size=window_size,
|
||||||
|
with_pe=with_pe,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||||
|
),
|
||||||
|
# channel-like attention
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num,
|
||||||
|
Channel_Attention_grid(
|
||||||
|
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.layer(x)
|
||||||
|
return out
|
|
@ -0,0 +1,60 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: OSAG.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .esa import ESA
|
||||||
|
from .OSA import OSA_Block
|
||||||
|
|
||||||
|
|
||||||
|
class OSAG(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channel_num=64,
|
||||||
|
bias=True,
|
||||||
|
block_num=4,
|
||||||
|
ffn_bias=False,
|
||||||
|
window_size=0,
|
||||||
|
pe=False,
|
||||||
|
):
|
||||||
|
super(OSAG, self).__init__()
|
||||||
|
|
||||||
|
# print("window_size: %d" % (window_size))
|
||||||
|
# print("with_pe", pe)
|
||||||
|
# print("ffn_bias: %d" % (ffn_bias))
|
||||||
|
|
||||||
|
# block_script_name = kwargs.get("block_script_name", "OSA")
|
||||||
|
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
|
||||||
|
|
||||||
|
# script_name = "." + block_script_name
|
||||||
|
# package = __import__(script_name, fromlist=True)
|
||||||
|
block_class = OSA_Block # getattr(package, block_class_name)
|
||||||
|
group_list = []
|
||||||
|
for _ in range(block_num):
|
||||||
|
temp_res = block_class(
|
||||||
|
channel_num,
|
||||||
|
bias,
|
||||||
|
ffn_bias=ffn_bias,
|
||||||
|
window_size=window_size,
|
||||||
|
with_pe=pe,
|
||||||
|
)
|
||||||
|
group_list.append(temp_res)
|
||||||
|
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
|
||||||
|
self.residual_layer = nn.Sequential(*group_list)
|
||||||
|
esa_channel = max(channel_num // 4, 16)
|
||||||
|
self.esa = ESA(esa_channel, channel_num)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.residual_layer(x)
|
||||||
|
out = out + x
|
||||||
|
return self.esa(out)
|
|
@ -0,0 +1,133 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: OmniSR.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .OSAG import OSAG
|
||||||
|
from .pixelshuffle import pixelshuffle_block
|
||||||
|
|
||||||
|
|
||||||
|
class OmniSR(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(OmniSR, self).__init__()
|
||||||
|
self.state = state_dict
|
||||||
|
|
||||||
|
bias = True # Fine to assume this for now
|
||||||
|
block_num = 1 # Fine to assume this for now
|
||||||
|
ffn_bias = True
|
||||||
|
pe = True
|
||||||
|
|
||||||
|
num_feat = state_dict["input.weight"].shape[0] or 64
|
||||||
|
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
||||||
|
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
||||||
|
|
||||||
|
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
||||||
|
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
||||||
|
if up_scale - int(up_scale) > 0:
|
||||||
|
print(
|
||||||
|
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
||||||
|
)
|
||||||
|
up_scale = int(up_scale)
|
||||||
|
res_num = 0
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if "residual_layer" in key:
|
||||||
|
temp_res_num = int(key.split(".")[1])
|
||||||
|
if temp_res_num > res_num:
|
||||||
|
res_num = temp_res_num
|
||||||
|
res_num = res_num + 1 # zero-indexed
|
||||||
|
|
||||||
|
residual_layer = []
|
||||||
|
self.res_num = res_num
|
||||||
|
|
||||||
|
self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer)
|
||||||
|
self.up_scale = up_scale
|
||||||
|
|
||||||
|
for _ in range(res_num):
|
||||||
|
temp_res = OSAG(
|
||||||
|
channel_num=num_feat,
|
||||||
|
bias=bias,
|
||||||
|
block_num=block_num,
|
||||||
|
ffn_bias=ffn_bias,
|
||||||
|
window_size=self.window_size,
|
||||||
|
pe=pe,
|
||||||
|
)
|
||||||
|
residual_layer.append(temp_res)
|
||||||
|
self.residual_layer = nn.Sequential(*residual_layer)
|
||||||
|
self.input = nn.Conv2d(
|
||||||
|
in_channels=num_in_ch,
|
||||||
|
out_channels=num_feat,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.output = nn.Conv2d(
|
||||||
|
in_channels=num_feat,
|
||||||
|
out_channels=num_feat,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
||||||
|
|
||||||
|
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
||||||
|
|
||||||
|
# for m in self.modules():
|
||||||
|
# if isinstance(m, nn.Conv2d):
|
||||||
|
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
# m.weight.data.normal_(0, sqrt(2. / n))
|
||||||
|
|
||||||
|
# chaiNNer specific stuff
|
||||||
|
self.model_arch = "OmniSR"
|
||||||
|
self.sub_type = "SR"
|
||||||
|
self.in_nc = num_in_ch
|
||||||
|
self.out_nc = num_out_ch
|
||||||
|
self.num_feat = num_feat
|
||||||
|
self.scale = up_scale
|
||||||
|
|
||||||
|
self.supports_fp16 = True # TODO: Test this
|
||||||
|
self.supports_bfp16 = True
|
||||||
|
self.min_size_restriction = 16
|
||||||
|
|
||||||
|
self.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
def check_image_size(self, x):
|
||||||
|
_, _, h, w = x.size()
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
||||||
|
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
||||||
|
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
||||||
|
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
H, W = x.shape[2:]
|
||||||
|
x = self.check_image_size(x)
|
||||||
|
|
||||||
|
residual = self.input(x)
|
||||||
|
out = self.residual_layer(residual)
|
||||||
|
|
||||||
|
# origin
|
||||||
|
out = torch.add(self.output(out), residual)
|
||||||
|
out = self.up(out)
|
||||||
|
|
||||||
|
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
||||||
|
return out
|
|
@ -0,0 +1,294 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: esa.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Thursday, 20th April 2023 9:28:06 am
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .layernorm import LayerNorm2d
|
||||||
|
|
||||||
|
|
||||||
|
def moment(x, dim=(2, 3), k=2):
|
||||||
|
assert len(x.size()) == 4
|
||||||
|
mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
|
||||||
|
return mk
|
||||||
|
|
||||||
|
|
||||||
|
class ESA(nn.Module):
|
||||||
|
"""
|
||||||
|
Modification of Enhanced Spatial Attention (ESA), which is proposed by
|
||||||
|
`Residual Feature Aggregation Network for Image Super-Resolution`
|
||||||
|
Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
|
||||||
|
are deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
|
||||||
|
super(ESA, self).__init__()
|
||||||
|
f = esa_channels
|
||||||
|
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||||
|
self.conv_f = conv(f, f, kernel_size=1)
|
||||||
|
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
||||||
|
self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
||||||
|
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1_ = self.conv1(x)
|
||||||
|
c1 = self.conv2(c1_)
|
||||||
|
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
||||||
|
c3 = self.conv3(v_max)
|
||||||
|
c3 = F.interpolate(
|
||||||
|
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
cf = self.conv_f(c1_)
|
||||||
|
c4 = self.conv4(c3 + cf)
|
||||||
|
m = self.sigmoid(c4)
|
||||||
|
return x * m
|
||||||
|
|
||||||
|
|
||||||
|
class LK_ESA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||||
|
):
|
||||||
|
super(LK_ESA, self).__init__()
|
||||||
|
f = esa_channels
|
||||||
|
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||||
|
self.conv_f = conv(f, f, kernel_size=1)
|
||||||
|
|
||||||
|
kernel_size = 17
|
||||||
|
kernel_expand = kernel_expand
|
||||||
|
padding = kernel_size // 2
|
||||||
|
|
||||||
|
self.vec_conv = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(1, kernel_size),
|
||||||
|
padding=(0, padding),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.vec_conv3x1 = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(1, 3),
|
||||||
|
padding=(0, 1),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hor_conv = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(kernel_size, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.hor_conv1x3 = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(3, 1),
|
||||||
|
padding=(1, 0),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1_ = self.conv1(x)
|
||||||
|
|
||||||
|
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||||
|
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||||
|
|
||||||
|
cf = self.conv_f(c1_)
|
||||||
|
c4 = self.conv4(res + cf)
|
||||||
|
m = self.sigmoid(c4)
|
||||||
|
return x * m
|
||||||
|
|
||||||
|
|
||||||
|
class LK_ESA_LN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||||
|
):
|
||||||
|
super(LK_ESA_LN, self).__init__()
|
||||||
|
f = esa_channels
|
||||||
|
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||||
|
self.conv_f = conv(f, f, kernel_size=1)
|
||||||
|
|
||||||
|
kernel_size = 17
|
||||||
|
kernel_expand = kernel_expand
|
||||||
|
padding = kernel_size // 2
|
||||||
|
|
||||||
|
self.norm = LayerNorm2d(n_feats)
|
||||||
|
|
||||||
|
self.vec_conv = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(1, kernel_size),
|
||||||
|
padding=(0, padding),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.vec_conv3x1 = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(1, 3),
|
||||||
|
padding=(0, 1),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hor_conv = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(kernel_size, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.hor_conv1x3 = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(3, 1),
|
||||||
|
padding=(1, 0),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1_ = self.norm(x)
|
||||||
|
c1_ = self.conv1(c1_)
|
||||||
|
|
||||||
|
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||||
|
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||||
|
|
||||||
|
cf = self.conv_f(c1_)
|
||||||
|
c4 = self.conv4(res + cf)
|
||||||
|
m = self.sigmoid(c4)
|
||||||
|
return x * m
|
||||||
|
|
||||||
|
|
||||||
|
class AdaGuidedFilter(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||||
|
):
|
||||||
|
super(AdaGuidedFilter, self).__init__()
|
||||||
|
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Conv2d(
|
||||||
|
in_channels=n_feats,
|
||||||
|
out_channels=1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.r = 5
|
||||||
|
|
||||||
|
def box_filter(self, x, r):
|
||||||
|
channel = x.shape[1]
|
||||||
|
kernel_size = 2 * r + 1
|
||||||
|
weight = 1.0 / (kernel_size**2)
|
||||||
|
box_kernel = weight * torch.ones(
|
||||||
|
(channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
|
||||||
|
)
|
||||||
|
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
N = self.box_filter(
|
||||||
|
torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
|
||||||
|
)
|
||||||
|
|
||||||
|
# epsilon = self.fc(self.gap(x))
|
||||||
|
# epsilon = torch.pow(epsilon, 2)
|
||||||
|
epsilon = 1e-2
|
||||||
|
|
||||||
|
mean_x = self.box_filter(x, self.r) / N
|
||||||
|
var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
|
||||||
|
|
||||||
|
A = var_x / (var_x + epsilon)
|
||||||
|
b = (1 - A) * mean_x
|
||||||
|
m = A * x + b
|
||||||
|
|
||||||
|
# mean_A = self.box_filter(A, self.r) / N
|
||||||
|
# mean_b = self.box_filter(b, self.r) / N
|
||||||
|
# m = mean_A * x + mean_b
|
||||||
|
return x * m
|
||||||
|
|
||||||
|
|
||||||
|
class AdaConvGuidedFilter(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||||
|
):
|
||||||
|
super(AdaConvGuidedFilter, self).__init__()
|
||||||
|
f = esa_channels
|
||||||
|
|
||||||
|
self.conv_f = conv(f, f, kernel_size=1)
|
||||||
|
|
||||||
|
kernel_size = 17
|
||||||
|
kernel_expand = kernel_expand
|
||||||
|
padding = kernel_size // 2
|
||||||
|
|
||||||
|
self.vec_conv = nn.Conv2d(
|
||||||
|
in_channels=f,
|
||||||
|
out_channels=f,
|
||||||
|
kernel_size=(1, kernel_size),
|
||||||
|
padding=(0, padding),
|
||||||
|
groups=f,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hor_conv = nn.Conv2d(
|
||||||
|
in_channels=f,
|
||||||
|
out_channels=f,
|
||||||
|
kernel_size=(kernel_size, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
groups=f,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Conv2d(
|
||||||
|
in_channels=f,
|
||||||
|
out_channels=f,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.vec_conv(x)
|
||||||
|
y = self.hor_conv(y)
|
||||||
|
|
||||||
|
sigma = torch.pow(y, 2)
|
||||||
|
epsilon = self.fc(self.gap(y))
|
||||||
|
|
||||||
|
weight = sigma / (sigma + epsilon)
|
||||||
|
|
||||||
|
m = weight * x + (1 - weight)
|
||||||
|
|
||||||
|
return x * m
|
|
@ -0,0 +1,70 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: layernorm.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Thursday, 20th April 2023 9:28:20 am
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, weight, bias, eps):
|
||||||
|
ctx.eps = eps
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
mu = x.mean(1, keepdim=True)
|
||||||
|
var = (x - mu).pow(2).mean(1, keepdim=True)
|
||||||
|
y = (x - mu) / (var + eps).sqrt()
|
||||||
|
ctx.save_for_backward(y, var, weight)
|
||||||
|
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
eps = ctx.eps
|
||||||
|
|
||||||
|
N, C, H, W = grad_output.size()
|
||||||
|
y, var, weight = ctx.saved_variables
|
||||||
|
g = grad_output * weight.view(1, C, 1, 1)
|
||||||
|
mean_g = g.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
||||||
|
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
||||||
|
return (
|
||||||
|
gx,
|
||||||
|
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
|
||||||
|
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2d(nn.Module):
|
||||||
|
def __init__(self, channels, eps=1e-6):
|
||||||
|
super(LayerNorm2d, self).__init__()
|
||||||
|
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
|
||||||
|
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class GRN(nn.Module):
|
||||||
|
"""GRN (Global Response Normalization) layer"""
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
|
||||||
|
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
|
||||||
|
return self.gamma * (x * Nx) + self.beta + x
|
|
@ -0,0 +1,31 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: pixelshuffle.py
|
||||||
|
# Created Date: Friday July 1st 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def pixelshuffle_block(
|
||||||
|
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Upsample features according to `upscale_factor`.
|
||||||
|
"""
|
||||||
|
padding = kernel_size // 2
|
||||||
|
conv = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels * (upscale_factor**2),
|
||||||
|
kernel_size,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||||
|
return nn.Sequential(*[conv, pixel_shuffle])
|
|
@ -79,6 +79,12 @@ class RRDBNet(nn.Module):
|
||||||
self.scale: int = self.get_scale()
|
self.scale: int = self.get_scale()
|
||||||
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
||||||
|
|
||||||
|
c2x2 = False
|
||||||
|
if self.state["model.0.weight"].shape[-2] == 2:
|
||||||
|
c2x2 = True
|
||||||
|
self.scale = round(math.sqrt(self.scale / 4))
|
||||||
|
self.model_arch = "ESRGAN-2c2"
|
||||||
|
|
||||||
self.supports_fp16 = True
|
self.supports_fp16 = True
|
||||||
self.supports_bfp16 = True
|
self.supports_bfp16 = True
|
||||||
self.min_size_restriction = None
|
self.min_size_restriction = None
|
||||||
|
@ -105,11 +111,15 @@ class RRDBNet(nn.Module):
|
||||||
out_nc=self.num_filters,
|
out_nc=self.num_filters,
|
||||||
upscale_factor=3,
|
upscale_factor=3,
|
||||||
act_type=self.act,
|
act_type=self.act,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
upsample_blocks = [
|
upsample_blocks = [
|
||||||
upsample_block(
|
upsample_block(
|
||||||
in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act
|
in_nc=self.num_filters,
|
||||||
|
out_nc=self.num_filters,
|
||||||
|
act_type=self.act,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
for _ in range(int(math.log(self.scale, 2)))
|
for _ in range(int(math.log(self.scale, 2)))
|
||||||
]
|
]
|
||||||
|
@ -122,6 +132,7 @@ class RRDBNet(nn.Module):
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
norm_type=None,
|
norm_type=None,
|
||||||
act_type=None,
|
act_type=None,
|
||||||
|
c2x2=c2x2,
|
||||||
),
|
),
|
||||||
B.ShortcutBlock(
|
B.ShortcutBlock(
|
||||||
B.sequential(
|
B.sequential(
|
||||||
|
@ -138,6 +149,7 @@ class RRDBNet(nn.Module):
|
||||||
act_type=self.act,
|
act_type=self.act,
|
||||||
mode="CNA",
|
mode="CNA",
|
||||||
plus=self.plus,
|
plus=self.plus,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
for _ in range(self.num_blocks)
|
for _ in range(self.num_blocks)
|
||||||
],
|
],
|
||||||
|
@ -149,6 +161,7 @@ class RRDBNet(nn.Module):
|
||||||
norm_type=self.norm,
|
norm_type=self.norm,
|
||||||
act_type=None,
|
act_type=None,
|
||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
|
c2x2=c2x2,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
@ -160,6 +173,7 @@ class RRDBNet(nn.Module):
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
norm_type=None,
|
norm_type=None,
|
||||||
act_type=self.act,
|
act_type=self.act,
|
||||||
|
c2x2=c2x2,
|
||||||
),
|
),
|
||||||
# hr_conv1
|
# hr_conv1
|
||||||
B.conv_block(
|
B.conv_block(
|
||||||
|
@ -168,6 +182,7 @@ class RRDBNet(nn.Module):
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
norm_type=None,
|
norm_type=None,
|
||||||
act_type=None,
|
act_type=None,
|
||||||
|
c2x2=c2x2,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -141,6 +141,19 @@ def sequential(*args):
|
||||||
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
||||||
|
|
||||||
|
|
||||||
|
# 2x2x2 Conv Block
|
||||||
|
def conv_block_2c2(
|
||||||
|
in_nc,
|
||||||
|
out_nc,
|
||||||
|
act_type="relu",
|
||||||
|
):
|
||||||
|
return sequential(
|
||||||
|
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
||||||
|
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
||||||
|
act(act_type) if act_type else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def conv_block(
|
def conv_block(
|
||||||
in_nc: int,
|
in_nc: int,
|
||||||
out_nc: int,
|
out_nc: int,
|
||||||
|
@ -153,12 +166,17 @@ def conv_block(
|
||||||
norm_type: str | None = None,
|
norm_type: str | None = None,
|
||||||
act_type: str | None = "relu",
|
act_type: str | None = "relu",
|
||||||
mode: ConvMode = "CNA",
|
mode: ConvMode = "CNA",
|
||||||
|
c2x2=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Conv layer with padding, normalization, activation
|
Conv layer with padding, normalization, activation
|
||||||
mode: CNA --> Conv -> Norm -> Act
|
mode: CNA --> Conv -> Norm -> Act
|
||||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if c2x2:
|
||||||
|
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
||||||
|
|
||||||
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
||||||
padding = get_valid_padding(kernel_size, dilation)
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
||||||
|
@ -285,6 +303,7 @@ class RRDB(nn.Module):
|
||||||
_convtype="Conv2D",
|
_convtype="Conv2D",
|
||||||
_spectral_norm=False,
|
_spectral_norm=False,
|
||||||
plus=False,
|
plus=False,
|
||||||
|
c2x2=False,
|
||||||
):
|
):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = ResidualDenseBlock_5C(
|
self.RDB1 = ResidualDenseBlock_5C(
|
||||||
|
@ -298,6 +317,7 @@ class RRDB(nn.Module):
|
||||||
act_type,
|
act_type,
|
||||||
mode,
|
mode,
|
||||||
plus=plus,
|
plus=plus,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.RDB2 = ResidualDenseBlock_5C(
|
self.RDB2 = ResidualDenseBlock_5C(
|
||||||
nf,
|
nf,
|
||||||
|
@ -310,6 +330,7 @@ class RRDB(nn.Module):
|
||||||
act_type,
|
act_type,
|
||||||
mode,
|
mode,
|
||||||
plus=plus,
|
plus=plus,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.RDB3 = ResidualDenseBlock_5C(
|
self.RDB3 = ResidualDenseBlock_5C(
|
||||||
nf,
|
nf,
|
||||||
|
@ -322,6 +343,7 @@ class RRDB(nn.Module):
|
||||||
act_type,
|
act_type,
|
||||||
mode,
|
mode,
|
||||||
plus=plus,
|
plus=plus,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||||
act_type="leakyrelu",
|
act_type="leakyrelu",
|
||||||
mode: ConvMode = "CNA",
|
mode: ConvMode = "CNA",
|
||||||
plus=False,
|
plus=False,
|
||||||
|
c2x2=False,
|
||||||
):
|
):
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
|
|
||||||
|
@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.conv2 = conv_block(
|
self.conv2 = conv_block(
|
||||||
nf + gc,
|
nf + gc,
|
||||||
|
@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.conv3 = conv_block(
|
self.conv3 = conv_block(
|
||||||
nf + 2 * gc,
|
nf + 2 * gc,
|
||||||
|
@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.conv4 = conv_block(
|
self.conv4 = conv_block(
|
||||||
nf + 3 * gc,
|
nf + 3 * gc,
|
||||||
|
@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
if mode == "CNA":
|
if mode == "CNA":
|
||||||
last_act = None
|
last_act = None
|
||||||
|
@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=last_act,
|
act_type=last_act,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -499,6 +527,7 @@ def upconv_block(
|
||||||
norm_type: str | None = None,
|
norm_type: str | None = None,
|
||||||
act_type="relu",
|
act_type="relu",
|
||||||
mode="nearest",
|
mode="nearest",
|
||||||
|
c2x2=False,
|
||||||
):
|
):
|
||||||
# Up conv
|
# Up conv
|
||||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||||
|
@ -512,5 +541,6 @@ def upconv_block(
|
||||||
pad_type=pad_type,
|
pad_type=pad_type,
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
return sequential(upsample, conv)
|
return sequential(upsample, conv)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
||||||
from .architecture.HAT import HAT
|
from .architecture.HAT import HAT
|
||||||
from .architecture.LaMa import LaMa
|
from .architecture.LaMa import LaMa
|
||||||
from .architecture.MAT import MAT
|
from .architecture.MAT import MAT
|
||||||
|
from .architecture.OmniSR.OmniSR import OmniSR
|
||||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||||
from .architecture.SPSR import SPSRNet as SPSR
|
from .architecture.SPSR import SPSRNet as SPSR
|
||||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||||
|
@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||||
state_dict = state_dict["params"]
|
state_dict = state_dict["params"]
|
||||||
|
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
# SRVGGNet Real-ESRGAN (v2)
|
# SRVGGNet Real-ESRGAN (v2)
|
||||||
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
||||||
model = RealESRGANv2(state_dict)
|
model = RealESRGANv2(state_dict)
|
||||||
|
@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||||
# MAT
|
# MAT
|
||||||
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
|
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
|
||||||
model = MAT(state_dict)
|
model = MAT(state_dict)
|
||||||
|
# Omni-SR
|
||||||
|
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
|
||||||
|
model = OmniSR(state_dict)
|
||||||
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
||||||
from .architecture.HAT import HAT
|
from .architecture.HAT import HAT
|
||||||
from .architecture.LaMa import LaMa
|
from .architecture.LaMa import LaMa
|
||||||
from .architecture.MAT import MAT
|
from .architecture.MAT import MAT
|
||||||
|
from .architecture.OmniSR.OmniSR import OmniSR
|
||||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||||
from .architecture.SPSR import SPSRNet as SPSR
|
from .architecture.SPSR import SPSRNet as SPSR
|
||||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||||
|
@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
||||||
from .architecture.Swin2SR import Swin2SR
|
from .architecture.Swin2SR import Swin2SR
|
||||||
from .architecture.SwinIR import SwinIR
|
from .architecture.SwinIR import SwinIR
|
||||||
|
|
||||||
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT)
|
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR)
|
||||||
PyTorchSRModel = Union[
|
PyTorchSRModel = Union[
|
||||||
RealESRGANv2,
|
RealESRGANv2,
|
||||||
SPSR,
|
SPSR,
|
||||||
|
@ -22,6 +23,7 @@ PyTorchSRModel = Union[
|
||||||
SwinIR,
|
SwinIR,
|
||||||
Swin2SR,
|
Swin2SR,
|
||||||
HAT,
|
HAT,
|
||||||
|
OmniSR,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue