Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
6bd6154a
Unverified
Commit
6bd6154a
authored
Oct 23, 2022
by
AUTOMATIC1111
Committed by
GitHub
Oct 23, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #2067 from victorca25/esrgan_mod
update ESRGAN architecture and model to support all ESRGAN models
parents
696cb33e
53154ba1
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
563 additions
and
292 deletions
+563
-292
bsrgan_model.py
modules/bsrgan_model.py
+0
-76
bsrgan_model_arch.py
modules/bsrgan_model_arch.py
+0
-102
esrgan_model.py
modules/esrgan_model.py
+128
-62
esrgan_model_arch.py
modules/esrgan_model_arch.py
+435
-52
No files found.
modules/bsrgan_model.py
deleted
100644 → 0
View file @
696cb33e
import
os.path
import
sys
import
traceback
import
PIL.Image
import
numpy
as
np
import
torch
from
basicsr.utils.download_util
import
load_file_from_url
import
modules.upscaler
from
modules
import
devices
,
modelloader
from
modules.bsrgan_model_arch
import
RRDBNet
class
UpscalerBSRGAN
(
modules
.
upscaler
.
Upscaler
):
def
__init__
(
self
,
dirname
):
self
.
name
=
"BSRGAN"
self
.
model_name
=
"BSRGAN 4x"
self
.
model_url
=
"https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self
.
user_path
=
dirname
super
()
.
__init__
()
model_paths
=
self
.
find_models
(
ext_filter
=
[
".pt"
,
".pth"
])
scalers
=
[]
if
len
(
model_paths
)
==
0
:
scaler_data
=
modules
.
upscaler
.
UpscalerData
(
self
.
model_name
,
self
.
model_url
,
self
,
4
)
scalers
.
append
(
scaler_data
)
for
file
in
model_paths
:
if
"http"
in
file
:
name
=
self
.
model_name
else
:
name
=
modelloader
.
friendly_name
(
file
)
try
:
scaler_data
=
modules
.
upscaler
.
UpscalerData
(
name
,
file
,
self
,
4
)
scalers
.
append
(
scaler_data
)
except
Exception
:
print
(
f
"Error loading BSRGAN model: {file}"
,
file
=
sys
.
stderr
)
print
(
traceback
.
format_exc
(),
file
=
sys
.
stderr
)
self
.
scalers
=
scalers
def
do_upscale
(
self
,
img
:
PIL
.
Image
,
selected_file
):
torch
.
cuda
.
empty_cache
()
model
=
self
.
load_model
(
selected_file
)
if
model
is
None
:
return
img
model
.
to
(
devices
.
device_bsrgan
)
torch
.
cuda
.
empty_cache
()
img
=
np
.
array
(
img
)
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
moveaxis
(
img
,
2
,
0
)
/
255
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
img
.
unsqueeze
(
0
)
.
to
(
devices
.
device_bsrgan
)
with
torch
.
no_grad
():
output
=
model
(
img
)
output
=
output
.
squeeze
()
.
float
()
.
cpu
()
.
clamp_
(
0
,
1
)
.
numpy
()
output
=
255.
*
np
.
moveaxis
(
output
,
0
,
2
)
output
=
output
.
astype
(
np
.
uint8
)
output
=
output
[:,
:,
::
-
1
]
torch
.
cuda
.
empty_cache
()
return
PIL
.
Image
.
fromarray
(
output
,
'RGB'
)
def
load_model
(
self
,
path
:
str
):
if
"http"
in
path
:
filename
=
load_file_from_url
(
url
=
self
.
model_url
,
model_dir
=
self
.
model_path
,
file_name
=
"
%
s.pth"
%
self
.
name
,
progress
=
True
)
else
:
filename
=
path
if
not
os
.
path
.
exists
(
filename
)
or
filename
is
None
:
print
(
f
"BSRGAN: Unable to load model from {filename}"
,
file
=
sys
.
stderr
)
return
None
model
=
RRDBNet
(
in_nc
=
3
,
out_nc
=
3
,
nf
=
64
,
nb
=
23
,
gc
=
32
,
sf
=
4
)
# define network
model
.
load_state_dict
(
torch
.
load
(
filename
),
strict
=
True
)
model
.
eval
()
for
k
,
v
in
model
.
named_parameters
():
v
.
requires_grad
=
False
return
model
modules/bsrgan_model_arch.py
deleted
100644 → 0
View file @
696cb33e
import
functools
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
def
initialize_weights
(
net_l
,
scale
=
1
):
if
not
isinstance
(
net_l
,
list
):
net_l
=
[
net_l
]
for
net
in
net_l
:
for
m
in
net
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0
,
mode
=
'fan_in'
)
m
.
weight
.
data
*=
scale
# for residual block
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
Linear
):
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0
,
mode
=
'fan_in'
)
m
.
weight
.
data
*=
scale
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
init
.
constant_
(
m
.
weight
,
1
)
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
def
make_layer
(
block
,
n_layers
):
layers
=
[]
for
_
in
range
(
n_layers
):
layers
.
append
(
block
())
return
nn
.
Sequential
(
*
layers
)
class
ResidualDenseBlock_5C
(
nn
.
Module
):
def
__init__
(
self
,
nf
=
64
,
gc
=
32
,
bias
=
True
):
super
(
ResidualDenseBlock_5C
,
self
)
.
__init__
()
# gc: growth channel, i.e. intermediate channels
self
.
conv1
=
nn
.
Conv2d
(
nf
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv2
=
nn
.
Conv2d
(
nf
+
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv3
=
nn
.
Conv2d
(
nf
+
2
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv4
=
nn
.
Conv2d
(
nf
+
3
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv5
=
nn
.
Conv2d
(
nf
+
4
*
gc
,
nf
,
3
,
1
,
1
,
bias
=
bias
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
# initialization
initialize_weights
([
self
.
conv1
,
self
.
conv2
,
self
.
conv3
,
self
.
conv4
,
self
.
conv5
],
0.1
)
def
forward
(
self
,
x
):
x1
=
self
.
lrelu
(
self
.
conv1
(
x
))
x2
=
self
.
lrelu
(
self
.
conv2
(
torch
.
cat
((
x
,
x1
),
1
)))
x3
=
self
.
lrelu
(
self
.
conv3
(
torch
.
cat
((
x
,
x1
,
x2
),
1
)))
x4
=
self
.
lrelu
(
self
.
conv4
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
),
1
)))
x5
=
self
.
conv5
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
,
x4
),
1
))
return
x5
*
0.2
+
x
class
RRDB
(
nn
.
Module
):
'''Residual in Residual Dense Block'''
def
__init__
(
self
,
nf
,
gc
=
32
):
super
(
RRDB
,
self
)
.
__init__
()
self
.
RDB1
=
ResidualDenseBlock_5C
(
nf
,
gc
)
self
.
RDB2
=
ResidualDenseBlock_5C
(
nf
,
gc
)
self
.
RDB3
=
ResidualDenseBlock_5C
(
nf
,
gc
)
def
forward
(
self
,
x
):
out
=
self
.
RDB1
(
x
)
out
=
self
.
RDB2
(
out
)
out
=
self
.
RDB3
(
out
)
return
out
*
0.2
+
x
class
RRDBNet
(
nn
.
Module
):
def
__init__
(
self
,
in_nc
=
3
,
out_nc
=
3
,
nf
=
64
,
nb
=
23
,
gc
=
32
,
sf
=
4
):
super
(
RRDBNet
,
self
)
.
__init__
()
RRDB_block_f
=
functools
.
partial
(
RRDB
,
nf
=
nf
,
gc
=
gc
)
self
.
sf
=
sf
self
.
conv_first
=
nn
.
Conv2d
(
in_nc
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
RRDB_trunk
=
make_layer
(
RRDB_block_f
,
nb
)
self
.
trunk_conv
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
#### upsampling
self
.
upconv1
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
if
self
.
sf
==
4
:
self
.
upconv2
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
HRconv
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
conv_last
=
nn
.
Conv2d
(
nf
,
out_nc
,
3
,
1
,
1
,
bias
=
True
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
def
forward
(
self
,
x
):
fea
=
self
.
conv_first
(
x
)
trunk
=
self
.
trunk_conv
(
self
.
RRDB_trunk
(
fea
))
fea
=
fea
+
trunk
fea
=
self
.
lrelu
(
self
.
upconv1
(
F
.
interpolate
(
fea
,
scale_factor
=
2
,
mode
=
'nearest'
)))
if
self
.
sf
==
4
:
fea
=
self
.
lrelu
(
self
.
upconv2
(
F
.
interpolate
(
fea
,
scale_factor
=
2
,
mode
=
'nearest'
)))
out
=
self
.
conv_last
(
self
.
lrelu
(
self
.
HRconv
(
fea
)))
return
out
\ No newline at end of file
modules/esrgan_model.py
View file @
6bd6154a
...
@@ -11,62 +11,109 @@ from modules.upscaler import Upscaler, UpscalerData
...
@@ -11,62 +11,109 @@ from modules.upscaler import Upscaler, UpscalerData
from
modules.shared
import
opts
from
modules.shared
import
opts
def
fix_model_layers
(
crt_model
,
pretrained_net
):
# this code is adapted from https://github.com/xinntao/ESRGAN
if
'conv_first.weight'
in
pretrained_net
:
return
pretrained_net
if
'model.0.weight'
not
in
pretrained_net
:
is_realesrgan
=
"params_ema"
in
pretrained_net
and
'body.0.rdb1.conv1.weight'
in
pretrained_net
[
"params_ema"
]
if
is_realesrgan
:
raise
Exception
(
"The file is a RealESRGAN model, it can't be used as a ESRGAN model."
)
else
:
raise
Exception
(
"The file is not a ESRGAN model."
)
crt_net
=
crt_model
.
state_dict
()
load_net_clean
=
{}
for
k
,
v
in
pretrained_net
.
items
():
if
k
.
startswith
(
'module.'
):
load_net_clean
[
k
[
7
:]]
=
v
else
:
load_net_clean
[
k
]
=
v
pretrained_net
=
load_net_clean
tbd
=
[]
def
mod2normal
(
state_dict
):
for
k
,
v
in
crt_net
.
items
():
# this code is copied from https://github.com/victorca25/iNNfer
tbd
.
append
(
k
)
if
'conv_first.weight'
in
state_dict
:
crt_net
=
{}
items
=
[]
for
k
,
v
in
state_dict
.
items
():
items
.
append
(
k
)
# directly copy
crt_net
[
'model.0.weight'
]
=
state_dict
[
'conv_first.weight'
]
for
k
,
v
in
crt_net
.
items
():
crt_net
[
'model.0.bias'
]
=
state_dict
[
'conv_first.bias'
]
if
k
in
pretrained_net
and
pretrained_net
[
k
]
.
size
()
==
v
.
size
():
crt_net
[
k
]
=
pretrained_net
[
k
]
tbd
.
remove
(
k
)
crt_net
[
'conv_first.weight'
]
=
pretrained_net
[
'model.0.weight'
]
for
k
in
items
.
copy
():
crt_net
[
'conv_first.bias'
]
=
pretrained_net
[
'model.0.bias'
]
for
k
in
tbd
.
copy
():
if
'RDB'
in
k
:
if
'RDB'
in
k
:
ori_k
=
k
.
replace
(
'RRDB_trunk.'
,
'model.1.sub.'
)
ori_k
=
k
.
replace
(
'RRDB_trunk.'
,
'model.1.sub.'
)
if
'.weight'
in
k
:
if
'.weight'
in
k
:
ori_k
=
ori_k
.
replace
(
'.weight'
,
'.0.weight'
)
ori_k
=
ori_k
.
replace
(
'.weight'
,
'.0.weight'
)
elif
'.bias'
in
k
:
elif
'.bias'
in
k
:
ori_k
=
ori_k
.
replace
(
'.bias'
,
'.0.bias'
)
ori_k
=
ori_k
.
replace
(
'.bias'
,
'.0.bias'
)
crt_net
[
k
]
=
pretrained_net
[
ori_k
]
crt_net
[
ori_k
]
=
state_dict
[
k
]
tbd
.
remove
(
k
)
items
.
remove
(
k
)
crt_net
[
'trunk_conv.weight'
]
=
pretrained_net
[
'model.1.sub.23.weight'
]
crt_net
[
'model.1.sub.23.weight'
]
=
state_dict
[
'trunk_conv.weight'
]
crt_net
[
'trunk_conv.bias'
]
=
pretrained_net
[
'model.1.sub.23.bias'
]
crt_net
[
'model.1.sub.23.bias'
]
=
state_dict
[
'trunk_conv.bias'
]
crt_net
[
'upconv1.weight'
]
=
pretrained_net
[
'model.3.weight'
]
crt_net
[
'model.3.weight'
]
=
state_dict
[
'upconv1.weight'
]
crt_net
[
'upconv1.bias'
]
=
pretrained_net
[
'model.3.bias'
]
crt_net
[
'model.3.bias'
]
=
state_dict
[
'upconv1.bias'
]
crt_net
[
'upconv2.weight'
]
=
pretrained_net
[
'model.6.weight'
]
crt_net
[
'model.6.weight'
]
=
state_dict
[
'upconv2.weight'
]
crt_net
[
'upconv2.bias'
]
=
pretrained_net
[
'model.6.bias'
]
crt_net
[
'model.6.bias'
]
=
state_dict
[
'upconv2.bias'
]
crt_net
[
'HRconv.weight'
]
=
pretrained_net
[
'model.8.weight'
]
crt_net
[
'model.8.weight'
]
=
state_dict
[
'HRconv.weight'
]
crt_net
[
'HRconv.bias'
]
=
pretrained_net
[
'model.8.bias'
]
crt_net
[
'model.8.bias'
]
=
state_dict
[
'HRconv.bias'
]
crt_net
[
'conv_last.weight'
]
=
pretrained_net
[
'model.10.weight'
]
crt_net
[
'model.10.weight'
]
=
state_dict
[
'conv_last.weight'
]
crt_net
[
'conv_last.bias'
]
=
pretrained_net
[
'model.10.bias'
]
crt_net
[
'model.10.bias'
]
=
state_dict
[
'conv_last.bias'
]
state_dict
=
crt_net
return
crt_net
return
state_dict
def
resrgan2normal
(
state_dict
,
nb
=
23
):
# this code is copied from https://github.com/victorca25/iNNfer
if
"conv_first.weight"
in
state_dict
and
"body.0.rdb1.conv1.weight"
in
state_dict
:
crt_net
=
{}
items
=
[]
for
k
,
v
in
state_dict
.
items
():
items
.
append
(
k
)
crt_net
[
'model.0.weight'
]
=
state_dict
[
'conv_first.weight'
]
crt_net
[
'model.0.bias'
]
=
state_dict
[
'conv_first.bias'
]
for
k
in
items
.
copy
():
if
"rdb"
in
k
:
ori_k
=
k
.
replace
(
'body.'
,
'model.1.sub.'
)
ori_k
=
ori_k
.
replace
(
'.rdb'
,
'.RDB'
)
if
'.weight'
in
k
:
ori_k
=
ori_k
.
replace
(
'.weight'
,
'.0.weight'
)
elif
'.bias'
in
k
:
ori_k
=
ori_k
.
replace
(
'.bias'
,
'.0.bias'
)
crt_net
[
ori_k
]
=
state_dict
[
k
]
items
.
remove
(
k
)
crt_net
[
f
'model.1.sub.{nb}.weight'
]
=
state_dict
[
'conv_body.weight'
]
crt_net
[
f
'model.1.sub.{nb}.bias'
]
=
state_dict
[
'conv_body.bias'
]
crt_net
[
'model.3.weight'
]
=
state_dict
[
'conv_up1.weight'
]
crt_net
[
'model.3.bias'
]
=
state_dict
[
'conv_up1.bias'
]
crt_net
[
'model.6.weight'
]
=
state_dict
[
'conv_up2.weight'
]
crt_net
[
'model.6.bias'
]
=
state_dict
[
'conv_up2.bias'
]
crt_net
[
'model.8.weight'
]
=
state_dict
[
'conv_hr.weight'
]
crt_net
[
'model.8.bias'
]
=
state_dict
[
'conv_hr.bias'
]
crt_net
[
'model.10.weight'
]
=
state_dict
[
'conv_last.weight'
]
crt_net
[
'model.10.bias'
]
=
state_dict
[
'conv_last.bias'
]
state_dict
=
crt_net
return
state_dict
def
infer_params
(
state_dict
):
# this code is copied from https://github.com/victorca25/iNNfer
scale2x
=
0
scalemin
=
6
n_uplayer
=
0
plus
=
False
for
block
in
list
(
state_dict
):
parts
=
block
.
split
(
"."
)
n_parts
=
len
(
parts
)
if
n_parts
==
5
and
parts
[
2
]
==
"sub"
:
nb
=
int
(
parts
[
3
])
elif
n_parts
==
3
:
part_num
=
int
(
parts
[
1
])
if
(
part_num
>
scalemin
and
parts
[
0
]
==
"model"
and
parts
[
2
]
==
"weight"
):
scale2x
+=
1
if
part_num
>
n_uplayer
:
n_uplayer
=
part_num
out_nc
=
state_dict
[
block
]
.
shape
[
0
]
if
not
plus
and
"conv1x1"
in
block
:
plus
=
True
nf
=
state_dict
[
"model.0.weight"
]
.
shape
[
0
]
in_nc
=
state_dict
[
"model.0.weight"
]
.
shape
[
1
]
out_nc
=
out_nc
scale
=
2
**
scale2x
return
in_nc
,
out_nc
,
nf
,
nb
,
plus
,
scale
class
UpscalerESRGAN
(
Upscaler
):
class
UpscalerESRGAN
(
Upscaler
):
def
__init__
(
self
,
dirname
):
def
__init__
(
self
,
dirname
):
...
@@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler):
...
@@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler):
print
(
"Unable to load
%
s from
%
s"
%
(
self
.
model_path
,
filename
))
print
(
"Unable to load
%
s from
%
s"
%
(
self
.
model_path
,
filename
))
return
None
return
None
pretrained_net
=
torch
.
load
(
filename
,
map_location
=
'cpu'
if
devices
.
device_esrgan
.
type
==
'mps'
else
None
)
state_dict
=
torch
.
load
(
filename
,
map_location
=
'cpu'
if
devices
.
device_esrgan
.
type
==
'mps'
else
None
)
crt_model
=
arch
.
RRDBNet
(
3
,
3
,
64
,
23
,
gc
=
32
)
if
"params_ema"
in
state_dict
:
state_dict
=
state_dict
[
"params_ema"
]
elif
"params"
in
state_dict
:
state_dict
=
state_dict
[
"params"
]
num_conv
=
16
if
"realesr-animevideov3"
in
filename
else
32
model
=
arch
.
SRVGGNetCompact
(
num_in_ch
=
3
,
num_out_ch
=
3
,
num_feat
=
64
,
num_conv
=
num_conv
,
upscale
=
4
,
act_type
=
'prelu'
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
if
"body.0.rdb1.conv1.weight"
in
state_dict
and
"conv_first.weight"
in
state_dict
:
nb
=
6
if
"RealESRGAN_x4plus_anime_6B"
in
filename
else
23
state_dict
=
resrgan2normal
(
state_dict
,
nb
)
elif
"conv_first.weight"
in
state_dict
:
state_dict
=
mod2normal
(
state_dict
)
elif
"model.0.weight"
not
in
state_dict
:
raise
Exception
(
"The file is not a recognized ESRGAN model."
)
in_nc
,
out_nc
,
nf
,
nb
,
plus
,
mscale
=
infer_params
(
state_dict
)
pretrained_net
=
fix_model_layers
(
crt_model
,
pretrained_net
)
model
=
arch
.
RRDBNet
(
in_nc
=
in_nc
,
out_nc
=
out_nc
,
nf
=
nf
,
nb
=
nb
,
upscale
=
mscale
,
plus
=
plus
)
crt_model
.
load_state_dict
(
pretrained_ne
t
)
model
.
load_state_dict
(
state_dic
t
)
crt_
model
.
eval
()
model
.
eval
()
return
crt_
model
return
model
def
upscale_without_tiling
(
model
,
img
):
def
upscale_without_tiling
(
model
,
img
):
img
=
np
.
array
(
img
)
img
=
np
.
array
(
img
)
img
=
img
[:,
:,
::
-
1
]
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
moveaxis
(
img
,
2
,
0
)
/
255
img
=
np
.
ascontiguousarray
(
np
.
transpose
(
img
,
(
2
,
0
,
1
))
)
/
255
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
img
.
unsqueeze
(
0
)
.
to
(
devices
.
device_esrgan
)
img
=
img
.
unsqueeze
(
0
)
.
to
(
devices
.
device_esrgan
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
modules/esrgan_model_arch.py
View file @
6bd6154a
# this file is
taken from https://github.com/xinntao/ESRGAN
# this file is
adapted from https://github.com/victorca25/iNNfer
import
math
import
functools
import
functools
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
def
make_layer
(
block
,
n_layers
):
####################
layers
=
[]
# RRDBNet Generator
for
_
in
range
(
n_layers
):
####################
layers
.
append
(
block
())
return
nn
.
Sequential
(
*
layers
)
class
RRDBNet
(
nn
.
Module
):
def
__init__
(
self
,
in_nc
,
out_nc
,
nf
,
nb
,
nr
=
3
,
gc
=
32
,
upscale
=
4
,
norm_type
=
None
,
act_type
=
'leakyrelu'
,
mode
=
'CNA'
,
upsample_mode
=
'upconv'
,
convtype
=
'Conv2D'
,
finalact
=
None
,
gaussian_noise
=
False
,
plus
=
False
):
super
(
RRDBNet
,
self
)
.
__init__
()
n_upscale
=
int
(
math
.
log
(
upscale
,
2
))
if
upscale
==
3
:
n_upscale
=
1
class
ResidualDenseBlock_5C
(
nn
.
Module
):
self
.
resrgan_scale
=
0
def
__init__
(
self
,
nf
=
64
,
gc
=
32
,
bias
=
True
):
if
in_nc
%
16
==
0
:
super
(
ResidualDenseBlock_5C
,
self
)
.
__init__
()
self
.
resrgan_scale
=
1
# gc: growth channel, i.e. intermediate channels
elif
in_nc
!=
4
and
in_nc
%
4
==
0
:
self
.
conv1
=
nn
.
Conv2d
(
nf
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
resrgan_scale
=
2
self
.
conv2
=
nn
.
Conv2d
(
nf
+
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv3
=
nn
.
Conv2d
(
nf
+
2
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv4
=
nn
.
Conv2d
(
nf
+
3
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv5
=
nn
.
Conv2d
(
nf
+
4
*
gc
,
nf
,
3
,
1
,
1
,
bias
=
bias
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
# initialization
fea_conv
=
conv_block
(
in_nc
,
nf
,
kernel_size
=
3
,
norm_type
=
None
,
act_type
=
None
,
convtype
=
convtype
)
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
rb_blocks
=
[
RRDB
(
nf
,
nr
,
kernel_size
=
3
,
gc
=
32
,
stride
=
1
,
bias
=
1
,
pad_type
=
'zero'
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
'CNA'
,
convtype
=
convtype
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
for
_
in
range
(
nb
)]
LR_conv
=
conv_block
(
nf
,
nf
,
kernel_size
=
3
,
norm_type
=
norm_type
,
act_type
=
None
,
mode
=
mode
,
convtype
=
convtype
)
def
forward
(
self
,
x
):
if
upsample_mode
==
'upconv'
:
x1
=
self
.
lrelu
(
self
.
conv1
(
x
))
upsample_block
=
upconv_block
x2
=
self
.
lrelu
(
self
.
conv2
(
torch
.
cat
((
x
,
x1
),
1
)))
elif
upsample_mode
==
'pixelshuffle'
:
x3
=
self
.
lrelu
(
self
.
conv3
(
torch
.
cat
((
x
,
x1
,
x2
),
1
)))
upsample_block
=
pixelshuffle_block
x4
=
self
.
lrelu
(
self
.
conv4
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
),
1
)))
else
:
x5
=
self
.
conv5
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
,
x4
),
1
))
raise
NotImplementedError
(
'upsample mode [{:s}] is not found'
.
format
(
upsample_mode
))
return
x5
*
0.2
+
x
if
upscale
==
3
:
upsampler
=
upsample_block
(
nf
,
nf
,
3
,
act_type
=
act_type
,
convtype
=
convtype
)
else
:
upsampler
=
[
upsample_block
(
nf
,
nf
,
act_type
=
act_type
,
convtype
=
convtype
)
for
_
in
range
(
n_upscale
)]
HR_conv0
=
conv_block
(
nf
,
nf
,
kernel_size
=
3
,
norm_type
=
None
,
act_type
=
act_type
,
convtype
=
convtype
)
HR_conv1
=
conv_block
(
nf
,
out_nc
,
kernel_size
=
3
,
norm_type
=
None
,
act_type
=
None
,
convtype
=
convtype
)
outact
=
act
(
finalact
)
if
finalact
else
None
self
.
model
=
sequential
(
fea_conv
,
ShortcutBlock
(
sequential
(
*
rb_blocks
,
LR_conv
)),
*
upsampler
,
HR_conv0
,
HR_conv1
,
outact
)
def
forward
(
self
,
x
,
outm
=
None
):
if
self
.
resrgan_scale
==
1
:
feat
=
pixel_unshuffle
(
x
,
scale
=
4
)
elif
self
.
resrgan_scale
==
2
:
feat
=
pixel_unshuffle
(
x
,
scale
=
2
)
else
:
feat
=
x
return
self
.
model
(
feat
)
class
RRDB
(
nn
.
Module
):
class
RRDB
(
nn
.
Module
):
'''Residual in Residual Dense Block'''
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""
def
__init__
(
self
,
nf
,
gc
=
32
):
def
__init__
(
self
,
nf
,
nr
=
3
,
kernel_size
=
3
,
gc
=
32
,
stride
=
1
,
bias
=
1
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'leakyrelu'
,
mode
=
'CNA'
,
convtype
=
'Conv2D'
,
spectral_norm
=
False
,
gaussian_noise
=
False
,
plus
=
False
):
super
(
RRDB
,
self
)
.
__init__
()
super
(
RRDB
,
self
)
.
__init__
()
self
.
RDB1
=
ResidualDenseBlock_5C
(
nf
,
gc
)
# This is for backwards compatibility with existing models
self
.
RDB2
=
ResidualDenseBlock_5C
(
nf
,
gc
)
if
nr
==
3
:
self
.
RDB3
=
ResidualDenseBlock_5C
(
nf
,
gc
)
self
.
RDB1
=
ResidualDenseBlock_5C
(
nf
,
kernel_size
,
gc
,
stride
,
bias
,
pad_type
,
norm_type
,
act_type
,
mode
,
convtype
,
spectral_norm
=
spectral_norm
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
self
.
RDB2
=
ResidualDenseBlock_5C
(
nf
,
kernel_size
,
gc
,
stride
,
bias
,
pad_type
,
norm_type
,
act_type
,
mode
,
convtype
,
spectral_norm
=
spectral_norm
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
self
.
RDB3
=
ResidualDenseBlock_5C
(
nf
,
kernel_size
,
gc
,
stride
,
bias
,
pad_type
,
norm_type
,
act_type
,
mode
,
convtype
,
spectral_norm
=
spectral_norm
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
else
:
RDB_list
=
[
ResidualDenseBlock_5C
(
nf
,
kernel_size
,
gc
,
stride
,
bias
,
pad_type
,
norm_type
,
act_type
,
mode
,
convtype
,
spectral_norm
=
spectral_norm
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
for
_
in
range
(
nr
)]
self
.
RDBs
=
nn
.
Sequential
(
*
RDB_list
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
hasattr
(
self
,
'RDB1'
):
out
=
self
.
RDB1
(
x
)
out
=
self
.
RDB1
(
x
)
out
=
self
.
RDB2
(
out
)
out
=
self
.
RDB2
(
out
)
out
=
self
.
RDB3
(
out
)
out
=
self
.
RDB3
(
out
)
else
:
out
=
self
.
RDBs
(
x
)
return
out
*
0.2
+
x
return
out
*
0.2
+
x
class
RRDBNet
(
nn
.
Module
):
class
ResidualDenseBlock_5C
(
nn
.
Module
):
def
__init__
(
self
,
in_nc
,
out_nc
,
nf
,
nb
,
gc
=
32
):
"""
super
(
RRDBNet
,
self
)
.
__init__
()
Residual Dense Block
RRDB_block_f
=
functools
.
partial
(
RRDB
,
nf
=
nf
,
gc
=
gc
)
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
"""
def
__init__
(
self
,
nf
=
64
,
kernel_size
=
3
,
gc
=
32
,
stride
=
1
,
bias
=
1
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'leakyrelu'
,
mode
=
'CNA'
,
convtype
=
'Conv2D'
,
spectral_norm
=
False
,
gaussian_noise
=
False
,
plus
=
False
):
super
(
ResidualDenseBlock_5C
,
self
)
.
__init__
()
self
.
noise
=
GaussianNoise
()
if
gaussian_noise
else
None
self
.
conv1x1
=
conv1x1
(
nf
,
gc
)
if
plus
else
None
self
.
conv1
=
conv_block
(
nf
,
gc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
self
.
conv2
=
conv_block
(
nf
+
gc
,
gc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
self
.
conv3
=
conv_block
(
nf
+
2
*
gc
,
gc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
self
.
conv4
=
conv_block
(
nf
+
3
*
gc
,
gc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
if
mode
==
'CNA'
:
last_act
=
None
else
:
last_act
=
act_type
self
.
conv5
=
conv_block
(
nf
+
4
*
gc
,
nf
,
3
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
last_act
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
def
forward
(
self
,
x
):
x1
=
self
.
conv1
(
x
)
x2
=
self
.
conv2
(
torch
.
cat
((
x
,
x1
),
1
))
if
self
.
conv1x1
:
x2
=
x2
+
self
.
conv1x1
(
x
)
x3
=
self
.
conv3
(
torch
.
cat
((
x
,
x1
,
x2
),
1
))
x4
=
self
.
conv4
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
),
1
))
if
self
.
conv1x1
:
x4
=
x4
+
x2
x5
=
self
.
conv5
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
,
x4
),
1
))
if
self
.
noise
:
return
self
.
noise
(
x5
.
mul
(
0.2
)
+
x
)
else
:
return
x5
*
0.2
+
x
self
.
conv_first
=
nn
.
Conv2d
(
in_nc
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
RRDB_trunk
=
make_layer
(
RRDB_block_f
,
nb
)
self
.
trunk_conv
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
#### upsampling
self
.
upconv1
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
upconv2
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
HRconv
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
conv_last
=
nn
.
Conv2d
(
nf
,
out_nc
,
3
,
1
,
1
,
bias
=
True
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
####################
# ESRGANplus
####################
class
GaussianNoise
(
nn
.
Module
):
def
__init__
(
self
,
sigma
=
0.1
,
is_relative_detach
=
False
):
super
()
.
__init__
()
self
.
sigma
=
sigma
self
.
is_relative_detach
=
is_relative_detach
self
.
noise
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
fea
=
self
.
conv_first
(
x
)
if
self
.
training
and
self
.
sigma
!=
0
:
trunk
=
self
.
trunk_conv
(
self
.
RRDB_trunk
(
fea
))
self
.
noise
=
self
.
noise
.
to
(
x
.
device
)
fea
=
fea
+
trunk
scale
=
self
.
sigma
*
x
.
detach
()
if
self
.
is_relative_detach
else
self
.
sigma
*
x
sampled_noise
=
self
.
noise
.
repeat
(
*
x
.
size
())
.
normal_
()
*
scale
x
=
x
+
sampled_noise
return
x
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
)
####################
# SRVGGNetCompact
####################
class
SRVGGNetCompact
(
nn
.
Module
):
"""A compact VGG-style network structure for super-resolution.
This class is copied from https://github.com/xinntao/Real-ESRGAN
"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_out_ch
=
3
,
num_feat
=
64
,
num_conv
=
16
,
upscale
=
4
,
act_type
=
'prelu'
):
super
(
SRVGGNetCompact
,
self
)
.
__init__
()
self
.
num_in_ch
=
num_in_ch
self
.
num_out_ch
=
num_out_ch
self
.
num_feat
=
num_feat
self
.
num_conv
=
num_conv
self
.
upscale
=
upscale
self
.
act_type
=
act_type
self
.
body
=
nn
.
ModuleList
()
# the first conv
self
.
body
.
append
(
nn
.
Conv2d
(
num_in_ch
,
num_feat
,
3
,
1
,
1
))
# the first activation
if
act_type
==
'relu'
:
activation
=
nn
.
ReLU
(
inplace
=
True
)
elif
act_type
==
'prelu'
:
activation
=
nn
.
PReLU
(
num_parameters
=
num_feat
)
elif
act_type
==
'leakyrelu'
:
activation
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
self
.
body
.
append
(
activation
)
# the body structure
for
_
in
range
(
num_conv
):
self
.
body
.
append
(
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
))
# activation
if
act_type
==
'relu'
:
activation
=
nn
.
ReLU
(
inplace
=
True
)
elif
act_type
==
'prelu'
:
activation
=
nn
.
PReLU
(
num_parameters
=
num_feat
)
elif
act_type
==
'leakyrelu'
:
activation
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
self
.
body
.
append
(
activation
)
fea
=
self
.
lrelu
(
self
.
upconv1
(
F
.
interpolate
(
fea
,
scale_factor
=
2
,
mode
=
'nearest'
)))
# the last conv
fea
=
self
.
lrelu
(
self
.
upconv2
(
F
.
interpolate
(
fea
,
scale_factor
=
2
,
mode
=
'nearest'
)))
self
.
body
.
append
(
nn
.
Conv2d
(
num_feat
,
num_out_ch
*
upscale
*
upscale
,
3
,
1
,
1
))
out
=
self
.
conv_last
(
self
.
lrelu
(
self
.
HRconv
(
fea
)))
# upsample
self
.
upsampler
=
nn
.
PixelShuffle
(
upscale
)
def
forward
(
self
,
x
):
out
=
x
for
i
in
range
(
0
,
len
(
self
.
body
)):
out
=
self
.
body
[
i
](
out
)
out
=
self
.
upsampler
(
out
)
# add the nearest upsampled image, so that the network learns the residual
base
=
F
.
interpolate
(
x
,
scale_factor
=
self
.
upscale
,
mode
=
'nearest'
)
out
+=
base
return
out
return
out
####################
# Upsampler
####################
class
Upsample
(
nn
.
Module
):
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
The input data is assumed to be of the form
`minibatch x channels x [optional depth] x [optional height] x width`.
"""
def
__init__
(
self
,
size
=
None
,
scale_factor
=
None
,
mode
=
"nearest"
,
align_corners
=
None
):
super
(
Upsample
,
self
)
.
__init__
()
if
isinstance
(
scale_factor
,
tuple
):
self
.
scale_factor
=
tuple
(
float
(
factor
)
for
factor
in
scale_factor
)
else
:
self
.
scale_factor
=
float
(
scale_factor
)
if
scale_factor
else
None
self
.
mode
=
mode
self
.
size
=
size
self
.
align_corners
=
align_corners
def
forward
(
self
,
x
):
return
nn
.
functional
.
interpolate
(
x
,
size
=
self
.
size
,
scale_factor
=
self
.
scale_factor
,
mode
=
self
.
mode
,
align_corners
=
self
.
align_corners
)
def
extra_repr
(
self
):
if
self
.
scale_factor
is
not
None
:
info
=
'scale_factor='
+
str
(
self
.
scale_factor
)
else
:
info
=
'size='
+
str
(
self
.
size
)
info
+=
', mode='
+
self
.
mode
return
info
def
pixel_unshuffle
(
x
,
scale
):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b
,
c
,
hh
,
hw
=
x
.
size
()
out_channel
=
c
*
(
scale
**
2
)
assert
hh
%
scale
==
0
and
hw
%
scale
==
0
h
=
hh
//
scale
w
=
hw
//
scale
x_view
=
x
.
view
(
b
,
c
,
h
,
scale
,
w
,
scale
)
return
x_view
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
)
.
reshape
(
b
,
out_channel
,
h
,
w
)
def
pixelshuffle_block
(
in_nc
,
out_nc
,
upscale_factor
=
2
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'relu'
,
convtype
=
'Conv2D'
):
"""
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv
=
conv_block
(
in_nc
,
out_nc
*
(
upscale_factor
**
2
),
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
None
,
act_type
=
None
,
convtype
=
convtype
)
pixel_shuffle
=
nn
.
PixelShuffle
(
upscale_factor
)
n
=
norm
(
norm_type
,
out_nc
)
if
norm_type
else
None
a
=
act
(
act_type
)
if
act_type
else
None
return
sequential
(
conv
,
pixel_shuffle
,
n
,
a
)
def
upconv_block
(
in_nc
,
out_nc
,
upscale_factor
=
2
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'relu'
,
mode
=
'nearest'
,
convtype
=
'Conv2D'
):
""" Upconv layer """
upscale_factor
=
(
1
,
upscale_factor
,
upscale_factor
)
if
convtype
==
'Conv3D'
else
upscale_factor
upsample
=
Upsample
(
scale_factor
=
upscale_factor
,
mode
=
mode
)
conv
=
conv_block
(
in_nc
,
out_nc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
convtype
=
convtype
)
return
sequential
(
upsample
,
conv
)
####################
# Basic blocks
####################
def
make_layer
(
basic_block
,
num_basic_block
,
**
kwarg
):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block. (block)
num_basic_block (int): number of blocks. (n_layers)
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers
=
[]
for
_
in
range
(
num_basic_block
):
layers
.
append
(
basic_block
(
**
kwarg
))
return
nn
.
Sequential
(
*
layers
)
def
act
(
act_type
,
inplace
=
True
,
neg_slope
=
0.2
,
n_prelu
=
1
,
beta
=
1.0
):
""" activation helper """
act_type
=
act_type
.
lower
()
if
act_type
==
'relu'
:
layer
=
nn
.
ReLU
(
inplace
)
elif
act_type
in
(
'leakyrelu'
,
'lrelu'
):
layer
=
nn
.
LeakyReLU
(
neg_slope
,
inplace
)
elif
act_type
==
'prelu'
:
layer
=
nn
.
PReLU
(
num_parameters
=
n_prelu
,
init
=
neg_slope
)
elif
act_type
==
'tanh'
:
# [-1, 1] range output
layer
=
nn
.
Tanh
()
elif
act_type
==
'sigmoid'
:
# [0, 1] range output
layer
=
nn
.
Sigmoid
()
else
:
raise
NotImplementedError
(
'activation layer [{:s}] is not found'
.
format
(
act_type
))
return
layer
class
Identity
(
nn
.
Module
):
def
__init__
(
self
,
*
kwargs
):
super
(
Identity
,
self
)
.
__init__
()
def
forward
(
self
,
x
,
*
kwargs
):
return
x
def
norm
(
norm_type
,
nc
):
""" Return a normalization layer """
norm_type
=
norm_type
.
lower
()
if
norm_type
==
'batch'
:
layer
=
nn
.
BatchNorm2d
(
nc
,
affine
=
True
)
elif
norm_type
==
'instance'
:
layer
=
nn
.
InstanceNorm2d
(
nc
,
affine
=
False
)
elif
norm_type
==
'none'
:
def
norm_layer
(
x
):
return
Identity
()
else
:
raise
NotImplementedError
(
'normalization layer [{:s}] is not found'
.
format
(
norm_type
))
return
layer
def
pad
(
pad_type
,
padding
):
""" padding layer helper """
pad_type
=
pad_type
.
lower
()
if
padding
==
0
:
return
None
if
pad_type
==
'reflect'
:
layer
=
nn
.
ReflectionPad2d
(
padding
)
elif
pad_type
==
'replicate'
:
layer
=
nn
.
ReplicationPad2d
(
padding
)
elif
pad_type
==
'zero'
:
layer
=
nn
.
ZeroPad2d
(
padding
)
else
:
raise
NotImplementedError
(
'padding layer [{:s}] is not implemented'
.
format
(
pad_type
))
return
layer
def
get_valid_padding
(
kernel_size
,
dilation
):
kernel_size
=
kernel_size
+
(
kernel_size
-
1
)
*
(
dilation
-
1
)
padding
=
(
kernel_size
-
1
)
//
2
return
padding
class
ShortcutBlock
(
nn
.
Module
):
""" Elementwise sum the output of a submodule to its input """
def
__init__
(
self
,
submodule
):
super
(
ShortcutBlock
,
self
)
.
__init__
()
self
.
sub
=
submodule
def
forward
(
self
,
x
):
output
=
x
+
self
.
sub
(
x
)
return
output
def
__repr__
(
self
):
return
'Identity +
\n
|'
+
self
.
sub
.
__repr__
()
.
replace
(
'
\n
'
,
'
\n
|'
)
def
sequential
(
*
args
):
""" Flatten Sequential. It unwraps nn.Sequential. """
if
len
(
args
)
==
1
:
if
isinstance
(
args
[
0
],
OrderedDict
):
raise
NotImplementedError
(
'sequential does not support OrderedDict input.'
)
return
args
[
0
]
# No sequential is needed.
modules
=
[]
for
module
in
args
:
if
isinstance
(
module
,
nn
.
Sequential
):
for
submodule
in
module
.
children
():
modules
.
append
(
submodule
)
elif
isinstance
(
module
,
nn
.
Module
):
modules
.
append
(
module
)
return
nn
.
Sequential
(
*
modules
)
def
conv_block
(
in_nc
,
out_nc
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'relu'
,
mode
=
'CNA'
,
convtype
=
'Conv2D'
,
spectral_norm
=
False
):
""" Conv layer with padding, normalization, activation """
assert
mode
in
[
'CNA'
,
'NAC'
,
'CNAC'
],
'Wrong conv mode [{:s}]'
.
format
(
mode
)
padding
=
get_valid_padding
(
kernel_size
,
dilation
)
p
=
pad
(
pad_type
,
padding
)
if
pad_type
and
pad_type
!=
'zero'
else
None
padding
=
padding
if
pad_type
==
'zero'
else
0
if
convtype
==
'PartialConv2D'
:
c
=
PartialConv2d
(
in_nc
,
out_nc
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
groups
=
groups
)
elif
convtype
==
'DeformConv2D'
:
c
=
DeformConv2d
(
in_nc
,
out_nc
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
groups
=
groups
)
elif
convtype
==
'Conv3D'
:
c
=
nn
.
Conv3d
(
in_nc
,
out_nc
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
groups
=
groups
)
else
:
c
=
nn
.
Conv2d
(
in_nc
,
out_nc
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
groups
=
groups
)
if
spectral_norm
:
c
=
nn
.
utils
.
spectral_norm
(
c
)
a
=
act
(
act_type
)
if
act_type
else
None
if
'CNA'
in
mode
:
n
=
norm
(
norm_type
,
out_nc
)
if
norm_type
else
None
return
sequential
(
p
,
c
,
n
,
a
)
elif
mode
==
'NAC'
:
if
norm_type
is
None
and
act_type
is
not
None
:
a
=
act
(
act_type
,
inplace
=
False
)
n
=
norm
(
norm_type
,
in_nc
)
if
norm_type
else
None
return
sequential
(
n
,
a
,
p
,
c
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment