tensorflow超分辨率代码解析

详见代码 : https://github.com/jiny2001/dcscn-super-resolution
在这里插入图片描述

load_datasets 主要读取文件夹下面的图片,并且把大的图片,裁剪成训练需要的小图片。保存在文件夹下面。
build_graph 主要是建立tensorflow graph 即建立cnn网络结构,建立loss公式。
x 表示有损原始输入。x2表示采用传统算法放大的图像。
self.y_ = self.H_out + self.x2//因为网络实际学习是放大的图像与原始图像的残差。
build_optimizer计算对应的loss以及mse损失:
diff = self.y_ - self.y
loss = mse
self.training_optimizer = self.add_optimizer_op(loss, self.lr_input)//最小化loss,对所有参数变理进行梯度并且相减。
build_input_batch读入待训练数据
train_batch sess运行,feed数据。
, mse = self.sess.run([self.training_optimizer, self.mse]//得到前一次训练。一次训练里batch_input读入batch_num张64x64大小图片。

对于超分辨率来说,input_image 是缩小的图像,quad_image是由于input_image放大2*2图像,然后再变为4通道图像,即quad_image图像大小与缩小图像一样,但是是4通道的。输出y 是input_image 的滤波与quad_image相对,然后再把4通道变为放大图像。

关于训练。
train(model, FLAGS, i) 表示训练一次完成,完成的条件是学习率下降到指定的学习率。
在此训练函数中,包括:
epochs_completed += 1//如果训练的所有数据完成,那么此加1
train_batch 调用一次就更新一次梯度,下一次使用更新后的梯度进行测试。
if epochs_completed < model.epochs_completed//每次一新epoch训练后,测试一下test。
调用mse = model.evaluate() 得到当前的mse,如果比上一次min_mse更小,那么保存模型。

mse = model.do_for_evaluate训练完成后,使用真实的图像进行测试

训练方法:
小网络训练:python train.py --dataset bsd200 --filters 16 --min_filters 8 --nin_filters 24 --nin_filters2 8
训练之后测试单张图片:
python sr.py --file ./data/valid_decode/20170930_131716.png --model_name epoch_00000063 --output_dir ./output --output_name c20170930_131716.png

使用训练好的模型:
主要是通过调用do_for_file来进行。
通过RGB转YUV,只对Y图像进行滤波。 input_y_image 是y的数据,shape[h][w][1]
通过使用output_y_image = self.do(input_y_image, input_y_image) 得到滤波后的Y
再与原始的uv数据,组合成yuv444,再转为RGB输出最终的图像。

猜你喜欢

转载自blog.csdn.net/bvngh3247/article/details/88366672