tensorflow 的HelloWorld模型

第一步,训练并保存模型:
saver_hello.py内容如下

import tensorflow as tf


def helloFunc():
    print("hellFunc")


if __name__ == '__main__':
    hello = tf.Variable(tf.constant('Hello World', name = "hello"))
    #init = tf.initialize_all_variables() #deprecated
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)


    saver = tf.train.Saver()
    saver.save(sess, "./hello_model")


执行saver_hello.py
生成3个文件
hello_model.data-00000-of-00001  hello_model.index  hello_model.meta
(tensorflow) zm@linux-kl9l:~/workspace/flower/hello> ./saver_hello.py
(tensorflow) zm@linux-kl9l:~/workspace/flower/hello> ls
a.out  checkpoint  hello.cpp  hello_model.data-00000-of-00001  hello_model.index  hello_model.meta  restore_hello.py  saver_hello.py


第二步,加载模型,并测试:
restore_hello.py 内容如下 
import tensorflow as tf


if __name__ == '__main__':
    restore = tf.train.import_meta_graph("hello_model.meta")
    sess = tf.Session()
    restore.restore(sess, "hello_model")


    print(sess.run(tf.get_default_graph().get_tensor_by_name("hello:0")))
执行
(tensorflow) zm@linux-kl9l:~/workspace/flower/hello> ./restore_hello.py
会打印出
b'Hello World'


opencv的加载方式如下:
main.cpp内容
#include <opencv2/opencv.hpp>


    int
main( int argc, char **argv )
{
    if (argc != 2)
    {
        printf( "argc %d != 2\n", argc );
        exit(-1);
    }


    cv::dnn::Net net = cv::dnn::readNetFromTensorflow(argv[1]);


    return 0;
}

编译后加载:

./a.out model_data/model.ckpt-1000.index

作者:帅得不敢出门

猜你喜欢

转载自blog.csdn.net/zmlovelx/article/details/80919193