tensorflow 高版本模型如何兼容低版本

使用高版本的AI引擎训练,导出模型后转换成Ascend310芯片的OM格式时,有可能遇到算子不支持的情况出现,现在教大家如何合理规避这些算子。

以在TensorFlow-2.x上训练得到的模型为例,如何转换成低版本Ascend310芯片(如C32版本)可用的OM模型。更多的技巧通过这篇文章可以举一反三,灵活变通。

写在前面

由于Frozen Graph已经被TF-2.x抛弃,TF-2.x开始使用keras模型,导出是saved_model格式或者h5格式。想要转换OM模型,首先要得到TensorFlow-1.x上的Frozen Graph模型

在TF-2.x下导出Frozen Graph

假设你有一个TF-2.x下的keras model

model = tf.keras.Model(input_nodes, output_nodes)

通过以下这段代码转换成Frozen Graph 

import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_graph_def = frozen_func.graph.as_graph_def()

# remove final nodes generated by keras which having so many indepencies inputs.
# this will help model to be opened by netron and to be converted to OM
frozen_sub_graph_def = tf.compat.v1.graph_util.extract_sub_graph(
  frozen_graph_def, dest_nodes=[out_node.name[:-2] for out_node in output_nodes])
  
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_sub_graph_def,
                  logdir="/tmp/frozen_graph/",
                  name="model.pb",
                  as_text=False)

keras导出的模型会在最后加上Indentity节点,并且已整个模型为依赖,会导致模型不能转OM,并且netron也加载不了。

中间加入以下方法

tf.compat.v1.graph_util.extract_sub_graph(...)

可以把最后的Identity节点有效移除

一、算子版本过高

转换OM的时候,可能会遇到这种算子不存在的错误,如`FusedBatchNormV3`,这是因为低版本可能只支持到`FusedBatchNorm`为止,没有V3这个版本。

这个时候其实只要通过编辑Frozen Graph文件,简单的替换PB模型文件中的算子名称,把`FusedBatchNormV3`替换成`FusedBatchNorm`就可以了,计算是一样的,不会影响精度,只会影响性能。同类型的还有`AddV2`替换成`Add`,或者其他的算子,如果能找到早期版本的对应算子,就能合理规避。

二、算子新增参数不支持

比如在Conv2D这个算子上,高版本的TF引擎会有explicit_paddings这个选项,并且会写到Graph里,这时候转换OM就会报错,提示explicit_paddings这个attribute找不到,这个时候,也是编辑Frozen Graph文件,将这个attr从Conv2D这个op中去掉。一般来说这种低版本没有,高版本新增的参数特性,为了向前兼容,都默认关闭的,所以去掉一个attr不会影响精度。

奉上以上两种方法,通过编辑Frozen Graph来规避的实现代码

import os
import tempfile
import tensorflow as tf

from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2

TMP_PBTXT = 'tmp_model.pbtxt'
TMP_COMPAT_PBTXT = 'tmp_compat_model.pbtxt'

def merge_line_by_line(fo, input_graph_def):
  entry = 0
  item_lines = []
  for line in fo:
    if not line.strip():
      continue
    item_lines.append(line)
    if line.strip().endswith('{'):
      entry += 1
    elif line.strip().endswith('}'):
      entry -= 1
      if entry == 0:
        text_format.MergeLines(item_lines, input_graph_def)
        del item_lines[:]
        
def parse_input_graph_proto(input_graph, input_binary):
  if not os.path.exists(input_graph):
    raise ValueError('invalid input path')
  input_graph_def = graph_pb2.GraphDef()
  if input_binary:
    with open(input_graph, 'rb') as f:
      input_graph_content = f.read()
    input_graph_def.ParseFromString(input_graph_content)
  else:
    with open(input_graph, 'r') as f:
      merge_line_by_line(f, input_graph_def)
  return input_graph_def
  
def compat_pb(input_pb_path, replace=True):
  tmp_dir = tempfile.mkdtemp()
  tmp_pbtxt_file = os.path.join(tmp_dir, TMP_PBTXT)
  graph_def = parse_input_graph_proto(input_pb_path, input_binary=True)
  tf.train.write_graph(graph_or_graph_def=graph_def, logdir=tmp_dir, name=TMP_PBTXT, as_text=True)
  del graph_def
  new_graph_def_str = ''
  lines_to_cache = []
  num_lines_to_skip = 0
  with open(tmp_pbtxt_file, 'r') as f:
    for line in f:
      if num_lines_to_skip > 0:
        num_lines_to_skip -= 1
        continue
      if 'attr {' in line.strip():
        lines_to_cache.append(line)
        continue
      if line.strip().startswith('key: "explicit_paddings"'):
        del lines_to_cache[:]
        num_lines_to_skip = 5
        continue
      elif line.strip().startswith('key: "U"'):
        del lines_to_cache[:]
        num_lines_to_skip = 4
        continue
      elif line.strip().startswith('key: "half_pixel_centers"'):
        del lines_to_cache[:]
        num_lines_to_skip = 4
        continue
      if lines_to_cache:
        new_graph_def_str += ''.join(lines_to_cache)
        del lines_to_cache[:]
      new_graph_def_str += line
      
  new_graph_def_str = new_graph_def_str.replace('FusedBatchNormV3', 'FusedBatchNorm').replace('AddV2', 'Add')
  tmp_compat_pbtxt_file = os.path.join(tmp_dir, TMP_COMPAT_PBTXT)
  with open(tmp_compat_pbtxt_file, 'w') as f:
    f.write(new_graph_def_str)
    
  del new_graph_def_str
  graph_def_compat = parse_input_graph_proto(tmp_compat_pbtxt_file, input_binary=False)
  input_pb_dir, input_pb_name = os.path.split(input_pb_path)
  output_pb_dir = input_pb_dir
  if replace:
    output_pb_name = input_pb_name
  else:
    output_pb_name = 'compat_' + input_pb_name
    
  tf.train.write_graph(graph_or_graph_def=graph_def_compat, logdir=output_pb_dir, name=output_pb_name, as_text=False)
  
if __name__ == '__main__':
  compat_pb('/tmp/model.pb', replace=False)

虽然这里有一些硬编码,不过作为一个线下用用的工具,能满足功能就行了~

这个代片段主要是为了删除half_pixel_centers这个attribute对应的一组proto描述

      elif line.strip().startswith('key: "half_pixel_centers"'):
        del lines_to_cache[:]
        num_lines_to_skip = 4

如果读懂这个脚本,就能应对各种算子低版本兼容和属性删除了。

当前这个脚本已经可以应对很多TF-1.15向下兼容到Ascend310-C32的情况了。

三、算子本身不支持

比如leaky_relu在Ascend310-C32版本上是找不到算子实现的,那么这个时候只能通过用其他算子拼凑的方式替换了。这时候不能通过编译Frozen Graph文件来解决(过于复杂),推荐直接从源码修改。比如将

y =tf.nn.leaky_relu(x, alpha=alpha)

替换成

tf.maximum(alpha * x, x)

又例如mish激活函数找不到,那么可以用以下算子替换

y = x * tf.tanh(tf.math.log(1 + tf.exp(x)))

四、通过前/后处理规避

如果你的算子使用上述方式还不能支持,并且这个算子出现在模型的头上或者尾部,那恭喜你你还有希望。

你可以在导出模型的时候将算子涉及的这段计算从模型中拿出来,放到推理脚本的前处理后后处理。以伪代码举例

假设你的模型是:

def model(x):
    y = op1(x)
    y = op2(y)
    y = op3(y)

假设op1和op3都不支持,并且都是一些不含有网络权重的计算,那么你导出模型的时候只导出op2部分,将op1用numpy的API写在预处理中,将op3用numpy的API写在后处理中

例如:

在声音分类中,模型的最前面要对数据进行傅里叶变换,但是傅里叶变换算子在Ascend310-C32上不支持,那么在导出模型的时候将傅里叶变换从模型的最开始摘除,然后用numpy实现,写在推理脚本的前处理

在物体检测中,模型的最后要对结果做NMS,其中涉及动态shape,在Ascend310-C32上不支持,那么导出模型直接摘除后处理,模型直接输出feature_map,然后在推理脚本的后处理做NMS(numpy的NMS也很快,不用担心性能)

from tensorflow.python.compat import compat

with compat.forward_compatibility_horizon(2019, 05, 01):
    y = model(x)

猜你喜欢

转载自blog.csdn.net/yxpandjay/article/details/108780776