[Tensorflow] 统计模型的参数数量 How to calculate the amount of parameters in my model?

版权声明:Copyright reserved to Hazekiah Wang ([email protected]) https://blog.csdn.net/u010909964/article/details/83090282
import logging
logging.basicConfig(level=logging.INFO, format='%(message)s', filemode='w', filename=config.logger)

def _params_usage():
	total = 0
	prompt = []
	for v in tf.trainable_variables():
		shape = v.get_shape()
		cnt = 1
		for dim in shape:
			cnt *= dim.value
		prompt.append('{} with shape {} has {}'.format(v.name, shape, cnt))
		logging.info(prompt[-1])
		total += cnt
	prompt.append('totaling {}'.format(total))
	logging.info(prompt[-1])
	return '\n'.join(prompt)

shape is of type TensorShape. It is an iterable and each element is of type Dimension, whose attribute .value gives the raw integer of the dimension.

The above function _params_usage() prints out infos in the specified logging approach, and returns a string. This is intended to prints out in parallel to a logging file and the stdout stream.

猜你喜欢

转载自blog.csdn.net/u010909964/article/details/83090282