参考:https://www.cnblogs.com/hellcat/p/6906065.html
tf.nn.sufficient_statistics(x, axes, shift=None,keep_dims=False, name=None)
1. 功能:计算与均值和方差有关的完全统计量
2. 返回:4维元组 -> ( 元素个数,元素加和,元素的平方和,shift )
3. 示例:
# tf.__version__ -> 1.4.0
import tensorflow as tf
size = 3
W = tf.constant([[1., 2., 3.], [4., 5., 6.]])
shift = tf.Variable(tf.zeros([size]))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ss = sess.run(tf.nn.sufficient_statistics(W, axes=[0], shift=shift)) # axes=[0]表示按列计算
print(ss)
for i in ss:
print(i)
输出:
(2.0, array([5., 7., 9.], dtype=float32), array([17., 29., 45.], dtype=float32), array([0., 0., 0.], dtype=float32))
2.0
[5. 7. 9.]
[17. 29. 45.]
[0. 0. 0.]
4. 分析源码:
def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
"""Calculate the sufficient statistics for the mean and variance of `x`.
These sufficient statistics are computed using the one pass algorithm on
an input that's optionally shifted. See:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
Args:
x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and variance.
shift: A `Tensor` containing the value by which to shift the data for
numerical stability, or `None` if no shift is to be performed. A shift
close to the true mean provides the most numerically stable results.
keep_dims: produce statistics with the same dimensionality as the input.
name: Name used to scope the operations that compute the sufficient stats.
Returns:
Four `Tensor` objects of the same type as `x`:
* the count (number of elements to average over).
* the (possibly shifted) sum of the elements in the array.
* the (possibly shifted) sum of squares of the elements in the array.
* the shift by which the mean must be corrected or None if `shift` is None.
"""
axes = list(set(axes))
with ops.name_scope(name, "sufficient_statistics", [x, shift]):
x = ops.convert_to_tensor(x, name="x")
x_shape = x.get_shape()
if all(x_shape[d].value is not None for d in axes):
counts = 1
for d in axes:
counts *= x_shape[d].value
counts = constant_op.constant(counts, dtype=x.dtype)
else: # shape needs to be inferred at runtime.
x_dims = array_ops.gather(
math_ops.cast(array_ops.shape(x), x.dtype), axes)
counts = math_ops.reduce_prod(x_dims, name="count")
if shift is not None:
shift = ops.convert_to_tensor(shift, name="shift")
m_ss = math_ops.subtract(x, shift)
v_ss = math_ops.squared_difference(x, shift)
else: # no shift.
m_ss = x
v_ss = math_ops.square(x)
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
return counts, m_ss, v_ss, shift
4.1 元素个数的计算:
for d in axes:
counts *= x_shape[d].value
这样就清楚元素个数是怎么得来的了,其实就是累乘axes里给出的维度对应的大小。
4.2 squared_difference(x, shift)会返回什么呢?看源码:
def squared_difference(x, y, name=None):
r"""Returns (x - y)(x - y) element-wise.
# ...省略...
返回:( x - shift )^2
5. 当参数 shift=None 时遇到的Bug:
# tf.__version__ -> 1.4.0
import tensorflow as tf
size = 3
W = tf.constant([[1., 2., 3.], [4., 5., 6.]])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ss = sess.run(tf.nn.sufficient_statistics(W, axes=[0,1])) # axes=[0]表示按列计算
print(ss)
Bug:
# ...省略...
TypeError: Fetch argument None has invalid type <class 'NoneType'>
这是因为当参数 shift=None时,tf.nn.sufficient_statistics(W, axes=[0,1])也会返回一个None!
不用 sess.run(tf.nn.sufficient_statistics(W, axes=[0,1])),直接 print(tf.nn.sufficient_statistics(W, axes=[0,1]))是OK的:
(<tf.Tensor 'sufficient_statistics/Const:0' shape=() dtype=float32>, <tf.Tensor 'sufficient_statistics/mean_ss:0' shape=() dtype=float32>, <tf.Tensor 'sufficient_statistics/var_ss:0' shape=() dtype=float32>, None)
不用怀疑!返回元组的最后一个元素 shift 就是 None!
然后,用了一个很笨的办法输出返回值:
# tf.__version__ -> 1.4.0
import tensorflow as tf
size = 3
W = tf.constant([[1., 2., 3.], [4., 5., 6.]])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
A = tf.nn.sufficient_statistics(W, axes=[0]) # axes=[0]表示按列计算
for i, Ai in enumerate(A):
if Ai != None:
print(i, Ai.eval())
else:
print(i, None)
输出:
0 2.0
1 [5. 7. 9.]
2 [17. 29. 45.]
3 None