32 lines
904 B
Python
32 lines
904 B
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: pixelshuffle.py
|
|
# Created Date: Friday July 1st 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
def pixelshuffle_block(
|
|
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
|
):
|
|
"""
|
|
Upsample features according to `upscale_factor`.
|
|
"""
|
|
padding = kernel_size // 2
|
|
conv = nn.Conv2d(
|
|
in_channels,
|
|
out_channels * (upscale_factor**2),
|
|
kernel_size,
|
|
padding=1,
|
|
bias=bias,
|
|
)
|
|
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
return nn.Sequential(*[conv, pixel_shuffle])
|