deepfm tensorflow 模型导出及java使用

接上篇

python导出

from tensorflow.python import pywrap_tensorflow
import tensorflow as tf
from tensorflow.python.framework import graph_util

def getAllNodes(checkpoint_path):
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    # Print tensor name and values
    for key in var_to_shape_map:
        print("tensor_name: ", key)
        #print(reader.get_tensor(key))

def freeze_graph(ckpt, output_graph):
    output_node_names = 'feat_index,feat_value,label,dropout_keep_fm,dropout_keep_deep,train_phase,output/predictlabel'

    # saver = tf.train.import_meta_graph(ckpt+'.meta', clear_devices=True)
    saver = tf.compat.v1.train.import_meta_graph(ckpt+".meta", clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        saver.restore(sess, ckpt)
        for node in input_graph_def.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in range(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr: del node.attr['use_locking']
        output_graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=input_graph_def,
            output_node_names=output_node_names.split(',')
        )
        with tf.gfile.GFile(output_graph, 'wb') as fw:
            fw.write(output_graph_def.SerializeToString())
        print('{} ops in the final graph.'.format(len(output_graph_def.node)))

if __name__ == '__main__':
    ckpt_path = 'model'
    getAllNodes(ckpt_path)
    output_graph_path = 'res.pb'
    freeze_graph(ckpt_path, output_graph_path)

java加载


private final static Logger logger = LoggerFactory.getLogger(Test.class);

private Graph graph;
private Session sess;

public Test(String pbFile) {
    try {
        graph = new Graph();
        byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(pbFile));
        graph.importGraphDef(graphBytes);
        sess = new Session(graph);
    } catch (java.io.IOException e) {
        logger.error("",e);
    }
}

public float[][] predict(int[][] feat_index, float[][] feat_value, float[][]label, float[]dropout_keep_fm,
        float[]dropout_keep_deep, boolean train_phase) {
    Tensor indexTensor = Tensor.create(feat_index);
    Tensor valueTensor = Tensor.create(feat_value);
    Tensor labelTensor = Tensor.create(label);
    Tensor dropoutKeepFmTensor = Tensor.create(dropout_keep_fm);
    Tensor dropoutkeepDeepTensor = Tensor.create(dropout_keep_deep);
    Tensor trainPhaseTensor = Tensor.create(train_phase);

    Tensor rlt = sess.runner().feed("feat_index", indexTensor).feed("feat_value", valueTensor)
            .feed("label", labelTensor).feed("dropout_keep_fm", dropoutKeepFmTensor)
            .feed("dropout_keep_deep", dropoutkeepDeepTensor).feed("train_phase", trainPhaseTensor)        
            .fetch("output/predictlabel").run().get(0);
    float[][] finalRlt = new float[feat_index.length][1];
    rlt.copyTo(finalRlt);
    return finalRlt;
}

public static void main( String[] args ) throws IOException
{
    Test t=new Test("C:\\res.pb");
    float[][] res=t.predict(new int[][] {{1,12,16,21,330,1065,1078,1087,1089,1092,1093,1095,1098,1099,1101,1104,1105,1107,1125,2639,20838,20857,21171,21997,22109,23049,23550,24190,24217,29147,43920,43922,43924,43926,43930,43941,44055,44995,45496,46136,46163,51046,65866,65868,65870,65872,65876,65887,66001,66941,67442,68082,68109,72991,87812,87814,87817,87818,87822,87833}}, new float[][]{{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}}, new float[][]{{1}},
            new float[]{1.0f,1.0f}, 
            new float[]{0.8f, 0.8f, 0.8f} , false);
    System.out.println(JSON.toJSONString(res));
}

猜你喜欢

转载自blog.51cto.com/12597095/2561301