PocketFlow ChannelPrunedLearner代码详解

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/hw5226349/article/details/89025690

下面代码不包括DDPG强化学习参数优化器和Distill蒸馏训练

PocketFlow框架安装

conda create --name PocketFlow python=3.6

source activate PocketFlow

pip install tensorflow-gpu=1.10.0

pip install numpy=1.14.5

conda install panda

conda install scikit-learn

cifar10数据集准备

cifar-10:使用binary版本

wget https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

下载完成后,解压到data目录

并在path.conf中设置路径:data_dir_local_cifar10 = /home/mars/hewu/tensorflow/PocketFlow/data/cifar-10-batches-bin

执行代码

./scripts/run_local.sh nets/resnet_at_cifar10_run.py 

官方教程需要在工程根目录配置path.conf文件,然后执行上述脚本。个人觉得不太方便调试,直接启动py,单步跟踪,更加便于理解程序逻辑。

nets/resnet_at_cifar10_run.py --model_http_url https://api.ai.tencent.com/pocketflow --data_dir_local /home/mars/hewu/tensorflow/PocketFlow/data/cifar-10-batches-bin --learner channel --cp_prune_option uniform

程序入口-main.py

#/home/mars/hewu/tensorflow/PocketFlow/main.py
from nets.resnet_at_cifar10 import ModelHelper
from learners.learner_utils import create_learner

#1创建模型helper和learner
model_helper = ModelHelper()#网络和数据集的类
learner = create_learner(sw_writer,model_helper)#跳转到不同的压缩算法learner

#2进入训练,或者评估
learner.train()
learner.evaluate()

通道裁剪学习器-channel_pruning/learner.py

#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/learner.py
from learners.distillation_helper import DistillationHelper #蒸馏相关
from learners.abstract_learner import AbstractLearner
from learners.channel_pruning.model_wrapper import Model #模型相关
from learners.channel_pruning.channel_pruner import ChannelPruner #裁剪相关
from rl_agents.ddpg.agent import Agent as DdpgAgent #强化学习代理DDPG

#继承自AbstractLearner
class ChannelPrunedLearner(AbstractLearner):
  #继承初始化
  super(ChannelPrunedLearner,self).__init__(sm_writer,model_helper)
  
  #类内初始化
  #蒸馏类初始化
  self.learner_dst = DistillationHelper(sm_writer,model_helper)
  
  #构建
  #构建输入数据,模型定义,计算裁剪上下限等
  self.__build(is_train=True)
  
  #1train函数
  def train(self):
    #下载预训练模型,恢复权重,创建裁剪者pruner
    #...
    self.create_pruner()
    #选择裁剪策略:list,auto,uniform
    if FLAGS.cp_prune_option == 'list':
      self.__prune_and_finetune_list()
      #self.__prune_and_finetune_auto()
      #self.__prune_and_finetune_uniform()
  #2
  def create_pruner(self):
    #...
    self.model = Model(self.sess_train)
    self.pruner = ChannlPruner(
      self.model,#模型
      images=train_images,
      labels=train_labels,
      mem_images=mem_images,
      mem_labels=mem_labels,
      metrics=metrics,#度量,loss,accuracy
      lbound=self.lbound,#裁剪保留通道比例
      summary_op=summary_op,
      sm_writer=self.sm_writer)
  
  #3以auto策略为例介绍具体裁剪方法
  def __prune_and_finetune_auto(self):
    self.__prune_rl()#初始化RL类并进行裁剪(调用compress),学习最佳裁剪方法
    while not done:#完成prune和finetune
      done = self.__prune_list_layers(queue, [FLAGS.cp_list_group])
     
  def __prune_rl(self):
    #RL学习搜索裁剪策略
    
  #5__prune_rl()和__prune_list_layers()中都会调用compress
  def compress(self, c_ratio): 
    #裁剪时,只把选中的裁剪通道的权值置0,并没有真的裁剪掉
    self.prune_kernel(conv_op,c_ratio)#裁剪策略lasso等
    self.prune_W1(father_conv, idxs)#裁剪父conv的输出通道数(即当前conv的输入通道数)
    self.prune_W2(conv_op, idxs, W2)#裁剪当前conv的输入通道数
    
  def  prune_kernel(self, op, nb_channel_new): #裁剪的具体步骤
    #当前卷积:裁剪后通道数,newX输入feature map,Y目标值,W2权值
    nb_channel_new = max(int(np.around(c * nb_channel_new)), 1)#hw new channel number
    newX = self.__extract_input(op)
    Y = self.feats_dict[outname]
    W2 = self._model.param_data(op)
    #lasso裁剪,得到新的权值newW2,以及通道索引(True/False)
    idxs, newW2 = self.compute_pruned_kernel(newX, W2, Y, c_new=nb_channel_new)
    
  def compute_pruned_kernel(
      self,
      X,
      W2,
      Y,
      alpha=1e-4,
      c_new=None,
      tolerance=0.02):
        
          #固定beta,优化W,即求解W
          while True:
            _, tmp, coef = solve(right)
            ...
          #固定W,优化beta,即求解beta(idxs索引就是beta)
          while True:
            idxs, tmp, coef = solve(alpha)
            ...        

模型封装器-channel_pruning/model_wrapper.py

#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/model_wrapper.py
  def get_Add_if_is_first_after_resblock(self, op):
    #Add的输出层
   
   
  def get_Add_if_is_last_in_resblock(cls, op):
    #Add的输入层
    
  def is_W1_prunable(self, conv):
    #可以裁剪的层

通道裁剪器-channel_pruning/channel_pruner.py

#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/channel_pruner.py
from sklearn.linear_model import LassoLars
from sklearn.linear_model import LinearRegression


class ChannelPruner(object):
  def __init__(self,...):
    self._model = model
    self.thisconvs = self._model.get_operations_by_type()#网络中的卷积层
    self.__build()
  def __build(self):
    self.__extract_output_of_conv_and_sum()#获取conv和add op,存入self.names列表
    self.__create_extractor()#创建用于获取卷积输入feature map的extractor
    self.initialize_state()#初始化状态:主要是确定哪些能裁剪,裁剪率等
    
  def initialize_state(self):
    #op名,对应裁剪保留范围:[] 例如第一个和最后一个卷积不裁剪,则范围为[1.0, 1.0]
    self.max_strategy_dict = {} # collection of intilial max [inp preserve, out preserve]
    #op名,对应输入通道列表和输出通道列表,里面的值为True保留这个通道,False裁剪这个通道
    self.fake_pruning_dict = {} # collection of fake pruning indices
    #layer   n          c    H  W  stride  maxreduce  layercomp
    #状态  输出通道 输入通道 高 宽 stride   最大缩减   层计算量    都是除以每一列最大值后的归一化结果

resnet20裁剪:
权值维度:[KH,KW,Cin,Cout]
当前卷积都是裁剪输入通道,父卷积都是裁剪输出通道。
depthwise conv可以往前递推,直至找到一个普通的Conv2D OP,因为depthwise conv中不同channel之间没有dependency

裁剪的当前卷积 裁剪的父卷积
conv2d_1 conv2d
conv2d_2 conv2d
conv2d_3 conv2d_2
conv2d_5 conv2d_4
conv2d_7 conv2d_6
conv2d_10 conv2d_9
conv2d_12 conv2d_11
conv2d_14 conv2d_13
conv2d_17 conv2d_16
conv2d_19 conv2d_18
conv2d_4可以裁剪输入通道,但是转pb时需要在其前面插入tf.gather Add
conv2d_6 Add
conv2d_8 Add
conv2d_9 Add
conv2d_11 Add
conv2d_13 Add
conv2d_15 Add
conv2d_16 Add
conv2d_18 Add
conv2d_20 Add

最后一个卷积不裁剪
conv2d_21

猜你喜欢

转载自blog.csdn.net/hw5226349/article/details/89025690