统计tensorflow中trainable参数的数量,计算方法举例子如下:

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/tianguiyuyu/article/details/84800360
#!/usr/bin/python
# -*- coding:utf8 -*-

import pandas as pd
import numpy as np
import tensorflow as tf
import cv2
import config as cfg


def model(x):

    conv1=tf.layers.conv2d(inputs=x,filters=20,kernel_size=1, padding='same')
    conv2=tf.layers.conv2d(inputs=conv1,filters=20,kernel_size=1,padding='same')
    shape = conv2.get_shape().as_list()
    dim = 1
    for d in shape[1:]:
        dim *= d
    dense_input = tf.reshape(conv2, [-1, dim])
    output=tf.layers.dense(inputs=dense_input,units=20)
    return output

#统计trainable的数量
def count1():
    total_parameters = 0
    for variable in tf.trainable_variables():
        print(variable)
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        # print(shape)
        # print(len(shape))
        variable_parameters = 1
        for dim in shape:
            # print(dim)
            variable_parameters *= dim.value
        print(variable_parameters)
        total_parameters += variable_parameters
    return total_parameters

x=tf.placeholder(name="x",shape=[None,20,20,3],dtype=tf.float32)
result=model(x)


sess=tf.Session()
sess.run(tf.global_variables_initializer())

image=cv2.imread("./3_2.png")
image=cv2.resize(image,(20,20))
k=[]
k.append(image)
print(count1())
print(sess.run(result,feed_dict={x:k}))

猜你喜欢

转载自blog.csdn.net/tianguiyuyu/article/details/84800360