论文地址:
Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
复现代码:
Github-SRcode
图像分辨率是一组用于评估图像中蕴含细节信息丰富程度的性能参数,包括时间分辨率、空间分辨率及色阶分辨率等,体现了成像系统实际所能反映物体细节信息的能力。相较于低分辨率图像,高分辨率图像通常包含更大的像素密度、更丰富的纹理细节及更高的可信赖度。但在实际上情况中,受采集设备与环境、网络传输介质与带宽、图像退化模型本身等诸多因素的约束,我们通常并不能直接得到具有边缘锐化、无成块模糊的理想高分辨率图像。提升图像分辨率的最直接的做法是对采集系统中的光学硬件进行改进,但是由于制造工艺难以大幅改进并且制造成本十分高昂,因此物理上解决图像低分辨率问题往往代价太大。由此,从软件和算法的角度着手,实现图像超分辨率重建的技术成为了图像处理和计算机视觉等多个领域的热点研究课题。
图像的超分辨率重建技术指的是将给定的低分辨率图像通过特定的算法恢复成相应的高分辨率图像。具体来说,图像超分辨率重建技术指的是利用数字图像处理、计算机视觉等领域的相关知识,借由特定的算法和处理流程,从给定的低分辨率图像中重建出高分辨率图像的过程。其旨在克服或补偿由于图像采集系统或采集环境本身的限制,导致的成像图像模糊、质量低下、感兴趣区域不显著等问题。
简单来理解超分辨率重建就是将小尺寸图像变为大尺寸图像,使图像更加“清晰”。具体效果如下图所示。
现今,超分辨率问题的病态性质尤其表现在取较高的放大因子时,重构的超分辨率图像通常会缺失纹理细节。监督SR算法的优化目标函数通常取重建高分辨率图像和地面真值之间的均方误差,在减小均方误差的同时又可以增大峰值信噪比(PSNR),PSNR是评价和比较SR算法的常用指标。但是MSE和PSNR值的高低并不能很好的表示视觉效果的好坏。正如Figture2表现出的,PSNR最高并不能反映最好的视觉SR效果。
SRGAN,由论文《
Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
》提出。文章使用了结合跳跃-连接(skip-connection)的深度残差网络(ResNet)。通过使用VGG网络的高层特征映射定义了新的感知损失,该损失使用的判别器使生成的高分辨率图像与实际原始图像在视觉上尽量相似。
本文实验原理基于SRGAN的图像超分方法,下面分析该论文的理论内容:
构建了基于
MSE
损失构建的16 blocks ResNet:
SRResNet
,作为生成网络的backbone 。
提出了基于感知损失优化的
SRGAN
网络,同时将内容损失由直接
MSE
替换为VGG网络特征图的欧氏距离计算损失。引入的判别器结构增加了生成图片的真实感,改进的感知损失使与原始图片的相似不再局限于像素而是全局。
使用主观评估手段:
MOS
,更加强调人的感知。
核心目标:训练一个生成器
$G$
,对低分辨率图像进行超分恢复。
$ I^{HR} $:原始高分辨率图片
$ I^{SR} $:超分辨率恢复照片
$ I^{LR} $:原始高分辨率图片
高斯滤波
+
bicubic 下采样
后的低分辨率图片
生成器参数更新:
\hat{\theta}
{G}=\arg \min
{\theta
{G}} \frac{1}{N} \sum
{n=1}^{N} l^{S R}\left(G_{\theta_{G}}\left(I_{n}^{L R}\right), I_{n}^{H R}\right)
这段公式是生成网络优化的核心:由原始高分图像下采样的低分图像经生成器恢复后,与原始高分图像计算损失,再对总损失沿负梯度方向优化。
而作为文章的创新之一,损失函数
$l^{SR}$
作者并没有用通常的MSE loss,而是提出基于改进的感知损失函数(Perceptual loss function) ,它由content loss 和 adversarial loss 加权获得:
l^{S R}=\underbrace{\underbrace{l_{\mathrm{X}}^{S R}}
{\text {content loss }}+\underbrace{10^{-3} l
{G e n}^{S R}}
{\text {adversarial loss}}}
{\text {perceptual loss (for VGG based content losses)}}
通常的 content loss 由逐像素的 MSE loss 表示:
l_{M S E}^{S R}=\frac{1}{r^{2} W H} \sum_{x=1}^{r W} \sum_{y=1}^{r H}\left(I_{x, y}^{H R}-G_{\theta_{G}}\left(I^{L R}\right)
{x, y}\right)^{2}
而在这里,作者取代传统的逐像素 MSE loss,使用 VGG loss,更加考虑全局相关性 。这里的 $\phi
{i, j}$ 表示在 VGG19 网络中的第 i 个最大池化层之前通过第 j 个卷积(激活后)获得的
$W \times H$
特征图,使用原图和生成图像经过VGG19后特征图的欧氏距离表示 loss ,计算表达式如下:
\begin{aligned}
l_{V G G / i . j}^{S R}=\frac{1}{W_{i, j} H_{i, j}} & \sum_{x=1}^{W_{i, j}} \sum_{y=1}^{H_{i, j}}\left(\phi_{i, j}\left(I^{H R}\right)
{x, y}\right.\left.-\phi
{i, j}\left(G_{\theta_{G}}\left(I^{L R}\right)\right)_{x, y}\right)^{2}
\end{aligned}
对于 adversarial loss
$l_{G e n}^{S R}$
,作者基于判别器在所有训练样本上的判别准确率之和,$ D_{\theta_{D}}\left(G_{\theta_{G}}\left(I^{L R}\right)\right) $ 表示的是判别器判断生成图像为原始高分图像的概率。为了更好的梯度下降效果,使用 $ -\log D_{\theta_{D}}\left(G_{\theta_{G}}\left(I^{L R}\right)\right) $ 代替 $\log \left[ {1 - {D_{{\theta
D}}}\left( {{G
{{\theta
G}}}\left( {{I^{LR}}} \right)} \right)} \right]$。
l
{G e n}^{S R}=\sum_{n=1}^{N}-\log D_{\theta_{D}}\left(G_{\theta_{G}}\left(I^{L R}\right)\right)
作者提出了一种新的评价标准:平均意见分数(Mean opinion score)。
具体来说,要求 26 位评分者为超分图像
SR
分配从 1(质量差)到 5(质量好)的积分。每个评分者对数据集每张图片的12个处理结果进行评判,包括:最近邻(NN)、bicubic、SRCNN、 SelfExSR 、DRCN 、 ESPCN、 SRResNet-MSE、SRResNet-VGG22∗ (∗ 表示不在 BSD100 数据集评分)、SRGAN-MSE∗、SRGAN-VGG22∗、 SRGAN-VGG54 和原始
HR
图像。评分者同时在来自 BSD300 训练集的 20 张图像的 NN(得分 1)和 HR(得分5)上进行了校准。
经过实验,改评价标准具有良好的可靠性,相同图像的评级之间没有显着差异。 评分者非常一致地将 NN 插值测试图像评为 1,将原始 HR 图像评为 5(参见下图)。
使用
ImageNet
数据集作为数据来源,随机从中选取图像作为训练,并将其与测试图像区分开。
每个mini-batch,训练的16张
HR
图像是对原图进行随机96X96裁剪获得的。
对
HR
图像再使用bicubic内核进行4X下采样,获得24X24大小的
LR
图像。
LR
图像标准化到
$/255 \in \left[ {0,1} \right]$
中,
HR
图像标准化到$/127.5 - 1 \in \left[ { - 1,1} \right]$中。
SRResNet
训练使用Adam优化器,β取0.9,learning rate 取 1e-4,进行1e6次迭代。为了避免局部最优,先对
SRenNet
进行训练,并将训练结果作为生成器的初始权重。
SRGAN
的训练也使用
Adam
优化器,先以 1e-4 的 learning rate 进行 1e5 次迭代,再以 1e-5 的 learning rate 进行 1e5 次迭代。每轮生成器和判别器参数交替更新。
SRGAN-VGG22:
$l_{V G G / 2.2}^{S R}$
,$\phi_{2,2}$ 表示较低级特征的特征图上定义的损失。
SRGAN-VGG54:
$l_{V G G / 5.4}^{S R}$
,
$\phi_{5,4}$
表示更深网络层的更高级特征图上定义的损失,有更多的潜力去关注图像内容。
表格的前两列表示基于
SRResNet
的消融实验,分别用普通的
MSE
损失和特征损失。可以看出,即使使用了对抗网络,
MSE
损失相比感知损失也提供了更高的
PSNR
结果,然而在实际结果给人的感知上,图片却更加平滑且难以让人信服,但
MOS
的结果却真实反映出了这一结果。
同时,经过比较,使用感知损失相比MSE在
Set5
上的差别并不大,但在
Set14
上,SRGAN-VGG54 的
MOS
分数明显优于其它方法。而且对比$\phi_{2,2}$,使用更高级别特征图的$\phi_{5,4}$会产生更优秀的纹理细节。
Super-resolution of CelebA using Generative Adversarial Networks.
The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0
Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to 'datasets/'
4. Run the sript using command 'python srgan.py'
from
__future__
import
print_function
,
division
import
os
import
datetime
import
numpy
as
np
import
matplotlib
import
matplotlib
.
pyplot
as
plt
from
PIL
import
Image
from
PIL
import
ImageFilter
from
glob
import
glob
import
tensorflow
as
tf
from
tensorflow
.
keras
.
layers
import
Input
,
Dense
,
Reshape
,
Flatten
,
Dropout
,
Concatenate
from
tensorflow
.
keras
.
layers
import
BatchNormalization
,
Activation
,
ZeroPadding2D
,
Add
from
tensorflow
.
keras
.
layers
import
PReLU
,
LeakyReLU
from
tensorflow
.
keras
.
layers
import
UpSampling2D
,
Conv2D
from
tensorflow
.
keras
.
losses
import
MeanSquaredError
,
BinaryCrossentropy
from
tensorflow
.
keras
.
applications
.
vgg19
import
VGG19
from
tensorflow
.
keras
.
models
import
Sequential
,
Model
,
load_model
from
tensorflow
.
keras
.
optimizers
import
Adam
from
tensorflow
.
keras
.
utils
import
plot_model
matplotlib
.
use
(
'Agg'
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
os
.
environ
[
'TF_CPP_MIN_LOG_LEVEL'
]
=
'2'
class
DataLoader
(
object
):
def
__init__
(
self
,
dataset_name
,
img_res
=
(
256
,
256
)):
self
.
dataset_name
=
dataset_name
self
.
img_res
=
img_res
self
.
image_paths
=
glob
(
'./datasets/%s/*'
%
self
.
dataset_name
)
np
.
random
.
shuffle
(
self
.
image_paths
)
self
.
images
=
self
.
image_paths
def
reload_dataset
(
self
):
self
.
images
=
self
.
image_paths
def
load_data
(
self
,
batch_size
=
1
,
is_testing
=
False
):
# without replacement sampling
batch_images
=
np
.
random
.
choice
(
self
.
images
,
size
=
1
,
replace
=
False
)
img
=
Image
.
open
(
batch_images
[
0
]).
convert
(
'RGB'
)
# if high_res image smaller than set size, continue select
while
img
.
size
[
0
]
-
self
.
img_res
[
0
]
-
1
<=
0
or
img
.
size
[
1
]
-
self
.
img_res
[
1
]
-
1
<=
0
:
batch_images
=
np
.
random
.
choice
(
self
.
images
,
size
=
1
,
replace
=
False
)
img
=
Image
.
open
(
batch_images
[
0
]).
convert
(
'RGB'
)
imgs_hr
=
[]
imgs_lr
=
[]
for
_
in
range
(
batch_size
):
# random crop 96 × 96 HR sub images
left_upx
=
np
.
random
.
randint
(
0
,
img
.
size
[
0
]
-
self
.
img_res
[
0
]
-
1
)
left_upy
=
np
.
random
.
randint
(
0
,
img
.
size
[
1
]
-
self
.
img_res
[
1
]
-
1
)
img_hr
=
img
.
crop
((
left_upx
,
left_upy
,
left_upx
+
self
.
img_res
[
0
],
left_upy
+
self
.
img_res
[
1
]))
# Gaussian filter
img_lr
=
img_hr
.
filter
(
ImageFilter
.
GaussianBlur
(
1.5
))
# Downsampling
img_lr
=
img_lr
.
resize
((
self
.
img_res
[
0
]
//
4
,
self
.
img_res
[
1
]
//
4
),
Image
.
Resampling
.
BICUBIC
)
# convert to numpy
img_hr
=
np
.
array
(
img_hr
)
img_lr
=
np
.
array
(
img_lr
)
# If training => do random flip
if
not
is_testing
and
np
.
random
.
random
()
<
0.5
:
img_hr
=
np
.
fliplr
(
img_hr
)
img_lr
=
np
.
fliplr
(
img_lr
)
imgs_hr
.
append
(
img_hr
)
imgs_lr
.
append
(
img_lr
)
imgs_hr
=
np
.
array
(
imgs_hr
)
/
127.5
-
1.
imgs_lr
=
np
.
array
(
imgs_lr
)
/
127.5
-
1.
return
imgs_hr
,
imgs_lr
def
load_spec_data
(
self
):
os
.
makedirs
(
f'images/_src'
,
exist_ok
=
True
)
path
=
glob
(
'./datasets/img_test/*'
)
imgs_hr
=
[]
imgs_lr
=
[]
for
i
,
img_path
in
enumerate
(
path
):
img_hr
=
Image
.
open
(
img_path
).
convert
(
'RGB'
)
# Gaussian filter
img_lr
=
img_hr
.
filter
(
ImageFilter
.
GaussianBlur
(
1.5
))
# Downsampling
img_lr
=
img_lr
.
resize
((
img_hr
.
size
[
0
]
//
4
,
img_hr
.
size
[
1
]
//
4
),
Image
.
Resampling
.
BICUBIC
)
img_lr
.
save
(
f'./images/_src/
{
i
}
.png'
)
# convert to numpy
imgs_hr
.
append
(
np
.
array
(
img_hr
))
imgs_lr
.
append
(
np
.
array
(
img_lr
))
imgs_hr
=
np
.
array
(
imgs_hr
)
/
127.5
-
1.
imgs_lr
=
np
.
array
(
imgs_lr
)
/
127.5
-
1.
return
imgs_hr
,
imgs_lr
class
SRGAN
():
def
__init__
(
self
):
# Input shape
# Use conv layers totally. So do not need to choose shape.
self
.
channels
=
3
self
.
lr_shape
=
(
None
,
None
,
self
.
channels
)
self
.
hr_shape
=
(
None
,
None
,
self
.
channels
)
# Number of residual blocks in the generator
self
.
n_residual_blocks
=
16
# Number of filters in the first layer of G and D
self
.
gf
=
self
.
df
=
64
# Configure data loader
# self.dataset_name = 'img_align_celeba'
self
.
dataset_name
=
'test2017'
self
.
data_loader
=
DataLoader
(
dataset_name
=
self
.
dataset_name
)
self
.
test_imgs
=
self
.
data_loader
.
load_spec_data
()
optimizer
=
Adam
(
0.0001
,
0.9
)
# We use a pre-trained VGG19 model to extract image features from the high resolution
# and the generated high resolution images and minimize the mse between them
vgg
=
VGG19
(
weights
=
"imagenet"
,
input_shape
=
self
.
hr_shape
,
include_top
=
False
)
self
.
vgg
=
Model
(
vgg
.
input
,
outputs
=
vgg
.
layers
[
9
].
output
,
trainable
=
False
)
self
.
vgg
.
compile
(
loss
=
'mse'
,
optimizer
=
optimizer
,
metrics
=
[
'accuracy'
])
# Build and compile the discriminator
self
.
discriminator
=
self
.
build_discriminator
()
self
.
discriminator
.
compile
(
loss
=
'mse'
,
optimizer
=
optimizer
,
metrics
=
[
'accuracy'
])
# For the combined model we will only train the generator
self
.
discriminator
.
trainable
=
False
# High res. and low res. images
img_hr
=
Input
(
self
.
hr_shape
)
img_lr
=
Input
(
self
.
lr_shape
)
# Build the generator
self
.
generator
=
self
.
build_generator
()
self
.
generator
.
compile
(
loss
=
'mse'
,
optimizer
=
optimizer
,
metrics
=
[
'accuracy'
])
# Generate high res. version from low res.
fake_hr
=
self
.
generator
(
img_lr
)
# Perceptual loss: 1e-3 * adversarial loss + vgg loss
self
.
combined
=
Model
([
img_lr
,
img_hr
], [
self
.
discriminator
(
fake_hr
),
self
.
vgg
(
fake_hr
)])
self
.
combined
.
compile
(
loss
=
[
'binary_crossentropy'
,
'mse'
],
loss_weights
=
[
1e-3
,
1
/
(
12.75
**
2
)],
optimizer
=
optimizer
)
plot_model
(
self
.
generator
,
# keras模型
to_file
=
"生成网络.png"
,
# 保存图片路径
show_shapes
=
True
,
# 是否显示形状信息
show_layer_names
=
True
,
# 是否显示图层名称
rankdir
=
"TB"
,
# "TB":垂直图 "LR":水平图
expand_nested
=
True
,
# 是否将嵌套模型展开为簇。
dpi
=
96
# 图片每英寸点数。
plot_model
(
self
.
discriminator
,
# keras模型
to_file
=
"判别网络.png"
,
# 保存图片路径
show_shapes
=
True
,
# 是否显示形状信息
show_layer_names
=
True
,
# 是否显示图层名称
rankdir
=
"TB"
,
# "TB":垂直图 "LR":水平图
expand_nested
=
True
,
# 是否将嵌套模型展开为簇。
dpi
=
96
# 图片每英寸点数。
def
build_generator
(
self
):
def
residual_block
(
layer_input
):
"""Residual block described in paper"""
d
=
Conv2D
(
64
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
)(
layer_input
)
d
=
Activation
(
'relu'
)(
d
)
d
=
BatchNormalization
(
momentum
=
0.8
)(
d
)
d
=
Conv2D
(
64
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
)(
d
)
d
=
BatchNormalization
(
momentum
=
0.8
)(
d
)
d
=
Add
()([
d
,
layer_input
])
return
d
def
deconv2d
(
layer_input
):
"""Layers used during upsampling"""
u
=
Conv2D
(
256
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
)(
layer_input
)
u
=
UpSampling2D
(
size
=
2
)(
u
)
u
=
Activation
(
'relu'
)(
u
)
return
u
# Low resolution image input
img_lr
=
Input
(
shape
=
(
None
,
None
,
3
))
# Pre-residual block
c1
=
Conv2D
(
64
,
kernel_size
=
9
,
strides
=
1
,
padding
=
'same'
)(
img_lr
)
c1
=
Activation
(
'relu'
)(
c1
)
# Propogate through residual blocks
r
=
residual_block
(
c1
)
for
_
in
range
(
self
.
n_residual_blocks
-
1
):
r
=
residual_block
(
r
)
# Post-residual block
c2
=
Conv2D
(
64
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
)(
r
)
c2
=
BatchNormalization
(
momentum
=
0.8
)(
c2
)
c2
=
Add
()([
c2
,
c1
])
# Upsampling
u1
=
deconv2d
(
c2
)
u2
=
deconv2d
(
u1
)
# Generate high resolution output
gen_hr
=
Conv2D
(
self
.
channels
,
kernel_size
=
9
,
strides
=
1
,
padding
=
'same'
,
activation
=
'tanh'
)(
u2
)
return
Model
(
img_lr
,
gen_hr
)
def
build_discriminator
(
self
):
def
d_block
(
layer_input
,
filters
,
strides
=
1
,
bn
=
True
):
"""Discriminator layer"""
d
=
Conv2D
(
filters
,
kernel_size
=
3
,
strides
=
strides
,
padding
=
'same'
)(
layer_input
)
d
=
LeakyReLU
(
alpha
=
0.2
)(
d
)
if
bn
:
d
=
BatchNormalization
(
momentum
=
0.8
)(
d
)
return
d
# Input img
d0
=
Input
(
shape
=
(
None
,
None
,
3
))
d1
=
d_block
(
d0
,
self
.
df
,
bn
=
False
)
d2
=
d_block
(
d1
,
self
.
df
,
strides
=
2
)
d3
=
d_block
(
d2
,
self
.
df
*
2
)
d4
=
d_block
(
d3
,
self
.
df
*
2
,
strides
=
2
)
d5
=
d_block
(
d4
,
self
.
df
*
4
)
d6
=
d_block
(
d5
,
self
.
df
*
4
,
strides
=
2
)
d7
=
d_block
(
d6
,
self
.
df
*
8
)
d8
=
d_block
(
d7
,
self
.
df
*
8
,
strides
=
2
)
d9
=
Dense
(
self
.
df
*
16
)(
d8
)
d10
=
LeakyReLU
(
alpha
=
0.2
)(
d9
)
validity
=
Dense
(
1
,
activation
=
'sigmoid'
)(
d10
)
return
Model
(
d0
,
validity
)
def
pre_train
(
self
,
epochs
,
batch_size
=
1
,
sample_interval
=
50
):
self
.
generator
.
load_weights
(
'./weights/pre_training_checkpoints/'
)
self
.
data_loader
.
reload_dataset
()
start_time
=
datetime
.
datetime
.
now
()
last_time
=
datetime
.
datetime
.
now
()
for
epoch
in
range
(
epochs
):
# ------------------
# Train Generator
# ------------------
# Sample images and their conditioning counterparts
imgs_hr
,
imgs_lr
=
self
.
data_loader
.
load_data
(
batch_size
)
# Train the generators
self
.
generator
.
train_on_batch
(
imgs_lr
,
imgs_hr
)
# If at save interval => save generated image samples
if
epoch
%
sample_interval
==
0
:
self
.
generator
.
save_weights
(
'./weights/pre_training_checkpoints/'
)
loss
=
self
.
evaluate
(
epoch
,
comp_dir
=
"_res_gen"
)
elapsed_time
=
datetime
.
datetime
.
now
()
-
start_time
used_time
=
datetime
.
datetime
.
now
()
-
last_time
# Plot the progress
print
(
f"epoch:
{
epoch
}
\t
g_loss:
{
loss
}
\t
time:
{
elapsed_time
}
\t
interval:
{
used_time
}
"
)
last_time
=
datetime
.
datetime
.
now
()
def
train
(
self
,
epochs
,
batch_size
=
1
,
sample_interval
=
50
):
self
.
generator
.
load_weights
(
'./weights/pre_training_checkpoints/'
)
self
.
data_loader
.
reload_dataset
()
start_time
=
datetime
.
datetime
.
now
()
last_time
=
datetime
.
datetime
.
now
()
# self.generator.summary()
# self.discriminator.summary()
for
epoch
in
range
(
epochs
):
# ----------------------
# Train Discriminator
# ----------------------
# Sample images and their conditioning counterparts
imgs_hr
,
imgs_lr
=
self
.
data_loader
.
load_data
(
batch_size
)
# From low res. image generate high res. version
fake_hr
=
self
.
generator
.
predict
(
imgs_lr
)
# Calculate output shape of D (PatchGAN)
patch_w
=
imgs_hr
.
shape
[
1
]
//
2
**
4
patch_h
=
imgs_hr
.
shape
[
2
]
//
2
**
4
valid
=
np
.
ones
((
batch_size
,
patch_w
,
patch_h
,
1
))
fake
=
np
.
zeros
((
batch_size
,
patch_w
,
patch_h
,
1
))
# Train the discriminators (original images = real / generated = Fake)
d_loss_real
=
self
.
discriminator
.
train_on_batch
(
imgs_hr
,
valid
)
d_loss_fake
=
self
.
discriminator
.
train_on_batch
(
fake_hr
,
fake
)
d_loss
=
(
0.5
*
np
.
add
(
d_loss_real
,
d_loss_fake
))[
0
]
# ------------------
# Train Generator
# ------------------
# Sample images and their conditioning counterparts
imgs_hr
,
imgs_lr
=
self
.
data_loader
.
load_data
(
batch_size
)
# The generators want the discriminators to label all the generated images as real
patch_w
=
imgs_hr
.
shape
[
1
]
//
2
**
4
patch_h
=
imgs_hr
.
shape
[
2
]
//
2
**
4
valid
=
np
.
ones
((
batch_size
,
patch_w
,
patch_h
,
1
))
# Extract ground truth image features using pre-trained VGG19 model
image_features
=
self
.
vgg
.
predict
(
imgs_hr
)
# Train the generators
self
.
combined
.
train_on_batch
([
imgs_lr
,
imgs_hr
], [
valid
,
image_features
])
# If at save interval => save generated image samples
if
epoch
%
sample_interval
==
0
:
self
.
generator
.
save_weights
(
'./weights/training_checkpoints/'
)
# checkpoint = tf.train.Checkpoint(self.generator)
# checkpoint.save('./weights/training_checkpoints')
# self.sample_images(epoch)
loss
=
self
.
evaluate
(
epoch
,
comp_dir
=
"_res_gan"
,
src_dir
=
"_res_src_gan"
)
elapsed_time
=
datetime
.
datetime
.
now
()
-
start_time
used_time
=
datetime
.
datetime
.
now
()
-
last_time
# Plot the progress
print
(
f"epoch:
{
epoch
}
\t
d_loss: %.5f
\t
g_loss: (%.5f, %.5f)
\t
time:
{
elapsed_time
}
\t
interval:
{
used_time
}
"
%
(
d_loss
,
loss
[
0
],
loss
[
1
]))
last_time
=
datetime
.
datetime
.
now
()
def
sample_images
(
self
,
epoch
):
os
.
makedirs
(
'images/%s'
%
self
.
dataset_name
,
exist_ok
=
True
)
r
,
c
=
2
,
2
imgs_hr
,
imgs_lr
=
self
.
data_loader
.
load_data
(
batch_size
=
2
,
is_testing
=
True
)
fake_hr
=
self
.
generator
.
predict
(
imgs_lr
)
# Rescale images 0 - 1
imgs_lr
=
0.5
*
imgs_lr
+
0.5
fake_hr
=
0.5
*
fake_hr
+
0.5
imgs_hr
=
0.5
*
imgs_hr
+
0.5
# Save generated images and the high resolution originals
titles
=
[
'Generated'
,
'Original'
]
fig
,
axs
=
plt
.
subplots
(
r
,
c
)
cnt
=
0
for
row
in
range
(
r
):
for
col
,
image
in
enumerate
([
fake_hr
,
imgs_hr
]):
axs
[
row
,
col
].
imshow
(
image
[
row
])
axs
[
row
,
col
].
set_title
(
titles
[
col
])
axs
[
row
,
col
].
axis
(
'off'
)
cnt
+=
1
fig
.
savefig
(
"images/%s/%d.png"
%
(
self
.
dataset_name
,
epoch
))
plt
.
close
()
# # Save low resolution images for comparison
# for i in range(r):
# fig = plt.figure()
# plt.imshow(imgs_lr[i])
# fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
# plt.close()
def
evaluate
(
self
,
epoch
,
comp_dir
=
''
,
src_dir
=
''
,
testing
=
False
):
if
testing
:
self
.
generator
.
load_weights
(
'./weights/pre_training_checkpoints/'
)
r
,
c
=
2
,
3
imgs_hr
=
self
.
test_imgs
[
0
]
imgs_lr
=
self
.
test_imgs
[
1
]
fake_hr
=
self
.
generator
.
predict
(
imgs_lr
)
precision
,
vgg_feature
=
self
.
combined
.
predict
([
imgs_lr
,
imgs_hr
])
# MSE loss
MSE_loss
=
MeanSquaredError
()(
fake_hr
,
imgs_hr
).
numpy
()
# Perceptual loss
patch_w
=
imgs_hr
.
shape
[
1
]
//
2
**
4
patch_h
=
imgs_hr
.
shape
[
2
]
//
2
**
4
valid
=
np
.
ones
((
c
,
patch_w
,
patch_h
,
1
))
content_loss
=
MeanSquaredError
()(
self
.
vgg
.
predict
(
imgs_hr
),
vgg_feature
).
numpy
()
adversarial_loss
=
BinaryCrossentropy
()(
valid
,
precision
).
numpy
()
Perceptual_loss
=
content_loss
/
(
12.75
**
2
)
+
adversarial_loss
/
1000
# Rescale images 0 - 1
fake_hr
=
0.5
*
fake_hr
+
0.5
imgs_lr
=
0.5
*
imgs_lr
+
0.5
if
comp_dir
:
os
.
makedirs
(
f'images/
{
comp_dir
}
'
,
exist_ok
=
True
)
# Save generated images and the high resolution originals
titles
=
[
'Generated'
,
'Low-resolution'
]
fig
,
axs
=
plt
.
subplots
(
r
,
c
)
plt
.
suptitle
(
f'epoch:
{
epoch
}
MSE_loss: %.5f Perceptual_loss: %.5f'
%
(
MSE_loss
,
Perceptual_loss
))
for
col
in
range
(
c
):
for
row
,
image
in
enumerate
([
fake_hr
,
imgs_lr
]):
axs
[
row
,
col
].
imshow
(
image
[
col
])
axs
[
row
,
col
].
set_title
(
titles
[
row
])
axs
[
row
,
col
].
axis
(
'off'
)
fig
.
savefig
(
f"images/
{
comp_dir
}
/%d.png"
%
epoch
)
plt
.
close
()
if
src_dir
:
os
.
makedirs
(
f'images/
{
src_dir
}
'
,
exist_ok
=
True
)
# Save generative resolution images for comparison
for
i
in
range
(
c
):
im
=
Image
.
fromarray
((
255
*
fake_hr
[
i
]).
astype
(
np
.
uint8
))
im
.
save
(
f"images/
{
src_dir
}
/%d_res%d.png"
%
(
epoch
,
i
))
return
MSE_loss
,
Perceptual_loss
if
__name__
==
'__main__'
:
gan
=
SRGAN
()
gan
.
pre_train
(
epochs
=
10000
,
batch_size
=
16
,
sample_interval
=
100
)
gan
.
train
(
epochs
=
100000
,
batch_size
=
16
,
sample_interval
=
50
)
gan
.
evaluate
(
0
,
src_dir
=
'_res'
,
testing
=
True
)