超分辨率重建——SESR网络训练并推理测试(详细图文教程)

最近学了一个超轻量化的超分辨率重建网络SESR,效果还不错。
在这里插入图片描述

一、 源码包

SESR官网的地址为:官网

我自己调整过的源码包地址为:SESR完整包 提取码:b80m

论文地址:论文

源码包推荐使用我给的,我注释过很多地方,看起来不吃力,且我自己添加了推理测试脚本。

下载好源码包解压后的样子如下:

在这里插入图片描述

二、 数据集的准备

获取数据集可以有两种方法:

2.1 官网下载

直接运行源码包中的脚本文件train.py,会自动先下载div2k数据集,但是下载的非常慢,高分辨率数据集有3G多,容易下蹦了。默认会下载到系统C盘下,具体路径为:

C:\Users\Administrator\tensorflow_datasets\downloads,每次下载失败后再次运行又会重新生成序列码并重新下载,很麻烦。

如下:
在这里插入图片描述

2.2 网盘下载

我提供了一个我下载好并整理好的数据集,文件存放对应关系我都整理好了,学者可以直接下载导入使用,下载链接为:网盘下载 ,提取码为:32d4

三、 训练环境配置

该网络结构是在TensorFlow框架下运行的TensorFlow版本是2.3,还有一个包的版本是tensorflow_datasets==4.1,Pyhton3.6版本,额。。。。。。。。。。。。。。。。。。

踩了很多坑,最后我自己调通的版本是TensorFlow-gpu2.9,Python 3.7版本,tensorflow_datasets4.8.2,如下:

在这里插入图片描述

安装好TensorFlow-GPU后先测试一下能不能正常调用GPU,测试方法参考:添加链接描述

四、训练

4.1 修改配置参数

打开train.py文件,里面有些配置参数根据自己电脑情况修改:

在这里插入图片描述

train.py脚本中对应上图修改的地方如下:

在这里插入图片描述

4.2 导入数据集

下载好我提供的数据集后,解压好讲整个tensorflow_datasets文件夹放到data文件夹中,并将tensorflow_datasets文件夹所在路径赋值给变量data_dir,代码中具体的修改地方如下:

在这里插入图片描述

4.3 2倍超分网络训练

根据自己需求选择要训练深度:

4.3.1 训练SESR-M5网络

其中m = 5,f = 16,feature_size = 256,具有折叠线性块:

python train.py

4.3.2 训练SESR-M5网络

m = 5,f = 16,feature_size = 256,扩展线性块:

python train.py --linear_block_type expanded

4.3.3 训练SESR-M11网络

其中m = 11,f = 16,feature_size = 64,具有折叠线性块:

python train.py --m 11 --feature_size 64

4.4.4 训练SESR-XL网络

其中m = 11,f = 16,feature_size = 64,具有折叠线性块:

python train.py --m 11 --int_features 32 --feature_size 64

4.4 2倍超分网络模型

通过上面步骤训练好后会在logs文件中自动保存权重文件和模型,我自己训练好的模型权重文件都打包在源码包了,学者可以直接使用,如下:

在这里插入图片描述
上面各个文件代表内容为:

.pb:表示protocol buffers,是模型结构和参数的二进制序列化文件。存储了模型的网络结构,变量,权重等信息。是模型persist的主要文件。

.data-00000-of-00001:存储了模型变量的取值,即模型权重参数的值。模型训练完成后保存的权重。

.index:索引文件,存放了参数tensor的meta信息,如tensor名称、维度等。用于定位data文件中的tensor数据。

checkpoints文件:存储模型训练过程中的参数,用于恢复训练。

4.5 修改模型保存格式

上面是默认的保存方式,学长如果需要其他格式的自己修改保存方法,具体修改地方如下:

在这里插入图片描述

4.6 4倍超分网络训练

4倍超分网络得在2倍超分模型基础上训练才行,网络深度自己选择:

4.6.1 训练SESR-M5网络

其中m = 5,f = 16,feature_size = 256,具有折叠线性块:

python train.py --scale 4

4.6.2 训练SESR-M5网络

m = 5,f = 16,feature_size = 256,扩展线性块:

python train.py --linear_block_type expanded --scale 4

4.6.3 训练SESR-M11网络

其中m = 11,f = 16,feature_size = 64,具有折叠线性块:

python train.py --m 11 --feature_size 64 --scale 4

4.6.4 训练SESR-XL网络

其中m = 11,f = 16,feature_size = 64,具有折叠线性块:

python train.py --m 11 --int_features 32 --feature_size 64 --scale 4

4.7 4倍超分网络模型

训练好后,模型会自动保存在logs文件中,如下:

在这里插入图片描述

五、量化训练

运行以下命令,在训练时对网络进行调试,并生成TFLITE(用于x2 SISR、SESR-M5网络):

python train.py --quant_W --quant_A --gen_tflite

5.1 量化训练模型

训练好后自动保存在logs/x2_models文件下,如下:

在这里插入图片描述

六、模型推理测试

推理脚本是我自己写的,具体使用如下,根据需求自行选择:

在这里插入图片描述在这里插入图片描述
在这里插入图片描述

七、超分效果

在这里插入图片描述
在这里插入图片描述

八、总结

以上就是超分辨率重建——SESR网络训练并推理测试的详细图文教程,总结不易,给个三连多多支持,谢谢!欢迎留言讨论。

猜你喜欢

转载自blog.csdn.net/qq_40280673/article/details/134062403