297 lines
10 KiB
Python
297 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import functools
|
|
import math
|
|
import re
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from . import block as B
|
|
|
|
|
|
# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
|
|
# Which enhanced stuff that was already here
|
|
class RRDBNet(nn.Module):
|
|
def __init__(
|
|
self,
|
|
state_dict,
|
|
norm=None,
|
|
act: str = "leakyrelu",
|
|
upsampler: str = "upconv",
|
|
mode: B.ConvMode = "CNA",
|
|
) -> None:
|
|
"""
|
|
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
|
|
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
|
|
and Chen Change Loy.
|
|
This is old-arch Residual in Residual Dense Block Network and is not
|
|
the newest revision that's available at github.com/xinntao/ESRGAN.
|
|
This is on purpose, the newest Network has severely limited the
|
|
potential use of the Network with no benefits.
|
|
This network supports model files from both new and old-arch.
|
|
Args:
|
|
norm: Normalization layer
|
|
act: Activation layer
|
|
upsampler: Upsample layer. upconv, pixel_shuffle
|
|
mode: Convolution mode
|
|
"""
|
|
super(RRDBNet, self).__init__()
|
|
self.model_arch = "ESRGAN"
|
|
self.sub_type = "SR"
|
|
|
|
self.state = state_dict
|
|
self.norm = norm
|
|
self.act = act
|
|
self.upsampler = upsampler
|
|
self.mode = mode
|
|
|
|
self.state_map = {
|
|
# currently supports old, new, and newer RRDBNet arch models
|
|
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
|
|
"model.0.weight": ("conv_first.weight",),
|
|
"model.0.bias": ("conv_first.bias",),
|
|
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
|
|
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
|
|
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
|
|
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
|
|
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
|
|
),
|
|
}
|
|
if "params_ema" in self.state:
|
|
self.state = self.state["params_ema"]
|
|
# self.model_arch = "RealESRGAN"
|
|
self.num_blocks = self.get_num_blocks()
|
|
self.plus = any("conv1x1" in k for k in self.state.keys())
|
|
if self.plus:
|
|
self.model_arch = "ESRGAN+"
|
|
|
|
self.state = self.new_to_old_arch(self.state)
|
|
|
|
self.key_arr = list(self.state.keys())
|
|
|
|
self.in_nc: int = self.state[self.key_arr[0]].shape[1]
|
|
self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
|
|
|
|
self.scale: int = self.get_scale()
|
|
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
|
|
|
c2x2 = False
|
|
if self.state["model.0.weight"].shape[-2] == 2:
|
|
c2x2 = True
|
|
self.scale = round(math.sqrt(self.scale / 4))
|
|
self.model_arch = "ESRGAN-2c2"
|
|
|
|
self.supports_fp16 = True
|
|
self.supports_bfp16 = True
|
|
self.min_size_restriction = None
|
|
|
|
# Detect if pixelunshuffle was used (Real-ESRGAN)
|
|
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
|
|
self.in_nc / 4,
|
|
self.in_nc / 16,
|
|
):
|
|
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
|
|
else:
|
|
self.shuffle_factor = None
|
|
|
|
upsample_block = {
|
|
"upconv": B.upconv_block,
|
|
"pixel_shuffle": B.pixelshuffle_block,
|
|
}.get(self.upsampler)
|
|
if upsample_block is None:
|
|
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
|
|
|
|
if self.scale == 3:
|
|
upsample_blocks = upsample_block(
|
|
in_nc=self.num_filters,
|
|
out_nc=self.num_filters,
|
|
upscale_factor=3,
|
|
act_type=self.act,
|
|
c2x2=c2x2,
|
|
)
|
|
else:
|
|
upsample_blocks = [
|
|
upsample_block(
|
|
in_nc=self.num_filters,
|
|
out_nc=self.num_filters,
|
|
act_type=self.act,
|
|
c2x2=c2x2,
|
|
)
|
|
for _ in range(int(math.log(self.scale, 2)))
|
|
]
|
|
|
|
self.model = B.sequential(
|
|
# fea conv
|
|
B.conv_block(
|
|
in_nc=self.in_nc,
|
|
out_nc=self.num_filters,
|
|
kernel_size=3,
|
|
norm_type=None,
|
|
act_type=None,
|
|
c2x2=c2x2,
|
|
),
|
|
B.ShortcutBlock(
|
|
B.sequential(
|
|
# rrdb blocks
|
|
*[
|
|
B.RRDB(
|
|
nf=self.num_filters,
|
|
kernel_size=3,
|
|
gc=32,
|
|
stride=1,
|
|
bias=True,
|
|
pad_type="zero",
|
|
norm_type=self.norm,
|
|
act_type=self.act,
|
|
mode="CNA",
|
|
plus=self.plus,
|
|
c2x2=c2x2,
|
|
)
|
|
for _ in range(self.num_blocks)
|
|
],
|
|
# lr conv
|
|
B.conv_block(
|
|
in_nc=self.num_filters,
|
|
out_nc=self.num_filters,
|
|
kernel_size=3,
|
|
norm_type=self.norm,
|
|
act_type=None,
|
|
mode=self.mode,
|
|
c2x2=c2x2,
|
|
),
|
|
)
|
|
),
|
|
*upsample_blocks,
|
|
# hr_conv0
|
|
B.conv_block(
|
|
in_nc=self.num_filters,
|
|
out_nc=self.num_filters,
|
|
kernel_size=3,
|
|
norm_type=None,
|
|
act_type=self.act,
|
|
c2x2=c2x2,
|
|
),
|
|
# hr_conv1
|
|
B.conv_block(
|
|
in_nc=self.num_filters,
|
|
out_nc=self.out_nc,
|
|
kernel_size=3,
|
|
norm_type=None,
|
|
act_type=None,
|
|
c2x2=c2x2,
|
|
),
|
|
)
|
|
|
|
# Adjust these properties for calculations outside of the model
|
|
if self.shuffle_factor:
|
|
self.in_nc //= self.shuffle_factor**2
|
|
self.scale //= self.shuffle_factor
|
|
|
|
self.load_state_dict(self.state, strict=False)
|
|
|
|
def new_to_old_arch(self, state):
|
|
"""Convert a new-arch model state dictionary to an old-arch dictionary."""
|
|
if "params_ema" in state:
|
|
state = state["params_ema"]
|
|
|
|
if "conv_first.weight" not in state:
|
|
# model is already old arch, this is a loose check, but should be sufficient
|
|
return state
|
|
|
|
# add nb to state keys
|
|
for kind in ("weight", "bias"):
|
|
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
|
|
f"model.1.sub./NB/.{kind}"
|
|
]
|
|
del self.state_map[f"model.1.sub./NB/.{kind}"]
|
|
|
|
old_state = OrderedDict()
|
|
for old_key, new_keys in self.state_map.items():
|
|
for new_key in new_keys:
|
|
if r"\1" in old_key:
|
|
for k, v in state.items():
|
|
sub = re.sub(new_key, old_key, k)
|
|
if sub != k:
|
|
old_state[sub] = v
|
|
else:
|
|
if new_key in state:
|
|
old_state[old_key] = state[new_key]
|
|
|
|
# upconv layers
|
|
max_upconv = 0
|
|
for key in state.keys():
|
|
match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
|
|
if match is not None:
|
|
_, key_num, key_type = match.groups()
|
|
old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
|
|
max_upconv = max(max_upconv, int(key_num) * 3)
|
|
|
|
# final layers
|
|
for key in state.keys():
|
|
if key in ("HRconv.weight", "conv_hr.weight"):
|
|
old_state[f"model.{max_upconv + 2}.weight"] = state[key]
|
|
elif key in ("HRconv.bias", "conv_hr.bias"):
|
|
old_state[f"model.{max_upconv + 2}.bias"] = state[key]
|
|
elif key in ("conv_last.weight",):
|
|
old_state[f"model.{max_upconv + 4}.weight"] = state[key]
|
|
elif key in ("conv_last.bias",):
|
|
old_state[f"model.{max_upconv + 4}.bias"] = state[key]
|
|
|
|
# Sort by first numeric value of each layer
|
|
def compare(item1, item2):
|
|
parts1 = item1.split(".")
|
|
parts2 = item2.split(".")
|
|
int1 = int(parts1[1])
|
|
int2 = int(parts2[1])
|
|
return int1 - int2
|
|
|
|
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
|
|
|
|
# Rebuild the output dict in the right order
|
|
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
|
|
|
|
return out_dict
|
|
|
|
def get_scale(self, min_part: int = 6) -> int:
|
|
n = 0
|
|
for part in list(self.state):
|
|
parts = part.split(".")[1:]
|
|
if len(parts) == 2:
|
|
part_num = int(parts[0])
|
|
if part_num > min_part and parts[1] == "weight":
|
|
n += 1
|
|
return 2**n
|
|
|
|
def get_num_blocks(self) -> int:
|
|
nbs = []
|
|
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
|
|
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
|
|
)
|
|
for state_key in state_keys:
|
|
for k in self.state:
|
|
m = re.search(state_key, k)
|
|
if m:
|
|
nbs.append(int(m.group(1)))
|
|
if nbs:
|
|
break
|
|
return max(*nbs) + 1
|
|
|
|
def forward(self, x):
|
|
if self.shuffle_factor:
|
|
_, _, h, w = x.size()
|
|
mod_pad_h = (
|
|
self.shuffle_factor - h % self.shuffle_factor
|
|
) % self.shuffle_factor
|
|
mod_pad_w = (
|
|
self.shuffle_factor - w % self.shuffle_factor
|
|
) % self.shuffle_factor
|
|
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
|
|
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
|
|
x = self.model(x)
|
|
return x[:, :, : h * self.scale, : w * self.scale]
|
|
return self.model(x)
|